1use std::sync::Mutex;
2
3use tracing;
4
5use crate::config::EmbedConfig;
6use crate::error::{EmbedError, Result};
7use crate::model_files::ModelFiles;
8use crate::tokenize::BertTokenizer;
9
10const BGE_SMALL_MODEL_ID: &str = "bge-small-en-v1.5";
11const BGE_SMALL_HIDDEN_SIZE: usize = 384;
12const BGE_SMALL_MAX_SEQ_LEN: usize = 512;
13
14pub struct LocalEmbedder {
15 session: Mutex<ort::session::Session>,
16 tokenizer: BertTokenizer,
17 config: EmbedConfig,
18 dimension: usize,
19 model_id: String,
20}
21
22impl LocalEmbedder {
23 pub async fn new(config: EmbedConfig) -> Result<Self> {
24 let model_files =
25 ModelFiles::ensure_available(BGE_SMALL_MODEL_ID, config.cache_dir.as_deref()).await?;
26
27 let tokenizer = BertTokenizer::from_file(&model_files.tokenizer_path)?;
28
29 ort::init().with_name("argyph-embed").commit();
30
31 let session = ort::session::Session::builder()
32 .map_err(|e| EmbedError::Config(format!("ONNX session builder: {e}")))?
33 .commit_from_file(model_files.onnx_path)
34 .map_err(|e| EmbedError::Config(format!("failed to load ONNX model: {e}")))?;
35
36 tracing::info!(
37 model_id = BGE_SMALL_MODEL_ID,
38 dimension = BGE_SMALL_HIDDEN_SIZE,
39 "local embedder ready"
40 );
41
42 Ok(Self {
43 session: Mutex::new(session),
44 tokenizer,
45 config,
46 dimension: BGE_SMALL_HIDDEN_SIZE,
47 model_id: BGE_SMALL_MODEL_ID.to_string(),
48 })
49 }
50
51 fn do_embed(
52 session: &mut ort::session::Session,
53 tokenizer: &BertTokenizer,
54 texts: &[String],
55 batch_size: usize,
56 seq_len: usize,
57 dimension: usize,
58 ) -> Result<Vec<Vec<f32>>> {
59 let batch = tokenizer.encode_batch(texts, seq_len)?;
60
61 use ort::value::Tensor;
62
63 let attention_mask_data = batch.attention_mask.clone();
64
65 let input_ids_tensor = Tensor::from_array((
66 [batch_size, batch.seq_len],
67 batch.input_ids.into_boxed_slice(),
68 ))
69 .map_err(|e| EmbedError::Config(format!("ONNX input_ids tensor: {e}")))?;
70
71 let attention_mask_tensor = Tensor::from_array((
72 [batch_size, batch.seq_len],
73 batch.attention_mask.into_boxed_slice(),
74 ))
75 .map_err(|e| EmbedError::Config(format!("ONNX attention_mask tensor: {e}")))?;
76
77 let token_type_ids = vec![0_i64; batch_size * batch.seq_len];
78 let token_type_ids_tensor = Tensor::from_array((
79 [batch_size, batch.seq_len],
80 token_type_ids.into_boxed_slice(),
81 ))
82 .map_err(|e| EmbedError::Config(format!("ONNX token_type_ids tensor: {e}")))?;
83
84 let inputs = ort::inputs![
85 "input_ids" => input_ids_tensor.view(),
86 "attention_mask" => attention_mask_tensor.view(),
87 "token_type_ids" => token_type_ids_tensor.view(),
88 ];
89
90 let outputs = session
91 .run(inputs)
92 .map_err(|e| EmbedError::Config(format!("ONNX inference failed: {e}")))?;
93
94 let last_hidden_value = outputs
95 .get("last_hidden_state")
96 .ok_or_else(|| EmbedError::Config("ONNX output missing 'last_hidden_state'".into()))?;
97
98 let (_out_shape, last_hidden_data): (_, &[f32]) = last_hidden_value
99 .try_extract_tensor::<f32>()
100 .map_err(|e| EmbedError::Config(format!("ONNX output extraction: {e}")))?;
101
102 let owned_data = last_hidden_data.to_vec();
103
104 drop(outputs);
105
106 Ok(BertTokenizer::mean_pool(
107 &owned_data,
108 &attention_mask_data,
109 batch_size,
110 batch.seq_len,
111 dimension,
112 ))
113 }
114}
115
116#[async_trait::async_trait]
117impl crate::Embedder for LocalEmbedder {
118 fn dimension(&self) -> usize {
119 self.dimension
120 }
121
122 fn model_id(&self) -> &str {
123 &self.model_id
124 }
125
126 async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
127 if texts.is_empty() {
128 return Err(EmbedError::EmptyInput);
129 }
130
131 let chunk_size = self.config.batch_size.min(128);
132 let mut all_embeddings: Vec<Vec<f32>> = Vec::with_capacity(texts.len());
133
134 for chunk in texts.chunks(chunk_size) {
135 let batch_texts: Vec<String> = chunk.to_vec();
136 let n = batch_texts.len();
137
138 let embeddings = {
139 let mut session = self.session.lock().unwrap_or_else(|e| e.into_inner());
140 Self::do_embed(
141 &mut session,
142 &self.tokenizer,
143 &batch_texts,
144 n,
145 BGE_SMALL_MAX_SEQ_LEN,
146 self.dimension,
147 )?
148 };
149
150 all_embeddings.extend(embeddings);
151 }
152
153 Ok(all_embeddings)
154 }
155}
156
157#[cfg(test)]
158#[allow(clippy::unwrap_used, clippy::expect_used)]
159mod tests {
160 use super::*;
161 use crate::config::EmbedConfig;
162 use crate::Embedder;
163
164 fn model_dir_exists() -> bool {
165 let home = std::env::var("HOME").unwrap_or_else(|_| ".".into());
166 let cache: std::path::PathBuf =
167 std::path::PathBuf::from(home).join(".cache/argyph/models/bge-small-en-v1.5");
168 cache.join("model.onnx").exists() && cache.join("tokenizer.json").exists()
169 }
170
171 #[tokio::test]
172 async fn local_embedder_succeeds_even_if_cache_empty() {
173 if model_dir_exists() {
174 eprintln!("model already cached, test would re-download (slow); skipping");
175 return;
176 }
177 let config = EmbedConfig {
178 cache_dir: None,
179 ..EmbedConfig::default()
180 };
181 let result = LocalEmbedder::new(config).await;
182 match result {
189 Ok(_) => {}
190 Err(EmbedError::Config(_)) => {}
191 Err(other) => panic!("unexpected error: {other:?}"),
192 }
193 }
194
195 #[tokio::test]
196 async fn local_embedder_works_if_model_cached() {
197 if !model_dir_exists() {
198 eprintln!("model not cached, skipping integration test");
199 return;
200 }
201
202 let home = std::env::var("HOME").unwrap();
203 let cache: std::path::PathBuf = std::path::PathBuf::from(home).join(".cache/argyph/models");
204
205 let config = EmbedConfig {
206 cache_dir: Some(cache),
207 ..EmbedConfig::default()
208 };
209
210 let embedder = LocalEmbedder::new(config).await.unwrap();
211 assert_eq!(embedder.dimension(), 384);
212 assert_eq!(embedder.model_id(), "bge-small-en-v1.5");
213
214 let texts: Vec<String> = vec!["hello world".into(), "goodbye world".into()];
215 let embeddings = embedder.embed(&texts).await.unwrap();
216 assert_eq!(embeddings.len(), 2);
217 for v in &embeddings {
218 assert_eq!(v.len(), 384);
219 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
220 assert!(
221 (norm - 1.0).abs() < 0.01,
222 "L2 norm should be approx 1.0, got {norm}"
223 );
224 }
225 }
226
227 #[tokio::test]
228 async fn local_embedder_empty_input_error() {
229 if !model_dir_exists() {
230 eprintln!("model not cached, skipping integration test");
231 return;
232 }
233
234 let home = std::env::var("HOME").unwrap();
235 let cache: std::path::PathBuf = std::path::PathBuf::from(home).join(".cache/argyph/models");
236
237 let config = EmbedConfig {
238 cache_dir: Some(cache),
239 ..EmbedConfig::default()
240 };
241 let embedder = LocalEmbedder::new(config).await.unwrap();
242 let result = embedder.embed(&[]).await;
243 assert!(result.is_err());
244 match result.unwrap_err() {
245 EmbedError::EmptyInput => {}
246 other => panic!("expected EmptyInput, got: {other:?}"),
247 }
248 }
249}