Skip to main content

oxios_kernel/
embedding.rs

1//! Embedding abstraction for semantic similarity.
2//!
3//! Supports two embedding modes:
4//! - **Sparse (TF-IDF):** Zero-dependency, works for any language.
5//! - **Dense (f32):** Produced by ONNX or API-based models (OpenAI, etc.).
6//!
7//! Dense vectors are used by the HNSW index for fast ANN search.
8//! Sparse vectors serve as a fallback when no embedding model is available.
9
10use std::collections::HashMap;
11
12use anyhow::Result;
13
14/// An embedding vector for semantic similarity comparison.
15#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
16pub enum EmbeddingVector {
17    /// Dense vector from API-based embeddings (f64 for precision).
18    Dense(Vec<f64>),
19    /// Dense f32 vector (for HNSW index compatibility).
20    DenseF32(Vec<f32>),
21    /// Sparse TF-IDF vector (term → weight).
22    Sparse(HashMap<String, f64>),
23}
24
25impl EmbeddingVector {
26    /// Compute cosine similarity between two vectors.
27    pub fn cosine_similarity(&self, other: &Self) -> f64 {
28        match (self, other) {
29            (EmbeddingVector::Dense(a), EmbeddingVector::Dense(b)) => {
30                if a.len() != b.len() || a.is_empty() {
31                    return 0.0;
32                }
33                let dot: f64 = a.iter().zip(b).map(|(x, y)| x * y).sum();
34                let na: f64 = a.iter().map(|v| v * v).sum::<f64>().sqrt();
35                let nb: f64 = b.iter().map(|v| v * v).sum::<f64>().sqrt();
36                if na == 0.0 || nb == 0.0 {
37                    return 0.0;
38                }
39                dot / (na * nb)
40            }
41            (EmbeddingVector::DenseF32(a), EmbeddingVector::DenseF32(b)) => {
42                crate::memory::normalizer::cosine_similarity_f32(a, b) as f64
43            }
44            (EmbeddingVector::Dense(a), EmbeddingVector::DenseF32(b))
45            | (EmbeddingVector::DenseF32(b), EmbeddingVector::Dense(a)) => {
46                // Cross-dense: convert f32 to f64 for comparison
47                let b_f64: Vec<f64> = b.iter().map(|&v| v as f64).collect();
48                let (aa, bb) = if matches!(self, EmbeddingVector::Dense(_)) {
49                    (a, &b_f64)
50                } else {
51                    (&b_f64, a)
52                };
53                if aa.is_empty() || bb.is_empty() || aa.len() != bb.len() {
54                    return 0.0;
55                }
56                let dot: f64 = aa.iter().zip(bb).map(|(x, y)| x * y).sum();
57                let na: f64 = aa.iter().map(|v| v * v).sum::<f64>().sqrt();
58                let nb: f64 = bb.iter().map(|v| v * v).sum::<f64>().sqrt();
59                if na == 0.0 || nb == 0.0 {
60                    return 0.0;
61                }
62                dot / (na * nb)
63            }
64            (EmbeddingVector::Sparse(a), EmbeddingVector::Sparse(b)) => {
65                if a.is_empty() || b.is_empty() {
66                    return 0.0;
67                }
68                let mut dot = 0.0;
69                for (term, w) in a {
70                    if let Some(w2) = b.get(term) {
71                        dot += w * w2;
72                    }
73                }
74                let na: f64 = a.values().map(|v| v * v).sum::<f64>().sqrt();
75                let nb: f64 = b.values().map(|v| v * v).sum::<f64>().sqrt();
76                if na == 0.0 || nb == 0.0 {
77                    return 0.0;
78                }
79                dot / (na * nb)
80            }
81            _ => 0.0, // Cross-type comparison not supported
82        }
83    }
84
85    /// Returns true if this vector is empty/zero.
86    pub fn is_empty(&self) -> bool {
87        match self {
88            EmbeddingVector::Dense(v) => v.is_empty(),
89            EmbeddingVector::DenseF32(v) => v.is_empty(),
90            EmbeddingVector::Sparse(m) => m.is_empty(),
91        }
92    }
93
94    /// Convert to f32 dense vector (for HNSW index).
95    ///
96    /// - `DenseF32` → clone
97    /// - `Dense` → cast f64 to f32
98    /// - `Sparse` → returns None (not convertible)
99    pub fn to_f32_dense(&self) -> Option<Vec<f32>> {
100        match self {
101            EmbeddingVector::DenseF32(v) => Some(v.clone()),
102            EmbeddingVector::Dense(v) => Some(v.iter().map(|&x| x as f32).collect()),
103            EmbeddingVector::Sparse(_) => None,
104        }
105    }
106
107    /// Get the dimensionality of the vector.
108    pub fn dimensions(&self) -> usize {
109        match self {
110            EmbeddingVector::Dense(v) => v.len(),
111            EmbeddingVector::DenseF32(v) => v.len(),
112            EmbeddingVector::Sparse(m) => m.len(),
113        }
114    }
115}
116
117/// Provider for generating text embeddings.
118#[async_trait::async_trait]
119pub trait EmbeddingProvider: Send + Sync {
120    /// Generate an embedding vector for the given text.
121    async fn embed(&self, text: &str) -> Result<EmbeddingVector>;
122    /// Name of this provider.
123    fn name(&self) -> &str;
124}
125
126/// TF-IDF based embedding provider (zero dependencies).
127pub struct TfIdfEmbeddingProvider;
128
129#[async_trait::async_trait]
130impl EmbeddingProvider for TfIdfEmbeddingProvider {
131    async fn embed(&self, text: &str) -> Result<EmbeddingVector> {
132        let tv = crate::memory::TextVector::from_text(text);
133        Ok(EmbeddingVector::Sparse(tv.tf_map().clone()))
134    }
135    fn name(&self) -> &str {
136        "tfidf"
137    }
138}
139
140// ---------------------------------------------------------------------------
141// Tests
142// ---------------------------------------------------------------------------
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147
148    #[test]
149    fn test_dense_f32_similarity() {
150        let a = EmbeddingVector::DenseF32(vec![1.0, 0.0, 0.0]);
151        let b = EmbeddingVector::DenseF32(vec![1.0, 0.0, 0.0]);
152        let sim = a.cosine_similarity(&b);
153        assert!((sim - 1.0).abs() < 1e-6, "identical should be 1.0");
154    }
155
156    #[test]
157    fn test_cross_dense_similarity() {
158        let a = EmbeddingVector::Dense(vec![1.0, 0.0, 0.0]);
159        let b = EmbeddingVector::DenseF32(vec![1.0, 0.0, 0.0]);
160        let sim = a.cosine_similarity(&b);
161        assert!((sim - 1.0).abs() < 1e-6, "cross-dense should be 1.0");
162    }
163
164    #[test]
165    fn test_to_f32_dense_from_dense() {
166        let v = EmbeddingVector::Dense(vec![1.0, 2.0]);
167        let f32 = v.to_f32_dense().unwrap();
168        assert_eq!(f32, vec![1.0f32, 2.0]);
169    }
170
171    #[test]
172    fn test_to_f32_dense_from_sparse_returns_none() {
173        let v = EmbeddingVector::Sparse(HashMap::from([("a".to_string(), 1.0)]));
174        assert!(v.to_f32_dense().is_none());
175    }
176
177    #[test]
178    fn test_dimensions() {
179        assert_eq!(EmbeddingVector::Dense(vec![1.0; 10]).dimensions(), 10);
180        assert_eq!(EmbeddingVector::DenseF32(vec![1.0; 5]).dimensions(), 5);
181        assert_eq!(EmbeddingVector::Sparse(HashMap::new()).dimensions(), 0);
182    }
183}