1use crate::api::{ChatApi, ChatCompletionChunk, ModelInfo, ModelPermission};
2use crate::metrics::{Metrics, NoopMetrics};
3use crate::transport::{DynHttpTransportRef, HttpTransport};
4use crate::types::{
5 AiLibError, ChatCompletionRequest, ChatCompletionResponse, Choice, Message, Role, Usage,
6};
7use futures::stream::{self, Stream};
8use std::collections::HashMap;
9use std::env;
10use std::sync::Arc;
11
12pub struct GeminiAdapter {
22 transport: DynHttpTransportRef,
23 api_key: String,
24 base_url: String,
25 metrics: Arc<dyn Metrics>,
26}
27
28impl GeminiAdapter {
29 pub fn new() -> Result<Self, AiLibError> {
30 let api_key = env::var("GEMINI_API_KEY").map_err(|_| {
31 AiLibError::AuthenticationError(
32 "GEMINI_API_KEY environment variable not set".to_string(),
33 )
34 })?;
35
36 Ok(Self {
37 transport: HttpTransport::new().boxed(),
38 api_key,
39 base_url: "https://generativelanguage.googleapis.com/v1beta".to_string(),
40 metrics: Arc::new(NoopMetrics::new()),
41 })
42 }
43
44 pub fn with_transport_ref(
46 transport: DynHttpTransportRef,
47 api_key: String,
48 base_url: String,
49 ) -> Result<Self, AiLibError> {
50 Ok(Self {
51 transport,
52 api_key,
53 base_url,
54 metrics: Arc::new(NoopMetrics::new()),
55 })
56 }
57
58 pub fn with_transport_ref_and_metrics(
60 transport: DynHttpTransportRef,
61 api_key: String,
62 base_url: String,
63 metrics: Arc<dyn Metrics>,
64 ) -> Result<Self, AiLibError> {
65 Ok(Self {
66 transport,
67 api_key,
68 base_url,
69 metrics,
70 })
71 }
72
73 fn convert_to_gemini_request(&self, request: &ChatCompletionRequest) -> serde_json::Value {
75 let contents: Vec<serde_json::Value> = request
76 .messages
77 .iter()
78 .map(|msg| {
79 let role = match msg.role {
80 Role::User => "user",
81 Role::Assistant => "model", Role::System => "user", };
84
85 serde_json::json!({
86 "role": role,
87 "parts": [{"text": msg.content.as_text()}]
88 })
89 })
90 .collect();
91
92 let mut gemini_request = serde_json::json!({
93 "contents": contents
94 });
95
96 let mut generation_config = serde_json::json!({});
98
99 if let Some(temp) = request.temperature {
100 generation_config["temperature"] =
101 serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
102 }
103 if let Some(max_tokens) = request.max_tokens {
104 generation_config["maxOutputTokens"] =
105 serde_json::Value::Number(serde_json::Number::from(max_tokens));
106 }
107 if let Some(top_p) = request.top_p {
108 generation_config["topP"] =
109 serde_json::Value::Number(serde_json::Number::from_f64(top_p.into()).unwrap());
110 }
111
112 if !generation_config.as_object().unwrap().is_empty() {
113 gemini_request["generationConfig"] = generation_config;
114 }
115
116 gemini_request
117 }
118
119 fn parse_gemini_response(
121 &self,
122 response: serde_json::Value,
123 model: &str,
124 ) -> Result<ChatCompletionResponse, AiLibError> {
125 let candidates = response["candidates"].as_array().ok_or_else(|| {
126 AiLibError::ProviderError("No candidates in Gemini response".to_string())
127 })?;
128
129 let choices: Result<Vec<Choice>, AiLibError> = candidates
130 .iter()
131 .enumerate()
132 .map(|(index, candidate)| {
133 let content = candidate["content"]["parts"][0]["text"]
134 .as_str()
135 .ok_or_else(|| {
136 AiLibError::ProviderError("No text in Gemini candidate".to_string())
137 })?;
138
139 let mut function_call: Option<crate::types::function_call::FunctionCall> = None;
143 if let Some(fc_val) = candidate.get("function_call").cloned().or_else(|| {
144 candidate
145 .get("content")
146 .and_then(|c| c.get("function_call"))
147 .cloned()
148 }) {
149 if let Ok(fc) = serde_json::from_value::<
150 crate::types::function_call::FunctionCall,
151 >(fc_val.clone())
152 {
153 function_call = Some(fc);
154 } else {
155 if let Some(name) = fc_val
157 .get("name")
158 .and_then(|v| v.as_str())
159 .map(|s| s.to_string())
160 {
161 let args = fc_val.get("arguments").and_then(|a| {
162 if a.is_string() {
163 serde_json::from_str::<serde_json::Value>(a.as_str().unwrap())
164 .ok()
165 } else {
166 Some(a.clone())
167 }
168 });
169 function_call = Some(crate::types::function_call::FunctionCall {
170 name,
171 arguments: args,
172 });
173 }
174 }
175 }
176
177 let finish_reason = candidate["finishReason"].as_str().map(|r| match r {
178 "STOP" => "stop".to_string(),
179 "MAX_TOKENS" => "length".to_string(),
180 _ => r.to_string(),
181 });
182
183 Ok(Choice {
184 index: index as u32,
185 message: Message {
186 role: Role::Assistant,
187 content: crate::types::common::Content::Text(content.to_string()),
188 function_call,
189 },
190 finish_reason,
191 })
192 })
193 .collect();
194
195 let usage = Usage {
196 prompt_tokens: response["usageMetadata"]["promptTokenCount"]
197 .as_u64()
198 .unwrap_or(0) as u32,
199 completion_tokens: response["usageMetadata"]["candidatesTokenCount"]
200 .as_u64()
201 .unwrap_or(0) as u32,
202 total_tokens: response["usageMetadata"]["totalTokenCount"]
203 .as_u64()
204 .unwrap_or(0) as u32,
205 };
206
207 Ok(ChatCompletionResponse {
208 id: format!("gemini-{}", chrono::Utc::now().timestamp()),
209 object: "chat.completion".to_string(),
210 created: chrono::Utc::now().timestamp() as u64,
211 model: model.to_string(),
212 choices: choices?,
213 usage,
214 })
215 }
216}
217
218#[async_trait::async_trait]
219impl ChatApi for GeminiAdapter {
220 async fn chat_completion(
221 &self,
222 request: ChatCompletionRequest,
223 ) -> Result<ChatCompletionResponse, AiLibError> {
224 self.metrics.incr_counter("gemini.requests", 1).await;
225 let timer = self.metrics.start_timer("gemini.request_duration_ms").await;
226
227 let gemini_request = self.convert_to_gemini_request(&request);
228
229 let url = format!(
231 "{}/models/{}:generateContent?key={}",
232 self.base_url, request.model, self.api_key
233 );
234
235 let headers = HashMap::from([("Content-Type".to_string(), "application/json".to_string())]);
236
237 let response = match self
238 .transport
239 .post_json(&url, Some(headers), gemini_request)
240 .await
241 {
242 Ok(v) => {
243 if let Some(t) = timer {
244 t.stop();
245 }
246 v
247 }
248 Err(e) => {
249 if let Some(t) = timer {
250 t.stop();
251 }
252 return Err(e);
253 }
254 };
255
256 self.parse_gemini_response(response, &request.model)
257 }
258
259 async fn chat_completion_stream(
260 &self,
261 _request: ChatCompletionRequest,
262 ) -> Result<
263 Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
264 AiLibError,
265 > {
266 let stream = stream::empty();
268 Ok(Box::new(Box::pin(stream)))
269 }
270
271 async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
272 Ok(vec![
274 "gemini-1.5-pro".to_string(),
275 "gemini-1.5-flash".to_string(),
276 "gemini-1.0-pro".to_string(),
277 ])
278 }
279
280 async fn get_model_info(&self, model_id: &str) -> Result<ModelInfo, AiLibError> {
281 Ok(ModelInfo {
282 id: model_id.to_string(),
283 object: "model".to_string(),
284 created: 0,
285 owned_by: "google".to_string(),
286 permission: vec![ModelPermission {
287 id: "default".to_string(),
288 object: "model_permission".to_string(),
289 created: 0,
290 allow_create_engine: false,
291 allow_sampling: true,
292 allow_logprobs: false,
293 allow_search_indices: false,
294 allow_view: true,
295 allow_fine_tuning: false,
296 organization: "*".to_string(),
297 group: None,
298 is_blocking: false,
299 }],
300 })
301 }
302}