1use lru::LruCache;
4use ndarray::Array2;
5use once_cell::sync::OnceCell;
6use ort::ep::ExecutionProvider as OrtExecutionProvider;
7use ort::session::Session;
8use std::num::NonZeroUsize;
9use std::path::{Path, PathBuf};
10use std::sync::Mutex;
11use thiserror::Error;
12
13const MODEL_REPO: &str = "intfloat/e5-base-v2";
15const MODEL_FILE: &str = "onnx/model.onnx";
16const TOKENIZER_FILE: &str = "onnx/tokenizer.json";
17
18const MODEL_BLAKE3: &str = "";
20const TOKENIZER_BLAKE3: &str = "";
21
22#[derive(Error, Debug)]
23pub enum EmbedderError {
24 #[error("Model not found: {0}")]
25 ModelNotFound(String),
26 #[error("Tokenizer error: {0}")]
27 TokenizerError(String),
28 #[error("Inference failed: {0}")]
29 InferenceFailed(String),
30 #[error("Checksum mismatch for {path}: expected {expected}, got {actual}")]
31 ChecksumMismatch {
32 path: String,
33 expected: String,
34 actual: String,
35 },
36 #[error("Query cannot be empty")]
37 EmptyQuery,
38 #[error("HuggingFace Hub error: {0}")]
39 HfHubError(String),
40}
41
42impl From<ort::Error> for EmbedderError {
43 fn from(e: ort::Error) -> Self {
44 EmbedderError::InferenceFailed(e.to_string())
45 }
46}
47
48#[derive(Debug, Clone)]
54pub struct Embedding(Vec<f32>);
55
56pub const MODEL_DIM: usize = 768;
58pub const EMBEDDING_DIM: usize = 769;
60
61impl Embedding {
62 pub fn new(data: Vec<f32>) -> Self {
64 Self(data)
65 }
66
67 pub fn with_sentiment(mut self, sentiment: f32) -> Self {
72 debug_assert_eq!(self.0.len(), MODEL_DIM, "Expected 768-dim embedding");
73 self.0.push(sentiment.clamp(-1.0, 1.0));
74 self
75 }
76
77 pub fn sentiment(&self) -> Option<f32> {
79 if self.0.len() == EMBEDDING_DIM {
80 Some(self.0[MODEL_DIM])
81 } else {
82 None
83 }
84 }
85
86 pub fn as_slice(&self) -> &[f32] {
88 &self.0
89 }
90
91 pub fn as_vec(&self) -> &Vec<f32> {
93 &self.0
94 }
95
96 pub fn into_inner(self) -> Vec<f32> {
98 self.0
99 }
100
101 pub fn len(&self) -> usize {
103 self.0.len()
104 }
105
106 pub fn is_empty(&self) -> bool {
108 self.0.is_empty()
109 }
110}
111
112#[derive(Debug, Clone, Copy)]
114pub enum ExecutionProvider {
115 CUDA { device_id: i32 },
117 TensorRT { device_id: i32 },
119 CPU,
121}
122
123impl std::fmt::Display for ExecutionProvider {
124 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
125 match self {
126 ExecutionProvider::CUDA { device_id } => write!(f, "CUDA (device {})", device_id),
127 ExecutionProvider::TensorRT { device_id } => {
128 write!(f, "TensorRT (device {})", device_id)
129 }
130 ExecutionProvider::CPU => write!(f, "CPU"),
131 }
132 }
133}
134
135pub struct Embedder {
151 session: OnceCell<Mutex<Session>>,
153 tokenizer: OnceCell<tokenizers::Tokenizer>,
155 model_path: PathBuf,
157 tokenizer_path: PathBuf,
158 provider: ExecutionProvider,
159 max_length: usize,
160 batch_size: usize,
161 query_cache: Mutex<LruCache<String, Embedding>>,
163}
164
165impl Embedder {
166 pub fn new() -> Result<Self, EmbedderError> {
173 let (model_path, tokenizer_path) = ensure_model()?;
174 let provider = select_provider();
175
176 let batch_size = match provider {
177 ExecutionProvider::CPU => 4,
178 _ => 16,
179 };
180
181 let query_cache = Mutex::new(LruCache::new(
182 NonZeroUsize::new(100).expect("100 is non-zero"),
183 ));
184
185 Ok(Self {
186 session: OnceCell::new(),
187 tokenizer: OnceCell::new(),
188 model_path,
189 tokenizer_path,
190 provider,
191 max_length: 512,
192 batch_size,
193 query_cache,
194 })
195 }
196
197 pub fn new_cpu() -> Result<Self, EmbedderError> {
202 let (model_path, tokenizer_path) = ensure_model()?;
203
204 let query_cache = Mutex::new(LruCache::new(
205 NonZeroUsize::new(100).expect("100 is non-zero"),
206 ));
207
208 Ok(Self {
209 session: OnceCell::new(),
210 tokenizer: OnceCell::new(),
211 model_path,
212 tokenizer_path,
213 provider: ExecutionProvider::CPU,
214 max_length: 512,
215 batch_size: 4,
216 query_cache,
217 })
218 }
219
220 fn session(&self) -> Result<std::sync::MutexGuard<'_, Session>, EmbedderError> {
222 let session = self
223 .session
224 .get_or_try_init(|| create_session(&self.model_path, self.provider).map(Mutex::new))?;
225 Ok(session.lock().unwrap_or_else(|p| p.into_inner()))
226 }
227
228 fn tokenizer(&self) -> Result<&tokenizers::Tokenizer, EmbedderError> {
230 self.tokenizer.get_or_try_init(|| {
231 tokenizers::Tokenizer::from_file(&self.tokenizer_path)
232 .map_err(|e| EmbedderError::TokenizerError(e.to_string()))
233 })
234 }
235
236 pub fn token_count(&self, text: &str) -> Result<usize, EmbedderError> {
238 let encoding = self
239 .tokenizer()?
240 .encode(text, false)
241 .map_err(|e| EmbedderError::TokenizerError(e.to_string()))?;
242 Ok(encoding.get_ids().len())
243 }
244
245 pub fn split_into_windows(
249 &self,
250 text: &str,
251 max_tokens: usize,
252 overlap: usize,
253 ) -> Result<Vec<(String, u32)>, EmbedderError> {
254 let tokenizer = self.tokenizer()?;
255 let encoding = tokenizer
256 .encode(text, false)
257 .map_err(|e| EmbedderError::TokenizerError(e.to_string()))?;
258
259 let ids = encoding.get_ids();
260 if ids.len() <= max_tokens {
261 return Ok(vec![(text.to_string(), 0)]);
262 }
263
264 let mut windows = Vec::new();
265 let step = max_tokens.saturating_sub(overlap).max(1); let mut start = 0;
267 let mut window_idx = 0u32;
268
269 while start < ids.len() {
270 let end = (start + max_tokens).min(ids.len());
271 let window_ids: Vec<u32> = ids[start..end].to_vec();
272
273 let window_text = tokenizer
275 .decode(&window_ids, true)
276 .map_err(|e| EmbedderError::TokenizerError(e.to_string()))?;
277
278 windows.push((window_text, window_idx));
279 window_idx += 1;
280
281 if end >= ids.len() {
282 break;
283 }
284 start += step;
285 }
286
287 Ok(windows)
288 }
289
290 pub fn embed_documents(&mut self, texts: &[&str]) -> Result<Vec<Embedding>, EmbedderError> {
292 let prefixed: Vec<String> = texts.iter().map(|t| format!("passage: {}", t)).collect();
293 self.embed_batch(&prefixed)
294 }
295
296 pub fn embed_query(&mut self, text: &str) -> Result<Embedding, EmbedderError> {
298 let text = text.trim();
299 if text.is_empty() {
300 return Err(EmbedderError::EmptyQuery);
301 }
302
303 {
305 let mut cache = self.query_cache.lock().unwrap_or_else(|poisoned| {
306 tracing::debug!("Query cache lock poisoned, recovering");
307 poisoned.into_inner()
308 });
309 if let Some(cached) = cache.get(text) {
310 return Ok(cached.clone());
311 }
312 }
313
314 let prefixed = format!("query: {}", text);
316 let results = self.embed_batch(&[prefixed])?;
317 let base_embedding = results
318 .into_iter()
319 .next()
320 .expect("embed_batch with single item always returns one result");
321
322 let embedding = base_embedding.with_sentiment(0.0);
324
325 {
327 let mut cache = self.query_cache.lock().unwrap_or_else(|poisoned| {
328 tracing::debug!("Query cache lock poisoned, recovering");
329 poisoned.into_inner()
330 });
331 cache.put(text.to_string(), embedding.clone());
332 }
333
334 Ok(embedding)
335 }
336
337 pub fn provider(&self) -> ExecutionProvider {
339 self.provider
340 }
341
342 pub fn batch_size(&self) -> usize {
344 self.batch_size
345 }
346
347 pub fn warm(&mut self) -> Result<(), EmbedderError> {
349 let _ = self.embed_query("warmup")?;
350 Ok(())
351 }
352
353 fn embed_batch(&mut self, texts: &[String]) -> Result<Vec<Embedding>, EmbedderError> {
354 use ort::value::Tensor;
355
356 let _span = tracing::info_span!("embed_batch", count = texts.len()).entered();
357
358 if texts.is_empty() {
359 return Ok(vec![]);
360 }
361
362 let encodings = self
364 .tokenizer()?
365 .encode_batch(texts.to_vec(), true)
366 .map_err(|e| EmbedderError::TokenizerError(e.to_string()))?;
367
368 let input_ids: Vec<Vec<i64>> = encodings
370 .iter()
371 .map(|e| e.get_ids().iter().map(|&id| id as i64).collect())
372 .collect();
373 let attention_mask: Vec<Vec<i64>> = encodings
374 .iter()
375 .map(|e| e.get_attention_mask().iter().map(|&m| m as i64).collect())
376 .collect();
377
378 let max_len = input_ids
380 .iter()
381 .map(|v| v.len())
382 .max()
383 .unwrap_or(0)
384 .min(self.max_length);
385
386 let input_ids_arr = pad_2d_i64(&input_ids, max_len, 0);
388 let attention_mask_arr = pad_2d_i64(&attention_mask, max_len, 0);
389 let token_type_ids_arr = Array2::<i64>::zeros((texts.len(), max_len));
391
392 let input_ids_tensor = Tensor::from_array(input_ids_arr)?;
394 let attention_mask_tensor = Tensor::from_array(attention_mask_arr)?;
395 let token_type_ids_tensor = Tensor::from_array(token_type_ids_arr)?;
396
397 let mut session = self.session()?;
399 let outputs = session.run(ort::inputs![
400 "input_ids" => input_ids_tensor,
401 "attention_mask" => attention_mask_tensor,
402 "token_type_ids" => token_type_ids_tensor,
403 ])?;
404
405 let (_shape, data) = outputs["last_hidden_state"].try_extract_tensor::<f32>()?;
407
408 let batch_size = texts.len();
410 let seq_len = max_len;
411 let embedding_dim = 768;
412 let mut results = Vec::with_capacity(batch_size);
413
414 for (i, mask_vec) in attention_mask.iter().enumerate().take(batch_size) {
415 let mut sum = vec![0.0f32; embedding_dim];
416 let mut count = 0.0f32;
417
418 for j in 0..seq_len {
419 let mask = mask_vec.get(j).copied().unwrap_or(0) as f32;
420 if mask > 0.0 {
421 count += mask;
422 let offset = i * seq_len * embedding_dim + j * embedding_dim;
423 for (k, sum_val) in sum.iter_mut().enumerate() {
424 *sum_val += data[offset + k] * mask;
425 }
426 }
427 }
428
429 if count > 0.0 {
431 for sum_val in &mut sum {
432 *sum_val /= count;
433 }
434 }
435
436 results.push(Embedding::new(normalize_l2(sum)));
437 }
438
439 Ok(results)
440 }
441}
442
443fn ensure_model() -> Result<(PathBuf, PathBuf), EmbedderError> {
445 use hf_hub::api::sync::Api;
446
447 let api = Api::new().map_err(|e| EmbedderError::HfHubError(e.to_string()))?;
448 let repo = api.model(MODEL_REPO.to_string());
449
450 let model_path = repo
451 .get(MODEL_FILE)
452 .map_err(|e| EmbedderError::HfHubError(e.to_string()))?;
453 let tokenizer_path = repo
454 .get(TOKENIZER_FILE)
455 .map_err(|e| EmbedderError::HfHubError(e.to_string()))?;
456
457 if !MODEL_BLAKE3.is_empty() {
459 verify_checksum(&model_path, MODEL_BLAKE3)?;
460 }
461 if !TOKENIZER_BLAKE3.is_empty() {
462 verify_checksum(&tokenizer_path, TOKENIZER_BLAKE3)?;
463 }
464
465 Ok((model_path, tokenizer_path))
466}
467
468fn verify_checksum(path: &Path, expected: &str) -> Result<(), EmbedderError> {
470 let mut file =
471 std::fs::File::open(path).map_err(|e| EmbedderError::ModelNotFound(e.to_string()))?;
472 let mut hasher = blake3::Hasher::new();
473 std::io::copy(&mut file, &mut hasher)
474 .map_err(|e| EmbedderError::ModelNotFound(e.to_string()))?;
475 let actual = hasher.finalize().to_hex().to_string();
476
477 if actual != expected {
478 return Err(EmbedderError::ChecksumMismatch {
479 path: path.display().to_string(),
480 expected: expected.to_string(),
481 actual,
482 });
483 }
484 Ok(())
485}
486
487fn ensure_ort_provider_libs() {
493 let home = match std::env::var("HOME") {
495 Ok(h) => std::path::PathBuf::from(h),
496 Err(_) => return,
497 };
498 let ort_cache = home.join(".cache/ort.pyke.io/dfbin/x86_64-unknown-linux-gnu");
499
500 let ort_lib_dir = match std::fs::read_dir(&ort_cache) {
502 Ok(entries) => entries
503 .filter_map(|e| e.ok())
504 .filter(|e| e.path().is_dir())
505 .map(|e| e.path())
506 .next(),
507 Err(_) => return,
508 };
509
510 let ort_lib_dir = match ort_lib_dir {
511 Some(d) => d,
512 None => return,
513 };
514
515 let ld_path = std::env::var("LD_LIBRARY_PATH").unwrap_or_default();
517 let ort_cache_str = ort_cache.to_string_lossy();
518 let target_dir = ld_path
519 .split(':')
520 .find(|p| {
521 !p.is_empty() && std::path::Path::new(p).is_dir() && !p.contains(ort_cache_str.as_ref())
522 })
524 .map(std::path::PathBuf::from);
525
526 let target_dir = match target_dir {
527 Some(d) => d,
528 None => return, };
530
531 let provider_libs = [
533 "libonnxruntime_providers_shared.so",
534 "libonnxruntime_providers_cuda.so",
535 "libonnxruntime_providers_tensorrt.so",
536 ];
537
538 for lib in &provider_libs {
539 let src = ort_lib_dir.join(lib);
540 let dst = target_dir.join(lib);
541
542 if !src.exists() {
544 continue;
545 }
546
547 if dst.symlink_metadata().is_ok() {
549 if let Ok(target) = std::fs::read_link(&dst) {
550 if target == src {
551 continue; }
553 }
554 let _ = std::fs::remove_file(&dst);
556 }
557
558 if let Err(e) = std::os::unix::fs::symlink(&src, &dst) {
560 tracing::debug!("Failed to symlink {}: {}", lib, e);
561 } else {
562 tracing::info!("Created symlink: {} -> {}", dst.display(), src.display());
563 }
564 }
565}
566
567fn select_provider() -> ExecutionProvider {
569 use ort::ep::{TensorRT, CUDA};
570
571 ensure_ort_provider_libs();
573
574 let cuda = CUDA::default();
576 if cuda.is_available().unwrap_or(false) {
577 return ExecutionProvider::CUDA { device_id: 0 };
578 }
579
580 let tensorrt = TensorRT::default();
582 if tensorrt.is_available().unwrap_or(false) {
583 return ExecutionProvider::TensorRT { device_id: 0 };
584 }
585
586 ExecutionProvider::CPU
587}
588
589fn create_session(
591 model_path: &Path,
592 provider: ExecutionProvider,
593) -> Result<Session, EmbedderError> {
594 use ort::ep::{TensorRT, CUDA};
595
596 let builder = Session::builder()?;
597
598 let session = match provider {
599 ExecutionProvider::CUDA { device_id } => builder
600 .with_execution_providers([CUDA::default().with_device_id(device_id).build()])?
601 .commit_from_file(model_path)?,
602 ExecutionProvider::TensorRT { device_id } => {
603 builder
604 .with_execution_providers([
605 TensorRT::default().with_device_id(device_id).build(),
606 CUDA::default().with_device_id(device_id).build(),
608 ])?
609 .commit_from_file(model_path)?
610 }
611 ExecutionProvider::CPU => builder.commit_from_file(model_path)?,
612 };
613
614 Ok(session)
615}
616
617fn pad_2d_i64(inputs: &[Vec<i64>], max_len: usize, pad_value: i64) -> Array2<i64> {
619 let batch_size = inputs.len();
620 let mut arr = Array2::from_elem((batch_size, max_len), pad_value);
621 for (i, seq) in inputs.iter().enumerate() {
622 for (j, &val) in seq.iter().take(max_len).enumerate() {
623 arr[[i, j]] = val;
624 }
625 }
626 arr
627}
628
629fn normalize_l2(mut v: Vec<f32>) -> Vec<f32> {
631 let norm_sq: f32 = v.iter().fold(0.0, |acc, &x| acc + x * x);
632 if norm_sq > 0.0 {
633 let inv_norm = 1.0 / norm_sq.sqrt();
634 v.iter_mut().for_each(|x| *x *= inv_norm);
635 }
636 v
637}