1use ceres_core::error::{AppError, GeminiErrorDetails, GeminiErrorKind};
22use ceres_core::HttpConfig;
23use reqwest::Client;
24use serde::{Deserialize, Serialize};
25
26#[derive(Clone)]
50pub struct GeminiClient {
51 client: Client,
52 api_key: String,
53}
54
55#[derive(Serialize)]
57struct EmbeddingRequest {
58 model: String,
59 content: Content,
60}
61
62#[derive(Serialize)]
63struct Content {
64 parts: Vec<Part>,
65}
66
67#[derive(Serialize)]
68struct Part {
69 text: String,
70}
71
72#[derive(Deserialize)]
74struct EmbeddingResponse {
75 embedding: EmbeddingData,
76}
77
78#[derive(Deserialize)]
79struct EmbeddingData {
80 values: Vec<f32>,
81}
82
83#[derive(Deserialize)]
85struct GeminiError {
86 error: GeminiErrorDetail,
87}
88
89#[derive(Deserialize)]
90struct GeminiErrorDetail {
91 message: String,
92 #[allow(dead_code)]
93 status: Option<String>,
94}
95
96fn classify_gemini_error(status_code: u16, message: &str) -> GeminiErrorKind {
98 match status_code {
99 401 => GeminiErrorKind::Authentication,
100 429 => {
101 if message.contains("insufficient_quota") || message.contains("quota") {
103 GeminiErrorKind::QuotaExceeded
104 } else {
105 GeminiErrorKind::RateLimit
106 }
107 }
108 500..=599 => GeminiErrorKind::ServerError,
109 _ => {
110 if message.contains("API key") || message.contains("Unauthorized") {
112 GeminiErrorKind::Authentication
113 } else if message.contains("rate") {
114 GeminiErrorKind::RateLimit
115 } else if message.contains("quota") {
116 GeminiErrorKind::QuotaExceeded
117 } else {
118 GeminiErrorKind::Unknown
119 }
120 }
121 }
122}
123
124impl GeminiClient {
125 pub fn new(api_key: &str) -> Result<Self, AppError> {
127 let http_config = HttpConfig::default();
128 let client = Client::builder()
129 .timeout(http_config.timeout)
130 .build()
131 .map_err(|e| AppError::ClientError(e.to_string()))?;
132
133 Ok(Self {
134 client,
135 api_key: api_key.to_string(),
136 })
137 }
138
139 pub async fn get_embeddings(&self, text: &str) -> Result<Vec<f32>, AppError> {
157 let sanitized_text = text.replace('\n', " ");
159
160 let url = "https://generativelanguage.googleapis.com/v1beta/models/text-embedding-004:embedContent";
163
164 let request_body = EmbeddingRequest {
169 model: "models/text-embedding-004".to_string(),
170 content: Content {
171 parts: vec![Part {
172 text: sanitized_text,
173 }],
174 },
175 };
176
177 let response = self
178 .client
179 .post(url)
180 .header("x-goog-api-key", self.api_key.clone())
181 .json(&request_body)
182 .send()
183 .await
184 .map_err(|e| {
185 if e.is_timeout() {
186 AppError::Timeout(30)
187 } else if e.is_connect() {
188 AppError::GeminiError(GeminiErrorDetails::new(
189 GeminiErrorKind::NetworkError,
190 format!("Connection failed: {}", e),
191 0, ))
193 } else {
194 AppError::ClientError(e.to_string())
195 }
196 })?;
197
198 let status = response.status();
199
200 if !status.is_success() {
201 let status_code = status.as_u16();
202 let error_text = response.text().await.unwrap_or_default();
203
204 let message = if let Ok(gemini_error) = serde_json::from_str::<GeminiError>(&error_text)
206 {
207 gemini_error.error.message
208 } else {
209 format!("HTTP {}: {}", status_code, error_text)
210 };
211
212 let kind = classify_gemini_error(status_code, &message);
214
215 return Err(AppError::GeminiError(GeminiErrorDetails::new(
217 kind,
218 message,
219 status_code,
220 )));
221 }
222
223 let embedding_response: EmbeddingResponse = response
224 .json()
225 .await
226 .map_err(|e| AppError::ClientError(format!("Failed to parse response: {}", e)))?;
227
228 Ok(embedding_response.embedding.values)
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235
236 #[test]
237 fn test_new_client() {
238 let client = GeminiClient::new("test-api-key");
239 assert!(client.is_ok());
240 }
241
242 #[test]
243 fn test_text_sanitization() {
244 let text_with_newlines = "Line 1\nLine 2\nLine 3";
245 let sanitized = text_with_newlines.replace('\n', " ");
246 assert_eq!(sanitized, "Line 1 Line 2 Line 3");
247 }
248
249 #[test]
250 fn test_request_serialization() {
251 let request = EmbeddingRequest {
252 model: "models/text-embedding-004".to_string(),
253 content: Content {
254 parts: vec![Part {
255 text: "Hello world".to_string(),
256 }],
257 },
258 };
259
260 let json = serde_json::to_string(&request).unwrap();
261 assert!(json.contains("text-embedding-004"));
262 assert!(json.contains("Hello world"));
263 }
264
265 #[test]
266 fn test_classify_gemini_error_auth() {
267 let kind = classify_gemini_error(401, "Invalid API key");
268 assert_eq!(kind, GeminiErrorKind::Authentication);
269 }
270
271 #[test]
272 fn test_classify_gemini_error_auth_from_message() {
273 let kind = classify_gemini_error(400, "API key not valid");
274 assert_eq!(kind, GeminiErrorKind::Authentication);
275 }
276
277 #[test]
278 fn test_classify_gemini_error_rate_limit() {
279 let kind = classify_gemini_error(429, "Rate limit exceeded");
280 assert_eq!(kind, GeminiErrorKind::RateLimit);
281 }
282
283 #[test]
284 fn test_classify_gemini_error_quota() {
285 let kind = classify_gemini_error(429, "insufficient_quota");
286 assert_eq!(kind, GeminiErrorKind::QuotaExceeded);
287 }
288
289 #[test]
290 fn test_classify_gemini_error_server() {
291 let kind = classify_gemini_error(500, "Internal server error");
292 assert_eq!(kind, GeminiErrorKind::ServerError);
293 }
294
295 #[test]
296 fn test_classify_gemini_error_server_503() {
297 let kind = classify_gemini_error(503, "Service unavailable");
298 assert_eq!(kind, GeminiErrorKind::ServerError);
299 }
300
301 #[test]
302 fn test_classify_gemini_error_unknown() {
303 let kind = classify_gemini_error(400, "Bad request");
304 assert_eq!(kind, GeminiErrorKind::Unknown);
305 }
306}