llm/providers/gemini/
provider.rs1use crate::provider::get_context_window;
2use crate::providers::openai_compatible::{AetherOpenAiConfig, build_chat_request, create_custom_stream_generic};
3use crate::{
4 Context, LlmError, LlmResponseStream, ProviderAuthMode, ProviderConnectionConfig, ProviderFactory, Result,
5 StreamingModelProvider,
6};
7use async_stream::stream;
8use futures::StreamExt;
9use std::env::var;
10
11pub const GEMINI_API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta/openai/";
12
13#[derive(Clone)]
14pub struct GeminiProvider {
15 api_key: Option<String>,
16 base_url: Option<String>,
17 auth_mode: ProviderAuthMode,
18 model: String,
19}
20
21impl GeminiProvider {
22 pub fn new(api_key: Option<String>) -> Self {
23 Self { api_key, base_url: None, auth_mode: ProviderAuthMode::Default, model: String::new() }
24 }
25
26 pub fn with_connection(mut self, connection: ProviderConnectionConfig) -> Self {
27 self.base_url = connection.base_url;
28 self.auth_mode = connection.auth_mode;
29 self
30 }
31
32 fn get_api_key(&self) -> Result<String> {
33 if self.auth_mode == ProviderAuthMode::None {
34 return Ok(String::new());
35 }
36 if let Some(key) = &self.api_key {
37 return Ok(key.clone());
38 }
39
40 if let Ok(api_key) = var("GEMINI_API_KEY") {
41 return Ok(api_key);
42 }
43
44 Err(LlmError::MissingApiKey(
45 "GEMINI_API_KEY not set. Set the environment variable or provide an API key.".to_string(),
46 ))
47 }
48
49 fn build_openai_client(&self, api_key: &str) -> async_openai::Client<AetherOpenAiConfig> {
50 let api_base = self.base_url.as_deref().unwrap_or(GEMINI_API_BASE);
51 let config = async_openai::config::OpenAIConfig::new().with_api_key(api_key).with_api_base(api_base);
52 async_openai::Client::with_config(AetherOpenAiConfig::new(config, self.auth_mode))
53 }
54}
55
56impl ProviderFactory for GeminiProvider {
57 async fn from_env() -> Result<Self> {
58 Ok(Self::new(None))
59 }
60
61 async fn from_env_with_connection(connection: ProviderConnectionConfig) -> Result<Self> {
62 Ok(Self::new(None).with_connection(connection))
63 }
64
65 fn with_model(mut self, model: &str) -> Self {
66 self.model = model.to_string();
67 self
68 }
69}
70
71impl StreamingModelProvider for GeminiProvider {
72 fn model(&self) -> Option<crate::LlmModel> {
73 format!("gemini:{}", self.model).parse().ok()
74 }
75
76 fn context_window(&self) -> Option<u32> {
77 get_context_window("gemini", &self.model)
78 }
79
80 fn stream_response(&self, context: &Context) -> LlmResponseStream {
81 let provider = self.clone();
82 let context = context.clone();
83
84 Box::pin(stream! {
85 let api_key = match provider.get_api_key() {
86 Ok(key) => key,
87 Err(e) => {
88 yield Err(e);
89 return;
90 }
91 };
92
93 tracing::info!("Using Gemini API with API key (OpenAI-compatible endpoint)");
94 let client = provider.build_openai_client(&api_key);
95 let request = match build_chat_request(&provider.model, &context, None) {
96 Ok(req) => req,
97 Err(e) => {
98 yield Err(e);
99 return;
100 }
101 };
102 let mut inner_stream =
103 create_custom_stream_generic(&client, request);
104
105 while let Some(result) = inner_stream.next().await {
106 yield result;
107 }
108 })
109 }
110
111 fn display_name(&self) -> String {
112 format!("Gemini ({})", self.model)
113 }
114}
115
116#[cfg(test)]
117mod tests {
118 use super::*;
119 use async_openai::config::Config;
120 use reqwest::header::AUTHORIZATION;
121
122 #[test]
123 fn test_provider_display_name() {
124 let provider = GeminiProvider::new(None).with_model("gemini-2.0-flash");
125 assert_eq!(provider.display_name(), "Gemini (gemini-2.0-flash)");
126 }
127
128 #[test]
129 fn get_api_key_returns_empty_when_auth_is_none() {
130 let provider = GeminiProvider::new(Some("real-key".to_string()))
131 .with_connection(ProviderConnectionConfig { auth_mode: ProviderAuthMode::None, ..Default::default() });
132 assert_eq!(provider.get_api_key().unwrap(), "");
133 }
134
135 #[test]
136 fn build_openai_client_strips_authorization_when_auth_is_none() {
137 let provider = GeminiProvider::new(Some("real-key".to_string()))
138 .with_connection(ProviderConnectionConfig { auth_mode: ProviderAuthMode::None, ..Default::default() });
139 let api_key = provider.get_api_key().unwrap();
140 let client = provider.build_openai_client(&api_key);
141 assert!(!client.config().headers().contains_key(AUTHORIZATION));
142 }
143}