Skip to main content

search_semantically/
embedder.rs

1use anyhow::{Context, Result};
2use ort::session::Session;
3use std::path::{Path, PathBuf};
4use std::str::FromStr;
5use std::sync::Arc;
6
7pub type DownloadCallback = Arc<dyn Fn(&str) + Send + Sync>;
8
9const DEFAULT_MODEL_NAME: &str = "Xenova/all-MiniLM-L6-v2";
10const DEFAULT_DIMENSION: usize = 384;
11
12pub struct Embedder {
13    model_cache_dir: PathBuf,
14    model_name: String,
15    session: Option<Session>,
16    tokenizer: Option<tokenizers::Tokenizer>,
17    dimension: usize,
18    download_callback: Option<DownloadCallback>,
19}
20
21impl Embedder {
22    pub fn new(model_cache_dir: PathBuf) -> Self {
23        Self {
24            model_cache_dir,
25            model_name: DEFAULT_MODEL_NAME.to_string(),
26            session: None,
27            tokenizer: None,
28            dimension: DEFAULT_DIMENSION,
29            download_callback: None,
30        }
31    }
32
33    pub fn with_download_callback(mut self, callback: DownloadCallback) -> Self {
34        self.download_callback = Some(callback);
35        self
36    }
37
38    pub fn set_download_callback(&mut self, callback: DownloadCallback) {
39        self.download_callback = Some(callback);
40    }
41
42    pub fn initialize(&mut self) -> Result<()> {
43        if self.session.is_some() {
44            return Ok(());
45        }
46
47        let model_dir = self.model_cache_dir.join(&self.model_name);
48        std::fs::create_dir_all(&model_dir)
49            .with_context(|| format!("Creating model cache dir: {}", model_dir.display()))?;
50
51        let onnx_path = model_dir.join("model.onnx");
52        let tokenizer_path = model_dir.join("tokenizer.json");
53
54        if !onnx_path.exists() || !tokenizer_path.exists() {
55            download_model(&self.model_name, &model_dir, self.download_callback.clone())?;
56        }
57
58        let session = Session::builder()
59            .context("Creating ONNX session builder")?
60            .commit_from_file(&onnx_path)
61            .with_context(|| format!("Loading ONNX model from {}", onnx_path.display()))?;
62
63        let tokenizer_data = std::fs::read_to_string(&tokenizer_path)
64            .with_context(|| format!("Reading tokenizer from {}", tokenizer_path.display()))?;
65        let tokenizer = tokenizers::Tokenizer::from_str(&tokenizer_data)
66            .map_err(|e| anyhow::anyhow!("Parsing tokenizer JSON: {e}"))?;
67
68        self.dimension = detect_dimension(&session);
69        self.session = Some(session);
70        self.tokenizer = Some(tokenizer);
71
72        Ok(())
73    }
74
75    pub fn embed(&mut self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
76        let tokenizer = self.tokenizer.as_ref().expect("Embedder not initialized");
77
78        let mut results = Vec::with_capacity(texts.len());
79
80        for text in texts {
81            let encoding = tokenizer
82                .encode(*text, true)
83                .map_err(|e| anyhow::anyhow!("Tokenization failed: {e}"))?;
84
85            let ids: Vec<i64> = encoding.get_ids().iter().map(|&id| id as i64).collect();
86            let attention_mask: Vec<i64> = encoding
87                .get_attention_mask()
88                .iter()
89                .map(|&m| m as i64)
90                .collect();
91            let type_ids: Vec<i64> = encoding.get_type_ids().iter().map(|&t| t as i64).collect();
92
93            let len = ids.len();
94            let input_ids = ndarray::Array2::from_shape_vec((1, len), ids)
95                .context("Creating input_ids array")?;
96            let attn_mask = ndarray::Array2::from_shape_vec((1, len), attention_mask)
97                .context("Creating attention_mask array")?;
98            let token_types = ndarray::Array2::from_shape_vec((1, len), type_ids)
99                .context("Creating token_type_ids array")?;
100
101            let session = self.session.as_mut().expect("Embedder not session");
102
103            let input_ids_val =
104                ort::value::Tensor::from_array(input_ids).context("Creating input_ids tensor")?;
105            let attn_mask_val = ort::value::Tensor::from_array(attn_mask)
106                .context("Creating attention_mask tensor")?;
107            let token_types_val = ort::value::Tensor::from_array(token_types)
108                .context("Creating token_type_ids tensor")?;
109
110            let outputs = session
111                .run(ort::inputs! {
112                    "input_ids" => input_ids_val,
113                    "attention_mask" => attn_mask_val,
114                    "token_type_ids" => token_types_val,
115                })
116                .context("Running ONNX inference")?;
117
118            let output = outputs.iter().next().context("No output from model")?.1;
119
120            let (_, data) = output
121                .try_extract_tensor::<f32>()
122                .context("Extracting tensor")?;
123
124            let mask_f32: Vec<f32> = encoding
125                .get_attention_mask()
126                .iter()
127                .map(|&m| m as f32)
128                .collect();
129            let embedding = mean_pool_normalize(data, len, self.dimension, &mask_f32);
130
131            results.push(embedding);
132        }
133
134        Ok(results)
135    }
136
137    pub fn dimension(&self) -> usize {
138        self.dimension
139    }
140}
141
142fn detect_dimension(session: &Session) -> usize {
143    session
144        .outputs()
145        .first()
146        .and_then(|outlet| outlet.dtype().tensor_shape())
147        .and_then(|shape| shape.last().copied())
148        .filter(|&d| d > 0)
149        .map(|d| d as usize)
150        .unwrap_or(DEFAULT_DIMENSION)
151}
152
153fn mean_pool_normalize(data: &[f32], seq_len: usize, dim: usize, mask: &[f32]) -> Vec<f32> {
154    let mut pooled = vec![0.0_f32; dim];
155    let mut mask_sum = 0.0_f32;
156
157    for i in 0..seq_len {
158        let weight = mask[i];
159        mask_sum += weight;
160        for j in 0..dim {
161            pooled[j] += data[i * dim + j] * weight;
162        }
163    }
164
165    if mask_sum > 0.0 {
166        for val in pooled.iter_mut() {
167            *val /= mask_sum;
168        }
169    }
170
171    let norm: f32 = pooled.iter().map(|v| v * v).sum::<f32>().sqrt();
172    if norm > 0.0 {
173        for val in pooled.iter_mut() {
174            *val /= norm;
175        }
176    }
177
178    pooled
179}
180
181fn download_model(
182    model_name: &str,
183    target_dir: &Path,
184    callback: Option<DownloadCallback>,
185) -> Result<()> {
186    let files = ["model.onnx", "tokenizer.json"];
187
188    let model_name_owned = model_name.to_string();
189    let target_dir_owned = target_dir.to_path_buf();
190
191    let handle = std::thread::spawn(move || -> Result<()> {
192        for file in &files {
193            let url = format!("https://huggingface.co/{model_name_owned}/resolve/main/{file}");
194            let dest = target_dir_owned.join(file);
195
196            if let Some(ref cb) = callback {
197                cb(&url);
198            }
199
200            let response = reqwest::blocking::get(&url)
201                .with_context(|| format!("HTTP request to {url}"))?
202                .error_for_status()
203                .context("HTTP request failed")?;
204            let buf = response.bytes().context("Reading response body")?;
205            std::fs::write(&dest, &buf).with_context(|| format!("Writing {}", dest.display()))?;
206        }
207        Ok(())
208    });
209
210    handle
211        .join()
212        .map_err(|e| anyhow::anyhow!("Model download thread panicked: {e:?}"))?
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218
219    #[test]
220    fn embedder_new_has_no_session() {
221        let embedder = Embedder::new(std::env::temp_dir());
222        assert!(embedder.session.is_none());
223        assert!(embedder.tokenizer.is_none());
224    }
225
226    #[test]
227    fn embedder_default_dimension() {
228        let embedder = Embedder::new(std::env::temp_dir());
229        assert_eq!(embedder.dimension(), 384);
230    }
231
232    #[test]
233    fn mean_pool_normalize_produces_unit_vector() {
234        let data = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0]; // 2 tokens x 3 dim
235        let mask = vec![1.0_f32, 1.0];
236        let result = mean_pool_normalize(&data, 2, 3, &mask);
237
238        let norm: f32 = result.iter().map(|v| v * v).sum::<f32>().sqrt();
239        assert!(
240            (norm - 1.0).abs() < 1e-5,
241            "Should be unit vector, got norm {norm}"
242        );
243    }
244
245    #[test]
246    fn mean_pool_normalize_with_zero_mask() {
247        let data = vec![1.0_f32, 2.0, 3.0];
248        let mask = vec![0.0_f32];
249        let result = mean_pool_normalize(&data, 1, 3, &mask);
250        assert!(result.iter().all(|&v| v == 0.0));
251    }
252}