1use crate::api::{ChatCompletionChunk, ChatProvider, ModelInfo, ModelPermission};
2use crate::metrics::{Metrics, NoopMetrics};
3use crate::transport::{DynHttpTransportRef, HttpTransport};
4use crate::types::{
5 AiLibError, ChatCompletionRequest, ChatCompletionResponse, Choice, Message, Role, Usage,
6 UsageStatus,
7};
8use futures::stream::Stream;
9use futures::StreamExt;
10use std::collections::HashMap;
11use std::env;
12use std::sync::Arc;
13#[cfg(feature = "unified_transport")]
14use std::time::Duration;
15
16pub struct GeminiAdapter {
26 #[allow(dead_code)] transport: DynHttpTransportRef,
28 api_key: String,
29 base_url: String,
30 metrics: Arc<dyn Metrics>,
31}
32
33impl GeminiAdapter {
34 #[allow(dead_code)]
35 fn build_default_timeout_secs() -> u64 {
36 std::env::var("AI_HTTP_TIMEOUT_SECS")
37 .ok()
38 .and_then(|s| s.parse::<u64>().ok())
39 .unwrap_or(30)
40 }
41
42 fn build_default_transport() -> Result<DynHttpTransportRef, AiLibError> {
43 #[cfg(feature = "unified_transport")]
44 {
45 let timeout = Duration::from_secs(Self::build_default_timeout_secs());
46 let client = crate::transport::client_factory::build_shared_client().map_err(|e| {
47 AiLibError::NetworkError(format!("Failed to build http client: {}", e))
48 })?;
49 let t = HttpTransport::with_reqwest_client(client, timeout);
50 Ok(t.boxed())
51 }
52 #[cfg(not(feature = "unified_transport"))]
53 {
54 let t = HttpTransport::new();
55 return Ok(t.boxed());
56 }
57 }
58
59 pub fn new() -> Result<Self, AiLibError> {
60 let api_key = env::var("GEMINI_API_KEY").map_err(|_| {
61 AiLibError::AuthenticationError(
62 "GEMINI_API_KEY environment variable not set".to_string(),
63 )
64 })?;
65
66 Ok(Self {
67 transport: Self::build_default_transport()?,
68 api_key,
69 base_url: "https://generativelanguage.googleapis.com/v1beta".to_string(),
70 metrics: Arc::new(NoopMetrics::new()),
71 })
72 }
73
74 pub fn new_with_overrides(
76 api_key: String,
77 base_url: Option<String>,
78 ) -> Result<Self, AiLibError> {
79 Ok(Self {
80 transport: Self::build_default_transport()?,
81 api_key,
82 base_url: base_url
83 .unwrap_or_else(|| "https://generativelanguage.googleapis.com/v1beta".to_string()),
84 metrics: Arc::new(NoopMetrics::new()),
85 })
86 }
87
88 pub fn with_transport_ref(
90 transport: DynHttpTransportRef,
91 api_key: String,
92 base_url: String,
93 ) -> Result<Self, AiLibError> {
94 Ok(Self {
95 transport,
96 api_key,
97 base_url,
98 metrics: Arc::new(NoopMetrics::new()),
99 })
100 }
101
102 pub fn with_transport_ref_and_metrics(
104 transport: DynHttpTransportRef,
105 api_key: String,
106 base_url: String,
107 metrics: Arc<dyn Metrics>,
108 ) -> Result<Self, AiLibError> {
109 Ok(Self {
110 transport,
111 api_key,
112 base_url,
113 metrics,
114 })
115 }
116
117 fn convert_to_gemini_request(&self, request: &ChatCompletionRequest) -> serde_json::Value {
119 let contents: Vec<serde_json::Value> = request
120 .messages
121 .iter()
122 .map(|msg| {
123 let role = match msg.role {
124 Role::User => "user",
125 Role::Assistant => "model", Role::System => "user", };
128
129 serde_json::json!({
130 "role": role,
131 "parts": [{"text": msg.content.as_text()}]
132 })
133 })
134 .collect();
135
136 let mut gemini_request = serde_json::json!({
137 "contents": contents
138 });
139
140 let mut generation_config = serde_json::json!({});
142
143 if let Some(temp) = request.temperature {
144 generation_config["temperature"] =
145 serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
146 }
147 if let Some(max_tokens) = request.max_tokens {
148 generation_config["maxOutputTokens"] =
149 serde_json::Value::Number(serde_json::Number::from(max_tokens));
150 }
151 if let Some(top_p) = request.top_p {
152 generation_config["topP"] =
153 serde_json::Value::Number(serde_json::Number::from_f64(top_p.into()).unwrap());
154 }
155
156 if !generation_config.as_object().unwrap().is_empty() {
157 gemini_request["generationConfig"] = generation_config;
158 }
159
160 request.apply_extensions(&mut gemini_request);
161
162 gemini_request
163 }
164
165 fn parse_gemini_response(
167 &self,
168 response: serde_json::Value,
169 model: &str,
170 ) -> Result<ChatCompletionResponse, AiLibError> {
171 let candidates = response["candidates"].as_array().ok_or_else(|| {
172 AiLibError::ProviderError("No candidates in Gemini response".to_string())
173 })?;
174
175 let choices: Result<Vec<Choice>, AiLibError> = candidates
176 .iter()
177 .enumerate()
178 .map(|(index, candidate)| {
179 let content = candidate["content"]["parts"][0]["text"]
180 .as_str()
181 .ok_or_else(|| {
182 AiLibError::ProviderError("No text in Gemini candidate".to_string())
183 })?;
184
185 let mut function_call: Option<crate::types::function_call::FunctionCall> = None;
189 if let Some(fc_val) = candidate.get("function_call").cloned().or_else(|| {
190 candidate
191 .get("content")
192 .and_then(|c| c.get("function_call"))
193 .cloned()
194 }) {
195 if let Ok(fc) = serde_json::from_value::<
196 crate::types::function_call::FunctionCall,
197 >(fc_val.clone())
198 {
199 function_call = Some(fc);
200 } else {
201 if let Some(name) = fc_val
203 .get("name")
204 .and_then(|v| v.as_str())
205 .map(|s| s.to_string())
206 {
207 let args = fc_val.get("arguments").and_then(|a| {
208 if a.is_string() {
209 serde_json::from_str::<serde_json::Value>(a.as_str().unwrap())
210 .ok()
211 } else {
212 Some(a.clone())
213 }
214 });
215 function_call = Some(crate::types::function_call::FunctionCall {
216 name,
217 arguments: args,
218 });
219 }
220 }
221 }
222
223 let finish_reason = candidate["finishReason"].as_str().map(|r| match r {
224 "STOP" => "stop".to_string(),
225 "MAX_TOKENS" => "length".to_string(),
226 _ => r.to_string(),
227 });
228
229 Ok(Choice {
230 index: index as u32,
231 message: Message {
232 role: Role::Assistant,
233 content: crate::types::common::Content::Text(content.to_string()),
234 function_call,
235 },
236 finish_reason,
237 })
238 })
239 .collect();
240
241 let usage = Usage {
242 prompt_tokens: response["usageMetadata"]["promptTokenCount"]
243 .as_u64()
244 .unwrap_or(0) as u32,
245 completion_tokens: response["usageMetadata"]["candidatesTokenCount"]
246 .as_u64()
247 .unwrap_or(0) as u32,
248 total_tokens: response["usageMetadata"]["totalTokenCount"]
249 .as_u64()
250 .unwrap_or(0) as u32,
251 };
252
253 Ok(ChatCompletionResponse {
254 id: format!("gemini-{}", chrono::Utc::now().timestamp()),
255 object: "chat.completion".to_string(),
256 created: chrono::Utc::now().timestamp() as u64,
257 model: model.to_string(),
258 choices: choices?,
259 usage,
260 usage_status: UsageStatus::Finalized, })
262 }
263}
264
265#[async_trait::async_trait]
266impl ChatProvider for GeminiAdapter {
267 fn name(&self) -> &str {
268 "Gemini"
269 }
270
271 async fn chat(
272 &self,
273 request: ChatCompletionRequest,
274 ) -> Result<ChatCompletionResponse, AiLibError> {
275 self.metrics.incr_counter("gemini.requests", 1).await;
276 let timer = self.metrics.start_timer("gemini.request_duration_ms").await;
277
278 let gemini_request = self.convert_to_gemini_request(&request);
279
280 let url = format!("{}/models/{}:generateContent", self.base_url, request.model);
282
283 let headers = HashMap::from([
284 ("Content-Type".to_string(), "application/json".to_string()),
285 ("x-goog-api-key".to_string(), self.api_key.clone()),
286 ]);
287
288 let response_json = self
290 .transport
291 .post_json(&url, Some(headers), gemini_request)
292 .await
293 .map_err(|e| e.with_context(&format!("Gemini chat request to {}", url)))?;
294 if let Some(t) = timer {
295 t.stop();
296 }
297 self.parse_gemini_response(response_json, &request.model)
298 }
299
300 async fn stream(
301 &self,
302 request: ChatCompletionRequest,
303 ) -> Result<
304 Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
305 AiLibError,
306 > {
307 let url = format!(
309 "{}/models/{}:streamGenerateContent",
310 self.base_url, request.model
311 );
312 let gemini_request = self.convert_to_gemini_request(&request);
313 let mut headers = HashMap::new();
314 headers.insert("Content-Type".to_string(), "application/json".to_string());
315 headers.insert("Accept".to_string(), "text/event-stream".to_string());
316 headers.insert("x-goog-api-key".to_string(), self.api_key.clone());
317
318 if let Ok(mut byte_stream) = self
319 .transport
320 .post_stream(&url, Some(headers), gemini_request)
321 .await
322 {
323 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
324 tokio::spawn(async move {
325 let mut buffer = Vec::new();
326 while let Some(item) = byte_stream.next().await {
327 match item {
328 Ok(bytes) => {
329 buffer.extend_from_slice(&bytes);
330 #[cfg(feature = "unified_sse")]
331 {
332 while let Some(boundary) =
333 crate::sse::parser::find_event_boundary(&buffer)
334 {
335 let event_bytes = buffer.drain(..boundary).collect::<Vec<_>>();
336 if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
337 for line in event_text.lines() {
338 let line = line.trim();
339 if let Some(data) = line.strip_prefix("data: ") {
340 if data.is_empty() {
341 continue;
342 }
343 if data == "[DONE]" {
344 return;
345 }
346 match serde_json::from_str::<serde_json::Value>(
347 data,
348 ) {
349 Ok(json) => {
350 let text = json
351 .get("candidates")
352 .and_then(|c| c.as_array())
353 .and_then(|arr| arr.first())
354 .and_then(|cand| {
355 cand.get("content")
356 .and_then(|c| c.get("parts"))
357 .and_then(|p| p.as_array())
358 .and_then(|parts| parts.first())
359 .and_then(|part| {
360 part.get("text")
361 })
362 .and_then(|t| t.as_str())
363 })
364 .map(|s| s.to_string());
365 if let Some(tdelta) = text {
366 let delta = crate::api::ChoiceDelta { index: 0, delta: crate::api::MessageDelta { role: Some(crate::types::Role::Assistant), content: Some(tdelta) }, finish_reason: json.get("candidates").and_then(|c| c.as_array()).and_then(|arr| arr.first()).and_then(|cand| cand.get("finishReason").or_else(|| json.get("finishReason"))).and_then(|v| v.as_str()).map(|r| match r { "STOP" => "stop".to_string(), "MAX_TOKENS" => "length".to_string(), other => other.to_string() }) };
367 let chunk_obj = ChatCompletionChunk {
368 id: json
369 .get("responseId")
370 .and_then(|v| v.as_str())
371 .unwrap_or("")
372 .to_string(),
373 object: "chat.completion.chunk"
374 .to_string(),
375 created: 0,
376 model: request.model.clone(),
377 choices: vec![delta],
378 };
379 if tx.send(Ok(chunk_obj)).is_err() {
380 return;
381 }
382 }
383 }
384 Err(e) => {
385 let _ = tx.send(Err(
386 AiLibError::ProviderError(format!(
387 "Gemini SSE JSON parse error: {}",
388 e
389 )),
390 ));
391 return;
392 }
393 }
394 }
395 }
396 }
397 }
398 }
399 #[cfg(not(feature = "unified_sse"))]
400 {
401 fn find_event_boundary(buffer: &[u8]) -> Option<usize> {
402 let mut i = 0;
403 while i + 1 < buffer.len() {
404 if buffer[i] == b'\n' && buffer[i + 1] == b'\n' {
405 return Some(i + 2);
406 }
407 if i + 3 < buffer.len()
408 && buffer[i] == b'\r'
409 && buffer[i + 1] == b'\n'
410 && buffer[i + 2] == b'\r'
411 && buffer[i + 3] == b'\n'
412 {
413 return Some(i + 4);
414 }
415 i += 1;
416 }
417 None
418 }
419 while let Some(boundary) = find_event_boundary(&buffer) {
420 let event_bytes = buffer.drain(..boundary).collect::<Vec<_>>();
421 if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
422 for line in event_text.lines() {
423 let line = line.trim();
424 if let Some(data) = line.strip_prefix("data: ") {
425 if data.is_empty() {
426 continue;
427 }
428 if data == "[DONE]" {
429 return;
430 }
431 match serde_json::from_str::<serde_json::Value>(
432 data,
433 ) {
434 Ok(json) => {
435 let text = json
436 .get("candidates")
437 .and_then(|c| c.as_array())
438 .and_then(|arr| arr.first())
439 .and_then(|cand| {
440 cand.get("content")
441 .and_then(|c| c.get("parts"))
442 .and_then(|p| p.as_array())
443 .and_then(|parts| parts.first())
444 .and_then(|part| {
445 part.get("text")
446 })
447 .and_then(|t| t.as_str())
448 })
449 .map(|s| s.to_string());
450 if let Some(tdelta) = text {
451 let delta = crate::api::ChoiceDelta { index: 0, delta: crate::api::MessageDelta { role: Some(crate::types::Role::Assistant), content: Some(tdelta) }, finish_reason: None };
452 let chunk_obj = ChatCompletionChunk {
453 id: json
454 .get("responseId")
455 .and_then(|v| v.as_str())
456 .unwrap_or("")
457 .to_string(),
458 object: "chat.completion.chunk"
459 .to_string(),
460 created: 0,
461 model: request.model.clone(),
462 choices: vec![delta],
463 };
464 if tx.send(Ok(chunk_obj)).is_err() {
465 return;
466 }
467 }
468 }
469 Err(e) => {
470 let _ = tx.send(Err(
471 AiLibError::ProviderError(format!(
472 "Gemini SSE JSON parse error: {}",
473 e
474 )),
475 ));
476 return;
477 }
478 }
479 }
480 }
481 }
482 }
483 }
484 }
485 Err(e) => {
486 let _ = tx.send(Err(AiLibError::ProviderError(format!(
487 "Stream error: {}",
488 e
489 ))));
490 break;
491 }
492 }
493 }
494 });
495 let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx);
496 return Ok(Box::new(Box::pin(stream)));
497 }
498
499 fn split_text_into_chunks(text: &str, max_len: usize) -> Vec<String> {
501 let mut chunks = Vec::new();
502 let mut start = 0;
503 let bytes = text.as_bytes();
504 while start < bytes.len() {
505 let end = std::cmp::min(start + max_len, bytes.len());
506 let mut cut = end;
507 if end < bytes.len() {
508 if let Some(pos) = text[start..end].rfind(' ') {
509 cut = start + pos;
510 }
511 }
512 if cut == start {
513 cut = end;
514 }
515 chunks.push(String::from_utf8_lossy(&bytes[start..cut]).to_string());
516 start = cut;
517 if start < bytes.len() && bytes[start] == b' ' {
518 start += 1;
519 }
520 }
521 chunks
522 }
523
524 let finished = self.chat(request).await?;
525 let text = finished
526 .choices
527 .first()
528 .map(|c| c.message.content.as_text())
529 .unwrap_or_default();
530 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
531 tokio::spawn(async move {
532 let chunks = split_text_into_chunks(&text, 80);
533 for chunk in chunks {
534 let delta = crate::api::ChoiceDelta {
535 index: 0,
536 delta: crate::api::MessageDelta {
537 role: Some(crate::types::Role::Assistant),
538 content: Some(chunk.clone()),
539 },
540 finish_reason: None,
541 };
542 let chunk_obj = ChatCompletionChunk {
543 id: "simulated".to_string(),
544 object: "chat.completion.chunk".to_string(),
545 created: 0,
546 model: finished.model.clone(),
547 choices: vec![delta],
548 };
549 if tx.send(Ok(chunk_obj)).is_err() {
550 return;
551 }
552 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
553 }
554 });
555 let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx);
556 Ok(Box::new(Box::pin(stream)))
557 }
558
559 async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
560 Ok(vec![
562 "gemini-1.5-pro".to_string(),
563 "gemini-1.5-flash".to_string(),
564 "gemini-1.0-pro".to_string(),
565 ])
566 }
567
568 async fn get_model_info(&self, model_id: &str) -> Result<ModelInfo, AiLibError> {
569 Ok(ModelInfo {
570 id: model_id.to_string(),
571 object: "model".to_string(),
572 created: 0,
573 owned_by: "google".to_string(),
574 permission: vec![ModelPermission {
575 id: "default".to_string(),
576 object: "model_permission".to_string(),
577 created: 0,
578 allow_create_engine: false,
579 allow_sampling: true,
580 allow_logprobs: false,
581 allow_search_indices: false,
582 allow_view: true,
583 allow_fine_tuning: false,
584 organization: "*".to_string(),
585 group: None,
586 is_blocking: false,
587 }],
588 })
589 }
590}