Skip to main content

cognee_models/
entity_type.rs

1//! EntityType - Storage-layer entity type model.
2//!
3//! Mirrors Python's `cognee/modules/engine/models/EntityType.py`
4//! Represents a category/type of entities (e.g., "Organization", "Person", "Location").
5
6use serde::{Deserialize, Serialize};
7use uuid::Uuid;
8
9use crate::DataPoint;
10use crate::has_datapoint::HasDataPoint;
11
12/// Storage-layer entity type model.
13///
14/// Represents a category of entities (e.g., "Organization", "Person", "Location").
15/// Entity instances reference their EntityType via the `is_a` field.
16#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
17pub struct EntityType {
18    /// Base data point fields (id, timestamps, metadata, etc.)
19    #[serde(flatten)]
20    pub base: DataPoint,
21
22    /// Type name (e.g., "Organization", "Person", "Location")
23    pub name: String,
24
25    /// Type description
26    pub description: String,
27}
28
29impl EntityType {
30    /// Index fields to embed for vector search.
31    pub const INDEX_FIELDS: &'static [&'static str] = &["name"];
32
33    /// Create a new EntityType.
34    ///
35    /// # Arguments
36    /// * `name` - Type name (e.g., "Organization")
37    /// * `description` - Type description
38    /// * `dataset_id` - Dataset UUID
39    pub fn new(
40        name: impl Into<String>,
41        description: impl Into<String>,
42        dataset_id: Option<Uuid>,
43    ) -> Self {
44        let mut metadata = std::collections::HashMap::new();
45        metadata.insert(
46            "index_fields".to_string(),
47            serde_json::json!(Self::INDEX_FIELDS),
48        );
49
50        let name_str = name.into();
51        let description_str = description.into();
52
53        Self {
54            base: DataPoint::with_metadata("EntityType", dataset_id, metadata),
55            name: name_str.clone(),
56            description: if description_str.is_empty() {
57                format!("Entity type: {name_str}")
58            } else {
59                description_str
60            },
61        }
62    }
63
64    /// Create EntityType from LLM-extracted node type string.
65    ///
66    /// # Arguments
67    /// * `type_name` - Node type from LLM (e.g., "Organization")
68    /// * `dataset_id` - Dataset UUID
69    pub fn from_node_type(type_name: impl Into<String>, dataset_id: Option<Uuid>) -> Self {
70        let type_str = type_name.into();
71        Self::new(
72            type_str.clone(),
73            format!("Entity type: {type_str}"),
74            dataset_id,
75        )
76    }
77
78    /// Get the type name (for embedding).
79    pub fn get_embeddable_text(&self) -> String {
80        self.name.clone()
81    }
82
83    /// Update type description.
84    pub fn set_description(&mut self, description: impl Into<String>) {
85        self.description = description.into();
86        self.base.touch();
87    }
88
89    /// Check if this type has been validated against an ontology.
90    pub fn is_ontology_valid(&self) -> bool {
91        self.base.ontology_valid
92    }
93
94    /// Mark as ontology-validated with canonical name.
95    ///
96    /// # Arguments
97    /// * `canonical_name` - Canonical name from ontology
98    pub fn mark_ontology_valid(&mut self, canonical_name: Option<String>) {
99        self.base.set_ontology_valid(true);
100
101        if let Some(canonical) = canonical_name
102            && canonical != self.name
103        {
104            self.base
105                .set_metadata("original_name", serde_json::json!(self.name.clone()));
106            self.name = canonical;
107        }
108    }
109}
110
111impl HasDataPoint for EntityType {
112    fn data_point(&self) -> &DataPoint {
113        &self.base
114    }
115    fn data_point_mut(&mut self) -> &mut DataPoint {
116        &mut self.base
117    }
118    // for_each_child_mut: default no-op — EntityType is a leaf in the
119    // model graph (no owned `HasDataPoint` children).
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125
126    #[test]
127    fn test_entity_type_creation() {
128        let et = EntityType::new("Organization", "A company or institution", None);
129
130        assert_eq!(et.name, "Organization");
131        assert_eq!(et.description, "A company or institution");
132        assert_eq!(et.base.data_type, "EntityType");
133    }
134
135    #[test]
136    fn test_entity_type_empty_description() {
137        let et = EntityType::new("Person", "", None);
138
139        assert_eq!(et.name, "Person");
140        assert_eq!(et.description, "Entity type: Person");
141    }
142
143    #[test]
144    fn test_entity_type_from_node_type() {
145        let et = EntityType::from_node_type("Location", None);
146
147        assert_eq!(et.name, "Location");
148        assert_eq!(et.description, "Entity type: Location");
149    }
150
151    #[test]
152    fn test_entity_type_index_fields() {
153        let et = EntityType::new("Organization", "A company", None);
154        let index_fields = et.base.get_metadata("index_fields");
155
156        assert_eq!(index_fields, Some(&serde_json::json!(["name"])));
157    }
158
159    #[test]
160    fn test_entity_type_embeddable_text() {
161        let et = EntityType::new("Organization", "A company", None);
162        assert_eq!(et.get_embeddable_text(), "Organization");
163    }
164
165    #[test]
166    fn test_entity_type_set_description() {
167        let mut et = EntityType::new("Organization", "Old desc", None);
168        et.set_description("New description");
169        assert_eq!(et.description, "New description");
170    }
171
172    #[test]
173    fn test_ontology_validation() {
174        let mut et = EntityType::new("Mathematician", "", None);
175        assert!(!et.is_ontology_valid());
176
177        // Mark as valid with canonical name
178        et.mark_ontology_valid(Some("Person".to_string()));
179
180        assert!(et.is_ontology_valid());
181        assert_eq!(et.name, "Person");
182        assert_eq!(
183            et.base.get_metadata("original_name"),
184            Some(&serde_json::json!("Mathematician"))
185        );
186    }
187
188    #[test]
189    fn test_ontology_validation_same_name() {
190        let mut et = EntityType::new("Person", "", None);
191        et.mark_ontology_valid(Some("Person".to_string()));
192
193        assert!(et.is_ontology_valid());
194        assert_eq!(et.name, "Person");
195        assert_eq!(et.base.get_metadata("original_name"), None);
196    }
197
198    #[test]
199    fn test_ontology_validation_no_canonical() {
200        let mut et = EntityType::new("Person", "", None);
201        et.mark_ontology_valid(None);
202
203        assert!(et.is_ontology_valid());
204        assert_eq!(et.name, "Person");
205    }
206
207    #[test]
208    fn entity_type_implements_has_datapoint() {
209        let et = EntityType::new("Org", "desc", None);
210        let dp_id = et.base.id;
211        assert_eq!(et.data_point().id, dp_id);
212        let mut et2 = et;
213        assert_eq!(et2.data_point_mut().id, dp_id);
214    }
215}