ai_lib/provider/
gemini.rs

1use crate::api::{ChatApi, ChatCompletionChunk, ModelInfo, ModelPermission};
2use crate::metrics::{Metrics, NoopMetrics};
3use crate::transport::{DynHttpTransportRef, HttpTransport};
4use crate::types::{
5    AiLibError, ChatCompletionRequest, ChatCompletionResponse, Choice, Message, Role, Usage, UsageStatus,
6};
7use futures::stream::Stream;
8use futures::StreamExt;
9use std::collections::HashMap;
10use std::env;
11use std::sync::Arc;
12#[cfg(feature = "unified_transport")]
13use std::time::Duration;
14
15/// Google Gemini independent adapter, supporting multimodal AI services
16///
17/// Google Gemini independent adapter for multimodal AI service
18///
19/// Gemini API is completely different from OpenAI format, requires independent adapter:
20/// - Endpoint: /v1beta/models/{model}:generateContent
21/// - Request body: contents array instead of messages
22/// - Response: candidates[0].content.parts[0].text
23/// - Authentication: URL parameter ?key=<API_KEY>
24pub struct GeminiAdapter {
25    #[allow(dead_code)] // Kept for backward compatibility, now using direct reqwest
26    transport: DynHttpTransportRef,
27    api_key: String,
28    base_url: String,
29    metrics: Arc<dyn Metrics>,
30}
31
32impl GeminiAdapter {
33    #[allow(dead_code)]
34    fn build_default_timeout_secs() -> u64 {
35        std::env::var("AI_HTTP_TIMEOUT_SECS")
36            .ok()
37            .and_then(|s| s.parse::<u64>().ok())
38            .unwrap_or(30)
39    }
40
41    fn build_default_transport() -> Result<DynHttpTransportRef, AiLibError> {
42        #[cfg(feature = "unified_transport")]
43        {
44            let timeout = Duration::from_secs(Self::build_default_timeout_secs());
45            let client = crate::transport::client_factory::build_shared_client()
46                .map_err(|e| AiLibError::NetworkError(format!("Failed to build http client: {}", e)))?;
47            let t = HttpTransport::with_reqwest_client(client, timeout);
48            return Ok(t.boxed());
49        }
50        #[cfg(not(feature = "unified_transport"))]
51        {
52            let t = HttpTransport::new();
53            return Ok(t.boxed());
54        }
55    }
56
57    pub fn new() -> Result<Self, AiLibError> {
58        let api_key = env::var("GEMINI_API_KEY").map_err(|_| {
59            AiLibError::AuthenticationError(
60                "GEMINI_API_KEY environment variable not set".to_string(),
61            )
62        })?;
63
64        Ok(Self {
65            transport: Self::build_default_transport()?,
66            api_key,
67            base_url: "https://generativelanguage.googleapis.com/v1beta".to_string(),
68            metrics: Arc::new(NoopMetrics::new()),
69        })
70    }
71
72    /// Explicit overrides for api_key and optional base_url (takes precedence over env vars)
73    pub fn new_with_overrides(
74        api_key: String,
75        base_url: Option<String>,
76    ) -> Result<Self, AiLibError> {
77        Ok(Self {
78            transport: Self::build_default_transport()?,
79            api_key,
80            base_url: base_url
81                .unwrap_or_else(|| "https://generativelanguage.googleapis.com/v1beta".to_string()),
82            metrics: Arc::new(NoopMetrics::new()),
83        })
84    }
85
86    /// Construct using object-safe transport reference
87    pub fn with_transport_ref(
88        transport: DynHttpTransportRef,
89        api_key: String,
90        base_url: String,
91    ) -> Result<Self, AiLibError> {
92        Ok(Self {
93            transport,
94            api_key,
95            base_url,
96            metrics: Arc::new(NoopMetrics::new()),
97        })
98    }
99
100    /// Construct with an injected transport and metrics implementation
101    pub fn with_transport_ref_and_metrics(
102        transport: DynHttpTransportRef,
103        api_key: String,
104        base_url: String,
105        metrics: Arc<dyn Metrics>,
106    ) -> Result<Self, AiLibError> {
107        Ok(Self {
108            transport,
109            api_key,
110            base_url,
111            metrics,
112        })
113    }
114
115    /// Convert generic request to Gemini format
116    fn convert_to_gemini_request(&self, request: &ChatCompletionRequest) -> serde_json::Value {
117        let contents: Vec<serde_json::Value> = request
118            .messages
119            .iter()
120            .map(|msg| {
121                let role = match msg.role {
122                    Role::User => "user",
123                    Role::Assistant => "model", // Gemini uses "model" instead of "assistant"
124                    Role::System => "user",     // Gemini has no system role, convert to user
125                };
126
127                serde_json::json!({
128                    "role": role,
129                    "parts": [{"text": msg.content.as_text()}]
130                })
131            })
132            .collect();
133
134        let mut gemini_request = serde_json::json!({
135            "contents": contents
136        });
137
138        // Gemini generation configuration
139        let mut generation_config = serde_json::json!({});
140
141        if let Some(temp) = request.temperature {
142            generation_config["temperature"] =
143                serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
144        }
145        if let Some(max_tokens) = request.max_tokens {
146            generation_config["maxOutputTokens"] =
147                serde_json::Value::Number(serde_json::Number::from(max_tokens));
148        }
149        if let Some(top_p) = request.top_p {
150            generation_config["topP"] =
151                serde_json::Value::Number(serde_json::Number::from_f64(top_p.into()).unwrap());
152        }
153
154        if !generation_config.as_object().unwrap().is_empty() {
155            gemini_request["generationConfig"] = generation_config;
156        }
157
158        gemini_request
159    }
160
161    /// Parse Gemini response to generic format
162    fn parse_gemini_response(
163        &self,
164        response: serde_json::Value,
165        model: &str,
166    ) -> Result<ChatCompletionResponse, AiLibError> {
167        let candidates = response["candidates"].as_array().ok_or_else(|| {
168            AiLibError::ProviderError("No candidates in Gemini response".to_string())
169        })?;
170
171        let choices: Result<Vec<Choice>, AiLibError> = candidates
172            .iter()
173            .enumerate()
174            .map(|(index, candidate)| {
175                let content = candidate["content"]["parts"][0]["text"]
176                    .as_str()
177                    .ok_or_else(|| {
178                        AiLibError::ProviderError("No text in Gemini candidate".to_string())
179                    })?;
180
181                // Try to parse a function_call if the provider returned one. Gemini's
182                // response shape may place structured data under candidate["function_call"]
183                // or nested inside candidate["content"]["function_call"]. We try both.
184                let mut function_call: Option<crate::types::function_call::FunctionCall> = None;
185                if let Some(fc_val) = candidate.get("function_call").cloned().or_else(|| {
186                    candidate
187                        .get("content")
188                        .and_then(|c| c.get("function_call"))
189                        .cloned()
190                }) {
191                    if let Ok(fc) = serde_json::from_value::<
192                        crate::types::function_call::FunctionCall,
193                    >(fc_val.clone())
194                    {
195                        function_call = Some(fc);
196                    } else {
197                        // Fallback: extract name + arguments (arguments may be a JSON string)
198                        if let Some(name) = fc_val
199                            .get("name")
200                            .and_then(|v| v.as_str())
201                            .map(|s| s.to_string())
202                        {
203                            let args = fc_val.get("arguments").and_then(|a| {
204                                if a.is_string() {
205                                    serde_json::from_str::<serde_json::Value>(a.as_str().unwrap())
206                                        .ok()
207                                } else {
208                                    Some(a.clone())
209                                }
210                            });
211                            function_call = Some(crate::types::function_call::FunctionCall {
212                                name,
213                                arguments: args,
214                            });
215                        }
216                    }
217                }
218
219                let finish_reason = candidate["finishReason"].as_str().map(|r| match r {
220                    "STOP" => "stop".to_string(),
221                    "MAX_TOKENS" => "length".to_string(),
222                    _ => r.to_string(),
223                });
224
225                Ok(Choice {
226                    index: index as u32,
227                    message: Message {
228                        role: Role::Assistant,
229                        content: crate::types::common::Content::Text(content.to_string()),
230                        function_call,
231                    },
232                    finish_reason,
233                })
234            })
235            .collect();
236
237        let usage = Usage {
238            prompt_tokens: response["usageMetadata"]["promptTokenCount"]
239                .as_u64()
240                .unwrap_or(0) as u32,
241            completion_tokens: response["usageMetadata"]["candidatesTokenCount"]
242                .as_u64()
243                .unwrap_or(0) as u32,
244            total_tokens: response["usageMetadata"]["totalTokenCount"]
245                .as_u64()
246                .unwrap_or(0) as u32,
247        };
248
249        Ok(ChatCompletionResponse {
250            id: format!("gemini-{}", chrono::Utc::now().timestamp()),
251            object: "chat.completion".to_string(),
252            created: chrono::Utc::now().timestamp() as u64,
253            model: model.to_string(),
254            choices: choices?,
255            usage,
256            usage_status: UsageStatus::Finalized, // Gemini provides accurate usage data
257        })
258    }
259}
260
261#[async_trait::async_trait]
262impl ChatApi for GeminiAdapter {
263    async fn chat_completion(
264        &self,
265        request: ChatCompletionRequest,
266    ) -> Result<ChatCompletionResponse, AiLibError> {
267        self.metrics.incr_counter("gemini.requests", 1).await;
268        let timer = self.metrics.start_timer("gemini.request_duration_ms").await;
269
270        let gemini_request = self.convert_to_gemini_request(&request);
271
272        // Gemini uses URL parameter authentication, not headers
273        let url = format!("{}/models/{}:generateContent", self.base_url, request.model);
274
275        let headers = HashMap::from([
276            ("Content-Type".to_string(), "application/json".to_string()),
277            ("x-goog-api-key".to_string(), self.api_key.clone()),
278        ]);
279
280        // Use unified transport
281        let response_json = self
282            .transport
283            .post_json(&url, Some(headers), gemini_request)
284            .await?;
285        if let Some(t) = timer {
286            t.stop();
287        }
288        self.parse_gemini_response(response_json, &request.model)
289    }
290
291    async fn chat_completion_stream(
292        &self,
293        request: ChatCompletionRequest,
294    ) -> Result<
295        Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
296        AiLibError,
297    > {
298        // Try native SSE first per Gemini API streamGenerateContent
299        let url = format!(
300            "{}/models/{}:streamGenerateContent",
301            self.base_url, request.model
302        );
303        let gemini_request = self.convert_to_gemini_request(&request);
304        let mut headers = HashMap::new();
305        headers.insert("Content-Type".to_string(), "application/json".to_string());
306        headers.insert("Accept".to_string(), "text/event-stream".to_string());
307        headers.insert("x-goog-api-key".to_string(), self.api_key.clone());
308
309        if let Ok(mut byte_stream) = self
310            .transport
311            .post_stream(&url, Some(headers), gemini_request)
312            .await
313        {
314            let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
315            tokio::spawn(async move {
316                let mut buffer = Vec::new();
317                while let Some(item) = byte_stream.next().await {
318                    match item {
319                        Ok(bytes) => {
320                            buffer.extend_from_slice(&bytes);
321                            #[cfg(feature = "unified_sse")]
322                            {
323                                while let Some(boundary) =
324                                    crate::sse::parser::find_event_boundary(&buffer)
325                                {
326                                    let event_bytes = buffer.drain(..boundary).collect::<Vec<_>>();
327                                    if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
328                                        for line in event_text.lines() {
329                                            let line = line.trim();
330                                            if let Some(data) = line.strip_prefix("data: ") {
331                                                if data.is_empty() {
332                                                    continue;
333                                                }
334                                                if data == "[DONE]" {
335                                                    return;
336                                                }
337                                                match serde_json::from_str::<serde_json::Value>(
338                                                    data,
339                                                ) {
340                                                    Ok(json) => {
341                                                        let text = json
342                                                            .get("candidates")
343                                                            .and_then(|c| c.as_array())
344                                                            .and_then(|arr| arr.first())
345                                                            .and_then(|cand| {
346                                                                cand.get("content")
347                                                                    .and_then(|c| c.get("parts"))
348                                                                    .and_then(|p| p.as_array())
349                                                                    .and_then(|parts| parts.first())
350                                                                    .and_then(|part| {
351                                                                        part.get("text")
352                                                                    })
353                                                                    .and_then(|t| t.as_str())
354                                                            })
355                                                            .map(|s| s.to_string());
356                                                        if let Some(tdelta) = text {
357                                                            let delta = crate::api::ChoiceDelta { index: 0, delta: crate::api::MessageDelta { role: Some(crate::types::Role::Assistant), content: Some(tdelta) }, finish_reason: json.get("candidates").and_then(|c| c.as_array()).and_then(|arr| arr.first()).and_then(|cand| cand.get("finishReason").or_else(|| json.get("finishReason"))).and_then(|v| v.as_str()).map(|r| match r { "STOP" => "stop".to_string(), "MAX_TOKENS" => "length".to_string(), other => other.to_string() }) };
358                                                            let chunk_obj = ChatCompletionChunk {
359                                                                id: json
360                                                                    .get("responseId")
361                                                                    .and_then(|v| v.as_str())
362                                                                    .unwrap_or("")
363                                                                    .to_string(),
364                                                                object: "chat.completion.chunk"
365                                                                    .to_string(),
366                                                                created: 0,
367                                                                model: request.model.clone(),
368                                                                choices: vec![delta],
369                                                            };
370                                                            if tx.send(Ok(chunk_obj)).is_err() {
371                                                                return;
372                                                            }
373                                                        }
374                                                    }
375                                                    Err(e) => {
376                                                        let _ = tx.send(Err(
377                                                            AiLibError::ProviderError(format!(
378                                                                "Gemini SSE JSON parse error: {}",
379                                                                e
380                                                            )),
381                                                        ));
382                                                        return;
383                                                    }
384                                                }
385                                            }
386                                        }
387                                    }
388                                }
389                            }
390                            #[cfg(not(feature = "unified_sse"))]
391                            {
392                                fn find_event_boundary(buffer: &[u8]) -> Option<usize> {
393                                    let mut i = 0;
394                                    while i + 1 < buffer.len() {
395                                        if buffer[i] == b'\n' && buffer[i + 1] == b'\n' {
396                                            return Some(i + 2);
397                                        }
398                                        if i + 3 < buffer.len()
399                                            && buffer[i] == b'\r'
400                                            && buffer[i + 1] == b'\n'
401                                            && buffer[i + 2] == b'\r'
402                                            && buffer[i + 3] == b'\n'
403                                        {
404                                            return Some(i + 4);
405                                        }
406                                        i += 1;
407                                    }
408                                    None
409                                }
410                                while let Some(boundary) = find_event_boundary(&buffer) {
411                                    let event_bytes = buffer.drain(..boundary).collect::<Vec<_>>();
412                                    if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
413                                        for line in event_text.lines() {
414                                            let line = line.trim();
415                                            if let Some(data) = line.strip_prefix("data: ") {
416                                                if data.is_empty() {
417                                                    continue;
418                                                }
419                                                if data == "[DONE]" {
420                                                    return;
421                                                }
422                                                match serde_json::from_str::<serde_json::Value>(
423                                                    data,
424                                                ) {
425                                                    Ok(json) => {
426                                                        let text = json
427                                                            .get("candidates")
428                                                            .and_then(|c| c.as_array())
429                                                            .and_then(|arr| arr.first())
430                                                            .and_then(|cand| {
431                                                                cand.get("content")
432                                                                    .and_then(|c| c.get("parts"))
433                                                                    .and_then(|p| p.as_array())
434                                                                    .and_then(|parts| parts.first())
435                                                                    .and_then(|part| {
436                                                                        part.get("text")
437                                                                    })
438                                                                    .and_then(|t| t.as_str())
439                                                            })
440                                                            .map(|s| s.to_string());
441                                                        if let Some(tdelta) = text {
442                                                            let delta = crate::api::ChoiceDelta { index: 0, delta: crate::api::MessageDelta { role: Some(crate::types::Role::Assistant), content: Some(tdelta) }, finish_reason: None };
443                                                            let chunk_obj = ChatCompletionChunk {
444                                                                id: json
445                                                                    .get("responseId")
446                                                                    .and_then(|v| v.as_str())
447                                                                    .unwrap_or("")
448                                                                    .to_string(),
449                                                                object: "chat.completion.chunk"
450                                                                    .to_string(),
451                                                                created: 0,
452                                                                model: request.model.clone(),
453                                                                choices: vec![delta],
454                                                            };
455                                                            if tx.send(Ok(chunk_obj)).is_err() {
456                                                                return;
457                                                            }
458                                                        }
459                                                    }
460                                                    Err(e) => {
461                                                        let _ = tx.send(Err(
462                                                            AiLibError::ProviderError(format!(
463                                                                "Gemini SSE JSON parse error: {}",
464                                                                e
465                                                            )),
466                                                        ));
467                                                        return;
468                                                    }
469                                                }
470                                            }
471                                        }
472                                    }
473                                }
474                            }
475                        }
476                        Err(e) => {
477                            let _ = tx.send(Err(AiLibError::ProviderError(format!(
478                                "Stream error: {}",
479                                e
480                            ))));
481                            break;
482                        }
483                    }
484                }
485            });
486            let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx);
487            return Ok(Box::new(Box::pin(stream)));
488        }
489
490        // Fallback to non-streaming + simulated chunks
491        fn split_text_into_chunks(text: &str, max_len: usize) -> Vec<String> {
492            let mut chunks = Vec::new();
493            let mut start = 0;
494            let bytes = text.as_bytes();
495            while start < bytes.len() {
496                let end = std::cmp::min(start + max_len, bytes.len());
497                let mut cut = end;
498                if end < bytes.len() {
499                    if let Some(pos) = text[start..end].rfind(' ') {
500                        cut = start + pos;
501                    }
502                }
503                if cut == start {
504                    cut = end;
505                }
506                chunks.push(String::from_utf8_lossy(&bytes[start..cut]).to_string());
507                start = cut;
508                if start < bytes.len() && bytes[start] == b' ' {
509                    start += 1;
510                }
511            }
512            chunks
513        }
514
515        let finished = self.chat_completion(request).await?;
516        let text = finished
517            .choices
518            .first()
519            .map(|c| c.message.content.as_text())
520            .unwrap_or_default();
521        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
522        tokio::spawn(async move {
523            let chunks = split_text_into_chunks(&text, 80);
524            for chunk in chunks {
525                let delta = crate::api::ChoiceDelta {
526                    index: 0,
527                    delta: crate::api::MessageDelta {
528                        role: Some(crate::types::Role::Assistant),
529                        content: Some(chunk.clone()),
530                    },
531                    finish_reason: None,
532                };
533                let chunk_obj = ChatCompletionChunk {
534                    id: "simulated".to_string(),
535                    object: "chat.completion.chunk".to_string(),
536                    created: 0,
537                    model: finished.model.clone(),
538                    choices: vec![delta],
539                };
540                if tx.send(Ok(chunk_obj)).is_err() {
541                    return;
542                }
543                tokio::time::sleep(std::time::Duration::from_millis(50)).await;
544            }
545        });
546        let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx);
547        Ok(Box::new(Box::pin(stream)))
548    }
549
550    async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
551        // Common Gemini models
552        Ok(vec![
553            "gemini-1.5-pro".to_string(),
554            "gemini-1.5-flash".to_string(),
555            "gemini-1.0-pro".to_string(),
556        ])
557    }
558
559    async fn get_model_info(&self, model_id: &str) -> Result<ModelInfo, AiLibError> {
560        Ok(ModelInfo {
561            id: model_id.to_string(),
562            object: "model".to_string(),
563            created: 0,
564            owned_by: "google".to_string(),
565            permission: vec![ModelPermission {
566                id: "default".to_string(),
567                object: "model_permission".to_string(),
568                created: 0,
569                allow_create_engine: false,
570                allow_sampling: true,
571                allow_logprobs: false,
572                allow_search_indices: false,
573                allow_view: true,
574                allow_fine_tuning: false,
575                organization: "*".to_string(),
576                group: None,
577                is_blocking: false,
578            }],
579        })
580    }
581}