1use crate::adapter::{
8 ChatMessage, ChatRole, ModelAdapter, ModelConfig, ModelError, ModelRequest, ModelResponse,
9 StructuredRequest,
10};
11use async_trait::async_trait;
12use serde_json::{json, Value};
13use tracing::{debug, instrument};
14
15const GEMINI_API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta";
16const DEFAULT_MODEL: &str = "gemini-2.0-flash";
17const DEFAULT_MAX_TOKENS: u32 = 4096;
18
19pub struct GoogleAdapter {
24 client: reqwest::Client,
25 api_key: String,
26 default_model: String,
27}
28
29impl GoogleAdapter {
30 pub fn new(api_key: impl Into<String>) -> Self {
31 Self {
32 client: reqwest::Client::new(),
33 api_key: api_key.into(),
34 default_model: DEFAULT_MODEL.into(),
35 }
36 }
37
38 pub fn from_env() -> Result<Self, ModelError> {
40 let key = std::env::var("GOOGLE_API_KEY")
41 .or_else(|_| std::env::var("GEMINI_API_KEY"))
42 .map_err(|_| ModelError::Network("GOOGLE_API_KEY or GEMINI_API_KEY not set".into()))?;
43 Ok(Self::new(key))
44 }
45
46 pub fn with_default_model(mut self, model: impl Into<String>) -> Self {
47 self.default_model = model.into();
48 self
49 }
50
51 async fn call_api(&self, model: &str, body: Value) -> Result<Value, ModelError> {
52 let url = format!(
54 "{}/models/{}:generateContent?key={}",
55 GEMINI_API_BASE, model, self.api_key
56 );
57
58 let resp = self
59 .client
60 .post(&url)
61 .json(&body)
62 .send()
63 .await
64 .map_err(|e| ModelError::Network(e.to_string()))?;
65
66 let status = resp.status().as_u16();
67 let body_text = resp
68 .text()
69 .await
70 .map_err(|e| ModelError::Network(e.to_string()))?;
71
72 if status == 429 {
73 return Err(ModelError::RateLimited {
74 retry_after_secs: 60,
75 });
76 }
77 if status != 200 {
78 return Err(ModelError::Api {
79 status,
80 body: body_text,
81 });
82 }
83
84 serde_json::from_str(&body_text).map_err(|e| ModelError::Serialization(e.to_string()))
85 }
86
87 fn build_request_body(
88 &self,
89 messages: &[ChatMessage],
90 config: &ModelConfig,
91 response_mime_type: Option<&str>,
92 ) -> Value {
93 let max_tokens = config.max_tokens.unwrap_or(DEFAULT_MAX_TOKENS);
94
95 let contents: Vec<Value> = messages
97 .iter()
98 .filter(|m| !matches!(m.role, ChatRole::System))
99 .map(|m| {
100 let role = match m.role {
101 ChatRole::User | ChatRole::Tool => "user",
102 ChatRole::Assistant => "model",
103 ChatRole::System => unreachable!(),
104 };
105 json!({
106 "role": role,
107 "parts": [{"text": m.content}]
108 })
109 })
110 .collect();
111
112 let mut generation_config = json!({
113 "maxOutputTokens": max_tokens,
114 });
115
116 if let Some(temp) = config.temperature {
117 generation_config["temperature"] = json!(temp);
118 }
119 if let Some(stops) = &config.stop_sequences {
120 generation_config["stopSequences"] = json!(stops);
121 }
122 if let Some(mime) = response_mime_type {
123 generation_config["responseMimeType"] = json!(mime);
124 }
125
126 let mut body = json!({
127 "contents": contents,
128 "generationConfig": generation_config,
129 });
130
131 let system_text = config.system_prompt.as_deref().or_else(|| {
133 messages
134 .iter()
135 .find(|m| matches!(m.role, ChatRole::System))
136 .map(|m| m.content.as_str())
137 });
138
139 if let Some(sys) = system_text {
140 body["systemInstruction"] = json!({
141 "parts": [{"text": sys}]
142 });
143 }
144
145 body
146 }
147
148 fn parse_response(&self, resp: Value) -> Result<ModelResponse, ModelError> {
149 let candidate = resp["candidates"]
151 .as_array()
152 .and_then(|cs| cs.first())
153 .ok_or_else(|| ModelError::Api {
154 status: 200,
155 body: "no candidates in response".into(),
156 })?;
157
158 let content = candidate["content"]["parts"]
159 .as_array()
160 .and_then(|parts| parts.first())
161 .and_then(|p| p["text"].as_str())
162 .unwrap_or("")
163 .to_string();
164
165 let finish_reason = candidate["finishReason"]
166 .as_str()
167 .unwrap_or("STOP")
168 .to_string();
169
170 let usage = &resp["usageMetadata"];
172 let input_tokens = usage["promptTokenCount"].as_u64().unwrap_or(0);
173 let output_tokens = usage["candidatesTokenCount"].as_u64().unwrap_or(0);
174
175 let model = resp["modelVersion"]
177 .as_str()
178 .unwrap_or(&self.default_model)
179 .to_string();
180
181 Ok(ModelResponse {
182 content,
183 model,
184 finish_reason,
185 input_tokens,
186 output_tokens,
187 structured: None,
188 })
189 }
190}
191
192#[async_trait]
193impl ModelAdapter for GoogleAdapter {
194 fn system_name(&self) -> &'static str {
195 "google"
196 }
197
198 fn default_model(&self) -> &str {
199 &self.default_model
200 }
201
202 #[instrument(skip(self, request), fields(
203 gen_ai.system = "google",
204 gen_ai.request.model = tracing::field::Empty,
205 gen_ai.usage.input_tokens = tracing::field::Empty,
206 gen_ai.usage.output_tokens = tracing::field::Empty,
207 ))]
208 async fn chat(&self, request: ModelRequest) -> Result<ModelResponse, ModelError> {
209 let model = request
210 .config
211 .model
212 .as_deref()
213 .unwrap_or(&self.default_model)
214 .to_string();
215 tracing::Span::current().record("gen_ai.request.model", model.as_str());
216
217 debug!(model = %model, "Calling Gemini generateContent API");
218
219 let body = self.build_request_body(&request.messages, &request.config, None);
220 let resp_json = self.call_api(&model, body).await?;
221 let response = self.parse_response(resp_json)?;
222
223 tracing::Span::current()
224 .record("gen_ai.usage.input_tokens", response.input_tokens)
225 .record("gen_ai.usage.output_tokens", response.output_tokens);
226
227 Ok(response)
228 }
229
230 #[instrument(skip(self, request), fields(
231 gen_ai.system = "google",
232 gen_ai.request.model = tracing::field::Empty,
233 ))]
234 async fn structured_output(
235 &self,
236 request: StructuredRequest,
237 ) -> Result<ModelResponse, ModelError> {
238 let model = request
239 .config
240 .model
241 .as_deref()
242 .unwrap_or(&self.default_model)
243 .to_string();
244 tracing::Span::current().record("gen_ai.request.model", model.as_str());
245
246 let mut config = request.config.clone();
249 let schema_str = serde_json::to_string_pretty(&request.output_schema)
250 .map_err(|e| ModelError::Serialization(e.to_string()))?;
251 let system = config.system_prompt.get_or_insert_with(String::new);
252 system.push_str(&format!(
253 "\n\nRespond ONLY with a valid JSON object matching this schema:\n{schema_str}"
254 ));
255
256 let body = self.build_request_body(&request.messages, &config, Some("application/json"));
257 let resp_json = self.call_api(&model, body).await?;
258 let mut response = self.parse_response(resp_json)?;
259
260 let structured =
262 serde_json::from_str::<serde_json::Value>(&response.content).map_err(|e| {
263 ModelError::Serialization(format!("structured output parse error: {e}"))
264 })?;
265 response.structured = Some(structured);
266
267 Ok(response)
268 }
269}
270
271#[cfg(test)]
272mod tests {
273 use super::*;
274
275 #[test]
276 fn test_build_request_body_system_instruction() {
277 let adapter = GoogleAdapter::new("test-key");
278 let messages = vec![ChatMessage::user("Hello")];
279 let config = ModelConfig {
280 model: Some("gemini-2.0-flash".into()),
281 system_prompt: Some("You are helpful.".into()),
282 max_tokens: Some(100),
283 ..Default::default()
284 };
285 let body = adapter.build_request_body(&messages, &config, None);
286
287 assert!(body["systemInstruction"]["parts"][0]["text"]
288 .as_str()
289 .unwrap()
290 .contains("You are helpful"));
291 assert_eq!(body["contents"][0]["role"], "user");
292 assert_eq!(body["generationConfig"]["maxOutputTokens"], 100);
293 }
294
295 #[test]
296 fn test_parse_response() {
297 let adapter = GoogleAdapter::new("test-key");
298 let resp = json!({
299 "candidates": [{
300 "content": {
301 "parts": [{"text": "Hello!"}],
302 "role": "model"
303 },
304 "finishReason": "STOP"
305 }],
306 "usageMetadata": {
307 "promptTokenCount": 10,
308 "candidatesTokenCount": 3,
309 "totalTokenCount": 13
310 },
311 "modelVersion": "gemini-2.0-flash"
312 });
313
314 let parsed = adapter.parse_response(resp).unwrap();
315 assert_eq!(parsed.content, "Hello!");
316 assert_eq!(parsed.input_tokens, 10);
317 assert_eq!(parsed.output_tokens, 3);
318 assert_eq!(parsed.finish_reason, "STOP");
319 }
320}