ai_lib/provider/
gemini.rs1use crate::api::{ChatApi, ChatCompletionChunk, ModelInfo, ModelPermission};
2use crate::types::{ChatCompletionRequest, ChatCompletionResponse, AiLibError, Message, Role, Choice, Usage};
3use crate::transport::{HttpClient, HttpTransport};
4use std::env;
5use std::collections::HashMap;
6use futures::stream::{self, Stream};
7
8pub struct GeminiAdapter {
18 transport: HttpTransport,
19 api_key: String,
20 base_url: String,
21}
22
23impl GeminiAdapter {
24 pub fn new() -> Result<Self, AiLibError> {
25 let api_key = env::var("GEMINI_API_KEY")
26 .map_err(|_| AiLibError::AuthenticationError(
27 "GEMINI_API_KEY environment variable not set".to_string()
28 ))?;
29
30 Ok(Self {
31 transport: HttpTransport::new(),
32 api_key,
33 base_url: "https://generativelanguage.googleapis.com/v1beta".to_string(),
34 })
35 }
36
37 fn convert_to_gemini_request(&self, request: &ChatCompletionRequest) -> serde_json::Value {
39 let contents: Vec<serde_json::Value> = request.messages.iter().map(|msg| {
40 let role = match msg.role {
41 Role::User => "user",
42 Role::Assistant => "model", Role::System => "user", };
45
46 serde_json::json!({
47 "role": role,
48 "parts": [{"text": msg.content}]
49 })
50 }).collect();
51
52 let mut gemini_request = serde_json::json!({
53 "contents": contents
54 });
55
56 let mut generation_config = serde_json::json!({});
58
59 if let Some(temp) = request.temperature {
60 generation_config["temperature"] = serde_json::Value::Number(
61 serde_json::Number::from_f64(temp.into()).unwrap()
62 );
63 }
64 if let Some(max_tokens) = request.max_tokens {
65 generation_config["maxOutputTokens"] = serde_json::Value::Number(
66 serde_json::Number::from(max_tokens)
67 );
68 }
69 if let Some(top_p) = request.top_p {
70 generation_config["topP"] = serde_json::Value::Number(
71 serde_json::Number::from_f64(top_p.into()).unwrap()
72 );
73 }
74
75 if !generation_config.as_object().unwrap().is_empty() {
76 gemini_request["generationConfig"] = generation_config;
77 }
78
79 gemini_request
80 }
81
82 fn parse_gemini_response(&self, response: serde_json::Value, model: &str) -> Result<ChatCompletionResponse, AiLibError> {
84 let candidates = response["candidates"].as_array()
85 .ok_or_else(|| AiLibError::ProviderError("No candidates in Gemini response".to_string()))?;
86
87 let choices: Result<Vec<Choice>, AiLibError> = candidates.iter().enumerate().map(|(index, candidate)| {
88 let content = candidate["content"]["parts"][0]["text"].as_str()
89 .ok_or_else(|| AiLibError::ProviderError("No text in Gemini candidate".to_string()))?;
90
91 let finish_reason = candidate["finishReason"].as_str().map(|r| match r {
92 "STOP" => "stop".to_string(),
93 "MAX_TOKENS" => "length".to_string(),
94 _ => r.to_string(),
95 });
96
97 Ok(Choice {
98 index: index as u32,
99 message: Message {
100 role: Role::Assistant,
101 content: content.to_string(),
102 },
103 finish_reason,
104 })
105 }).collect();
106
107 let usage = Usage {
108 prompt_tokens: response["usageMetadata"]["promptTokenCount"].as_u64().unwrap_or(0) as u32,
109 completion_tokens: response["usageMetadata"]["candidatesTokenCount"].as_u64().unwrap_or(0) as u32,
110 total_tokens: response["usageMetadata"]["totalTokenCount"].as_u64().unwrap_or(0) as u32,
111 };
112
113 Ok(ChatCompletionResponse {
114 id: format!("gemini-{}", chrono::Utc::now().timestamp()),
115 object: "chat.completion".to_string(),
116 created: chrono::Utc::now().timestamp() as u64,
117 model: model.to_string(),
118 choices: choices?,
119 usage,
120 })
121 }
122}
123
124#[async_trait::async_trait]
125impl ChatApi for GeminiAdapter {
126 async fn chat_completion(&self, request: ChatCompletionRequest) -> Result<ChatCompletionResponse, AiLibError> {
127 let gemini_request = self.convert_to_gemini_request(&request);
128
129 let url = format!(
131 "{}/models/{}:generateContent?key={}",
132 self.base_url, request.model, self.api_key
133 );
134
135 let headers = HashMap::from([
136 ("Content-Type".to_string(), "application/json".to_string()),
137 ]);
138
139 let response: serde_json::Value = self.transport
140 .post(&url, Some(headers), &gemini_request)
141 .await?;
142
143 self.parse_gemini_response(response, &request.model)
144 }
145
146 async fn chat_completion_stream(&self, _request: ChatCompletionRequest) -> Result<Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>, AiLibError> {
147 let stream = stream::empty();
149 Ok(Box::new(Box::pin(stream)))
150 }
151
152 async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
153 Ok(vec![
155 "gemini-1.5-pro".to_string(),
156 "gemini-1.5-flash".to_string(),
157 "gemini-1.0-pro".to_string(),
158 ])
159 }
160
161 async fn get_model_info(&self, model_id: &str) -> Result<ModelInfo, AiLibError> {
162 Ok(ModelInfo {
163 id: model_id.to_string(),
164 object: "model".to_string(),
165 created: 0,
166 owned_by: "google".to_string(),
167 permission: vec![ModelPermission {
168 id: "default".to_string(),
169 object: "model_permission".to_string(),
170 created: 0,
171 allow_create_engine: false,
172 allow_sampling: true,
173 allow_logprobs: false,
174 allow_search_indices: false,
175 allow_view: true,
176 allow_fine_tuning: false,
177 organization: "*".to_string(),
178 group: None,
179 is_blocking: false,
180 }],
181 })
182 }
183}