agcodex_core/embeddings/providers/
gemini.rs1use super::super::EmbeddingError;
9use super::super::EmbeddingProvider;
10use super::super::EmbeddingVector;
11use reqwest::Client;
12use serde::Deserialize;
13use serde::Serialize;
14use tokio::time::Duration;
15use tokio::time::sleep;
16
17pub struct GeminiProvider {
19 client: Client,
20 api_key: String,
21 model: String,
22 api_endpoint: Option<String>,
23}
24
25impl GeminiProvider {
26 pub fn new(api_key: String, model: String) -> Self {
28 Self {
29 client: Client::new(),
30 api_key,
31 model,
32 api_endpoint: None,
33 }
34 }
35
36 pub fn new_with_endpoint(api_key: String, model: String, api_endpoint: String) -> Self {
38 Self {
39 client: Client::new(),
40 api_key,
41 model,
42 api_endpoint: Some(api_endpoint),
43 }
44 }
45}
46
47#[derive(Debug, Serialize)]
48struct GeminiRequest {
49 content: GeminiContent,
50}
51
52#[derive(Debug, Serialize)]
53struct GeminiContent {
54 parts: Vec<GeminiPart>,
55}
56
57#[derive(Debug, Serialize)]
58struct GeminiPart {
59 text: String,
60}
61
62#[derive(Debug, Deserialize)]
63struct GeminiResponse {
64 embedding: GeminiEmbedding,
65}
66
67#[derive(Debug, Deserialize)]
68struct GeminiEmbedding {
69 values: Vec<f32>,
70}
71
72#[derive(Debug, Deserialize)]
73struct GeminiError {
74 error: GeminiErrorDetail,
75}
76
77#[derive(Debug, Deserialize)]
78struct GeminiErrorDetail {
79 message: String,
80 _code: Option<u32>,
81 status: Option<String>,
82}
83
84#[async_trait::async_trait]
85impl EmbeddingProvider for GeminiProvider {
86 fn model_id(&self) -> String {
87 format!("gemini:{}", self.model)
88 }
89
90 fn dimensions(&self) -> usize {
91 match self.model.as_str() {
93 "gemini-embedding-001" => 768,
94 "text-embedding-004" => 768,
95 "embedding-001" => 768,
96 "textembedding-gecko@001" => 768,
97 "textembedding-gecko@003" => 768,
98 _ => 768, }
100 }
101
102 async fn embed(&self, text: &str) -> Result<EmbeddingVector, EmbeddingError> {
103 let request = GeminiRequest {
104 content: GeminiContent {
105 parts: vec![GeminiPart {
106 text: text.to_string(),
107 }],
108 },
109 };
110
111 let endpoint = self
112 .api_endpoint
113 .as_deref()
114 .unwrap_or("https://generativelanguage.googleapis.com");
115 let url = format!(
116 "{}/v1/models/{}:embedContent?key={}",
117 endpoint, self.model, self.api_key
118 );
119
120 let response = self
121 .client
122 .post(&url)
123 .header("Content-Type", "application/json")
124 .json(&request)
125 .send()
126 .await
127 .map_err(|e| EmbeddingError::ApiError(format!("Request failed: {}", e)))?;
128
129 let status = response.status();
130 if !status.is_success() {
131 let error_text = response
132 .text()
133 .await
134 .unwrap_or_else(|_| "Unknown error".to_string());
135
136 if let Ok(error) = serde_json::from_str::<GeminiError>(&error_text) {
138 return Err(EmbeddingError::ApiError(format!(
139 "Gemini API error ({}): {} - {}",
140 status,
141 error.error.status.unwrap_or_else(|| "Unknown".to_string()),
142 error.error.message
143 )));
144 }
145
146 return Err(EmbeddingError::ApiError(format!(
147 "Gemini API error ({}): {}",
148 status, error_text
149 )));
150 }
151
152 let gemini_response: GeminiResponse = response
153 .json()
154 .await
155 .map_err(|e| EmbeddingError::ApiError(format!("Failed to parse response: {}", e)))?;
156
157 let expected_dims = self.dimensions();
159 if gemini_response.embedding.values.len() != expected_dims {
160 return Err(EmbeddingError::DimensionMismatch {
161 expected: expected_dims,
162 actual: gemini_response.embedding.values.len(),
163 });
164 }
165
166 Ok(gemini_response.embedding.values)
167 }
168
169 async fn embed_batch(&self, texts: &[String]) -> Result<Vec<EmbeddingVector>, EmbeddingError> {
170 if texts.is_empty() {
171 return Ok(vec![]);
172 }
173
174 let mut embeddings = Vec::with_capacity(texts.len());
177
178 for text in texts {
179 let embedding = self.embed(text).await?;
180 embeddings.push(embedding);
181
182 sleep(Duration::from_millis(100)).await;
184 }
185
186 Ok(embeddings)
187 }
188
189 fn is_available(&self) -> bool {
190 !self.api_key.is_empty()
191 }
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197
198 #[test]
199 fn test_model_id() {
200 let provider = GeminiProvider::new("test-key".to_string(), "embedding-001".to_string());
201 assert_eq!(provider.model_id(), "gemini:embedding-001");
202 }
203
204 #[test]
205 fn test_dimensions() {
206 let provider =
207 GeminiProvider::new("test-key".to_string(), "gemini-embedding-001".to_string());
208 assert_eq!(provider.dimensions(), 768);
209
210 let provider_old =
211 GeminiProvider::new("test-key".to_string(), "text-embedding-004".to_string());
212 assert_eq!(provider_old.dimensions(), 768);
213
214 let provider_default =
215 GeminiProvider::new("test-key".to_string(), "unknown-model".to_string());
216 assert_eq!(provider_default.dimensions(), 768);
217 }
218
219 #[test]
220 fn test_is_available() {
221 let provider = GeminiProvider::new("test-key".to_string(), "embedding-001".to_string());
222 assert!(provider.is_available());
223
224 let provider_empty = GeminiProvider::new(String::new(), "embedding-001".to_string());
225 assert!(!provider_empty.is_available());
226 }
227}