1use crate::api::{ChatApi, ChatCompletionChunk, ModelInfo, ModelPermission};
2use crate::metrics::{Metrics, NoopMetrics};
3use crate::transport::{DynHttpTransportRef, HttpTransport};
4use crate::types::{
5 AiLibError, ChatCompletionRequest, ChatCompletionResponse, Choice, Message, Role, Usage,
6};
7use futures::stream::{self, Stream};
8use std::collections::HashMap;
9use std::env;
10use std::sync::Arc;
11
12pub struct GeminiAdapter {
22 transport: DynHttpTransportRef,
23 api_key: String,
24 base_url: String,
25 metrics: Arc<dyn Metrics>,
26}
27
28impl GeminiAdapter {
29 pub fn new() -> Result<Self, AiLibError> {
30 let api_key = env::var("GEMINI_API_KEY").map_err(|_| {
31 AiLibError::AuthenticationError(
32 "GEMINI_API_KEY environment variable not set".to_string(),
33 )
34 })?;
35
36 Ok(Self {
37 transport: HttpTransport::new().boxed(),
38 api_key,
39 base_url: "https://generativelanguage.googleapis.com/v1beta".to_string(),
40 metrics: Arc::new(NoopMetrics::new()),
41 })
42 }
43
44 pub fn new_with_overrides(
46 api_key: String,
47 base_url: Option<String>,
48 ) -> Result<Self, AiLibError> {
49 Ok(Self {
50 transport: HttpTransport::new().boxed(),
51 api_key,
52 base_url: base_url
53 .unwrap_or_else(|| "https://generativelanguage.googleapis.com/v1beta".to_string()),
54 metrics: Arc::new(NoopMetrics::new()),
55 })
56 }
57
58 pub fn with_transport_ref(
60 transport: DynHttpTransportRef,
61 api_key: String,
62 base_url: String,
63 ) -> Result<Self, AiLibError> {
64 Ok(Self {
65 transport,
66 api_key,
67 base_url,
68 metrics: Arc::new(NoopMetrics::new()),
69 })
70 }
71
72 pub fn with_transport_ref_and_metrics(
74 transport: DynHttpTransportRef,
75 api_key: String,
76 base_url: String,
77 metrics: Arc<dyn Metrics>,
78 ) -> Result<Self, AiLibError> {
79 Ok(Self {
80 transport,
81 api_key,
82 base_url,
83 metrics,
84 })
85 }
86
87 fn convert_to_gemini_request(&self, request: &ChatCompletionRequest) -> serde_json::Value {
89 let contents: Vec<serde_json::Value> = request
90 .messages
91 .iter()
92 .map(|msg| {
93 let role = match msg.role {
94 Role::User => "user",
95 Role::Assistant => "model", Role::System => "user", };
98
99 serde_json::json!({
100 "role": role,
101 "parts": [{"text": msg.content.as_text()}]
102 })
103 })
104 .collect();
105
106 let mut gemini_request = serde_json::json!({
107 "contents": contents
108 });
109
110 let mut generation_config = serde_json::json!({});
112
113 if let Some(temp) = request.temperature {
114 generation_config["temperature"] =
115 serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
116 }
117 if let Some(max_tokens) = request.max_tokens {
118 generation_config["maxOutputTokens"] =
119 serde_json::Value::Number(serde_json::Number::from(max_tokens));
120 }
121 if let Some(top_p) = request.top_p {
122 generation_config["topP"] =
123 serde_json::Value::Number(serde_json::Number::from_f64(top_p.into()).unwrap());
124 }
125
126 if !generation_config.as_object().unwrap().is_empty() {
127 gemini_request["generationConfig"] = generation_config;
128 }
129
130 gemini_request
131 }
132
133 fn parse_gemini_response(
135 &self,
136 response: serde_json::Value,
137 model: &str,
138 ) -> Result<ChatCompletionResponse, AiLibError> {
139 let candidates = response["candidates"].as_array().ok_or_else(|| {
140 AiLibError::ProviderError("No candidates in Gemini response".to_string())
141 })?;
142
143 let choices: Result<Vec<Choice>, AiLibError> = candidates
144 .iter()
145 .enumerate()
146 .map(|(index, candidate)| {
147 let content = candidate["content"]["parts"][0]["text"]
148 .as_str()
149 .ok_or_else(|| {
150 AiLibError::ProviderError("No text in Gemini candidate".to_string())
151 })?;
152
153 let mut function_call: Option<crate::types::function_call::FunctionCall> = None;
157 if let Some(fc_val) = candidate.get("function_call").cloned().or_else(|| {
158 candidate
159 .get("content")
160 .and_then(|c| c.get("function_call"))
161 .cloned()
162 }) {
163 if let Ok(fc) = serde_json::from_value::<
164 crate::types::function_call::FunctionCall,
165 >(fc_val.clone())
166 {
167 function_call = Some(fc);
168 } else {
169 if let Some(name) = fc_val
171 .get("name")
172 .and_then(|v| v.as_str())
173 .map(|s| s.to_string())
174 {
175 let args = fc_val.get("arguments").and_then(|a| {
176 if a.is_string() {
177 serde_json::from_str::<serde_json::Value>(a.as_str().unwrap())
178 .ok()
179 } else {
180 Some(a.clone())
181 }
182 });
183 function_call = Some(crate::types::function_call::FunctionCall {
184 name,
185 arguments: args,
186 });
187 }
188 }
189 }
190
191 let finish_reason = candidate["finishReason"].as_str().map(|r| match r {
192 "STOP" => "stop".to_string(),
193 "MAX_TOKENS" => "length".to_string(),
194 _ => r.to_string(),
195 });
196
197 Ok(Choice {
198 index: index as u32,
199 message: Message {
200 role: Role::Assistant,
201 content: crate::types::common::Content::Text(content.to_string()),
202 function_call,
203 },
204 finish_reason,
205 })
206 })
207 .collect();
208
209 let usage = Usage {
210 prompt_tokens: response["usageMetadata"]["promptTokenCount"]
211 .as_u64()
212 .unwrap_or(0) as u32,
213 completion_tokens: response["usageMetadata"]["candidatesTokenCount"]
214 .as_u64()
215 .unwrap_or(0) as u32,
216 total_tokens: response["usageMetadata"]["totalTokenCount"]
217 .as_u64()
218 .unwrap_or(0) as u32,
219 };
220
221 Ok(ChatCompletionResponse {
222 id: format!("gemini-{}", chrono::Utc::now().timestamp()),
223 object: "chat.completion".to_string(),
224 created: chrono::Utc::now().timestamp() as u64,
225 model: model.to_string(),
226 choices: choices?,
227 usage,
228 })
229 }
230}
231
232#[async_trait::async_trait]
233impl ChatApi for GeminiAdapter {
234 async fn chat_completion(
235 &self,
236 request: ChatCompletionRequest,
237 ) -> Result<ChatCompletionResponse, AiLibError> {
238 self.metrics.incr_counter("gemini.requests", 1).await;
239 let timer = self.metrics.start_timer("gemini.request_duration_ms").await;
240
241 let gemini_request = self.convert_to_gemini_request(&request);
242
243 let url = format!(
245 "{}/models/{}:generateContent?key={}",
246 self.base_url, request.model, self.api_key
247 );
248
249 let headers = HashMap::from([("Content-Type".to_string(), "application/json".to_string())]);
250
251 let response = match self
252 .transport
253 .post_json(&url, Some(headers), gemini_request)
254 .await
255 {
256 Ok(v) => {
257 if let Some(t) = timer {
258 t.stop();
259 }
260 v
261 }
262 Err(e) => {
263 if let Some(t) = timer {
264 t.stop();
265 }
266 return Err(e);
267 }
268 };
269
270 self.parse_gemini_response(response, &request.model)
271 }
272
273 async fn chat_completion_stream(
274 &self,
275 _request: ChatCompletionRequest,
276 ) -> Result<
277 Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
278 AiLibError,
279 > {
280 let stream = stream::empty();
282 Ok(Box::new(Box::pin(stream)))
283 }
284
285 async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
286 Ok(vec![
288 "gemini-1.5-pro".to_string(),
289 "gemini-1.5-flash".to_string(),
290 "gemini-1.0-pro".to_string(),
291 ])
292 }
293
294 async fn get_model_info(&self, model_id: &str) -> Result<ModelInfo, AiLibError> {
295 Ok(ModelInfo {
296 id: model_id.to_string(),
297 object: "model".to_string(),
298 created: 0,
299 owned_by: "google".to_string(),
300 permission: vec![ModelPermission {
301 id: "default".to_string(),
302 object: "model_permission".to_string(),
303 created: 0,
304 allow_create_engine: false,
305 allow_sampling: true,
306 allow_logprobs: false,
307 allow_search_indices: false,
308 allow_view: true,
309 allow_fine_tuning: false,
310 organization: "*".to_string(),
311 group: None,
312 is_blocking: false,
313 }],
314 })
315 }
316}