memoir_core/embedding/
onnx.rs1use std::sync::{Arc, Mutex};
4
5use super::{EmbeddingError, EmbeddingModel};
6
7const ONNX_DIMENSIONS: usize = 384;
8
9pub struct OnnxEmbedding {
14 model: Arc<Mutex<fastembed::TextEmbedding>>,
15}
16
17impl std::fmt::Debug for OnnxEmbedding {
18 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19 f.debug_struct("OnnxEmbedding").finish_non_exhaustive()
20 }
21}
22
23impl OnnxEmbedding {
24 pub fn new() -> Result<Self, EmbeddingError> {
31 let options = fastembed::InitOptions::new(fastembed::EmbeddingModel::BGESmallENV15);
32 let model = fastembed::TextEmbedding::try_new(options)
33 .map_err(|e| EmbeddingError::Init(e.to_string()))?;
34 Ok(Self {
35 model: Arc::new(Mutex::new(model)),
36 })
37 }
38}
39
40impl EmbeddingModel for OnnxEmbedding {
41 fn embed(&self, text: &str) -> impl std::future::Future<Output = Result<Vec<f32>, EmbeddingError>> + Send {
42 let model = self.model.clone();
43 let text = text.to_owned();
44 async move {
45 tokio::task::spawn_blocking(move || {
46 let mut guard = model
47 .lock()
48 .map_err(|e| EmbeddingError::Embed(format!("model lock poisoned: {e}")))?;
49 let mut results = guard
50 .embed(vec![&text], None)
51 .map_err(|e| EmbeddingError::Embed(e.to_string()))?;
52 results
53 .pop()
54 .ok_or_else(|| EmbeddingError::Embed("empty result from model".into()))
55 })
56 .await
57 .map_err(|e| EmbeddingError::Embed(format!("join error: {e}")))?
58 }
59 }
60
61 fn dimensions(&self) -> usize {
62 ONNX_DIMENSIONS
63 }
64}
65
66#[cfg(test)]
67mod tests {
68 use super::*;
69
70 #[test]
71 fn should_report_onnx_dimensions_as_384() {
72 assert_eq!(ONNX_DIMENSIONS, 384);
73 }
74}