ai_lib/provider/
gemini.rs

1use crate::api::{ChatCompletionChunk, ChatProvider, ModelInfo, ModelPermission};
2use crate::metrics::{Metrics, NoopMetrics};
3use crate::transport::{DynHttpTransportRef, HttpTransport};
4use crate::types::{
5    AiLibError, ChatCompletionRequest, ChatCompletionResponse, Choice, Message, Role, Usage,
6    UsageStatus,
7};
8use futures::stream::Stream;
9use futures::StreamExt;
10use std::collections::HashMap;
11use std::env;
12use std::sync::Arc;
13#[cfg(feature = "unified_transport")]
14use std::time::Duration;
15
16/// Google Gemini independent adapter, supporting multimodal AI services
17///
18/// Google Gemini independent adapter for multimodal AI service
19///
20/// Gemini API is completely different from OpenAI format, requires independent adapter:
21/// - Endpoint: /v1beta/models/{model}:generateContent
22/// - Request body: contents array instead of messages
23/// - Response: candidates\[0\].content.parts\[0\].text
24/// - Authentication: URL parameter ?key=<API_KEY>
25pub struct GeminiAdapter {
26    #[allow(dead_code)] // Kept for backward compatibility, now using direct reqwest
27    transport: DynHttpTransportRef,
28    api_key: String,
29    base_url: String,
30    metrics: Arc<dyn Metrics>,
31}
32
33impl GeminiAdapter {
34    #[allow(dead_code)]
35    fn build_default_timeout_secs() -> u64 {
36        std::env::var("AI_HTTP_TIMEOUT_SECS")
37            .ok()
38            .and_then(|s| s.parse::<u64>().ok())
39            .unwrap_or(30)
40    }
41
42    fn build_default_transport() -> Result<DynHttpTransportRef, AiLibError> {
43        #[cfg(feature = "unified_transport")]
44        {
45            let timeout = Duration::from_secs(Self::build_default_timeout_secs());
46            let client = crate::transport::client_factory::build_shared_client().map_err(|e| {
47                AiLibError::NetworkError(format!("Failed to build http client: {}", e))
48            })?;
49            let t = HttpTransport::with_reqwest_client(client, timeout);
50            Ok(t.boxed())
51        }
52        #[cfg(not(feature = "unified_transport"))]
53        {
54            let t = HttpTransport::new();
55            return Ok(t.boxed());
56        }
57    }
58
59    pub fn new() -> Result<Self, AiLibError> {
60        let api_key = env::var("GEMINI_API_KEY").map_err(|_| {
61            AiLibError::AuthenticationError(
62                "GEMINI_API_KEY environment variable not set".to_string(),
63            )
64        })?;
65
66        Ok(Self {
67            transport: Self::build_default_transport()?,
68            api_key,
69            base_url: "https://generativelanguage.googleapis.com/v1beta".to_string(),
70            metrics: Arc::new(NoopMetrics::new()),
71        })
72    }
73
74    /// Explicit overrides for api_key and optional base_url (takes precedence over env vars)
75    pub fn new_with_overrides(
76        api_key: String,
77        base_url: Option<String>,
78    ) -> Result<Self, AiLibError> {
79        Ok(Self {
80            transport: Self::build_default_transport()?,
81            api_key,
82            base_url: base_url
83                .unwrap_or_else(|| "https://generativelanguage.googleapis.com/v1beta".to_string()),
84            metrics: Arc::new(NoopMetrics::new()),
85        })
86    }
87
88    /// Construct using object-safe transport reference
89    pub fn with_transport_ref(
90        transport: DynHttpTransportRef,
91        api_key: String,
92        base_url: String,
93    ) -> Result<Self, AiLibError> {
94        Ok(Self {
95            transport,
96            api_key,
97            base_url,
98            metrics: Arc::new(NoopMetrics::new()),
99        })
100    }
101
102    /// Construct with an injected transport and metrics implementation
103    pub fn with_transport_ref_and_metrics(
104        transport: DynHttpTransportRef,
105        api_key: String,
106        base_url: String,
107        metrics: Arc<dyn Metrics>,
108    ) -> Result<Self, AiLibError> {
109        Ok(Self {
110            transport,
111            api_key,
112            base_url,
113            metrics,
114        })
115    }
116
117    /// Convert generic request to Gemini format
118    fn convert_to_gemini_request(&self, request: &ChatCompletionRequest) -> serde_json::Value {
119        let contents: Vec<serde_json::Value> = request
120            .messages
121            .iter()
122            .map(|msg| {
123                let role = match msg.role {
124                    Role::User => "user",
125                    Role::Assistant => "model", // Gemini uses "model" instead of "assistant"
126                    Role::System => "user",     // Gemini has no system role, convert to user
127                };
128
129                serde_json::json!({
130                    "role": role,
131                    "parts": [{"text": msg.content.as_text()}]
132                })
133            })
134            .collect();
135
136        let mut gemini_request = serde_json::json!({
137            "contents": contents
138        });
139
140        // Gemini generation configuration
141        let mut generation_config = serde_json::json!({});
142
143        if let Some(temp) = request.temperature {
144            generation_config["temperature"] =
145                serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
146        }
147        if let Some(max_tokens) = request.max_tokens {
148            generation_config["maxOutputTokens"] =
149                serde_json::Value::Number(serde_json::Number::from(max_tokens));
150        }
151        if let Some(top_p) = request.top_p {
152            generation_config["topP"] =
153                serde_json::Value::Number(serde_json::Number::from_f64(top_p.into()).unwrap());
154        }
155
156        if !generation_config.as_object().unwrap().is_empty() {
157            gemini_request["generationConfig"] = generation_config;
158        }
159
160        request.apply_extensions(&mut gemini_request);
161
162        gemini_request
163    }
164
165    /// Parse Gemini response to generic format
166    fn parse_gemini_response(
167        &self,
168        response: serde_json::Value,
169        model: &str,
170    ) -> Result<ChatCompletionResponse, AiLibError> {
171        let candidates = response["candidates"].as_array().ok_or_else(|| {
172            AiLibError::ProviderError("No candidates in Gemini response".to_string())
173        })?;
174
175        let choices: Result<Vec<Choice>, AiLibError> = candidates
176            .iter()
177            .enumerate()
178            .map(|(index, candidate)| {
179                let content = candidate["content"]["parts"][0]["text"]
180                    .as_str()
181                    .ok_or_else(|| {
182                        AiLibError::ProviderError("No text in Gemini candidate".to_string())
183                    })?;
184
185                // Try to parse a function_call if the provider returned one. Gemini's
186                // response shape may place structured data under candidate["function_call"]
187                // or nested inside candidate["content"]["function_call"]. We try both.
188                let mut function_call: Option<crate::types::function_call::FunctionCall> = None;
189                if let Some(fc_val) = candidate.get("function_call").cloned().or_else(|| {
190                    candidate
191                        .get("content")
192                        .and_then(|c| c.get("function_call"))
193                        .cloned()
194                }) {
195                    if let Ok(fc) = serde_json::from_value::<
196                        crate::types::function_call::FunctionCall,
197                    >(fc_val.clone())
198                    {
199                        function_call = Some(fc);
200                    } else {
201                        // Fallback: extract name + arguments (arguments may be a JSON string)
202                        if let Some(name) = fc_val
203                            .get("name")
204                            .and_then(|v| v.as_str())
205                            .map(|s| s.to_string())
206                        {
207                            let args = fc_val.get("arguments").and_then(|a| {
208                                if a.is_string() {
209                                    serde_json::from_str::<serde_json::Value>(a.as_str().unwrap())
210                                        .ok()
211                                } else {
212                                    Some(a.clone())
213                                }
214                            });
215                            function_call = Some(crate::types::function_call::FunctionCall {
216                                name,
217                                arguments: args,
218                            });
219                        }
220                    }
221                }
222
223                let finish_reason = candidate["finishReason"].as_str().map(|r| match r {
224                    "STOP" => "stop".to_string(),
225                    "MAX_TOKENS" => "length".to_string(),
226                    _ => r.to_string(),
227                });
228
229                Ok(Choice {
230                    index: index as u32,
231                    message: Message {
232                        role: Role::Assistant,
233                        content: crate::types::common::Content::Text(content.to_string()),
234                        function_call,
235                    },
236                    finish_reason,
237                })
238            })
239            .collect();
240
241        let usage = Usage {
242            prompt_tokens: response["usageMetadata"]["promptTokenCount"]
243                .as_u64()
244                .unwrap_or(0) as u32,
245            completion_tokens: response["usageMetadata"]["candidatesTokenCount"]
246                .as_u64()
247                .unwrap_or(0) as u32,
248            total_tokens: response["usageMetadata"]["totalTokenCount"]
249                .as_u64()
250                .unwrap_or(0) as u32,
251        };
252
253        Ok(ChatCompletionResponse {
254            id: format!("gemini-{}", chrono::Utc::now().timestamp()),
255            object: "chat.completion".to_string(),
256            created: chrono::Utc::now().timestamp() as u64,
257            model: model.to_string(),
258            choices: choices?,
259            usage,
260            usage_status: UsageStatus::Finalized, // Gemini provides accurate usage data
261        })
262    }
263}
264
265#[async_trait::async_trait]
266impl ChatProvider for GeminiAdapter {
267    fn name(&self) -> &str {
268        "Gemini"
269    }
270
271    async fn chat(
272        &self,
273        request: ChatCompletionRequest,
274    ) -> Result<ChatCompletionResponse, AiLibError> {
275        self.metrics.incr_counter("gemini.requests", 1).await;
276        let timer = self.metrics.start_timer("gemini.request_duration_ms").await;
277
278        let gemini_request = self.convert_to_gemini_request(&request);
279
280        // Gemini uses URL parameter authentication, not headers
281        let url = format!("{}/models/{}:generateContent", self.base_url, request.model);
282
283        let headers = HashMap::from([
284            ("Content-Type".to_string(), "application/json".to_string()),
285            ("x-goog-api-key".to_string(), self.api_key.clone()),
286        ]);
287
288        // Use unified transport
289        let response_json = self
290            .transport
291            .post_json(&url, Some(headers), gemini_request)
292            .await
293            .map_err(|e| e.with_context(&format!("Gemini chat request to {}", url)))?;
294        if let Some(t) = timer {
295            t.stop();
296        }
297        self.parse_gemini_response(response_json, &request.model)
298    }
299
300    async fn stream(
301        &self,
302        request: ChatCompletionRequest,
303    ) -> Result<
304        Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
305        AiLibError,
306    > {
307        // Try native SSE first per Gemini API streamGenerateContent
308        let url = format!(
309            "{}/models/{}:streamGenerateContent",
310            self.base_url, request.model
311        );
312        let gemini_request = self.convert_to_gemini_request(&request);
313        let mut headers = HashMap::new();
314        headers.insert("Content-Type".to_string(), "application/json".to_string());
315        headers.insert("Accept".to_string(), "text/event-stream".to_string());
316        headers.insert("x-goog-api-key".to_string(), self.api_key.clone());
317
318        if let Ok(mut byte_stream) = self
319            .transport
320            .post_stream(&url, Some(headers), gemini_request)
321            .await
322        {
323            let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
324            tokio::spawn(async move {
325                let mut buffer = Vec::new();
326                while let Some(item) = byte_stream.next().await {
327                    match item {
328                        Ok(bytes) => {
329                            buffer.extend_from_slice(&bytes);
330                            #[cfg(feature = "unified_sse")]
331                            {
332                                while let Some(boundary) =
333                                    crate::sse::parser::find_event_boundary(&buffer)
334                                {
335                                    let event_bytes = buffer.drain(..boundary).collect::<Vec<_>>();
336                                    if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
337                                        for line in event_text.lines() {
338                                            let line = line.trim();
339                                            if let Some(data) = line.strip_prefix("data: ") {
340                                                if data.is_empty() {
341                                                    continue;
342                                                }
343                                                if data == "[DONE]" {
344                                                    return;
345                                                }
346                                                match serde_json::from_str::<serde_json::Value>(
347                                                    data,
348                                                ) {
349                                                    Ok(json) => {
350                                                        let text = json
351                                                            .get("candidates")
352                                                            .and_then(|c| c.as_array())
353                                                            .and_then(|arr| arr.first())
354                                                            .and_then(|cand| {
355                                                                cand.get("content")
356                                                                    .and_then(|c| c.get("parts"))
357                                                                    .and_then(|p| p.as_array())
358                                                                    .and_then(|parts| parts.first())
359                                                                    .and_then(|part| {
360                                                                        part.get("text")
361                                                                    })
362                                                                    .and_then(|t| t.as_str())
363                                                            })
364                                                            .map(|s| s.to_string());
365                                                        if let Some(tdelta) = text {
366                                                            let delta = crate::api::ChoiceDelta { index: 0, delta: crate::api::MessageDelta { role: Some(crate::types::Role::Assistant), content: Some(tdelta) }, finish_reason: json.get("candidates").and_then(|c| c.as_array()).and_then(|arr| arr.first()).and_then(|cand| cand.get("finishReason").or_else(|| json.get("finishReason"))).and_then(|v| v.as_str()).map(|r| match r { "STOP" => "stop".to_string(), "MAX_TOKENS" => "length".to_string(), other => other.to_string() }) };
367                                                            let chunk_obj = ChatCompletionChunk {
368                                                                id: json
369                                                                    .get("responseId")
370                                                                    .and_then(|v| v.as_str())
371                                                                    .unwrap_or("")
372                                                                    .to_string(),
373                                                                object: "chat.completion.chunk"
374                                                                    .to_string(),
375                                                                created: 0,
376                                                                model: request.model.clone(),
377                                                                choices: vec![delta],
378                                                            };
379                                                            if tx.send(Ok(chunk_obj)).is_err() {
380                                                                return;
381                                                            }
382                                                        }
383                                                    }
384                                                    Err(e) => {
385                                                        let _ = tx.send(Err(
386                                                            AiLibError::ProviderError(format!(
387                                                                "Gemini SSE JSON parse error: {}",
388                                                                e
389                                                            )),
390                                                        ));
391                                                        return;
392                                                    }
393                                                }
394                                            }
395                                        }
396                                    }
397                                }
398                            }
399                            #[cfg(not(feature = "unified_sse"))]
400                            {
401                                fn find_event_boundary(buffer: &[u8]) -> Option<usize> {
402                                    let mut i = 0;
403                                    while i + 1 < buffer.len() {
404                                        if buffer[i] == b'\n' && buffer[i + 1] == b'\n' {
405                                            return Some(i + 2);
406                                        }
407                                        if i + 3 < buffer.len()
408                                            && buffer[i] == b'\r'
409                                            && buffer[i + 1] == b'\n'
410                                            && buffer[i + 2] == b'\r'
411                                            && buffer[i + 3] == b'\n'
412                                        {
413                                            return Some(i + 4);
414                                        }
415                                        i += 1;
416                                    }
417                                    None
418                                }
419                                while let Some(boundary) = find_event_boundary(&buffer) {
420                                    let event_bytes = buffer.drain(..boundary).collect::<Vec<_>>();
421                                    if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
422                                        for line in event_text.lines() {
423                                            let line = line.trim();
424                                            if let Some(data) = line.strip_prefix("data: ") {
425                                                if data.is_empty() {
426                                                    continue;
427                                                }
428                                                if data == "[DONE]" {
429                                                    return;
430                                                }
431                                                match serde_json::from_str::<serde_json::Value>(
432                                                    data,
433                                                ) {
434                                                    Ok(json) => {
435                                                        let text = json
436                                                            .get("candidates")
437                                                            .and_then(|c| c.as_array())
438                                                            .and_then(|arr| arr.first())
439                                                            .and_then(|cand| {
440                                                                cand.get("content")
441                                                                    .and_then(|c| c.get("parts"))
442                                                                    .and_then(|p| p.as_array())
443                                                                    .and_then(|parts| parts.first())
444                                                                    .and_then(|part| {
445                                                                        part.get("text")
446                                                                    })
447                                                                    .and_then(|t| t.as_str())
448                                                            })
449                                                            .map(|s| s.to_string());
450                                                        if let Some(tdelta) = text {
451                                                            let delta = crate::api::ChoiceDelta { index: 0, delta: crate::api::MessageDelta { role: Some(crate::types::Role::Assistant), content: Some(tdelta) }, finish_reason: None };
452                                                            let chunk_obj = ChatCompletionChunk {
453                                                                id: json
454                                                                    .get("responseId")
455                                                                    .and_then(|v| v.as_str())
456                                                                    .unwrap_or("")
457                                                                    .to_string(),
458                                                                object: "chat.completion.chunk"
459                                                                    .to_string(),
460                                                                created: 0,
461                                                                model: request.model.clone(),
462                                                                choices: vec![delta],
463                                                            };
464                                                            if tx.send(Ok(chunk_obj)).is_err() {
465                                                                return;
466                                                            }
467                                                        }
468                                                    }
469                                                    Err(e) => {
470                                                        let _ = tx.send(Err(
471                                                            AiLibError::ProviderError(format!(
472                                                                "Gemini SSE JSON parse error: {}",
473                                                                e
474                                                            )),
475                                                        ));
476                                                        return;
477                                                    }
478                                                }
479                                            }
480                                        }
481                                    }
482                                }
483                            }
484                        }
485                        Err(e) => {
486                            let _ = tx.send(Err(AiLibError::ProviderError(format!(
487                                "Stream error: {}",
488                                e
489                            ))));
490                            break;
491                        }
492                    }
493                }
494            });
495            let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx);
496            return Ok(Box::new(Box::pin(stream)));
497        }
498
499        // Fallback to non-streaming + simulated chunks
500        fn split_text_into_chunks(text: &str, max_len: usize) -> Vec<String> {
501            let mut chunks = Vec::new();
502            let mut start = 0;
503            let bytes = text.as_bytes();
504            while start < bytes.len() {
505                let end = std::cmp::min(start + max_len, bytes.len());
506                let mut cut = end;
507                if end < bytes.len() {
508                    if let Some(pos) = text[start..end].rfind(' ') {
509                        cut = start + pos;
510                    }
511                }
512                if cut == start {
513                    cut = end;
514                }
515                chunks.push(String::from_utf8_lossy(&bytes[start..cut]).to_string());
516                start = cut;
517                if start < bytes.len() && bytes[start] == b' ' {
518                    start += 1;
519                }
520            }
521            chunks
522        }
523
524        let finished = self.chat(request).await?;
525        let text = finished
526            .choices
527            .first()
528            .map(|c| c.message.content.as_text())
529            .unwrap_or_default();
530        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
531        tokio::spawn(async move {
532            let chunks = split_text_into_chunks(&text, 80);
533            for chunk in chunks {
534                let delta = crate::api::ChoiceDelta {
535                    index: 0,
536                    delta: crate::api::MessageDelta {
537                        role: Some(crate::types::Role::Assistant),
538                        content: Some(chunk.clone()),
539                    },
540                    finish_reason: None,
541                };
542                let chunk_obj = ChatCompletionChunk {
543                    id: "simulated".to_string(),
544                    object: "chat.completion.chunk".to_string(),
545                    created: 0,
546                    model: finished.model.clone(),
547                    choices: vec![delta],
548                };
549                if tx.send(Ok(chunk_obj)).is_err() {
550                    return;
551                }
552                tokio::time::sleep(std::time::Duration::from_millis(50)).await;
553            }
554        });
555        let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx);
556        Ok(Box::new(Box::pin(stream)))
557    }
558
559    async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
560        // Common Gemini models
561        Ok(vec![
562            "gemini-1.5-pro".to_string(),
563            "gemini-1.5-flash".to_string(),
564            "gemini-1.0-pro".to_string(),
565        ])
566    }
567
568    async fn get_model_info(&self, model_id: &str) -> Result<ModelInfo, AiLibError> {
569        Ok(ModelInfo {
570            id: model_id.to_string(),
571            object: "model".to_string(),
572            created: 0,
573            owned_by: "google".to_string(),
574            permission: vec![ModelPermission {
575                id: "default".to_string(),
576                object: "model_permission".to_string(),
577                created: 0,
578                allow_create_engine: false,
579                allow_sampling: true,
580                allow_logprobs: false,
581                allow_search_indices: false,
582                allow_view: true,
583                allow_fine_tuning: false,
584                organization: "*".to_string(),
585                group: None,
586                is_blocking: false,
587            }],
588        })
589    }
590}