1use crate::api::{ChatApi, ChatCompletionChunk, ModelInfo, ModelPermission};
2use crate::types::{ChatCompletionRequest, ChatCompletionResponse, AiLibError, Message, Role, Choice, Usage};
3use crate::transport::{HttpTransport, DynHttpTransportRef};
4use std::env;
5use std::collections::HashMap;
6use futures::stream::{self, Stream};
7
8pub struct GeminiAdapter {
18 transport: DynHttpTransportRef,
19 api_key: String,
20 base_url: String,
21}
22
23impl GeminiAdapter {
24 pub fn new() -> Result<Self, AiLibError> {
25 let api_key = env::var("GEMINI_API_KEY")
26 .map_err(|_| AiLibError::AuthenticationError(
27 "GEMINI_API_KEY environment variable not set".to_string()
28 ))?;
29
30 Ok(Self {
31 transport: HttpTransport::new().boxed(),
32 api_key,
33 base_url: "https://generativelanguage.googleapis.com/v1beta".to_string(),
34 })
35 }
36
37 pub fn with_transport_ref(transport: DynHttpTransportRef, api_key: String, base_url: String) -> Result<Self, AiLibError> {
39 Ok(Self { transport, api_key, base_url })
40 }
41
42 fn convert_to_gemini_request(&self, request: &ChatCompletionRequest) -> serde_json::Value {
44 let contents: Vec<serde_json::Value> = request.messages.iter().map(|msg| {
45 let role = match msg.role {
46 Role::User => "user",
47 Role::Assistant => "model", Role::System => "user", };
50
51 serde_json::json!({
52 "role": role,
53 "parts": [{"text": msg.content}]
54 })
55 }).collect();
56
57 let mut gemini_request = serde_json::json!({
58 "contents": contents
59 });
60
61 let mut generation_config = serde_json::json!({});
63
64 if let Some(temp) = request.temperature {
65 generation_config["temperature"] = serde_json::Value::Number(
66 serde_json::Number::from_f64(temp.into()).unwrap()
67 );
68 }
69 if let Some(max_tokens) = request.max_tokens {
70 generation_config["maxOutputTokens"] = serde_json::Value::Number(
71 serde_json::Number::from(max_tokens)
72 );
73 }
74 if let Some(top_p) = request.top_p {
75 generation_config["topP"] = serde_json::Value::Number(
76 serde_json::Number::from_f64(top_p.into()).unwrap()
77 );
78 }
79
80 if !generation_config.as_object().unwrap().is_empty() {
81 gemini_request["generationConfig"] = generation_config;
82 }
83
84 gemini_request
85 }
86
87 fn parse_gemini_response(&self, response: serde_json::Value, model: &str) -> Result<ChatCompletionResponse, AiLibError> {
89 let candidates = response["candidates"].as_array()
90 .ok_or_else(|| AiLibError::ProviderError("No candidates in Gemini response".to_string()))?;
91
92 let choices: Result<Vec<Choice>, AiLibError> = candidates.iter().enumerate().map(|(index, candidate)| {
93 let content = candidate["content"]["parts"][0]["text"].as_str()
94 .ok_or_else(|| AiLibError::ProviderError("No text in Gemini candidate".to_string()))?;
95
96 let finish_reason = candidate["finishReason"].as_str().map(|r| match r {
97 "STOP" => "stop".to_string(),
98 "MAX_TOKENS" => "length".to_string(),
99 _ => r.to_string(),
100 });
101
102 Ok(Choice {
103 index: index as u32,
104 message: Message {
105 role: Role::Assistant,
106 content: content.to_string(),
107 },
108 finish_reason,
109 })
110 }).collect();
111
112 let usage = Usage {
113 prompt_tokens: response["usageMetadata"]["promptTokenCount"].as_u64().unwrap_or(0) as u32,
114 completion_tokens: response["usageMetadata"]["candidatesTokenCount"].as_u64().unwrap_or(0) as u32,
115 total_tokens: response["usageMetadata"]["totalTokenCount"].as_u64().unwrap_or(0) as u32,
116 };
117
118 Ok(ChatCompletionResponse {
119 id: format!("gemini-{}", chrono::Utc::now().timestamp()),
120 object: "chat.completion".to_string(),
121 created: chrono::Utc::now().timestamp() as u64,
122 model: model.to_string(),
123 choices: choices?,
124 usage,
125 })
126 }
127}
128
129#[async_trait::async_trait]
130impl ChatApi for GeminiAdapter {
131 async fn chat_completion(&self, request: ChatCompletionRequest) -> Result<ChatCompletionResponse, AiLibError> {
132 let gemini_request = self.convert_to_gemini_request(&request);
133
134 let url = format!(
136 "{}/models/{}:generateContent?key={}",
137 self.base_url, request.model, self.api_key
138 );
139
140 let headers = HashMap::from([
141 ("Content-Type".to_string(), "application/json".to_string()),
142 ]);
143
144 let response: serde_json::Value = self.transport
145 .post_json(&url, Some(headers), gemini_request)
146 .await?;
147
148 self.parse_gemini_response(response, &request.model)
149 }
150
151 async fn chat_completion_stream(&self, _request: ChatCompletionRequest) -> Result<Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>, AiLibError> {
152 let stream = stream::empty();
154 Ok(Box::new(Box::pin(stream)))
155 }
156
157 async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
158 Ok(vec![
160 "gemini-1.5-pro".to_string(),
161 "gemini-1.5-flash".to_string(),
162 "gemini-1.0-pro".to_string(),
163 ])
164 }
165
166 async fn get_model_info(&self, model_id: &str) -> Result<ModelInfo, AiLibError> {
167 Ok(ModelInfo {
168 id: model_id.to_string(),
169 object: "model".to_string(),
170 created: 0,
171 owned_by: "google".to_string(),
172 permission: vec![ModelPermission {
173 id: "default".to_string(),
174 object: "model_permission".to_string(),
175 created: 0,
176 allow_create_engine: false,
177 allow_sampling: true,
178 allow_logprobs: false,
179 allow_search_indices: false,
180 allow_view: true,
181 allow_fine_tuning: false,
182 organization: "*".to_string(),
183 group: None,
184 is_blocking: false,
185 }],
186 })
187 }
188}