1use ceres_core::HttpConfig;
35use ceres_core::error::{AppError, GeminiErrorDetails, GeminiErrorKind};
36use reqwest::Client;
37use serde::{Deserialize, Serialize};
38
39#[derive(Clone)]
63pub struct GeminiClient {
64 client: Client,
65 api_key: String,
66}
67
68#[derive(Serialize)]
70struct EmbeddingRequest {
71 model: String,
72 content: Content,
73}
74
75#[derive(Serialize)]
76struct Content {
77 parts: Vec<Part>,
78}
79
80#[derive(Serialize)]
81struct Part {
82 text: String,
83}
84
85#[derive(Deserialize)]
87struct EmbeddingResponse {
88 embedding: EmbeddingData,
89}
90
91#[derive(Deserialize)]
92struct EmbeddingData {
93 values: Vec<f32>,
94}
95
96#[derive(Deserialize)]
98struct GeminiError {
99 error: GeminiErrorDetail,
100}
101
102#[derive(Deserialize)]
103struct GeminiErrorDetail {
104 message: String,
105 #[allow(dead_code)]
106 status: Option<String>,
107}
108
109fn classify_gemini_error(status_code: u16, message: &str) -> GeminiErrorKind {
111 match status_code {
112 401 => GeminiErrorKind::Authentication,
113 429 => {
114 if message.contains("insufficient_quota") || message.contains("quota") {
116 GeminiErrorKind::QuotaExceeded
117 } else {
118 GeminiErrorKind::RateLimit
119 }
120 }
121 500..=599 => GeminiErrorKind::ServerError,
122 _ => {
123 if message.contains("API key") || message.contains("Unauthorized") {
125 GeminiErrorKind::Authentication
126 } else if message.contains("rate") {
127 GeminiErrorKind::RateLimit
128 } else if message.contains("quota") {
129 GeminiErrorKind::QuotaExceeded
130 } else {
131 GeminiErrorKind::Unknown
132 }
133 }
134 }
135}
136
137impl GeminiClient {
138 pub fn new(api_key: &str) -> Result<Self, AppError> {
140 let http_config = HttpConfig::default();
141 let client = Client::builder()
142 .timeout(http_config.timeout)
143 .build()
144 .map_err(|e| AppError::ClientError(e.to_string()))?;
145
146 Ok(Self {
147 client,
148 api_key: api_key.to_string(),
149 })
150 }
151
152 pub async fn get_embeddings(&self, text: &str) -> Result<Vec<f32>, AppError> {
170 let sanitized_text = text.replace('\n', " ");
172
173 let url = "https://generativelanguage.googleapis.com/v1beta/models/text-embedding-004:embedContent";
176
177 let request_body = EmbeddingRequest {
182 model: "models/text-embedding-004".to_string(),
183 content: Content {
184 parts: vec![Part {
185 text: sanitized_text,
186 }],
187 },
188 };
189
190 let response = self
191 .client
192 .post(url)
193 .header("x-goog-api-key", self.api_key.clone())
194 .json(&request_body)
195 .send()
196 .await
197 .map_err(|e| {
198 if e.is_timeout() {
199 AppError::Timeout(30)
200 } else if e.is_connect() {
201 AppError::GeminiError(GeminiErrorDetails::new(
202 GeminiErrorKind::NetworkError,
203 format!("Connection failed: {}", e),
204 0, ))
206 } else {
207 AppError::ClientError(e.to_string())
208 }
209 })?;
210
211 let status = response.status();
212
213 if !status.is_success() {
214 let status_code = status.as_u16();
215 let error_text = response.text().await.unwrap_or_default();
216
217 let message = if let Ok(gemini_error) = serde_json::from_str::<GeminiError>(&error_text)
219 {
220 gemini_error.error.message
221 } else {
222 format!("HTTP {}: {}", status_code, error_text)
223 };
224
225 let kind = classify_gemini_error(status_code, &message);
227
228 return Err(AppError::GeminiError(GeminiErrorDetails::new(
230 kind,
231 message,
232 status_code,
233 )));
234 }
235
236 let embedding_response: EmbeddingResponse = response
237 .json()
238 .await
239 .map_err(|e| AppError::ClientError(format!("Failed to parse response: {}", e)))?;
240
241 Ok(embedding_response.embedding.values)
242 }
243}
244
245impl ceres_core::traits::EmbeddingProvider for GeminiClient {
250 async fn generate(&self, text: &str) -> Result<Vec<f32>, AppError> {
251 self.get_embeddings(text).await
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258
259 #[test]
260 fn test_new_client() {
261 let client = GeminiClient::new("test-api-key");
262 assert!(client.is_ok());
263 }
264
265 #[test]
266 fn test_text_sanitization() {
267 let text_with_newlines = "Line 1\nLine 2\nLine 3";
268 let sanitized = text_with_newlines.replace('\n', " ");
269 assert_eq!(sanitized, "Line 1 Line 2 Line 3");
270 }
271
272 #[test]
273 fn test_request_serialization() {
274 let request = EmbeddingRequest {
275 model: "models/text-embedding-004".to_string(),
276 content: Content {
277 parts: vec![Part {
278 text: "Hello world".to_string(),
279 }],
280 },
281 };
282
283 let json = serde_json::to_string(&request).unwrap();
284 assert!(json.contains("text-embedding-004"));
285 assert!(json.contains("Hello world"));
286 }
287
288 #[test]
289 fn test_classify_gemini_error_auth() {
290 let kind = classify_gemini_error(401, "Invalid API key");
291 assert_eq!(kind, GeminiErrorKind::Authentication);
292 }
293
294 #[test]
295 fn test_classify_gemini_error_auth_from_message() {
296 let kind = classify_gemini_error(400, "API key not valid");
297 assert_eq!(kind, GeminiErrorKind::Authentication);
298 }
299
300 #[test]
301 fn test_classify_gemini_error_rate_limit() {
302 let kind = classify_gemini_error(429, "Rate limit exceeded");
303 assert_eq!(kind, GeminiErrorKind::RateLimit);
304 }
305
306 #[test]
307 fn test_classify_gemini_error_quota() {
308 let kind = classify_gemini_error(429, "insufficient_quota");
309 assert_eq!(kind, GeminiErrorKind::QuotaExceeded);
310 }
311
312 #[test]
313 fn test_classify_gemini_error_server() {
314 let kind = classify_gemini_error(500, "Internal server error");
315 assert_eq!(kind, GeminiErrorKind::ServerError);
316 }
317
318 #[test]
319 fn test_classify_gemini_error_server_503() {
320 let kind = classify_gemini_error(503, "Service unavailable");
321 assert_eq!(kind, GeminiErrorKind::ServerError);
322 }
323
324 #[test]
325 fn test_classify_gemini_error_unknown() {
326 let kind = classify_gemini_error(400, "Bad request");
327 assert_eq!(kind, GeminiErrorKind::Unknown);
328 }
329}