codemem_embeddings/
gemini.rs1use codemem_core::CodememError;
15
16pub const DEFAULT_BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta";
18
19pub const DEFAULT_MODEL: &str = "text-embedding-004";
21
22pub struct GeminiProvider {
24 api_key: String,
25 model: String,
26 dimensions: usize,
27 pub(crate) base_url: String,
28 client: reqwest::blocking::Client,
29}
30
31impl GeminiProvider {
32 pub fn new(api_key: &str, model: &str, dimensions: usize, base_url: Option<&str>) -> Self {
34 Self {
35 api_key: api_key.to_string(),
36 model: model.to_string(),
37 dimensions,
38 base_url: base_url.unwrap_or(DEFAULT_BASE_URL).to_string(),
39 client: reqwest::blocking::Client::new(),
40 }
41 }
42}
43
44impl super::EmbeddingProvider for GeminiProvider {
45 fn dimensions(&self) -> usize {
46 self.dimensions
47 }
48
49 fn embed(&self, text: &str) -> Result<Vec<f32>, CodememError> {
50 let url = format!("{}/models/{}:embedContent", self.base_url, self.model);
51
52 let mut body = serde_json::json!({
53 "model": format!("models/{}", self.model),
54 "content": {
55 "parts": [{"text": text}]
56 },
57 "taskType": "RETRIEVAL_DOCUMENT",
58 });
59
60 if self.dimensions > 0 {
61 body["outputDimensionality"] = serde_json::json!(self.dimensions);
62 }
63
64 let response = self
65 .client
66 .post(&url)
67 .header("Content-Type", "application/json")
68 .header("x-goog-api-key", &self.api_key)
69 .json(&body)
70 .send()
71 .map_err(|e| CodememError::Embedding(format!("Gemini request failed: {e}")))?;
72
73 if !response.status().is_success() {
74 let status = response.status();
75 let body = response.text().unwrap_or_default();
76 return Err(CodememError::Embedding(format!(
77 "Gemini returned status {status}: {body}",
78 )));
79 }
80
81 let json: serde_json::Value = response
82 .json()
83 .map_err(|e| CodememError::Embedding(format!("Gemini response parse error: {e}")))?;
84
85 let embedding: Vec<f32> = json
86 .get("embedding")
87 .and_then(|v| v.get("values"))
88 .and_then(|v| v.as_array())
89 .ok_or_else(|| {
90 CodememError::Embedding("Missing embedding.values in Gemini response".into())
91 })?
92 .iter()
93 .map(|v| v.as_f64().unwrap_or(0.0) as f32)
94 .collect();
95
96 if self.dimensions > 0 && embedding.len() != self.dimensions {
97 return Err(CodememError::Embedding(format!(
98 "Gemini returned {} dimensions, expected {}",
99 embedding.len(),
100 self.dimensions
101 )));
102 }
103
104 Ok(embedding)
105 }
106
107 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, CodememError> {
108 if texts.is_empty() {
109 return Ok(vec![]);
110 }
111
112 let url = format!("{}/models/{}:batchEmbedContents", self.base_url, self.model);
113
114 const MAX_BATCH: usize = 100;
116 let mut all_embeddings = Vec::with_capacity(texts.len());
117
118 for chunk in texts.chunks(MAX_BATCH) {
119 let requests: Vec<serde_json::Value> = chunk
120 .iter()
121 .map(|text| {
122 let mut req = serde_json::json!({
123 "model": format!("models/{}", self.model),
124 "content": {
125 "parts": [{"text": text}]
126 },
127 "taskType": "RETRIEVAL_DOCUMENT",
128 });
129 if self.dimensions > 0 {
130 req["outputDimensionality"] = serde_json::json!(self.dimensions);
131 }
132 req
133 })
134 .collect();
135
136 let body = serde_json::json!({ "requests": requests });
137
138 let response = self
139 .client
140 .post(&url)
141 .header("Content-Type", "application/json")
142 .header("x-goog-api-key", &self.api_key)
143 .json(&body)
144 .send()
145 .map_err(|e| {
146 CodememError::Embedding(format!("Gemini batch request failed: {e}"))
147 })?;
148
149 if !response.status().is_success() {
150 let status = response.status();
151 let body = response.text().unwrap_or_default();
152 return Err(CodememError::Embedding(format!(
153 "Gemini returned status {status}: {body}",
154 )));
155 }
156
157 let json: serde_json::Value = response.json().map_err(|e| {
158 CodememError::Embedding(format!("Gemini response parse error: {e}"))
159 })?;
160
161 let embeddings = json
162 .get("embeddings")
163 .and_then(|v| v.as_array())
164 .ok_or_else(|| {
165 CodememError::Embedding("Missing 'embeddings' array in Gemini response".into())
166 })?;
167
168 if embeddings.len() != chunk.len() {
169 return Err(CodememError::Embedding(format!(
170 "Gemini returned {} embeddings, expected {}",
171 embeddings.len(),
172 chunk.len()
173 )));
174 }
175
176 for (i, item) in embeddings.iter().enumerate() {
177 let embedding: Vec<f32> = item
178 .get("values")
179 .and_then(|v| v.as_array())
180 .ok_or_else(|| {
181 CodememError::Embedding(format!(
182 "Missing values in Gemini embedding at index {i}"
183 ))
184 })?
185 .iter()
186 .map(|v| v.as_f64().unwrap_or(0.0) as f32)
187 .collect();
188
189 if self.dimensions > 0 && embedding.len() != self.dimensions {
190 return Err(CodememError::Embedding(format!(
191 "Gemini returned {} dimensions at index {i}, expected {}",
192 embedding.len(),
193 self.dimensions
194 )));
195 }
196 all_embeddings.push(embedding);
197 }
198 }
199
200 Ok(all_embeddings)
201 }
202
203 fn name(&self) -> &str {
204 "gemini"
205 }
206}
207
208#[cfg(test)]
209#[path = "tests/gemini_tests.rs"]
210mod tests;