nt_memory/agentdb/
embeddings.rs1use serde::Serialize;
4use std::hash::{Hash, Hasher};
5
6pub type Embedding = Vec<f32>;
8
9#[async_trait::async_trait]
11pub trait EmbeddingProvider: Send + Sync {
12 async fn embed(&self, text: &str) -> anyhow::Result<Embedding>;
14
15 async fn embed_batch(&self, texts: &[String]) -> anyhow::Result<Vec<Embedding>>;
17
18 fn dimension(&self) -> usize;
20}
21
22pub struct DeterministicEmbedder {
24 dimension: usize,
25}
26
27impl DeterministicEmbedder {
28 pub fn new(dimension: usize) -> Self {
29 Self { dimension }
30 }
31
32 fn hash_to_embedding(&self, text: &str) -> Embedding {
33 use std::collections::hash_map::DefaultHasher;
34
35 let mut hasher = DefaultHasher::new();
36 text.hash(&mut hasher);
37 let hash = hasher.finish();
38
39 let mut embedding = Vec::with_capacity(self.dimension);
41 let mut current_hash = hash;
42
43 for _ in 0..self.dimension {
44 current_hash = current_hash.wrapping_mul(1103515245).wrapping_add(12345);
46 let value = ((current_hash >> 16) & 0xFFFF) as f32 / 65535.0;
47 embedding.push(value * 2.0 - 1.0); }
49
50 let magnitude: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
52 if magnitude > 0.0 {
53 embedding.iter_mut().for_each(|x| *x /= magnitude);
54 }
55
56 embedding
57 }
58}
59
60#[async_trait::async_trait]
61impl EmbeddingProvider for DeterministicEmbedder {
62 async fn embed(&self, text: &str) -> anyhow::Result<Embedding> {
63 Ok(self.hash_to_embedding(text))
64 }
65
66 async fn embed_batch(&self, texts: &[String]) -> anyhow::Result<Vec<Embedding>> {
67 Ok(texts
68 .iter()
69 .map(|text| self.hash_to_embedding(text))
70 .collect())
71 }
72
73 fn dimension(&self) -> usize {
74 self.dimension
75 }
76}
77
78pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
80 if a.len() != b.len() {
81 return 0.0;
82 }
83
84 let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
85
86 let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
87 let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
88
89 if mag_a > 0.0 && mag_b > 0.0 {
90 dot_product / (mag_a * mag_b)
91 } else {
92 0.0
93 }
94}
95
96pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
98 if a.len() != b.len() {
99 return f32::MAX;
100 }
101
102 a.iter()
103 .zip(b.iter())
104 .map(|(x, y)| (x - y).powi(2))
105 .sum::<f32>()
106 .sqrt()
107}
108
109#[cfg(test)]
110mod tests {
111 use super::*;
112
113 #[tokio::test]
114 async fn test_deterministic_embedder() {
115 let embedder = DeterministicEmbedder::new(384);
116
117 let text = "test string";
118 let embedding1 = embedder.embed(text).await.unwrap();
119 let embedding2 = embedder.embed(text).await.unwrap();
120
121 assert_eq!(embedding1, embedding2);
123 assert_eq!(embedding1.len(), 384);
124 }
125
126 #[tokio::test]
127 async fn test_batch_embedding() {
128 let embedder = DeterministicEmbedder::new(128);
129
130 let texts = vec!["hello".to_string(), "world".to_string()];
131 let embeddings = embedder.embed_batch(&texts).await.unwrap();
132
133 assert_eq!(embeddings.len(), 2);
134 assert_eq!(embeddings[0].len(), 128);
135 assert_eq!(embeddings[1].len(), 128);
136 }
137
138 #[test]
139 fn test_cosine_similarity() {
140 let a = vec![1.0, 0.0, 0.0];
141 let b = vec![1.0, 0.0, 0.0];
142 let c = vec![0.0, 1.0, 0.0];
143
144 assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
146
147 assert!(cosine_similarity(&a, &c).abs() < 0.001);
149 }
150
151 #[test]
152 fn test_euclidean_distance() {
153 let a = vec![0.0, 0.0];
154 let b = vec![3.0, 4.0];
155
156 assert!((euclidean_distance(&a, &b) - 5.0).abs() < 0.001);
158 }
159}