1use crate::api::{ChatApi, ChatCompletionChunk, ModelInfo, ModelPermission};
2use crate::metrics::{Metrics, NoopMetrics};
3use crate::transport::{DynHttpTransportRef, HttpTransport};
4use crate::types::{
5 AiLibError, ChatCompletionRequest, ChatCompletionResponse, Choice, Message, Role, Usage,
6};
7use futures::stream::Stream;
8use futures::StreamExt;
9use std::collections::HashMap;
10use std::env;
11use std::sync::Arc;
12#[cfg(feature = "unified_transport")]
13use std::time::Duration;
14
15pub struct GeminiAdapter {
25 #[allow(dead_code)] transport: DynHttpTransportRef,
27 api_key: String,
28 base_url: String,
29 metrics: Arc<dyn Metrics>,
30}
31
32impl GeminiAdapter {
33 fn build_default_timeout_secs() -> u64 {
34 std::env::var("AI_HTTP_TIMEOUT_SECS")
35 .ok()
36 .and_then(|s| s.parse::<u64>().ok())
37 .unwrap_or(30)
38 }
39
40 fn build_default_transport() -> Result<DynHttpTransportRef, AiLibError> {
41 #[cfg(feature = "unified_transport")]
42 {
43 let timeout = Duration::from_secs(Self::build_default_timeout_secs());
44 let client = crate::transport::client_factory::build_shared_client()
45 .map_err(|e| AiLibError::NetworkError(format!("Failed to build http client: {}", e)))?;
46 let t = HttpTransport::with_reqwest_client(client, timeout);
47 return Ok(t.boxed());
48 }
49 #[cfg(not(feature = "unified_transport"))]
50 {
51 let t = HttpTransport::new();
52 return Ok(t.boxed());
53 }
54 }
55
56 pub fn new() -> Result<Self, AiLibError> {
57 let api_key = env::var("GEMINI_API_KEY").map_err(|_| {
58 AiLibError::AuthenticationError(
59 "GEMINI_API_KEY environment variable not set".to_string(),
60 )
61 })?;
62
63 Ok(Self {
64 transport: Self::build_default_transport()?,
65 api_key,
66 base_url: "https://generativelanguage.googleapis.com/v1beta".to_string(),
67 metrics: Arc::new(NoopMetrics::new()),
68 })
69 }
70
71 pub fn new_with_overrides(
73 api_key: String,
74 base_url: Option<String>,
75 ) -> Result<Self, AiLibError> {
76 Ok(Self {
77 transport: Self::build_default_transport()?,
78 api_key,
79 base_url: base_url
80 .unwrap_or_else(|| "https://generativelanguage.googleapis.com/v1beta".to_string()),
81 metrics: Arc::new(NoopMetrics::new()),
82 })
83 }
84
85 pub fn with_transport_ref(
87 transport: DynHttpTransportRef,
88 api_key: String,
89 base_url: String,
90 ) -> Result<Self, AiLibError> {
91 Ok(Self {
92 transport,
93 api_key,
94 base_url,
95 metrics: Arc::new(NoopMetrics::new()),
96 })
97 }
98
99 pub fn with_transport_ref_and_metrics(
101 transport: DynHttpTransportRef,
102 api_key: String,
103 base_url: String,
104 metrics: Arc<dyn Metrics>,
105 ) -> Result<Self, AiLibError> {
106 Ok(Self {
107 transport,
108 api_key,
109 base_url,
110 metrics,
111 })
112 }
113
114 fn convert_to_gemini_request(&self, request: &ChatCompletionRequest) -> serde_json::Value {
116 let contents: Vec<serde_json::Value> = request
117 .messages
118 .iter()
119 .map(|msg| {
120 let role = match msg.role {
121 Role::User => "user",
122 Role::Assistant => "model", Role::System => "user", };
125
126 serde_json::json!({
127 "role": role,
128 "parts": [{"text": msg.content.as_text()}]
129 })
130 })
131 .collect();
132
133 let mut gemini_request = serde_json::json!({
134 "contents": contents
135 });
136
137 let mut generation_config = serde_json::json!({});
139
140 if let Some(temp) = request.temperature {
141 generation_config["temperature"] =
142 serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
143 }
144 if let Some(max_tokens) = request.max_tokens {
145 generation_config["maxOutputTokens"] =
146 serde_json::Value::Number(serde_json::Number::from(max_tokens));
147 }
148 if let Some(top_p) = request.top_p {
149 generation_config["topP"] =
150 serde_json::Value::Number(serde_json::Number::from_f64(top_p.into()).unwrap());
151 }
152
153 if !generation_config.as_object().unwrap().is_empty() {
154 gemini_request["generationConfig"] = generation_config;
155 }
156
157 gemini_request
158 }
159
160 fn parse_gemini_response(
162 &self,
163 response: serde_json::Value,
164 model: &str,
165 ) -> Result<ChatCompletionResponse, AiLibError> {
166 let candidates = response["candidates"].as_array().ok_or_else(|| {
167 AiLibError::ProviderError("No candidates in Gemini response".to_string())
168 })?;
169
170 let choices: Result<Vec<Choice>, AiLibError> = candidates
171 .iter()
172 .enumerate()
173 .map(|(index, candidate)| {
174 let content = candidate["content"]["parts"][0]["text"]
175 .as_str()
176 .ok_or_else(|| {
177 AiLibError::ProviderError("No text in Gemini candidate".to_string())
178 })?;
179
180 let mut function_call: Option<crate::types::function_call::FunctionCall> = None;
184 if let Some(fc_val) = candidate.get("function_call").cloned().or_else(|| {
185 candidate
186 .get("content")
187 .and_then(|c| c.get("function_call"))
188 .cloned()
189 }) {
190 if let Ok(fc) = serde_json::from_value::<
191 crate::types::function_call::FunctionCall,
192 >(fc_val.clone())
193 {
194 function_call = Some(fc);
195 } else {
196 if let Some(name) = fc_val
198 .get("name")
199 .and_then(|v| v.as_str())
200 .map(|s| s.to_string())
201 {
202 let args = fc_val.get("arguments").and_then(|a| {
203 if a.is_string() {
204 serde_json::from_str::<serde_json::Value>(a.as_str().unwrap())
205 .ok()
206 } else {
207 Some(a.clone())
208 }
209 });
210 function_call = Some(crate::types::function_call::FunctionCall {
211 name,
212 arguments: args,
213 });
214 }
215 }
216 }
217
218 let finish_reason = candidate["finishReason"].as_str().map(|r| match r {
219 "STOP" => "stop".to_string(),
220 "MAX_TOKENS" => "length".to_string(),
221 _ => r.to_string(),
222 });
223
224 Ok(Choice {
225 index: index as u32,
226 message: Message {
227 role: Role::Assistant,
228 content: crate::types::common::Content::Text(content.to_string()),
229 function_call,
230 },
231 finish_reason,
232 })
233 })
234 .collect();
235
236 let usage = Usage {
237 prompt_tokens: response["usageMetadata"]["promptTokenCount"]
238 .as_u64()
239 .unwrap_or(0) as u32,
240 completion_tokens: response["usageMetadata"]["candidatesTokenCount"]
241 .as_u64()
242 .unwrap_or(0) as u32,
243 total_tokens: response["usageMetadata"]["totalTokenCount"]
244 .as_u64()
245 .unwrap_or(0) as u32,
246 };
247
248 Ok(ChatCompletionResponse {
249 id: format!("gemini-{}", chrono::Utc::now().timestamp()),
250 object: "chat.completion".to_string(),
251 created: chrono::Utc::now().timestamp() as u64,
252 model: model.to_string(),
253 choices: choices?,
254 usage,
255 })
256 }
257}
258
259#[async_trait::async_trait]
260impl ChatApi for GeminiAdapter {
261 async fn chat_completion(
262 &self,
263 request: ChatCompletionRequest,
264 ) -> Result<ChatCompletionResponse, AiLibError> {
265 self.metrics.incr_counter("gemini.requests", 1).await;
266 let timer = self.metrics.start_timer("gemini.request_duration_ms").await;
267
268 let gemini_request = self.convert_to_gemini_request(&request);
269
270 let url = format!("{}/models/{}:generateContent", self.base_url, request.model);
272
273 let headers = HashMap::from([
274 ("Content-Type".to_string(), "application/json".to_string()),
275 ("x-goog-api-key".to_string(), self.api_key.clone()),
276 ]);
277
278 let response_json = self
280 .transport
281 .post_json(&url, Some(headers), gemini_request)
282 .await?;
283 if let Some(t) = timer {
284 t.stop();
285 }
286 self.parse_gemini_response(response_json, &request.model)
287 }
288
289 async fn chat_completion_stream(
290 &self,
291 request: ChatCompletionRequest,
292 ) -> Result<
293 Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
294 AiLibError,
295 > {
296 let url = format!(
298 "{}/models/{}:streamGenerateContent",
299 self.base_url, request.model
300 );
301 let gemini_request = self.convert_to_gemini_request(&request);
302 let mut headers = HashMap::new();
303 headers.insert("Content-Type".to_string(), "application/json".to_string());
304 headers.insert("Accept".to_string(), "text/event-stream".to_string());
305 headers.insert("x-goog-api-key".to_string(), self.api_key.clone());
306
307 if let Ok(mut byte_stream) = self
308 .transport
309 .post_stream(&url, Some(headers), gemini_request)
310 .await
311 {
312 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
313 tokio::spawn(async move {
314 let mut buffer = Vec::new();
315 while let Some(item) = byte_stream.next().await {
316 match item {
317 Ok(bytes) => {
318 buffer.extend_from_slice(&bytes);
319 #[cfg(feature = "unified_sse")]
320 {
321 while let Some(boundary) =
322 crate::sse::parser::find_event_boundary(&buffer)
323 {
324 let event_bytes = buffer.drain(..boundary).collect::<Vec<_>>();
325 if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
326 for line in event_text.lines() {
327 let line = line.trim();
328 if let Some(data) = line.strip_prefix("data: ") {
329 if data.is_empty() {
330 continue;
331 }
332 if data == "[DONE]" {
333 return;
334 }
335 match serde_json::from_str::<serde_json::Value>(
336 data,
337 ) {
338 Ok(json) => {
339 let text = json
340 .get("candidates")
341 .and_then(|c| c.as_array())
342 .and_then(|arr| arr.first())
343 .and_then(|cand| {
344 cand.get("content")
345 .and_then(|c| c.get("parts"))
346 .and_then(|p| p.as_array())
347 .and_then(|parts| parts.first())
348 .and_then(|part| {
349 part.get("text")
350 })
351 .and_then(|t| t.as_str())
352 })
353 .map(|s| s.to_string());
354 if let Some(tdelta) = text {
355 let delta = crate::api::ChoiceDelta { index: 0, delta: crate::api::MessageDelta { role: Some(crate::types::Role::Assistant), content: Some(tdelta) }, finish_reason: json.get("candidates").and_then(|c| c.as_array()).and_then(|arr| arr.first()).and_then(|cand| cand.get("finishReason").or_else(|| json.get("finishReason"))).and_then(|v| v.as_str()).map(|r| match r { "STOP" => "stop".to_string(), "MAX_TOKENS" => "length".to_string(), other => other.to_string() }) };
356 let chunk_obj = ChatCompletionChunk {
357 id: json
358 .get("responseId")
359 .and_then(|v| v.as_str())
360 .unwrap_or("")
361 .to_string(),
362 object: "chat.completion.chunk"
363 .to_string(),
364 created: 0,
365 model: request.model.clone(),
366 choices: vec![delta],
367 };
368 if tx.send(Ok(chunk_obj)).is_err() {
369 return;
370 }
371 }
372 }
373 Err(e) => {
374 let _ = tx.send(Err(
375 AiLibError::ProviderError(format!(
376 "Gemini SSE JSON parse error: {}",
377 e
378 )),
379 ));
380 return;
381 }
382 }
383 }
384 }
385 }
386 }
387 }
388 #[cfg(not(feature = "unified_sse"))]
389 {
390 fn find_event_boundary(buffer: &[u8]) -> Option<usize> {
391 let mut i = 0;
392 while i + 1 < buffer.len() {
393 if buffer[i] == b'\n' && buffer[i + 1] == b'\n' {
394 return Some(i + 2);
395 }
396 if i + 3 < buffer.len()
397 && buffer[i] == b'\r'
398 && buffer[i + 1] == b'\n'
399 && buffer[i + 2] == b'\r'
400 && buffer[i + 3] == b'\n'
401 {
402 return Some(i + 4);
403 }
404 i += 1;
405 }
406 None
407 }
408 while let Some(boundary) = find_event_boundary(&buffer) {
409 let event_bytes = buffer.drain(..boundary).collect::<Vec<_>>();
410 if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
411 for line in event_text.lines() {
412 let line = line.trim();
413 if let Some(data) = line.strip_prefix("data: ") {
414 if data.is_empty() {
415 continue;
416 }
417 if data == "[DONE]" {
418 return;
419 }
420 match serde_json::from_str::<serde_json::Value>(
421 data,
422 ) {
423 Ok(json) => {
424 let text = json
425 .get("candidates")
426 .and_then(|c| c.as_array())
427 .and_then(|arr| arr.first())
428 .and_then(|cand| {
429 cand.get("content")
430 .and_then(|c| c.get("parts"))
431 .and_then(|p| p.as_array())
432 .and_then(|parts| parts.first())
433 .and_then(|part| {
434 part.get("text")
435 })
436 .and_then(|t| t.as_str())
437 })
438 .map(|s| s.to_string());
439 if let Some(tdelta) = text {
440 let delta = crate::api::ChoiceDelta { index: 0, delta: crate::api::MessageDelta { role: Some(crate::types::Role::Assistant), content: Some(tdelta) }, finish_reason: None };
441 let chunk_obj = ChatCompletionChunk {
442 id: json
443 .get("responseId")
444 .and_then(|v| v.as_str())
445 .unwrap_or("")
446 .to_string(),
447 object: "chat.completion.chunk"
448 .to_string(),
449 created: 0,
450 model: request.model.clone(),
451 choices: vec![delta],
452 };
453 if tx.send(Ok(chunk_obj)).is_err() {
454 return;
455 }
456 }
457 }
458 Err(e) => {
459 let _ = tx.send(Err(
460 AiLibError::ProviderError(format!(
461 "Gemini SSE JSON parse error: {}",
462 e
463 )),
464 ));
465 return;
466 }
467 }
468 }
469 }
470 }
471 }
472 }
473 }
474 Err(e) => {
475 let _ = tx.send(Err(AiLibError::ProviderError(format!(
476 "Stream error: {}",
477 e
478 ))));
479 break;
480 }
481 }
482 }
483 });
484 let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx);
485 return Ok(Box::new(Box::pin(stream)));
486 }
487
488 fn split_text_into_chunks(text: &str, max_len: usize) -> Vec<String> {
490 let mut chunks = Vec::new();
491 let mut start = 0;
492 let bytes = text.as_bytes();
493 while start < bytes.len() {
494 let end = std::cmp::min(start + max_len, bytes.len());
495 let mut cut = end;
496 if end < bytes.len() {
497 if let Some(pos) = text[start..end].rfind(' ') {
498 cut = start + pos;
499 }
500 }
501 if cut == start {
502 cut = end;
503 }
504 chunks.push(String::from_utf8_lossy(&bytes[start..cut]).to_string());
505 start = cut;
506 if start < bytes.len() && bytes[start] == b' ' {
507 start += 1;
508 }
509 }
510 chunks
511 }
512
513 let finished = self.chat_completion(request).await?;
514 let text = finished
515 .choices
516 .first()
517 .map(|c| c.message.content.as_text())
518 .unwrap_or_default();
519 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
520 tokio::spawn(async move {
521 let chunks = split_text_into_chunks(&text, 80);
522 for chunk in chunks {
523 let delta = crate::api::ChoiceDelta {
524 index: 0,
525 delta: crate::api::MessageDelta {
526 role: Some(crate::types::Role::Assistant),
527 content: Some(chunk.clone()),
528 },
529 finish_reason: None,
530 };
531 let chunk_obj = ChatCompletionChunk {
532 id: "simulated".to_string(),
533 object: "chat.completion.chunk".to_string(),
534 created: 0,
535 model: finished.model.clone(),
536 choices: vec![delta],
537 };
538 if tx.send(Ok(chunk_obj)).is_err() {
539 return;
540 }
541 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
542 }
543 });
544 let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx);
545 Ok(Box::new(Box::pin(stream)))
546 }
547
548 async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
549 Ok(vec![
551 "gemini-1.5-pro".to_string(),
552 "gemini-1.5-flash".to_string(),
553 "gemini-1.0-pro".to_string(),
554 ])
555 }
556
557 async fn get_model_info(&self, model_id: &str) -> Result<ModelInfo, AiLibError> {
558 Ok(ModelInfo {
559 id: model_id.to_string(),
560 object: "model".to_string(),
561 created: 0,
562 owned_by: "google".to_string(),
563 permission: vec![ModelPermission {
564 id: "default".to_string(),
565 object: "model_permission".to_string(),
566 created: 0,
567 allow_create_engine: false,
568 allow_sampling: true,
569 allow_logprobs: false,
570 allow_search_indices: false,
571 allow_view: true,
572 allow_fine_tuning: false,
573 organization: "*".to_string(),
574 group: None,
575 is_blocking: false,
576 }],
577 })
578 }
579}