1use chrono::{DateTime, Utc};
4use rusqlite::{params, Connection, Row};
5use std::collections::HashMap;
6
7use crate::error::{EngramError, Result};
8use crate::intelligence::{Entity, EntityRelation, EntityType, ExtractedEntity};
9use crate::types::MemoryId;
10
11fn entity_from_row(row: &Row) -> rusqlite::Result<Entity> {
17 let id: i64 = row.get("id")?;
18 let name: String = row.get("name")?;
19 let normalized_name: String = row.get("normalized_name")?;
20 let entity_type_str: String = row.get("entity_type")?;
21 let aliases_str: String = row.get("aliases")?;
22 let metadata_str: String = row.get("metadata")?;
23 let created_at: String = row.get("created_at")?;
24 let updated_at: String = row.get("updated_at")?;
25 let mention_count: i32 = row.get("mention_count")?;
26
27 let entity_type = entity_type_str.parse().unwrap_or(EntityType::Other);
28 let aliases: Vec<String> = serde_json::from_str(&aliases_str).unwrap_or_default();
29 let metadata: HashMap<String, serde_json::Value> =
30 serde_json::from_str(&metadata_str).unwrap_or_default();
31
32 Ok(Entity {
33 id,
34 name,
35 normalized_name,
36 entity_type,
37 aliases,
38 metadata,
39 created_at: DateTime::parse_from_rfc3339(&created_at)
40 .map(|dt| dt.with_timezone(&Utc))
41 .unwrap_or_else(|_| Utc::now()),
42 updated_at: DateTime::parse_from_rfc3339(&updated_at)
43 .map(|dt| dt.with_timezone(&Utc))
44 .unwrap_or_else(|_| Utc::now()),
45 mention_count,
46 })
47}
48
49pub fn upsert_entity(conn: &Connection, extracted: &ExtractedEntity) -> Result<i64> {
51 let now = Utc::now().to_rfc3339();
52
53 let existing: Option<i64> = conn
55 .query_row(
56 "SELECT id FROM entities WHERE normalized_name = ? AND entity_type = ?",
57 params![extracted.normalized, extracted.entity_type.as_str()],
58 |row| row.get(0),
59 )
60 .ok();
61
62 if let Some(id) = existing {
63 conn.execute(
65 "UPDATE entities SET updated_at = ? WHERE id = ?",
66 params![now, id],
67 )?;
68 Ok(id)
69 } else {
70 conn.execute(
72 "INSERT INTO entities (name, normalized_name, entity_type, created_at, updated_at, mention_count)
73 VALUES (?, ?, ?, ?, ?, 0)",
74 params![
75 extracted.text,
76 extracted.normalized,
77 extracted.entity_type.as_str(),
78 now,
79 now,
80 ],
81 )?;
82 Ok(conn.last_insert_rowid())
83 }
84}
85
86pub fn link_entity_to_memory(
88 conn: &Connection,
89 memory_id: MemoryId,
90 entity_id: i64,
91 relation: EntityRelation,
92 confidence: f32,
93 offset: Option<usize>,
94) -> Result<bool> {
95 let now = Utc::now().to_rfc3339();
96
97 let inserted = conn.execute(
98 "INSERT OR IGNORE INTO memory_entities (memory_id, entity_id, relation, confidence, char_offset, created_at)
99 VALUES (?, ?, ?, ?, ?, ?)",
100 params![
101 memory_id,
102 entity_id,
103 relation.as_str(),
104 confidence,
105 offset.map(|o| o as i64),
106 now,
107 ],
108 )? > 0;
109
110 if inserted {
111 conn.execute(
112 "UPDATE entities SET mention_count = mention_count + 1, updated_at = ? WHERE id = ?",
113 params![now, entity_id],
114 )?;
115 }
116
117 Ok(inserted)
118}
119
120pub fn get_entity(conn: &Connection, id: i64) -> Result<Entity> {
122 let mut stmt = conn.prepare_cached(
123 "SELECT id, name, normalized_name, entity_type, aliases, metadata,
124 created_at, updated_at, mention_count
125 FROM entities WHERE id = ?",
126 )?;
127
128 stmt.query_row([id], entity_from_row)
129 .map_err(|_| EngramError::NotFound(id))
130}
131
132pub fn find_entity(
134 conn: &Connection,
135 name: &str,
136 entity_type: Option<EntityType>,
137) -> Result<Option<Entity>> {
138 let normalized = name.trim().to_lowercase();
139
140 let sql = if entity_type.is_some() {
141 "SELECT id, name, normalized_name, entity_type, aliases, metadata,
142 created_at, updated_at, mention_count
143 FROM entities WHERE normalized_name = ? AND entity_type = ?"
144 } else {
145 "SELECT id, name, normalized_name, entity_type, aliases, metadata,
146 created_at, updated_at, mention_count
147 FROM entities WHERE normalized_name = ?"
148 };
149
150 let mut stmt = conn.prepare(sql)?;
151
152 let result = if let Some(et) = entity_type {
153 stmt.query_row(params![normalized, et.as_str()], entity_from_row)
154 } else {
155 stmt.query_row(params![normalized], entity_from_row)
156 };
157
158 match result {
159 Ok(entity) => Ok(Some(entity)),
160 Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
161 Err(e) => Err(EngramError::from(e)),
162 }
163}
164
165pub fn list_entities(
167 conn: &Connection,
168 entity_type: Option<EntityType>,
169 limit: i64,
170 offset: i64,
171) -> Result<Vec<Entity>> {
172 let sql = if entity_type.is_some() {
173 "SELECT id, name, normalized_name, entity_type, aliases, metadata,
174 created_at, updated_at, mention_count
175 FROM entities WHERE entity_type = ?
176 ORDER BY mention_count DESC, updated_at DESC
177 LIMIT ? OFFSET ?"
178 } else {
179 "SELECT id, name, normalized_name, entity_type, aliases, metadata,
180 created_at, updated_at, mention_count
181 FROM entities
182 ORDER BY mention_count DESC, updated_at DESC
183 LIMIT ? OFFSET ?"
184 };
185
186 let mut stmt = conn.prepare(sql)?;
187
188 let entities = if let Some(et) = entity_type {
189 stmt.query_map(params![et.as_str(), limit, offset], entity_from_row)?
190 .filter_map(|r| r.ok())
191 .collect()
192 } else {
193 stmt.query_map(params![limit, offset], entity_from_row)?
194 .filter_map(|r| r.ok())
195 .collect()
196 };
197
198 Ok(entities)
199}
200
201pub fn get_entities_for_memory(
203 conn: &Connection,
204 memory_id: MemoryId,
205) -> Result<Vec<(Entity, EntityRelation, f32)>> {
206 let mut stmt = conn.prepare(
207 "SELECT e.id, e.name, e.normalized_name, e.entity_type, e.aliases, e.metadata,
208 e.created_at, e.updated_at, e.mention_count,
209 me.relation, me.confidence
210 FROM entities e
211 JOIN memory_entities me ON e.id = me.entity_id
212 WHERE me.memory_id = ?
213 ORDER BY me.confidence DESC",
214 )?;
215
216 let results: Vec<(Entity, EntityRelation, f32)> = stmt
217 .query_map([memory_id], |row| {
218 let entity = entity_from_row(row)?;
219 let relation_str: String = row.get("relation")?;
220 let confidence: f32 = row.get("confidence")?;
221 let relation = relation_str.parse().unwrap_or(EntityRelation::Mentions);
222 Ok((entity, relation, confidence))
223 })?
224 .filter_map(|r| r.ok())
225 .collect();
226
227 Ok(results)
228}
229
230pub fn get_memories_for_entity(
232 conn: &Connection,
233 entity_id: i64,
234) -> Result<Vec<(MemoryId, EntityRelation, f32)>> {
235 let mut stmt = conn.prepare(
236 "SELECT memory_id, relation, confidence
237 FROM memory_entities
238 WHERE entity_id = ?
239 ORDER BY confidence DESC",
240 )?;
241
242 let results: Vec<(MemoryId, EntityRelation, f32)> = stmt
243 .query_map([entity_id], |row| {
244 let memory_id: MemoryId = row.get("memory_id")?;
245 let relation_str: String = row.get("relation")?;
246 let confidence: f32 = row.get("confidence")?;
247 let relation = relation_str.parse().unwrap_or(EntityRelation::Mentions);
248 Ok((memory_id, relation, confidence))
249 })?
250 .filter_map(|r| r.ok())
251 .collect();
252
253 Ok(results)
254}
255
256pub fn search_entities(
258 conn: &Connection,
259 query: &str,
260 entity_type: Option<EntityType>,
261 limit: i64,
262) -> Result<Vec<Entity>> {
263 let pattern = format!("{}%", query.to_lowercase());
264
265 let sql = if entity_type.is_some() {
266 "SELECT id, name, normalized_name, entity_type, aliases, metadata,
267 created_at, updated_at, mention_count
268 FROM entities
269 WHERE normalized_name LIKE ? AND entity_type = ?
270 ORDER BY mention_count DESC
271 LIMIT ?"
272 } else {
273 "SELECT id, name, normalized_name, entity_type, aliases, metadata,
274 created_at, updated_at, mention_count
275 FROM entities
276 WHERE normalized_name LIKE ?
277 ORDER BY mention_count DESC
278 LIMIT ?"
279 };
280
281 let mut stmt = conn.prepare(sql)?;
282
283 let entities = if let Some(et) = entity_type {
284 stmt.query_map(params![pattern, et.as_str(), limit], entity_from_row)?
285 .filter_map(|r| r.ok())
286 .collect()
287 } else {
288 stmt.query_map(params![pattern, limit], entity_from_row)?
289 .filter_map(|r| r.ok())
290 .collect()
291 };
292
293 Ok(entities)
294}
295
296pub fn delete_entity(conn: &Connection, id: i64) -> Result<()> {
298 let affected = conn.execute("DELETE FROM entities WHERE id = ?", params![id])?;
300
301 if affected == 0 {
302 return Err(EngramError::NotFound(id));
303 }
304
305 Ok(())
306}
307
308pub fn unlink_entity_from_memory(
310 conn: &Connection,
311 memory_id: MemoryId,
312 entity_id: i64,
313) -> Result<()> {
314 conn.execute(
315 "DELETE FROM memory_entities WHERE memory_id = ? AND entity_id = ?",
316 params![memory_id, entity_id],
317 )?;
318
319 Ok(())
320}
321
322pub fn get_entity_stats(conn: &Connection) -> Result<EntityStats> {
324 let total_entities: i64 =
325 conn.query_row("SELECT COUNT(*) FROM entities", [], |row| row.get(0))?;
326
327 let total_links: i64 =
328 conn.query_row("SELECT COUNT(*) FROM memory_entities", [], |row| row.get(0))?;
329
330 let by_type: HashMap<String, i64> = {
331 let mut stmt =
332 conn.prepare("SELECT entity_type, COUNT(*) FROM entities GROUP BY entity_type")?;
333 let results: Vec<(String, i64)> = stmt
334 .query_map([], |row| {
335 let entity_type: String = row.get(0)?;
336 let count: i64 = row.get(1)?;
337 Ok((entity_type, count))
338 })?
339 .filter_map(|r| r.ok())
340 .collect();
341 results.into_iter().collect()
342 };
343
344 Ok(EntityStats {
345 total_entities,
346 total_links,
347 by_type,
348 })
349}
350
351#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
353pub struct EntityStats {
354 pub total_entities: i64,
355 pub total_links: i64,
356 pub by_type: HashMap<String, i64>,
357}
358
359#[cfg(test)]
364mod tests {
365 use super::*;
366 use crate::storage::Storage;
367
368 #[test]
369 fn test_upsert_and_find_entity() {
370 let storage = Storage::open_in_memory().unwrap();
371
372 storage
373 .with_connection(|conn| {
374 let extracted = ExtractedEntity {
375 text: "Anthropic".to_string(),
376 normalized: "anthropic".to_string(),
377 entity_type: EntityType::Organization,
378 confidence: 0.9,
379 offset: 0,
380 length: 9,
381 suggested_relation: EntityRelation::Mentions,
382 };
383
384 let id1 = upsert_entity(conn, &extracted)?;
386 assert!(id1 > 0);
387
388 let id2 = upsert_entity(conn, &extracted)?;
390 assert_eq!(id1, id2);
391
392 let entity = get_entity(conn, id1)?;
394 assert_eq!(entity.mention_count, 0);
395 assert_eq!(entity.name, "Anthropic");
396
397 let found = find_entity(conn, "anthropic", Some(EntityType::Organization))?;
399 assert!(found.is_some());
400 assert_eq!(found.unwrap().id, id1);
401
402 Ok(())
403 })
404 .unwrap();
405 }
406
407 #[test]
408 fn test_link_entity_to_memory() {
409 let storage = Storage::open_in_memory().unwrap();
410
411 storage
412 .with_transaction(|conn| {
413 use crate::storage::queries::create_memory;
414 use crate::types::{CreateMemoryInput, MemoryType};
415
416 let memory = create_memory(
418 conn,
419 &CreateMemoryInput {
420 content: "Testing Anthropic's Claude model".to_string(),
421 memory_type: MemoryType::Note,
422 tags: vec![],
423 metadata: HashMap::new(),
424 importance: None,
425 scope: Default::default(),
426 workspace: None,
427 tier: Default::default(),
428 defer_embedding: true,
429 ttl_seconds: None,
430 dedup_mode: Default::default(),
431 dedup_threshold: None,
432 event_time: None,
433 event_duration_seconds: None,
434 trigger_pattern: None,
435 summary_of_id: None,
436 },
437 )?;
438
439 let extracted = ExtractedEntity {
441 text: "Anthropic".to_string(),
442 normalized: "anthropic".to_string(),
443 entity_type: EntityType::Organization,
444 confidence: 0.9,
445 offset: 8,
446 length: 9,
447 suggested_relation: EntityRelation::Mentions,
448 };
449 let entity_id = upsert_entity(conn, &extracted)?;
450
451 let inserted = link_entity_to_memory(
453 conn,
454 memory.id,
455 entity_id,
456 EntityRelation::Mentions,
457 0.9,
458 Some(8),
459 )?;
460 assert!(inserted);
461
462 let entities = get_entities_for_memory(conn, memory.id)?;
464 assert_eq!(entities.len(), 1);
465 assert_eq!(entities[0].0.name, "Anthropic");
466 assert_eq!(entities[0].1, EntityRelation::Mentions);
467 assert_eq!(entities[0].0.mention_count, 1);
468
469 let inserted_again = link_entity_to_memory(
471 conn,
472 memory.id,
473 entity_id,
474 EntityRelation::Mentions,
475 0.9,
476 Some(8),
477 )?;
478 assert!(!inserted_again);
479
480 let entity = get_entity(conn, entity_id)?;
481 assert_eq!(entity.mention_count, 1);
482
483 let memories = get_memories_for_entity(conn, entity_id)?;
485 assert_eq!(memories.len(), 1);
486 assert_eq!(memories[0].0, memory.id);
487
488 Ok(())
489 })
490 .unwrap();
491 }
492
493 #[test]
494 fn test_entity_search() {
495 let storage = Storage::open_in_memory().unwrap();
496
497 storage
498 .with_connection(|conn| {
499 for name in &["Anthropic", "Apple", "Amazon", "Microsoft"] {
501 let extracted = ExtractedEntity {
502 text: name.to_string(),
503 normalized: name.to_lowercase(),
504 entity_type: EntityType::Organization,
505 confidence: 0.9,
506 offset: 0,
507 length: name.len(),
508 suggested_relation: EntityRelation::Mentions,
509 };
510 upsert_entity(conn, &extracted)?;
511 }
512
513 let results = search_entities(conn, "a", Some(EntityType::Organization), 10)?;
515 assert_eq!(results.len(), 3); let results = search_entities(conn, "mi", None, 10)?;
519 assert_eq!(results.len(), 1); Ok(())
522 })
523 .unwrap();
524 }
525}