oxios_memory/memory/
embedding.rs1#![allow(missing_docs)]
2use std::collections::HashMap;
12
13use anyhow::Result;
14
15#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
17pub enum EmbeddingVector {
18 Dense(Vec<f64>),
20 DenseF32(Vec<f32>),
22 Sparse(HashMap<String, f64>),
24}
25
26impl EmbeddingVector {
27 pub fn cosine_similarity(&self, other: &Self) -> f64 {
29 match (self, other) {
30 (EmbeddingVector::Dense(a), EmbeddingVector::Dense(b)) => {
31 if a.len() != b.len() || a.is_empty() {
32 return 0.0;
33 }
34 let dot: f64 = a.iter().zip(b).map(|(x, y)| x * y).sum();
35 let na: f64 = a.iter().map(|v| v * v).sum::<f64>().sqrt();
36 let nb: f64 = b.iter().map(|v| v * v).sum::<f64>().sqrt();
37 if na == 0.0 || nb == 0.0 {
38 return 0.0;
39 }
40 dot / (na * nb)
41 }
42 (EmbeddingVector::DenseF32(a), EmbeddingVector::DenseF32(b)) => {
43 crate::memory::cosine_similarity_f32(a, b) as f64
44 }
45 (EmbeddingVector::Dense(a), EmbeddingVector::DenseF32(b))
46 | (EmbeddingVector::DenseF32(b), EmbeddingVector::Dense(a)) => {
47 let b_f64: Vec<f64> = b.iter().map(|&v| v as f64).collect();
49 let (aa, bb) = if matches!(self, EmbeddingVector::Dense(_)) {
50 (a, &b_f64)
51 } else {
52 (&b_f64, a)
53 };
54 if aa.is_empty() || bb.is_empty() || aa.len() != bb.len() {
55 return 0.0;
56 }
57 let dot: f64 = aa.iter().zip(bb).map(|(x, y)| x * y).sum();
58 let na: f64 = aa.iter().map(|v| v * v).sum::<f64>().sqrt();
59 let nb: f64 = bb.iter().map(|v| v * v).sum::<f64>().sqrt();
60 if na == 0.0 || nb == 0.0 {
61 return 0.0;
62 }
63 dot / (na * nb)
64 }
65 (EmbeddingVector::Sparse(a), EmbeddingVector::Sparse(b)) => {
66 if a.is_empty() || b.is_empty() {
67 return 0.0;
68 }
69 let mut dot = 0.0;
70 for (term, w) in a {
71 if let Some(w2) = b.get(term) {
72 dot += w * w2;
73 }
74 }
75 let na: f64 = a.values().map(|v| v * v).sum::<f64>().sqrt();
76 let nb: f64 = b.values().map(|v| v * v).sum::<f64>().sqrt();
77 if na == 0.0 || nb == 0.0 {
78 return 0.0;
79 }
80 dot / (na * nb)
81 }
82 _ => 0.0, }
84 }
85
86 pub fn is_empty(&self) -> bool {
88 match self {
89 EmbeddingVector::Dense(v) => v.is_empty(),
90 EmbeddingVector::DenseF32(v) => v.is_empty(),
91 EmbeddingVector::Sparse(m) => m.is_empty(),
92 }
93 }
94
95 pub fn to_f32_dense(&self) -> Option<Vec<f32>> {
101 match self {
102 EmbeddingVector::DenseF32(v) => Some(v.clone()),
103 EmbeddingVector::Dense(v) => Some(v.iter().map(|&x| x as f32).collect()),
104 EmbeddingVector::Sparse(_) => None,
105 }
106 }
107
108 pub fn dimensions(&self) -> usize {
110 match self {
111 EmbeddingVector::Dense(v) => v.len(),
112 EmbeddingVector::DenseF32(v) => v.len(),
113 EmbeddingVector::Sparse(m) => m.len(),
114 }
115 }
116}
117
118#[async_trait::async_trait]
120pub trait EmbeddingProvider: Send + Sync {
121 async fn embed(&self, text: &str) -> Result<EmbeddingVector>;
123 fn name(&self) -> &str;
125}
126
127pub struct TfIdfEmbeddingProvider;
129
130#[async_trait::async_trait]
131impl EmbeddingProvider for TfIdfEmbeddingProvider {
132 async fn embed(&self, text: &str) -> Result<EmbeddingVector> {
133 let tv = crate::memory::TextVector::from_text(text);
134 Ok(EmbeddingVector::Sparse(tv.tf_map().clone()))
135 }
136 fn name(&self) -> &str {
137 "tfidf"
138 }
139}
140
141#[cfg(test)]
146mod tests {
147 use super::*;
148
149 #[test]
150 fn test_dense_f32_similarity() {
151 let a = EmbeddingVector::DenseF32(vec![1.0, 0.0, 0.0]);
152 let b = EmbeddingVector::DenseF32(vec![1.0, 0.0, 0.0]);
153 let sim = a.cosine_similarity(&b);
154 assert!((sim - 1.0).abs() < 1e-6, "identical should be 1.0");
155 }
156
157 #[test]
158 fn test_cross_dense_similarity() {
159 let a = EmbeddingVector::Dense(vec![1.0, 0.0, 0.0]);
160 let b = EmbeddingVector::DenseF32(vec![1.0, 0.0, 0.0]);
161 let sim = a.cosine_similarity(&b);
162 assert!((sim - 1.0).abs() < 1e-6, "cross-dense should be 1.0");
163 }
164
165 #[test]
166 fn test_to_f32_dense_from_dense() {
167 let v = EmbeddingVector::Dense(vec![1.0, 2.0]);
168 let f32 = v.to_f32_dense().unwrap();
169 assert_eq!(f32, vec![1.0f32, 2.0]);
170 }
171
172 #[test]
173 fn test_to_f32_dense_from_sparse_returns_none() {
174 let v = EmbeddingVector::Sparse(HashMap::from([("a".to_string(), 1.0)]));
175 assert!(v.to_f32_dense().is_none());
176 }
177
178 #[test]
179 fn test_dimensions() {
180 assert_eq!(EmbeddingVector::Dense(vec![1.0; 10]).dimensions(), 10);
181 assert_eq!(EmbeddingVector::DenseF32(vec![1.0; 5]).dimensions(), 5);
182 assert_eq!(EmbeddingVector::Sparse(HashMap::new()).dimensions(), 0);
183 }
184}