Skip to main content

argyph_embed/
local.rs

1use std::sync::Mutex;
2
3use tracing;
4
5use crate::config::EmbedConfig;
6use crate::error::{EmbedError, Result};
7use crate::model_files::ModelFiles;
8use crate::tokenize::BertTokenizer;
9
10const BGE_SMALL_MODEL_ID: &str = "bge-small-en-v1.5";
11const BGE_SMALL_HIDDEN_SIZE: usize = 384;
12const BGE_SMALL_MAX_SEQ_LEN: usize = 512;
13
14pub struct LocalEmbedder {
15    session: Mutex<ort::session::Session>,
16    tokenizer: BertTokenizer,
17    config: EmbedConfig,
18    dimension: usize,
19    model_id: String,
20}
21
22impl LocalEmbedder {
23    pub async fn new(config: EmbedConfig) -> Result<Self> {
24        let model_files =
25            ModelFiles::ensure_available(BGE_SMALL_MODEL_ID, config.cache_dir.as_deref()).await?;
26
27        let tokenizer = BertTokenizer::from_file(&model_files.tokenizer_path)?;
28
29        ort::init().with_name("argyph-embed").commit();
30
31        let session = ort::session::Session::builder()
32            .map_err(|e| EmbedError::Config(format!("ONNX session builder: {e}")))?
33            .commit_from_file(model_files.onnx_path)
34            .map_err(|e| EmbedError::Config(format!("failed to load ONNX model: {e}")))?;
35
36        tracing::info!(
37            model_id = BGE_SMALL_MODEL_ID,
38            dimension = BGE_SMALL_HIDDEN_SIZE,
39            "local embedder ready"
40        );
41
42        Ok(Self {
43            session: Mutex::new(session),
44            tokenizer,
45            config,
46            dimension: BGE_SMALL_HIDDEN_SIZE,
47            model_id: BGE_SMALL_MODEL_ID.to_string(),
48        })
49    }
50
51    fn do_embed(
52        session: &mut ort::session::Session,
53        tokenizer: &BertTokenizer,
54        texts: &[String],
55        batch_size: usize,
56        seq_len: usize,
57        dimension: usize,
58    ) -> Result<Vec<Vec<f32>>> {
59        let batch = tokenizer.encode_batch(texts, seq_len)?;
60
61        use ort::value::Tensor;
62
63        let attention_mask_data = batch.attention_mask.clone();
64
65        let input_ids_tensor = Tensor::from_array((
66            [batch_size, batch.seq_len],
67            batch.input_ids.into_boxed_slice(),
68        ))
69        .map_err(|e| EmbedError::Config(format!("ONNX input_ids tensor: {e}")))?;
70
71        let attention_mask_tensor = Tensor::from_array((
72            [batch_size, batch.seq_len],
73            batch.attention_mask.into_boxed_slice(),
74        ))
75        .map_err(|e| EmbedError::Config(format!("ONNX attention_mask tensor: {e}")))?;
76
77        let token_type_ids = vec![0_i64; batch_size * batch.seq_len];
78        let token_type_ids_tensor = Tensor::from_array((
79            [batch_size, batch.seq_len],
80            token_type_ids.into_boxed_slice(),
81        ))
82        .map_err(|e| EmbedError::Config(format!("ONNX token_type_ids tensor: {e}")))?;
83
84        let inputs = ort::inputs![
85            "input_ids" => input_ids_tensor.view(),
86            "attention_mask" => attention_mask_tensor.view(),
87            "token_type_ids" => token_type_ids_tensor.view(),
88        ];
89
90        let outputs = session
91            .run(inputs)
92            .map_err(|e| EmbedError::Config(format!("ONNX inference failed: {e}")))?;
93
94        let last_hidden_value = outputs
95            .get("last_hidden_state")
96            .ok_or_else(|| EmbedError::Config("ONNX output missing 'last_hidden_state'".into()))?;
97
98        let (_out_shape, last_hidden_data): (_, &[f32]) = last_hidden_value
99            .try_extract_tensor::<f32>()
100            .map_err(|e| EmbedError::Config(format!("ONNX output extraction: {e}")))?;
101
102        let owned_data = last_hidden_data.to_vec();
103
104        drop(outputs);
105
106        Ok(BertTokenizer::mean_pool(
107            &owned_data,
108            &attention_mask_data,
109            batch_size,
110            batch.seq_len,
111            dimension,
112        ))
113    }
114}
115
116#[async_trait::async_trait]
117impl crate::Embedder for LocalEmbedder {
118    fn dimension(&self) -> usize {
119        self.dimension
120    }
121
122    fn model_id(&self) -> &str {
123        &self.model_id
124    }
125
126    async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
127        if texts.is_empty() {
128            return Err(EmbedError::EmptyInput);
129        }
130
131        let chunk_size = self.config.batch_size.min(128);
132        let mut all_embeddings: Vec<Vec<f32>> = Vec::with_capacity(texts.len());
133
134        for chunk in texts.chunks(chunk_size) {
135            let batch_texts: Vec<String> = chunk.to_vec();
136            let n = batch_texts.len();
137
138            let embeddings = {
139                let mut session = self.session.lock().unwrap_or_else(|e| e.into_inner());
140                Self::do_embed(
141                    &mut session,
142                    &self.tokenizer,
143                    &batch_texts,
144                    n,
145                    BGE_SMALL_MAX_SEQ_LEN,
146                    self.dimension,
147                )?
148            };
149
150            all_embeddings.extend(embeddings);
151        }
152
153        Ok(all_embeddings)
154    }
155}
156
157#[cfg(test)]
158#[allow(clippy::unwrap_used, clippy::expect_used)]
159mod tests {
160    use super::*;
161    use crate::config::EmbedConfig;
162    use crate::Embedder;
163
164    fn model_dir_exists() -> bool {
165        let home = std::env::var("HOME").unwrap_or_else(|_| ".".into());
166        let cache: std::path::PathBuf =
167            std::path::PathBuf::from(home).join(".cache/argyph/models/bge-small-en-v1.5");
168        cache.join("model.onnx").exists() && cache.join("tokenizer.json").exists()
169    }
170
171    #[tokio::test]
172    async fn local_embedder_succeeds_even_if_cache_empty() {
173        if model_dir_exists() {
174            eprintln!("model already cached, test would re-download (slow); skipping");
175            return;
176        }
177        let config = EmbedConfig {
178            cache_dir: None,
179            ..EmbedConfig::default()
180        };
181        let result = LocalEmbedder::new(config).await;
182        // The test passes if the embedder either successfully downloads
183        // the model or returns a Config error (network unreachable,
184        // download failed, rename failed because the tmp file is
185        // missing, etc.). The point is that the code path never
186        // panics — any downstream IO failure surfaces as a Config error
187        // string, which is what we accept here.
188        match result {
189            Ok(_) => {}
190            Err(EmbedError::Config(_)) => {}
191            Err(other) => panic!("unexpected error: {other:?}"),
192        }
193    }
194
195    #[tokio::test]
196    async fn local_embedder_works_if_model_cached() {
197        if !model_dir_exists() {
198            eprintln!("model not cached, skipping integration test");
199            return;
200        }
201
202        let home = std::env::var("HOME").unwrap();
203        let cache: std::path::PathBuf = std::path::PathBuf::from(home).join(".cache/argyph/models");
204
205        let config = EmbedConfig {
206            cache_dir: Some(cache),
207            ..EmbedConfig::default()
208        };
209
210        let embedder = LocalEmbedder::new(config).await.unwrap();
211        assert_eq!(embedder.dimension(), 384);
212        assert_eq!(embedder.model_id(), "bge-small-en-v1.5");
213
214        let texts: Vec<String> = vec!["hello world".into(), "goodbye world".into()];
215        let embeddings = embedder.embed(&texts).await.unwrap();
216        assert_eq!(embeddings.len(), 2);
217        for v in &embeddings {
218            assert_eq!(v.len(), 384);
219            let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
220            assert!(
221                (norm - 1.0).abs() < 0.01,
222                "L2 norm should be approx 1.0, got {norm}"
223            );
224        }
225    }
226
227    #[tokio::test]
228    async fn local_embedder_empty_input_error() {
229        if !model_dir_exists() {
230            eprintln!("model not cached, skipping integration test");
231            return;
232        }
233
234        let home = std::env::var("HOME").unwrap();
235        let cache: std::path::PathBuf = std::path::PathBuf::from(home).join(".cache/argyph/models");
236
237        let config = EmbedConfig {
238            cache_dir: Some(cache),
239            ..EmbedConfig::default()
240        };
241        let embedder = LocalEmbedder::new(config).await.unwrap();
242        let result = embedder.embed(&[]).await;
243        assert!(result.is_err());
244        match result.unwrap_err() {
245            EmbedError::EmptyInput => {}
246            other => panic!("expected EmptyInput, got: {other:?}"),
247        }
248    }
249}