ctxgraph_embed/
encoder.rs1use std::path::PathBuf;
2
3use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
4
5use crate::EmbedError;
6
7pub struct EmbedEngine {
8 model: TextEmbedding,
9}
10
11impl EmbedEngine {
12 pub fn new() -> Result<Self, EmbedError> {
14 let model = TextEmbedding::try_new(
15 InitOptions::new(EmbeddingModel::AllMiniLML6V2),
16 )
17 .map_err(|e| EmbedError::ModelInit(e.to_string()))?;
18 Ok(Self { model })
19 }
20
21 pub fn new_with_cache(cache_dir: PathBuf) -> Result<Self, EmbedError> {
23 let model = TextEmbedding::try_new(
24 InitOptions::new(EmbeddingModel::AllMiniLML6V2).with_cache_dir(cache_dir),
25 )
26 .map_err(|e| EmbedError::ModelInit(e.to_string()))?;
27 Ok(Self { model })
28 }
29
30 pub fn embed(&self, text: &str) -> Result<Vec<f32>, EmbedError> {
32 let mut batch = self
33 .model
34 .embed(vec![text], None)
35 .map_err(|e| EmbedError::Encoding(e.to_string()))?;
36 batch
37 .pop()
38 .ok_or_else(|| EmbedError::Encoding("empty embedding result".to_string()))
39 }
40
41 pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbedError> {
43 self.model
44 .embed(texts.to_vec(), None)
45 .map_err(|e| EmbedError::Encoding(e.to_string()))
46 }
47
48 pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
51 if a.len() != b.len() || a.is_empty() {
52 return 0.0;
53 }
54 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
55 let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
56 let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
57 if mag_a == 0.0 || mag_b == 0.0 {
58 0.0
59 } else {
60 dot / (mag_a * mag_b)
61 }
62 }
63}