asimov_gemini_module/
lib.rs

1// This is free and unencumbered software released into the public domain.
2
3#![no_std]
4#![forbid(unsafe_code)]
5
6use asimov_module::{
7    prelude::*,
8    secrecy::{ExposeSecret, SecretString},
9    tracing,
10};
11use core::error::Error;
12use serde_json::{Value, json};
13
14#[derive(Clone, Debug, bon::Builder)]
15#[builder(on(String, into))]
16pub struct Options {
17    #[builder(default = "https://generativelanguage.googleapis.com")]
18    pub endpoint: String,
19
20    #[builder(default = "gemini-2.5-flash")]
21    pub model: String,
22
23    pub max_tokens: Option<usize>,
24
25    #[builder(into)]
26    pub api_key: SecretString,
27}
28
29pub fn generate(input: impl AsRef<str>, options: &Options) -> Result<Vec<String>, Box<dyn Error>> {
30    let mut req = json!({
31        "contents": {
32            "parts": [
33                {"text": input.as_ref()}
34            ]
35        },
36    });
37
38    if let Some(max_tokens) = options.max_tokens {
39        req["generationConfig"] = json!({"maxOutputTokens": max_tokens})
40    }
41
42    let mut resp = ureq::Agent::config_builder()
43        .http_status_as_error(false)
44        .user_agent("asimov-gemini-module")
45        .build()
46        .new_agent()
47        .post(format!(
48            "{}/v1beta/models/{}:generateContent",
49            options.endpoint, options.model
50        ))
51        .header("x-goog-api-key", options.api_key.expose_secret())
52        .header("content-type", "application/json")
53        .send_json(&req)
54        .inspect_err(|e| tracing::error!("HTTP request failed: {e}"))?;
55    tracing::debug!(response = ?resp);
56
57    let status = resp.status();
58    tracing::debug!(status = status.to_string());
59
60    let resp: Value = resp
61        .body_mut()
62        .read_json()
63        .inspect_err(|e| tracing::error!("unable to read HTTP response body: {e}"))?;
64    tracing::debug!(body = ?resp);
65
66    if !status.is_success() {
67        tracing::error!("Received an error response: {status}");
68
69        // {
70        //   "error": {
71        //     "code": 400,
72        //     "message": "API key not valid. Please pass a valid API key.",
73        //     "status": "INVALID_ARGUMENT"
74        //   }
75        // }
76        if let Some(message) = resp["error"]["message"].as_str() {
77            return Err(message.into());
78        }
79    }
80
81    // {
82    //   "candidates": [
83    //     {
84    //       "content": {
85    //         "parts": [
86    //           {
87    //             "text": "..."
88    //           }
89    //         ],
90    //         "role": "model"
91    //       },
92    //       "finishReason": "STOP",
93    //       "index": 0
94    //     }
95    //   ],
96    //   "usageMetadata": {
97    //     "promptTokenCount": 8,
98    //     "candidatesTokenCount": 15,
99    //     "totalTokenCount": 1191,
100    //     "promptTokensDetails": [
101    //       {
102    //         "modality": "TEXT",
103    //         "tokenCount": 8
104    //       }
105    //     ],
106    //     "thoughtsTokenCount": 1168
107    //   },
108    //   "modelVersion": "gemini-2.5-flash",
109    //   "responseId": "..."
110    // }
111
112    let mut responses = Vec::new();
113
114    if let Some(chunks) = resp["candidates"].as_array() {
115        for chunk in chunks {
116            let content = &chunk["content"];
117            if !content.is_object() {
118                continue;
119            }
120
121            if content["role"].as_str().is_none_or(|r| r != "model") {
122                continue;
123            }
124
125            if let Some(parts) = content["parts"].as_array() {
126                for part in parts {
127                    if let Some(text) = part["text"].as_str() {
128                        responses.push(text.to_string())
129                    }
130                }
131            }
132
133            if let Some(stop_reason) = chunk["finishReason"].as_str() {
134                tracing::debug!(stop_reason);
135            }
136        }
137    }
138
139    Ok(responses)
140}