Skip to main content

sift/
embed.rs

1use std::path::PathBuf;
2
3use anyhow::{bail, Context, Result};
4
5/// Controls which embedding backend to use.
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum EmbedBackend {
8    /// Try candle (local) first, fall back to API on failure.
9    Auto,
10    /// Use candle for local inference only.
11    Local,
12    /// Use API only.
13    Api,
14}
15
16/// Configuration for the embedding system.
17#[derive(Debug, Clone)]
18pub struct EmbedConfig {
19    /// Which backend(s) to use.
20    pub backend: EmbedBackend,
21    /// Path to a local model directory (e.g. downloaded all-MiniLM-L6-v2).
22    pub model_path: Option<PathBuf>,
23    /// API key for the fallback embedding API.
24    pub api_key: Option<String>,
25    /// API endpoint URL (default: OpenAI-compatible).
26    pub api_url: Option<String>,
27    /// Model name for the API (e.g. "text-embedding-3-small").
28    pub api_model: Option<String>,
29}
30
31impl EmbedConfig {
32    pub fn from_env() -> Self {
33        Self {
34            backend: match std::env::var("SIFT_EMBED_BACKEND")
35                .as_deref()
36                .unwrap_or("auto")
37            {
38                "local" => EmbedBackend::Local,
39                "api" => EmbedBackend::Api,
40                _ => EmbedBackend::Auto,
41            },
42            model_path: std::env::var("SIFT_EMBED_MODEL_PATH").ok().map(Into::into),
43            api_key: std::env::var("SIFT_EMBED_API_KEY").ok(),
44            api_url: std::env::var("SIFT_EMBED_API_URL").ok(),
45            api_model: Some(
46                std::env::var("SIFT_EMBED_API_MODEL")
47                    .unwrap_or_else(|_| "text-embedding-3-small".into()),
48            ),
49        }
50    }
51}
52
53// ---------------------------------------------------------------------------
54// Embedder trait
55// ---------------------------------------------------------------------------
56
57pub trait Embedder {
58    fn embed(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>>;
59}
60
61impl Embedder for AutoEmbedder {
62    fn embed(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
63        #[cfg(feature = "candle")]
64        if let Some(wrapper) = &self.local {
65            return wrapper.inner.embed_texts(texts);
66        }
67        self.api.embed(texts)
68    }
69}
70
71// ---------------------------------------------------------------------------
72// API embedder (always available)
73// ---------------------------------------------------------------------------
74
75pub struct ApiEmbedder {
76    api_key: String,
77    api_url: String,
78    model: String,
79    client: reqwest::blocking::Client,
80}
81
82impl ApiEmbedder {
83    pub fn new(config: &EmbedConfig) -> Result<Self> {
84        let api_key = config
85            .api_key
86            .clone()
87            .or_else(|| std::env::var("OPENAI_API_KEY").ok())
88            .unwrap_or_default();
89        let api_url = config
90            .api_url
91            .clone()
92            .unwrap_or_else(|| "https://api.openai.com/v1/embeddings".into());
93        let model = config.api_model.clone().unwrap_or_else(|| "text-embedding-3-small".into());
94        let client = reqwest::blocking::Client::builder()
95            .timeout(std::time::Duration::from_secs(60))
96            .build()?;
97        Ok(Self { api_key, api_url, model, client })
98    }
99}
100
101impl Embedder for ApiEmbedder {
102    fn embed(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
103        #[derive(serde::Serialize)]
104        struct Request<'a> {
105            input: Vec<&'a str>,
106            model: &'a str,
107        }
108        #[derive(serde::Deserialize)]
109        struct Response {
110            data: Vec<Data>,
111        }
112        #[derive(serde::Deserialize)]
113        struct Data {
114            embedding: Vec<f32>,
115        }
116
117        let mut req = self.client.post(&self.api_url);
118        if !self.api_key.is_empty() {
119            req = req.header("Authorization", format!("Bearer {}", self.api_key));
120        }
121        let resp = req
122            .json(&Request { input: texts.to_vec(), model: &self.model })
123            .send()
124            .context("API embedding request failed")?;
125
126        if !resp.status().is_success() {
127            let status = resp.status();
128            let body = resp.text().unwrap_or_default();
129            bail!("API embedding error ({}): {}", status, body);
130        }
131
132        let body: Response = resp.json().context("Failed to parse API embedding response")?;
133        if body.data.len() != texts.len() {
134            bail!(
135                "API returned {} embeddings for {} texts",
136                body.data.len(),
137                texts.len()
138            );
139        }
140        Ok(body.data.into_iter().map(|d| d.embedding).collect())
141    }
142}
143
144// ---------------------------------------------------------------------------
145// Candle-based local embedder (feature-gated)
146// ---------------------------------------------------------------------------
147
148#[cfg(feature = "candle")]
149pub mod local {
150    use std::path::Path;
151    use anyhow::{Context, Result};
152    use candle_core::{Device, Tensor};
153    use candle_nn::VarBuilder;
154    use candle_transformers::models::bert::{BertModel, Config, DTYPE};
155    use hf_hub::api::sync::Api;
156    use tokenizers::Tokenizer;
157
158    pub struct LocalEmbedder {
159        model: BertModel,
160        tokenizer: Tokenizer,
161        device: Device,
162    }
163
164    impl LocalEmbedder {
165        pub fn new(model_path: Option<&Path>) -> Result<Self> {
166            let device = Device::cuda_if_available(0).unwrap_or(Device::Cpu);
167
168            let (model, tokenizer) = if let Some(path) = model_path {
169                let tokenizer_path = path.join("tokenizer.json");
170                let model_path = path.join("model.safetensors");
171                let config_path = path.join("config.json");
172                let tokenizer = Tokenizer::from_file(tokenizer_path)
173                    .map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))?;
174                let config_s = std::fs::read_to_string(config_path)
175                    .context("Failed to read config.json")?;
176                let config: Config = serde_json::from_str(&config_s)
177                    .context("Failed to parse config.json")?;
178                let vb = unsafe {
179                    VarBuilder::from_mmaped_safetensors(&[model_path], DTYPE, &device)
180                        .context("Failed to load model.safetensors")?
181                };
182                let model = BertModel::load(vb, &config)?;
183                (model, tokenizer)
184            } else {
185                let api = Api::new().context("Failed to init hf-hub API")?;
186                let repo = api.model("sentence-transformers/all-MiniLM-L6-v2".into());
187                let tokenizer_path = repo.get("tokenizer.json")?;
188                let model_path = repo.get("model.safetensors")?;
189                let config_path = repo.get("config.json")?;
190                let tokenizer = Tokenizer::from_file(tokenizer_path)
191                    .map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))?;
192                let config_s = std::fs::read_to_string(config_path)
193                    .context("Failed to read config.json")?;
194                let config: Config = serde_json::from_str(&config_s)
195                    .context("Failed to parse config.json")?;
196                let vb = unsafe {
197                    VarBuilder::from_mmaped_safetensors(&[model_path], DTYPE, &device)
198                        .context("Failed to load model.safetensors")?
199                };
200                let model = BertModel::load(vb, &config)?;
201                (model, tokenizer)
202            };
203
204            Ok(Self { model, tokenizer, device })
205        }
206
207        fn mean_pool(&self, token_embeddings: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
208            let mask = attention_mask
209                .unsqueeze(2)?
210                .to_dtype(candle_core::DType::F32)?;
211            let sum_embeddings = (token_embeddings * &mask)?.sum(1)?;
212            let sum_mask = mask.sum(1)?;
213            let mean = (&sum_embeddings / &sum_mask)?;
214            Ok(mean)
215        }
216
217        fn normalize(&self, v: &Tensor) -> Result<Tensor> {
218            let norm = v.sqr()?.sum_keepdim(1)?.sqrt()?;
219            Ok((v / norm)?)
220        }
221
222        pub fn embed_texts(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
223            let max_length = 128;
224            let mut all_embeddings = Vec::with_capacity(texts.len());
225
226            for text in texts {
227                let encoding = self
228                    .tokenizer
229                    .encode(*text, true)
230                    .map_err(|e| anyhow::anyhow!("Tokenization failed: {}", e))?;
231
232                let token_ids = encoding
233                    .get_ids()
234                    .iter()
235                    .map(|&id| id as u32)
236                    .collect::<Vec<_>>();
237                let attention = encoding
238                    .get_attention_mask()
239                    .iter()
240                    .map(|&m| m as u32)
241                    .collect::<Vec<_>>();
242
243                let token_ids = if token_ids.len() > max_length {
244                    let mut t = vec![token_ids[0]];
245                    t.extend_from_slice(&token_ids[1..max_length - 1]);
246                    t.push(token_ids[token_ids.len() - 1]);
247                    t
248                } else {
249                    token_ids
250                };
251
252                let seq_len = token_ids.len();
253                let input = Tensor::new(token_ids.as_slice(), &self.device)?.unsqueeze(0)?;
254                let mask = if attention.len() > token_ids.len() {
255                    Tensor::new(&attention[..seq_len], &self.device)?.unsqueeze(0)?
256                } else {
257                    Tensor::new(attention.as_slice(), &self.device)?.unsqueeze(0)?
258                };
259                let type_ids = input.zeros_like()?;
260
261                let output = self.model.forward(&input, &type_ids, Some(&mask))?;
262                let pooled = self.mean_pool(&output, &mask)?;
263                let normalized = self.normalize(&pooled)?;
264
265                let vec: Vec<f32> = normalized.to_vec1()?;
266                all_embeddings.push(vec);
267            }
268
269            Ok(all_embeddings)
270        }
271    }
272}
273
274// ---------------------------------------------------------------------------
275// Auto embedder: try local, fall back to API
276// ---------------------------------------------------------------------------
277
278#[allow(dead_code)]
279pub struct AutoEmbedder {
280    local: Option<LocalWrapper>,
281    api: ApiEmbedder,
282}
283
284struct LocalWrapper {
285    #[cfg(feature = "candle")]
286    inner: local::LocalEmbedder,
287}
288
289impl AutoEmbedder {
290    pub fn new(config: &EmbedConfig) -> Result<Self> {
291        let local = match config.backend {
292            EmbedBackend::Api => None,
293            _ => Self::try_local(config),
294        };
295        let api = ApiEmbedder::new(config)?;
296        Ok(Self { local, api })
297    }
298
299    #[cfg(feature = "candle")]
300    fn try_local(config: &EmbedConfig) -> Option<LocalWrapper> {
301        match local::LocalEmbedder::new(config.model_path.as_deref()) {
302            Ok(inner) => {
303                eprintln!("[sift] using local embedding model (candle)");
304                Some(LocalWrapper { inner })
305            }
306            Err(e) => {
307                eprintln!("[sift] local model unavailable ({}), falling back to API", e);
308                None
309            }
310        }
311    }
312
313    #[cfg(not(feature = "candle"))]
314    fn try_local(_config: &EmbedConfig) -> Option<LocalWrapper> {
315        None
316    }
317
318}
319
320// ---------------------------------------------------------------------------
321// Cosine similarity
322// ---------------------------------------------------------------------------
323
324pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
325    let dot: f64 = a.iter().zip(b).map(|(x, y)| *x as f64 * *y as f64).sum();
326    let na: f64 = a.iter().map(|x| *x as f64 * *x as f64).sum::<f64>().sqrt();
327    let nb: f64 = b.iter().map(|x| *x as f64 * *x as f64).sum::<f64>().sqrt();
328    if na == 0.0 || nb == 0.0 {
329        0.0
330    } else {
331        dot / (na * nb)
332    }
333}
334
335pub fn top_k_similar(query: &[f32], candidates: &[(usize, &[f32])], k: usize) -> Vec<(usize, f64)> {
336    let mut scores: Vec<(usize, f64)> = candidates
337        .iter()
338        .map(|(id, vec)| (*id, cosine_similarity(query, vec)))
339        .collect();
340    scores.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
341    scores.truncate(k);
342    scores
343}