Skip to main content

llm/providers/gemini/
provider.rs

1use crate::provider::get_context_window;
2use crate::providers::openai_compatible::{build_chat_request, create_custom_stream_generic};
3use crate::{
4    Context, LlmError, LlmResponseStream, ProviderFactory, Result, StreamingModelProvider,
5};
6use async_stream::stream;
7use futures::StreamExt;
8use std::env::var;
9
10pub const GEMINI_API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta/openai/";
11
12#[derive(Clone)]
13pub struct GeminiProvider {
14    api_key: Option<String>,
15    model: String,
16}
17
18impl GeminiProvider {
19    pub fn new(api_key: Option<String>) -> Self {
20        Self {
21            api_key,
22            model: String::new(),
23        }
24    }
25
26    fn get_api_key(&self) -> Result<String> {
27        if let Some(key) = &self.api_key {
28            return Ok(key.clone());
29        }
30
31        if let Ok(api_key) = var("GEMINI_API_KEY") {
32            return Ok(api_key);
33        }
34
35        Err(LlmError::MissingApiKey(
36            "GEMINI_API_KEY not set. Set the environment variable or provide an API key."
37                .to_string(),
38        ))
39    }
40
41    fn build_openai_client(
42        api_key: &str,
43    ) -> async_openai::Client<async_openai::config::OpenAIConfig> {
44        let config = async_openai::config::OpenAIConfig::new()
45            .with_api_key(api_key)
46            .with_api_base(GEMINI_API_BASE);
47        async_openai::Client::with_config(config)
48    }
49}
50
51impl ProviderFactory for GeminiProvider {
52    fn from_env() -> Result<Self> {
53        Ok(Self::new(None))
54    }
55
56    fn with_model(mut self, model: &str) -> Self {
57        self.model = model.to_string();
58        self
59    }
60}
61
62impl StreamingModelProvider for GeminiProvider {
63    fn model(&self) -> Option<crate::LlmModel> {
64        format!("gemini:{}", self.model).parse().ok()
65    }
66
67    fn context_window(&self) -> Option<u32> {
68        get_context_window("gemini", &self.model)
69    }
70
71    fn stream_response(&self, context: &Context) -> LlmResponseStream {
72        let provider = self.clone();
73        let context = context.clone();
74
75        Box::pin(stream! {
76            let api_key = match provider.get_api_key() {
77                Ok(key) => key,
78                Err(e) => {
79                    yield Err(e);
80                    return;
81                }
82            };
83
84            tracing::info!("Using Gemini API with API key (OpenAI-compatible endpoint)");
85            let client = Self::build_openai_client(&api_key);
86            let request = match build_chat_request(&provider.model, &context) {
87                Ok(req) => req,
88                Err(e) => {
89                    yield Err(e);
90                    return;
91                }
92            };
93            let mut inner_stream =
94                create_custom_stream_generic(&client, request);
95
96            while let Some(result) = inner_stream.next().await {
97                yield result;
98            }
99        })
100    }
101
102    fn display_name(&self) -> String {
103        format!("Gemini ({})", self.model)
104    }
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110
111    #[test]
112    fn test_provider_display_name() {
113        let provider = GeminiProvider::new(None).with_model("gemini-2.0-flash");
114        assert_eq!(provider.display_name(), "Gemini (gemini-2.0-flash)");
115    }
116}