Skip to main content

normalize_semantic/
embedder.rs

1//! Embedding generation via fastembed (ONNX-backed, no server required).
2//!
3//! The embedder wraps a fastembed `TextEmbedding` model and serializes/
4//! deserializes raw f32 vectors for storage in SQLite BLOBs.
5
6use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
7use std::path::Path;
8
9/// Default embedding model — nomic-embed-text-v1.5 gives 768 dimensions and
10/// good code+text mixed-domain performance with Matryoshka support.
11pub const DEFAULT_MODEL: &str = "nomic-embed-text-v1.5";
12
13/// Wraps the fastembed model and provides encode/decode helpers.
14pub struct Embedder {
15    model: TextEmbedding,
16    pub model_name: String,
17    pub dimensions: usize,
18}
19
20impl Embedder {
21    /// Load the model, downloading it if necessary.
22    ///
23    /// `cache_dir` is the directory used for ONNX model caching (typically
24    /// `~/.cache/huggingface` or similar); if `None` fastembed uses its default.
25    pub fn load(model_name: &str, cache_dir: Option<&Path>) -> anyhow::Result<Self> {
26        let embedding_model = resolve_model(model_name)?;
27        let mut opts = InitOptions::new(embedding_model);
28        if let Some(dir) = cache_dir {
29            opts = opts.with_cache_dir(dir.to_path_buf());
30        }
31        let mut model = TextEmbedding::try_new(opts).map_err(|e| {
32            anyhow::anyhow!("Failed to load embedding model '{}': {}", model_name, e)
33        })?;
34
35        // Probe dimensions by embedding an empty string.
36        let probe = model
37            .embed(vec![""], None)
38            .map_err(|e| anyhow::anyhow!("Failed to probe embedding dimensions: {}", e))?;
39        let dimensions = probe.first().map(|v| v.len()).unwrap_or(768);
40
41        Ok(Self {
42            model,
43            model_name: model_name.to_string(),
44            dimensions,
45        })
46    }
47
48    /// Embed a batch of texts. Returns one vector per input, in order.
49    pub fn embed_batch(&mut self, texts: &[&str]) -> anyhow::Result<Vec<Vec<f32>>> {
50        self.model
51            .embed(texts, None)
52            .map_err(|e| anyhow::anyhow!("Embedding failed: {}", e))
53    }
54
55    /// Embed a single text.
56    pub fn embed_one(&mut self, text: &str) -> anyhow::Result<Vec<f32>> {
57        let batch = self.embed_batch(&[text])?;
58        batch
59            .into_iter()
60            .next()
61            .ok_or_else(|| anyhow::anyhow!("Embedder returned empty result for single text"))
62    }
63}
64
65/// Convert a slice of f32 to a little-endian byte blob for SQLite storage.
66pub fn encode_vector(v: &[f32]) -> Vec<u8> {
67    let mut bytes = Vec::with_capacity(v.len() * 4);
68    for &x in v {
69        bytes.extend_from_slice(&x.to_le_bytes());
70    }
71    bytes
72}
73
74/// Decode a little-endian byte blob back to f32 values.
75pub fn decode_vector(bytes: &[u8]) -> Vec<f32> {
76    bytes
77        .chunks_exact(4)
78        .map(|b| f32::from_le_bytes(b.try_into().unwrap_or([0u8; 4])))
79        .collect()
80}
81
82/// Cosine similarity between two equal-length vectors.
83/// Returns 0.0 if either vector has zero magnitude.
84pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
85    debug_assert_eq!(
86        a.len(),
87        b.len(),
88        "cosine_similarity: vector length mismatch"
89    );
90    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
91    let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
92    let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
93    if mag_a == 0.0 || mag_b == 0.0 {
94        return 0.0;
95    }
96    (dot / (mag_a * mag_b)).clamp(-1.0, 1.0)
97}
98
99/// Return the known output dimensionality for a model without loading it.
100///
101/// Returns `None` for unknown models (caller should default to 768 or probe at
102/// load time via [`Embedder::dimensions`]).
103pub fn dims_for_model(name: &str) -> Option<usize> {
104    match name {
105        "nomic-embed-text-v1.5" => Some(768),
106        "all-MiniLM-L6-v2" => Some(384),
107        "all-MiniLM-L12-v2" => Some(384),
108        _ => None,
109    }
110}
111
112/// Resolve a model name string to a fastembed `EmbeddingModel`.
113fn resolve_model(name: &str) -> anyhow::Result<EmbeddingModel> {
114    match name {
115        "nomic-embed-text-v1.5" => Ok(EmbeddingModel::NomicEmbedTextV15),
116        "all-MiniLM-L6-v2" => Ok(EmbeddingModel::AllMiniLML6V2),
117        "all-MiniLM-L12-v2" => Ok(EmbeddingModel::AllMiniLML12V2),
118        other => Err(anyhow::anyhow!(
119            "Unknown embedding model '{}'. Supported: nomic-embed-text-v1.5, all-MiniLM-L6-v2, all-MiniLM-L12-v2",
120            other
121        )),
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128
129    #[test]
130    fn test_encode_decode_roundtrip() {
131        let original = vec![1.0_f32, -0.5, 0.25, 0.0];
132        let bytes = encode_vector(&original);
133        let decoded = decode_vector(&bytes);
134        for (a, b) in original.iter().zip(decoded.iter()) {
135            assert!((a - b).abs() < 1e-7, "roundtrip mismatch: {} vs {}", a, b);
136        }
137    }
138
139    #[test]
140    fn test_cosine_similarity_identical() {
141        let v = vec![1.0_f32, 2.0, 3.0];
142        let sim = cosine_similarity(&v, &v);
143        assert!(
144            (sim - 1.0).abs() < 1e-6,
145            "identical vectors should have sim=1"
146        );
147    }
148
149    #[test]
150    fn test_cosine_similarity_orthogonal() {
151        let a = vec![1.0_f32, 0.0, 0.0];
152        let b = vec![0.0_f32, 1.0, 0.0];
153        let sim = cosine_similarity(&a, &b);
154        assert!(sim.abs() < 1e-6, "orthogonal vectors should have sim=0");
155    }
156
157    #[test]
158    fn test_cosine_zero_vector() {
159        let a = vec![0.0_f32, 0.0, 0.0];
160        let b = vec![1.0_f32, 2.0, 3.0];
161        let sim = cosine_similarity(&a, &b);
162        assert_eq!(sim, 0.0);
163    }
164}