agent_sdk_providers/impls/
gemini.rs1pub(crate) mod data;
7
8use crate::attachments::validate_request_attachments;
9use crate::provider::LlmProvider;
10use crate::streaming::{StreamBox, StreamDelta, StreamErrorKind};
11use agent_sdk_foundation::llm::{ChatOutcome, ChatRequest, ChatResponse, ThinkingConfig};
12use anyhow::Result;
13use async_trait::async_trait;
14use data::{
15 ApiContent, ApiFunctionCallingConfig, ApiGenerateContentRequest, ApiGenerateContentResponse,
16 ApiGenerationConfig, ApiPart, ApiUsageMetadata, build_api_contents, build_content_blocks,
17 convert_tools_to_config, gemini_response_schema, map_finish_reason, map_thinking_config,
18};
19use reqwest::StatusCode;
20
21const API_BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta";
22
23pub const MODEL_GEMINI_31_PRO: &str = "gemini-3.1-pro-preview";
25pub const MODEL_GEMINI_31_FLASH_LITE: &str = "gemini-3.1-flash-lite-preview";
26
27pub const MODEL_GEMINI_3_FLASH: &str = "gemini-3-flash-preview";
29
30pub const MODEL_GEMINI_3_PRO: &str = "gemini-3.0-pro";
32
33pub const MODEL_GEMINI_25_FLASH: &str = "gemini-2.5-flash";
35pub const MODEL_GEMINI_25_PRO: &str = "gemini-2.5-pro";
36
37pub const MODEL_GEMINI_2_FLASH: &str = "gemini-2.0-flash";
39pub const MODEL_GEMINI_2_FLASH_LITE: &str = "gemini-2.0-flash-lite";
40
41#[derive(Clone)]
43pub struct GeminiProvider {
44 client: reqwest::Client,
45 api_key: String,
46 model: String,
47 base_url: String,
48 thinking: Option<ThinkingConfig>,
49 use_header_auth: bool,
52 extra_headers: Vec<(String, String)>,
54}
55
56impl GeminiProvider {
57 pub const API_KEY_ENV: &'static str = "GEMINI_API_KEY";
59
60 #[must_use]
62 pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
63 Self {
64 client: reqwest::Client::new(),
65 api_key: api_key.into(),
66 model: model.into(),
67 base_url: API_BASE_URL.to_owned(),
68 thinking: None,
69 use_header_auth: true,
70 extra_headers: Vec::new(),
71 }
72 }
73
74 #[must_use]
82 pub fn from_env() -> Self {
83 Self::try_from_env().unwrap_or_else(|e| panic!("{e}"))
84 }
85
86 pub fn try_from_env() -> Result<Self> {
93 let api_key = std::env::var(Self::API_KEY_ENV).map_err(|_| {
94 anyhow::anyhow!("environment variable `{}` is not set", Self::API_KEY_ENV)
95 })?;
96 Ok(Self::flash(api_key))
97 }
98
99 #[must_use]
101 pub fn flash(api_key: impl Into<String>) -> Self {
102 Self::new(api_key, MODEL_GEMINI_3_FLASH)
103 }
104
105 #[must_use]
107 pub fn flash_lite_31(api_key: String) -> Self {
108 Self::new(api_key, MODEL_GEMINI_31_FLASH_LITE.to_owned())
109 }
110
111 #[must_use]
113 pub fn flash_lite(api_key: String) -> Self {
114 Self::new(api_key, MODEL_GEMINI_2_FLASH_LITE.to_owned())
115 }
116
117 #[must_use]
119 pub fn pro_31(api_key: String) -> Self {
120 Self::new(api_key, MODEL_GEMINI_31_PRO.to_owned())
121 }
122
123 #[must_use]
125 pub fn pro(api_key: String) -> Self {
126 Self::new(api_key, MODEL_GEMINI_31_PRO.to_owned())
127 }
128
129 #[must_use]
131 pub const fn with_thinking(mut self, thinking: ThinkingConfig) -> Self {
132 self.thinking = Some(thinking);
133 self
134 }
135
136 #[must_use]
138 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
139 self.base_url = base_url.into();
140 self
141 }
142
143 #[must_use]
146 pub const fn with_header_auth(mut self) -> Self {
147 self.use_header_auth = true;
148 self
149 }
150
151 #[must_use]
153 pub fn with_extra_headers(mut self, headers: Vec<(String, String)>) -> Self {
154 self.extra_headers = headers;
155 self
156 }
157
158 fn apply_auth(&self, builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
161 let builder = if self.api_key.is_empty() {
162 builder
163 } else if self.use_header_auth {
164 builder.header("x-goog-api-key", &self.api_key)
165 } else {
166 builder.query(&[("key", &self.api_key)])
167 };
168 self.extra_headers
169 .iter()
170 .fold(builder, |b, (k, v)| b.header(k.as_str(), v.as_str()))
171 }
172}
173
174#[async_trait]
175#[allow(clippy::too_many_lines)]
176impl LlmProvider for GeminiProvider {
177 async fn chat(&self, request: ChatRequest) -> Result<ChatOutcome> {
178 let thinking = match self.resolve_thinking_config(request.thinking.as_ref()) {
179 Ok(thinking) => thinking,
180 Err(error) => return Ok(ChatOutcome::InvalidRequest(error.to_string())),
181 };
182 if let Err(error) = validate_request_attachments(self.provider(), self.model(), &request) {
183 return Ok(ChatOutcome::InvalidRequest(error.to_string()));
184 }
185 let contents = build_api_contents(&request.messages);
186 let tools = request.tools.map(convert_tools_to_config);
187 let tool_config = request
188 .tool_choice
189 .as_ref()
190 .map(ApiFunctionCallingConfig::from_tool_choice);
191 let system_instruction = if request.system.is_empty() {
192 None
193 } else {
194 Some(ApiContent {
195 role: None,
196 parts: vec![ApiPart::Text {
197 text: request.system.clone(),
198 thought_signature: None,
199 }],
200 })
201 };
202
203 let thinking_config = thinking.as_ref().map(map_thinking_config);
204 let (response_mime_type, response_schema) =
205 request.response_format.as_ref().map_or((None, None), |rf| {
206 (
207 Some("application/json"),
208 Some(gemini_response_schema(&rf.schema)),
209 )
210 });
211
212 let api_request = ApiGenerateContentRequest {
213 contents: &contents,
214 system_instruction: system_instruction.as_ref(),
215 tools: tools.as_ref().map(std::slice::from_ref),
216 tool_config,
217 generation_config: Some(ApiGenerationConfig {
218 max_output_tokens: Some(request.max_tokens),
219 thinking_config,
220 response_mime_type,
221 response_schema,
222 }),
223 cached_content: request.cached_content.as_deref(),
224 };
225
226 log::debug!(
227 "Gemini LLM request model={} max_tokens={}",
228 self.model,
229 request.max_tokens
230 );
231
232 let builder = self
233 .client
234 .post(format!(
235 "{}/models/{}:generateContent",
236 self.base_url, self.model
237 ))
238 .header("Content-Type", "application/json");
239 let response = self
240 .apply_auth(builder)
241 .json(&api_request)
242 .send()
243 .await
244 .map_err(|e| anyhow::anyhow!("request failed: {e}"))?;
245
246 let status = response.status();
247 let bytes = response
248 .bytes()
249 .await
250 .map_err(|e| anyhow::anyhow!("failed to read response body: {e}"))?;
251
252 log::debug!(
253 "Gemini LLM response status={} body_len={}",
254 status,
255 bytes.len()
256 );
257
258 if status == StatusCode::TOO_MANY_REQUESTS {
259 return Ok(ChatOutcome::RateLimited);
260 }
261
262 if status.is_server_error() {
263 let body = String::from_utf8_lossy(&bytes);
264 log::error!("Gemini server error status={status} body={body}");
265 return Ok(ChatOutcome::ServerError(body.into_owned()));
266 }
267
268 if status.is_client_error() {
269 let body = String::from_utf8_lossy(&bytes);
270 log::warn!("Gemini client error status={status} body={body}");
271 return Ok(ChatOutcome::InvalidRequest(body.into_owned()));
272 }
273
274 let api_response: ApiGenerateContentResponse = serde_json::from_slice(&bytes)
275 .map_err(|e| anyhow::anyhow!("failed to parse response: {e}"))?;
276
277 let candidate = api_response
278 .candidates
279 .into_iter()
280 .next()
281 .ok_or_else(|| anyhow::anyhow!("no candidates in response"))?;
282
283 let content = build_content_blocks(&candidate.content);
284
285 if content.is_empty() && !candidate.content.parts.is_empty() {
286 log::warn!(
287 "Gemini parts not converted to content blocks raw_parts={:?}",
288 candidate.content.parts
289 );
290 }
291
292 let has_tool_calls = content
293 .iter()
294 .any(|b| matches!(b, agent_sdk_foundation::llm::ContentBlock::ToolUse { .. }));
295
296 let stop_reason = candidate
297 .finish_reason
298 .as_ref()
299 .map(|r| map_finish_reason(r, has_tool_calls));
300
301 let usage = api_response
302 .usage_metadata
303 .unwrap_or(ApiUsageMetadata {
304 prompt: 0,
305 candidates: 0,
306 cached_content: 0,
307 })
308 .into_usage();
309
310 Ok(ChatOutcome::Success(ChatResponse {
311 id: String::new(),
312 content,
313 model: self.model.clone(),
314 stop_reason,
315 usage,
316 }))
317 }
318
319 fn chat_stream(&self, request: ChatRequest) -> StreamBox<'_> {
320 Box::pin(async_stream::stream! {
321 let thinking = match self.resolve_thinking_config(request.thinking.as_ref()) {
322 Ok(thinking) => thinking,
323 Err(error) => {
324 yield Ok(StreamDelta::Error {
325 message: error.to_string(),
326 kind: StreamErrorKind::InvalidRequest,
327 });
328 return;
329 }
330 };
331 if let Err(error) = validate_request_attachments(self.provider(), self.model(), &request) {
332 yield Ok(StreamDelta::Error {
333 message: error.to_string(),
334 kind: StreamErrorKind::InvalidRequest,
335 });
336 return;
337 }
338 let contents = build_api_contents(&request.messages);
339 let tools = request.tools.map(convert_tools_to_config);
340 let tool_config = request
341 .tool_choice
342 .as_ref()
343 .map(ApiFunctionCallingConfig::from_tool_choice);
344 let system_instruction = if request.system.is_empty() {
345 None
346 } else {
347 Some(ApiContent {
348 role: None,
349 parts: vec![ApiPart::Text {
350 text: request.system.clone(),
351 thought_signature: None,
352 }],
353 })
354 };
355
356 let thinking_config = thinking.as_ref().map(map_thinking_config);
357 let (response_mime_type, response_schema) = request
358 .response_format
359 .as_ref()
360 .map_or((None, None), |rf| {
361 (
362 Some("application/json"),
363 Some(gemini_response_schema(&rf.schema)),
364 )
365 });
366
367 let api_request = ApiGenerateContentRequest {
368 contents: &contents,
369 system_instruction: system_instruction.as_ref(),
370 tools: tools.as_ref().map(std::slice::from_ref),
371 tool_config,
372 generation_config: Some(ApiGenerationConfig {
373 max_output_tokens: Some(request.max_tokens),
374 thinking_config,
375 response_mime_type,
376 response_schema,
377 }),
378 cached_content: request.cached_content.as_deref(),
379 };
380
381 log::debug!(
382 "Gemini streaming LLM request model={} max_tokens={}",
383 self.model,
384 request.max_tokens
385 );
386
387 let stream_builder = self
388 .client
389 .post(format!(
390 "{}/models/{}:streamGenerateContent",
391 self.base_url, self.model
392 ))
393 .header("Content-Type", "application/json")
394 .query(&[("alt", "sse")]);
395 let Ok(response) = self
396 .apply_auth(stream_builder)
397 .json(&api_request)
398 .send()
399 .await
400 else {
401 yield Err(anyhow::anyhow!("request failed"));
402 return;
403 };
404
405 let status = response.status();
406 if !status.is_success() {
407 let body = response.text().await.unwrap_or_default();
408 let kind = if status == StatusCode::TOO_MANY_REQUESTS {
409 StreamErrorKind::RateLimited
410 } else if status.is_server_error() {
411 StreamErrorKind::ServerError
412 } else {
413 StreamErrorKind::InvalidRequest
414 };
415 log::warn!("Gemini error status={status} body={body}");
416 yield Ok(StreamDelta::Error {
417 message: body,
418 kind,
419 });
420 return;
421 }
422
423 let mut inner = data::stream_gemini_response(response);
424 while let Some(item) = futures::StreamExt::next(&mut inner).await {
425 yield item;
426 }
427 })
428 }
429
430 fn model(&self) -> &str {
431 &self.model
432 }
433
434 fn provider(&self) -> &'static str {
435 "gemini"
436 }
437
438 fn configured_thinking(&self) -> Option<&ThinkingConfig> {
439 self.thinking.as_ref()
440 }
441}
442
443#[cfg(test)]
444mod tests {
445 use super::*;
446
447 #[test]
448 fn test_new_creates_provider_with_custom_model() {
449 let provider = GeminiProvider::new("test-api-key".to_string(), "custom-model".to_string());
450
451 assert_eq!(provider.model(), "custom-model");
452 assert_eq!(provider.provider(), "gemini");
453 }
454
455 #[test]
456 fn test_flash_factory_creates_flash_provider() {
457 let provider = GeminiProvider::flash("test-api-key".to_string());
458
459 assert_eq!(provider.model(), MODEL_GEMINI_3_FLASH);
460 assert_eq!(provider.provider(), "gemini");
461 }
462
463 #[test]
464 fn test_flash_lite_factory_creates_flash_lite_provider() {
465 let provider = GeminiProvider::flash_lite("test-api-key".to_string());
466
467 assert_eq!(provider.model(), MODEL_GEMINI_2_FLASH_LITE);
468 assert_eq!(provider.provider(), "gemini");
469 }
470
471 #[test]
472 fn test_flash_lite_31_factory_creates_flash_lite_provider() {
473 let provider = GeminiProvider::flash_lite_31("test-api-key".to_string());
474
475 assert_eq!(provider.model(), MODEL_GEMINI_31_FLASH_LITE);
476 assert_eq!(provider.provider(), "gemini");
477 }
478
479 #[test]
480 fn test_pro_factory_creates_pro_provider() {
481 let provider = GeminiProvider::pro("test-api-key".to_string());
482
483 assert_eq!(provider.model(), MODEL_GEMINI_31_PRO);
484 assert_eq!(provider.provider(), "gemini");
485 }
486
487 #[test]
488 fn test_pro_31_factory_creates_pro_provider() {
489 let provider = GeminiProvider::pro_31("test-api-key".to_string());
490
491 assert_eq!(provider.model(), MODEL_GEMINI_31_PRO);
492 assert_eq!(provider.provider(), "gemini");
493 }
494
495 #[test]
496 fn test_model_constants_have_expected_values() {
497 assert_eq!(MODEL_GEMINI_31_PRO, "gemini-3.1-pro-preview");
498 assert_eq!(MODEL_GEMINI_31_FLASH_LITE, "gemini-3.1-flash-lite-preview");
499 assert_eq!(MODEL_GEMINI_3_FLASH, "gemini-3-flash-preview");
500 assert_eq!(MODEL_GEMINI_3_PRO, "gemini-3.0-pro");
501 assert_eq!(MODEL_GEMINI_25_FLASH, "gemini-2.5-flash");
502 assert_eq!(MODEL_GEMINI_25_PRO, "gemini-2.5-pro");
503 assert_eq!(MODEL_GEMINI_2_FLASH, "gemini-2.0-flash");
504 assert_eq!(MODEL_GEMINI_2_FLASH_LITE, "gemini-2.0-flash-lite");
505 }
506
507 #[test]
508 fn test_gemini_20_models_reject_thinking() {
509 let provider = GeminiProvider::flash_lite("test-api-key".to_string());
510 let error = provider
511 .validate_thinking_config(Some(&ThinkingConfig::new(10_000)))
512 .unwrap_err();
513 assert!(error.to_string().contains("thinking is not supported"));
514 }
515
516 #[test]
517 fn test_default_uses_header_auth() {
518 let provider = GeminiProvider::new("test-key".to_string(), "model".to_string());
519 assert!(
520 provider.use_header_auth,
521 "Default should use header auth for security"
522 );
523 }
524
525 #[test]
526 fn test_provider_is_cloneable() {
527 let provider = GeminiProvider::new("test-api-key".to_string(), "test-model".to_string());
528 let cloned = provider.clone();
529
530 assert_eq!(provider.model(), cloned.model());
531 assert_eq!(provider.provider(), cloned.provider());
532 }
533}