Skip to main content

ctxgraph_embed/
encoder.rs

1use std::path::PathBuf;
2
3use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
4
5use crate::EmbedError;
6
7pub struct EmbedEngine {
8    model: TextEmbedding,
9}
10
11impl EmbedEngine {
12    /// Initialize with default cache directory (~/.cache/fastembed).
13    pub fn new() -> Result<Self, EmbedError> {
14        let model = TextEmbedding::try_new(
15            InitOptions::new(EmbeddingModel::AllMiniLML6V2),
16        )
17        .map_err(|e| EmbedError::ModelInit(e.to_string()))?;
18        Ok(Self { model })
19    }
20
21    /// Initialize with a custom model cache directory.
22    pub fn new_with_cache(cache_dir: PathBuf) -> Result<Self, EmbedError> {
23        let model = TextEmbedding::try_new(
24            InitOptions::new(EmbeddingModel::AllMiniLML6V2).with_cache_dir(cache_dir),
25        )
26        .map_err(|e| EmbedError::ModelInit(e.to_string()))?;
27        Ok(Self { model })
28    }
29
30    /// Embed a single text string. Returns a 384-dimensional vector.
31    pub fn embed(&self, text: &str) -> Result<Vec<f32>, EmbedError> {
32        let mut batch = self
33            .model
34            .embed(vec![text], None)
35            .map_err(|e| EmbedError::Encoding(e.to_string()))?;
36        batch
37            .pop()
38            .ok_or_else(|| EmbedError::Encoding("empty embedding result".to_string()))
39    }
40
41    /// Embed a batch of texts. Returns one vector per input text.
42    pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbedError> {
43        self.model
44            .embed(texts.to_vec(), None)
45            .map_err(|e| EmbedError::Encoding(e.to_string()))
46    }
47
48    /// Compute cosine similarity between two f32 vectors.
49    /// Returns 0.0 if either vector has zero magnitude.
50    pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
51        if a.len() != b.len() || a.is_empty() {
52            return 0.0;
53        }
54        let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
55        let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
56        let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
57        if mag_a == 0.0 || mag_b == 0.0 {
58            0.0
59        } else {
60            dot / (mag_a * mag_b)
61        }
62    }
63}