Skip to main content

llm/providers/gemini/
provider.rs

1use crate::provider::get_context_window;
2use crate::providers::openai_compatible::{build_chat_request, create_custom_stream_generic};
3use crate::{Context, LlmError, LlmResponseStream, ProviderFactory, Result, StreamingModelProvider};
4use async_stream::stream;
5use futures::StreamExt;
6use std::env::var;
7
8pub const GEMINI_API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta/openai/";
9
10#[derive(Clone)]
11pub struct GeminiProvider {
12    api_key: Option<String>,
13    model: String,
14}
15
16impl GeminiProvider {
17    pub fn new(api_key: Option<String>) -> Self {
18        Self { api_key, model: String::new() }
19    }
20
21    fn get_api_key(&self) -> Result<String> {
22        if let Some(key) = &self.api_key {
23            return Ok(key.clone());
24        }
25
26        if let Ok(api_key) = var("GEMINI_API_KEY") {
27            return Ok(api_key);
28        }
29
30        Err(LlmError::MissingApiKey(
31            "GEMINI_API_KEY not set. Set the environment variable or provide an API key.".to_string(),
32        ))
33    }
34
35    fn build_openai_client(api_key: &str) -> async_openai::Client<async_openai::config::OpenAIConfig> {
36        let config = async_openai::config::OpenAIConfig::new().with_api_key(api_key).with_api_base(GEMINI_API_BASE);
37        async_openai::Client::with_config(config)
38    }
39}
40
41impl ProviderFactory for GeminiProvider {
42    async fn from_env() -> Result<Self> {
43        Ok(Self::new(None))
44    }
45
46    fn with_model(mut self, model: &str) -> Self {
47        self.model = model.to_string();
48        self
49    }
50}
51
52impl StreamingModelProvider for GeminiProvider {
53    fn model(&self) -> Option<crate::LlmModel> {
54        format!("gemini:{}", self.model).parse().ok()
55    }
56
57    fn context_window(&self) -> Option<u32> {
58        get_context_window("gemini", &self.model)
59    }
60
61    fn stream_response(&self, context: &Context) -> LlmResponseStream {
62        let provider = self.clone();
63        let context = context.clone();
64
65        Box::pin(stream! {
66            let api_key = match provider.get_api_key() {
67                Ok(key) => key,
68                Err(e) => {
69                    yield Err(e);
70                    return;
71                }
72            };
73
74            tracing::info!("Using Gemini API with API key (OpenAI-compatible endpoint)");
75            let client = Self::build_openai_client(&api_key);
76            let request = match build_chat_request(&provider.model, &context) {
77                Ok(req) => req,
78                Err(e) => {
79                    yield Err(e);
80                    return;
81                }
82            };
83            let mut inner_stream =
84                create_custom_stream_generic(&client, request);
85
86            while let Some(result) = inner_stream.next().await {
87                yield result;
88            }
89        })
90    }
91
92    fn display_name(&self) -> String {
93        format!("Gemini ({})", self.model)
94    }
95}
96
97#[cfg(test)]
98mod tests {
99    use super::*;
100
101    #[test]
102    fn test_provider_display_name() {
103        let provider = GeminiProvider::new(None).with_model("gemini-2.0-flash");
104        assert_eq!(provider.display_name(), "Gemini (gemini-2.0-flash)");
105    }
106}