Skip to main content

cognee_models/
embedding.rs

1//! Embedding - Storage model for vector embeddings of data points.
2//!
3//! Represents an embedding vector for a specific field of a data point.
4//! Used to store embeddings in vector databases for semantic search.
5
6use serde::{Deserialize, Serialize};
7use uuid::Uuid;
8
9/// Embedding vector for a data point field.
10///
11/// Each embedding represents a specific field (e.g., "text", "name")
12/// of a data point (e.g., DocumentChunk, Entity) as a dense vector.
13///
14/// # Examples
15/// ```
16/// use cognee_models::Embedding;
17/// use uuid::Uuid;
18///
19/// let chunk_id = Uuid::new_v4();
20/// let embedding = Embedding::new(
21///     chunk_id,
22///     "DocumentChunk",
23///     "text",
24///     vec![0.1, 0.2, 0.3], // Dense vector
25/// );
26///
27/// assert_eq!(embedding.dimensions(), 3);
28/// assert_eq!(embedding.field_name, "text");
29/// ```
30#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
31pub struct Embedding {
32    /// UUID of the data point this embedding belongs to
33    pub data_point_id: Uuid,
34
35    /// Type of the data point (e.g., "DocumentChunk", "Entity", "TextSummary")
36    pub data_type: String,
37
38    /// Name of the field that was embedded (e.g., "text", "name", "content")
39    pub field_name: String,
40
41    /// Dense embedding vector (f32 for compatibility with most vector DBs)
42    pub vector: Vec<f32>,
43}
44
45impl Embedding {
46    /// Create a new embedding.
47    ///
48    /// # Arguments
49    /// * `data_point_id` - UUID of the source data point
50    /// * `data_type` - Type discriminator (e.g., "DocumentChunk")
51    /// * `field_name` - Field that was embedded (e.g., "text")
52    /// * `vector` - Dense embedding vector (usually 384, 768, or 1536 dimensions)
53    pub fn new(
54        data_point_id: Uuid,
55        data_type: impl Into<String>,
56        field_name: impl Into<String>,
57        vector: Vec<f32>,
58    ) -> Self {
59        Self {
60            data_point_id,
61            data_type: data_type.into(),
62            field_name: field_name.into(),
63            vector,
64        }
65    }
66
67    /// Get the dimensionality of the embedding vector.
68    pub fn dimensions(&self) -> usize {
69        self.vector.len()
70    }
71
72    /// Calculate L2 norm of the embedding (should be ~1.0 if normalized).
73    pub fn norm(&self) -> f32 {
74        self.vector.iter().map(|x| x * x).sum::<f32>().sqrt()
75    }
76
77    /// Calculate cosine similarity with another embedding.
78    ///
79    /// Both embeddings must be normalized (L2 norm = 1.0).
80    /// Returns dot product in range [-1.0, 1.0].
81    pub fn cosine_similarity(&self, other: &Embedding) -> Option<f32> {
82        if self.vector.len() != other.vector.len() {
83            return None;
84        }
85
86        let similarity = self
87            .vector
88            .iter()
89            .zip(&other.vector)
90            .map(|(a, b)| a * b)
91            .sum();
92
93        Some(similarity)
94    }
95}
96
97#[cfg(test)]
98mod tests {
99    use super::*;
100
101    #[test]
102    fn test_embedding_creation() {
103        let id = Uuid::new_v4();
104        let embedding = Embedding::new(id, "DocumentChunk", "text", vec![0.1, 0.2, 0.3]);
105
106        assert_eq!(embedding.data_point_id, id);
107        assert_eq!(embedding.data_type, "DocumentChunk");
108        assert_eq!(embedding.field_name, "text");
109        assert_eq!(embedding.dimensions(), 3);
110    }
111
112    #[test]
113    fn test_norm() {
114        let embedding = Embedding::new(
115            Uuid::new_v4(),
116            "Entity",
117            "name",
118            vec![0.6, 0.8], // 3-4-5 triangle: norm = 1.0
119        );
120
121        let norm = embedding.norm();
122        assert!((norm - 1.0).abs() < 0.01, "Expected norm ~1.0, got {norm}");
123    }
124
125    #[test]
126    fn test_cosine_similarity() {
127        let id1 = Uuid::new_v4();
128        let id2 = Uuid::new_v4();
129
130        // Normalized vectors
131        let e1 = Embedding::new(id1, "Entity", "name", vec![1.0, 0.0, 0.0]);
132        let e2 = Embedding::new(id2, "Entity", "name", vec![1.0, 0.0, 0.0]);
133        let e3 = Embedding::new(id2, "Entity", "name", vec![0.0, 1.0, 0.0]);
134
135        // Identical vectors
136        assert_eq!(e1.cosine_similarity(&e2), Some(1.0));
137
138        // Orthogonal vectors
139        assert_eq!(e1.cosine_similarity(&e3), Some(0.0));
140    }
141
142    #[test]
143    fn test_cosine_similarity_dimension_mismatch() {
144        let e1 = Embedding::new(Uuid::new_v4(), "Entity", "name", vec![1.0, 0.0]);
145        let e2 = Embedding::new(Uuid::new_v4(), "Entity", "name", vec![1.0, 0.0, 0.0]);
146
147        assert_eq!(e1.cosine_similarity(&e2), None);
148    }
149}