rectilinear_core/embedding/
mod.rs1#[cfg(feature = "local-embeddings")]
2mod local;
3
4use anyhow::{Context, Result};
5use reqwest::StatusCode;
6use serde::Deserialize;
7
8use crate::config::{Config, EmbeddingBackend};
9
10enum Backend {
11 Gemini(GeminiBackend),
12 #[cfg(feature = "local-embeddings")]
13 Local(local::LocalBackend),
14}
15
16pub struct Embedder {
17 backend: Backend,
18 dimensions: usize,
19}
20
21struct GeminiBackend {
24 client: reqwest::Client,
25 api_key: String,
26}
27
28#[derive(Deserialize)]
29struct GeminiModelsResponse {
30 models: Vec<GeminiModel>,
31}
32
33#[derive(Deserialize)]
34struct GeminiModel {
35 #[allow(dead_code)]
36 name: String,
37}
38
39#[derive(Deserialize)]
40struct GeminiErrorEnvelope {
41 error: GeminiErrorBody,
42}
43
44#[derive(Deserialize)]
45struct GeminiErrorBody {
46 #[allow(dead_code)]
47 code: Option<u16>,
48 message: Option<String>,
49 status: Option<String>,
50}
51
52impl GeminiBackend {
53 fn new(api_key: &str) -> Self {
54 Self {
55 client: reqwest::Client::new(),
56 api_key: api_key.to_string(),
57 }
58 }
59
60 fn with_http_client(client: reqwest::Client, api_key: &str) -> Self {
61 Self {
62 client,
63 api_key: api_key.to_string(),
64 }
65 }
66
67 async fn test_api_key(&self) -> Result<()> {
68 let resp = self
69 .client
70 .get("https://generativelanguage.googleapis.com/v1beta/models?pageSize=1")
71 .header("x-goog-api-key", &self.api_key)
72 .send()
73 .await
74 .context("Failed to call Gemini models API")?;
75
76 let status = resp.status();
77 let body = resp
78 .text()
79 .await
80 .context("Failed to read Gemini models response")?;
81
82 if !status.is_success() {
83 anyhow::bail!("{}", summarize_gemini_error(status, &body));
84 }
85
86 let response: GeminiModelsResponse =
87 serde_json::from_str(&body).context("Failed to parse Gemini models response")?;
88 if response.models.is_empty() {
89 anyhow::bail!("Gemini returned no models");
90 }
91
92 Ok(())
93 }
94
95 async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
96 let url = format!(
97 "https://generativelanguage.googleapis.com/v1beta/models/gemini-embedding-2-preview:batchEmbedContents?key={}",
98 self.api_key
99 );
100
101 let mut all_embeddings = Vec::new();
102 for batch in texts.chunks(100) {
103 let requests: Vec<_> = batch
104 .iter()
105 .map(|text| {
106 serde_json::json!({
107 "model": "models/gemini-embedding-2-preview",
108 "content": {
109 "parts": [{"text": text}]
110 },
111 "outputDimensionality": 768
112 })
113 })
114 .collect();
115
116 let body = serde_json::json!({ "requests": requests });
117
118 let resp = self
119 .client
120 .post(&url)
121 .json(&body)
122 .send()
123 .await
124 .context("Failed to call Gemini embedding API")?;
125
126 let status = resp.status();
127 if !status.is_success() {
128 let text = resp.text().await.unwrap_or_default();
129 anyhow::bail!("Gemini API returned {}: {}", status, text);
130 }
131
132 let data: serde_json::Value = resp.json().await?;
133 let embeddings = data["embeddings"]
134 .as_array()
135 .context("No embeddings in response")?;
136
137 for emb in embeddings {
138 let values: Vec<f32> = emb["values"]
139 .as_array()
140 .context("No values in embedding")?
141 .iter()
142 .map(|v| v.as_f64().unwrap_or(0.0) as f32)
143 .collect();
144 all_embeddings.push(values);
145 }
146 }
147
148 Ok(all_embeddings)
149 }
150}
151
152fn summarize_gemini_error(status: StatusCode, body: &str) -> String {
153 if let Ok(error) = serde_json::from_str::<GeminiErrorEnvelope>(body) {
154 return compact_gemini_error(
155 status,
156 error.error.message.as_deref(),
157 error.error.status.as_deref(),
158 );
159 }
160
161 compact_gemini_error(status, None, None)
162}
163
164fn compact_gemini_error(
165 status: StatusCode,
166 message: Option<&str>,
167 api_status: Option<&str>,
168) -> String {
169 let normalized = message.unwrap_or("").trim().to_lowercase();
170
171 if normalized.contains("api key not valid") || normalized.contains("invalid api key") {
172 return "Invalid API key".into();
173 }
174
175 if normalized.contains("reported as leaked") || normalized.contains("disabled") {
176 return "Key blocked".into();
177 }
178
179 if normalized.contains("billing")
180 || matches!(api_status, Some("FAILED_PRECONDITION" | "SERVICE_DISABLED"))
181 {
182 return "Setup required".into();
183 }
184
185 if status == StatusCode::UNAUTHORIZED || matches!(api_status, Some("UNAUTHENTICATED")) {
186 return "Unauthorized".into();
187 }
188
189 if status == StatusCode::FORBIDDEN || matches!(api_status, Some("PERMISSION_DENIED")) {
190 return "Access denied".into();
191 }
192
193 if status == StatusCode::TOO_MANY_REQUESTS || matches!(api_status, Some("RESOURCE_EXHAUSTED")) {
194 return "Rate limited".into();
195 }
196
197 if status.is_server_error() {
198 return "Gemini unavailable".into();
199 }
200
201 if let Some(message) = message {
202 let trimmed = message.trim();
203 if !trimmed.is_empty() && trimmed.len() <= 48 {
204 return trimmed.to_string();
205 }
206 }
207
208 status
209 .canonical_reason()
210 .unwrap_or("Request failed")
211 .to_string()
212}
213
214impl Embedder {
217 pub fn new(config: &Config) -> Result<Self> {
218 let gemini_key = std::env::var("GEMINI_API_KEY")
219 .ok()
220 .or_else(|| config.embedding.gemini_api_key.clone());
221
222 match config.embedding.backend {
223 EmbeddingBackend::Api => {
224 let key = gemini_key.context(
225 "Gemini API key required for API backend. Set GEMINI_API_KEY or configure in config.",
226 )?;
227 Self::new_api(&key)
228 }
229 #[cfg(feature = "local-embeddings")]
230 EmbeddingBackend::Local => {
231 let backend = local::LocalBackend::new(config)?;
232 let dimensions = backend.dimensions();
233 Ok(Self {
234 dimensions,
235 backend: Backend::Local(backend),
236 })
237 }
238 #[cfg(not(feature = "local-embeddings"))]
239 EmbeddingBackend::Local => {
240 anyhow::bail!(
241 "Local embeddings not available — compile with `local-embeddings` feature"
242 )
243 }
244 }
245 }
246
247 pub fn new_api(api_key: &str) -> Result<Self> {
249 Ok(Self {
250 dimensions: 768,
251 backend: Backend::Gemini(GeminiBackend::new(api_key)),
252 })
253 }
254
255 pub fn new_api_with_http_client(client: reqwest::Client, api_key: &str) -> Result<Self> {
257 Ok(Self {
258 dimensions: 768,
259 backend: Backend::Gemini(GeminiBackend::with_http_client(client, api_key)),
260 })
261 }
262
263 pub async fn test_api_key(&self) -> Result<()> {
264 match &self.backend {
265 Backend::Gemini(b) => b.test_api_key().await,
266 #[cfg(feature = "local-embeddings")]
267 Backend::Local(_) => anyhow::bail!("Gemini API key not in use"),
268 }
269 }
270
271 #[cfg(feature = "local-embeddings")]
273 pub fn new_local(_models_dir: &std::path::Path) -> Result<Self> {
274 let config = Config {
276 embedding: crate::config::EmbeddingConfig {
277 backend: EmbeddingBackend::Local,
278 gemini_api_key: None,
279 },
280 ..Config::default()
281 };
282 let backend = local::LocalBackend::new(&config)?;
283 let dimensions = backend.dimensions();
284 Ok(Self {
285 dimensions,
286 backend: Backend::Local(backend),
287 })
288 }
289
290 pub async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
291 match &self.backend {
292 Backend::Gemini(b) => b.embed_batch(texts).await,
293 #[cfg(feature = "local-embeddings")]
294 Backend::Local(b) => b.embed_batch(texts),
295 }
296 }
297
298 pub async fn embed_single(&self, text: &str) -> Result<Vec<f32>> {
299 let results: Vec<Vec<f32>> = self.embed_batch(&[text.to_string()]).await?;
300 results.into_iter().next().context("No embedding returned")
301 }
302
303 pub fn dimensions(&self) -> usize {
304 self.dimensions
305 }
306
307 pub fn backend_name(&self) -> &str {
308 match &self.backend {
309 Backend::Gemini(_) => "gemini-api",
310 #[cfg(feature = "local-embeddings")]
311 Backend::Local(_) => "local-gguf",
312 }
313 }
314}
315
316pub fn chunk_text(title: &str, text: &str, max_tokens: usize, overlap: usize) -> Vec<String> {
320 let prefix = format!("title: {}\n\n", title);
321
322 if text.is_empty() {
323 return vec![format!("{}(no description)", prefix)];
324 }
325
326 let max_chars = max_tokens * 4;
327 let overlap_chars = overlap * 4;
328
329 if text.len() <= max_chars {
330 return vec![format!("{}{}", prefix, text)];
331 }
332
333 let floor_char = |s: &str, pos: usize| {
335 let pos = pos.min(s.len());
336 let mut i = pos;
337 while i > 0 && !s.is_char_boundary(i) {
338 i -= 1;
339 }
340 i
341 };
342
343 let mut chunks = Vec::new();
344 let mut start = 0;
345
346 while start < text.len() {
347 let end = floor_char(text, start + max_chars);
348
349 let chunk_slice = &text[start..end];
350 let break_at = if end < text.len() {
351 chunk_slice
352 .rfind("\n\n")
353 .or_else(|| chunk_slice.rfind('\n'))
354 .or_else(|| chunk_slice.rfind(". "))
355 .or_else(|| chunk_slice.rfind(' '))
356 .map(|p| start + p + 1)
357 .unwrap_or(end)
358 } else {
359 end
360 };
361
362 chunks.push(format!("{}{}", prefix, &text[start..break_at]));
363
364 if break_at >= text.len() {
365 break;
366 }
367
368 let new_start = floor_char(
369 text,
370 if break_at > overlap_chars {
371 break_at - overlap_chars
372 } else {
373 break_at
374 },
375 );
376 start = if new_start <= start {
378 break_at
379 } else {
380 new_start
381 };
382 }
383
384 chunks
385}
386
387pub fn embedding_to_bytes(embedding: &[f32]) -> Vec<u8> {
389 embedding.iter().flat_map(|f| f.to_le_bytes()).collect()
390}
391
392pub fn bytes_to_embedding(bytes: &[u8]) -> Vec<f32> {
394 bytes
395 .chunks_exact(4)
396 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
397 .collect()
398}
399
400pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
402 if a.len() != b.len() || a.is_empty() {
403 return 0.0;
404 }
405
406 let mut dot = 0.0f32;
407 let mut norm_a = 0.0f32;
408 let mut norm_b = 0.0f32;
409
410 for (x, y) in a.iter().zip(b.iter()) {
411 dot += x * y;
412 norm_a += x * x;
413 norm_b += y * y;
414 }
415
416 let denom = norm_a.sqrt() * norm_b.sqrt();
417 if denom == 0.0 {
418 0.0
419 } else {
420 dot / denom
421 }
422}
423
424#[cfg(test)]
425mod tests {
426 use super::*;
427
428 #[test]
429 fn compacts_invalid_api_key_errors() {
430 let body = r#"{
431 "error": {
432 "code": 400,
433 "message": "API key not valid. Please pass a valid API key.",
434 "status": "INVALID_ARGUMENT"
435 }
436 }"#;
437
438 assert_eq!(
439 summarize_gemini_error(StatusCode::BAD_REQUEST, body),
440 "Invalid API key"
441 );
442 }
443
444 #[test]
445 fn compacts_rate_limit_errors() {
446 let body = r#"{
447 "error": {
448 "code": 429,
449 "message": "Quota exceeded.",
450 "status": "RESOURCE_EXHAUSTED"
451 }
452 }"#;
453
454 assert_eq!(
455 summarize_gemini_error(StatusCode::TOO_MANY_REQUESTS, body),
456 "Rate limited"
457 );
458 }
459
460 #[test]
461 fn falls_back_to_status_for_unknown_errors() {
462 assert_eq!(
463 summarize_gemini_error(StatusCode::SERVICE_UNAVAILABLE, "not-json"),
464 "Gemini unavailable"
465 );
466 }
467}