Skip to main content

agent_sdk_providers/impls/
gemini.rs

1//! Google Gemini API provider implementation.
2//!
3//! This module provides an implementation of `LlmProvider` for the Google Gemini
4//! API (`generativelanguage.googleapis.com`).
5
6pub(crate) mod data;
7
8use crate::attachments::validate_request_attachments;
9use crate::provider::LlmProvider;
10use crate::streaming::{StreamBox, StreamDelta, StreamErrorKind};
11use agent_sdk_foundation::llm::{ChatOutcome, ChatRequest, ChatResponse, ThinkingConfig};
12use anyhow::Result;
13use async_trait::async_trait;
14use data::{
15    ApiContent, ApiFunctionCallingConfig, ApiGenerateContentRequest, ApiGenerateContentResponse,
16    ApiGenerationConfig, ApiPart, ApiUsageMetadata, build_api_contents, build_content_blocks,
17    convert_tools_to_config, gemini_response_schema, map_finish_reason, map_thinking_config,
18};
19use reqwest::StatusCode;
20
21const API_BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta";
22
23/// Connect timeout for the HTTP client (matches Anthropic/Vertex).
24const CONNECT_TIMEOUT_SECS: u64 = 30;
25/// TCP keepalive interval to keep long streaming connections from dropping.
26const TCP_KEEPALIVE_SECS: u64 = 30;
27/// Per-request read timeout for the **non-streaming** `chat()` path. Bounds a
28/// black-holed endpoint so a single turn cannot hang the agent loop forever.
29/// Streaming requests intentionally have no overall timeout.
30const CHAT_READ_TIMEOUT_SECS: u64 = 300;
31
32/// Build the shared HTTP client with connect + keepalive timeouts, falling back
33/// to a default client (with a logged warning) if the builder fails.
34fn build_http_client() -> reqwest::Client {
35    reqwest::Client::builder()
36        .connect_timeout(std::time::Duration::from_secs(CONNECT_TIMEOUT_SECS))
37        .tcp_keepalive(std::time::Duration::from_secs(TCP_KEEPALIVE_SECS))
38        .build()
39        .unwrap_or_else(|error| {
40            log::warn!(
41                "failed to build Gemini HTTP client with timeouts ({error}); using default client"
42            );
43            reqwest::Client::new()
44        })
45}
46
47// Gemini 3.1 series
48pub const MODEL_GEMINI_31_PRO: &str = "gemini-3.1-pro-preview";
49pub const MODEL_GEMINI_31_FLASH_LITE: &str = "gemini-3.1-flash-lite-preview";
50
51// Gemini 3 series
52pub const MODEL_GEMINI_3_FLASH: &str = "gemini-3-flash-preview";
53
54// Legacy Gemini 3.0 Pro model kept for explicit opt-in.
55pub const MODEL_GEMINI_3_PRO: &str = "gemini-3.0-pro";
56
57// Gemini 2.5 series
58pub const MODEL_GEMINI_25_FLASH: &str = "gemini-2.5-flash";
59pub const MODEL_GEMINI_25_PRO: &str = "gemini-2.5-pro";
60
61// Gemini 2.0 series
62pub const MODEL_GEMINI_2_FLASH: &str = "gemini-2.0-flash";
63pub const MODEL_GEMINI_2_FLASH_LITE: &str = "gemini-2.0-flash-lite";
64
65/// Google Gemini LLM provider.
66#[derive(Clone)]
67pub struct GeminiProvider {
68    client: reqwest::Client,
69    api_key: String,
70    model: String,
71    base_url: String,
72    thinking: Option<ThinkingConfig>,
73    /// When true, send the API key via `x-goog-api-key` header instead of a
74    /// query parameter. Required when routing through proxies.
75    use_header_auth: bool,
76    /// Extra headers applied to every request (e.g. for gateway authentication).
77    extra_headers: Vec<(String, String)>,
78}
79
80impl GeminiProvider {
81    /// The conventional environment variable holding the Gemini API key.
82    pub const API_KEY_ENV: &'static str = "GEMINI_API_KEY";
83
84    /// Create a new Gemini provider with the specified API key and model.
85    #[must_use]
86    pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
87        Self {
88            client: build_http_client(),
89            api_key: api_key.into(),
90            model: model.into(),
91            base_url: API_BASE_URL.to_owned(),
92            thinking: None,
93            use_header_auth: true,
94            extra_headers: Vec::new(),
95        }
96    }
97
98    /// Effective output-token budget for a request.
99    ///
100    /// Mirrors the Anthropic provider: when the caller did not explicitly set
101    /// `max_tokens`, substitute the provider/model default
102    /// ([`default_max_tokens`](LlmProvider::default_max_tokens)) instead of
103    /// silently capping at `ChatRequest::DEFAULT_MAX_TOKENS`.
104    fn effective_max_tokens(&self, request: &ChatRequest) -> u32 {
105        if request.max_tokens_explicit {
106            request.max_tokens
107        } else {
108            self.default_max_tokens()
109        }
110    }
111
112    /// Create a provider using Gemini Flash, reading the API key from the
113    /// conventional [`GEMINI_API_KEY`](Self::API_KEY_ENV) environment variable.
114    ///
115    /// # Panics
116    ///
117    /// Panics if `GEMINI_API_KEY` is not set. Prefer
118    /// [`try_from_env`](Self::try_from_env) outside of examples/tests.
119    #[must_use]
120    pub fn from_env() -> Self {
121        Self::try_from_env().unwrap_or_else(|e| panic!("{e}"))
122    }
123
124    /// Create a provider using Gemini Flash, reading the API key from the
125    /// conventional [`GEMINI_API_KEY`](Self::API_KEY_ENV) environment variable.
126    ///
127    /// # Errors
128    ///
129    /// Returns an error if `GEMINI_API_KEY` is unset or not valid UTF-8.
130    pub fn try_from_env() -> Result<Self> {
131        let api_key = std::env::var(Self::API_KEY_ENV).map_err(|_| {
132            anyhow::anyhow!("environment variable `{}` is not set", Self::API_KEY_ENV)
133        })?;
134        Ok(Self::flash(api_key))
135    }
136
137    /// Create a provider using Gemini 3 Flash Preview (fast and capable, current default).
138    #[must_use]
139    pub fn flash(api_key: impl Into<String>) -> Self {
140        Self::new(api_key, MODEL_GEMINI_3_FLASH)
141    }
142
143    /// Create a provider using Gemini 3.1 Flash Lite Preview.
144    #[must_use]
145    pub fn flash_lite_31(api_key: String) -> Self {
146        Self::new(api_key, MODEL_GEMINI_31_FLASH_LITE.to_owned())
147    }
148
149    /// Create a provider using Gemini 2.0 Flash Lite (fastest, most cost-effective).
150    #[must_use]
151    pub fn flash_lite(api_key: String) -> Self {
152        Self::new(api_key, MODEL_GEMINI_2_FLASH_LITE.to_owned())
153    }
154
155    /// Create a provider using Gemini 3.1 Pro Preview.
156    #[must_use]
157    pub fn pro_31(api_key: String) -> Self {
158        Self::new(api_key, MODEL_GEMINI_31_PRO.to_owned())
159    }
160
161    /// Create a provider using Gemini 3.1 Pro Preview (current recommended pro model).
162    #[must_use]
163    pub fn pro(api_key: String) -> Self {
164        Self::new(api_key, MODEL_GEMINI_31_PRO.to_owned())
165    }
166
167    /// Set the provider-owned thinking configuration for this model.
168    #[must_use]
169    pub const fn with_thinking(mut self, thinking: ThinkingConfig) -> Self {
170        self.thinking = Some(thinking);
171        self
172    }
173
174    /// Override the base URL.
175    #[must_use]
176    pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
177        self.base_url = base_url.into();
178        self
179    }
180
181    /// Send the API key via `x-goog-api-key` header instead of `?key=` query
182    /// parameter. Required when routing through proxies.
183    #[must_use]
184    pub const fn with_header_auth(mut self) -> Self {
185        self.use_header_auth = true;
186        self
187    }
188
189    /// Add extra HTTP headers applied to every request.
190    #[must_use]
191    pub fn with_extra_headers(mut self, headers: Vec<(String, String)>) -> Self {
192        self.extra_headers = headers;
193        self
194    }
195
196    /// Apply auth + extra headers. Skips provider auth when `api_key` is
197    /// empty (BYOK gateway mode).
198    fn apply_auth(&self, builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
199        let builder = if self.api_key.is_empty() {
200            builder
201        } else if self.use_header_auth {
202            builder.header("x-goog-api-key", &self.api_key)
203        } else {
204            builder.query(&[("key", &self.api_key)])
205        };
206        self.extra_headers
207            .iter()
208            .fold(builder, |b, (k, v)| b.header(k.as_str(), v.as_str()))
209    }
210}
211
212#[async_trait]
213#[allow(clippy::too_many_lines)]
214impl LlmProvider for GeminiProvider {
215    async fn chat(&self, request: ChatRequest) -> Result<ChatOutcome> {
216        let thinking = match self.resolve_thinking_config(request.thinking.as_ref()) {
217            Ok(thinking) => thinking,
218            Err(error) => return Ok(ChatOutcome::InvalidRequest(error.to_string())),
219        };
220        if let Err(error) = validate_request_attachments(self.provider(), self.model(), &request) {
221            return Ok(ChatOutcome::InvalidRequest(error.to_string()));
222        }
223        let contents = build_api_contents(&request.messages);
224        let tools = request
225            .tools
226            .as_ref()
227            .map(|t| convert_tools_to_config(t.clone()));
228        let tool_config = request
229            .tool_choice
230            .as_ref()
231            .map(ApiFunctionCallingConfig::from_tool_choice);
232        let system_instruction = if request.system.is_empty() {
233            None
234        } else {
235            Some(ApiContent {
236                role: None,
237                parts: vec![ApiPart::Text {
238                    text: request.system.clone(),
239                    thought_signature: None,
240                }],
241            })
242        };
243
244        let thinking_config = thinking.as_ref().map(map_thinking_config);
245        let (response_mime_type, response_schema) =
246            request.response_format.as_ref().map_or((None, None), |rf| {
247                (
248                    Some("application/json"),
249                    Some(gemini_response_schema(&rf.schema)),
250                )
251            });
252
253        let max_tokens = self.effective_max_tokens(&request);
254        let api_request = ApiGenerateContentRequest {
255            contents: &contents,
256            system_instruction: system_instruction.as_ref(),
257            tools: tools.as_ref().map(std::slice::from_ref),
258            tool_config,
259            generation_config: Some(ApiGenerationConfig {
260                max_output_tokens: Some(max_tokens),
261                thinking_config,
262                response_mime_type,
263                response_schema,
264            }),
265            cached_content: request.cached_content.as_deref(),
266        };
267
268        log::debug!(
269            "Gemini LLM request model={} max_tokens={}",
270            self.model,
271            max_tokens
272        );
273
274        let builder = self
275            .client
276            .post(format!(
277                "{}/models/{}:generateContent",
278                self.base_url, self.model
279            ))
280            .header("Content-Type", "application/json")
281            .timeout(std::time::Duration::from_secs(CHAT_READ_TIMEOUT_SECS));
282        let response = self
283            .apply_auth(builder)
284            .json(&api_request)
285            .send()
286            .await
287            .map_err(|e| anyhow::anyhow!("request failed: {e}"))?;
288
289        let status = response.status();
290        let bytes = response
291            .bytes()
292            .await
293            .map_err(|e| anyhow::anyhow!("failed to read response body: {e}"))?;
294
295        log::debug!(
296            "Gemini LLM response status={} body_len={}",
297            status,
298            bytes.len()
299        );
300
301        if status == StatusCode::TOO_MANY_REQUESTS {
302            return Ok(ChatOutcome::RateLimited);
303        }
304
305        if status.is_server_error() {
306            let body = String::from_utf8_lossy(&bytes);
307            log::error!("Gemini server error status={status} body={body}");
308            return Ok(ChatOutcome::ServerError(body.into_owned()));
309        }
310
311        if status.is_client_error() {
312            let body = String::from_utf8_lossy(&bytes);
313            log::warn!("Gemini client error status={status} body={body}");
314            return Ok(ChatOutcome::InvalidRequest(body.into_owned()));
315        }
316
317        let api_response: ApiGenerateContentResponse = serde_json::from_slice(&bytes)
318            .map_err(|e| anyhow::anyhow!("failed to parse response: {e}"))?;
319
320        let candidate = api_response
321            .candidates
322            .into_iter()
323            .next()
324            .ok_or_else(|| anyhow::anyhow!("no candidates in response"))?;
325
326        let content = build_content_blocks(&candidate.content);
327
328        if content.is_empty() && !candidate.content.parts.is_empty() {
329            log::warn!(
330                "Gemini parts not converted to content blocks raw_parts={:?}",
331                candidate.content.parts
332            );
333        }
334
335        let has_tool_calls = content
336            .iter()
337            .any(|b| matches!(b, agent_sdk_foundation::llm::ContentBlock::ToolUse { .. }));
338
339        let stop_reason = candidate
340            .finish_reason
341            .as_ref()
342            .map(|r| map_finish_reason(r, has_tool_calls));
343
344        let usage = api_response
345            .usage_metadata
346            .unwrap_or(ApiUsageMetadata {
347                prompt: 0,
348                candidates: 0,
349                cached_content: 0,
350            })
351            .into_usage();
352
353        Ok(ChatOutcome::Success(ChatResponse {
354            id: String::new(),
355            content,
356            model: self.model.clone(),
357            stop_reason,
358            usage,
359        }))
360    }
361
362    fn chat_stream(&self, request: ChatRequest) -> StreamBox<'_> {
363        Box::pin(async_stream::stream! {
364            let thinking = match self.resolve_thinking_config(request.thinking.as_ref()) {
365                Ok(thinking) => thinking,
366                Err(error) => {
367                    yield Ok(StreamDelta::Error {
368                        message: error.to_string(),
369                        kind: StreamErrorKind::InvalidRequest,
370                    });
371                    return;
372                }
373            };
374            if let Err(error) = validate_request_attachments(self.provider(), self.model(), &request) {
375                yield Ok(StreamDelta::Error {
376                    message: error.to_string(),
377                    kind: StreamErrorKind::InvalidRequest,
378                });
379                return;
380            }
381            let contents = build_api_contents(&request.messages);
382            let tools = request
383            .tools
384            .as_ref()
385            .map(|t| convert_tools_to_config(t.clone()));
386            let tool_config = request
387                .tool_choice
388                .as_ref()
389                .map(ApiFunctionCallingConfig::from_tool_choice);
390            let system_instruction = if request.system.is_empty() {
391                None
392            } else {
393                Some(ApiContent {
394                    role: None,
395                    parts: vec![ApiPart::Text {
396                        text: request.system.clone(),
397                        thought_signature: None,
398                    }],
399                })
400            };
401
402            let thinking_config = thinking.as_ref().map(map_thinking_config);
403            let (response_mime_type, response_schema) = request
404                .response_format
405                .as_ref()
406                .map_or((None, None), |rf| {
407                    (
408                        Some("application/json"),
409                        Some(gemini_response_schema(&rf.schema)),
410                    )
411                });
412
413            let max_tokens = self.effective_max_tokens(&request);
414            let api_request = ApiGenerateContentRequest {
415                contents: &contents,
416                system_instruction: system_instruction.as_ref(),
417                tools: tools.as_ref().map(std::slice::from_ref),
418                tool_config,
419                generation_config: Some(ApiGenerationConfig {
420                    max_output_tokens: Some(max_tokens),
421                    thinking_config,
422                    response_mime_type,
423                    response_schema,
424                }),
425                cached_content: request.cached_content.as_deref(),
426            };
427
428            log::debug!(
429                "Gemini streaming LLM request model={} max_tokens={}",
430                self.model,
431                max_tokens
432            );
433
434            let stream_builder = self
435                .client
436                .post(format!(
437                    "{}/models/{}:streamGenerateContent",
438                    self.base_url, self.model
439                ))
440                .header("Content-Type", "application/json")
441                .query(&[("alt", "sse")]);
442            let response = match self
443                .apply_auth(stream_builder)
444                .json(&api_request)
445                .send()
446                .await
447            {
448                Ok(r) => r,
449                Err(e) => {
450                    // Include the cause so 401 detection / diagnostics survive.
451                    yield Err(anyhow::anyhow!("request failed: {e}"));
452                    return;
453                }
454            };
455
456            let status = response.status();
457            if !status.is_success() {
458                let body = response.text().await.unwrap_or_default();
459                let kind = if status == StatusCode::TOO_MANY_REQUESTS {
460                    StreamErrorKind::RateLimited
461                } else if status.is_server_error() {
462                    StreamErrorKind::ServerError
463                } else {
464                    StreamErrorKind::InvalidRequest
465                };
466                log::warn!("Gemini error status={status} body={body}");
467                yield Ok(StreamDelta::Error {
468                    message: body,
469                    kind,
470                });
471                return;
472            }
473
474            let mut inner = data::stream_gemini_response(response);
475            while let Some(item) = futures::StreamExt::next(&mut inner).await {
476                yield item;
477            }
478        })
479    }
480
481    fn model(&self) -> &str {
482        &self.model
483    }
484
485    fn provider(&self) -> &'static str {
486        "gemini"
487    }
488
489    fn configured_thinking(&self) -> Option<&ThinkingConfig> {
490        self.thinking.as_ref()
491    }
492}
493
494#[cfg(test)]
495mod tests {
496    use super::*;
497
498    #[test]
499    fn test_new_creates_provider_with_custom_model() {
500        let provider = GeminiProvider::new("test-api-key".to_string(), "custom-model".to_string());
501
502        assert_eq!(provider.model(), "custom-model");
503        assert_eq!(provider.provider(), "gemini");
504    }
505
506    #[test]
507    fn test_flash_factory_creates_flash_provider() {
508        let provider = GeminiProvider::flash("test-api-key".to_string());
509
510        assert_eq!(provider.model(), MODEL_GEMINI_3_FLASH);
511        assert_eq!(provider.provider(), "gemini");
512    }
513
514    #[test]
515    fn test_flash_lite_factory_creates_flash_lite_provider() {
516        let provider = GeminiProvider::flash_lite("test-api-key".to_string());
517
518        assert_eq!(provider.model(), MODEL_GEMINI_2_FLASH_LITE);
519        assert_eq!(provider.provider(), "gemini");
520    }
521
522    #[test]
523    fn test_flash_lite_31_factory_creates_flash_lite_provider() {
524        let provider = GeminiProvider::flash_lite_31("test-api-key".to_string());
525
526        assert_eq!(provider.model(), MODEL_GEMINI_31_FLASH_LITE);
527        assert_eq!(provider.provider(), "gemini");
528    }
529
530    #[test]
531    fn test_pro_factory_creates_pro_provider() {
532        let provider = GeminiProvider::pro("test-api-key".to_string());
533
534        assert_eq!(provider.model(), MODEL_GEMINI_31_PRO);
535        assert_eq!(provider.provider(), "gemini");
536    }
537
538    #[test]
539    fn test_pro_31_factory_creates_pro_provider() {
540        let provider = GeminiProvider::pro_31("test-api-key".to_string());
541
542        assert_eq!(provider.model(), MODEL_GEMINI_31_PRO);
543        assert_eq!(provider.provider(), "gemini");
544    }
545
546    #[test]
547    fn test_model_constants_have_expected_values() {
548        assert_eq!(MODEL_GEMINI_31_PRO, "gemini-3.1-pro-preview");
549        assert_eq!(MODEL_GEMINI_31_FLASH_LITE, "gemini-3.1-flash-lite-preview");
550        assert_eq!(MODEL_GEMINI_3_FLASH, "gemini-3-flash-preview");
551        assert_eq!(MODEL_GEMINI_3_PRO, "gemini-3.0-pro");
552        assert_eq!(MODEL_GEMINI_25_FLASH, "gemini-2.5-flash");
553        assert_eq!(MODEL_GEMINI_25_PRO, "gemini-2.5-pro");
554        assert_eq!(MODEL_GEMINI_2_FLASH, "gemini-2.0-flash");
555        assert_eq!(MODEL_GEMINI_2_FLASH_LITE, "gemini-2.0-flash-lite");
556    }
557
558    #[test]
559    fn test_gemini_20_models_reject_thinking() {
560        let provider = GeminiProvider::flash_lite("test-api-key".to_string());
561        let error = provider
562            .validate_thinking_config(Some(&ThinkingConfig::new(10_000)))
563            .unwrap_err();
564        assert!(error.to_string().contains("thinking is not supported"));
565    }
566
567    #[test]
568    fn test_default_uses_header_auth() {
569        let provider = GeminiProvider::new("test-key".to_string(), "model".to_string());
570        assert!(
571            provider.use_header_auth,
572            "Default should use header auth for security"
573        );
574    }
575
576    #[test]
577    fn test_provider_is_cloneable() {
578        let provider = GeminiProvider::new("test-api-key".to_string(), "test-model".to_string());
579        let cloned = provider.clone();
580
581        assert_eq!(provider.model(), cloned.model());
582        assert_eq!(provider.provider(), cloned.provider());
583    }
584
585    fn request_with_max_tokens(max_tokens: u32, explicit: bool) -> ChatRequest {
586        ChatRequest {
587            system: String::new(),
588            messages: vec![agent_sdk_foundation::llm::Message::user("hi")],
589            tools: None,
590            max_tokens,
591            max_tokens_explicit: explicit,
592            session_id: None,
593            cached_content: None,
594            thinking: None,
595            tool_choice: None,
596            response_format: None,
597        }
598    }
599
600    #[test]
601    fn test_effective_max_tokens_honors_explicit_budget() {
602        let provider = GeminiProvider::pro("test-api-key".to_string());
603        let request = request_with_max_tokens(123, true);
604        assert_eq!(provider.effective_max_tokens(&request), 123);
605    }
606
607    #[test]
608    fn test_effective_max_tokens_uses_default_when_implicit() {
609        // An implicit budget must fall back to the provider/model default, not
610        // be silently capped at ChatRequest::DEFAULT_MAX_TOKENS.
611        let provider = GeminiProvider::pro("test-api-key".to_string());
612        let request = request_with_max_tokens(4096, false);
613        assert_eq!(
614            provider.effective_max_tokens(&request),
615            provider.default_max_tokens()
616        );
617    }
618}