Skip to main content

cognee_models/
entity.rs

1//! Entity - Storage-layer entity model.
2//!
3//! Mirrors Python's `cognee/modules/engine/models/Entity.py`
4//! Represents an entity extracted from text and stored in the graph database.
5
6use serde::{Deserialize, Serialize};
7use uuid::Uuid;
8
9use crate::DataPoint;
10use crate::has_datapoint::HasDataPoint;
11
12/// Storage-layer entity model.
13///
14/// Represents an entity (e.g., "TechCorp", "Alice", "London") extracted
15/// from text. Each entity has a name, description, and a reference to its
16/// EntityType (e.g., "Organization", "Person", "Location").
17#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
18pub struct Entity {
19    /// Base data point fields (id, timestamps, metadata, etc.)
20    #[serde(flatten)]
21    pub base: DataPoint,
22
23    /// Entity name (e.g., "TechCorp")
24    pub name: String,
25
26    /// Reference to EntityType UUID (e.g., UUID of "Organization" type)
27    pub is_a: Option<Uuid>,
28
29    /// Entity description from LLM extraction
30    pub description: String,
31}
32
33impl Entity {
34    /// Index fields to embed for vector search.
35    pub const INDEX_FIELDS: &'static [&'static str] = &["name"];
36
37    /// Create a new Entity.
38    ///
39    /// # Arguments
40    /// * `name` - Entity name
41    /// * `entity_type_id` - Optional reference to EntityType
42    /// * `description` - Entity description
43    /// * `dataset_id` - Dataset UUID
44    pub fn new(
45        name: impl Into<String>,
46        entity_type_id: Option<Uuid>,
47        description: impl Into<String>,
48        dataset_id: Option<Uuid>,
49    ) -> Self {
50        let mut metadata = std::collections::HashMap::new();
51        metadata.insert(
52            "index_fields".to_string(),
53            serde_json::json!(Self::INDEX_FIELDS),
54        );
55
56        Self {
57            base: DataPoint::with_metadata("Entity", dataset_id, metadata),
58            name: name.into(),
59            is_a: entity_type_id,
60            description: description.into(),
61        }
62    }
63
64    /// Create Entity from LLM-extracted Node.
65    ///
66    /// # Arguments
67    /// * `node_id` - Original node ID from LLM extraction
68    /// * `node_name` - Node name
69    /// * `node_description` - Node description
70    /// * `entity_type_id` - EntityType UUID
71    /// * `dataset_id` - Dataset UUID
72    pub fn from_node(
73        node_id: impl Into<String>,
74        node_name: impl Into<String>,
75        node_description: impl Into<String>,
76        entity_type_id: Uuid,
77        dataset_id: Option<Uuid>,
78    ) -> Self {
79        let mut entity = Self::new(
80            node_name,
81            Some(entity_type_id),
82            node_description,
83            dataset_id,
84        );
85
86        entity
87            .base
88            .set_metadata("original_node_id", serde_json::json!(node_id.into()));
89
90        entity
91    }
92
93    /// Get the entity name (for embedding).
94    pub fn get_embeddable_text(&self) -> String {
95        self.name.clone()
96    }
97
98    /// Update entity description.
99    pub fn set_description(&mut self, description: impl Into<String>) {
100        self.description = description.into();
101        self.base.touch();
102    }
103
104    /// Update entity type reference.
105    pub fn set_entity_type(&mut self, entity_type_id: Uuid) {
106        self.is_a = Some(entity_type_id);
107        self.base.touch();
108    }
109}
110
111impl HasDataPoint for Entity {
112    fn data_point(&self) -> &DataPoint {
113        &self.base
114    }
115    fn data_point_mut(&mut self) -> &mut DataPoint {
116        &mut self.base
117    }
118    // for_each_child_mut: default no-op — Entity references its EntityType
119    // by UUID (`is_a: Option<Uuid>`), not by ownership. If a future variant
120    // owns an `entity_type: Box<EntityType>` field, override here to recurse.
121}
122
123#[cfg(test)]
124mod tests {
125    use super::*;
126
127    #[test]
128    fn test_entity_creation() {
129        let entity = Entity::new("TechCorp", None, "A technology company", None);
130
131        assert_eq!(entity.name, "TechCorp");
132        assert_eq!(entity.description, "A technology company");
133        assert_eq!(entity.base.data_type, "Entity");
134        assert!(entity.is_a.is_none());
135    }
136
137    #[test]
138    fn test_entity_with_type() {
139        let type_id = Uuid::new_v4();
140        let entity = Entity::new("TechCorp", Some(type_id), "A technology company", None);
141
142        assert_eq!(entity.is_a, Some(type_id));
143    }
144
145    #[test]
146    fn test_entity_from_node() {
147        let type_id = Uuid::new_v4();
148        let entity = Entity::from_node(
149            "techcorp_1",
150            "TechCorp",
151            "A technology company",
152            type_id,
153            None,
154        );
155
156        assert_eq!(entity.name, "TechCorp");
157        assert_eq!(entity.is_a, Some(type_id));
158        assert_eq!(
159            entity.base.get_metadata("original_node_id"),
160            Some(&serde_json::json!("techcorp_1"))
161        );
162    }
163
164    #[test]
165    fn test_entity_index_fields() {
166        let entity = Entity::new("TechCorp", None, "A company", None);
167        let index_fields = entity.base.get_metadata("index_fields");
168
169        assert_eq!(index_fields, Some(&serde_json::json!(["name"])));
170    }
171
172    #[test]
173    fn test_entity_embeddable_text() {
174        let entity = Entity::new("TechCorp", None, "A company", None);
175        assert_eq!(entity.get_embeddable_text(), "TechCorp");
176    }
177
178    #[test]
179    fn test_entity_set_description() {
180        let mut entity = Entity::new("TechCorp", None, "Old desc", None);
181        let old_time = entity.base.updated_at;
182
183        std::thread::sleep(std::time::Duration::from_millis(10));
184        entity.set_description("New description");
185
186        assert_eq!(entity.description, "New description");
187        // updated_at is i64 (millis since epoch); touch() should advance it
188        assert!(entity.base.updated_at >= old_time);
189    }
190
191    #[test]
192    fn test_entity_set_type() {
193        let mut entity = Entity::new("TechCorp", None, "A company", None);
194        let type_id = Uuid::new_v4();
195
196        entity.set_entity_type(type_id);
197        assert_eq!(entity.is_a, Some(type_id));
198    }
199
200    #[test]
201    fn entity_implements_has_datapoint() {
202        let e = Entity::new("Foo", None, "desc", None);
203        let dp_id = e.base.id;
204        assert_eq!(e.data_point().id, dp_id);
205        let mut e2 = e;
206        assert_eq!(e2.data_point_mut().id, dp_id);
207    }
208}