ai_lib/provider/
gemini.rs

1use crate::api::{ChatApi, ChatCompletionChunk, ModelInfo, ModelPermission};
2use crate::metrics::{Metrics, NoopMetrics};
3use crate::transport::{DynHttpTransportRef, HttpTransport};
4use crate::types::{
5    AiLibError, ChatCompletionRequest, ChatCompletionResponse, Choice, Message, Role, Usage,
6};
7use futures::stream::Stream;
8use futures::StreamExt;
9use std::collections::HashMap;
10use std::env;
11use std::sync::Arc;
12#[cfg(feature = "unified_transport")]
13use std::time::Duration;
14
15/// Google Gemini independent adapter, supporting multimodal AI services
16///
17/// Google Gemini independent adapter for multimodal AI service
18///
19/// Gemini API is completely different from OpenAI format, requires independent adapter:
20/// - Endpoint: /v1beta/models/{model}:generateContent
21/// - Request body: contents array instead of messages
22/// - Response: candidates[0].content.parts[0].text
23/// - Authentication: URL parameter ?key=<API_KEY>
24pub struct GeminiAdapter {
25    #[allow(dead_code)] // Kept for backward compatibility, now using direct reqwest
26    transport: DynHttpTransportRef,
27    api_key: String,
28    base_url: String,
29    metrics: Arc<dyn Metrics>,
30}
31
32impl GeminiAdapter {
33    fn build_default_timeout_secs() -> u64 {
34        std::env::var("AI_HTTP_TIMEOUT_SECS")
35            .ok()
36            .and_then(|s| s.parse::<u64>().ok())
37            .unwrap_or(30)
38    }
39
40    fn build_default_transport() -> Result<DynHttpTransportRef, AiLibError> {
41        #[cfg(feature = "unified_transport")]
42        {
43            let timeout = Duration::from_secs(Self::build_default_timeout_secs());
44            let client = crate::transport::client_factory::build_shared_client()
45                .map_err(|e| AiLibError::NetworkError(format!("Failed to build http client: {}", e)))?;
46            let t = HttpTransport::with_reqwest_client(client, timeout);
47            return Ok(t.boxed());
48        }
49        #[cfg(not(feature = "unified_transport"))]
50        {
51            let t = HttpTransport::new();
52            return Ok(t.boxed());
53        }
54    }
55
56    pub fn new() -> Result<Self, AiLibError> {
57        let api_key = env::var("GEMINI_API_KEY").map_err(|_| {
58            AiLibError::AuthenticationError(
59                "GEMINI_API_KEY environment variable not set".to_string(),
60            )
61        })?;
62
63        Ok(Self {
64            transport: Self::build_default_transport()?,
65            api_key,
66            base_url: "https://generativelanguage.googleapis.com/v1beta".to_string(),
67            metrics: Arc::new(NoopMetrics::new()),
68        })
69    }
70
71    /// Explicit overrides for api_key and optional base_url (takes precedence over env vars)
72    pub fn new_with_overrides(
73        api_key: String,
74        base_url: Option<String>,
75    ) -> Result<Self, AiLibError> {
76        Ok(Self {
77            transport: Self::build_default_transport()?,
78            api_key,
79            base_url: base_url
80                .unwrap_or_else(|| "https://generativelanguage.googleapis.com/v1beta".to_string()),
81            metrics: Arc::new(NoopMetrics::new()),
82        })
83    }
84
85    /// Construct using object-safe transport reference
86    pub fn with_transport_ref(
87        transport: DynHttpTransportRef,
88        api_key: String,
89        base_url: String,
90    ) -> Result<Self, AiLibError> {
91        Ok(Self {
92            transport,
93            api_key,
94            base_url,
95            metrics: Arc::new(NoopMetrics::new()),
96        })
97    }
98
99    /// Construct with an injected transport and metrics implementation
100    pub fn with_transport_ref_and_metrics(
101        transport: DynHttpTransportRef,
102        api_key: String,
103        base_url: String,
104        metrics: Arc<dyn Metrics>,
105    ) -> Result<Self, AiLibError> {
106        Ok(Self {
107            transport,
108            api_key,
109            base_url,
110            metrics,
111        })
112    }
113
114    /// Convert generic request to Gemini format
115    fn convert_to_gemini_request(&self, request: &ChatCompletionRequest) -> serde_json::Value {
116        let contents: Vec<serde_json::Value> = request
117            .messages
118            .iter()
119            .map(|msg| {
120                let role = match msg.role {
121                    Role::User => "user",
122                    Role::Assistant => "model", // Gemini uses "model" instead of "assistant"
123                    Role::System => "user",     // Gemini has no system role, convert to user
124                };
125
126                serde_json::json!({
127                    "role": role,
128                    "parts": [{"text": msg.content.as_text()}]
129                })
130            })
131            .collect();
132
133        let mut gemini_request = serde_json::json!({
134            "contents": contents
135        });
136
137        // Gemini generation configuration
138        let mut generation_config = serde_json::json!({});
139
140        if let Some(temp) = request.temperature {
141            generation_config["temperature"] =
142                serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
143        }
144        if let Some(max_tokens) = request.max_tokens {
145            generation_config["maxOutputTokens"] =
146                serde_json::Value::Number(serde_json::Number::from(max_tokens));
147        }
148        if let Some(top_p) = request.top_p {
149            generation_config["topP"] =
150                serde_json::Value::Number(serde_json::Number::from_f64(top_p.into()).unwrap());
151        }
152
153        if !generation_config.as_object().unwrap().is_empty() {
154            gemini_request["generationConfig"] = generation_config;
155        }
156
157        gemini_request
158    }
159
160    /// Parse Gemini response to generic format
161    fn parse_gemini_response(
162        &self,
163        response: serde_json::Value,
164        model: &str,
165    ) -> Result<ChatCompletionResponse, AiLibError> {
166        let candidates = response["candidates"].as_array().ok_or_else(|| {
167            AiLibError::ProviderError("No candidates in Gemini response".to_string())
168        })?;
169
170        let choices: Result<Vec<Choice>, AiLibError> = candidates
171            .iter()
172            .enumerate()
173            .map(|(index, candidate)| {
174                let content = candidate["content"]["parts"][0]["text"]
175                    .as_str()
176                    .ok_or_else(|| {
177                        AiLibError::ProviderError("No text in Gemini candidate".to_string())
178                    })?;
179
180                // Try to parse a function_call if the provider returned one. Gemini's
181                // response shape may place structured data under candidate["function_call"]
182                // or nested inside candidate["content"]["function_call"]. We try both.
183                let mut function_call: Option<crate::types::function_call::FunctionCall> = None;
184                if let Some(fc_val) = candidate.get("function_call").cloned().or_else(|| {
185                    candidate
186                        .get("content")
187                        .and_then(|c| c.get("function_call"))
188                        .cloned()
189                }) {
190                    if let Ok(fc) = serde_json::from_value::<
191                        crate::types::function_call::FunctionCall,
192                    >(fc_val.clone())
193                    {
194                        function_call = Some(fc);
195                    } else {
196                        // Fallback: extract name + arguments (arguments may be a JSON string)
197                        if let Some(name) = fc_val
198                            .get("name")
199                            .and_then(|v| v.as_str())
200                            .map(|s| s.to_string())
201                        {
202                            let args = fc_val.get("arguments").and_then(|a| {
203                                if a.is_string() {
204                                    serde_json::from_str::<serde_json::Value>(a.as_str().unwrap())
205                                        .ok()
206                                } else {
207                                    Some(a.clone())
208                                }
209                            });
210                            function_call = Some(crate::types::function_call::FunctionCall {
211                                name,
212                                arguments: args,
213                            });
214                        }
215                    }
216                }
217
218                let finish_reason = candidate["finishReason"].as_str().map(|r| match r {
219                    "STOP" => "stop".to_string(),
220                    "MAX_TOKENS" => "length".to_string(),
221                    _ => r.to_string(),
222                });
223
224                Ok(Choice {
225                    index: index as u32,
226                    message: Message {
227                        role: Role::Assistant,
228                        content: crate::types::common::Content::Text(content.to_string()),
229                        function_call,
230                    },
231                    finish_reason,
232                })
233            })
234            .collect();
235
236        let usage = Usage {
237            prompt_tokens: response["usageMetadata"]["promptTokenCount"]
238                .as_u64()
239                .unwrap_or(0) as u32,
240            completion_tokens: response["usageMetadata"]["candidatesTokenCount"]
241                .as_u64()
242                .unwrap_or(0) as u32,
243            total_tokens: response["usageMetadata"]["totalTokenCount"]
244                .as_u64()
245                .unwrap_or(0) as u32,
246        };
247
248        Ok(ChatCompletionResponse {
249            id: format!("gemini-{}", chrono::Utc::now().timestamp()),
250            object: "chat.completion".to_string(),
251            created: chrono::Utc::now().timestamp() as u64,
252            model: model.to_string(),
253            choices: choices?,
254            usage,
255        })
256    }
257}
258
259#[async_trait::async_trait]
260impl ChatApi for GeminiAdapter {
261    async fn chat_completion(
262        &self,
263        request: ChatCompletionRequest,
264    ) -> Result<ChatCompletionResponse, AiLibError> {
265        self.metrics.incr_counter("gemini.requests", 1).await;
266        let timer = self.metrics.start_timer("gemini.request_duration_ms").await;
267
268        let gemini_request = self.convert_to_gemini_request(&request);
269
270        // Gemini uses URL parameter authentication, not headers
271        let url = format!("{}/models/{}:generateContent", self.base_url, request.model);
272
273        let headers = HashMap::from([
274            ("Content-Type".to_string(), "application/json".to_string()),
275            ("x-goog-api-key".to_string(), self.api_key.clone()),
276        ]);
277
278        // Use unified transport
279        let response_json = self
280            .transport
281            .post_json(&url, Some(headers), gemini_request)
282            .await?;
283        if let Some(t) = timer {
284            t.stop();
285        }
286        self.parse_gemini_response(response_json, &request.model)
287    }
288
289    async fn chat_completion_stream(
290        &self,
291        request: ChatCompletionRequest,
292    ) -> Result<
293        Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
294        AiLibError,
295    > {
296        // Try native SSE first per Gemini API streamGenerateContent
297        let url = format!(
298            "{}/models/{}:streamGenerateContent",
299            self.base_url, request.model
300        );
301        let gemini_request = self.convert_to_gemini_request(&request);
302        let mut headers = HashMap::new();
303        headers.insert("Content-Type".to_string(), "application/json".to_string());
304        headers.insert("Accept".to_string(), "text/event-stream".to_string());
305        headers.insert("x-goog-api-key".to_string(), self.api_key.clone());
306
307        if let Ok(mut byte_stream) = self
308            .transport
309            .post_stream(&url, Some(headers), gemini_request)
310            .await
311        {
312            let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
313            tokio::spawn(async move {
314                let mut buffer = Vec::new();
315                while let Some(item) = byte_stream.next().await {
316                    match item {
317                        Ok(bytes) => {
318                            buffer.extend_from_slice(&bytes);
319                            #[cfg(feature = "unified_sse")]
320                            {
321                                while let Some(boundary) =
322                                    crate::sse::parser::find_event_boundary(&buffer)
323                                {
324                                    let event_bytes = buffer.drain(..boundary).collect::<Vec<_>>();
325                                    if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
326                                        for line in event_text.lines() {
327                                            let line = line.trim();
328                                            if let Some(data) = line.strip_prefix("data: ") {
329                                                if data.is_empty() {
330                                                    continue;
331                                                }
332                                                if data == "[DONE]" {
333                                                    return;
334                                                }
335                                                match serde_json::from_str::<serde_json::Value>(
336                                                    data,
337                                                ) {
338                                                    Ok(json) => {
339                                                        let text = json
340                                                            .get("candidates")
341                                                            .and_then(|c| c.as_array())
342                                                            .and_then(|arr| arr.first())
343                                                            .and_then(|cand| {
344                                                                cand.get("content")
345                                                                    .and_then(|c| c.get("parts"))
346                                                                    .and_then(|p| p.as_array())
347                                                                    .and_then(|parts| parts.first())
348                                                                    .and_then(|part| {
349                                                                        part.get("text")
350                                                                    })
351                                                                    .and_then(|t| t.as_str())
352                                                            })
353                                                            .map(|s| s.to_string());
354                                                        if let Some(tdelta) = text {
355                                                            let delta = crate::api::ChoiceDelta { index: 0, delta: crate::api::MessageDelta { role: Some(crate::types::Role::Assistant), content: Some(tdelta) }, finish_reason: json.get("candidates").and_then(|c| c.as_array()).and_then(|arr| arr.first()).and_then(|cand| cand.get("finishReason").or_else(|| json.get("finishReason"))).and_then(|v| v.as_str()).map(|r| match r { "STOP" => "stop".to_string(), "MAX_TOKENS" => "length".to_string(), other => other.to_string() }) };
356                                                            let chunk_obj = ChatCompletionChunk {
357                                                                id: json
358                                                                    .get("responseId")
359                                                                    .and_then(|v| v.as_str())
360                                                                    .unwrap_or("")
361                                                                    .to_string(),
362                                                                object: "chat.completion.chunk"
363                                                                    .to_string(),
364                                                                created: 0,
365                                                                model: request.model.clone(),
366                                                                choices: vec![delta],
367                                                            };
368                                                            if tx.send(Ok(chunk_obj)).is_err() {
369                                                                return;
370                                                            }
371                                                        }
372                                                    }
373                                                    Err(e) => {
374                                                        let _ = tx.send(Err(
375                                                            AiLibError::ProviderError(format!(
376                                                                "Gemini SSE JSON parse error: {}",
377                                                                e
378                                                            )),
379                                                        ));
380                                                        return;
381                                                    }
382                                                }
383                                            }
384                                        }
385                                    }
386                                }
387                            }
388                            #[cfg(not(feature = "unified_sse"))]
389                            {
390                                fn find_event_boundary(buffer: &[u8]) -> Option<usize> {
391                                    let mut i = 0;
392                                    while i + 1 < buffer.len() {
393                                        if buffer[i] == b'\n' && buffer[i + 1] == b'\n' {
394                                            return Some(i + 2);
395                                        }
396                                        if i + 3 < buffer.len()
397                                            && buffer[i] == b'\r'
398                                            && buffer[i + 1] == b'\n'
399                                            && buffer[i + 2] == b'\r'
400                                            && buffer[i + 3] == b'\n'
401                                        {
402                                            return Some(i + 4);
403                                        }
404                                        i += 1;
405                                    }
406                                    None
407                                }
408                                while let Some(boundary) = find_event_boundary(&buffer) {
409                                    let event_bytes = buffer.drain(..boundary).collect::<Vec<_>>();
410                                    if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
411                                        for line in event_text.lines() {
412                                            let line = line.trim();
413                                            if let Some(data) = line.strip_prefix("data: ") {
414                                                if data.is_empty() {
415                                                    continue;
416                                                }
417                                                if data == "[DONE]" {
418                                                    return;
419                                                }
420                                                match serde_json::from_str::<serde_json::Value>(
421                                                    data,
422                                                ) {
423                                                    Ok(json) => {
424                                                        let text = json
425                                                            .get("candidates")
426                                                            .and_then(|c| c.as_array())
427                                                            .and_then(|arr| arr.first())
428                                                            .and_then(|cand| {
429                                                                cand.get("content")
430                                                                    .and_then(|c| c.get("parts"))
431                                                                    .and_then(|p| p.as_array())
432                                                                    .and_then(|parts| parts.first())
433                                                                    .and_then(|part| {
434                                                                        part.get("text")
435                                                                    })
436                                                                    .and_then(|t| t.as_str())
437                                                            })
438                                                            .map(|s| s.to_string());
439                                                        if let Some(tdelta) = text {
440                                                            let delta = crate::api::ChoiceDelta { index: 0, delta: crate::api::MessageDelta { role: Some(crate::types::Role::Assistant), content: Some(tdelta) }, finish_reason: None };
441                                                            let chunk_obj = ChatCompletionChunk {
442                                                                id: json
443                                                                    .get("responseId")
444                                                                    .and_then(|v| v.as_str())
445                                                                    .unwrap_or("")
446                                                                    .to_string(),
447                                                                object: "chat.completion.chunk"
448                                                                    .to_string(),
449                                                                created: 0,
450                                                                model: request.model.clone(),
451                                                                choices: vec![delta],
452                                                            };
453                                                            if tx.send(Ok(chunk_obj)).is_err() {
454                                                                return;
455                                                            }
456                                                        }
457                                                    }
458                                                    Err(e) => {
459                                                        let _ = tx.send(Err(
460                                                            AiLibError::ProviderError(format!(
461                                                                "Gemini SSE JSON parse error: {}",
462                                                                e
463                                                            )),
464                                                        ));
465                                                        return;
466                                                    }
467                                                }
468                                            }
469                                        }
470                                    }
471                                }
472                            }
473                        }
474                        Err(e) => {
475                            let _ = tx.send(Err(AiLibError::ProviderError(format!(
476                                "Stream error: {}",
477                                e
478                            ))));
479                            break;
480                        }
481                    }
482                }
483            });
484            let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx);
485            return Ok(Box::new(Box::pin(stream)));
486        }
487
488        // Fallback to non-streaming + simulated chunks
489        fn split_text_into_chunks(text: &str, max_len: usize) -> Vec<String> {
490            let mut chunks = Vec::new();
491            let mut start = 0;
492            let bytes = text.as_bytes();
493            while start < bytes.len() {
494                let end = std::cmp::min(start + max_len, bytes.len());
495                let mut cut = end;
496                if end < bytes.len() {
497                    if let Some(pos) = text[start..end].rfind(' ') {
498                        cut = start + pos;
499                    }
500                }
501                if cut == start {
502                    cut = end;
503                }
504                chunks.push(String::from_utf8_lossy(&bytes[start..cut]).to_string());
505                start = cut;
506                if start < bytes.len() && bytes[start] == b' ' {
507                    start += 1;
508                }
509            }
510            chunks
511        }
512
513        let finished = self.chat_completion(request).await?;
514        let text = finished
515            .choices
516            .first()
517            .map(|c| c.message.content.as_text())
518            .unwrap_or_default();
519        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
520        tokio::spawn(async move {
521            let chunks = split_text_into_chunks(&text, 80);
522            for chunk in chunks {
523                let delta = crate::api::ChoiceDelta {
524                    index: 0,
525                    delta: crate::api::MessageDelta {
526                        role: Some(crate::types::Role::Assistant),
527                        content: Some(chunk.clone()),
528                    },
529                    finish_reason: None,
530                };
531                let chunk_obj = ChatCompletionChunk {
532                    id: "simulated".to_string(),
533                    object: "chat.completion.chunk".to_string(),
534                    created: 0,
535                    model: finished.model.clone(),
536                    choices: vec![delta],
537                };
538                if tx.send(Ok(chunk_obj)).is_err() {
539                    return;
540                }
541                tokio::time::sleep(std::time::Duration::from_millis(50)).await;
542            }
543        });
544        let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx);
545        Ok(Box::new(Box::pin(stream)))
546    }
547
548    async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
549        // Common Gemini models
550        Ok(vec![
551            "gemini-1.5-pro".to_string(),
552            "gemini-1.5-flash".to_string(),
553            "gemini-1.0-pro".to_string(),
554        ])
555    }
556
557    async fn get_model_info(&self, model_id: &str) -> Result<ModelInfo, AiLibError> {
558        Ok(ModelInfo {
559            id: model_id.to_string(),
560            object: "model".to_string(),
561            created: 0,
562            owned_by: "google".to_string(),
563            permission: vec![ModelPermission {
564                id: "default".to_string(),
565                object: "model_permission".to_string(),
566                created: 0,
567                allow_create_engine: false,
568                allow_sampling: true,
569                allow_logprobs: false,
570                allow_search_indices: false,
571                allow_view: true,
572                allow_fine_tuning: false,
573                organization: "*".to_string(),
574                group: None,
575                is_blocking: false,
576            }],
577        })
578    }
579}