Skip to main content

oxios_memory/memory/
embedding.rs

1#![allow(missing_docs)]
2//! Embedding abstraction for semantic similarity.
3//!
4//! Supports two embedding modes:
5//! - **Sparse (TF-IDF):** Zero-dependency, works for any language.
6//! - **Dense (f32):** Produced by GGUF models (EmbeddingGemma) or API-based models.
7//!
8//! Dense vectors are used by the HNSW index for fast ANN search.
9//! Sparse vectors serve as a fallback when no embedding model is available.
10
11use std::collections::HashMap;
12
13use anyhow::Result;
14
15/// An embedding vector for semantic similarity comparison.
16#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
17pub enum EmbeddingVector {
18    /// Dense vector from API-based embeddings (f64 for precision).
19    Dense(Vec<f64>),
20    /// Dense f32 vector (for HNSW index compatibility).
21    DenseF32(Vec<f32>),
22    /// Sparse TF-IDF vector (term → weight).
23    Sparse(HashMap<String, f64>),
24}
25
26impl EmbeddingVector {
27    /// Compute cosine similarity between two vectors.
28    pub fn cosine_similarity(&self, other: &Self) -> f64 {
29        match (self, other) {
30            (EmbeddingVector::Dense(a), EmbeddingVector::Dense(b)) => {
31                if a.len() != b.len() || a.is_empty() {
32                    return 0.0;
33                }
34                let dot: f64 = a.iter().zip(b).map(|(x, y)| x * y).sum();
35                let na: f64 = a.iter().map(|v| v * v).sum::<f64>().sqrt();
36                let nb: f64 = b.iter().map(|v| v * v).sum::<f64>().sqrt();
37                if na == 0.0 || nb == 0.0 {
38                    return 0.0;
39                }
40                dot / (na * nb)
41            }
42            (EmbeddingVector::DenseF32(a), EmbeddingVector::DenseF32(b)) => {
43                crate::memory::cosine_similarity_f32(a, b) as f64
44            }
45            (EmbeddingVector::Dense(a), EmbeddingVector::DenseF32(b))
46            | (EmbeddingVector::DenseF32(b), EmbeddingVector::Dense(a)) => {
47                // Cross-dense: convert f32 to f64 for comparison
48                let b_f64: Vec<f64> = b.iter().map(|&v| v as f64).collect();
49                let (aa, bb) = if matches!(self, EmbeddingVector::Dense(_)) {
50                    (a, &b_f64)
51                } else {
52                    (&b_f64, a)
53                };
54                if aa.is_empty() || bb.is_empty() || aa.len() != bb.len() {
55                    return 0.0;
56                }
57                let dot: f64 = aa.iter().zip(bb).map(|(x, y)| x * y).sum();
58                let na: f64 = aa.iter().map(|v| v * v).sum::<f64>().sqrt();
59                let nb: f64 = bb.iter().map(|v| v * v).sum::<f64>().sqrt();
60                if na == 0.0 || nb == 0.0 {
61                    return 0.0;
62                }
63                dot / (na * nb)
64            }
65            (EmbeddingVector::Sparse(a), EmbeddingVector::Sparse(b)) => {
66                if a.is_empty() || b.is_empty() {
67                    return 0.0;
68                }
69                let mut dot = 0.0;
70                for (term, w) in a {
71                    if let Some(w2) = b.get(term) {
72                        dot += w * w2;
73                    }
74                }
75                let na: f64 = a.values().map(|v| v * v).sum::<f64>().sqrt();
76                let nb: f64 = b.values().map(|v| v * v).sum::<f64>().sqrt();
77                if na == 0.0 || nb == 0.0 {
78                    return 0.0;
79                }
80                dot / (na * nb)
81            }
82            _ => 0.0, // Cross-type comparison not supported
83        }
84    }
85
86    /// Returns true if this vector is empty/zero.
87    pub fn is_empty(&self) -> bool {
88        match self {
89            EmbeddingVector::Dense(v) => v.is_empty(),
90            EmbeddingVector::DenseF32(v) => v.is_empty(),
91            EmbeddingVector::Sparse(m) => m.is_empty(),
92        }
93    }
94
95    /// Convert to f32 dense vector (for HNSW index).
96    ///
97    /// - `DenseF32` → clone
98    /// - `Dense` → cast f64 to f32
99    /// - `Sparse` → returns None (not convertible)
100    pub fn to_f32_dense(&self) -> Option<Vec<f32>> {
101        match self {
102            EmbeddingVector::DenseF32(v) => Some(v.clone()),
103            EmbeddingVector::Dense(v) => Some(v.iter().map(|&x| x as f32).collect()),
104            EmbeddingVector::Sparse(_) => None,
105        }
106    }
107
108    /// Get the dimensionality of the vector.
109    pub fn dimensions(&self) -> usize {
110        match self {
111            EmbeddingVector::Dense(v) => v.len(),
112            EmbeddingVector::DenseF32(v) => v.len(),
113            EmbeddingVector::Sparse(m) => m.len(),
114        }
115    }
116}
117
118/// Provider for generating text embeddings.
119#[async_trait::async_trait]
120pub trait EmbeddingProvider: Send + Sync {
121    /// Generate an embedding vector for the given text.
122    async fn embed(&self, text: &str) -> Result<EmbeddingVector>;
123    /// Name of this provider.
124    fn name(&self) -> &str;
125}
126
127/// TF-IDF based embedding provider (zero dependencies).
128pub struct TfIdfEmbeddingProvider;
129
130#[async_trait::async_trait]
131impl EmbeddingProvider for TfIdfEmbeddingProvider {
132    async fn embed(&self, text: &str) -> Result<EmbeddingVector> {
133        let tv = crate::memory::TextVector::from_text(text);
134        Ok(EmbeddingVector::Sparse(tv.tf_map().clone()))
135    }
136    fn name(&self) -> &str {
137        "tfidf"
138    }
139}
140
141// ---------------------------------------------------------------------------
142// Tests
143// ---------------------------------------------------------------------------
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148
149    #[test]
150    fn test_dense_f32_similarity() {
151        let a = EmbeddingVector::DenseF32(vec![1.0, 0.0, 0.0]);
152        let b = EmbeddingVector::DenseF32(vec![1.0, 0.0, 0.0]);
153        let sim = a.cosine_similarity(&b);
154        assert!((sim - 1.0).abs() < 1e-6, "identical should be 1.0");
155    }
156
157    #[test]
158    fn test_cross_dense_similarity() {
159        let a = EmbeddingVector::Dense(vec![1.0, 0.0, 0.0]);
160        let b = EmbeddingVector::DenseF32(vec![1.0, 0.0, 0.0]);
161        let sim = a.cosine_similarity(&b);
162        assert!((sim - 1.0).abs() < 1e-6, "cross-dense should be 1.0");
163    }
164
165    #[test]
166    fn test_to_f32_dense_from_dense() {
167        let v = EmbeddingVector::Dense(vec![1.0, 2.0]);
168        let f32 = v.to_f32_dense().unwrap();
169        assert_eq!(f32, vec![1.0f32, 2.0]);
170    }
171
172    #[test]
173    fn test_to_f32_dense_from_sparse_returns_none() {
174        let v = EmbeddingVector::Sparse(HashMap::from([("a".to_string(), 1.0)]));
175        assert!(v.to_f32_dense().is_none());
176    }
177
178    #[test]
179    fn test_dimensions() {
180        assert_eq!(EmbeddingVector::Dense(vec![1.0; 10]).dimensions(), 10);
181        assert_eq!(EmbeddingVector::DenseF32(vec![1.0; 5]).dimensions(), 5);
182        assert_eq!(EmbeddingVector::Sparse(HashMap::new()).dimensions(), 0);
183    }
184}