Skip to main content

khive_db/stores/
entity.rs

1//! SQL-backed `EntityStore` implementation.
2
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use uuid::Uuid;
7
8use khive_storage::entity::{Entity, EntityFilter};
9use khive_storage::error::StorageError;
10use khive_storage::types::{BatchWriteSummary, DeleteMode, Page, PageRequest};
11use khive_storage::EntityStore;
12use khive_storage::StorageCapability;
13
14use crate::error::SqliteError;
15use crate::pool::ConnectionPool;
16
17fn map_err(e: rusqlite::Error, op: &'static str) -> StorageError {
18    StorageError::driver(StorageCapability::Entities, op, e)
19}
20
21fn map_sqlite_err(e: SqliteError, op: &'static str) -> StorageError {
22    StorageError::driver(StorageCapability::Entities, op, e)
23}
24
25/// An EntityStore backed by SQLite. Namespace is the caller's responsibility.
26///
27/// UUID is globally unique — get/delete by ID alone. Query/count use the
28/// namespace parameter as passed. The store is just a pool + is_file_backed.
29pub struct SqlEntityStore {
30    pool: Arc<ConnectionPool>,
31    is_file_backed: bool,
32}
33
34impl SqlEntityStore {
35    /// Create a new store.
36    pub fn new(pool: Arc<ConnectionPool>, is_file_backed: bool) -> Self {
37        Self {
38            pool,
39            is_file_backed,
40        }
41    }
42
43    fn open_standalone_reader(&self) -> Result<rusqlite::Connection, StorageError> {
44        let config = self.pool.config();
45        let path = config.path.as_ref().ok_or_else(|| StorageError::Pool {
46            operation: "entity_reader".into(),
47            message: "in-memory databases do not support standalone connections".into(),
48        })?;
49
50        let conn = rusqlite::Connection::open_with_flags(
51            path,
52            rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY
53                | rusqlite::OpenFlags::SQLITE_OPEN_NO_MUTEX
54                | rusqlite::OpenFlags::SQLITE_OPEN_URI,
55        )
56        .map_err(|e| map_err(e, "open_entity_reader"))?;
57
58        conn.busy_timeout(config.busy_timeout)
59            .map_err(|e| map_err(e, "open_entity_reader"))?;
60        conn.pragma_update(None, "foreign_keys", "ON")
61            .map_err(|e| map_err(e, "open_entity_reader"))?;
62        conn.pragma_update(None, "synchronous", "NORMAL")
63            .map_err(|e| map_err(e, "open_entity_reader"))?;
64
65        Ok(conn)
66    }
67
68    async fn with_writer<F, R>(&self, op: &'static str, f: F) -> Result<R, StorageError>
69    where
70        F: FnOnce(&rusqlite::Connection) -> Result<R, rusqlite::Error> + Send + 'static,
71        R: Send + 'static,
72    {
73        let pool = Arc::clone(&self.pool);
74        tokio::task::spawn_blocking(move || {
75            let guard = pool.try_writer().map_err(|e| map_sqlite_err(e, op))?;
76            f(guard.conn()).map_err(|e| map_err(e, op))
77        })
78        .await
79        .map_err(|e| StorageError::driver(StorageCapability::Entities, op, e))?
80    }
81
82    async fn with_reader<F, R>(&self, op: &'static str, f: F) -> Result<R, StorageError>
83    where
84        F: FnOnce(&rusqlite::Connection) -> Result<R, rusqlite::Error> + Send + 'static,
85        R: Send + 'static,
86    {
87        if self.is_file_backed {
88            let conn = self.open_standalone_reader()?;
89            tokio::task::spawn_blocking(move || f(&conn).map_err(|e| map_err(e, op)))
90                .await
91                .map_err(|e| StorageError::driver(StorageCapability::Entities, op, e))?
92        } else {
93            let pool = Arc::clone(&self.pool);
94            tokio::task::spawn_blocking(move || {
95                let guard = pool.reader().map_err(|e| map_sqlite_err(e, op))?;
96                f(guard.conn()).map_err(|e| map_err(e, op))
97            })
98            .await
99            .map_err(|e| StorageError::driver(StorageCapability::Entities, op, e))?
100        }
101    }
102}
103
104// =============================================================================
105// Helpers
106// =============================================================================
107
108fn read_entity(row: &rusqlite::Row<'_>) -> Result<Entity, rusqlite::Error> {
109    let id_str: String = row.get(0)?;
110    let namespace: String = row.get(1)?;
111    let kind: String = row.get(2)?;
112    let entity_type: Option<String> = row.get(3)?;
113    let name: String = row.get(4)?;
114    let description: Option<String> = row.get(5)?;
115    let properties_str: Option<String> = row.get(6)?;
116    let tags_str: String = row.get(7)?;
117    let created_at: i64 = row.get(8)?;
118    let updated_at: i64 = row.get(9)?;
119    let deleted_at: Option<i64> = row.get(10)?;
120    let merged_into_str: Option<String> = row.get(11)?;
121    let merge_event_id_str: Option<String> = row.get(12)?;
122
123    let id = parse_uuid(&id_str)?;
124
125    let properties = properties_str
126        .map(|s| {
127            serde_json::from_str(&s).map_err(|e| {
128                rusqlite::Error::FromSqlConversionFailure(
129                    6,
130                    rusqlite::types::Type::Text,
131                    Box::new(e),
132                )
133            })
134        })
135        .transpose()?;
136
137    let tags: Vec<String> = serde_json::from_str(&tags_str).map_err(|e| {
138        rusqlite::Error::FromSqlConversionFailure(7, rusqlite::types::Type::Text, Box::new(e))
139    })?;
140
141    let merged_into = merged_into_str
142        .as_deref()
143        .map(Uuid::parse_str)
144        .transpose()
145        .map_err(|e| {
146            rusqlite::Error::FromSqlConversionFailure(10, rusqlite::types::Type::Text, Box::new(e))
147        })?;
148
149    let merge_event_id = merge_event_id_str
150        .as_deref()
151        .map(Uuid::parse_str)
152        .transpose()
153        .map_err(|e| {
154            rusqlite::Error::FromSqlConversionFailure(11, rusqlite::types::Type::Text, Box::new(e))
155        })?;
156
157    Ok(Entity {
158        id,
159        namespace,
160        kind,
161        entity_type,
162        name,
163        description,
164        properties,
165        tags,
166        created_at,
167        updated_at,
168        deleted_at,
169        merged_into,
170        merge_event_id,
171    })
172}
173
174fn parse_uuid(s: &str) -> Result<Uuid, rusqlite::Error> {
175    Uuid::parse_str(s).map_err(|e| {
176        rusqlite::Error::FromSqlConversionFailure(0, rusqlite::types::Type::Text, Box::new(e))
177    })
178}
179
180fn build_entity_where(
181    namespace: &str,
182    filter: &EntityFilter,
183) -> (String, Vec<Box<dyn rusqlite::types::ToSql>>) {
184    let mut conditions: Vec<String> = vec![
185        "namespace = ?1".to_string(),
186        "deleted_at IS NULL".to_string(),
187    ];
188    let mut params: Vec<Box<dyn rusqlite::types::ToSql>> = vec![Box::new(namespace.to_string())];
189
190    if !filter.ids.is_empty() {
191        let placeholders: Vec<String> = filter
192            .ids
193            .iter()
194            .map(|id| {
195                params.push(Box::new(id.to_string()));
196                format!("?{}", params.len())
197            })
198            .collect();
199        conditions.push(format!("id IN ({})", placeholders.join(", ")));
200    }
201
202    if !filter.kinds.is_empty() {
203        let placeholders: Vec<String> = filter
204            .kinds
205            .iter()
206            .map(|k| {
207                params.push(Box::new(k.clone()));
208                format!("?{}", params.len())
209            })
210            .collect();
211        conditions.push(format!("kind IN ({})", placeholders.join(", ")));
212    }
213
214    if !filter.entity_types.is_empty() {
215        let placeholders: Vec<String> = filter
216            .entity_types
217            .iter()
218            .map(|t| {
219                params.push(Box::new(t.clone()));
220                format!("?{}", params.len())
221            })
222            .collect();
223        conditions.push(format!("entity_type IN ({})", placeholders.join(", ")));
224    }
225
226    if let Some(ref prefix) = filter.name_prefix {
227        params.push(Box::new(format!("{}%", prefix)));
228        conditions.push(format!("name LIKE ?{}", params.len()));
229    }
230
231    if !filter.tags_any.is_empty() {
232        let placeholders: Vec<String> = filter
233            .tags_any
234            .iter()
235            .map(|t| {
236                // Normalise to lowercase so the comparison is case-insensitive
237                // (ADR-047 §91: domain filter must be case-insensitive).
238                params.push(Box::new(t.to_lowercase()));
239                format!("?{}", params.len())
240            })
241            .collect();
242        conditions.push(format!(
243            "EXISTS (SELECT 1 FROM json_each(tags) WHERE LOWER(json_each.value) IN ({}))",
244            placeholders.join(", ")
245        ));
246    }
247
248    let clause = format!(" WHERE {}", conditions.join(" AND "));
249    (clause, params)
250}
251
252// =============================================================================
253// EntityStore implementation
254// =============================================================================
255
256#[async_trait]
257impl EntityStore for SqlEntityStore {
258    async fn upsert_entity(&self, entity: Entity) -> Result<(), StorageError> {
259        let namespace = entity.namespace.clone();
260        let id_str = entity.id.to_string();
261        let properties_str = entity
262            .properties
263            .as_ref()
264            .map(|v| serde_json::to_string(v).unwrap_or_default());
265        let tags_str = serde_json::to_string(&entity.tags).unwrap_or_else(|_| "[]".to_string());
266
267        let merged_into_str = entity.merged_into.map(|u| u.to_string());
268        let merge_event_id_str = entity.merge_event_id.map(|u| u.to_string());
269
270        self.with_writer("upsert_entity", move |conn| {
271            conn.execute(
272                "INSERT OR REPLACE INTO entities \
273                 (id, namespace, kind, entity_type, name, description, properties, tags, \
274                  created_at, updated_at, deleted_at, merged_into, merge_event_id) \
275                 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13)",
276                rusqlite::params![
277                    id_str,
278                    namespace,
279                    entity.kind,
280                    entity.entity_type,
281                    entity.name,
282                    entity.description,
283                    properties_str,
284                    tags_str,
285                    entity.created_at,
286                    entity.updated_at,
287                    entity.deleted_at,
288                    merged_into_str,
289                    merge_event_id_str,
290                ],
291            )?;
292            Ok(())
293        })
294        .await
295    }
296
297    async fn upsert_entities(
298        &self,
299        entities: Vec<Entity>,
300    ) -> Result<BatchWriteSummary, StorageError> {
301        let attempted = entities.len() as u64;
302
303        self.with_writer("upsert_entities", move |conn| {
304            conn.execute_batch("BEGIN IMMEDIATE")?;
305            let mut affected = 0u64;
306            let mut failed = 0u64;
307            let mut first_error = String::new();
308
309            for entity in &entities {
310                let id_str = entity.id.to_string();
311                let properties_str = entity
312                    .properties
313                    .as_ref()
314                    .map(|v| serde_json::to_string(v).unwrap_or_default());
315                let tags_str =
316                    serde_json::to_string(&entity.tags).unwrap_or_else(|_| "[]".to_string());
317
318                let merged_into_str = entity.merged_into.map(|u| u.to_string());
319                let merge_event_id_str = entity.merge_event_id.map(|u| u.to_string());
320                match conn.execute(
321                    "INSERT OR REPLACE INTO entities \
322                     (id, namespace, kind, entity_type, name, description, properties, tags, \
323                      created_at, updated_at, deleted_at, merged_into, merge_event_id) \
324                     VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13)",
325                    rusqlite::params![
326                        id_str,
327                        &entity.namespace,
328                        entity.kind,
329                        entity.entity_type,
330                        entity.name,
331                        entity.description,
332                        properties_str,
333                        tags_str,
334                        entity.created_at,
335                        entity.updated_at,
336                        entity.deleted_at,
337                        merged_into_str,
338                        merge_event_id_str,
339                    ],
340                ) {
341                    Ok(_) => affected += 1,
342                    Err(e) => {
343                        if first_error.is_empty() {
344                            first_error = e.to_string();
345                        }
346                        failed += 1;
347                    }
348                }
349            }
350
351            if let Err(e) = conn.execute_batch("COMMIT") {
352                let _ = conn.execute_batch("ROLLBACK");
353                return Err(e);
354            }
355            Ok(BatchWriteSummary {
356                attempted,
357                affected,
358                failed,
359                first_error,
360            })
361        })
362        .await
363    }
364
365    async fn get_entity(&self, id: Uuid) -> Result<Option<Entity>, StorageError> {
366        let id_str = id.to_string();
367
368        self.with_reader("get_entity", move |conn| {
369            let mut stmt = conn.prepare(
370                "SELECT id, namespace, kind, entity_type, name, description, properties, tags, \
371                 created_at, updated_at, deleted_at, merged_into, merge_event_id \
372                 FROM entities WHERE id = ?1 AND deleted_at IS NULL",
373            )?;
374            let mut rows = stmt.query(rusqlite::params![id_str])?;
375            match rows.next()? {
376                Some(row) => Ok(Some(read_entity(row)?)),
377                None => Ok(None),
378            }
379        })
380        .await
381    }
382
383    async fn delete_entity(&self, id: Uuid, mode: DeleteMode) -> Result<bool, StorageError> {
384        let id_str = id.to_string();
385
386        match mode {
387            DeleteMode::Soft => {
388                self.with_writer("delete_entity_soft", move |conn| {
389                    let now = chrono::Utc::now().timestamp_micros();
390                    let deleted = conn.execute(
391                        "UPDATE entities SET deleted_at = ?1 \
392                         WHERE id = ?2 AND deleted_at IS NULL",
393                        rusqlite::params![now, id_str],
394                    )?;
395                    Ok(deleted > 0)
396                })
397                .await
398            }
399            DeleteMode::Hard => {
400                self.with_writer("delete_entity_hard", move |conn| {
401                    let deleted = conn.execute(
402                        "DELETE FROM entities WHERE id = ?1",
403                        rusqlite::params![id_str],
404                    )?;
405                    Ok(deleted > 0)
406                })
407                .await
408            }
409        }
410    }
411
412    async fn query_entities(
413        &self,
414        namespace: &str,
415        filter: EntityFilter,
416        page: PageRequest,
417    ) -> Result<Page<Entity>, StorageError> {
418        let namespace = namespace.to_string();
419
420        self.with_reader("query_entities", move |conn| {
421            let (count_sql, count_params) = build_entity_where(&namespace, &filter);
422            let total: i64 = {
423                let sql = format!("SELECT COUNT(*) FROM entities{}", count_sql);
424                let mut stmt = conn.prepare(&sql)?;
425                let param_refs: Vec<&dyn rusqlite::types::ToSql> =
426                    count_params.iter().map(|p| p.as_ref()).collect();
427                stmt.query_row(param_refs.as_slice(), |row| row.get(0))?
428            };
429
430            let (where_sql, mut data_params) = build_entity_where(&namespace, &filter);
431            data_params.push(Box::new(page.limit as i64));
432            data_params.push(Box::new(page.offset as i64));
433
434            let limit_idx = data_params.len() - 1;
435            let offset_idx = data_params.len();
436
437            let data_sql = format!(
438                "SELECT id, namespace, kind, entity_type, name, description, properties, tags, \
439                 created_at, updated_at, deleted_at, merged_into, merge_event_id \
440                 FROM entities{} ORDER BY created_at DESC LIMIT ?{} OFFSET ?{}",
441                where_sql, limit_idx, offset_idx,
442            );
443
444            let mut stmt = conn.prepare(&data_sql)?;
445            let param_refs: Vec<&dyn rusqlite::types::ToSql> =
446                data_params.iter().map(|p| p.as_ref()).collect();
447            let rows = stmt.query_map(param_refs.as_slice(), read_entity)?;
448
449            let mut items = Vec::new();
450            for row in rows {
451                items.push(row?);
452            }
453
454            Ok(Page {
455                items,
456                total: Some(total as u64),
457            })
458        })
459        .await
460    }
461
462    async fn count_entities(
463        &self,
464        namespace: &str,
465        filter: EntityFilter,
466    ) -> Result<u64, StorageError> {
467        let namespace = namespace.to_string();
468
469        self.with_reader("count_entities", move |conn| {
470            let (where_sql, params) = build_entity_where(&namespace, &filter);
471            let sql = format!("SELECT COUNT(*) FROM entities{}", where_sql);
472            let mut stmt = conn.prepare(&sql)?;
473            let param_refs: Vec<&dyn rusqlite::types::ToSql> =
474                params.iter().map(|p| p.as_ref()).collect();
475            let count: i64 = stmt.query_row(param_refs.as_slice(), |row| row.get(0))?;
476            Ok(count as u64)
477        })
478        .await
479    }
480}
481
482// =============================================================================
483// DDL
484// =============================================================================
485
486const ENTITIES_DDL: &str = "\
487    CREATE TABLE IF NOT EXISTS entities (\
488        id TEXT PRIMARY KEY,\
489        namespace TEXT NOT NULL,\
490        kind TEXT NOT NULL,\
491        entity_type TEXT,\
492        name TEXT NOT NULL,\
493        description TEXT,\
494        properties TEXT,\
495        tags TEXT NOT NULL DEFAULT '[]',\
496        created_at INTEGER NOT NULL,\
497        updated_at INTEGER NOT NULL,\
498        deleted_at INTEGER,\
499        merged_into TEXT,\
500        merge_event_id TEXT\
501    );\
502    CREATE INDEX IF NOT EXISTS idx_entities_namespace ON entities(namespace);\
503    CREATE INDEX IF NOT EXISTS idx_entities_kind ON entities(namespace, kind);\
504    CREATE INDEX IF NOT EXISTS idx_entities_kind_entity_type ON entities(namespace, kind, entity_type);\
505    CREATE INDEX IF NOT EXISTS idx_entities_name ON entities(namespace, name);\
506    CREATE INDEX IF NOT EXISTS idx_entities_created ON entities(created_at DESC);\
507    CREATE INDEX IF NOT EXISTS idx_entities_merged_into ON entities(namespace, merged_into);\
508";
509
510pub(crate) fn ensure_entities_schema(conn: &rusqlite::Connection) -> Result<(), rusqlite::Error> {
511    conn.execute_batch(ENTITIES_DDL)
512}
513
514#[cfg(test)]
515mod tests {
516    use super::*;
517    use crate::pool::PoolConfig;
518
519    fn setup_pool() -> Arc<ConnectionPool> {
520        let config = PoolConfig {
521            path: None,
522            ..PoolConfig::default()
523        };
524        let pool = Arc::new(ConnectionPool::new(config).unwrap());
525        {
526            let writer = pool.writer().unwrap();
527            writer.conn().execute_batch(ENTITIES_DDL).unwrap();
528        }
529        pool
530    }
531
532    fn setup_memory_store() -> SqlEntityStore {
533        SqlEntityStore::new(setup_pool(), false)
534    }
535
536    fn setup_memory_store_ns(_ns: &str) -> SqlEntityStore {
537        SqlEntityStore::new(setup_pool(), false)
538    }
539
540    fn make_entity(namespace: &str, kind: &str, name: &str) -> Entity {
541        let now = chrono::Utc::now().timestamp_micros();
542        Entity {
543            id: Uuid::new_v4(),
544            namespace: namespace.to_string(),
545            kind: kind.to_string(),
546            entity_type: None,
547            name: name.to_string(),
548            description: None,
549            properties: None,
550            tags: Vec::new(),
551            created_at: now,
552            updated_at: now,
553            deleted_at: None,
554            merged_into: None,
555            merge_event_id: None,
556        }
557    }
558
559    #[tokio::test]
560    async fn test_upsert_and_get_entity() {
561        let store = setup_memory_store();
562
563        let entity = make_entity("default", "concept", "LoRA");
564        let id = entity.id;
565
566        store.upsert_entity(entity).await.unwrap();
567
568        let fetched = store.get_entity(id).await.unwrap();
569        assert!(fetched.is_some());
570        let fetched = fetched.unwrap();
571        assert_eq!(fetched.id, id);
572        assert_eq!(fetched.name, "LoRA");
573        assert_eq!(fetched.kind, "concept");
574    }
575
576    #[tokio::test]
577    async fn test_upsert_with_builder() {
578        let store = setup_memory_store();
579
580        let props = serde_json::json!({"domain": "fine-tuning", "type": "technique"});
581        let entity = Entity::new("default", "concept", "QLoRA")
582            .with_description("Quantized LoRA")
583            .with_properties(props.clone())
584            .with_tags(vec!["fine-tuning".to_string(), "quantization".to_string()]);
585        let id = entity.id;
586
587        store.upsert_entity(entity).await.unwrap();
588
589        let fetched = store.get_entity(id).await.unwrap().unwrap();
590        assert_eq!(fetched.description.as_deref(), Some("Quantized LoRA"));
591        assert_eq!(fetched.properties, Some(props));
592        assert_eq!(fetched.tags, vec!["fine-tuning", "quantization"]);
593    }
594
595    #[tokio::test]
596    async fn test_soft_delete() {
597        let store = setup_memory_store();
598
599        let entity = make_entity("default", "concept", "to-delete");
600        let id = entity.id;
601        store.upsert_entity(entity).await.unwrap();
602
603        let deleted = store.delete_entity(id, DeleteMode::Soft).await.unwrap();
604        assert!(deleted);
605
606        let fetched = store.get_entity(id).await.unwrap();
607        assert!(fetched.is_none());
608    }
609
610    #[tokio::test]
611    async fn test_hard_delete() {
612        let store = setup_memory_store();
613
614        let entity = make_entity("default", "concept", "to-hard-delete");
615        let id = entity.id;
616        store.upsert_entity(entity).await.unwrap();
617
618        let deleted = store.delete_entity(id, DeleteMode::Hard).await.unwrap();
619        assert!(deleted);
620
621        let fetched = store.get_entity(id).await.unwrap();
622        assert!(fetched.is_none());
623    }
624
625    #[tokio::test]
626    async fn test_query_entities_basic() {
627        let store = setup_memory_store_ns("ns1");
628
629        for name in &["Alpha", "Beta", "Gamma"] {
630            store
631                .upsert_entity(make_entity("ns1", "concept", name))
632                .await
633                .unwrap();
634        }
635        store
636            .upsert_entity(make_entity("ns1", "document", "Paper1"))
637            .await
638            .unwrap();
639
640        let page = store
641            .query_entities(
642                "ns1",
643                EntityFilter::default(),
644                PageRequest {
645                    offset: 0,
646                    limit: 10,
647                },
648            )
649            .await
650            .unwrap();
651        assert_eq!(page.items.len(), 4);
652        assert_eq!(page.total, Some(4));
653
654        // Filter by kind
655        let concepts = store
656            .query_entities(
657                "ns1",
658                EntityFilter {
659                    kinds: vec!["concept".to_string()],
660                    ..Default::default()
661                },
662                PageRequest::default(),
663            )
664            .await
665            .unwrap();
666        assert_eq!(concepts.items.len(), 3);
667    }
668
669    #[tokio::test]
670    async fn test_query_by_name_prefix() {
671        let store = setup_memory_store_ns("ns1");
672
673        // "Alpha" and "AlphaGo" both start with "Alpha"; "Beta" does not
674        for &name in &["Alpha", "AlphaGo", "Beta"] {
675            store
676                .upsert_entity(make_entity("ns1", "concept", name))
677                .await
678                .unwrap();
679        }
680
681        let result = store
682            .query_entities(
683                "ns1",
684                EntityFilter {
685                    name_prefix: Some("Alpha".to_string()),
686                    ..Default::default()
687                },
688                PageRequest::default(),
689            )
690            .await
691            .unwrap();
692        assert_eq!(result.items.len(), 2);
693        let names: Vec<&str> = result.items.iter().map(|e| e.name.as_str()).collect();
694        assert!(names.contains(&"Alpha"), "Alpha not found in {names:?}");
695        assert!(names.contains(&"AlphaGo"), "AlphaGo not found in {names:?}");
696        assert!(!names.contains(&"Beta"));
697    }
698
699    #[tokio::test]
700    async fn test_count_entities() {
701        let store = setup_memory_store_ns("ns1");
702
703        for _ in 0..5 {
704            store
705                .upsert_entity(make_entity("ns1", "concept", "X"))
706                .await
707                .unwrap();
708        }
709
710        let count = store
711            .count_entities("ns1", EntityFilter::default())
712            .await
713            .unwrap();
714        assert_eq!(count, 5);
715
716        // Namespace is the caller's responsibility — querying "ns2" returns 0
717        // because no entities were inserted in that namespace.
718        let count_other = store
719            .count_entities("ns2", EntityFilter::default())
720            .await
721            .unwrap();
722        assert_eq!(count_other, 0);
723    }
724
725    #[tokio::test]
726    async fn test_batch_upsert() {
727        let store = setup_memory_store_ns("batch_ns");
728
729        let entities: Vec<Entity> = (0..10)
730            .map(|i| make_entity("batch_ns", "concept", &format!("entity_{i}")))
731            .collect();
732
733        let summary = store.upsert_entities(entities).await.unwrap();
734        assert_eq!(summary.attempted, 10);
735        assert_eq!(summary.affected, 10);
736        assert_eq!(summary.failed, 0);
737
738        let count = store
739            .count_entities("batch_ns", EntityFilter::default())
740            .await
741            .unwrap();
742        assert_eq!(count, 10);
743    }
744
745    /// One store, two namespaces — each query sees only its own.
746    #[tokio::test]
747    async fn test_namespace_isolation() {
748        let pool = setup_pool();
749        let store = SqlEntityStore::new(Arc::clone(&pool), false);
750
751        store
752            .upsert_entity(make_entity("ns_a", "concept", "EntityA"))
753            .await
754            .unwrap();
755        store
756            .upsert_entity(make_entity("ns_b", "concept", "EntityB"))
757            .await
758            .unwrap();
759
760        // Namespace is the caller's responsibility — pass it in the query.
761        let count_a = store
762            .count_entities("ns_a", EntityFilter::default())
763            .await
764            .unwrap();
765        let count_b = store
766            .count_entities("ns_b", EntityFilter::default())
767            .await
768            .unwrap();
769
770        assert_eq!(count_a, 1);
771        assert_eq!(count_b, 1);
772
773        let page_a = store
774            .query_entities("ns_a", EntityFilter::default(), PageRequest::default())
775            .await
776            .unwrap();
777        assert_eq!(page_a.items[0].name, "EntityA");
778
779        let page_b = store
780            .query_entities("ns_b", EntityFilter::default(), PageRequest::default())
781            .await
782            .unwrap();
783        assert_eq!(page_b.items[0].name, "EntityB");
784    }
785
786    #[tokio::test]
787    async fn test_query_by_tags() {
788        let store = setup_memory_store_ns("tags_ns");
789
790        let mut e1 = make_entity("tags_ns", "concept", "Tagged1");
791        e1.tags = vec!["rust".to_string(), "systems".to_string()];
792        let mut e2 = make_entity("tags_ns", "concept", "Tagged2");
793        e2.tags = vec!["python".to_string(), "ml".to_string()];
794        let mut e3 = make_entity("tags_ns", "concept", "Tagged3");
795        e3.tags = vec!["rust".to_string(), "ml".to_string()];
796
797        store.upsert_entity(e1).await.unwrap();
798        store.upsert_entity(e2).await.unwrap();
799        store.upsert_entity(e3).await.unwrap();
800
801        // Filter by "rust" tag — should match Tagged1 and Tagged3
802        let result = store
803            .query_entities(
804                "tags_ns",
805                EntityFilter {
806                    tags_any: vec!["rust".to_string()],
807                    ..Default::default()
808                },
809                PageRequest::default(),
810            )
811            .await
812            .unwrap();
813        assert_eq!(result.items.len(), 2);
814        let names: Vec<&str> = result.items.iter().map(|e| e.name.as_str()).collect();
815        assert!(names.contains(&"Tagged1"));
816        assert!(names.contains(&"Tagged3"));
817        assert!(!names.contains(&"Tagged2"));
818
819        // Filter by "ml" tag — should match Tagged2 and Tagged3
820        let result = store
821            .query_entities(
822                "tags_ns",
823                EntityFilter {
824                    tags_any: vec!["ml".to_string()],
825                    ..Default::default()
826                },
827                PageRequest::default(),
828            )
829            .await
830            .unwrap();
831        assert_eq!(result.items.len(), 2);
832
833        // Filter by both "rust" and "python" (union) — should match all three
834        let result = store
835            .query_entities(
836                "tags_ns",
837                EntityFilter {
838                    tags_any: vec!["rust".to_string(), "python".to_string()],
839                    ..Default::default()
840                },
841                PageRequest::default(),
842            )
843            .await
844            .unwrap();
845        assert_eq!(result.items.len(), 3);
846    }
847
848    #[tokio::test]
849    async fn test_query_by_ids() {
850        let store = setup_memory_store_ns("ns1");
851
852        let e1 = make_entity("ns1", "concept", "E1");
853        let e2 = make_entity("ns1", "concept", "E2");
854        let e3 = make_entity("ns1", "concept", "E3");
855        let ids = vec![e1.id, e3.id];
856
857        store.upsert_entity(e1).await.unwrap();
858        store.upsert_entity(e2).await.unwrap();
859        store.upsert_entity(e3).await.unwrap();
860
861        let result = store
862            .query_entities(
863                "ns1",
864                EntityFilter {
865                    ids,
866                    ..Default::default()
867                },
868                PageRequest::default(),
869            )
870            .await
871            .unwrap();
872        assert_eq!(result.items.len(), 2);
873        let names: Vec<&str> = result.items.iter().map(|e| e.name.as_str()).collect();
874        assert!(names.contains(&"E1"));
875        assert!(names.contains(&"E3"));
876        assert!(!names.contains(&"E2"));
877    }
878
879    #[tokio::test]
880    async fn test_entity_type_roundtrip() {
881        let store = setup_memory_store();
882
883        let entity =
884            Entity::new("default", "document", "ResearchPaper").with_entity_type(Some("paper"));
885        let id = entity.id;
886
887        store.upsert_entity(entity).await.unwrap();
888
889        let fetched = store.get_entity(id).await.unwrap().unwrap();
890        assert_eq!(fetched.entity_type, Some("paper".to_string()));
891        assert_eq!(fetched.kind, "document");
892        assert_eq!(fetched.name, "ResearchPaper");
893    }
894
895    #[tokio::test]
896    async fn test_query_by_kind_and_entity_type() {
897        let store = setup_memory_store_ns("et_ns");
898
899        let typed =
900            Entity::new("et_ns", "person", "Researcher").with_entity_type(Some("researcher"));
901        let untyped = make_entity("et_ns", "person", "Generic");
902
903        store.upsert_entity(typed).await.unwrap();
904        store.upsert_entity(untyped).await.unwrap();
905
906        let result = store
907            .query_entities(
908                "et_ns",
909                EntityFilter {
910                    entity_types: vec!["researcher".to_string()],
911                    ..Default::default()
912                },
913                PageRequest::default(),
914            )
915            .await
916            .unwrap();
917
918        assert_eq!(result.items.len(), 1);
919        assert_eq!(result.items[0].name, "Researcher");
920        assert_eq!(result.items[0].entity_type, Some("researcher".to_string()));
921    }
922
923    /// UUID is globally unique (id TEXT PRIMARY KEY). Upserting the same UUID in a
924    /// different namespace overwrites the row (INSERT OR REPLACE). get_entity by ID
925    /// returns whichever namespace currently owns that UUID.
926    #[tokio::test]
927    async fn test_same_id_upsert_replaces_row() {
928        let pool = setup_pool();
929        let store = SqlEntityStore::new(Arc::clone(&pool), false);
930
931        let shared_id = Uuid::new_v4();
932        let now = chrono::Utc::now().timestamp_micros();
933
934        let entity_a = Entity {
935            id: shared_id,
936            namespace: "ns_a".to_string(),
937            kind: "concept".to_string(),
938            entity_type: None,
939            name: "SharedInA".to_string(),
940            description: None,
941            properties: None,
942            tags: Vec::new(),
943            created_at: now,
944            updated_at: now,
945            deleted_at: None,
946            merged_into: None,
947            merge_event_id: None,
948        };
949        store.upsert_entity(entity_a).await.unwrap();
950
951        // At this point the row is in ns_a.
952        let fetched = store.get_entity(shared_id).await.unwrap().unwrap();
953        assert_eq!(fetched.namespace, "ns_a");
954        assert_eq!(fetched.name, "SharedInA");
955
956        // Upsert same UUID into ns_b — INSERT OR REPLACE replaces the row.
957        let entity_b = Entity {
958            id: shared_id,
959            namespace: "ns_b".to_string(),
960            kind: "concept".to_string(),
961            entity_type: None,
962            name: "SharedInB".to_string(),
963            description: None,
964            properties: None,
965            tags: Vec::new(),
966            created_at: now,
967            updated_at: now,
968            deleted_at: None,
969            merged_into: None,
970            merge_event_id: None,
971        };
972        store.upsert_entity(entity_b).await.unwrap();
973
974        // Now the row is in ns_b — get_entity returns ns_b regardless of which namespace
975        // you query from (namespace is caller's responsibility).
976        let fetched = store.get_entity(shared_id).await.unwrap().unwrap();
977        assert_eq!(fetched.namespace, "ns_b");
978        assert_eq!(fetched.name, "SharedInB");
979
980        // ns_a now has 0 entities; ns_b has 1.
981        let count_a = store
982            .count_entities("ns_a", EntityFilter::default())
983            .await
984            .unwrap();
985        let count_b = store
986            .count_entities("ns_b", EntityFilter::default())
987            .await
988            .unwrap();
989        assert_eq!(count_a, 0);
990        assert_eq!(count_b, 1);
991    }
992}