do_memory_core/embeddings/
similarity.rs1use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct SimilaritySearchResult<T> {
8 pub item: T,
10 pub similarity: f32,
12 pub metadata: SimilarityMetadata,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize, Default)]
18pub struct SimilarityMetadata {
19 #[serde(default)]
21 pub embedding_model: String,
22 pub embedding_timestamp: Option<chrono::DateTime<chrono::Utc>>,
24 #[serde(default)]
26 pub context: serde_json::Value,
27}
28
29#[must_use]
35pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
36 if a.len() != b.len() {
37 return 0.0;
38 }
39
40 if a.is_empty() {
41 return 0.0;
42 }
43
44 let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
45 let magnitude_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
46 let magnitude_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
47
48 if magnitude_a == 0.0 || magnitude_b == 0.0 {
49 return 0.0;
50 }
51
52 let similarity = dot_product / (magnitude_a * magnitude_b);
54 (similarity + 1.0) / 2.0
55}
56
57#[cfg(test)]
58mod tests {
59 use super::*;
60
61 #[test]
62 fn test_cosine_similarity() {
63 let vec1 = vec![1.0, 2.0, 3.0];
64 let vec2 = vec![1.0, 2.0, 3.0];
65 let similarity = cosine_similarity(&vec1, &vec2);
66 assert!((similarity - 1.0).abs() < 0.001);
67
68 let vec3 = vec![1.0, 0.0];
69 let vec4 = vec![0.0, 1.0];
70 let similarity = cosine_similarity(&vec3, &vec4);
71 assert!((similarity - 0.5).abs() < 0.001);
72
73 let vec5 = vec![1.0, 2.0, 3.0];
74 let vec6 = vec![-1.0, -2.0, -3.0];
75 let similarity = cosine_similarity(&vec5, &vec6);
76 assert!((similarity - 0.0).abs() < 0.001);
77
78 let vec7 = vec![1.0, 2.0];
79 let vec8 = vec![1.0, 2.0, 3.0];
80 let similarity = cosine_similarity(&vec7, &vec8);
81 assert_eq!(similarity, 0.0);
82 }
83}