ai_lib/provider/
gemini.rs

1use crate::api::{ChatApi, ChatCompletionChunk, ModelInfo, ModelPermission};
2use crate::metrics::{Metrics, NoopMetrics};
3use crate::transport::{DynHttpTransportRef, HttpTransport};
4use crate::types::{
5    AiLibError, ChatCompletionRequest, ChatCompletionResponse, Choice, Message, Role, Usage,
6};
7use futures::stream::{self, Stream};
8use std::collections::HashMap;
9use std::env;
10use std::sync::Arc;
11
12/// Google Gemini独立适配器,支持多模态AI服务
13///
14/// Google Gemini independent adapter for multimodal AI service
15///
16/// Gemini API is completely different from OpenAI format, requires independent adapter:
17/// - Endpoint: /v1beta/models/{model}:generateContent
18/// - Request body: contents array instead of messages
19/// - Response: candidates[0].content.parts[0].text
20/// - Authentication: URL parameter ?key=<API_KEY>
21pub struct GeminiAdapter {
22    transport: DynHttpTransportRef,
23    api_key: String,
24    base_url: String,
25    metrics: Arc<dyn Metrics>,
26}
27
28impl GeminiAdapter {
29    pub fn new() -> Result<Self, AiLibError> {
30        let api_key = env::var("GEMINI_API_KEY").map_err(|_| {
31            AiLibError::AuthenticationError(
32                "GEMINI_API_KEY environment variable not set".to_string(),
33            )
34        })?;
35
36        Ok(Self {
37            transport: HttpTransport::new().boxed(),
38            api_key,
39            base_url: "https://generativelanguage.googleapis.com/v1beta".to_string(),
40            metrics: Arc::new(NoopMetrics::new()),
41        })
42    }
43
44    /// Construct using object-safe transport reference
45    pub fn with_transport_ref(
46        transport: DynHttpTransportRef,
47        api_key: String,
48        base_url: String,
49    ) -> Result<Self, AiLibError> {
50        Ok(Self {
51            transport,
52            api_key,
53            base_url,
54            metrics: Arc::new(NoopMetrics::new()),
55        })
56    }
57
58    /// Construct with an injected transport and metrics implementation
59    pub fn with_transport_ref_and_metrics(
60        transport: DynHttpTransportRef,
61        api_key: String,
62        base_url: String,
63        metrics: Arc<dyn Metrics>,
64    ) -> Result<Self, AiLibError> {
65        Ok(Self {
66            transport,
67            api_key,
68            base_url,
69            metrics,
70        })
71    }
72
73    /// Convert generic request to Gemini format
74    fn convert_to_gemini_request(&self, request: &ChatCompletionRequest) -> serde_json::Value {
75        let contents: Vec<serde_json::Value> = request
76            .messages
77            .iter()
78            .map(|msg| {
79                let role = match msg.role {
80                    Role::User => "user",
81                    Role::Assistant => "model", // Gemini uses "model" instead of "assistant"
82                    Role::System => "user",     // Gemini has no system role, convert to user
83                };
84
85                serde_json::json!({
86                    "role": role,
87                    "parts": [{"text": msg.content.as_text()}]
88                })
89            })
90            .collect();
91
92        let mut gemini_request = serde_json::json!({
93            "contents": contents
94        });
95
96        // Gemini generation configuration
97        let mut generation_config = serde_json::json!({});
98
99        if let Some(temp) = request.temperature {
100            generation_config["temperature"] =
101                serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
102        }
103        if let Some(max_tokens) = request.max_tokens {
104            generation_config["maxOutputTokens"] =
105                serde_json::Value::Number(serde_json::Number::from(max_tokens));
106        }
107        if let Some(top_p) = request.top_p {
108            generation_config["topP"] =
109                serde_json::Value::Number(serde_json::Number::from_f64(top_p.into()).unwrap());
110        }
111
112        if !generation_config.as_object().unwrap().is_empty() {
113            gemini_request["generationConfig"] = generation_config;
114        }
115
116        gemini_request
117    }
118
119    /// Parse Gemini response to generic format
120    fn parse_gemini_response(
121        &self,
122        response: serde_json::Value,
123        model: &str,
124    ) -> Result<ChatCompletionResponse, AiLibError> {
125        let candidates = response["candidates"].as_array().ok_or_else(|| {
126            AiLibError::ProviderError("No candidates in Gemini response".to_string())
127        })?;
128
129        let choices: Result<Vec<Choice>, AiLibError> = candidates
130            .iter()
131            .enumerate()
132            .map(|(index, candidate)| {
133                let content = candidate["content"]["parts"][0]["text"]
134                    .as_str()
135                    .ok_or_else(|| {
136                        AiLibError::ProviderError("No text in Gemini candidate".to_string())
137                    })?;
138
139                // Try to parse a function_call if the provider returned one. Gemini's
140                // response shape may place structured data under candidate["function_call"]
141                // or nested inside candidate["content"]["function_call"]. We try both.
142                let mut function_call: Option<crate::types::function_call::FunctionCall> = None;
143                if let Some(fc_val) = candidate.get("function_call").cloned().or_else(|| {
144                    candidate
145                        .get("content")
146                        .and_then(|c| c.get("function_call"))
147                        .cloned()
148                }) {
149                    if let Ok(fc) = serde_json::from_value::<
150                        crate::types::function_call::FunctionCall,
151                    >(fc_val.clone())
152                    {
153                        function_call = Some(fc);
154                    } else {
155                        // Fallback: extract name + arguments (arguments may be a JSON string)
156                        if let Some(name) = fc_val
157                            .get("name")
158                            .and_then(|v| v.as_str())
159                            .map(|s| s.to_string())
160                        {
161                            let args = fc_val.get("arguments").and_then(|a| {
162                                if a.is_string() {
163                                    serde_json::from_str::<serde_json::Value>(a.as_str().unwrap())
164                                        .ok()
165                                } else {
166                                    Some(a.clone())
167                                }
168                            });
169                            function_call = Some(crate::types::function_call::FunctionCall {
170                                name,
171                                arguments: args,
172                            });
173                        }
174                    }
175                }
176
177                let finish_reason = candidate["finishReason"].as_str().map(|r| match r {
178                    "STOP" => "stop".to_string(),
179                    "MAX_TOKENS" => "length".to_string(),
180                    _ => r.to_string(),
181                });
182
183                Ok(Choice {
184                    index: index as u32,
185                    message: Message {
186                        role: Role::Assistant,
187                        content: crate::types::common::Content::Text(content.to_string()),
188                        function_call,
189                    },
190                    finish_reason,
191                })
192            })
193            .collect();
194
195        let usage = Usage {
196            prompt_tokens: response["usageMetadata"]["promptTokenCount"]
197                .as_u64()
198                .unwrap_or(0) as u32,
199            completion_tokens: response["usageMetadata"]["candidatesTokenCount"]
200                .as_u64()
201                .unwrap_or(0) as u32,
202            total_tokens: response["usageMetadata"]["totalTokenCount"]
203                .as_u64()
204                .unwrap_or(0) as u32,
205        };
206
207        Ok(ChatCompletionResponse {
208            id: format!("gemini-{}", chrono::Utc::now().timestamp()),
209            object: "chat.completion".to_string(),
210            created: chrono::Utc::now().timestamp() as u64,
211            model: model.to_string(),
212            choices: choices?,
213            usage,
214        })
215    }
216}
217
218#[async_trait::async_trait]
219impl ChatApi for GeminiAdapter {
220    async fn chat_completion(
221        &self,
222        request: ChatCompletionRequest,
223    ) -> Result<ChatCompletionResponse, AiLibError> {
224    self.metrics.incr_counter("gemini.requests", 1).await;
225    let timer = self.metrics.start_timer("gemini.request_duration_ms").await;
226
227        let gemini_request = self.convert_to_gemini_request(&request);
228
229        // Gemini uses URL parameter authentication, not headers
230        let url = format!(
231            "{}/models/{}:generateContent?key={}",
232            self.base_url, request.model, self.api_key
233        );
234
235        let headers = HashMap::from([("Content-Type".to_string(), "application/json".to_string())]);
236
237        let response = match self
238            .transport
239            .post_json(&url, Some(headers), gemini_request)
240            .await
241        {
242            Ok(v) => {
243                if let Some(t) = timer {
244                    t.stop();
245                }
246                v
247            }
248            Err(e) => {
249                if let Some(t) = timer {
250                    t.stop();
251                }
252                return Err(e);
253            }
254        };
255
256        self.parse_gemini_response(response, &request.model)
257    }
258
259    async fn chat_completion_stream(
260        &self,
261        _request: ChatCompletionRequest,
262    ) -> Result<
263        Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
264        AiLibError,
265    > {
266        // Gemini streaming response requires special handling, return empty stream for now
267        let stream = stream::empty();
268        Ok(Box::new(Box::pin(stream)))
269    }
270
271    async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
272        // Common Gemini models
273        Ok(vec![
274            "gemini-1.5-pro".to_string(),
275            "gemini-1.5-flash".to_string(),
276            "gemini-1.0-pro".to_string(),
277        ])
278    }
279
280    async fn get_model_info(&self, model_id: &str) -> Result<ModelInfo, AiLibError> {
281        Ok(ModelInfo {
282            id: model_id.to_string(),
283            object: "model".to_string(),
284            created: 0,
285            owned_by: "google".to_string(),
286            permission: vec![ModelPermission {
287                id: "default".to_string(),
288                object: "model_permission".to_string(),
289                created: 0,
290                allow_create_engine: false,
291                allow_sampling: true,
292                allow_logprobs: false,
293                allow_search_indices: false,
294                allow_view: true,
295                allow_fine_tuning: false,
296                organization: "*".to_string(),
297                group: None,
298                is_blocking: false,
299            }],
300        })
301    }
302}