1use std::path::PathBuf;
2
3use anyhow::{bail, Context, Result};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum EmbedBackend {
8 Auto,
10 Local,
12 Api,
14}
15
16#[derive(Debug, Clone)]
18pub struct EmbedConfig {
19 pub backend: EmbedBackend,
21 pub model_path: Option<PathBuf>,
23 pub api_key: Option<String>,
25 pub api_url: Option<String>,
27 pub api_model: Option<String>,
29}
30
31impl EmbedConfig {
32 pub fn from_env() -> Self {
33 Self {
34 backend: match std::env::var("SIFT_EMBED_BACKEND")
35 .as_deref()
36 .unwrap_or("auto")
37 {
38 "local" => EmbedBackend::Local,
39 "api" => EmbedBackend::Api,
40 _ => EmbedBackend::Auto,
41 },
42 model_path: std::env::var("SIFT_EMBED_MODEL_PATH").ok().map(Into::into),
43 api_key: std::env::var("SIFT_EMBED_API_KEY").ok(),
44 api_url: std::env::var("SIFT_EMBED_API_URL").ok(),
45 api_model: Some(
46 std::env::var("SIFT_EMBED_API_MODEL")
47 .unwrap_or_else(|_| "text-embedding-3-small".into()),
48 ),
49 }
50 }
51}
52
53pub trait Embedder {
58 fn embed(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>>;
59}
60
61impl Embedder for AutoEmbedder {
62 fn embed(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
63 #[cfg(feature = "candle")]
64 if let Some(wrapper) = &self.local {
65 return wrapper.inner.embed_texts(texts);
66 }
67 self.api.embed(texts)
68 }
69}
70
71pub struct ApiEmbedder {
76 api_key: String,
77 api_url: String,
78 model: String,
79 client: reqwest::blocking::Client,
80}
81
82impl ApiEmbedder {
83 pub fn new(config: &EmbedConfig) -> Result<Self> {
84 let api_key = config
85 .api_key
86 .clone()
87 .or_else(|| std::env::var("OPENAI_API_KEY").ok())
88 .unwrap_or_default();
89 let api_url = config
90 .api_url
91 .clone()
92 .unwrap_or_else(|| "https://api.openai.com/v1/embeddings".into());
93 let model = config.api_model.clone().unwrap_or_else(|| "text-embedding-3-small".into());
94 let client = reqwest::blocking::Client::builder()
95 .timeout(std::time::Duration::from_secs(60))
96 .build()?;
97 Ok(Self { api_key, api_url, model, client })
98 }
99}
100
101impl Embedder for ApiEmbedder {
102 fn embed(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
103 #[derive(serde::Serialize)]
104 struct Request<'a> {
105 input: Vec<&'a str>,
106 model: &'a str,
107 }
108 #[derive(serde::Deserialize)]
109 struct Response {
110 data: Vec<Data>,
111 }
112 #[derive(serde::Deserialize)]
113 struct Data {
114 embedding: Vec<f32>,
115 }
116
117 let mut req = self.client.post(&self.api_url);
118 if !self.api_key.is_empty() {
119 req = req.header("Authorization", format!("Bearer {}", self.api_key));
120 }
121 let resp = req
122 .json(&Request { input: texts.to_vec(), model: &self.model })
123 .send()
124 .context("API embedding request failed")?;
125
126 if !resp.status().is_success() {
127 let status = resp.status();
128 let body = resp.text().unwrap_or_default();
129 bail!("API embedding error ({}): {}", status, body);
130 }
131
132 let body: Response = resp.json().context("Failed to parse API embedding response")?;
133 if body.data.len() != texts.len() {
134 bail!(
135 "API returned {} embeddings for {} texts",
136 body.data.len(),
137 texts.len()
138 );
139 }
140 Ok(body.data.into_iter().map(|d| d.embedding).collect())
141 }
142}
143
144#[cfg(feature = "candle")]
149pub mod local {
150 use std::path::Path;
151 use anyhow::{Context, Result};
152 use candle_core::{Device, Tensor};
153 use candle_nn::VarBuilder;
154 use candle_transformers::models::bert::{BertModel, Config, DTYPE};
155 use hf_hub::api::sync::Api;
156 use tokenizers::Tokenizer;
157
158 pub struct LocalEmbedder {
159 model: BertModel,
160 tokenizer: Tokenizer,
161 device: Device,
162 }
163
164 impl LocalEmbedder {
165 pub fn new(model_path: Option<&Path>) -> Result<Self> {
166 let device = Device::cuda_if_available(0).unwrap_or(Device::Cpu);
167
168 let (model, tokenizer) = if let Some(path) = model_path {
169 let tokenizer_path = path.join("tokenizer.json");
170 let model_path = path.join("model.safetensors");
171 let config_path = path.join("config.json");
172 let tokenizer = Tokenizer::from_file(tokenizer_path)
173 .map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))?;
174 let config_s = std::fs::read_to_string(config_path)
175 .context("Failed to read config.json")?;
176 let config: Config = serde_json::from_str(&config_s)
177 .context("Failed to parse config.json")?;
178 let vb = unsafe {
179 VarBuilder::from_mmaped_safetensors(&[model_path], DTYPE, &device)
180 .context("Failed to load model.safetensors")?
181 };
182 let model = BertModel::load(vb, &config)?;
183 (model, tokenizer)
184 } else {
185 let api = Api::new().context("Failed to init hf-hub API")?;
186 let repo = api.model("sentence-transformers/all-MiniLM-L6-v2".into());
187 let tokenizer_path = repo.get("tokenizer.json")?;
188 let model_path = repo.get("model.safetensors")?;
189 let config_path = repo.get("config.json")?;
190 let tokenizer = Tokenizer::from_file(tokenizer_path)
191 .map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))?;
192 let config_s = std::fs::read_to_string(config_path)
193 .context("Failed to read config.json")?;
194 let config: Config = serde_json::from_str(&config_s)
195 .context("Failed to parse config.json")?;
196 let vb = unsafe {
197 VarBuilder::from_mmaped_safetensors(&[model_path], DTYPE, &device)
198 .context("Failed to load model.safetensors")?
199 };
200 let model = BertModel::load(vb, &config)?;
201 (model, tokenizer)
202 };
203
204 Ok(Self { model, tokenizer, device })
205 }
206
207 fn mean_pool(&self, token_embeddings: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
208 let mask = attention_mask
209 .unsqueeze(2)?
210 .to_dtype(candle_core::DType::F32)?;
211 let sum_embeddings = (token_embeddings * &mask)?.sum(1)?;
212 let sum_mask = mask.sum(1)?;
213 let mean = (&sum_embeddings / &sum_mask)?;
214 Ok(mean)
215 }
216
217 fn normalize(&self, v: &Tensor) -> Result<Tensor> {
218 let norm = v.sqr()?.sum_keepdim(1)?.sqrt()?;
219 Ok((v / norm)?)
220 }
221
222 pub fn embed_texts(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
223 let max_length = 128;
224 let mut all_embeddings = Vec::with_capacity(texts.len());
225
226 for text in texts {
227 let encoding = self
228 .tokenizer
229 .encode(*text, true)
230 .map_err(|e| anyhow::anyhow!("Tokenization failed: {}", e))?;
231
232 let token_ids = encoding
233 .get_ids()
234 .iter()
235 .map(|&id| id as u32)
236 .collect::<Vec<_>>();
237 let attention = encoding
238 .get_attention_mask()
239 .iter()
240 .map(|&m| m as u32)
241 .collect::<Vec<_>>();
242
243 let token_ids = if token_ids.len() > max_length {
244 let mut t = vec![token_ids[0]];
245 t.extend_from_slice(&token_ids[1..max_length - 1]);
246 t.push(token_ids[token_ids.len() - 1]);
247 t
248 } else {
249 token_ids
250 };
251
252 let seq_len = token_ids.len();
253 let input = Tensor::new(token_ids.as_slice(), &self.device)?.unsqueeze(0)?;
254 let mask = if attention.len() > token_ids.len() {
255 Tensor::new(&attention[..seq_len], &self.device)?.unsqueeze(0)?
256 } else {
257 Tensor::new(attention.as_slice(), &self.device)?.unsqueeze(0)?
258 };
259 let type_ids = input.zeros_like()?;
260
261 let output = self.model.forward(&input, &type_ids, Some(&mask))?;
262 let pooled = self.mean_pool(&output, &mask)?;
263 let normalized = self.normalize(&pooled)?;
264
265 let vec: Vec<f32> = normalized.to_vec1()?;
266 all_embeddings.push(vec);
267 }
268
269 Ok(all_embeddings)
270 }
271 }
272}
273
274#[allow(dead_code)]
279pub struct AutoEmbedder {
280 local: Option<LocalWrapper>,
281 api: ApiEmbedder,
282}
283
284struct LocalWrapper {
285 #[cfg(feature = "candle")]
286 inner: local::LocalEmbedder,
287}
288
289impl AutoEmbedder {
290 pub fn new(config: &EmbedConfig) -> Result<Self> {
291 let local = match config.backend {
292 EmbedBackend::Api => None,
293 _ => Self::try_local(config),
294 };
295 let api = ApiEmbedder::new(config)?;
296 Ok(Self { local, api })
297 }
298
299 #[cfg(feature = "candle")]
300 fn try_local(config: &EmbedConfig) -> Option<LocalWrapper> {
301 match local::LocalEmbedder::new(config.model_path.as_deref()) {
302 Ok(inner) => {
303 eprintln!("[sift] using local embedding model (candle)");
304 Some(LocalWrapper { inner })
305 }
306 Err(e) => {
307 eprintln!("[sift] local model unavailable ({}), falling back to API", e);
308 None
309 }
310 }
311 }
312
313 #[cfg(not(feature = "candle"))]
314 fn try_local(_config: &EmbedConfig) -> Option<LocalWrapper> {
315 None
316 }
317
318}
319
320pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
325 let dot: f64 = a.iter().zip(b).map(|(x, y)| *x as f64 * *y as f64).sum();
326 let na: f64 = a.iter().map(|x| *x as f64 * *x as f64).sum::<f64>().sqrt();
327 let nb: f64 = b.iter().map(|x| *x as f64 * *x as f64).sum::<f64>().sqrt();
328 if na == 0.0 || nb == 0.0 {
329 0.0
330 } else {
331 dot / (na * nb)
332 }
333}
334
335pub fn top_k_similar(query: &[f32], candidates: &[(usize, &[f32])], k: usize) -> Vec<(usize, f64)> {
336 let mut scores: Vec<(usize, f64)> = candidates
337 .iter()
338 .map(|(id, vec)| (*id, cosine_similarity(query, vec)))
339 .collect();
340 scores.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
341 scores.truncate(k);
342 scores
343}