cognee_models/
entity_type.rs1use serde::{Deserialize, Serialize};
7use uuid::Uuid;
8
9use crate::DataPoint;
10use crate::has_datapoint::HasDataPoint;
11
12#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
17pub struct EntityType {
18 #[serde(flatten)]
20 pub base: DataPoint,
21
22 pub name: String,
24
25 pub description: String,
27}
28
29impl EntityType {
30 pub const INDEX_FIELDS: &'static [&'static str] = &["name"];
32
33 pub fn new(
40 name: impl Into<String>,
41 description: impl Into<String>,
42 dataset_id: Option<Uuid>,
43 ) -> Self {
44 let mut metadata = std::collections::HashMap::new();
45 metadata.insert(
46 "index_fields".to_string(),
47 serde_json::json!(Self::INDEX_FIELDS),
48 );
49
50 let name_str = name.into();
51 let description_str = description.into();
52
53 Self {
54 base: DataPoint::with_metadata("EntityType", dataset_id, metadata),
55 name: name_str.clone(),
56 description: if description_str.is_empty() {
57 format!("Entity type: {name_str}")
58 } else {
59 description_str
60 },
61 }
62 }
63
64 pub fn from_node_type(type_name: impl Into<String>, dataset_id: Option<Uuid>) -> Self {
70 let type_str = type_name.into();
71 Self::new(
72 type_str.clone(),
73 format!("Entity type: {type_str}"),
74 dataset_id,
75 )
76 }
77
78 pub fn get_embeddable_text(&self) -> String {
80 self.name.clone()
81 }
82
83 pub fn set_description(&mut self, description: impl Into<String>) {
85 self.description = description.into();
86 self.base.touch();
87 }
88
89 pub fn is_ontology_valid(&self) -> bool {
91 self.base.ontology_valid
92 }
93
94 pub fn mark_ontology_valid(&mut self, canonical_name: Option<String>) {
99 self.base.set_ontology_valid(true);
100
101 if let Some(canonical) = canonical_name
102 && canonical != self.name
103 {
104 self.base
105 .set_metadata("original_name", serde_json::json!(self.name.clone()));
106 self.name = canonical;
107 }
108 }
109}
110
111impl HasDataPoint for EntityType {
112 fn data_point(&self) -> &DataPoint {
113 &self.base
114 }
115 fn data_point_mut(&mut self) -> &mut DataPoint {
116 &mut self.base
117 }
118 }
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125
126 #[test]
127 fn test_entity_type_creation() {
128 let et = EntityType::new("Organization", "A company or institution", None);
129
130 assert_eq!(et.name, "Organization");
131 assert_eq!(et.description, "A company or institution");
132 assert_eq!(et.base.data_type, "EntityType");
133 }
134
135 #[test]
136 fn test_entity_type_empty_description() {
137 let et = EntityType::new("Person", "", None);
138
139 assert_eq!(et.name, "Person");
140 assert_eq!(et.description, "Entity type: Person");
141 }
142
143 #[test]
144 fn test_entity_type_from_node_type() {
145 let et = EntityType::from_node_type("Location", None);
146
147 assert_eq!(et.name, "Location");
148 assert_eq!(et.description, "Entity type: Location");
149 }
150
151 #[test]
152 fn test_entity_type_index_fields() {
153 let et = EntityType::new("Organization", "A company", None);
154 let index_fields = et.base.get_metadata("index_fields");
155
156 assert_eq!(index_fields, Some(&serde_json::json!(["name"])));
157 }
158
159 #[test]
160 fn test_entity_type_embeddable_text() {
161 let et = EntityType::new("Organization", "A company", None);
162 assert_eq!(et.get_embeddable_text(), "Organization");
163 }
164
165 #[test]
166 fn test_entity_type_set_description() {
167 let mut et = EntityType::new("Organization", "Old desc", None);
168 et.set_description("New description");
169 assert_eq!(et.description, "New description");
170 }
171
172 #[test]
173 fn test_ontology_validation() {
174 let mut et = EntityType::new("Mathematician", "", None);
175 assert!(!et.is_ontology_valid());
176
177 et.mark_ontology_valid(Some("Person".to_string()));
179
180 assert!(et.is_ontology_valid());
181 assert_eq!(et.name, "Person");
182 assert_eq!(
183 et.base.get_metadata("original_name"),
184 Some(&serde_json::json!("Mathematician"))
185 );
186 }
187
188 #[test]
189 fn test_ontology_validation_same_name() {
190 let mut et = EntityType::new("Person", "", None);
191 et.mark_ontology_valid(Some("Person".to_string()));
192
193 assert!(et.is_ontology_valid());
194 assert_eq!(et.name, "Person");
195 assert_eq!(et.base.get_metadata("original_name"), None);
196 }
197
198 #[test]
199 fn test_ontology_validation_no_canonical() {
200 let mut et = EntityType::new("Person", "", None);
201 et.mark_ontology_valid(None);
202
203 assert!(et.is_ontology_valid());
204 assert_eq!(et.name, "Person");
205 }
206
207 #[test]
208 fn entity_type_implements_has_datapoint() {
209 let et = EntityType::new("Org", "desc", None);
210 let dp_id = et.base.id;
211 assert_eq!(et.data_point().id, dp_id);
212 let mut et2 = et;
213 assert_eq!(et2.data_point_mut().id, dp_id);
214 }
215}