agentzero_core/
embedding.rs1use async_trait::async_trait;
4
5#[async_trait]
7pub trait EmbeddingProvider: Send + Sync {
8 async fn embed(&self, text: &str) -> anyhow::Result<Vec<f32>>;
10
11 fn dimensions(&self) -> usize;
13}
14
15pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
20 if a.len() != b.len() || a.is_empty() {
21 return 0.0;
22 }
23
24 let mut dot = 0.0_f32;
25 let mut mag_a = 0.0_f32;
26 let mut mag_b = 0.0_f32;
27
28 for (x, y) in a.iter().zip(b.iter()) {
29 dot += x * y;
30 mag_a += x * x;
31 mag_b += y * y;
32 }
33
34 let denom = mag_a.sqrt() * mag_b.sqrt();
35 if denom == 0.0 {
36 return 0.0;
37 }
38
39 dot / denom
40}
41
42pub fn embedding_to_bytes(embedding: &[f32]) -> Vec<u8> {
44 let mut bytes = Vec::with_capacity(embedding.len() * 4);
45 for &val in embedding {
46 bytes.extend_from_slice(&val.to_le_bytes());
47 }
48 bytes
49}
50
51pub fn bytes_to_embedding(bytes: &[u8]) -> Vec<f32> {
53 bytes
54 .chunks_exact(4)
55 .map(|chunk| {
56 let arr: [u8; 4] = chunk.try_into().expect("chunk is exactly 4 bytes");
57 f32::from_le_bytes(arr)
58 })
59 .collect()
60}
61
62#[cfg(test)]
63mod tests {
64 use super::*;
65
66 #[test]
67 fn cosine_similarity_identical_vectors() {
68 let a = vec![1.0, 2.0, 3.0];
69 let b = vec![1.0, 2.0, 3.0];
70 let sim = cosine_similarity(&a, &b);
71 assert!((sim - 1.0).abs() < 1e-6, "identical vectors: {sim}");
72 }
73
74 #[test]
75 fn cosine_similarity_orthogonal_vectors() {
76 let a = vec![1.0, 0.0];
77 let b = vec![0.0, 1.0];
78 let sim = cosine_similarity(&a, &b);
79 assert!(sim.abs() < 1e-6, "orthogonal vectors: {sim}");
80 }
81
82 #[test]
83 fn cosine_similarity_opposite_vectors() {
84 let a = vec![1.0, 0.0];
85 let b = vec![-1.0, 0.0];
86 let sim = cosine_similarity(&a, &b);
87 assert!((sim + 1.0).abs() < 1e-6, "opposite vectors: {sim}");
88 }
89
90 #[test]
91 fn cosine_similarity_empty_returns_zero() {
92 let sim = cosine_similarity(&[], &[]);
93 assert_eq!(sim, 0.0);
94 }
95
96 #[test]
97 fn cosine_similarity_mismatched_lengths_returns_zero() {
98 let a = vec![1.0, 2.0];
99 let b = vec![1.0, 2.0, 3.0];
100 let sim = cosine_similarity(&a, &b);
101 assert_eq!(sim, 0.0);
102 }
103
104 #[test]
105 fn cosine_similarity_zero_vector_returns_zero() {
106 let a = vec![0.0, 0.0];
107 let b = vec![1.0, 2.0];
108 let sim = cosine_similarity(&a, &b);
109 assert_eq!(sim, 0.0);
110 }
111
112 #[test]
113 fn cosine_similarity_similar_vectors_high() {
114 let a = vec![1.0, 2.0, 3.0];
115 let b = vec![1.1, 2.1, 3.1]; let sim = cosine_similarity(&a, &b);
117 assert!(sim > 0.99, "similar vectors should be > 0.99: {sim}");
118 }
119
120 #[test]
121 fn embedding_roundtrip() {
122 let original = vec![1.0_f32, -2.5, std::f32::consts::PI, 0.0, f32::MAX, f32::MIN];
123 let bytes = embedding_to_bytes(&original);
124 let decoded = bytes_to_embedding(&bytes);
125 assert_eq!(original, decoded);
126 }
127
128 #[test]
129 fn embedding_roundtrip_empty() {
130 let original: Vec<f32> = vec![];
131 let bytes = embedding_to_bytes(&original);
132 let decoded = bytes_to_embedding(&bytes);
133 assert_eq!(original, decoded);
134 }
135}