Skip to main content

agentzero_core/
embedding.rs

1//! Embedding provider trait and cosine similarity for semantic memory recall.
2
3use async_trait::async_trait;
4
5/// Provider for text embeddings.
6#[async_trait]
7pub trait EmbeddingProvider: Send + Sync {
8    /// Embed a text string into a vector of floats.
9    async fn embed(&self, text: &str) -> anyhow::Result<Vec<f32>>;
10
11    /// The dimensionality of embeddings produced by this provider.
12    fn dimensions(&self) -> usize;
13}
14
15/// Compute cosine similarity between two vectors.
16///
17/// Returns a value between -1.0 (opposite) and 1.0 (identical).
18/// Returns 0.0 if either vector is zero-length or has zero magnitude.
19pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
20    if a.len() != b.len() || a.is_empty() {
21        return 0.0;
22    }
23
24    let mut dot = 0.0_f32;
25    let mut mag_a = 0.0_f32;
26    let mut mag_b = 0.0_f32;
27
28    for (x, y) in a.iter().zip(b.iter()) {
29        dot += x * y;
30        mag_a += x * x;
31        mag_b += y * y;
32    }
33
34    let denom = mag_a.sqrt() * mag_b.sqrt();
35    if denom == 0.0 {
36        return 0.0;
37    }
38
39    dot / denom
40}
41
42/// Encode a `Vec<f32>` as little-endian bytes (for SQLite BLOB storage).
43pub fn embedding_to_bytes(embedding: &[f32]) -> Vec<u8> {
44    let mut bytes = Vec::with_capacity(embedding.len() * 4);
45    for &val in embedding {
46        bytes.extend_from_slice(&val.to_le_bytes());
47    }
48    bytes
49}
50
51/// Decode little-endian bytes back into a `Vec<f32>`.
52pub fn bytes_to_embedding(bytes: &[u8]) -> Vec<f32> {
53    bytes
54        .chunks_exact(4)
55        .map(|chunk| {
56            let arr: [u8; 4] = chunk.try_into().expect("chunk is exactly 4 bytes");
57            f32::from_le_bytes(arr)
58        })
59        .collect()
60}
61
62#[cfg(test)]
63mod tests {
64    use super::*;
65
66    #[test]
67    fn cosine_similarity_identical_vectors() {
68        let a = vec![1.0, 2.0, 3.0];
69        let b = vec![1.0, 2.0, 3.0];
70        let sim = cosine_similarity(&a, &b);
71        assert!((sim - 1.0).abs() < 1e-6, "identical vectors: {sim}");
72    }
73
74    #[test]
75    fn cosine_similarity_orthogonal_vectors() {
76        let a = vec![1.0, 0.0];
77        let b = vec![0.0, 1.0];
78        let sim = cosine_similarity(&a, &b);
79        assert!(sim.abs() < 1e-6, "orthogonal vectors: {sim}");
80    }
81
82    #[test]
83    fn cosine_similarity_opposite_vectors() {
84        let a = vec![1.0, 0.0];
85        let b = vec![-1.0, 0.0];
86        let sim = cosine_similarity(&a, &b);
87        assert!((sim + 1.0).abs() < 1e-6, "opposite vectors: {sim}");
88    }
89
90    #[test]
91    fn cosine_similarity_empty_returns_zero() {
92        let sim = cosine_similarity(&[], &[]);
93        assert_eq!(sim, 0.0);
94    }
95
96    #[test]
97    fn cosine_similarity_mismatched_lengths_returns_zero() {
98        let a = vec![1.0, 2.0];
99        let b = vec![1.0, 2.0, 3.0];
100        let sim = cosine_similarity(&a, &b);
101        assert_eq!(sim, 0.0);
102    }
103
104    #[test]
105    fn cosine_similarity_zero_vector_returns_zero() {
106        let a = vec![0.0, 0.0];
107        let b = vec![1.0, 2.0];
108        let sim = cosine_similarity(&a, &b);
109        assert_eq!(sim, 0.0);
110    }
111
112    #[test]
113    fn cosine_similarity_similar_vectors_high() {
114        let a = vec![1.0, 2.0, 3.0];
115        let b = vec![1.1, 2.1, 3.1]; // Very similar
116        let sim = cosine_similarity(&a, &b);
117        assert!(sim > 0.99, "similar vectors should be > 0.99: {sim}");
118    }
119
120    #[test]
121    fn embedding_roundtrip() {
122        let original = vec![1.0_f32, -2.5, std::f32::consts::PI, 0.0, f32::MAX, f32::MIN];
123        let bytes = embedding_to_bytes(&original);
124        let decoded = bytes_to_embedding(&bytes);
125        assert_eq!(original, decoded);
126    }
127
128    #[test]
129    fn embedding_roundtrip_empty() {
130        let original: Vec<f32> = vec![];
131        let bytes = embedding_to_bytes(&original);
132        let decoded = bytes_to_embedding(&bytes);
133        assert_eq!(original, decoded);
134    }
135}