ai_lib/provider/
gemini.rs

1use crate::api::{ChatApi, ChatCompletionChunk, ModelInfo, ModelPermission};
2use crate::types::{ChatCompletionRequest, ChatCompletionResponse, AiLibError, Message, Role, Choice, Usage};
3use crate::transport::{HttpClient, HttpTransport};
4use std::env;
5use std::collections::HashMap;
6use futures::stream::{self, Stream};
7
8/// Google Gemini独立适配器,支持多模态AI服务
9/// 
10/// Google Gemini independent adapter for multimodal AI service
11/// 
12/// Gemini API is completely different from OpenAI format, requires independent adapter:
13/// - Endpoint: /v1beta/models/{model}:generateContent
14/// - Request body: contents array instead of messages
15/// - Response: candidates[0].content.parts[0].text
16/// - Authentication: URL parameter ?key=<API_KEY>
17pub struct GeminiAdapter {
18    transport: HttpTransport,
19    api_key: String,
20    base_url: String,
21}
22
23impl GeminiAdapter {
24    pub fn new() -> Result<Self, AiLibError> {
25        let api_key = env::var("GEMINI_API_KEY")
26            .map_err(|_| AiLibError::AuthenticationError(
27                "GEMINI_API_KEY environment variable not set".to_string()
28            ))?;
29        
30        Ok(Self {
31            transport: HttpTransport::new(),
32            api_key,
33            base_url: "https://generativelanguage.googleapis.com/v1beta".to_string(),
34        })
35    }
36
37    /// Convert generic request to Gemini format
38    fn convert_to_gemini_request(&self, request: &ChatCompletionRequest) -> serde_json::Value {
39        let contents: Vec<serde_json::Value> = request.messages.iter().map(|msg| {
40            let role = match msg.role {
41                Role::User => "user",
42                Role::Assistant => "model", // Gemini uses "model" instead of "assistant"
43                Role::System => "user", // Gemini has no system role, convert to user
44            };
45            
46            serde_json::json!({
47                "role": role,
48                "parts": [{"text": msg.content}]
49            })
50        }).collect();
51
52        let mut gemini_request = serde_json::json!({
53            "contents": contents
54        });
55
56        // Gemini generation configuration
57        let mut generation_config = serde_json::json!({});
58        
59        if let Some(temp) = request.temperature {
60            generation_config["temperature"] = serde_json::Value::Number(
61                serde_json::Number::from_f64(temp.into()).unwrap()
62            );
63        }
64        if let Some(max_tokens) = request.max_tokens {
65            generation_config["maxOutputTokens"] = serde_json::Value::Number(
66                serde_json::Number::from(max_tokens)
67            );
68        }
69        if let Some(top_p) = request.top_p {
70            generation_config["topP"] = serde_json::Value::Number(
71                serde_json::Number::from_f64(top_p.into()).unwrap()
72            );
73        }
74
75        if !generation_config.as_object().unwrap().is_empty() {
76            gemini_request["generationConfig"] = generation_config;
77        }
78
79        gemini_request
80    }
81
82    /// Parse Gemini response to generic format
83    fn parse_gemini_response(&self, response: serde_json::Value, model: &str) -> Result<ChatCompletionResponse, AiLibError> {
84        let candidates = response["candidates"].as_array()
85            .ok_or_else(|| AiLibError::ProviderError("No candidates in Gemini response".to_string()))?;
86
87        let choices: Result<Vec<Choice>, AiLibError> = candidates.iter().enumerate().map(|(index, candidate)| {
88            let content = candidate["content"]["parts"][0]["text"].as_str()
89                .ok_or_else(|| AiLibError::ProviderError("No text in Gemini candidate".to_string()))?;
90            
91            let finish_reason = candidate["finishReason"].as_str().map(|r| match r {
92                "STOP" => "stop".to_string(),
93                "MAX_TOKENS" => "length".to_string(),
94                _ => r.to_string(),
95            });
96
97            Ok(Choice {
98                index: index as u32,
99                message: Message {
100                    role: Role::Assistant,
101                    content: content.to_string(),
102                },
103                finish_reason,
104            })
105        }).collect();
106
107        let usage = Usage {
108            prompt_tokens: response["usageMetadata"]["promptTokenCount"].as_u64().unwrap_or(0) as u32,
109            completion_tokens: response["usageMetadata"]["candidatesTokenCount"].as_u64().unwrap_or(0) as u32,
110            total_tokens: response["usageMetadata"]["totalTokenCount"].as_u64().unwrap_or(0) as u32,
111        };
112
113        Ok(ChatCompletionResponse {
114            id: format!("gemini-{}", chrono::Utc::now().timestamp()),
115            object: "chat.completion".to_string(),
116            created: chrono::Utc::now().timestamp() as u64,
117            model: model.to_string(),
118            choices: choices?,
119            usage,
120        })
121    }
122}
123
124#[async_trait::async_trait]
125impl ChatApi for GeminiAdapter {
126    async fn chat_completion(&self, request: ChatCompletionRequest) -> Result<ChatCompletionResponse, AiLibError> {
127        let gemini_request = self.convert_to_gemini_request(&request);
128        
129        // Gemini uses URL parameter authentication, not headers
130        let url = format!(
131            "{}/models/{}:generateContent?key={}",
132            self.base_url, request.model, self.api_key
133        );
134
135        let headers = HashMap::from([
136            ("Content-Type".to_string(), "application/json".to_string()),
137        ]);
138
139        let response: serde_json::Value = self.transport
140            .post(&url, Some(headers), &gemini_request)
141            .await?;
142
143        self.parse_gemini_response(response, &request.model)
144    }
145
146    async fn chat_completion_stream(&self, _request: ChatCompletionRequest) -> Result<Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>, AiLibError> {
147        // Gemini streaming response requires special handling, return empty stream for now
148        let stream = stream::empty();
149        Ok(Box::new(Box::pin(stream)))
150    }
151
152    async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
153        // Common Gemini models
154        Ok(vec![
155            "gemini-1.5-pro".to_string(),
156            "gemini-1.5-flash".to_string(),
157            "gemini-1.0-pro".to_string(),
158        ])
159    }
160
161    async fn get_model_info(&self, model_id: &str) -> Result<ModelInfo, AiLibError> {
162        Ok(ModelInfo {
163            id: model_id.to_string(),
164            object: "model".to_string(),
165            created: 0,
166            owned_by: "google".to_string(),
167            permission: vec![ModelPermission {
168                id: "default".to_string(),
169                object: "model_permission".to_string(),
170                created: 0,
171                allow_create_engine: false,
172                allow_sampling: true,
173                allow_logprobs: false,
174                allow_search_indices: false,
175                allow_view: true,
176                allow_fine_tuning: false,
177                organization: "*".to_string(),
178                group: None,
179                is_blocking: false,
180            }],
181        })
182    }
183}