1use std::path::{Path, PathBuf};
2
3use anyhow::{bail, Context, Result};
4use serde::Deserialize;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum EmbedBackend {
9 Auto,
11 Local,
13 Api,
15}
16
17#[derive(Debug, Clone)]
19pub struct EmbedConfig {
20 pub backend: EmbedBackend,
22 pub model_path: Option<PathBuf>,
24 pub api_key: Option<String>,
26 pub api_url: Option<String>,
28 pub api_model: Option<String>,
30}
31
32#[derive(Debug, Default, Deserialize)]
34struct TomlConfig {
35 #[serde(default)]
36 embed: TomlEmbed,
37}
38
39#[derive(Debug, Default, Deserialize)]
40struct TomlEmbed {
41 backend: Option<String>,
42 model_path: Option<String>,
43 api_key: Option<String>,
44 api_url: Option<String>,
45 api_model: Option<String>,
46}
47
48impl EmbedConfig {
49 pub fn load() -> Self {
57 let mut config = Self::defaults();
58
59 if let Some(cfg) = Self::load_toml(&Self::user_config_path()) {
60 config.apply_toml(&cfg);
61 }
62 if let Some(cfg) = Self::load_toml(&Self::project_config_path()) {
63 config.apply_toml(&cfg);
64 }
65
66 config.apply_env();
67 config
68 }
69
70 pub fn from_env() -> Self {
72 let mut config = Self::defaults();
73 config.apply_env();
74 config
75 }
76
77 fn defaults() -> Self {
78 Self {
79 backend: EmbedBackend::Auto,
80 model_path: None,
81 api_key: None,
82 api_url: None,
83 api_model: Some("text-embedding-3-small".into()),
84 }
85 }
86
87 fn user_config_path() -> PathBuf {
88 let base = std::env::var("XDG_CONFIG_HOME")
89 .map(PathBuf::from)
90 .ok()
91 .unwrap_or_else(|| {
92 let home = std::env::var("HOME").unwrap_or_else(|_| ".".into());
93 PathBuf::from(home).join(".config")
94 });
95 base.join("sift").join("config.toml")
96 }
97
98 fn project_config_path() -> PathBuf {
99 std::env::current_dir()
100 .unwrap_or_else(|_| PathBuf::from("."))
101 .join(".sift")
102 .join("config.toml")
103 }
104
105 fn load_toml(path: &Path) -> Option<TomlConfig> {
106 if !path.exists() {
107 return None;
108 }
109 let content = std::fs::read_to_string(path).ok()?;
110 toml::from_str(&content).ok()
111 }
112
113 fn apply_toml(&mut self, cfg: &TomlConfig) {
114 if let Some(ref backend) = cfg.embed.backend {
115 self.backend = match backend.as_str() {
116 "local" => EmbedBackend::Local,
117 "api" => EmbedBackend::Api,
118 _ => EmbedBackend::Auto,
119 };
120 }
121 if let Some(ref v) = cfg.embed.model_path {
122 self.model_path = Some(v.into());
123 }
124 if let Some(ref v) = cfg.embed.api_key {
125 self.api_key = Some(v.clone());
126 }
127 if let Some(ref v) = cfg.embed.api_url {
128 self.api_url = Some(v.clone());
129 }
130 if let Some(ref v) = cfg.embed.api_model {
131 self.api_model = Some(v.clone());
132 }
133 }
134
135 fn apply_env(&mut self) {
136 if let Ok(v) = std::env::var("SIFT_EMBED_BACKEND") {
137 self.backend = match v.as_str() {
138 "local" => EmbedBackend::Local,
139 "api" => EmbedBackend::Api,
140 _ => EmbedBackend::Auto,
141 };
142 }
143 if let Ok(v) = std::env::var("SIFT_EMBED_MODEL_PATH") {
144 self.model_path = Some(v.into());
145 }
146 if let Ok(v) = std::env::var("SIFT_EMBED_API_KEY") {
147 self.api_key = Some(v);
148 }
149 if let Ok(v) = std::env::var("SIFT_EMBED_API_URL") {
150 self.api_url = Some(v);
151 }
152 if let Ok(v) = std::env::var("SIFT_EMBED_API_MODEL") {
153 self.api_model = Some(v);
154 }
155 }
156}
157
158pub trait Embedder {
163 fn embed(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>>;
164}
165
166impl Embedder for AutoEmbedder {
167 fn embed(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
168 #[cfg(feature = "candle")]
169 if let Some(wrapper) = &self.local {
170 return wrapper.inner.embed_texts(texts);
171 }
172 self.api.embed(texts)
173 }
174}
175
176pub struct ApiEmbedder {
181 api_key: String,
182 api_url: String,
183 model: String,
184 client: reqwest::blocking::Client,
185}
186
187impl ApiEmbedder {
188 pub fn new(config: &EmbedConfig) -> Result<Self> {
189 let api_key = config
190 .api_key
191 .clone()
192 .or_else(|| std::env::var("OPENAI_API_KEY").ok())
193 .unwrap_or_default();
194 let api_url = config
195 .api_url
196 .clone()
197 .unwrap_or_else(|| "https://api.openai.com/v1/embeddings".into());
198 let model = config.api_model.clone().unwrap_or_else(|| "text-embedding-3-small".into());
199 let client = reqwest::blocking::Client::builder()
200 .timeout(std::time::Duration::from_secs(60))
201 .build()?;
202 Ok(Self { api_key, api_url, model, client })
203 }
204}
205
206impl Embedder for ApiEmbedder {
207 fn embed(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
208 #[derive(serde::Serialize)]
209 struct Request<'a> {
210 input: Vec<&'a str>,
211 model: &'a str,
212 }
213 #[derive(serde::Deserialize)]
214 struct Response {
215 data: Vec<Data>,
216 }
217 #[derive(serde::Deserialize)]
218 struct Data {
219 embedding: Vec<f32>,
220 }
221
222 let mut req = self.client.post(&self.api_url);
223 if !self.api_key.is_empty() {
224 req = req.header("Authorization", format!("Bearer {}", self.api_key));
225 }
226 let resp = req
227 .json(&Request { input: texts.to_vec(), model: &self.model })
228 .send()
229 .context("API embedding request failed")?;
230
231 if !resp.status().is_success() {
232 let status = resp.status();
233 let body = resp.text().unwrap_or_default();
234 bail!("API embedding error ({}): {}", status, body);
235 }
236
237 let body: Response = resp.json().context("Failed to parse API embedding response")?;
238 if body.data.len() != texts.len() {
239 bail!(
240 "API returned {} embeddings for {} texts",
241 body.data.len(),
242 texts.len()
243 );
244 }
245 Ok(body.data.into_iter().map(|d| d.embedding).collect())
246 }
247}
248
249#[cfg(feature = "candle")]
254pub mod local {
255 use std::path::Path;
256 use anyhow::{Context, Result};
257 use candle_core::{Device, Tensor};
258 use candle_nn::VarBuilder;
259 use candle_transformers::models::bert::{BertModel, Config, DTYPE};
260 use hf_hub::api::sync::Api;
261 use tokenizers::Tokenizer;
262
263 pub struct LocalEmbedder {
264 model: BertModel,
265 tokenizer: Tokenizer,
266 device: Device,
267 }
268
269 impl LocalEmbedder {
270 pub fn new(model_path: Option<&Path>) -> Result<Self> {
271 let device = Device::cuda_if_available(0).unwrap_or(Device::Cpu);
272
273 let (model, tokenizer) = if let Some(path) = model_path {
274 let tokenizer_path = path.join("tokenizer.json");
275 let model_path = path.join("model.safetensors");
276 let config_path = path.join("config.json");
277 let tokenizer = Tokenizer::from_file(tokenizer_path)
278 .map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))?;
279 let config_s = std::fs::read_to_string(config_path)
280 .context("Failed to read config.json")?;
281 let config: Config = serde_json::from_str(&config_s)
282 .context("Failed to parse config.json")?;
283 let vb = unsafe {
284 VarBuilder::from_mmaped_safetensors(&[model_path], DTYPE, &device)
285 .context("Failed to load model.safetensors")?
286 };
287 let model = BertModel::load(vb, &config)?;
288 (model, tokenizer)
289 } else {
290 let api = Api::new().context("Failed to init hf-hub API")?;
291 let repo = api.model("sentence-transformers/all-MiniLM-L6-v2".into());
292 let tokenizer_path = repo.get("tokenizer.json")?;
293 let model_path = repo.get("model.safetensors")?;
294 let config_path = repo.get("config.json")?;
295 let tokenizer = Tokenizer::from_file(tokenizer_path)
296 .map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))?;
297 let config_s = std::fs::read_to_string(config_path)
298 .context("Failed to read config.json")?;
299 let config: Config = serde_json::from_str(&config_s)
300 .context("Failed to parse config.json")?;
301 let vb = unsafe {
302 VarBuilder::from_mmaped_safetensors(&[model_path], DTYPE, &device)
303 .context("Failed to load model.safetensors")?
304 };
305 let model = BertModel::load(vb, &config)?;
306 (model, tokenizer)
307 };
308
309 Ok(Self { model, tokenizer, device })
310 }
311
312 fn mean_pool(&self, token_embeddings: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
313 let (_b, _s, h) = token_embeddings.shape().dims3()?;
314 let mask = attention_mask.to_dtype(candle_core::DType::F32)?;
315 let sum_emb = mask.unsqueeze(1)?.matmul(token_embeddings)?.squeeze(1)?;
316 let count = mask.sum(1)?;
317 let count_val = count.squeeze(0)?.to_vec0::<f32>()?;
318 if count_val == 0.0 {
319 return Ok(Tensor::zeros((1, h), candle_core::DType::F32, &self.device)?);
320 }
321 let result = (&sum_emb / &Tensor::full(count_val, (1, h), &self.device)?)?;
322 Ok(result)
323 }
324
325 fn normalize(&self, v: &Tensor) -> Result<Tensor> {
326 let (_b, h) = v.shape().dims2()?;
327 let sq_sum: f32 = v.sqr()?.sum(1)?.squeeze(0)?.to_vec0::<f32>()?;
328 let norm = sq_sum.sqrt();
329 if norm == 0.0 {
330 return Ok(v.clone());
331 }
332 Ok((v / &Tensor::full(norm, (1, h), &self.device)?)?)
333 }
334
335 pub fn embed_texts(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
336 let max_length = 128;
337 let mut all_embeddings = Vec::with_capacity(texts.len());
338
339 for text in texts {
340 let encoding = self
341 .tokenizer
342 .encode(*text, true)
343 .map_err(|e| anyhow::anyhow!("Tokenization failed: {}", e))?;
344
345 let token_ids = encoding
346 .get_ids()
347 .iter()
348 .map(|&id| id as u32)
349 .collect::<Vec<_>>();
350 let attention = encoding
351 .get_attention_mask()
352 .iter()
353 .map(|&m| m as u32)
354 .collect::<Vec<_>>();
355
356 let token_ids = if token_ids.len() > max_length {
357 let mut t = vec![token_ids[0]];
358 t.extend_from_slice(&token_ids[1..max_length - 1]);
359 t.push(token_ids[token_ids.len() - 1]);
360 t
361 } else {
362 token_ids
363 };
364
365 let seq_len = token_ids.len();
366 let input = Tensor::new(token_ids.as_slice(), &self.device)?.unsqueeze(0)?;
367 let mask = if attention.len() > token_ids.len() {
368 Tensor::new(&attention[..seq_len], &self.device)?.unsqueeze(0)?
369 } else {
370 Tensor::new(attention.as_slice(), &self.device)?.unsqueeze(0)?
371 };
372 let type_ids = input.zeros_like()?;
373
374 let output = self.model.forward(&input, &type_ids, Some(&mask))?;
375 let pooled = self.mean_pool(&output, &mask)?;
376 let normalized = self.normalize(&pooled)?;
377
378 let vec: Vec<f32> = normalized.squeeze(0)?.to_vec1()?;
379 all_embeddings.push(vec);
380 }
381
382 Ok(all_embeddings)
383 }
384 }
385}
386
387#[allow(dead_code)]
392pub struct AutoEmbedder {
393 local: Option<LocalWrapper>,
394 api: ApiEmbedder,
395}
396
397struct LocalWrapper {
398 #[cfg(feature = "candle")]
399 inner: local::LocalEmbedder,
400}
401
402impl AutoEmbedder {
403 pub fn new(config: &EmbedConfig) -> Result<Self> {
404 let local = match config.backend {
405 EmbedBackend::Api => None,
406 _ => Self::try_local(config),
407 };
408 let api = ApiEmbedder::new(config)?;
409 Ok(Self { local, api })
410 }
411
412 #[cfg(feature = "candle")]
413 fn try_local(config: &EmbedConfig) -> Option<LocalWrapper> {
414 match local::LocalEmbedder::new(config.model_path.as_deref()) {
415 Ok(inner) => {
416 eprintln!("[sift] using local embedding model (candle)");
417 Some(LocalWrapper { inner })
418 }
419 Err(e) => {
420 eprintln!("[sift] local model unavailable ({}), falling back to API", e);
421 None
422 }
423 }
424 }
425
426 #[cfg(not(feature = "candle"))]
427 fn try_local(_config: &EmbedConfig) -> Option<LocalWrapper> {
428 None
429 }
430
431}
432
433pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
438 let dot: f64 = a.iter().zip(b).map(|(x, y)| *x as f64 * *y as f64).sum();
439 let na: f64 = a.iter().map(|x| *x as f64 * *x as f64).sum::<f64>().sqrt();
440 let nb: f64 = b.iter().map(|x| *x as f64 * *x as f64).sum::<f64>().sqrt();
441 if na == 0.0 || nb == 0.0 {
442 0.0
443 } else {
444 dot / (na * nb)
445 }
446}
447
448pub fn top_k_similar(query: &[f32], candidates: &[(usize, &[f32])], k: usize) -> Vec<(usize, f64)> {
449 let mut scores: Vec<(usize, f64)> = candidates
450 .iter()
451 .map(|(id, vec)| (*id, cosine_similarity(query, vec)))
452 .collect();
453 scores.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
454 scores.truncate(k);
455 scores
456}