cognee_models/
embedding.rs1use serde::{Deserialize, Serialize};
7use uuid::Uuid;
8
9#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
31pub struct Embedding {
32 pub data_point_id: Uuid,
34
35 pub data_type: String,
37
38 pub field_name: String,
40
41 pub vector: Vec<f32>,
43}
44
45impl Embedding {
46 pub fn new(
54 data_point_id: Uuid,
55 data_type: impl Into<String>,
56 field_name: impl Into<String>,
57 vector: Vec<f32>,
58 ) -> Self {
59 Self {
60 data_point_id,
61 data_type: data_type.into(),
62 field_name: field_name.into(),
63 vector,
64 }
65 }
66
67 pub fn dimensions(&self) -> usize {
69 self.vector.len()
70 }
71
72 pub fn norm(&self) -> f32 {
74 self.vector.iter().map(|x| x * x).sum::<f32>().sqrt()
75 }
76
77 pub fn cosine_similarity(&self, other: &Embedding) -> Option<f32> {
82 if self.vector.len() != other.vector.len() {
83 return None;
84 }
85
86 let similarity = self
87 .vector
88 .iter()
89 .zip(&other.vector)
90 .map(|(a, b)| a * b)
91 .sum();
92
93 Some(similarity)
94 }
95}
96
97#[cfg(test)]
98mod tests {
99 use super::*;
100
101 #[test]
102 fn test_embedding_creation() {
103 let id = Uuid::new_v4();
104 let embedding = Embedding::new(id, "DocumentChunk", "text", vec![0.1, 0.2, 0.3]);
105
106 assert_eq!(embedding.data_point_id, id);
107 assert_eq!(embedding.data_type, "DocumentChunk");
108 assert_eq!(embedding.field_name, "text");
109 assert_eq!(embedding.dimensions(), 3);
110 }
111
112 #[test]
113 fn test_norm() {
114 let embedding = Embedding::new(
115 Uuid::new_v4(),
116 "Entity",
117 "name",
118 vec![0.6, 0.8], );
120
121 let norm = embedding.norm();
122 assert!((norm - 1.0).abs() < 0.01, "Expected norm ~1.0, got {norm}");
123 }
124
125 #[test]
126 fn test_cosine_similarity() {
127 let id1 = Uuid::new_v4();
128 let id2 = Uuid::new_v4();
129
130 let e1 = Embedding::new(id1, "Entity", "name", vec![1.0, 0.0, 0.0]);
132 let e2 = Embedding::new(id2, "Entity", "name", vec![1.0, 0.0, 0.0]);
133 let e3 = Embedding::new(id2, "Entity", "name", vec![0.0, 1.0, 0.0]);
134
135 assert_eq!(e1.cosine_similarity(&e2), Some(1.0));
137
138 assert_eq!(e1.cosine_similarity(&e3), Some(0.0));
140 }
141
142 #[test]
143 fn test_cosine_similarity_dimension_mismatch() {
144 let e1 = Embedding::new(Uuid::new_v4(), "Entity", "name", vec![1.0, 0.0]);
145 let e2 = Embedding::new(Uuid::new_v4(), "Entity", "name", vec![1.0, 0.0, 0.0]);
146
147 assert_eq!(e1.cosine_similarity(&e2), None);
148 }
149}