1use ceres_core::HttpConfig;
35use ceres_core::error::{AppError, GeminiErrorDetails, GeminiErrorKind};
36use reqwest::Client;
37use serde::{Deserialize, Serialize};
38
39#[derive(Clone)]
63pub struct GeminiClient {
64 client: Client,
65 api_key: String,
66}
67
68#[derive(Serialize)]
70struct EmbeddingRequest {
71 model: String,
72 content: Content,
73 #[serde(skip_serializing_if = "Option::is_none")]
76 output_dimensionality: Option<usize>,
77}
78
79#[derive(Serialize)]
80struct Content {
81 parts: Vec<Part>,
82}
83
84#[derive(Serialize)]
85struct Part {
86 text: String,
87}
88
89#[derive(Deserialize)]
91struct EmbeddingResponse {
92 embedding: EmbeddingData,
93}
94
95#[derive(Deserialize)]
96struct EmbeddingData {
97 values: Vec<f32>,
98}
99
100#[derive(Deserialize)]
102struct GeminiError {
103 error: GeminiErrorDetail,
104}
105
106#[derive(Deserialize)]
107struct GeminiErrorDetail {
108 message: String,
109 #[allow(dead_code)]
110 status: Option<String>,
111}
112
113fn classify_gemini_error(status_code: u16, message: &str) -> GeminiErrorKind {
115 match status_code {
116 401 => GeminiErrorKind::Authentication,
117 429 => {
118 if message.contains("insufficient_quota") || message.contains("quota") {
120 GeminiErrorKind::QuotaExceeded
121 } else {
122 GeminiErrorKind::RateLimit
123 }
124 }
125 500..=599 => GeminiErrorKind::ServerError,
126 _ => {
127 if message.contains("API key") || message.contains("Unauthorized") {
129 GeminiErrorKind::Authentication
130 } else if message.contains("rate") {
131 GeminiErrorKind::RateLimit
132 } else if message.contains("quota") {
133 GeminiErrorKind::QuotaExceeded
134 } else {
135 GeminiErrorKind::Unknown
136 }
137 }
138 }
139}
140
141impl GeminiClient {
142 pub fn new(api_key: &str) -> Result<Self, AppError> {
144 let http_config = HttpConfig::default();
145 let client = Client::builder()
146 .timeout(http_config.timeout)
147 .build()
148 .map_err(|e| AppError::ClientError(e.to_string()))?;
149
150 Ok(Self {
151 client,
152 api_key: api_key.to_string(),
153 })
154 }
155
156 pub async fn get_embeddings(&self, text: &str) -> Result<Vec<f32>, AppError> {
174 let sanitized_text = text.replace('\n', " ");
176
177 let url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-embedding-001:embedContent";
178
179 let request_body = EmbeddingRequest {
180 model: "models/gemini-embedding-001".to_string(),
181 content: Content {
182 parts: vec![Part {
183 text: sanitized_text,
184 }],
185 },
186 output_dimensionality: Some(768),
187 };
188
189 let response = self
190 .client
191 .post(url)
192 .header("x-goog-api-key", self.api_key.clone())
193 .json(&request_body)
194 .send()
195 .await
196 .map_err(|e| {
197 if e.is_timeout() {
198 AppError::Timeout(30)
199 } else if e.is_connect() {
200 AppError::GeminiError(GeminiErrorDetails::new(
201 GeminiErrorKind::NetworkError,
202 format!("Connection failed: {}", e),
203 0, ))
205 } else {
206 AppError::ClientError(e.to_string())
207 }
208 })?;
209
210 let status = response.status();
211
212 if !status.is_success() {
213 let status_code = status.as_u16();
214 let error_text = response.text().await.unwrap_or_default();
215
216 let message = if let Ok(gemini_error) = serde_json::from_str::<GeminiError>(&error_text)
218 {
219 gemini_error.error.message
220 } else {
221 format!("HTTP {}: {}", status_code, error_text)
222 };
223
224 let kind = classify_gemini_error(status_code, &message);
226
227 return Err(AppError::GeminiError(GeminiErrorDetails::new(
229 kind,
230 message,
231 status_code,
232 )));
233 }
234
235 let embedding_response: EmbeddingResponse = response
236 .json()
237 .await
238 .map_err(|e| AppError::ClientError(format!("Failed to parse response: {}", e)))?;
239
240 Ok(embedding_response.embedding.values)
241 }
242}
243
244impl ceres_core::traits::EmbeddingProvider for GeminiClient {
249 fn name(&self) -> &'static str {
250 "gemini"
251 }
252
253 fn dimension(&self) -> usize {
254 768
256 }
257
258 async fn generate(&self, text: &str) -> Result<Vec<f32>, AppError> {
259 self.get_embeddings(text).await
260 }
261
262 }
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270
271 #[test]
272 fn test_new_client() {
273 let client = GeminiClient::new("test-api-key");
274 assert!(client.is_ok());
275 }
276
277 #[test]
278 fn test_text_sanitization() {
279 let text_with_newlines = "Line 1\nLine 2\nLine 3";
280 let sanitized = text_with_newlines.replace('\n', " ");
281 assert_eq!(sanitized, "Line 1 Line 2 Line 3");
282 }
283
284 #[test]
285 fn test_request_serialization() {
286 let request = EmbeddingRequest {
287 model: "models/gemini-embedding-001".to_string(),
288 content: Content {
289 parts: vec![Part {
290 text: "Hello world".to_string(),
291 }],
292 },
293 output_dimensionality: Some(768),
294 };
295
296 let json = serde_json::to_string(&request).unwrap();
297 assert!(json.contains("gemini-embedding-001"));
298 assert!(json.contains("Hello world"));
299 assert!(json.contains("output_dimensionality"));
300 }
301
302 #[test]
303 fn test_classify_gemini_error_auth() {
304 let kind = classify_gemini_error(401, "Invalid API key");
305 assert_eq!(kind, GeminiErrorKind::Authentication);
306 }
307
308 #[test]
309 fn test_classify_gemini_error_auth_from_message() {
310 let kind = classify_gemini_error(400, "API key not valid");
311 assert_eq!(kind, GeminiErrorKind::Authentication);
312 }
313
314 #[test]
315 fn test_classify_gemini_error_rate_limit() {
316 let kind = classify_gemini_error(429, "Rate limit exceeded");
317 assert_eq!(kind, GeminiErrorKind::RateLimit);
318 }
319
320 #[test]
321 fn test_classify_gemini_error_quota() {
322 let kind = classify_gemini_error(429, "insufficient_quota");
323 assert_eq!(kind, GeminiErrorKind::QuotaExceeded);
324 }
325
326 #[test]
327 fn test_classify_gemini_error_server() {
328 let kind = classify_gemini_error(500, "Internal server error");
329 assert_eq!(kind, GeminiErrorKind::ServerError);
330 }
331
332 #[test]
333 fn test_classify_gemini_error_server_503() {
334 let kind = classify_gemini_error(503, "Service unavailable");
335 assert_eq!(kind, GeminiErrorKind::ServerError);
336 }
337
338 #[test]
339 fn test_classify_gemini_error_unknown() {
340 let kind = classify_gemini_error(400, "Bad request");
341 assert_eq!(kind, GeminiErrorKind::Unknown);
342 }
343}