1use crate::api::{ChatApi, ChatCompletionChunk, ModelInfo, ModelPermission};
2use crate::metrics::{Metrics, NoopMetrics};
3use crate::transport::{DynHttpTransportRef, HttpTransport};
4use crate::types::{
5 AiLibError, ChatCompletionRequest, ChatCompletionResponse, Choice, Message, Role, Usage, UsageStatus,
6};
7use futures::stream::Stream;
8use futures::StreamExt;
9use std::collections::HashMap;
10use std::env;
11use std::sync::Arc;
12#[cfg(feature = "unified_transport")]
13use std::time::Duration;
14
15pub struct GeminiAdapter {
25 #[allow(dead_code)] transport: DynHttpTransportRef,
27 api_key: String,
28 base_url: String,
29 metrics: Arc<dyn Metrics>,
30}
31
32impl GeminiAdapter {
33 fn build_default_timeout_secs() -> u64 {
34 std::env::var("AI_HTTP_TIMEOUT_SECS")
35 .ok()
36 .and_then(|s| s.parse::<u64>().ok())
37 .unwrap_or(30)
38 }
39
40 fn build_default_transport() -> Result<DynHttpTransportRef, AiLibError> {
41 #[cfg(feature = "unified_transport")]
42 {
43 let timeout = Duration::from_secs(Self::build_default_timeout_secs());
44 let client = crate::transport::client_factory::build_shared_client()
45 .map_err(|e| AiLibError::NetworkError(format!("Failed to build http client: {}", e)))?;
46 let t = HttpTransport::with_reqwest_client(client, timeout);
47 return Ok(t.boxed());
48 }
49 #[cfg(not(feature = "unified_transport"))]
50 {
51 let t = HttpTransport::new();
52 return Ok(t.boxed());
53 }
54 }
55
56 pub fn new() -> Result<Self, AiLibError> {
57 let api_key = env::var("GEMINI_API_KEY").map_err(|_| {
58 AiLibError::AuthenticationError(
59 "GEMINI_API_KEY environment variable not set".to_string(),
60 )
61 })?;
62
63 Ok(Self {
64 transport: Self::build_default_transport()?,
65 api_key,
66 base_url: "https://generativelanguage.googleapis.com/v1beta".to_string(),
67 metrics: Arc::new(NoopMetrics::new()),
68 })
69 }
70
71 pub fn new_with_overrides(
73 api_key: String,
74 base_url: Option<String>,
75 ) -> Result<Self, AiLibError> {
76 Ok(Self {
77 transport: Self::build_default_transport()?,
78 api_key,
79 base_url: base_url
80 .unwrap_or_else(|| "https://generativelanguage.googleapis.com/v1beta".to_string()),
81 metrics: Arc::new(NoopMetrics::new()),
82 })
83 }
84
85 pub fn with_transport_ref(
87 transport: DynHttpTransportRef,
88 api_key: String,
89 base_url: String,
90 ) -> Result<Self, AiLibError> {
91 Ok(Self {
92 transport,
93 api_key,
94 base_url,
95 metrics: Arc::new(NoopMetrics::new()),
96 })
97 }
98
99 pub fn with_transport_ref_and_metrics(
101 transport: DynHttpTransportRef,
102 api_key: String,
103 base_url: String,
104 metrics: Arc<dyn Metrics>,
105 ) -> Result<Self, AiLibError> {
106 Ok(Self {
107 transport,
108 api_key,
109 base_url,
110 metrics,
111 })
112 }
113
114 fn convert_to_gemini_request(&self, request: &ChatCompletionRequest) -> serde_json::Value {
116 let contents: Vec<serde_json::Value> = request
117 .messages
118 .iter()
119 .map(|msg| {
120 let role = match msg.role {
121 Role::User => "user",
122 Role::Assistant => "model", Role::System => "user", };
125
126 serde_json::json!({
127 "role": role,
128 "parts": [{"text": msg.content.as_text()}]
129 })
130 })
131 .collect();
132
133 let mut gemini_request = serde_json::json!({
134 "contents": contents
135 });
136
137 let mut generation_config = serde_json::json!({});
139
140 if let Some(temp) = request.temperature {
141 generation_config["temperature"] =
142 serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
143 }
144 if let Some(max_tokens) = request.max_tokens {
145 generation_config["maxOutputTokens"] =
146 serde_json::Value::Number(serde_json::Number::from(max_tokens));
147 }
148 if let Some(top_p) = request.top_p {
149 generation_config["topP"] =
150 serde_json::Value::Number(serde_json::Number::from_f64(top_p.into()).unwrap());
151 }
152
153 if !generation_config.as_object().unwrap().is_empty() {
154 gemini_request["generationConfig"] = generation_config;
155 }
156
157 gemini_request
158 }
159
160 fn parse_gemini_response(
162 &self,
163 response: serde_json::Value,
164 model: &str,
165 ) -> Result<ChatCompletionResponse, AiLibError> {
166 let candidates = response["candidates"].as_array().ok_or_else(|| {
167 AiLibError::ProviderError("No candidates in Gemini response".to_string())
168 })?;
169
170 let choices: Result<Vec<Choice>, AiLibError> = candidates
171 .iter()
172 .enumerate()
173 .map(|(index, candidate)| {
174 let content = candidate["content"]["parts"][0]["text"]
175 .as_str()
176 .ok_or_else(|| {
177 AiLibError::ProviderError("No text in Gemini candidate".to_string())
178 })?;
179
180 let mut function_call: Option<crate::types::function_call::FunctionCall> = None;
184 if let Some(fc_val) = candidate.get("function_call").cloned().or_else(|| {
185 candidate
186 .get("content")
187 .and_then(|c| c.get("function_call"))
188 .cloned()
189 }) {
190 if let Ok(fc) = serde_json::from_value::<
191 crate::types::function_call::FunctionCall,
192 >(fc_val.clone())
193 {
194 function_call = Some(fc);
195 } else {
196 if let Some(name) = fc_val
198 .get("name")
199 .and_then(|v| v.as_str())
200 .map(|s| s.to_string())
201 {
202 let args = fc_val.get("arguments").and_then(|a| {
203 if a.is_string() {
204 serde_json::from_str::<serde_json::Value>(a.as_str().unwrap())
205 .ok()
206 } else {
207 Some(a.clone())
208 }
209 });
210 function_call = Some(crate::types::function_call::FunctionCall {
211 name,
212 arguments: args,
213 });
214 }
215 }
216 }
217
218 let finish_reason = candidate["finishReason"].as_str().map(|r| match r {
219 "STOP" => "stop".to_string(),
220 "MAX_TOKENS" => "length".to_string(),
221 _ => r.to_string(),
222 });
223
224 Ok(Choice {
225 index: index as u32,
226 message: Message {
227 role: Role::Assistant,
228 content: crate::types::common::Content::Text(content.to_string()),
229 function_call,
230 },
231 finish_reason,
232 })
233 })
234 .collect();
235
236 let usage = Usage {
237 prompt_tokens: response["usageMetadata"]["promptTokenCount"]
238 .as_u64()
239 .unwrap_or(0) as u32,
240 completion_tokens: response["usageMetadata"]["candidatesTokenCount"]
241 .as_u64()
242 .unwrap_or(0) as u32,
243 total_tokens: response["usageMetadata"]["totalTokenCount"]
244 .as_u64()
245 .unwrap_or(0) as u32,
246 };
247
248 Ok(ChatCompletionResponse {
249 id: format!("gemini-{}", chrono::Utc::now().timestamp()),
250 object: "chat.completion".to_string(),
251 created: chrono::Utc::now().timestamp() as u64,
252 model: model.to_string(),
253 choices: choices?,
254 usage,
255 usage_status: UsageStatus::Finalized, })
257 }
258}
259
260#[async_trait::async_trait]
261impl ChatApi for GeminiAdapter {
262 async fn chat_completion(
263 &self,
264 request: ChatCompletionRequest,
265 ) -> Result<ChatCompletionResponse, AiLibError> {
266 self.metrics.incr_counter("gemini.requests", 1).await;
267 let timer = self.metrics.start_timer("gemini.request_duration_ms").await;
268
269 let gemini_request = self.convert_to_gemini_request(&request);
270
271 let url = format!("{}/models/{}:generateContent", self.base_url, request.model);
273
274 let headers = HashMap::from([
275 ("Content-Type".to_string(), "application/json".to_string()),
276 ("x-goog-api-key".to_string(), self.api_key.clone()),
277 ]);
278
279 let response_json = self
281 .transport
282 .post_json(&url, Some(headers), gemini_request)
283 .await?;
284 if let Some(t) = timer {
285 t.stop();
286 }
287 self.parse_gemini_response(response_json, &request.model)
288 }
289
290 async fn chat_completion_stream(
291 &self,
292 request: ChatCompletionRequest,
293 ) -> Result<
294 Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
295 AiLibError,
296 > {
297 let url = format!(
299 "{}/models/{}:streamGenerateContent",
300 self.base_url, request.model
301 );
302 let gemini_request = self.convert_to_gemini_request(&request);
303 let mut headers = HashMap::new();
304 headers.insert("Content-Type".to_string(), "application/json".to_string());
305 headers.insert("Accept".to_string(), "text/event-stream".to_string());
306 headers.insert("x-goog-api-key".to_string(), self.api_key.clone());
307
308 if let Ok(mut byte_stream) = self
309 .transport
310 .post_stream(&url, Some(headers), gemini_request)
311 .await
312 {
313 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
314 tokio::spawn(async move {
315 let mut buffer = Vec::new();
316 while let Some(item) = byte_stream.next().await {
317 match item {
318 Ok(bytes) => {
319 buffer.extend_from_slice(&bytes);
320 #[cfg(feature = "unified_sse")]
321 {
322 while let Some(boundary) =
323 crate::sse::parser::find_event_boundary(&buffer)
324 {
325 let event_bytes = buffer.drain(..boundary).collect::<Vec<_>>();
326 if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
327 for line in event_text.lines() {
328 let line = line.trim();
329 if let Some(data) = line.strip_prefix("data: ") {
330 if data.is_empty() {
331 continue;
332 }
333 if data == "[DONE]" {
334 return;
335 }
336 match serde_json::from_str::<serde_json::Value>(
337 data,
338 ) {
339 Ok(json) => {
340 let text = json
341 .get("candidates")
342 .and_then(|c| c.as_array())
343 .and_then(|arr| arr.first())
344 .and_then(|cand| {
345 cand.get("content")
346 .and_then(|c| c.get("parts"))
347 .and_then(|p| p.as_array())
348 .and_then(|parts| parts.first())
349 .and_then(|part| {
350 part.get("text")
351 })
352 .and_then(|t| t.as_str())
353 })
354 .map(|s| s.to_string());
355 if let Some(tdelta) = text {
356 let delta = crate::api::ChoiceDelta { index: 0, delta: crate::api::MessageDelta { role: Some(crate::types::Role::Assistant), content: Some(tdelta) }, finish_reason: json.get("candidates").and_then(|c| c.as_array()).and_then(|arr| arr.first()).and_then(|cand| cand.get("finishReason").or_else(|| json.get("finishReason"))).and_then(|v| v.as_str()).map(|r| match r { "STOP" => "stop".to_string(), "MAX_TOKENS" => "length".to_string(), other => other.to_string() }) };
357 let chunk_obj = ChatCompletionChunk {
358 id: json
359 .get("responseId")
360 .and_then(|v| v.as_str())
361 .unwrap_or("")
362 .to_string(),
363 object: "chat.completion.chunk"
364 .to_string(),
365 created: 0,
366 model: request.model.clone(),
367 choices: vec![delta],
368 };
369 if tx.send(Ok(chunk_obj)).is_err() {
370 return;
371 }
372 }
373 }
374 Err(e) => {
375 let _ = tx.send(Err(
376 AiLibError::ProviderError(format!(
377 "Gemini SSE JSON parse error: {}",
378 e
379 )),
380 ));
381 return;
382 }
383 }
384 }
385 }
386 }
387 }
388 }
389 #[cfg(not(feature = "unified_sse"))]
390 {
391 fn find_event_boundary(buffer: &[u8]) -> Option<usize> {
392 let mut i = 0;
393 while i + 1 < buffer.len() {
394 if buffer[i] == b'\n' && buffer[i + 1] == b'\n' {
395 return Some(i + 2);
396 }
397 if i + 3 < buffer.len()
398 && buffer[i] == b'\r'
399 && buffer[i + 1] == b'\n'
400 && buffer[i + 2] == b'\r'
401 && buffer[i + 3] == b'\n'
402 {
403 return Some(i + 4);
404 }
405 i += 1;
406 }
407 None
408 }
409 while let Some(boundary) = find_event_boundary(&buffer) {
410 let event_bytes = buffer.drain(..boundary).collect::<Vec<_>>();
411 if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
412 for line in event_text.lines() {
413 let line = line.trim();
414 if let Some(data) = line.strip_prefix("data: ") {
415 if data.is_empty() {
416 continue;
417 }
418 if data == "[DONE]" {
419 return;
420 }
421 match serde_json::from_str::<serde_json::Value>(
422 data,
423 ) {
424 Ok(json) => {
425 let text = json
426 .get("candidates")
427 .and_then(|c| c.as_array())
428 .and_then(|arr| arr.first())
429 .and_then(|cand| {
430 cand.get("content")
431 .and_then(|c| c.get("parts"))
432 .and_then(|p| p.as_array())
433 .and_then(|parts| parts.first())
434 .and_then(|part| {
435 part.get("text")
436 })
437 .and_then(|t| t.as_str())
438 })
439 .map(|s| s.to_string());
440 if let Some(tdelta) = text {
441 let delta = crate::api::ChoiceDelta { index: 0, delta: crate::api::MessageDelta { role: Some(crate::types::Role::Assistant), content: Some(tdelta) }, finish_reason: None };
442 let chunk_obj = ChatCompletionChunk {
443 id: json
444 .get("responseId")
445 .and_then(|v| v.as_str())
446 .unwrap_or("")
447 .to_string(),
448 object: "chat.completion.chunk"
449 .to_string(),
450 created: 0,
451 model: request.model.clone(),
452 choices: vec![delta],
453 };
454 if tx.send(Ok(chunk_obj)).is_err() {
455 return;
456 }
457 }
458 }
459 Err(e) => {
460 let _ = tx.send(Err(
461 AiLibError::ProviderError(format!(
462 "Gemini SSE JSON parse error: {}",
463 e
464 )),
465 ));
466 return;
467 }
468 }
469 }
470 }
471 }
472 }
473 }
474 }
475 Err(e) => {
476 let _ = tx.send(Err(AiLibError::ProviderError(format!(
477 "Stream error: {}",
478 e
479 ))));
480 break;
481 }
482 }
483 }
484 });
485 let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx);
486 return Ok(Box::new(Box::pin(stream)));
487 }
488
489 fn split_text_into_chunks(text: &str, max_len: usize) -> Vec<String> {
491 let mut chunks = Vec::new();
492 let mut start = 0;
493 let bytes = text.as_bytes();
494 while start < bytes.len() {
495 let end = std::cmp::min(start + max_len, bytes.len());
496 let mut cut = end;
497 if end < bytes.len() {
498 if let Some(pos) = text[start..end].rfind(' ') {
499 cut = start + pos;
500 }
501 }
502 if cut == start {
503 cut = end;
504 }
505 chunks.push(String::from_utf8_lossy(&bytes[start..cut]).to_string());
506 start = cut;
507 if start < bytes.len() && bytes[start] == b' ' {
508 start += 1;
509 }
510 }
511 chunks
512 }
513
514 let finished = self.chat_completion(request).await?;
515 let text = finished
516 .choices
517 .first()
518 .map(|c| c.message.content.as_text())
519 .unwrap_or_default();
520 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
521 tokio::spawn(async move {
522 let chunks = split_text_into_chunks(&text, 80);
523 for chunk in chunks {
524 let delta = crate::api::ChoiceDelta {
525 index: 0,
526 delta: crate::api::MessageDelta {
527 role: Some(crate::types::Role::Assistant),
528 content: Some(chunk.clone()),
529 },
530 finish_reason: None,
531 };
532 let chunk_obj = ChatCompletionChunk {
533 id: "simulated".to_string(),
534 object: "chat.completion.chunk".to_string(),
535 created: 0,
536 model: finished.model.clone(),
537 choices: vec![delta],
538 };
539 if tx.send(Ok(chunk_obj)).is_err() {
540 return;
541 }
542 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
543 }
544 });
545 let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx);
546 Ok(Box::new(Box::pin(stream)))
547 }
548
549 async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
550 Ok(vec![
552 "gemini-1.5-pro".to_string(),
553 "gemini-1.5-flash".to_string(),
554 "gemini-1.0-pro".to_string(),
555 ])
556 }
557
558 async fn get_model_info(&self, model_id: &str) -> Result<ModelInfo, AiLibError> {
559 Ok(ModelInfo {
560 id: model_id.to_string(),
561 object: "model".to_string(),
562 created: 0,
563 owned_by: "google".to_string(),
564 permission: vec![ModelPermission {
565 id: "default".to_string(),
566 object: "model_permission".to_string(),
567 created: 0,
568 allow_create_engine: false,
569 allow_sampling: true,
570 allow_logprobs: false,
571 allow_search_indices: false,
572 allow_view: true,
573 allow_fine_tuning: false,
574 organization: "*".to_string(),
575 group: None,
576 is_blocking: false,
577 }],
578 })
579 }
580}