oxios_kernel/
embedding.rs1use std::collections::HashMap;
11
12use anyhow::Result;
13
14#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
16pub enum EmbeddingVector {
17 Dense(Vec<f64>),
19 DenseF32(Vec<f32>),
21 Sparse(HashMap<String, f64>),
23}
24
25impl EmbeddingVector {
26 pub fn cosine_similarity(&self, other: &Self) -> f64 {
28 match (self, other) {
29 (EmbeddingVector::Dense(a), EmbeddingVector::Dense(b)) => {
30 if a.len() != b.len() || a.is_empty() {
31 return 0.0;
32 }
33 let dot: f64 = a.iter().zip(b).map(|(x, y)| x * y).sum();
34 let na: f64 = a.iter().map(|v| v * v).sum::<f64>().sqrt();
35 let nb: f64 = b.iter().map(|v| v * v).sum::<f64>().sqrt();
36 if na == 0.0 || nb == 0.0 {
37 return 0.0;
38 }
39 dot / (na * nb)
40 }
41 (EmbeddingVector::DenseF32(a), EmbeddingVector::DenseF32(b)) => {
42 crate::memory::normalizer::cosine_similarity_f32(a, b) as f64
43 }
44 (EmbeddingVector::Dense(a), EmbeddingVector::DenseF32(b))
45 | (EmbeddingVector::DenseF32(b), EmbeddingVector::Dense(a)) => {
46 let b_f64: Vec<f64> = b.iter().map(|&v| v as f64).collect();
48 let (aa, bb) = if matches!(self, EmbeddingVector::Dense(_)) {
49 (a, &b_f64)
50 } else {
51 (&b_f64, a)
52 };
53 if aa.is_empty() || bb.is_empty() || aa.len() != bb.len() {
54 return 0.0;
55 }
56 let dot: f64 = aa.iter().zip(bb).map(|(x, y)| x * y).sum();
57 let na: f64 = aa.iter().map(|v| v * v).sum::<f64>().sqrt();
58 let nb: f64 = bb.iter().map(|v| v * v).sum::<f64>().sqrt();
59 if na == 0.0 || nb == 0.0 {
60 return 0.0;
61 }
62 dot / (na * nb)
63 }
64 (EmbeddingVector::Sparse(a), EmbeddingVector::Sparse(b)) => {
65 if a.is_empty() || b.is_empty() {
66 return 0.0;
67 }
68 let mut dot = 0.0;
69 for (term, w) in a {
70 if let Some(w2) = b.get(term) {
71 dot += w * w2;
72 }
73 }
74 let na: f64 = a.values().map(|v| v * v).sum::<f64>().sqrt();
75 let nb: f64 = b.values().map(|v| v * v).sum::<f64>().sqrt();
76 if na == 0.0 || nb == 0.0 {
77 return 0.0;
78 }
79 dot / (na * nb)
80 }
81 _ => 0.0, }
83 }
84
85 pub fn is_empty(&self) -> bool {
87 match self {
88 EmbeddingVector::Dense(v) => v.is_empty(),
89 EmbeddingVector::DenseF32(v) => v.is_empty(),
90 EmbeddingVector::Sparse(m) => m.is_empty(),
91 }
92 }
93
94 pub fn to_f32_dense(&self) -> Option<Vec<f32>> {
100 match self {
101 EmbeddingVector::DenseF32(v) => Some(v.clone()),
102 EmbeddingVector::Dense(v) => Some(v.iter().map(|&x| x as f32).collect()),
103 EmbeddingVector::Sparse(_) => None,
104 }
105 }
106
107 pub fn dimensions(&self) -> usize {
109 match self {
110 EmbeddingVector::Dense(v) => v.len(),
111 EmbeddingVector::DenseF32(v) => v.len(),
112 EmbeddingVector::Sparse(m) => m.len(),
113 }
114 }
115}
116
117#[async_trait::async_trait]
119pub trait EmbeddingProvider: Send + Sync {
120 async fn embed(&self, text: &str) -> Result<EmbeddingVector>;
122 fn name(&self) -> &str;
124}
125
126pub struct TfIdfEmbeddingProvider;
128
129#[async_trait::async_trait]
130impl EmbeddingProvider for TfIdfEmbeddingProvider {
131 async fn embed(&self, text: &str) -> Result<EmbeddingVector> {
132 let tv = crate::memory::TextVector::from_text(text);
133 Ok(EmbeddingVector::Sparse(tv.tf_map().clone()))
134 }
135 fn name(&self) -> &str {
136 "tfidf"
137 }
138}
139
140#[cfg(test)]
145mod tests {
146 use super::*;
147
148 #[test]
149 fn test_dense_f32_similarity() {
150 let a = EmbeddingVector::DenseF32(vec![1.0, 0.0, 0.0]);
151 let b = EmbeddingVector::DenseF32(vec![1.0, 0.0, 0.0]);
152 let sim = a.cosine_similarity(&b);
153 assert!((sim - 1.0).abs() < 1e-6, "identical should be 1.0");
154 }
155
156 #[test]
157 fn test_cross_dense_similarity() {
158 let a = EmbeddingVector::Dense(vec![1.0, 0.0, 0.0]);
159 let b = EmbeddingVector::DenseF32(vec![1.0, 0.0, 0.0]);
160 let sim = a.cosine_similarity(&b);
161 assert!((sim - 1.0).abs() < 1e-6, "cross-dense should be 1.0");
162 }
163
164 #[test]
165 fn test_to_f32_dense_from_dense() {
166 let v = EmbeddingVector::Dense(vec![1.0, 2.0]);
167 let f32 = v.to_f32_dense().unwrap();
168 assert_eq!(f32, vec![1.0f32, 2.0]);
169 }
170
171 #[test]
172 fn test_to_f32_dense_from_sparse_returns_none() {
173 let v = EmbeddingVector::Sparse(HashMap::from([("a".to_string(), 1.0)]));
174 assert!(v.to_f32_dense().is_none());
175 }
176
177 #[test]
178 fn test_dimensions() {
179 assert_eq!(EmbeddingVector::Dense(vec![1.0; 10]).dimensions(), 10);
180 assert_eq!(EmbeddingVector::DenseF32(vec![1.0; 5]).dimensions(), 5);
181 assert_eq!(EmbeddingVector::Sparse(HashMap::new()).dimensions(), 0);
182 }
183}