Skip to main content

jamjet_models/
google.rs

1//! Google Gemini adapter (Generative Language API).
2//!
3//! Supports gemini-2.0-flash, gemini-1.5-flash, gemini-1.5-pro, etc.
4//! Reads `GOOGLE_API_KEY` or `GEMINI_API_KEY` from the environment.
5//! Uses the REST API directly (no SDK dependency).
6
7use crate::adapter::{
8    ChatMessage, ChatRole, ModelAdapter, ModelConfig, ModelError, ModelRequest, ModelResponse,
9    StructuredRequest,
10};
11use async_trait::async_trait;
12use serde_json::{json, Value};
13use tracing::{debug, instrument};
14
15const GEMINI_API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta";
16const DEFAULT_MODEL: &str = "gemini-2.0-flash";
17const DEFAULT_MAX_TOKENS: u32 = 4096;
18
19/// Google Gemini adapter via the Generative Language REST API.
20///
21/// Uses API key authentication (not OAuth). Supports all Gemini models
22/// available through Google AI Studio.
23pub struct GoogleAdapter {
24    client: reqwest::Client,
25    api_key: String,
26    default_model: String,
27}
28
29impl GoogleAdapter {
30    pub fn new(api_key: impl Into<String>) -> Self {
31        Self {
32            client: reqwest::Client::new(),
33            api_key: api_key.into(),
34            default_model: DEFAULT_MODEL.into(),
35        }
36    }
37
38    /// Create adapter from `GOOGLE_API_KEY` or `GEMINI_API_KEY` env var.
39    pub fn from_env() -> Result<Self, ModelError> {
40        let key = std::env::var("GOOGLE_API_KEY")
41            .or_else(|_| std::env::var("GEMINI_API_KEY"))
42            .map_err(|_| ModelError::Network("GOOGLE_API_KEY or GEMINI_API_KEY not set".into()))?;
43        Ok(Self::new(key))
44    }
45
46    pub fn with_default_model(mut self, model: impl Into<String>) -> Self {
47        self.default_model = model.into();
48        self
49    }
50
51    async fn call_api(&self, model: &str, body: Value) -> Result<Value, ModelError> {
52        // Gemini API URL: /v1beta/models/{model}:generateContent?key={key}
53        let url = format!(
54            "{}/models/{}:generateContent?key={}",
55            GEMINI_API_BASE, model, self.api_key
56        );
57
58        let resp = self
59            .client
60            .post(&url)
61            .json(&body)
62            .send()
63            .await
64            .map_err(|e| ModelError::Network(e.to_string()))?;
65
66        let status = resp.status().as_u16();
67        let body_text = resp
68            .text()
69            .await
70            .map_err(|e| ModelError::Network(e.to_string()))?;
71
72        if status == 429 {
73            return Err(ModelError::RateLimited {
74                retry_after_secs: 60,
75            });
76        }
77        if status != 200 {
78            return Err(ModelError::Api {
79                status,
80                body: body_text,
81            });
82        }
83
84        serde_json::from_str(&body_text).map_err(|e| ModelError::Serialization(e.to_string()))
85    }
86
87    fn build_request_body(
88        &self,
89        messages: &[ChatMessage],
90        config: &ModelConfig,
91        response_mime_type: Option<&str>,
92    ) -> Value {
93        let max_tokens = config.max_tokens.unwrap_or(DEFAULT_MAX_TOKENS);
94
95        // Gemini uses "contents" with roles "user" and "model" (not "assistant").
96        let contents: Vec<Value> = messages
97            .iter()
98            .filter(|m| !matches!(m.role, ChatRole::System))
99            .map(|m| {
100                let role = match m.role {
101                    ChatRole::User | ChatRole::Tool => "user",
102                    ChatRole::Assistant => "model",
103                    ChatRole::System => unreachable!(),
104                };
105                json!({
106                    "role": role,
107                    "parts": [{"text": m.content}]
108                })
109            })
110            .collect();
111
112        let mut generation_config = json!({
113            "maxOutputTokens": max_tokens,
114        });
115
116        if let Some(temp) = config.temperature {
117            generation_config["temperature"] = json!(temp);
118        }
119        if let Some(stops) = &config.stop_sequences {
120            generation_config["stopSequences"] = json!(stops);
121        }
122        if let Some(mime) = response_mime_type {
123            generation_config["responseMimeType"] = json!(mime);
124        }
125
126        let mut body = json!({
127            "contents": contents,
128            "generationConfig": generation_config,
129        });
130
131        // System instruction (separate from contents in Gemini API).
132        let system_text = config.system_prompt.as_deref().or_else(|| {
133            messages
134                .iter()
135                .find(|m| matches!(m.role, ChatRole::System))
136                .map(|m| m.content.as_str())
137        });
138
139        if let Some(sys) = system_text {
140            body["systemInstruction"] = json!({
141                "parts": [{"text": sys}]
142            });
143        }
144
145        body
146    }
147
148    fn parse_response(&self, resp: Value) -> Result<ModelResponse, ModelError> {
149        // Extract text from candidates[0].content.parts[0].text
150        let candidate = resp["candidates"]
151            .as_array()
152            .and_then(|cs| cs.first())
153            .ok_or_else(|| ModelError::Api {
154                status: 200,
155                body: "no candidates in response".into(),
156            })?;
157
158        let content = candidate["content"]["parts"]
159            .as_array()
160            .and_then(|parts| parts.first())
161            .and_then(|p| p["text"].as_str())
162            .unwrap_or("")
163            .to_string();
164
165        let finish_reason = candidate["finishReason"]
166            .as_str()
167            .unwrap_or("STOP")
168            .to_string();
169
170        // Token counts from usageMetadata.
171        let usage = &resp["usageMetadata"];
172        let input_tokens = usage["promptTokenCount"].as_u64().unwrap_or(0);
173        let output_tokens = usage["candidatesTokenCount"].as_u64().unwrap_or(0);
174
175        // Model name from modelVersion if available.
176        let model = resp["modelVersion"]
177            .as_str()
178            .unwrap_or(&self.default_model)
179            .to_string();
180
181        Ok(ModelResponse {
182            content,
183            model,
184            finish_reason,
185            input_tokens,
186            output_tokens,
187            structured: None,
188        })
189    }
190}
191
192#[async_trait]
193impl ModelAdapter for GoogleAdapter {
194    fn system_name(&self) -> &'static str {
195        "google"
196    }
197
198    fn default_model(&self) -> &str {
199        &self.default_model
200    }
201
202    #[instrument(skip(self, request), fields(
203        gen_ai.system = "google",
204        gen_ai.request.model = tracing::field::Empty,
205        gen_ai.usage.input_tokens = tracing::field::Empty,
206        gen_ai.usage.output_tokens = tracing::field::Empty,
207    ))]
208    async fn chat(&self, request: ModelRequest) -> Result<ModelResponse, ModelError> {
209        let model = request
210            .config
211            .model
212            .as_deref()
213            .unwrap_or(&self.default_model)
214            .to_string();
215        tracing::Span::current().record("gen_ai.request.model", model.as_str());
216
217        debug!(model = %model, "Calling Gemini generateContent API");
218
219        let body = self.build_request_body(&request.messages, &request.config, None);
220        let resp_json = self.call_api(&model, body).await?;
221        let response = self.parse_response(resp_json)?;
222
223        tracing::Span::current()
224            .record("gen_ai.usage.input_tokens", response.input_tokens)
225            .record("gen_ai.usage.output_tokens", response.output_tokens);
226
227        Ok(response)
228    }
229
230    #[instrument(skip(self, request), fields(
231        gen_ai.system = "google",
232        gen_ai.request.model = tracing::field::Empty,
233    ))]
234    async fn structured_output(
235        &self,
236        request: StructuredRequest,
237    ) -> Result<ModelResponse, ModelError> {
238        let model = request
239            .config
240            .model
241            .as_deref()
242            .unwrap_or(&self.default_model)
243            .to_string();
244        tracing::Span::current().record("gen_ai.request.model", model.as_str());
245
246        // Gemini supports responseMimeType: "application/json" for JSON mode.
247        // Append schema to system prompt.
248        let mut config = request.config.clone();
249        let schema_str = serde_json::to_string_pretty(&request.output_schema)
250            .map_err(|e| ModelError::Serialization(e.to_string()))?;
251        let system = config.system_prompt.get_or_insert_with(String::new);
252        system.push_str(&format!(
253            "\n\nRespond ONLY with a valid JSON object matching this schema:\n{schema_str}"
254        ));
255
256        let body = self.build_request_body(&request.messages, &config, Some("application/json"));
257        let resp_json = self.call_api(&model, body).await?;
258        let mut response = self.parse_response(resp_json)?;
259
260        // Parse JSON from response content.
261        let structured =
262            serde_json::from_str::<serde_json::Value>(&response.content).map_err(|e| {
263                ModelError::Serialization(format!("structured output parse error: {e}"))
264            })?;
265        response.structured = Some(structured);
266
267        Ok(response)
268    }
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274
275    #[test]
276    fn test_build_request_body_system_instruction() {
277        let adapter = GoogleAdapter::new("test-key");
278        let messages = vec![ChatMessage::user("Hello")];
279        let config = ModelConfig {
280            model: Some("gemini-2.0-flash".into()),
281            system_prompt: Some("You are helpful.".into()),
282            max_tokens: Some(100),
283            ..Default::default()
284        };
285        let body = adapter.build_request_body(&messages, &config, None);
286
287        assert!(body["systemInstruction"]["parts"][0]["text"]
288            .as_str()
289            .unwrap()
290            .contains("You are helpful"));
291        assert_eq!(body["contents"][0]["role"], "user");
292        assert_eq!(body["generationConfig"]["maxOutputTokens"], 100);
293    }
294
295    #[test]
296    fn test_parse_response() {
297        let adapter = GoogleAdapter::new("test-key");
298        let resp = json!({
299            "candidates": [{
300                "content": {
301                    "parts": [{"text": "Hello!"}],
302                    "role": "model"
303                },
304                "finishReason": "STOP"
305            }],
306            "usageMetadata": {
307                "promptTokenCount": 10,
308                "candidatesTokenCount": 3,
309                "totalTokenCount": 13
310            },
311            "modelVersion": "gemini-2.0-flash"
312        });
313
314        let parsed = adapter.parse_response(resp).unwrap();
315        assert_eq!(parsed.content, "Hello!");
316        assert_eq!(parsed.input_tokens, 10);
317        assert_eq!(parsed.output_tokens, 3);
318        assert_eq!(parsed.finish_reason, "STOP");
319    }
320}