Skip to main content

engram/storage/
entity_queries.rs

1//! Database queries for entity operations (RML-925)
2
3use chrono::{DateTime, Utc};
4use rusqlite::{params, Connection, Row};
5use std::collections::HashMap;
6
7use crate::error::{EngramError, Result};
8use crate::intelligence::{Entity, EntityRelation, EntityType, ExtractedEntity};
9use crate::types::MemoryId;
10
11// =============================================================================
12// Entity Queries
13// =============================================================================
14
15/// Parse an entity from a database row
16fn entity_from_row(row: &Row) -> rusqlite::Result<Entity> {
17    let id: i64 = row.get("id")?;
18    let name: String = row.get("name")?;
19    let normalized_name: String = row.get("normalized_name")?;
20    let entity_type_str: String = row.get("entity_type")?;
21    let aliases_str: String = row.get("aliases")?;
22    let metadata_str: String = row.get("metadata")?;
23    let created_at: String = row.get("created_at")?;
24    let updated_at: String = row.get("updated_at")?;
25    let mention_count: i32 = row.get("mention_count")?;
26
27    let entity_type = entity_type_str.parse().unwrap_or(EntityType::Other);
28    let aliases: Vec<String> = serde_json::from_str(&aliases_str).unwrap_or_default();
29    let metadata: HashMap<String, serde_json::Value> =
30        serde_json::from_str(&metadata_str).unwrap_or_default();
31
32    Ok(Entity {
33        id,
34        name,
35        normalized_name,
36        entity_type,
37        aliases,
38        metadata,
39        created_at: DateTime::parse_from_rfc3339(&created_at)
40            .map(|dt| dt.with_timezone(&Utc))
41            .unwrap_or_else(|_| Utc::now()),
42        updated_at: DateTime::parse_from_rfc3339(&updated_at)
43            .map(|dt| dt.with_timezone(&Utc))
44            .unwrap_or_else(|_| Utc::now()),
45        mention_count,
46    })
47}
48
49/// Create or update an entity, returning its ID
50pub fn upsert_entity(conn: &Connection, extracted: &ExtractedEntity) -> Result<i64> {
51    let now = Utc::now().to_rfc3339();
52
53    // Try to find existing entity
54    let existing: Option<i64> = conn
55        .query_row(
56            "SELECT id FROM entities WHERE normalized_name = ? AND entity_type = ?",
57            params![extracted.normalized, extracted.entity_type.as_str()],
58            |row| row.get(0),
59        )
60        .ok();
61
62    if let Some(id) = existing {
63        // Update timestamp only; mention_count is incremented when a new link is created
64        conn.execute(
65            "UPDATE entities SET updated_at = ? WHERE id = ?",
66            params![now, id],
67        )?;
68        Ok(id)
69    } else {
70        // Insert new entity with zero mentions; links drive mention_count
71        conn.execute(
72            "INSERT INTO entities (name, normalized_name, entity_type, created_at, updated_at, mention_count)
73             VALUES (?, ?, ?, ?, ?, 0)",
74            params![
75                extracted.text,
76                extracted.normalized,
77                extracted.entity_type.as_str(),
78                now,
79                now,
80            ],
81        )?;
82        Ok(conn.last_insert_rowid())
83    }
84}
85
86/// Link an entity to a memory
87pub fn link_entity_to_memory(
88    conn: &Connection,
89    memory_id: MemoryId,
90    entity_id: i64,
91    relation: EntityRelation,
92    confidence: f32,
93    offset: Option<usize>,
94) -> Result<bool> {
95    let now = Utc::now().to_rfc3339();
96
97    let inserted = conn.execute(
98        "INSERT OR IGNORE INTO memory_entities (memory_id, entity_id, relation, confidence, char_offset, created_at)
99         VALUES (?, ?, ?, ?, ?, ?)",
100        params![
101            memory_id,
102            entity_id,
103            relation.as_str(),
104            confidence,
105            offset.map(|o| o as i64),
106            now,
107        ],
108    )? > 0;
109
110    if inserted {
111        conn.execute(
112            "UPDATE entities SET mention_count = mention_count + 1, updated_at = ? WHERE id = ?",
113            params![now, entity_id],
114        )?;
115    }
116
117    Ok(inserted)
118}
119
120/// Get an entity by ID
121pub fn get_entity(conn: &Connection, id: i64) -> Result<Entity> {
122    let mut stmt = conn.prepare_cached(
123        "SELECT id, name, normalized_name, entity_type, aliases, metadata,
124                created_at, updated_at, mention_count
125         FROM entities WHERE id = ?",
126    )?;
127
128    stmt.query_row([id], entity_from_row)
129        .map_err(|_| EngramError::NotFound(id))
130}
131
132/// Find entity by name and type
133pub fn find_entity(
134    conn: &Connection,
135    name: &str,
136    entity_type: Option<EntityType>,
137) -> Result<Option<Entity>> {
138    let normalized = name.trim().to_lowercase();
139
140    let sql = if entity_type.is_some() {
141        "SELECT id, name, normalized_name, entity_type, aliases, metadata,
142                created_at, updated_at, mention_count
143         FROM entities WHERE normalized_name = ? AND entity_type = ?"
144    } else {
145        "SELECT id, name, normalized_name, entity_type, aliases, metadata,
146                created_at, updated_at, mention_count
147         FROM entities WHERE normalized_name = ?"
148    };
149
150    let mut stmt = conn.prepare(sql)?;
151
152    let result = if let Some(et) = entity_type {
153        stmt.query_row(params![normalized, et.as_str()], entity_from_row)
154    } else {
155        stmt.query_row(params![normalized], entity_from_row)
156    };
157
158    match result {
159        Ok(entity) => Ok(Some(entity)),
160        Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
161        Err(e) => Err(EngramError::from(e)),
162    }
163}
164
165/// List entities with optional filtering
166pub fn list_entities(
167    conn: &Connection,
168    entity_type: Option<EntityType>,
169    limit: i64,
170    offset: i64,
171) -> Result<Vec<Entity>> {
172    let sql = if entity_type.is_some() {
173        "SELECT id, name, normalized_name, entity_type, aliases, metadata,
174                created_at, updated_at, mention_count
175         FROM entities WHERE entity_type = ?
176         ORDER BY mention_count DESC, updated_at DESC
177         LIMIT ? OFFSET ?"
178    } else {
179        "SELECT id, name, normalized_name, entity_type, aliases, metadata,
180                created_at, updated_at, mention_count
181         FROM entities
182         ORDER BY mention_count DESC, updated_at DESC
183         LIMIT ? OFFSET ?"
184    };
185
186    let mut stmt = conn.prepare(sql)?;
187
188    let entities = if let Some(et) = entity_type {
189        stmt.query_map(params![et.as_str(), limit, offset], entity_from_row)?
190            .filter_map(|r| r.ok())
191            .collect()
192    } else {
193        stmt.query_map(params![limit, offset], entity_from_row)?
194            .filter_map(|r| r.ok())
195            .collect()
196    };
197
198    Ok(entities)
199}
200
201/// Get all entities linked to a memory
202pub fn get_entities_for_memory(
203    conn: &Connection,
204    memory_id: MemoryId,
205) -> Result<Vec<(Entity, EntityRelation, f32)>> {
206    let mut stmt = conn.prepare(
207        "SELECT e.id, e.name, e.normalized_name, e.entity_type, e.aliases, e.metadata,
208                e.created_at, e.updated_at, e.mention_count,
209                me.relation, me.confidence
210         FROM entities e
211         JOIN memory_entities me ON e.id = me.entity_id
212         WHERE me.memory_id = ?
213         ORDER BY me.confidence DESC",
214    )?;
215
216    let results: Vec<(Entity, EntityRelation, f32)> = stmt
217        .query_map([memory_id], |row| {
218            let entity = entity_from_row(row)?;
219            let relation_str: String = row.get("relation")?;
220            let confidence: f32 = row.get("confidence")?;
221            let relation = relation_str.parse().unwrap_or(EntityRelation::Mentions);
222            Ok((entity, relation, confidence))
223        })?
224        .filter_map(|r| r.ok())
225        .collect();
226
227    Ok(results)
228}
229
230/// Get all memories that mention an entity
231pub fn get_memories_for_entity(
232    conn: &Connection,
233    entity_id: i64,
234) -> Result<Vec<(MemoryId, EntityRelation, f32)>> {
235    let mut stmt = conn.prepare(
236        "SELECT memory_id, relation, confidence
237         FROM memory_entities
238         WHERE entity_id = ?
239         ORDER BY confidence DESC",
240    )?;
241
242    let results: Vec<(MemoryId, EntityRelation, f32)> = stmt
243        .query_map([entity_id], |row| {
244            let memory_id: MemoryId = row.get("memory_id")?;
245            let relation_str: String = row.get("relation")?;
246            let confidence: f32 = row.get("confidence")?;
247            let relation = relation_str.parse().unwrap_or(EntityRelation::Mentions);
248            Ok((memory_id, relation, confidence))
249        })?
250        .filter_map(|r| r.ok())
251        .collect();
252
253    Ok(results)
254}
255
256/// Search entities by name prefix
257pub fn search_entities(
258    conn: &Connection,
259    query: &str,
260    entity_type: Option<EntityType>,
261    limit: i64,
262) -> Result<Vec<Entity>> {
263    let pattern = format!("{}%", query.to_lowercase());
264
265    let sql = if entity_type.is_some() {
266        "SELECT id, name, normalized_name, entity_type, aliases, metadata,
267                created_at, updated_at, mention_count
268         FROM entities
269         WHERE normalized_name LIKE ? AND entity_type = ?
270         ORDER BY mention_count DESC
271         LIMIT ?"
272    } else {
273        "SELECT id, name, normalized_name, entity_type, aliases, metadata,
274                created_at, updated_at, mention_count
275         FROM entities
276         WHERE normalized_name LIKE ?
277         ORDER BY mention_count DESC
278         LIMIT ?"
279    };
280
281    let mut stmt = conn.prepare(sql)?;
282
283    let entities = if let Some(et) = entity_type {
284        stmt.query_map(params![pattern, et.as_str(), limit], entity_from_row)?
285            .filter_map(|r| r.ok())
286            .collect()
287    } else {
288        stmt.query_map(params![pattern, limit], entity_from_row)?
289            .filter_map(|r| r.ok())
290            .collect()
291    };
292
293    Ok(entities)
294}
295
296/// Delete an entity and its links
297pub fn delete_entity(conn: &Connection, id: i64) -> Result<()> {
298    // Links are deleted by CASCADE
299    let affected = conn.execute("DELETE FROM entities WHERE id = ?", params![id])?;
300
301    if affected == 0 {
302        return Err(EngramError::NotFound(id));
303    }
304
305    Ok(())
306}
307
308/// Remove entity link from a memory
309pub fn unlink_entity_from_memory(
310    conn: &Connection,
311    memory_id: MemoryId,
312    entity_id: i64,
313) -> Result<()> {
314    conn.execute(
315        "DELETE FROM memory_entities WHERE memory_id = ? AND entity_id = ?",
316        params![memory_id, entity_id],
317    )?;
318
319    Ok(())
320}
321
322/// Get entity statistics
323pub fn get_entity_stats(conn: &Connection) -> Result<EntityStats> {
324    let total_entities: i64 =
325        conn.query_row("SELECT COUNT(*) FROM entities", [], |row| row.get(0))?;
326
327    let total_links: i64 =
328        conn.query_row("SELECT COUNT(*) FROM memory_entities", [], |row| row.get(0))?;
329
330    let by_type: HashMap<String, i64> = {
331        let mut stmt =
332            conn.prepare("SELECT entity_type, COUNT(*) FROM entities GROUP BY entity_type")?;
333        let results: Vec<(String, i64)> = stmt
334            .query_map([], |row| {
335                let entity_type: String = row.get(0)?;
336                let count: i64 = row.get(1)?;
337                Ok((entity_type, count))
338            })?
339            .filter_map(|r| r.ok())
340            .collect();
341        results.into_iter().collect()
342    };
343
344    Ok(EntityStats {
345        total_entities,
346        total_links,
347        by_type,
348    })
349}
350
351/// Entity statistics
352#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
353pub struct EntityStats {
354    pub total_entities: i64,
355    pub total_links: i64,
356    pub by_type: HashMap<String, i64>,
357}
358
359// =============================================================================
360// Tests
361// =============================================================================
362
363#[cfg(test)]
364mod tests {
365    use super::*;
366    use crate::storage::Storage;
367
368    #[test]
369    fn test_upsert_and_find_entity() {
370        let storage = Storage::open_in_memory().unwrap();
371
372        storage
373            .with_connection(|conn| {
374                let extracted = ExtractedEntity {
375                    text: "Anthropic".to_string(),
376                    normalized: "anthropic".to_string(),
377                    entity_type: EntityType::Organization,
378                    confidence: 0.9,
379                    offset: 0,
380                    length: 9,
381                    suggested_relation: EntityRelation::Mentions,
382                };
383
384                // First insert
385                let id1 = upsert_entity(conn, &extracted)?;
386                assert!(id1 > 0);
387
388                // Second insert should update, not create
389                let id2 = upsert_entity(conn, &extracted)?;
390                assert_eq!(id1, id2);
391
392                // Verify mention count unchanged (links drive mention_count)
393                let entity = get_entity(conn, id1)?;
394                assert_eq!(entity.mention_count, 0);
395                assert_eq!(entity.name, "Anthropic");
396
397                // Find by name
398                let found = find_entity(conn, "anthropic", Some(EntityType::Organization))?;
399                assert!(found.is_some());
400                assert_eq!(found.unwrap().id, id1);
401
402                Ok(())
403            })
404            .unwrap();
405    }
406
407    #[test]
408    fn test_link_entity_to_memory() {
409        let storage = Storage::open_in_memory().unwrap();
410
411        storage
412            .with_transaction(|conn| {
413                use crate::storage::queries::create_memory;
414                use crate::types::{CreateMemoryInput, MemoryType};
415
416                // Create a memory
417                let memory = create_memory(
418                    conn,
419                    &CreateMemoryInput {
420                        content: "Testing Anthropic's Claude model".to_string(),
421                        memory_type: MemoryType::Note,
422                        tags: vec![],
423                        metadata: HashMap::new(),
424                        importance: None,
425                        scope: Default::default(),
426                        workspace: None,
427                        tier: Default::default(),
428                        defer_embedding: true,
429                        ttl_seconds: None,
430                        dedup_mode: Default::default(),
431                        dedup_threshold: None,
432                        event_time: None,
433                        event_duration_seconds: None,
434                        trigger_pattern: None,
435                        summary_of_id: None,
436                    },
437                )?;
438
439                // Create an entity
440                let extracted = ExtractedEntity {
441                    text: "Anthropic".to_string(),
442                    normalized: "anthropic".to_string(),
443                    entity_type: EntityType::Organization,
444                    confidence: 0.9,
445                    offset: 8,
446                    length: 9,
447                    suggested_relation: EntityRelation::Mentions,
448                };
449                let entity_id = upsert_entity(conn, &extracted)?;
450
451                // Link them
452                let inserted = link_entity_to_memory(
453                    conn,
454                    memory.id,
455                    entity_id,
456                    EntityRelation::Mentions,
457                    0.9,
458                    Some(8),
459                )?;
460                assert!(inserted);
461
462                // Verify link
463                let entities = get_entities_for_memory(conn, memory.id)?;
464                assert_eq!(entities.len(), 1);
465                assert_eq!(entities[0].0.name, "Anthropic");
466                assert_eq!(entities[0].1, EntityRelation::Mentions);
467                assert_eq!(entities[0].0.mention_count, 1);
468
469                // Duplicate link should be ignored and not inflate mention_count
470                let inserted_again = link_entity_to_memory(
471                    conn,
472                    memory.id,
473                    entity_id,
474                    EntityRelation::Mentions,
475                    0.9,
476                    Some(8),
477                )?;
478                assert!(!inserted_again);
479
480                let entity = get_entity(conn, entity_id)?;
481                assert_eq!(entity.mention_count, 1);
482
483                // Verify reverse lookup
484                let memories = get_memories_for_entity(conn, entity_id)?;
485                assert_eq!(memories.len(), 1);
486                assert_eq!(memories[0].0, memory.id);
487
488                Ok(())
489            })
490            .unwrap();
491    }
492
493    #[test]
494    fn test_entity_search() {
495        let storage = Storage::open_in_memory().unwrap();
496
497        storage
498            .with_connection(|conn| {
499                // Create some entities
500                for name in &["Anthropic", "Apple", "Amazon", "Microsoft"] {
501                    let extracted = ExtractedEntity {
502                        text: name.to_string(),
503                        normalized: name.to_lowercase(),
504                        entity_type: EntityType::Organization,
505                        confidence: 0.9,
506                        offset: 0,
507                        length: name.len(),
508                        suggested_relation: EntityRelation::Mentions,
509                    };
510                    upsert_entity(conn, &extracted)?;
511                }
512
513                // Search for "a" prefix
514                let results = search_entities(conn, "a", Some(EntityType::Organization), 10)?;
515                assert_eq!(results.len(), 3); // Anthropic, Apple, Amazon
516
517                // Search for "mi" prefix
518                let results = search_entities(conn, "mi", None, 10)?;
519                assert_eq!(results.len(), 1); // Microsoft
520
521                Ok(())
522            })
523            .unwrap();
524    }
525}