1use crate::api::{ChatApi, ChatCompletionChunk, ModelInfo, ModelPermission};
2use crate::metrics::{Metrics, NoopMetrics};
3use crate::transport::{DynHttpTransportRef, HttpTransport};
4use crate::types::{
5 AiLibError, ChatCompletionRequest, ChatCompletionResponse, Choice, Message, Role, Usage, UsageStatus,
6};
7use futures::stream::Stream;
8use futures::StreamExt;
9use std::collections::HashMap;
10use std::env;
11use std::sync::Arc;
12#[cfg(feature = "unified_transport")]
13use std::time::Duration;
14
15pub struct GeminiAdapter {
25 #[allow(dead_code)] transport: DynHttpTransportRef,
27 api_key: String,
28 base_url: String,
29 metrics: Arc<dyn Metrics>,
30}
31
32impl GeminiAdapter {
33 #[allow(dead_code)]
34 fn build_default_timeout_secs() -> u64 {
35 std::env::var("AI_HTTP_TIMEOUT_SECS")
36 .ok()
37 .and_then(|s| s.parse::<u64>().ok())
38 .unwrap_or(30)
39 }
40
41 fn build_default_transport() -> Result<DynHttpTransportRef, AiLibError> {
42 #[cfg(feature = "unified_transport")]
43 {
44 let timeout = Duration::from_secs(Self::build_default_timeout_secs());
45 let client = crate::transport::client_factory::build_shared_client()
46 .map_err(|e| AiLibError::NetworkError(format!("Failed to build http client: {}", e)))?;
47 let t = HttpTransport::with_reqwest_client(client, timeout);
48 return Ok(t.boxed());
49 }
50 #[cfg(not(feature = "unified_transport"))]
51 {
52 let t = HttpTransport::new();
53 return Ok(t.boxed());
54 }
55 }
56
57 pub fn new() -> Result<Self, AiLibError> {
58 let api_key = env::var("GEMINI_API_KEY").map_err(|_| {
59 AiLibError::AuthenticationError(
60 "GEMINI_API_KEY environment variable not set".to_string(),
61 )
62 })?;
63
64 Ok(Self {
65 transport: Self::build_default_transport()?,
66 api_key,
67 base_url: "https://generativelanguage.googleapis.com/v1beta".to_string(),
68 metrics: Arc::new(NoopMetrics::new()),
69 })
70 }
71
72 pub fn new_with_overrides(
74 api_key: String,
75 base_url: Option<String>,
76 ) -> Result<Self, AiLibError> {
77 Ok(Self {
78 transport: Self::build_default_transport()?,
79 api_key,
80 base_url: base_url
81 .unwrap_or_else(|| "https://generativelanguage.googleapis.com/v1beta".to_string()),
82 metrics: Arc::new(NoopMetrics::new()),
83 })
84 }
85
86 pub fn with_transport_ref(
88 transport: DynHttpTransportRef,
89 api_key: String,
90 base_url: String,
91 ) -> Result<Self, AiLibError> {
92 Ok(Self {
93 transport,
94 api_key,
95 base_url,
96 metrics: Arc::new(NoopMetrics::new()),
97 })
98 }
99
100 pub fn with_transport_ref_and_metrics(
102 transport: DynHttpTransportRef,
103 api_key: String,
104 base_url: String,
105 metrics: Arc<dyn Metrics>,
106 ) -> Result<Self, AiLibError> {
107 Ok(Self {
108 transport,
109 api_key,
110 base_url,
111 metrics,
112 })
113 }
114
115 fn convert_to_gemini_request(&self, request: &ChatCompletionRequest) -> serde_json::Value {
117 let contents: Vec<serde_json::Value> = request
118 .messages
119 .iter()
120 .map(|msg| {
121 let role = match msg.role {
122 Role::User => "user",
123 Role::Assistant => "model", Role::System => "user", };
126
127 serde_json::json!({
128 "role": role,
129 "parts": [{"text": msg.content.as_text()}]
130 })
131 })
132 .collect();
133
134 let mut gemini_request = serde_json::json!({
135 "contents": contents
136 });
137
138 let mut generation_config = serde_json::json!({});
140
141 if let Some(temp) = request.temperature {
142 generation_config["temperature"] =
143 serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
144 }
145 if let Some(max_tokens) = request.max_tokens {
146 generation_config["maxOutputTokens"] =
147 serde_json::Value::Number(serde_json::Number::from(max_tokens));
148 }
149 if let Some(top_p) = request.top_p {
150 generation_config["topP"] =
151 serde_json::Value::Number(serde_json::Number::from_f64(top_p.into()).unwrap());
152 }
153
154 if !generation_config.as_object().unwrap().is_empty() {
155 gemini_request["generationConfig"] = generation_config;
156 }
157
158 gemini_request
159 }
160
161 fn parse_gemini_response(
163 &self,
164 response: serde_json::Value,
165 model: &str,
166 ) -> Result<ChatCompletionResponse, AiLibError> {
167 let candidates = response["candidates"].as_array().ok_or_else(|| {
168 AiLibError::ProviderError("No candidates in Gemini response".to_string())
169 })?;
170
171 let choices: Result<Vec<Choice>, AiLibError> = candidates
172 .iter()
173 .enumerate()
174 .map(|(index, candidate)| {
175 let content = candidate["content"]["parts"][0]["text"]
176 .as_str()
177 .ok_or_else(|| {
178 AiLibError::ProviderError("No text in Gemini candidate".to_string())
179 })?;
180
181 let mut function_call: Option<crate::types::function_call::FunctionCall> = None;
185 if let Some(fc_val) = candidate.get("function_call").cloned().or_else(|| {
186 candidate
187 .get("content")
188 .and_then(|c| c.get("function_call"))
189 .cloned()
190 }) {
191 if let Ok(fc) = serde_json::from_value::<
192 crate::types::function_call::FunctionCall,
193 >(fc_val.clone())
194 {
195 function_call = Some(fc);
196 } else {
197 if let Some(name) = fc_val
199 .get("name")
200 .and_then(|v| v.as_str())
201 .map(|s| s.to_string())
202 {
203 let args = fc_val.get("arguments").and_then(|a| {
204 if a.is_string() {
205 serde_json::from_str::<serde_json::Value>(a.as_str().unwrap())
206 .ok()
207 } else {
208 Some(a.clone())
209 }
210 });
211 function_call = Some(crate::types::function_call::FunctionCall {
212 name,
213 arguments: args,
214 });
215 }
216 }
217 }
218
219 let finish_reason = candidate["finishReason"].as_str().map(|r| match r {
220 "STOP" => "stop".to_string(),
221 "MAX_TOKENS" => "length".to_string(),
222 _ => r.to_string(),
223 });
224
225 Ok(Choice {
226 index: index as u32,
227 message: Message {
228 role: Role::Assistant,
229 content: crate::types::common::Content::Text(content.to_string()),
230 function_call,
231 },
232 finish_reason,
233 })
234 })
235 .collect();
236
237 let usage = Usage {
238 prompt_tokens: response["usageMetadata"]["promptTokenCount"]
239 .as_u64()
240 .unwrap_or(0) as u32,
241 completion_tokens: response["usageMetadata"]["candidatesTokenCount"]
242 .as_u64()
243 .unwrap_or(0) as u32,
244 total_tokens: response["usageMetadata"]["totalTokenCount"]
245 .as_u64()
246 .unwrap_or(0) as u32,
247 };
248
249 Ok(ChatCompletionResponse {
250 id: format!("gemini-{}", chrono::Utc::now().timestamp()),
251 object: "chat.completion".to_string(),
252 created: chrono::Utc::now().timestamp() as u64,
253 model: model.to_string(),
254 choices: choices?,
255 usage,
256 usage_status: UsageStatus::Finalized, })
258 }
259}
260
261#[async_trait::async_trait]
262impl ChatApi for GeminiAdapter {
263 async fn chat_completion(
264 &self,
265 request: ChatCompletionRequest,
266 ) -> Result<ChatCompletionResponse, AiLibError> {
267 self.metrics.incr_counter("gemini.requests", 1).await;
268 let timer = self.metrics.start_timer("gemini.request_duration_ms").await;
269
270 let gemini_request = self.convert_to_gemini_request(&request);
271
272 let url = format!("{}/models/{}:generateContent", self.base_url, request.model);
274
275 let headers = HashMap::from([
276 ("Content-Type".to_string(), "application/json".to_string()),
277 ("x-goog-api-key".to_string(), self.api_key.clone()),
278 ]);
279
280 let response_json = self
282 .transport
283 .post_json(&url, Some(headers), gemini_request)
284 .await?;
285 if let Some(t) = timer {
286 t.stop();
287 }
288 self.parse_gemini_response(response_json, &request.model)
289 }
290
291 async fn chat_completion_stream(
292 &self,
293 request: ChatCompletionRequest,
294 ) -> Result<
295 Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
296 AiLibError,
297 > {
298 let url = format!(
300 "{}/models/{}:streamGenerateContent",
301 self.base_url, request.model
302 );
303 let gemini_request = self.convert_to_gemini_request(&request);
304 let mut headers = HashMap::new();
305 headers.insert("Content-Type".to_string(), "application/json".to_string());
306 headers.insert("Accept".to_string(), "text/event-stream".to_string());
307 headers.insert("x-goog-api-key".to_string(), self.api_key.clone());
308
309 if let Ok(mut byte_stream) = self
310 .transport
311 .post_stream(&url, Some(headers), gemini_request)
312 .await
313 {
314 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
315 tokio::spawn(async move {
316 let mut buffer = Vec::new();
317 while let Some(item) = byte_stream.next().await {
318 match item {
319 Ok(bytes) => {
320 buffer.extend_from_slice(&bytes);
321 #[cfg(feature = "unified_sse")]
322 {
323 while let Some(boundary) =
324 crate::sse::parser::find_event_boundary(&buffer)
325 {
326 let event_bytes = buffer.drain(..boundary).collect::<Vec<_>>();
327 if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
328 for line in event_text.lines() {
329 let line = line.trim();
330 if let Some(data) = line.strip_prefix("data: ") {
331 if data.is_empty() {
332 continue;
333 }
334 if data == "[DONE]" {
335 return;
336 }
337 match serde_json::from_str::<serde_json::Value>(
338 data,
339 ) {
340 Ok(json) => {
341 let text = json
342 .get("candidates")
343 .and_then(|c| c.as_array())
344 .and_then(|arr| arr.first())
345 .and_then(|cand| {
346 cand.get("content")
347 .and_then(|c| c.get("parts"))
348 .and_then(|p| p.as_array())
349 .and_then(|parts| parts.first())
350 .and_then(|part| {
351 part.get("text")
352 })
353 .and_then(|t| t.as_str())
354 })
355 .map(|s| s.to_string());
356 if let Some(tdelta) = text {
357 let delta = crate::api::ChoiceDelta { index: 0, delta: crate::api::MessageDelta { role: Some(crate::types::Role::Assistant), content: Some(tdelta) }, finish_reason: json.get("candidates").and_then(|c| c.as_array()).and_then(|arr| arr.first()).and_then(|cand| cand.get("finishReason").or_else(|| json.get("finishReason"))).and_then(|v| v.as_str()).map(|r| match r { "STOP" => "stop".to_string(), "MAX_TOKENS" => "length".to_string(), other => other.to_string() }) };
358 let chunk_obj = ChatCompletionChunk {
359 id: json
360 .get("responseId")
361 .and_then(|v| v.as_str())
362 .unwrap_or("")
363 .to_string(),
364 object: "chat.completion.chunk"
365 .to_string(),
366 created: 0,
367 model: request.model.clone(),
368 choices: vec![delta],
369 };
370 if tx.send(Ok(chunk_obj)).is_err() {
371 return;
372 }
373 }
374 }
375 Err(e) => {
376 let _ = tx.send(Err(
377 AiLibError::ProviderError(format!(
378 "Gemini SSE JSON parse error: {}",
379 e
380 )),
381 ));
382 return;
383 }
384 }
385 }
386 }
387 }
388 }
389 }
390 #[cfg(not(feature = "unified_sse"))]
391 {
392 fn find_event_boundary(buffer: &[u8]) -> Option<usize> {
393 let mut i = 0;
394 while i + 1 < buffer.len() {
395 if buffer[i] == b'\n' && buffer[i + 1] == b'\n' {
396 return Some(i + 2);
397 }
398 if i + 3 < buffer.len()
399 && buffer[i] == b'\r'
400 && buffer[i + 1] == b'\n'
401 && buffer[i + 2] == b'\r'
402 && buffer[i + 3] == b'\n'
403 {
404 return Some(i + 4);
405 }
406 i += 1;
407 }
408 None
409 }
410 while let Some(boundary) = find_event_boundary(&buffer) {
411 let event_bytes = buffer.drain(..boundary).collect::<Vec<_>>();
412 if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
413 for line in event_text.lines() {
414 let line = line.trim();
415 if let Some(data) = line.strip_prefix("data: ") {
416 if data.is_empty() {
417 continue;
418 }
419 if data == "[DONE]" {
420 return;
421 }
422 match serde_json::from_str::<serde_json::Value>(
423 data,
424 ) {
425 Ok(json) => {
426 let text = json
427 .get("candidates")
428 .and_then(|c| c.as_array())
429 .and_then(|arr| arr.first())
430 .and_then(|cand| {
431 cand.get("content")
432 .and_then(|c| c.get("parts"))
433 .and_then(|p| p.as_array())
434 .and_then(|parts| parts.first())
435 .and_then(|part| {
436 part.get("text")
437 })
438 .and_then(|t| t.as_str())
439 })
440 .map(|s| s.to_string());
441 if let Some(tdelta) = text {
442 let delta = crate::api::ChoiceDelta { index: 0, delta: crate::api::MessageDelta { role: Some(crate::types::Role::Assistant), content: Some(tdelta) }, finish_reason: None };
443 let chunk_obj = ChatCompletionChunk {
444 id: json
445 .get("responseId")
446 .and_then(|v| v.as_str())
447 .unwrap_or("")
448 .to_string(),
449 object: "chat.completion.chunk"
450 .to_string(),
451 created: 0,
452 model: request.model.clone(),
453 choices: vec![delta],
454 };
455 if tx.send(Ok(chunk_obj)).is_err() {
456 return;
457 }
458 }
459 }
460 Err(e) => {
461 let _ = tx.send(Err(
462 AiLibError::ProviderError(format!(
463 "Gemini SSE JSON parse error: {}",
464 e
465 )),
466 ));
467 return;
468 }
469 }
470 }
471 }
472 }
473 }
474 }
475 }
476 Err(e) => {
477 let _ = tx.send(Err(AiLibError::ProviderError(format!(
478 "Stream error: {}",
479 e
480 ))));
481 break;
482 }
483 }
484 }
485 });
486 let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx);
487 return Ok(Box::new(Box::pin(stream)));
488 }
489
490 fn split_text_into_chunks(text: &str, max_len: usize) -> Vec<String> {
492 let mut chunks = Vec::new();
493 let mut start = 0;
494 let bytes = text.as_bytes();
495 while start < bytes.len() {
496 let end = std::cmp::min(start + max_len, bytes.len());
497 let mut cut = end;
498 if end < bytes.len() {
499 if let Some(pos) = text[start..end].rfind(' ') {
500 cut = start + pos;
501 }
502 }
503 if cut == start {
504 cut = end;
505 }
506 chunks.push(String::from_utf8_lossy(&bytes[start..cut]).to_string());
507 start = cut;
508 if start < bytes.len() && bytes[start] == b' ' {
509 start += 1;
510 }
511 }
512 chunks
513 }
514
515 let finished = self.chat_completion(request).await?;
516 let text = finished
517 .choices
518 .first()
519 .map(|c| c.message.content.as_text())
520 .unwrap_or_default();
521 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
522 tokio::spawn(async move {
523 let chunks = split_text_into_chunks(&text, 80);
524 for chunk in chunks {
525 let delta = crate::api::ChoiceDelta {
526 index: 0,
527 delta: crate::api::MessageDelta {
528 role: Some(crate::types::Role::Assistant),
529 content: Some(chunk.clone()),
530 },
531 finish_reason: None,
532 };
533 let chunk_obj = ChatCompletionChunk {
534 id: "simulated".to_string(),
535 object: "chat.completion.chunk".to_string(),
536 created: 0,
537 model: finished.model.clone(),
538 choices: vec![delta],
539 };
540 if tx.send(Ok(chunk_obj)).is_err() {
541 return;
542 }
543 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
544 }
545 });
546 let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx);
547 Ok(Box::new(Box::pin(stream)))
548 }
549
550 async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
551 Ok(vec![
553 "gemini-1.5-pro".to_string(),
554 "gemini-1.5-flash".to_string(),
555 "gemini-1.0-pro".to_string(),
556 ])
557 }
558
559 async fn get_model_info(&self, model_id: &str) -> Result<ModelInfo, AiLibError> {
560 Ok(ModelInfo {
561 id: model_id.to_string(),
562 object: "model".to_string(),
563 created: 0,
564 owned_by: "google".to_string(),
565 permission: vec![ModelPermission {
566 id: "default".to_string(),
567 object: "model_permission".to_string(),
568 created: 0,
569 allow_create_engine: false,
570 allow_sampling: true,
571 allow_logprobs: false,
572 allow_search_indices: false,
573 allow_view: true,
574 allow_fine_tuning: false,
575 organization: "*".to_string(),
576 group: None,
577 is_blocking: false,
578 }],
579 })
580 }
581}