Skip to main content

sift/
embed.rs

1use std::path::{Path, PathBuf};
2
3use anyhow::{bail, Context, Result};
4use serde::Deserialize;
5
6/// Controls which embedding backend to use.
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum EmbedBackend {
9    /// Try candle (local) first, fall back to API on failure.
10    Auto,
11    /// Use candle for local inference only.
12    Local,
13    /// Use API only.
14    Api,
15}
16
17/// Configuration for the embedding system.
18#[derive(Debug, Clone)]
19pub struct EmbedConfig {
20    /// Which backend(s) to use.
21    pub backend: EmbedBackend,
22    /// Path to a local model directory (e.g. downloaded all-MiniLM-L6-v2).
23    pub model_path: Option<PathBuf>,
24    /// API key for the fallback embedding API.
25    pub api_key: Option<String>,
26    /// API endpoint URL (default: OpenAI-compatible).
27    pub api_url: Option<String>,
28    /// Model name for the API (e.g. "text-embedding-3-small").
29    pub api_model: Option<String>,
30}
31
32/// TOML config file shape (all fields optional for merge semantics).
33#[derive(Debug, Default, Deserialize)]
34struct TomlConfig {
35    #[serde(default)]
36    embed: TomlEmbed,
37}
38
39#[derive(Debug, Default, Deserialize)]
40struct TomlEmbed {
41    backend: Option<String>,
42    model_path: Option<String>,
43    api_key: Option<String>,
44    api_url: Option<String>,
45    api_model: Option<String>,
46}
47
48impl EmbedConfig {
49    /// Load config from TOML files + env vars.
50    ///
51    /// Precedence (later wins):
52    ///   1. Hardcoded defaults
53    ///   2. `~/.config/sift/config.toml`
54    ///   3. `.sift/config.toml` (project-level, relative to cwd)
55    ///   4. `SIFT_EMBED_*` environment variables
56    pub fn load() -> Self {
57        let mut config = Self::defaults();
58
59        if let Some(cfg) = Self::load_toml(&Self::user_config_path()) {
60            config.apply_toml(&cfg);
61        }
62        if let Some(cfg) = Self::load_toml(&Self::project_config_path()) {
63            config.apply_toml(&cfg);
64        }
65
66        config.apply_env();
67        config
68    }
69
70    /// Read config from env vars only (legacy / programmatic use).
71    pub fn from_env() -> Self {
72        let mut config = Self::defaults();
73        config.apply_env();
74        config
75    }
76
77    fn defaults() -> Self {
78        Self {
79            backend: EmbedBackend::Auto,
80            model_path: None,
81            api_key: None,
82            api_url: None,
83            api_model: Some("text-embedding-3-small".into()),
84        }
85    }
86
87    fn user_config_path() -> PathBuf {
88        let base = std::env::var("XDG_CONFIG_HOME")
89            .map(PathBuf::from)
90            .ok()
91            .unwrap_or_else(|| {
92                let home = std::env::var("HOME").unwrap_or_else(|_| ".".into());
93                PathBuf::from(home).join(".config")
94            });
95        base.join("sift").join("config.toml")
96    }
97
98    fn project_config_path() -> PathBuf {
99        std::env::current_dir()
100            .unwrap_or_else(|_| PathBuf::from("."))
101            .join(".sift")
102            .join("config.toml")
103    }
104
105    fn load_toml(path: &Path) -> Option<TomlConfig> {
106        if !path.exists() {
107            return None;
108        }
109        let content = std::fs::read_to_string(path).ok()?;
110        toml::from_str(&content).ok()
111    }
112
113    fn apply_toml(&mut self, cfg: &TomlConfig) {
114        if let Some(ref backend) = cfg.embed.backend {
115            self.backend = match backend.as_str() {
116                "local" => EmbedBackend::Local,
117                "api" => EmbedBackend::Api,
118                _ => EmbedBackend::Auto,
119            };
120        }
121        if let Some(ref v) = cfg.embed.model_path {
122            self.model_path = Some(v.into());
123        }
124        if let Some(ref v) = cfg.embed.api_key {
125            self.api_key = Some(v.clone());
126        }
127        if let Some(ref v) = cfg.embed.api_url {
128            self.api_url = Some(v.clone());
129        }
130        if let Some(ref v) = cfg.embed.api_model {
131            self.api_model = Some(v.clone());
132        }
133    }
134
135    fn apply_env(&mut self) {
136        if let Ok(v) = std::env::var("SIFT_EMBED_BACKEND") {
137            self.backend = match v.as_str() {
138                "local" => EmbedBackend::Local,
139                "api" => EmbedBackend::Api,
140                _ => EmbedBackend::Auto,
141            };
142        }
143        if let Ok(v) = std::env::var("SIFT_EMBED_MODEL_PATH") {
144            self.model_path = Some(v.into());
145        }
146        if let Ok(v) = std::env::var("SIFT_EMBED_API_KEY") {
147            self.api_key = Some(v);
148        }
149        if let Ok(v) = std::env::var("SIFT_EMBED_API_URL") {
150            self.api_url = Some(v);
151        }
152        if let Ok(v) = std::env::var("SIFT_EMBED_API_MODEL") {
153            self.api_model = Some(v);
154        }
155    }
156}
157
158// ---------------------------------------------------------------------------
159// Embedder trait
160// ---------------------------------------------------------------------------
161
162pub trait Embedder {
163    fn embed(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>>;
164}
165
166impl Embedder for AutoEmbedder {
167    fn embed(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
168        #[cfg(feature = "candle")]
169        if let Some(wrapper) = &self.local {
170            return wrapper.inner.embed_texts(texts);
171        }
172        self.api.embed(texts)
173    }
174}
175
176// ---------------------------------------------------------------------------
177// API embedder (always available)
178// ---------------------------------------------------------------------------
179
180pub struct ApiEmbedder {
181    api_key: String,
182    api_url: String,
183    model: String,
184    client: reqwest::blocking::Client,
185}
186
187impl ApiEmbedder {
188    pub fn new(config: &EmbedConfig) -> Result<Self> {
189        let api_key = config
190            .api_key
191            .clone()
192            .or_else(|| std::env::var("OPENAI_API_KEY").ok())
193            .unwrap_or_default();
194        let api_url = config
195            .api_url
196            .clone()
197            .unwrap_or_else(|| "https://api.openai.com/v1/embeddings".into());
198        let model = config.api_model.clone().unwrap_or_else(|| "text-embedding-3-small".into());
199        let client = reqwest::blocking::Client::builder()
200            .timeout(std::time::Duration::from_secs(60))
201            .build()?;
202        Ok(Self { api_key, api_url, model, client })
203    }
204}
205
206impl Embedder for ApiEmbedder {
207    fn embed(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
208        #[derive(serde::Serialize)]
209        struct Request<'a> {
210            input: Vec<&'a str>,
211            model: &'a str,
212        }
213        #[derive(serde::Deserialize)]
214        struct Response {
215            data: Vec<Data>,
216        }
217        #[derive(serde::Deserialize)]
218        struct Data {
219            embedding: Vec<f32>,
220        }
221
222        let mut req = self.client.post(&self.api_url);
223        if !self.api_key.is_empty() {
224            req = req.header("Authorization", format!("Bearer {}", self.api_key));
225        }
226        let resp = req
227            .json(&Request { input: texts.to_vec(), model: &self.model })
228            .send()
229            .context("API embedding request failed")?;
230
231        if !resp.status().is_success() {
232            let status = resp.status();
233            let body = resp.text().unwrap_or_default();
234            bail!("API embedding error ({}): {}", status, body);
235        }
236
237        let body: Response = resp.json().context("Failed to parse API embedding response")?;
238        if body.data.len() != texts.len() {
239            bail!(
240                "API returned {} embeddings for {} texts",
241                body.data.len(),
242                texts.len()
243            );
244        }
245        Ok(body.data.into_iter().map(|d| d.embedding).collect())
246    }
247}
248
249// ---------------------------------------------------------------------------
250// Candle-based local embedder (feature-gated)
251// ---------------------------------------------------------------------------
252
253#[cfg(feature = "candle")]
254pub mod local {
255    use std::path::Path;
256    use anyhow::{Context, Result};
257    use candle_core::{Device, Tensor};
258    use candle_nn::VarBuilder;
259    use candle_transformers::models::bert::{BertModel, Config, DTYPE};
260    use hf_hub::api::sync::Api;
261    use tokenizers::Tokenizer;
262
263    pub struct LocalEmbedder {
264        model: BertModel,
265        tokenizer: Tokenizer,
266        device: Device,
267    }
268
269    impl LocalEmbedder {
270        pub fn new(model_path: Option<&Path>) -> Result<Self> {
271            let device = Device::cuda_if_available(0).unwrap_or(Device::Cpu);
272
273            let (model, tokenizer) = if let Some(path) = model_path {
274                let tokenizer_path = path.join("tokenizer.json");
275                let model_path = path.join("model.safetensors");
276                let config_path = path.join("config.json");
277                let tokenizer = Tokenizer::from_file(tokenizer_path)
278                    .map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))?;
279                let config_s = std::fs::read_to_string(config_path)
280                    .context("Failed to read config.json")?;
281                let config: Config = serde_json::from_str(&config_s)
282                    .context("Failed to parse config.json")?;
283                let vb = unsafe {
284                    VarBuilder::from_mmaped_safetensors(&[model_path], DTYPE, &device)
285                        .context("Failed to load model.safetensors")?
286                };
287                let model = BertModel::load(vb, &config)?;
288                (model, tokenizer)
289            } else {
290                let api = Api::new().context("Failed to init hf-hub API")?;
291                let repo = api.model("sentence-transformers/all-MiniLM-L6-v2".into());
292                let tokenizer_path = repo.get("tokenizer.json")?;
293                let model_path = repo.get("model.safetensors")?;
294                let config_path = repo.get("config.json")?;
295                let tokenizer = Tokenizer::from_file(tokenizer_path)
296                    .map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))?;
297                let config_s = std::fs::read_to_string(config_path)
298                    .context("Failed to read config.json")?;
299                let config: Config = serde_json::from_str(&config_s)
300                    .context("Failed to parse config.json")?;
301                let vb = unsafe {
302                    VarBuilder::from_mmaped_safetensors(&[model_path], DTYPE, &device)
303                        .context("Failed to load model.safetensors")?
304                };
305                let model = BertModel::load(vb, &config)?;
306                (model, tokenizer)
307            };
308
309            Ok(Self { model, tokenizer, device })
310        }
311
312        fn mean_pool(&self, token_embeddings: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
313            let (_b, _s, h) = token_embeddings.shape().dims3()?;
314            let mask = attention_mask.to_dtype(candle_core::DType::F32)?;
315            let sum_emb = mask.unsqueeze(1)?.matmul(token_embeddings)?.squeeze(1)?;
316            let count = mask.sum(1)?;
317            let count_val = count.squeeze(0)?.to_vec0::<f32>()?;
318            if count_val == 0.0 {
319                return Ok(Tensor::zeros((1, h), candle_core::DType::F32, &self.device)?);
320            }
321            let result = (&sum_emb / &Tensor::full(count_val, (1, h), &self.device)?)?;
322            Ok(result)
323        }
324
325        fn normalize(&self, v: &Tensor) -> Result<Tensor> {
326            let (_b, h) = v.shape().dims2()?;
327            let sq_sum: f32 = v.sqr()?.sum(1)?.squeeze(0)?.to_vec0::<f32>()?;
328            let norm = sq_sum.sqrt();
329            if norm == 0.0 {
330                return Ok(v.clone());
331            }
332            Ok((v / &Tensor::full(norm, (1, h), &self.device)?)?)
333        }
334
335        pub fn embed_texts(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
336            let max_length = 128;
337            let mut all_embeddings = Vec::with_capacity(texts.len());
338
339            for text in texts {
340                let encoding = self
341                    .tokenizer
342                    .encode(*text, true)
343                    .map_err(|e| anyhow::anyhow!("Tokenization failed: {}", e))?;
344
345                let token_ids = encoding
346                    .get_ids()
347                    .iter()
348                    .map(|&id| id as u32)
349                    .collect::<Vec<_>>();
350                let attention = encoding
351                    .get_attention_mask()
352                    .iter()
353                    .map(|&m| m as u32)
354                    .collect::<Vec<_>>();
355
356                let token_ids = if token_ids.len() > max_length {
357                    let mut t = vec![token_ids[0]];
358                    t.extend_from_slice(&token_ids[1..max_length - 1]);
359                    t.push(token_ids[token_ids.len() - 1]);
360                    t
361                } else {
362                    token_ids
363                };
364
365                let seq_len = token_ids.len();
366                let input = Tensor::new(token_ids.as_slice(), &self.device)?.unsqueeze(0)?;
367                let mask = if attention.len() > token_ids.len() {
368                    Tensor::new(&attention[..seq_len], &self.device)?.unsqueeze(0)?
369                } else {
370                    Tensor::new(attention.as_slice(), &self.device)?.unsqueeze(0)?
371                };
372                let type_ids = input.zeros_like()?;
373
374                let output = self.model.forward(&input, &type_ids, Some(&mask))?;
375                let pooled = self.mean_pool(&output, &mask)?;
376                let normalized = self.normalize(&pooled)?;
377
378                let vec: Vec<f32> = normalized.squeeze(0)?.to_vec1()?;
379                all_embeddings.push(vec);
380            }
381
382            Ok(all_embeddings)
383        }
384    }
385}
386
387// ---------------------------------------------------------------------------
388// Auto embedder: try local, fall back to API
389// ---------------------------------------------------------------------------
390
391#[allow(dead_code)]
392pub struct AutoEmbedder {
393    local: Option<LocalWrapper>,
394    api: ApiEmbedder,
395}
396
397struct LocalWrapper {
398    #[cfg(feature = "candle")]
399    inner: local::LocalEmbedder,
400}
401
402impl AutoEmbedder {
403    pub fn new(config: &EmbedConfig) -> Result<Self> {
404        let local = match config.backend {
405            EmbedBackend::Api => None,
406            _ => Self::try_local(config),
407        };
408        let api = ApiEmbedder::new(config)?;
409        Ok(Self { local, api })
410    }
411
412    #[cfg(feature = "candle")]
413    fn try_local(config: &EmbedConfig) -> Option<LocalWrapper> {
414        match local::LocalEmbedder::new(config.model_path.as_deref()) {
415            Ok(inner) => {
416                eprintln!("[sift] using local embedding model (candle)");
417                Some(LocalWrapper { inner })
418            }
419            Err(e) => {
420                eprintln!("[sift] local model unavailable ({}), falling back to API", e);
421                None
422            }
423        }
424    }
425
426    #[cfg(not(feature = "candle"))]
427    fn try_local(_config: &EmbedConfig) -> Option<LocalWrapper> {
428        None
429    }
430
431}
432
433// ---------------------------------------------------------------------------
434// Cosine similarity
435// ---------------------------------------------------------------------------
436
437pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
438    let dot: f64 = a.iter().zip(b).map(|(x, y)| *x as f64 * *y as f64).sum();
439    let na: f64 = a.iter().map(|x| *x as f64 * *x as f64).sum::<f64>().sqrt();
440    let nb: f64 = b.iter().map(|x| *x as f64 * *x as f64).sum::<f64>().sqrt();
441    if na == 0.0 || nb == 0.0 {
442        0.0
443    } else {
444        dot / (na * nb)
445    }
446}
447
448pub fn top_k_similar(query: &[f32], candidates: &[(usize, &[f32])], k: usize) -> Vec<(usize, f64)> {
449    let mut scores: Vec<(usize, f64)> = candidates
450        .iter()
451        .map(|(id, vec)| (*id, cosine_similarity(query, vec)))
452        .collect();
453    scores.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
454    scores.truncate(k);
455    scores
456}