Skip to main content

conch_core/
embed.rs

1use std::sync::{Arc, Mutex};
2
3pub type Embedding = Vec<f32>;
4
5pub trait Embedder: Send + Sync {
6    fn embed(&self, texts: &[&str]) -> Result<Vec<Embedding>, EmbedError>;
7
8    fn embed_one(&self, text: &str) -> Result<Embedding, EmbedError> {
9        let mut results = self.embed(&[text])?;
10        results
11            .pop()
12            .ok_or_else(|| EmbedError::Other("empty embedding result".into()))
13    }
14
15    fn dimension(&self) -> usize;
16}
17
18#[derive(Debug, thiserror::Error)]
19pub enum EmbedError {
20    #[error("embedding model error: {0}")]
21    Model(String),
22    #[error("{0}")]
23    Other(String),
24}
25
26pub struct FastEmbedder {
27    model: Mutex<fastembed::TextEmbedding>,
28    dimension: usize,
29}
30
31impl FastEmbedder {
32    pub fn new() -> Result<Self, EmbedError> {
33        let model = fastembed::TextEmbedding::try_new(Default::default())
34            .map_err(|e| EmbedError::Model(e.to_string()))?;
35        Ok(Self {
36            model: Mutex::new(model),
37            dimension: 384,
38        })
39    }
40}
41
42impl Embedder for FastEmbedder {
43    fn embed(&self, texts: &[&str]) -> Result<Vec<Embedding>, EmbedError> {
44        let owned: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
45        self.model
46            .lock()
47            .map_err(|e| EmbedError::Other(e.to_string()))?
48            .embed(owned, None)
49            .map_err(|e| EmbedError::Model(e.to_string()))
50    }
51
52    fn dimension(&self) -> usize {
53        self.dimension
54    }
55}
56
57pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
58    if a.len() != b.len() || a.is_empty() {
59        return 0.0;
60    }
61    let mut dot = 0.0f32;
62    let mut norm_a = 0.0f32;
63    let mut norm_b = 0.0f32;
64    for i in 0..a.len() {
65        dot += a[i] * b[i];
66        norm_a += a[i] * a[i];
67        norm_b += b[i] * b[i];
68    }
69    let denom = norm_a.sqrt() * norm_b.sqrt();
70    if denom == 0.0 {
71        0.0
72    } else {
73        dot / denom
74    }
75}
76
77/// Shared embedder reference for passing across modules
78pub type SharedEmbedder = Arc<dyn Embedder>;