Skip to main content

agent_sdk_providers/impls/
gemini.rs

1//! Google Gemini API provider implementation.
2//!
3//! This module provides an implementation of `LlmProvider` for the Google Gemini
4//! API (`generativelanguage.googleapis.com`).
5
6pub(crate) mod data;
7
8use crate::attachments::validate_request_attachments;
9use crate::provider::LlmProvider;
10use crate::streaming::{StreamBox, StreamDelta, StreamErrorKind};
11use agent_sdk_foundation::llm::{ChatOutcome, ChatRequest, ChatResponse, ThinkingConfig};
12use anyhow::Result;
13use async_trait::async_trait;
14use data::{
15    ApiContent, ApiFunctionCallingConfig, ApiGenerateContentRequest, ApiGenerateContentResponse,
16    ApiGenerationConfig, ApiPart, ApiUsageMetadata, build_api_contents, build_content_blocks,
17    convert_tools_to_config, gemini_response_schema, map_finish_reason, map_thinking_config,
18};
19use reqwest::StatusCode;
20
21const API_BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta";
22
23// Gemini 3.1 series
24pub const MODEL_GEMINI_31_PRO: &str = "gemini-3.1-pro-preview";
25pub const MODEL_GEMINI_31_FLASH_LITE: &str = "gemini-3.1-flash-lite-preview";
26
27// Gemini 3 series
28pub const MODEL_GEMINI_3_FLASH: &str = "gemini-3-flash-preview";
29
30// Legacy Gemini 3.0 Pro model kept for explicit opt-in.
31pub const MODEL_GEMINI_3_PRO: &str = "gemini-3.0-pro";
32
33// Gemini 2.5 series
34pub const MODEL_GEMINI_25_FLASH: &str = "gemini-2.5-flash";
35pub const MODEL_GEMINI_25_PRO: &str = "gemini-2.5-pro";
36
37// Gemini 2.0 series
38pub const MODEL_GEMINI_2_FLASH: &str = "gemini-2.0-flash";
39pub const MODEL_GEMINI_2_FLASH_LITE: &str = "gemini-2.0-flash-lite";
40
41/// Google Gemini LLM provider.
42#[derive(Clone)]
43pub struct GeminiProvider {
44    client: reqwest::Client,
45    api_key: String,
46    model: String,
47    base_url: String,
48    thinking: Option<ThinkingConfig>,
49    /// When true, send the API key via `x-goog-api-key` header instead of a
50    /// query parameter. Required when routing through proxies.
51    use_header_auth: bool,
52    /// Extra headers applied to every request (e.g. for gateway authentication).
53    extra_headers: Vec<(String, String)>,
54}
55
56impl GeminiProvider {
57    /// The conventional environment variable holding the Gemini API key.
58    pub const API_KEY_ENV: &'static str = "GEMINI_API_KEY";
59
60    /// Create a new Gemini provider with the specified API key and model.
61    #[must_use]
62    pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
63        Self {
64            client: reqwest::Client::new(),
65            api_key: api_key.into(),
66            model: model.into(),
67            base_url: API_BASE_URL.to_owned(),
68            thinking: None,
69            use_header_auth: true,
70            extra_headers: Vec::new(),
71        }
72    }
73
74    /// Create a provider using Gemini Flash, reading the API key from the
75    /// conventional [`GEMINI_API_KEY`](Self::API_KEY_ENV) environment variable.
76    ///
77    /// # Panics
78    ///
79    /// Panics if `GEMINI_API_KEY` is not set. Prefer
80    /// [`try_from_env`](Self::try_from_env) outside of examples/tests.
81    #[must_use]
82    pub fn from_env() -> Self {
83        Self::try_from_env().unwrap_or_else(|e| panic!("{e}"))
84    }
85
86    /// Create a provider using Gemini Flash, reading the API key from the
87    /// conventional [`GEMINI_API_KEY`](Self::API_KEY_ENV) environment variable.
88    ///
89    /// # Errors
90    ///
91    /// Returns an error if `GEMINI_API_KEY` is unset or not valid UTF-8.
92    pub fn try_from_env() -> Result<Self> {
93        let api_key = std::env::var(Self::API_KEY_ENV).map_err(|_| {
94            anyhow::anyhow!("environment variable `{}` is not set", Self::API_KEY_ENV)
95        })?;
96        Ok(Self::flash(api_key))
97    }
98
99    /// Create a provider using Gemini 3 Flash Preview (fast and capable, current default).
100    #[must_use]
101    pub fn flash(api_key: impl Into<String>) -> Self {
102        Self::new(api_key, MODEL_GEMINI_3_FLASH)
103    }
104
105    /// Create a provider using Gemini 3.1 Flash Lite Preview.
106    #[must_use]
107    pub fn flash_lite_31(api_key: String) -> Self {
108        Self::new(api_key, MODEL_GEMINI_31_FLASH_LITE.to_owned())
109    }
110
111    /// Create a provider using Gemini 2.0 Flash Lite (fastest, most cost-effective).
112    #[must_use]
113    pub fn flash_lite(api_key: String) -> Self {
114        Self::new(api_key, MODEL_GEMINI_2_FLASH_LITE.to_owned())
115    }
116
117    /// Create a provider using Gemini 3.1 Pro Preview.
118    #[must_use]
119    pub fn pro_31(api_key: String) -> Self {
120        Self::new(api_key, MODEL_GEMINI_31_PRO.to_owned())
121    }
122
123    /// Create a provider using Gemini 3.1 Pro Preview (current recommended pro model).
124    #[must_use]
125    pub fn pro(api_key: String) -> Self {
126        Self::new(api_key, MODEL_GEMINI_31_PRO.to_owned())
127    }
128
129    /// Set the provider-owned thinking configuration for this model.
130    #[must_use]
131    pub const fn with_thinking(mut self, thinking: ThinkingConfig) -> Self {
132        self.thinking = Some(thinking);
133        self
134    }
135
136    /// Override the base URL.
137    #[must_use]
138    pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
139        self.base_url = base_url.into();
140        self
141    }
142
143    /// Send the API key via `x-goog-api-key` header instead of `?key=` query
144    /// parameter. Required when routing through proxies.
145    #[must_use]
146    pub const fn with_header_auth(mut self) -> Self {
147        self.use_header_auth = true;
148        self
149    }
150
151    /// Add extra HTTP headers applied to every request.
152    #[must_use]
153    pub fn with_extra_headers(mut self, headers: Vec<(String, String)>) -> Self {
154        self.extra_headers = headers;
155        self
156    }
157
158    /// Apply auth + extra headers. Skips provider auth when `api_key` is
159    /// empty (BYOK gateway mode).
160    fn apply_auth(&self, builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
161        let builder = if self.api_key.is_empty() {
162            builder
163        } else if self.use_header_auth {
164            builder.header("x-goog-api-key", &self.api_key)
165        } else {
166            builder.query(&[("key", &self.api_key)])
167        };
168        self.extra_headers
169            .iter()
170            .fold(builder, |b, (k, v)| b.header(k.as_str(), v.as_str()))
171    }
172}
173
174#[async_trait]
175#[allow(clippy::too_many_lines)]
176impl LlmProvider for GeminiProvider {
177    async fn chat(&self, request: ChatRequest) -> Result<ChatOutcome> {
178        let thinking = match self.resolve_thinking_config(request.thinking.as_ref()) {
179            Ok(thinking) => thinking,
180            Err(error) => return Ok(ChatOutcome::InvalidRequest(error.to_string())),
181        };
182        if let Err(error) = validate_request_attachments(self.provider(), self.model(), &request) {
183            return Ok(ChatOutcome::InvalidRequest(error.to_string()));
184        }
185        let contents = build_api_contents(&request.messages);
186        let tools = request.tools.map(convert_tools_to_config);
187        let tool_config = request
188            .tool_choice
189            .as_ref()
190            .map(ApiFunctionCallingConfig::from_tool_choice);
191        let system_instruction = if request.system.is_empty() {
192            None
193        } else {
194            Some(ApiContent {
195                role: None,
196                parts: vec![ApiPart::Text {
197                    text: request.system.clone(),
198                    thought_signature: None,
199                }],
200            })
201        };
202
203        let thinking_config = thinking.as_ref().map(map_thinking_config);
204        let (response_mime_type, response_schema) =
205            request.response_format.as_ref().map_or((None, None), |rf| {
206                (
207                    Some("application/json"),
208                    Some(gemini_response_schema(&rf.schema)),
209                )
210            });
211
212        let api_request = ApiGenerateContentRequest {
213            contents: &contents,
214            system_instruction: system_instruction.as_ref(),
215            tools: tools.as_ref().map(std::slice::from_ref),
216            tool_config,
217            generation_config: Some(ApiGenerationConfig {
218                max_output_tokens: Some(request.max_tokens),
219                thinking_config,
220                response_mime_type,
221                response_schema,
222            }),
223            cached_content: request.cached_content.as_deref(),
224        };
225
226        log::debug!(
227            "Gemini LLM request model={} max_tokens={}",
228            self.model,
229            request.max_tokens
230        );
231
232        let builder = self
233            .client
234            .post(format!(
235                "{}/models/{}:generateContent",
236                self.base_url, self.model
237            ))
238            .header("Content-Type", "application/json");
239        let response = self
240            .apply_auth(builder)
241            .json(&api_request)
242            .send()
243            .await
244            .map_err(|e| anyhow::anyhow!("request failed: {e}"))?;
245
246        let status = response.status();
247        let bytes = response
248            .bytes()
249            .await
250            .map_err(|e| anyhow::anyhow!("failed to read response body: {e}"))?;
251
252        log::debug!(
253            "Gemini LLM response status={} body_len={}",
254            status,
255            bytes.len()
256        );
257
258        if status == StatusCode::TOO_MANY_REQUESTS {
259            return Ok(ChatOutcome::RateLimited);
260        }
261
262        if status.is_server_error() {
263            let body = String::from_utf8_lossy(&bytes);
264            log::error!("Gemini server error status={status} body={body}");
265            return Ok(ChatOutcome::ServerError(body.into_owned()));
266        }
267
268        if status.is_client_error() {
269            let body = String::from_utf8_lossy(&bytes);
270            log::warn!("Gemini client error status={status} body={body}");
271            return Ok(ChatOutcome::InvalidRequest(body.into_owned()));
272        }
273
274        let api_response: ApiGenerateContentResponse = serde_json::from_slice(&bytes)
275            .map_err(|e| anyhow::anyhow!("failed to parse response: {e}"))?;
276
277        let candidate = api_response
278            .candidates
279            .into_iter()
280            .next()
281            .ok_or_else(|| anyhow::anyhow!("no candidates in response"))?;
282
283        let content = build_content_blocks(&candidate.content);
284
285        if content.is_empty() && !candidate.content.parts.is_empty() {
286            log::warn!(
287                "Gemini parts not converted to content blocks raw_parts={:?}",
288                candidate.content.parts
289            );
290        }
291
292        let has_tool_calls = content
293            .iter()
294            .any(|b| matches!(b, agent_sdk_foundation::llm::ContentBlock::ToolUse { .. }));
295
296        let stop_reason = candidate
297            .finish_reason
298            .as_ref()
299            .map(|r| map_finish_reason(r, has_tool_calls));
300
301        let usage = api_response
302            .usage_metadata
303            .unwrap_or(ApiUsageMetadata {
304                prompt: 0,
305                candidates: 0,
306                cached_content: 0,
307            })
308            .into_usage();
309
310        Ok(ChatOutcome::Success(ChatResponse {
311            id: String::new(),
312            content,
313            model: self.model.clone(),
314            stop_reason,
315            usage,
316        }))
317    }
318
319    fn chat_stream(&self, request: ChatRequest) -> StreamBox<'_> {
320        Box::pin(async_stream::stream! {
321            let thinking = match self.resolve_thinking_config(request.thinking.as_ref()) {
322                Ok(thinking) => thinking,
323                Err(error) => {
324                    yield Ok(StreamDelta::Error {
325                        message: error.to_string(),
326                        kind: StreamErrorKind::InvalidRequest,
327                    });
328                    return;
329                }
330            };
331            if let Err(error) = validate_request_attachments(self.provider(), self.model(), &request) {
332                yield Ok(StreamDelta::Error {
333                    message: error.to_string(),
334                    kind: StreamErrorKind::InvalidRequest,
335                });
336                return;
337            }
338            let contents = build_api_contents(&request.messages);
339            let tools = request.tools.map(convert_tools_to_config);
340            let tool_config = request
341                .tool_choice
342                .as_ref()
343                .map(ApiFunctionCallingConfig::from_tool_choice);
344            let system_instruction = if request.system.is_empty() {
345                None
346            } else {
347                Some(ApiContent {
348                    role: None,
349                    parts: vec![ApiPart::Text {
350                        text: request.system.clone(),
351                        thought_signature: None,
352                    }],
353                })
354            };
355
356            let thinking_config = thinking.as_ref().map(map_thinking_config);
357            let (response_mime_type, response_schema) = request
358                .response_format
359                .as_ref()
360                .map_or((None, None), |rf| {
361                    (
362                        Some("application/json"),
363                        Some(gemini_response_schema(&rf.schema)),
364                    )
365                });
366
367            let api_request = ApiGenerateContentRequest {
368                contents: &contents,
369                system_instruction: system_instruction.as_ref(),
370                tools: tools.as_ref().map(std::slice::from_ref),
371                tool_config,
372                generation_config: Some(ApiGenerationConfig {
373                    max_output_tokens: Some(request.max_tokens),
374                    thinking_config,
375                    response_mime_type,
376                    response_schema,
377                }),
378                cached_content: request.cached_content.as_deref(),
379            };
380
381            log::debug!(
382                "Gemini streaming LLM request model={} max_tokens={}",
383                self.model,
384                request.max_tokens
385            );
386
387            let stream_builder = self
388                .client
389                .post(format!(
390                    "{}/models/{}:streamGenerateContent",
391                    self.base_url, self.model
392                ))
393                .header("Content-Type", "application/json")
394                .query(&[("alt", "sse")]);
395            let Ok(response) = self
396                .apply_auth(stream_builder)
397                .json(&api_request)
398                .send()
399                .await
400            else {
401                yield Err(anyhow::anyhow!("request failed"));
402                return;
403            };
404
405            let status = response.status();
406            if !status.is_success() {
407                let body = response.text().await.unwrap_or_default();
408                let kind = if status == StatusCode::TOO_MANY_REQUESTS {
409                    StreamErrorKind::RateLimited
410                } else if status.is_server_error() {
411                    StreamErrorKind::ServerError
412                } else {
413                    StreamErrorKind::InvalidRequest
414                };
415                log::warn!("Gemini error status={status} body={body}");
416                yield Ok(StreamDelta::Error {
417                    message: body,
418                    kind,
419                });
420                return;
421            }
422
423            let mut inner = data::stream_gemini_response(response);
424            while let Some(item) = futures::StreamExt::next(&mut inner).await {
425                yield item;
426            }
427        })
428    }
429
430    fn model(&self) -> &str {
431        &self.model
432    }
433
434    fn provider(&self) -> &'static str {
435        "gemini"
436    }
437
438    fn configured_thinking(&self) -> Option<&ThinkingConfig> {
439        self.thinking.as_ref()
440    }
441}
442
443#[cfg(test)]
444mod tests {
445    use super::*;
446
447    #[test]
448    fn test_new_creates_provider_with_custom_model() {
449        let provider = GeminiProvider::new("test-api-key".to_string(), "custom-model".to_string());
450
451        assert_eq!(provider.model(), "custom-model");
452        assert_eq!(provider.provider(), "gemini");
453    }
454
455    #[test]
456    fn test_flash_factory_creates_flash_provider() {
457        let provider = GeminiProvider::flash("test-api-key".to_string());
458
459        assert_eq!(provider.model(), MODEL_GEMINI_3_FLASH);
460        assert_eq!(provider.provider(), "gemini");
461    }
462
463    #[test]
464    fn test_flash_lite_factory_creates_flash_lite_provider() {
465        let provider = GeminiProvider::flash_lite("test-api-key".to_string());
466
467        assert_eq!(provider.model(), MODEL_GEMINI_2_FLASH_LITE);
468        assert_eq!(provider.provider(), "gemini");
469    }
470
471    #[test]
472    fn test_flash_lite_31_factory_creates_flash_lite_provider() {
473        let provider = GeminiProvider::flash_lite_31("test-api-key".to_string());
474
475        assert_eq!(provider.model(), MODEL_GEMINI_31_FLASH_LITE);
476        assert_eq!(provider.provider(), "gemini");
477    }
478
479    #[test]
480    fn test_pro_factory_creates_pro_provider() {
481        let provider = GeminiProvider::pro("test-api-key".to_string());
482
483        assert_eq!(provider.model(), MODEL_GEMINI_31_PRO);
484        assert_eq!(provider.provider(), "gemini");
485    }
486
487    #[test]
488    fn test_pro_31_factory_creates_pro_provider() {
489        let provider = GeminiProvider::pro_31("test-api-key".to_string());
490
491        assert_eq!(provider.model(), MODEL_GEMINI_31_PRO);
492        assert_eq!(provider.provider(), "gemini");
493    }
494
495    #[test]
496    fn test_model_constants_have_expected_values() {
497        assert_eq!(MODEL_GEMINI_31_PRO, "gemini-3.1-pro-preview");
498        assert_eq!(MODEL_GEMINI_31_FLASH_LITE, "gemini-3.1-flash-lite-preview");
499        assert_eq!(MODEL_GEMINI_3_FLASH, "gemini-3-flash-preview");
500        assert_eq!(MODEL_GEMINI_3_PRO, "gemini-3.0-pro");
501        assert_eq!(MODEL_GEMINI_25_FLASH, "gemini-2.5-flash");
502        assert_eq!(MODEL_GEMINI_25_PRO, "gemini-2.5-pro");
503        assert_eq!(MODEL_GEMINI_2_FLASH, "gemini-2.0-flash");
504        assert_eq!(MODEL_GEMINI_2_FLASH_LITE, "gemini-2.0-flash-lite");
505    }
506
507    #[test]
508    fn test_gemini_20_models_reject_thinking() {
509        let provider = GeminiProvider::flash_lite("test-api-key".to_string());
510        let error = provider
511            .validate_thinking_config(Some(&ThinkingConfig::new(10_000)))
512            .unwrap_err();
513        assert!(error.to_string().contains("thinking is not supported"));
514    }
515
516    #[test]
517    fn test_default_uses_header_auth() {
518        let provider = GeminiProvider::new("test-key".to_string(), "model".to_string());
519        assert!(
520            provider.use_header_auth,
521            "Default should use header auth for security"
522        );
523    }
524
525    #[test]
526    fn test_provider_is_cloneable() {
527        let provider = GeminiProvider::new("test-api-key".to_string(), "test-model".to_string());
528        let cloned = provider.clone();
529
530        assert_eq!(provider.model(), cloned.model());
531        assert_eq!(provider.provider(), cloned.provider());
532    }
533}