Skip to main content

rectilinear_core/embedding/
mod.rs

1#[cfg(feature = "local-embeddings")]
2mod local;
3
4use anyhow::{Context, Result};
5use reqwest::StatusCode;
6use serde::Deserialize;
7
8use crate::config::{Config, EmbeddingBackend};
9
10enum Backend {
11    Gemini(GeminiBackend),
12    #[cfg(feature = "local-embeddings")]
13    Local(local::LocalBackend),
14}
15
16pub struct Embedder {
17    backend: Backend,
18    dimensions: usize,
19}
20
21// --- Gemini API backend ---
22
23struct GeminiBackend {
24    client: reqwest::Client,
25    api_key: String,
26}
27
28#[derive(Deserialize)]
29struct GeminiModelsResponse {
30    models: Vec<GeminiModel>,
31}
32
33#[derive(Deserialize)]
34struct GeminiModel {
35    #[allow(dead_code)]
36    name: String,
37}
38
39#[derive(Deserialize)]
40struct GeminiErrorEnvelope {
41    error: GeminiErrorBody,
42}
43
44#[derive(Deserialize)]
45struct GeminiErrorBody {
46    #[allow(dead_code)]
47    code: Option<u16>,
48    message: Option<String>,
49    status: Option<String>,
50}
51
52impl GeminiBackend {
53    fn new(api_key: &str) -> Self {
54        Self {
55            client: reqwest::Client::new(),
56            api_key: api_key.to_string(),
57        }
58    }
59
60    fn with_http_client(client: reqwest::Client, api_key: &str) -> Self {
61        Self {
62            client,
63            api_key: api_key.to_string(),
64        }
65    }
66
67    async fn test_api_key(&self) -> Result<()> {
68        let resp = self
69            .client
70            .get("https://generativelanguage.googleapis.com/v1beta/models?pageSize=1")
71            .header("x-goog-api-key", &self.api_key)
72            .send()
73            .await
74            .context("Failed to call Gemini models API")?;
75
76        let status = resp.status();
77        let body = resp
78            .text()
79            .await
80            .context("Failed to read Gemini models response")?;
81
82        if !status.is_success() {
83            anyhow::bail!("{}", summarize_gemini_error(status, &body));
84        }
85
86        let response: GeminiModelsResponse =
87            serde_json::from_str(&body).context("Failed to parse Gemini models response")?;
88        if response.models.is_empty() {
89            anyhow::bail!("Gemini returned no models");
90        }
91
92        Ok(())
93    }
94
95    async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
96        let url = format!(
97            "https://generativelanguage.googleapis.com/v1beta/models/gemini-embedding-2-preview:batchEmbedContents?key={}",
98            self.api_key
99        );
100
101        let mut all_embeddings = Vec::new();
102        for batch in texts.chunks(100) {
103            let requests: Vec<_> = batch
104                .iter()
105                .map(|text| {
106                    serde_json::json!({
107                        "model": "models/gemini-embedding-2-preview",
108                        "content": {
109                            "parts": [{"text": text}]
110                        },
111                        "outputDimensionality": 768
112                    })
113                })
114                .collect();
115
116            let body = serde_json::json!({ "requests": requests });
117
118            let resp = self
119                .client
120                .post(&url)
121                .json(&body)
122                .send()
123                .await
124                .context("Failed to call Gemini embedding API")?;
125
126            let status = resp.status();
127            if !status.is_success() {
128                let text = resp.text().await.unwrap_or_default();
129                anyhow::bail!("Gemini API returned {}: {}", status, text);
130            }
131
132            let data: serde_json::Value = resp.json().await?;
133            let embeddings = data["embeddings"]
134                .as_array()
135                .context("No embeddings in response")?;
136
137            for emb in embeddings {
138                let values: Vec<f32> = emb["values"]
139                    .as_array()
140                    .context("No values in embedding")?
141                    .iter()
142                    .map(|v| v.as_f64().unwrap_or(0.0) as f32)
143                    .collect();
144                all_embeddings.push(values);
145            }
146        }
147
148        Ok(all_embeddings)
149    }
150}
151
152fn summarize_gemini_error(status: StatusCode, body: &str) -> String {
153    if let Ok(error) = serde_json::from_str::<GeminiErrorEnvelope>(body) {
154        return compact_gemini_error(
155            status,
156            error.error.message.as_deref(),
157            error.error.status.as_deref(),
158        );
159    }
160
161    compact_gemini_error(status, None, None)
162}
163
164fn compact_gemini_error(
165    status: StatusCode,
166    message: Option<&str>,
167    api_status: Option<&str>,
168) -> String {
169    let normalized = message.unwrap_or("").trim().to_lowercase();
170
171    if normalized.contains("api key not valid") || normalized.contains("invalid api key") {
172        return "Invalid API key".into();
173    }
174
175    if normalized.contains("reported as leaked") || normalized.contains("disabled") {
176        return "Key blocked".into();
177    }
178
179    if normalized.contains("billing")
180        || matches!(api_status, Some("FAILED_PRECONDITION" | "SERVICE_DISABLED"))
181    {
182        return "Setup required".into();
183    }
184
185    if status == StatusCode::UNAUTHORIZED || matches!(api_status, Some("UNAUTHENTICATED")) {
186        return "Unauthorized".into();
187    }
188
189    if status == StatusCode::FORBIDDEN || matches!(api_status, Some("PERMISSION_DENIED")) {
190        return "Access denied".into();
191    }
192
193    if status == StatusCode::TOO_MANY_REQUESTS || matches!(api_status, Some("RESOURCE_EXHAUSTED")) {
194        return "Rate limited".into();
195    }
196
197    if status.is_server_error() {
198        return "Gemini unavailable".into();
199    }
200
201    if let Some(message) = message {
202        let trimmed = message.trim();
203        if !trimmed.is_empty() && trimmed.len() <= 48 {
204            return trimmed.to_string();
205        }
206    }
207
208    status
209        .canonical_reason()
210        .unwrap_or("Request failed")
211        .to_string()
212}
213
214// --- Embedder (main interface) ---
215
216impl Embedder {
217    pub fn new(config: &Config) -> Result<Self> {
218        let gemini_key = std::env::var("GEMINI_API_KEY")
219            .ok()
220            .or_else(|| config.embedding.gemini_api_key.clone());
221
222        match config.embedding.backend {
223            EmbeddingBackend::Api => {
224                let key = gemini_key.context(
225                    "Gemini API key required for API backend. Set GEMINI_API_KEY or configure in config.",
226                )?;
227                Self::new_api(&key)
228            }
229            #[cfg(feature = "local-embeddings")]
230            EmbeddingBackend::Local => {
231                let backend = local::LocalBackend::new(config)?;
232                let dimensions = backend.dimensions();
233                Ok(Self {
234                    dimensions,
235                    backend: Backend::Local(backend),
236                })
237            }
238            #[cfg(not(feature = "local-embeddings"))]
239            EmbeddingBackend::Local => {
240                anyhow::bail!(
241                    "Local embeddings not available — compile with `local-embeddings` feature"
242                )
243            }
244        }
245    }
246
247    /// Create an embedder using the Gemini API backend.
248    pub fn new_api(api_key: &str) -> Result<Self> {
249        Ok(Self {
250            dimensions: 768,
251            backend: Backend::Gemini(GeminiBackend::new(api_key)),
252        })
253    }
254
255    /// Create an embedder using the Gemini API backend with a pre-built HTTP client.
256    pub fn new_api_with_http_client(client: reqwest::Client, api_key: &str) -> Result<Self> {
257        Ok(Self {
258            dimensions: 768,
259            backend: Backend::Gemini(GeminiBackend::with_http_client(client, api_key)),
260        })
261    }
262
263    pub async fn test_api_key(&self) -> Result<()> {
264        match &self.backend {
265            Backend::Gemini(b) => b.test_api_key().await,
266            #[cfg(feature = "local-embeddings")]
267            Backend::Local(_) => anyhow::bail!("Gemini API key not in use"),
268        }
269    }
270
271    /// Create an embedder using the local GGUF backend.
272    #[cfg(feature = "local-embeddings")]
273    pub fn new_local(_models_dir: &std::path::Path) -> Result<Self> {
274        // TODO: pass models_dir through to LocalBackend instead of using Config default
275        let config = Config {
276            embedding: crate::config::EmbeddingConfig {
277                backend: EmbeddingBackend::Local,
278                gemini_api_key: None,
279            },
280            ..Config::default()
281        };
282        let backend = local::LocalBackend::new(&config)?;
283        let dimensions = backend.dimensions();
284        Ok(Self {
285            dimensions,
286            backend: Backend::Local(backend),
287        })
288    }
289
290    pub async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
291        match &self.backend {
292            Backend::Gemini(b) => b.embed_batch(texts).await,
293            #[cfg(feature = "local-embeddings")]
294            Backend::Local(b) => b.embed_batch(texts),
295        }
296    }
297
298    pub async fn embed_single(&self, text: &str) -> Result<Vec<f32>> {
299        let results: Vec<Vec<f32>> = self.embed_batch(&[text.to_string()]).await?;
300        results.into_iter().next().context("No embedding returned")
301    }
302
303    pub fn dimensions(&self) -> usize {
304        self.dimensions
305    }
306
307    pub fn backend_name(&self) -> &str {
308        match &self.backend {
309            Backend::Gemini(_) => "gemini-api",
310            #[cfg(feature = "local-embeddings")]
311            Backend::Local(_) => "local-gguf",
312        }
313    }
314}
315
316// --- Text chunking ---
317
318/// Chunk text into segments of approximately `max_tokens` tokens with `overlap` token overlap.
319pub fn chunk_text(title: &str, text: &str, max_tokens: usize, overlap: usize) -> Vec<String> {
320    let prefix = format!("title: {}\n\n", title);
321
322    if text.is_empty() {
323        return vec![format!("{}(no description)", prefix)];
324    }
325
326    let max_chars = max_tokens * 4;
327    let overlap_chars = overlap * 4;
328
329    if text.len() <= max_chars {
330        return vec![format!("{}{}", prefix, text)];
331    }
332
333    // Snap a byte offset to the nearest char boundary (rounding down)
334    let floor_char = |s: &str, pos: usize| {
335        let pos = pos.min(s.len());
336        let mut i = pos;
337        while i > 0 && !s.is_char_boundary(i) {
338            i -= 1;
339        }
340        i
341    };
342
343    let mut chunks = Vec::new();
344    let mut start = 0;
345
346    while start < text.len() {
347        let end = floor_char(text, start + max_chars);
348
349        let chunk_slice = &text[start..end];
350        let break_at = if end < text.len() {
351            chunk_slice
352                .rfind("\n\n")
353                .or_else(|| chunk_slice.rfind('\n'))
354                .or_else(|| chunk_slice.rfind(". "))
355                .or_else(|| chunk_slice.rfind(' '))
356                .map(|p| start + p + 1)
357                .unwrap_or(end)
358        } else {
359            end
360        };
361
362        chunks.push(format!("{}{}", prefix, &text[start..break_at]));
363
364        if break_at >= text.len() {
365            break;
366        }
367
368        let new_start = floor_char(
369            text,
370            if break_at > overlap_chars {
371                break_at - overlap_chars
372            } else {
373                break_at
374            },
375        );
376        // Ensure forward progress — overlap must never push start backwards
377        start = if new_start <= start {
378            break_at
379        } else {
380            new_start
381        };
382    }
383
384    chunks
385}
386
387/// Convert f32 embedding to bytes for storage
388pub fn embedding_to_bytes(embedding: &[f32]) -> Vec<u8> {
389    embedding.iter().flat_map(|f| f.to_le_bytes()).collect()
390}
391
392/// Convert bytes back to f32 embedding
393pub fn bytes_to_embedding(bytes: &[u8]) -> Vec<f32> {
394    bytes
395        .chunks_exact(4)
396        .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
397        .collect()
398}
399
400/// Cosine similarity between two embeddings
401pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
402    if a.len() != b.len() || a.is_empty() {
403        return 0.0;
404    }
405
406    let mut dot = 0.0f32;
407    let mut norm_a = 0.0f32;
408    let mut norm_b = 0.0f32;
409
410    for (x, y) in a.iter().zip(b.iter()) {
411        dot += x * y;
412        norm_a += x * x;
413        norm_b += y * y;
414    }
415
416    let denom = norm_a.sqrt() * norm_b.sqrt();
417    if denom == 0.0 {
418        0.0
419    } else {
420        dot / denom
421    }
422}
423
424#[cfg(test)]
425mod tests {
426    use super::*;
427
428    #[test]
429    fn compacts_invalid_api_key_errors() {
430        let body = r#"{
431            "error": {
432                "code": 400,
433                "message": "API key not valid. Please pass a valid API key.",
434                "status": "INVALID_ARGUMENT"
435            }
436        }"#;
437
438        assert_eq!(
439            summarize_gemini_error(StatusCode::BAD_REQUEST, body),
440            "Invalid API key"
441        );
442    }
443
444    #[test]
445    fn compacts_rate_limit_errors() {
446        let body = r#"{
447            "error": {
448                "code": 429,
449                "message": "Quota exceeded.",
450                "status": "RESOURCE_EXHAUSTED"
451            }
452        }"#;
453
454        assert_eq!(
455            summarize_gemini_error(StatusCode::TOO_MANY_REQUESTS, body),
456            "Rate limited"
457        );
458    }
459
460    #[test]
461    fn falls_back_to_status_for_unknown_errors() {
462        assert_eq!(
463            summarize_gemini_error(StatusCode::SERVICE_UNAVAILABLE, "not-json"),
464            "Gemini unavailable"
465        );
466    }
467}