ai_lib/provider/
gemini.rs

1use crate::api::{ChatApi, ChatCompletionChunk, ModelInfo, ModelPermission};
2use crate::metrics::{Metrics, NoopMetrics};
3use crate::transport::{DynHttpTransportRef, HttpTransport};
4use crate::types::{
5    AiLibError, ChatCompletionRequest, ChatCompletionResponse, Choice, Message, Role, Usage, UsageStatus,
6};
7use futures::stream::Stream;
8use futures::StreamExt;
9use std::collections::HashMap;
10use std::env;
11use std::sync::Arc;
12#[cfg(feature = "unified_transport")]
13use std::time::Duration;
14
15/// Google Gemini independent adapter, supporting multimodal AI services
16///
17/// Google Gemini independent adapter for multimodal AI service
18///
19/// Gemini API is completely different from OpenAI format, requires independent adapter:
20/// - Endpoint: /v1beta/models/{model}:generateContent
21/// - Request body: contents array instead of messages
22/// - Response: candidates[0].content.parts[0].text
23/// - Authentication: URL parameter ?key=<API_KEY>
24pub struct GeminiAdapter {
25    #[allow(dead_code)] // Kept for backward compatibility, now using direct reqwest
26    transport: DynHttpTransportRef,
27    api_key: String,
28    base_url: String,
29    metrics: Arc<dyn Metrics>,
30}
31
32impl GeminiAdapter {
33    fn build_default_timeout_secs() -> u64 {
34        std::env::var("AI_HTTP_TIMEOUT_SECS")
35            .ok()
36            .and_then(|s| s.parse::<u64>().ok())
37            .unwrap_or(30)
38    }
39
40    fn build_default_transport() -> Result<DynHttpTransportRef, AiLibError> {
41        #[cfg(feature = "unified_transport")]
42        {
43            let timeout = Duration::from_secs(Self::build_default_timeout_secs());
44            let client = crate::transport::client_factory::build_shared_client()
45                .map_err(|e| AiLibError::NetworkError(format!("Failed to build http client: {}", e)))?;
46            let t = HttpTransport::with_reqwest_client(client, timeout);
47            return Ok(t.boxed());
48        }
49        #[cfg(not(feature = "unified_transport"))]
50        {
51            let t = HttpTransport::new();
52            return Ok(t.boxed());
53        }
54    }
55
56    pub fn new() -> Result<Self, AiLibError> {
57        let api_key = env::var("GEMINI_API_KEY").map_err(|_| {
58            AiLibError::AuthenticationError(
59                "GEMINI_API_KEY environment variable not set".to_string(),
60            )
61        })?;
62
63        Ok(Self {
64            transport: Self::build_default_transport()?,
65            api_key,
66            base_url: "https://generativelanguage.googleapis.com/v1beta".to_string(),
67            metrics: Arc::new(NoopMetrics::new()),
68        })
69    }
70
71    /// Explicit overrides for api_key and optional base_url (takes precedence over env vars)
72    pub fn new_with_overrides(
73        api_key: String,
74        base_url: Option<String>,
75    ) -> Result<Self, AiLibError> {
76        Ok(Self {
77            transport: Self::build_default_transport()?,
78            api_key,
79            base_url: base_url
80                .unwrap_or_else(|| "https://generativelanguage.googleapis.com/v1beta".to_string()),
81            metrics: Arc::new(NoopMetrics::new()),
82        })
83    }
84
85    /// Construct using object-safe transport reference
86    pub fn with_transport_ref(
87        transport: DynHttpTransportRef,
88        api_key: String,
89        base_url: String,
90    ) -> Result<Self, AiLibError> {
91        Ok(Self {
92            transport,
93            api_key,
94            base_url,
95            metrics: Arc::new(NoopMetrics::new()),
96        })
97    }
98
99    /// Construct with an injected transport and metrics implementation
100    pub fn with_transport_ref_and_metrics(
101        transport: DynHttpTransportRef,
102        api_key: String,
103        base_url: String,
104        metrics: Arc<dyn Metrics>,
105    ) -> Result<Self, AiLibError> {
106        Ok(Self {
107            transport,
108            api_key,
109            base_url,
110            metrics,
111        })
112    }
113
114    /// Convert generic request to Gemini format
115    fn convert_to_gemini_request(&self, request: &ChatCompletionRequest) -> serde_json::Value {
116        let contents: Vec<serde_json::Value> = request
117            .messages
118            .iter()
119            .map(|msg| {
120                let role = match msg.role {
121                    Role::User => "user",
122                    Role::Assistant => "model", // Gemini uses "model" instead of "assistant"
123                    Role::System => "user",     // Gemini has no system role, convert to user
124                };
125
126                serde_json::json!({
127                    "role": role,
128                    "parts": [{"text": msg.content.as_text()}]
129                })
130            })
131            .collect();
132
133        let mut gemini_request = serde_json::json!({
134            "contents": contents
135        });
136
137        // Gemini generation configuration
138        let mut generation_config = serde_json::json!({});
139
140        if let Some(temp) = request.temperature {
141            generation_config["temperature"] =
142                serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
143        }
144        if let Some(max_tokens) = request.max_tokens {
145            generation_config["maxOutputTokens"] =
146                serde_json::Value::Number(serde_json::Number::from(max_tokens));
147        }
148        if let Some(top_p) = request.top_p {
149            generation_config["topP"] =
150                serde_json::Value::Number(serde_json::Number::from_f64(top_p.into()).unwrap());
151        }
152
153        if !generation_config.as_object().unwrap().is_empty() {
154            gemini_request["generationConfig"] = generation_config;
155        }
156
157        gemini_request
158    }
159
160    /// Parse Gemini response to generic format
161    fn parse_gemini_response(
162        &self,
163        response: serde_json::Value,
164        model: &str,
165    ) -> Result<ChatCompletionResponse, AiLibError> {
166        let candidates = response["candidates"].as_array().ok_or_else(|| {
167            AiLibError::ProviderError("No candidates in Gemini response".to_string())
168        })?;
169
170        let choices: Result<Vec<Choice>, AiLibError> = candidates
171            .iter()
172            .enumerate()
173            .map(|(index, candidate)| {
174                let content = candidate["content"]["parts"][0]["text"]
175                    .as_str()
176                    .ok_or_else(|| {
177                        AiLibError::ProviderError("No text in Gemini candidate".to_string())
178                    })?;
179
180                // Try to parse a function_call if the provider returned one. Gemini's
181                // response shape may place structured data under candidate["function_call"]
182                // or nested inside candidate["content"]["function_call"]. We try both.
183                let mut function_call: Option<crate::types::function_call::FunctionCall> = None;
184                if let Some(fc_val) = candidate.get("function_call").cloned().or_else(|| {
185                    candidate
186                        .get("content")
187                        .and_then(|c| c.get("function_call"))
188                        .cloned()
189                }) {
190                    if let Ok(fc) = serde_json::from_value::<
191                        crate::types::function_call::FunctionCall,
192                    >(fc_val.clone())
193                    {
194                        function_call = Some(fc);
195                    } else {
196                        // Fallback: extract name + arguments (arguments may be a JSON string)
197                        if let Some(name) = fc_val
198                            .get("name")
199                            .and_then(|v| v.as_str())
200                            .map(|s| s.to_string())
201                        {
202                            let args = fc_val.get("arguments").and_then(|a| {
203                                if a.is_string() {
204                                    serde_json::from_str::<serde_json::Value>(a.as_str().unwrap())
205                                        .ok()
206                                } else {
207                                    Some(a.clone())
208                                }
209                            });
210                            function_call = Some(crate::types::function_call::FunctionCall {
211                                name,
212                                arguments: args,
213                            });
214                        }
215                    }
216                }
217
218                let finish_reason = candidate["finishReason"].as_str().map(|r| match r {
219                    "STOP" => "stop".to_string(),
220                    "MAX_TOKENS" => "length".to_string(),
221                    _ => r.to_string(),
222                });
223
224                Ok(Choice {
225                    index: index as u32,
226                    message: Message {
227                        role: Role::Assistant,
228                        content: crate::types::common::Content::Text(content.to_string()),
229                        function_call,
230                    },
231                    finish_reason,
232                })
233            })
234            .collect();
235
236        let usage = Usage {
237            prompt_tokens: response["usageMetadata"]["promptTokenCount"]
238                .as_u64()
239                .unwrap_or(0) as u32,
240            completion_tokens: response["usageMetadata"]["candidatesTokenCount"]
241                .as_u64()
242                .unwrap_or(0) as u32,
243            total_tokens: response["usageMetadata"]["totalTokenCount"]
244                .as_u64()
245                .unwrap_or(0) as u32,
246        };
247
248        Ok(ChatCompletionResponse {
249            id: format!("gemini-{}", chrono::Utc::now().timestamp()),
250            object: "chat.completion".to_string(),
251            created: chrono::Utc::now().timestamp() as u64,
252            model: model.to_string(),
253            choices: choices?,
254            usage,
255            usage_status: UsageStatus::Finalized, // Gemini provides accurate usage data
256        })
257    }
258}
259
260#[async_trait::async_trait]
261impl ChatApi for GeminiAdapter {
262    async fn chat_completion(
263        &self,
264        request: ChatCompletionRequest,
265    ) -> Result<ChatCompletionResponse, AiLibError> {
266        self.metrics.incr_counter("gemini.requests", 1).await;
267        let timer = self.metrics.start_timer("gemini.request_duration_ms").await;
268
269        let gemini_request = self.convert_to_gemini_request(&request);
270
271        // Gemini uses URL parameter authentication, not headers
272        let url = format!("{}/models/{}:generateContent", self.base_url, request.model);
273
274        let headers = HashMap::from([
275            ("Content-Type".to_string(), "application/json".to_string()),
276            ("x-goog-api-key".to_string(), self.api_key.clone()),
277        ]);
278
279        // Use unified transport
280        let response_json = self
281            .transport
282            .post_json(&url, Some(headers), gemini_request)
283            .await?;
284        if let Some(t) = timer {
285            t.stop();
286        }
287        self.parse_gemini_response(response_json, &request.model)
288    }
289
290    async fn chat_completion_stream(
291        &self,
292        request: ChatCompletionRequest,
293    ) -> Result<
294        Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
295        AiLibError,
296    > {
297        // Try native SSE first per Gemini API streamGenerateContent
298        let url = format!(
299            "{}/models/{}:streamGenerateContent",
300            self.base_url, request.model
301        );
302        let gemini_request = self.convert_to_gemini_request(&request);
303        let mut headers = HashMap::new();
304        headers.insert("Content-Type".to_string(), "application/json".to_string());
305        headers.insert("Accept".to_string(), "text/event-stream".to_string());
306        headers.insert("x-goog-api-key".to_string(), self.api_key.clone());
307
308        if let Ok(mut byte_stream) = self
309            .transport
310            .post_stream(&url, Some(headers), gemini_request)
311            .await
312        {
313            let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
314            tokio::spawn(async move {
315                let mut buffer = Vec::new();
316                while let Some(item) = byte_stream.next().await {
317                    match item {
318                        Ok(bytes) => {
319                            buffer.extend_from_slice(&bytes);
320                            #[cfg(feature = "unified_sse")]
321                            {
322                                while let Some(boundary) =
323                                    crate::sse::parser::find_event_boundary(&buffer)
324                                {
325                                    let event_bytes = buffer.drain(..boundary).collect::<Vec<_>>();
326                                    if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
327                                        for line in event_text.lines() {
328                                            let line = line.trim();
329                                            if let Some(data) = line.strip_prefix("data: ") {
330                                                if data.is_empty() {
331                                                    continue;
332                                                }
333                                                if data == "[DONE]" {
334                                                    return;
335                                                }
336                                                match serde_json::from_str::<serde_json::Value>(
337                                                    data,
338                                                ) {
339                                                    Ok(json) => {
340                                                        let text = json
341                                                            .get("candidates")
342                                                            .and_then(|c| c.as_array())
343                                                            .and_then(|arr| arr.first())
344                                                            .and_then(|cand| {
345                                                                cand.get("content")
346                                                                    .and_then(|c| c.get("parts"))
347                                                                    .and_then(|p| p.as_array())
348                                                                    .and_then(|parts| parts.first())
349                                                                    .and_then(|part| {
350                                                                        part.get("text")
351                                                                    })
352                                                                    .and_then(|t| t.as_str())
353                                                            })
354                                                            .map(|s| s.to_string());
355                                                        if let Some(tdelta) = text {
356                                                            let delta = crate::api::ChoiceDelta { index: 0, delta: crate::api::MessageDelta { role: Some(crate::types::Role::Assistant), content: Some(tdelta) }, finish_reason: json.get("candidates").and_then(|c| c.as_array()).and_then(|arr| arr.first()).and_then(|cand| cand.get("finishReason").or_else(|| json.get("finishReason"))).and_then(|v| v.as_str()).map(|r| match r { "STOP" => "stop".to_string(), "MAX_TOKENS" => "length".to_string(), other => other.to_string() }) };
357                                                            let chunk_obj = ChatCompletionChunk {
358                                                                id: json
359                                                                    .get("responseId")
360                                                                    .and_then(|v| v.as_str())
361                                                                    .unwrap_or("")
362                                                                    .to_string(),
363                                                                object: "chat.completion.chunk"
364                                                                    .to_string(),
365                                                                created: 0,
366                                                                model: request.model.clone(),
367                                                                choices: vec![delta],
368                                                            };
369                                                            if tx.send(Ok(chunk_obj)).is_err() {
370                                                                return;
371                                                            }
372                                                        }
373                                                    }
374                                                    Err(e) => {
375                                                        let _ = tx.send(Err(
376                                                            AiLibError::ProviderError(format!(
377                                                                "Gemini SSE JSON parse error: {}",
378                                                                e
379                                                            )),
380                                                        ));
381                                                        return;
382                                                    }
383                                                }
384                                            }
385                                        }
386                                    }
387                                }
388                            }
389                            #[cfg(not(feature = "unified_sse"))]
390                            {
391                                fn find_event_boundary(buffer: &[u8]) -> Option<usize> {
392                                    let mut i = 0;
393                                    while i + 1 < buffer.len() {
394                                        if buffer[i] == b'\n' && buffer[i + 1] == b'\n' {
395                                            return Some(i + 2);
396                                        }
397                                        if i + 3 < buffer.len()
398                                            && buffer[i] == b'\r'
399                                            && buffer[i + 1] == b'\n'
400                                            && buffer[i + 2] == b'\r'
401                                            && buffer[i + 3] == b'\n'
402                                        {
403                                            return Some(i + 4);
404                                        }
405                                        i += 1;
406                                    }
407                                    None
408                                }
409                                while let Some(boundary) = find_event_boundary(&buffer) {
410                                    let event_bytes = buffer.drain(..boundary).collect::<Vec<_>>();
411                                    if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
412                                        for line in event_text.lines() {
413                                            let line = line.trim();
414                                            if let Some(data) = line.strip_prefix("data: ") {
415                                                if data.is_empty() {
416                                                    continue;
417                                                }
418                                                if data == "[DONE]" {
419                                                    return;
420                                                }
421                                                match serde_json::from_str::<serde_json::Value>(
422                                                    data,
423                                                ) {
424                                                    Ok(json) => {
425                                                        let text = json
426                                                            .get("candidates")
427                                                            .and_then(|c| c.as_array())
428                                                            .and_then(|arr| arr.first())
429                                                            .and_then(|cand| {
430                                                                cand.get("content")
431                                                                    .and_then(|c| c.get("parts"))
432                                                                    .and_then(|p| p.as_array())
433                                                                    .and_then(|parts| parts.first())
434                                                                    .and_then(|part| {
435                                                                        part.get("text")
436                                                                    })
437                                                                    .and_then(|t| t.as_str())
438                                                            })
439                                                            .map(|s| s.to_string());
440                                                        if let Some(tdelta) = text {
441                                                            let delta = crate::api::ChoiceDelta { index: 0, delta: crate::api::MessageDelta { role: Some(crate::types::Role::Assistant), content: Some(tdelta) }, finish_reason: None };
442                                                            let chunk_obj = ChatCompletionChunk {
443                                                                id: json
444                                                                    .get("responseId")
445                                                                    .and_then(|v| v.as_str())
446                                                                    .unwrap_or("")
447                                                                    .to_string(),
448                                                                object: "chat.completion.chunk"
449                                                                    .to_string(),
450                                                                created: 0,
451                                                                model: request.model.clone(),
452                                                                choices: vec![delta],
453                                                            };
454                                                            if tx.send(Ok(chunk_obj)).is_err() {
455                                                                return;
456                                                            }
457                                                        }
458                                                    }
459                                                    Err(e) => {
460                                                        let _ = tx.send(Err(
461                                                            AiLibError::ProviderError(format!(
462                                                                "Gemini SSE JSON parse error: {}",
463                                                                e
464                                                            )),
465                                                        ));
466                                                        return;
467                                                    }
468                                                }
469                                            }
470                                        }
471                                    }
472                                }
473                            }
474                        }
475                        Err(e) => {
476                            let _ = tx.send(Err(AiLibError::ProviderError(format!(
477                                "Stream error: {}",
478                                e
479                            ))));
480                            break;
481                        }
482                    }
483                }
484            });
485            let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx);
486            return Ok(Box::new(Box::pin(stream)));
487        }
488
489        // Fallback to non-streaming + simulated chunks
490        fn split_text_into_chunks(text: &str, max_len: usize) -> Vec<String> {
491            let mut chunks = Vec::new();
492            let mut start = 0;
493            let bytes = text.as_bytes();
494            while start < bytes.len() {
495                let end = std::cmp::min(start + max_len, bytes.len());
496                let mut cut = end;
497                if end < bytes.len() {
498                    if let Some(pos) = text[start..end].rfind(' ') {
499                        cut = start + pos;
500                    }
501                }
502                if cut == start {
503                    cut = end;
504                }
505                chunks.push(String::from_utf8_lossy(&bytes[start..cut]).to_string());
506                start = cut;
507                if start < bytes.len() && bytes[start] == b' ' {
508                    start += 1;
509                }
510            }
511            chunks
512        }
513
514        let finished = self.chat_completion(request).await?;
515        let text = finished
516            .choices
517            .first()
518            .map(|c| c.message.content.as_text())
519            .unwrap_or_default();
520        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
521        tokio::spawn(async move {
522            let chunks = split_text_into_chunks(&text, 80);
523            for chunk in chunks {
524                let delta = crate::api::ChoiceDelta {
525                    index: 0,
526                    delta: crate::api::MessageDelta {
527                        role: Some(crate::types::Role::Assistant),
528                        content: Some(chunk.clone()),
529                    },
530                    finish_reason: None,
531                };
532                let chunk_obj = ChatCompletionChunk {
533                    id: "simulated".to_string(),
534                    object: "chat.completion.chunk".to_string(),
535                    created: 0,
536                    model: finished.model.clone(),
537                    choices: vec![delta],
538                };
539                if tx.send(Ok(chunk_obj)).is_err() {
540                    return;
541                }
542                tokio::time::sleep(std::time::Duration::from_millis(50)).await;
543            }
544        });
545        let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx);
546        Ok(Box::new(Box::pin(stream)))
547    }
548
549    async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
550        // Common Gemini models
551        Ok(vec![
552            "gemini-1.5-pro".to_string(),
553            "gemini-1.5-flash".to_string(),
554            "gemini-1.0-pro".to_string(),
555        ])
556    }
557
558    async fn get_model_info(&self, model_id: &str) -> Result<ModelInfo, AiLibError> {
559        Ok(ModelInfo {
560            id: model_id.to_string(),
561            object: "model".to_string(),
562            created: 0,
563            owned_by: "google".to_string(),
564            permission: vec![ModelPermission {
565                id: "default".to_string(),
566                object: "model_permission".to_string(),
567                created: 0,
568                allow_create_engine: false,
569                allow_sampling: true,
570                allow_logprobs: false,
571                allow_search_indices: false,
572                allow_view: true,
573                allow_fine_tuning: false,
574                organization: "*".to_string(),
575                group: None,
576                is_blocking: false,
577            }],
578        })
579    }
580}