1use std::sync::{Arc, Mutex};
2
3pub type Embedding = Vec<f32>;
4
5pub trait Embedder: Send + Sync {
6 fn embed(&self, texts: &[&str]) -> Result<Vec<Embedding>, EmbedError>;
7
8 fn embed_one(&self, text: &str) -> Result<Embedding, EmbedError> {
9 let mut results = self.embed(&[text])?;
10 results
11 .pop()
12 .ok_or_else(|| EmbedError::Other("empty embedding result".into()))
13 }
14
15 fn dimension(&self) -> usize;
16}
17
18#[derive(Debug, thiserror::Error)]
19pub enum EmbedError {
20 #[error("embedding model error: {0}")]
21 Model(String),
22 #[error("{0}")]
23 Other(String),
24}
25
26pub struct FastEmbedder {
27 model: Mutex<fastembed::TextEmbedding>,
28 dimension: usize,
29}
30
31impl FastEmbedder {
32 pub fn new() -> Result<Self, EmbedError> {
33 let model = fastembed::TextEmbedding::try_new(Default::default())
34 .map_err(|e| EmbedError::Model(e.to_string()))?;
35 Ok(Self {
36 model: Mutex::new(model),
37 dimension: 384,
38 })
39 }
40}
41
42impl Embedder for FastEmbedder {
43 fn embed(&self, texts: &[&str]) -> Result<Vec<Embedding>, EmbedError> {
44 let owned: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
45 self.model
46 .lock()
47 .map_err(|e| EmbedError::Other(e.to_string()))?
48 .embed(owned, None)
49 .map_err(|e| EmbedError::Model(e.to_string()))
50 }
51
52 fn dimension(&self) -> usize {
53 self.dimension
54 }
55}
56
57pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
58 if a.len() != b.len() || a.is_empty() {
59 return 0.0;
60 }
61 let mut dot = 0.0f32;
62 let mut norm_a = 0.0f32;
63 let mut norm_b = 0.0f32;
64 for i in 0..a.len() {
65 dot += a[i] * b[i];
66 norm_a += a[i] * a[i];
67 norm_b += b[i] * b[i];
68 }
69 let denom = norm_a.sqrt() * norm_b.sqrt();
70 if denom == 0.0 {
71 0.0
72 } else {
73 dot / denom
74 }
75}
76
77pub type SharedEmbedder = Arc<dyn Embedder>;