ai_lib/provider/
gemini.rs

1use crate::api::{ChatApi, ChatCompletionChunk, ModelInfo, ModelPermission};
2use crate::metrics::{Metrics, NoopMetrics};
3use crate::transport::{DynHttpTransportRef, HttpTransport};
4use crate::types::{
5    AiLibError, ChatCompletionRequest, ChatCompletionResponse, Choice, Message, Role, Usage,
6};
7use futures::stream::{self, Stream};
8use std::collections::HashMap;
9use std::env;
10use std::sync::Arc;
11
12/// Google Gemini independent adapter, supporting multimodal AI services
13///
14/// Google Gemini independent adapter for multimodal AI service
15///
16/// Gemini API is completely different from OpenAI format, requires independent adapter:
17/// - Endpoint: /v1beta/models/{model}:generateContent
18/// - Request body: contents array instead of messages
19/// - Response: candidates[0].content.parts[0].text
20/// - Authentication: URL parameter ?key=<API_KEY>
21pub struct GeminiAdapter {
22    transport: DynHttpTransportRef,
23    api_key: String,
24    base_url: String,
25    metrics: Arc<dyn Metrics>,
26}
27
28impl GeminiAdapter {
29    pub fn new() -> Result<Self, AiLibError> {
30        let api_key = env::var("GEMINI_API_KEY").map_err(|_| {
31            AiLibError::AuthenticationError(
32                "GEMINI_API_KEY environment variable not set".to_string(),
33            )
34        })?;
35
36        Ok(Self {
37            transport: HttpTransport::new().boxed(),
38            api_key,
39            base_url: "https://generativelanguage.googleapis.com/v1beta".to_string(),
40            metrics: Arc::new(NoopMetrics::new()),
41        })
42    }
43
44    /// Explicit overrides for api_key and optional base_url (takes precedence over env vars)
45    pub fn new_with_overrides(
46        api_key: String,
47        base_url: Option<String>,
48    ) -> Result<Self, AiLibError> {
49        Ok(Self {
50            transport: HttpTransport::new().boxed(),
51            api_key,
52            base_url: base_url
53                .unwrap_or_else(|| "https://generativelanguage.googleapis.com/v1beta".to_string()),
54            metrics: Arc::new(NoopMetrics::new()),
55        })
56    }
57
58    /// Construct using object-safe transport reference
59    pub fn with_transport_ref(
60        transport: DynHttpTransportRef,
61        api_key: String,
62        base_url: String,
63    ) -> Result<Self, AiLibError> {
64        Ok(Self {
65            transport,
66            api_key,
67            base_url,
68            metrics: Arc::new(NoopMetrics::new()),
69        })
70    }
71
72    /// Construct with an injected transport and metrics implementation
73    pub fn with_transport_ref_and_metrics(
74        transport: DynHttpTransportRef,
75        api_key: String,
76        base_url: String,
77        metrics: Arc<dyn Metrics>,
78    ) -> Result<Self, AiLibError> {
79        Ok(Self {
80            transport,
81            api_key,
82            base_url,
83            metrics,
84        })
85    }
86
87    /// Convert generic request to Gemini format
88    fn convert_to_gemini_request(&self, request: &ChatCompletionRequest) -> serde_json::Value {
89        let contents: Vec<serde_json::Value> = request
90            .messages
91            .iter()
92            .map(|msg| {
93                let role = match msg.role {
94                    Role::User => "user",
95                    Role::Assistant => "model", // Gemini uses "model" instead of "assistant"
96                    Role::System => "user",     // Gemini has no system role, convert to user
97                };
98
99                serde_json::json!({
100                    "role": role,
101                    "parts": [{"text": msg.content.as_text()}]
102                })
103            })
104            .collect();
105
106        let mut gemini_request = serde_json::json!({
107            "contents": contents
108        });
109
110        // Gemini generation configuration
111        let mut generation_config = serde_json::json!({});
112
113        if let Some(temp) = request.temperature {
114            generation_config["temperature"] =
115                serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
116        }
117        if let Some(max_tokens) = request.max_tokens {
118            generation_config["maxOutputTokens"] =
119                serde_json::Value::Number(serde_json::Number::from(max_tokens));
120        }
121        if let Some(top_p) = request.top_p {
122            generation_config["topP"] =
123                serde_json::Value::Number(serde_json::Number::from_f64(top_p.into()).unwrap());
124        }
125
126        if !generation_config.as_object().unwrap().is_empty() {
127            gemini_request["generationConfig"] = generation_config;
128        }
129
130        gemini_request
131    }
132
133    /// Parse Gemini response to generic format
134    fn parse_gemini_response(
135        &self,
136        response: serde_json::Value,
137        model: &str,
138    ) -> Result<ChatCompletionResponse, AiLibError> {
139        let candidates = response["candidates"].as_array().ok_or_else(|| {
140            AiLibError::ProviderError("No candidates in Gemini response".to_string())
141        })?;
142
143        let choices: Result<Vec<Choice>, AiLibError> = candidates
144            .iter()
145            .enumerate()
146            .map(|(index, candidate)| {
147                let content = candidate["content"]["parts"][0]["text"]
148                    .as_str()
149                    .ok_or_else(|| {
150                        AiLibError::ProviderError("No text in Gemini candidate".to_string())
151                    })?;
152
153                // Try to parse a function_call if the provider returned one. Gemini's
154                // response shape may place structured data under candidate["function_call"]
155                // or nested inside candidate["content"]["function_call"]. We try both.
156                let mut function_call: Option<crate::types::function_call::FunctionCall> = None;
157                if let Some(fc_val) = candidate.get("function_call").cloned().or_else(|| {
158                    candidate
159                        .get("content")
160                        .and_then(|c| c.get("function_call"))
161                        .cloned()
162                }) {
163                    if let Ok(fc) = serde_json::from_value::<
164                        crate::types::function_call::FunctionCall,
165                    >(fc_val.clone())
166                    {
167                        function_call = Some(fc);
168                    } else {
169                        // Fallback: extract name + arguments (arguments may be a JSON string)
170                        if let Some(name) = fc_val
171                            .get("name")
172                            .and_then(|v| v.as_str())
173                            .map(|s| s.to_string())
174                        {
175                            let args = fc_val.get("arguments").and_then(|a| {
176                                if a.is_string() {
177                                    serde_json::from_str::<serde_json::Value>(a.as_str().unwrap())
178                                        .ok()
179                                } else {
180                                    Some(a.clone())
181                                }
182                            });
183                            function_call = Some(crate::types::function_call::FunctionCall {
184                                name,
185                                arguments: args,
186                            });
187                        }
188                    }
189                }
190
191                let finish_reason = candidate["finishReason"].as_str().map(|r| match r {
192                    "STOP" => "stop".to_string(),
193                    "MAX_TOKENS" => "length".to_string(),
194                    _ => r.to_string(),
195                });
196
197                Ok(Choice {
198                    index: index as u32,
199                    message: Message {
200                        role: Role::Assistant,
201                        content: crate::types::common::Content::Text(content.to_string()),
202                        function_call,
203                    },
204                    finish_reason,
205                })
206            })
207            .collect();
208
209        let usage = Usage {
210            prompt_tokens: response["usageMetadata"]["promptTokenCount"]
211                .as_u64()
212                .unwrap_or(0) as u32,
213            completion_tokens: response["usageMetadata"]["candidatesTokenCount"]
214                .as_u64()
215                .unwrap_or(0) as u32,
216            total_tokens: response["usageMetadata"]["totalTokenCount"]
217                .as_u64()
218                .unwrap_or(0) as u32,
219        };
220
221        Ok(ChatCompletionResponse {
222            id: format!("gemini-{}", chrono::Utc::now().timestamp()),
223            object: "chat.completion".to_string(),
224            created: chrono::Utc::now().timestamp() as u64,
225            model: model.to_string(),
226            choices: choices?,
227            usage,
228        })
229    }
230}
231
232#[async_trait::async_trait]
233impl ChatApi for GeminiAdapter {
234    async fn chat_completion(
235        &self,
236        request: ChatCompletionRequest,
237    ) -> Result<ChatCompletionResponse, AiLibError> {
238        self.metrics.incr_counter("gemini.requests", 1).await;
239        let timer = self.metrics.start_timer("gemini.request_duration_ms").await;
240
241        let gemini_request = self.convert_to_gemini_request(&request);
242
243        // Gemini uses URL parameter authentication, not headers
244        let url = format!(
245            "{}/models/{}:generateContent?key={}",
246            self.base_url, request.model, self.api_key
247        );
248
249        let headers = HashMap::from([("Content-Type".to_string(), "application/json".to_string())]);
250
251        let response = match self
252            .transport
253            .post_json(&url, Some(headers), gemini_request)
254            .await
255        {
256            Ok(v) => {
257                if let Some(t) = timer {
258                    t.stop();
259                }
260                v
261            }
262            Err(e) => {
263                if let Some(t) = timer {
264                    t.stop();
265                }
266                return Err(e);
267            }
268        };
269
270        self.parse_gemini_response(response, &request.model)
271    }
272
273    async fn chat_completion_stream(
274        &self,
275        _request: ChatCompletionRequest,
276    ) -> Result<
277        Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
278        AiLibError,
279    > {
280        // Gemini streaming response requires special handling, return empty stream for now
281        let stream = stream::empty();
282        Ok(Box::new(Box::pin(stream)))
283    }
284
285    async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
286        // Common Gemini models
287        Ok(vec![
288            "gemini-1.5-pro".to_string(),
289            "gemini-1.5-flash".to_string(),
290            "gemini-1.0-pro".to_string(),
291        ])
292    }
293
294    async fn get_model_info(&self, model_id: &str) -> Result<ModelInfo, AiLibError> {
295        Ok(ModelInfo {
296            id: model_id.to_string(),
297            object: "model".to_string(),
298            created: 0,
299            owned_by: "google".to_string(),
300            permission: vec![ModelPermission {
301                id: "default".to_string(),
302                object: "model_permission".to_string(),
303                created: 0,
304                allow_create_engine: false,
305                allow_sampling: true,
306                allow_logprobs: false,
307                allow_search_indices: false,
308                allow_view: true,
309                allow_fine_tuning: false,
310                organization: "*".to_string(),
311                group: None,
312                is_blocking: false,
313            }],
314        })
315    }
316}