Skip to main content

khive_db/stores/
entity.rs

1//! SQL-backed `EntityStore` implementation.
2
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use uuid::Uuid;
7
8use khive_storage::entity::{Entity, EntityFilter};
9use khive_storage::error::StorageError;
10use khive_storage::types::{BatchWriteSummary, DeleteMode, Page, PageRequest};
11use khive_storage::EntityStore;
12use khive_storage::StorageCapability;
13use khive_types::EntityKind;
14
15use crate::error::SqliteError;
16use crate::pool::ConnectionPool;
17
18fn map_err(e: rusqlite::Error, op: &'static str) -> StorageError {
19    StorageError::driver(StorageCapability::Entities, op, e)
20}
21
22fn map_sqlite_err(e: SqliteError, op: &'static str) -> StorageError {
23    StorageError::driver(StorageCapability::Entities, op, e)
24}
25
26/// An EntityStore backed by SQLite. Namespace is the caller's responsibility.
27///
28/// UUID is globally unique — get/delete by ID alone. Query/count use the
29/// namespace parameter as passed. The store is just a pool + is_file_backed.
30pub struct SqlEntityStore {
31    pool: Arc<ConnectionPool>,
32    is_file_backed: bool,
33}
34
35impl SqlEntityStore {
36    /// Create a new store.
37    pub fn new(pool: Arc<ConnectionPool>, is_file_backed: bool) -> Self {
38        Self {
39            pool,
40            is_file_backed,
41        }
42    }
43
44    fn open_standalone_reader(&self) -> Result<rusqlite::Connection, StorageError> {
45        let config = self.pool.config();
46        let path = config.path.as_ref().ok_or_else(|| StorageError::Pool {
47            operation: "entity_reader".into(),
48            message: "in-memory databases do not support standalone connections".into(),
49        })?;
50
51        let conn = rusqlite::Connection::open_with_flags(
52            path,
53            rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY
54                | rusqlite::OpenFlags::SQLITE_OPEN_NO_MUTEX
55                | rusqlite::OpenFlags::SQLITE_OPEN_URI,
56        )
57        .map_err(|e| map_err(e, "open_entity_reader"))?;
58
59        conn.busy_timeout(config.busy_timeout)
60            .map_err(|e| map_err(e, "open_entity_reader"))?;
61        conn.pragma_update(None, "foreign_keys", "ON")
62            .map_err(|e| map_err(e, "open_entity_reader"))?;
63        conn.pragma_update(None, "synchronous", "NORMAL")
64            .map_err(|e| map_err(e, "open_entity_reader"))?;
65
66        Ok(conn)
67    }
68
69    async fn with_writer<F, R>(&self, op: &'static str, f: F) -> Result<R, StorageError>
70    where
71        F: FnOnce(&rusqlite::Connection) -> Result<R, rusqlite::Error> + Send + 'static,
72        R: Send + 'static,
73    {
74        let pool = Arc::clone(&self.pool);
75        tokio::task::spawn_blocking(move || {
76            let guard = pool.try_writer().map_err(|e| map_sqlite_err(e, op))?;
77            f(guard.conn()).map_err(|e| map_err(e, op))
78        })
79        .await
80        .map_err(|e| StorageError::driver(StorageCapability::Entities, op, e))?
81    }
82
83    async fn with_reader<F, R>(&self, op: &'static str, f: F) -> Result<R, StorageError>
84    where
85        F: FnOnce(&rusqlite::Connection) -> Result<R, rusqlite::Error> + Send + 'static,
86        R: Send + 'static,
87    {
88        if self.is_file_backed {
89            let conn = self.open_standalone_reader()?;
90            tokio::task::spawn_blocking(move || f(&conn).map_err(|e| map_err(e, op)))
91                .await
92                .map_err(|e| StorageError::driver(StorageCapability::Entities, op, e))?
93        } else {
94            let pool = Arc::clone(&self.pool);
95            tokio::task::spawn_blocking(move || {
96                let guard = pool.reader().map_err(|e| map_sqlite_err(e, op))?;
97                f(guard.conn()).map_err(|e| map_err(e, op))
98            })
99            .await
100            .map_err(|e| StorageError::driver(StorageCapability::Entities, op, e))?
101        }
102    }
103}
104
105// =============================================================================
106// Helpers
107// =============================================================================
108
109fn read_entity(row: &rusqlite::Row<'_>) -> Result<Entity, rusqlite::Error> {
110    let id_str: String = row.get(0)?;
111    let namespace: String = row.get(1)?;
112    let kind: String = row.get(2)?;
113    let name: String = row.get(3)?;
114    let description: Option<String> = row.get(4)?;
115    let properties_str: Option<String> = row.get(5)?;
116    let tags_str: String = row.get(6)?;
117    let created_at: i64 = row.get(7)?;
118    let updated_at: i64 = row.get(8)?;
119    let deleted_at: Option<i64> = row.get(9)?;
120
121    let id = parse_uuid(&id_str)?;
122
123    let parsed_kind = kind.parse::<EntityKind>().map_err(|e| {
124        rusqlite::Error::FromSqlConversionFailure(
125            2,
126            rusqlite::types::Type::Text,
127            Box::new(std::io::Error::new(std::io::ErrorKind::InvalidData, e)),
128        )
129    })?;
130
131    let properties = properties_str
132        .map(|s| {
133            serde_json::from_str(&s).map_err(|e| {
134                rusqlite::Error::FromSqlConversionFailure(
135                    5,
136                    rusqlite::types::Type::Text,
137                    Box::new(e),
138                )
139            })
140        })
141        .transpose()?;
142
143    let tags: Vec<String> = serde_json::from_str(&tags_str).map_err(|e| {
144        rusqlite::Error::FromSqlConversionFailure(6, rusqlite::types::Type::Text, Box::new(e))
145    })?;
146
147    Ok(Entity {
148        id,
149        namespace,
150        kind: parsed_kind,
151        name,
152        description,
153        properties,
154        tags,
155        created_at,
156        updated_at,
157        deleted_at,
158    })
159}
160
161fn parse_uuid(s: &str) -> Result<Uuid, rusqlite::Error> {
162    Uuid::parse_str(s).map_err(|e| {
163        rusqlite::Error::FromSqlConversionFailure(0, rusqlite::types::Type::Text, Box::new(e))
164    })
165}
166
167fn build_entity_where(
168    namespace: &str,
169    filter: &EntityFilter,
170) -> (String, Vec<Box<dyn rusqlite::types::ToSql>>) {
171    let mut conditions: Vec<String> = vec![
172        "namespace = ?1".to_string(),
173        "deleted_at IS NULL".to_string(),
174    ];
175    let mut params: Vec<Box<dyn rusqlite::types::ToSql>> = vec![Box::new(namespace.to_string())];
176
177    if !filter.ids.is_empty() {
178        let placeholders: Vec<String> = filter
179            .ids
180            .iter()
181            .map(|id| {
182                params.push(Box::new(id.to_string()));
183                format!("?{}", params.len())
184            })
185            .collect();
186        conditions.push(format!("id IN ({})", placeholders.join(", ")));
187    }
188
189    if !filter.kinds.is_empty() {
190        let placeholders: Vec<String> = filter
191            .kinds
192            .iter()
193            .map(|k| {
194                params.push(Box::new(k.to_string()));
195                format!("?{}", params.len())
196            })
197            .collect();
198        conditions.push(format!("kind IN ({})", placeholders.join(", ")));
199    }
200
201    if let Some(ref prefix) = filter.name_prefix {
202        params.push(Box::new(format!("{}%", prefix)));
203        conditions.push(format!("name LIKE ?{}", params.len()));
204    }
205
206    if !filter.tags_any.is_empty() {
207        let placeholders: Vec<String> = filter
208            .tags_any
209            .iter()
210            .map(|t| {
211                params.push(Box::new(t.clone()));
212                format!("?{}", params.len())
213            })
214            .collect();
215        conditions.push(format!(
216            "EXISTS (SELECT 1 FROM json_each(tags) WHERE json_each.value IN ({}))",
217            placeholders.join(", ")
218        ));
219    }
220
221    let clause = format!(" WHERE {}", conditions.join(" AND "));
222    (clause, params)
223}
224
225// =============================================================================
226// EntityStore implementation
227// =============================================================================
228
229#[async_trait]
230impl EntityStore for SqlEntityStore {
231    async fn upsert_entity(&self, entity: Entity) -> Result<(), StorageError> {
232        let namespace = entity.namespace.clone();
233        let id_str = entity.id.to_string();
234        let properties_str = entity
235            .properties
236            .as_ref()
237            .map(|v| serde_json::to_string(v).unwrap_or_default());
238        let tags_str = serde_json::to_string(&entity.tags).unwrap_or_else(|_| "[]".to_string());
239
240        self.with_writer("upsert_entity", move |conn| {
241            conn.execute(
242                "INSERT OR REPLACE INTO entities \
243                 (id, namespace, kind, name, description, properties, tags, \
244                  created_at, updated_at, deleted_at) \
245                 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)",
246                rusqlite::params![
247                    id_str,
248                    namespace,
249                    entity.kind.to_string(),
250                    entity.name,
251                    entity.description,
252                    properties_str,
253                    tags_str,
254                    entity.created_at,
255                    entity.updated_at,
256                    entity.deleted_at,
257                ],
258            )?;
259            Ok(())
260        })
261        .await
262    }
263
264    async fn upsert_entities(
265        &self,
266        entities: Vec<Entity>,
267    ) -> Result<BatchWriteSummary, StorageError> {
268        let attempted = entities.len() as u64;
269
270        self.with_writer("upsert_entities", move |conn| {
271            conn.execute_batch("BEGIN IMMEDIATE")?;
272            let mut affected = 0u64;
273            let mut failed = 0u64;
274            let mut first_error = String::new();
275
276            for entity in &entities {
277                let id_str = entity.id.to_string();
278                let properties_str = entity
279                    .properties
280                    .as_ref()
281                    .map(|v| serde_json::to_string(v).unwrap_or_default());
282                let tags_str =
283                    serde_json::to_string(&entity.tags).unwrap_or_else(|_| "[]".to_string());
284
285                match conn.execute(
286                    "INSERT OR REPLACE INTO entities \
287                     (id, namespace, kind, name, description, properties, tags, \
288                      created_at, updated_at, deleted_at) \
289                     VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)",
290                    rusqlite::params![
291                        id_str,
292                        &entity.namespace,
293                        entity.kind.to_string(),
294                        entity.name,
295                        entity.description,
296                        properties_str,
297                        tags_str,
298                        entity.created_at,
299                        entity.updated_at,
300                        entity.deleted_at,
301                    ],
302                ) {
303                    Ok(_) => affected += 1,
304                    Err(e) => {
305                        if first_error.is_empty() {
306                            first_error = e.to_string();
307                        }
308                        failed += 1;
309                    }
310                }
311            }
312
313            if let Err(e) = conn.execute_batch("COMMIT") {
314                let _ = conn.execute_batch("ROLLBACK");
315                return Err(e);
316            }
317            Ok(BatchWriteSummary {
318                attempted,
319                affected,
320                failed,
321                first_error,
322            })
323        })
324        .await
325    }
326
327    async fn get_entity(&self, id: Uuid) -> Result<Option<Entity>, StorageError> {
328        let id_str = id.to_string();
329
330        self.with_reader("get_entity", move |conn| {
331            let mut stmt = conn.prepare(
332                "SELECT id, namespace, kind, name, description, properties, tags, \
333                 created_at, updated_at, deleted_at \
334                 FROM entities WHERE id = ?1 AND deleted_at IS NULL",
335            )?;
336            let mut rows = stmt.query(rusqlite::params![id_str])?;
337            match rows.next()? {
338                Some(row) => Ok(Some(read_entity(row)?)),
339                None => Ok(None),
340            }
341        })
342        .await
343    }
344
345    async fn delete_entity(&self, id: Uuid, mode: DeleteMode) -> Result<bool, StorageError> {
346        let id_str = id.to_string();
347
348        match mode {
349            DeleteMode::Soft => {
350                self.with_writer("delete_entity_soft", move |conn| {
351                    let now = chrono::Utc::now().timestamp_micros();
352                    let deleted = conn.execute(
353                        "UPDATE entities SET deleted_at = ?1 \
354                         WHERE id = ?2 AND deleted_at IS NULL",
355                        rusqlite::params![now, id_str],
356                    )?;
357                    Ok(deleted > 0)
358                })
359                .await
360            }
361            DeleteMode::Hard => {
362                self.with_writer("delete_entity_hard", move |conn| {
363                    let deleted = conn.execute(
364                        "DELETE FROM entities WHERE id = ?1",
365                        rusqlite::params![id_str],
366                    )?;
367                    Ok(deleted > 0)
368                })
369                .await
370            }
371        }
372    }
373
374    async fn query_entities(
375        &self,
376        namespace: &str,
377        filter: EntityFilter,
378        page: PageRequest,
379    ) -> Result<Page<Entity>, StorageError> {
380        let namespace = namespace.to_string();
381
382        self.with_reader("query_entities", move |conn| {
383            let (count_sql, count_params) = build_entity_where(&namespace, &filter);
384            let total: i64 = {
385                let sql = format!("SELECT COUNT(*) FROM entities{}", count_sql);
386                let mut stmt = conn.prepare(&sql)?;
387                let param_refs: Vec<&dyn rusqlite::types::ToSql> =
388                    count_params.iter().map(|p| p.as_ref()).collect();
389                stmt.query_row(param_refs.as_slice(), |row| row.get(0))?
390            };
391
392            let (where_sql, mut data_params) = build_entity_where(&namespace, &filter);
393            data_params.push(Box::new(page.limit as i64));
394            data_params.push(Box::new(page.offset as i64));
395
396            let limit_idx = data_params.len() - 1;
397            let offset_idx = data_params.len();
398
399            let data_sql = format!(
400                "SELECT id, namespace, kind, name, description, properties, tags, \
401                 created_at, updated_at, deleted_at \
402                 FROM entities{} ORDER BY created_at DESC LIMIT ?{} OFFSET ?{}",
403                where_sql, limit_idx, offset_idx,
404            );
405
406            let mut stmt = conn.prepare(&data_sql)?;
407            let param_refs: Vec<&dyn rusqlite::types::ToSql> =
408                data_params.iter().map(|p| p.as_ref()).collect();
409            let rows = stmt.query_map(param_refs.as_slice(), read_entity)?;
410
411            let mut items = Vec::new();
412            for row in rows {
413                items.push(row?);
414            }
415
416            Ok(Page {
417                items,
418                total: Some(total as u64),
419            })
420        })
421        .await
422    }
423
424    async fn count_entities(
425        &self,
426        namespace: &str,
427        filter: EntityFilter,
428    ) -> Result<u64, StorageError> {
429        let namespace = namespace.to_string();
430
431        self.with_reader("count_entities", move |conn| {
432            let (where_sql, params) = build_entity_where(&namespace, &filter);
433            let sql = format!("SELECT COUNT(*) FROM entities{}", where_sql);
434            let mut stmt = conn.prepare(&sql)?;
435            let param_refs: Vec<&dyn rusqlite::types::ToSql> =
436                params.iter().map(|p| p.as_ref()).collect();
437            let count: i64 = stmt.query_row(param_refs.as_slice(), |row| row.get(0))?;
438            Ok(count as u64)
439        })
440        .await
441    }
442}
443
444// =============================================================================
445// DDL
446// =============================================================================
447
448const ENTITIES_DDL: &str = "\
449    CREATE TABLE IF NOT EXISTS entities (\
450        id TEXT PRIMARY KEY,\
451        namespace TEXT NOT NULL,\
452        kind TEXT NOT NULL,\
453        name TEXT NOT NULL,\
454        description TEXT,\
455        properties TEXT,\
456        tags TEXT NOT NULL DEFAULT '[]',\
457        created_at INTEGER NOT NULL,\
458        updated_at INTEGER NOT NULL,\
459        deleted_at INTEGER\
460    );\
461    CREATE INDEX IF NOT EXISTS idx_entities_namespace ON entities(namespace);\
462    CREATE INDEX IF NOT EXISTS idx_entities_kind ON entities(namespace, kind);\
463    CREATE INDEX IF NOT EXISTS idx_entities_name ON entities(namespace, name);\
464    CREATE INDEX IF NOT EXISTS idx_entities_created ON entities(created_at DESC);\
465";
466
467pub(crate) fn ensure_entities_schema(conn: &rusqlite::Connection) -> Result<(), rusqlite::Error> {
468    conn.execute_batch(ENTITIES_DDL)
469}
470
471#[cfg(test)]
472mod tests {
473    use super::*;
474    use crate::pool::PoolConfig;
475
476    fn setup_pool() -> Arc<ConnectionPool> {
477        let config = PoolConfig {
478            path: None,
479            ..PoolConfig::default()
480        };
481        let pool = Arc::new(ConnectionPool::new(config).unwrap());
482        {
483            let writer = pool.writer().unwrap();
484            writer.conn().execute_batch(ENTITIES_DDL).unwrap();
485        }
486        pool
487    }
488
489    fn setup_memory_store() -> SqlEntityStore {
490        SqlEntityStore::new(setup_pool(), false)
491    }
492
493    fn setup_memory_store_ns(_ns: &str) -> SqlEntityStore {
494        SqlEntityStore::new(setup_pool(), false)
495    }
496
497    fn make_entity(namespace: &str, kind: EntityKind, name: &str) -> Entity {
498        let now = chrono::Utc::now().timestamp_micros();
499        Entity {
500            id: Uuid::new_v4(),
501            namespace: namespace.to_string(),
502            kind,
503            name: name.to_string(),
504            description: None,
505            properties: None,
506            tags: Vec::new(),
507            created_at: now,
508            updated_at: now,
509            deleted_at: None,
510        }
511    }
512
513    #[tokio::test]
514    async fn test_upsert_and_get_entity() {
515        let store = setup_memory_store();
516
517        let entity = make_entity("default", EntityKind::Concept, "LoRA");
518        let id = entity.id;
519
520        store.upsert_entity(entity).await.unwrap();
521
522        let fetched = store.get_entity(id).await.unwrap();
523        assert!(fetched.is_some());
524        let fetched = fetched.unwrap();
525        assert_eq!(fetched.id, id);
526        assert_eq!(fetched.name, "LoRA");
527        assert_eq!(fetched.kind, EntityKind::Concept);
528    }
529
530    #[tokio::test]
531    async fn test_upsert_with_builder() {
532        let store = setup_memory_store();
533
534        let props = serde_json::json!({"domain": "fine-tuning", "type": "technique"});
535        let entity = Entity::new("default", EntityKind::Concept, "QLoRA")
536            .with_description("Quantized LoRA")
537            .with_properties(props.clone())
538            .with_tags(vec!["fine-tuning".to_string(), "quantization".to_string()]);
539        let id = entity.id;
540
541        store.upsert_entity(entity).await.unwrap();
542
543        let fetched = store.get_entity(id).await.unwrap().unwrap();
544        assert_eq!(fetched.description.as_deref(), Some("Quantized LoRA"));
545        assert_eq!(fetched.properties, Some(props));
546        assert_eq!(fetched.tags, vec!["fine-tuning", "quantization"]);
547    }
548
549    #[tokio::test]
550    async fn test_soft_delete() {
551        let store = setup_memory_store();
552
553        let entity = make_entity("default", EntityKind::Concept, "to-delete");
554        let id = entity.id;
555        store.upsert_entity(entity).await.unwrap();
556
557        let deleted = store.delete_entity(id, DeleteMode::Soft).await.unwrap();
558        assert!(deleted);
559
560        let fetched = store.get_entity(id).await.unwrap();
561        assert!(fetched.is_none());
562    }
563
564    #[tokio::test]
565    async fn test_hard_delete() {
566        let store = setup_memory_store();
567
568        let entity = make_entity("default", EntityKind::Concept, "to-hard-delete");
569        let id = entity.id;
570        store.upsert_entity(entity).await.unwrap();
571
572        let deleted = store.delete_entity(id, DeleteMode::Hard).await.unwrap();
573        assert!(deleted);
574
575        let fetched = store.get_entity(id).await.unwrap();
576        assert!(fetched.is_none());
577    }
578
579    #[tokio::test]
580    async fn test_query_entities_basic() {
581        let store = setup_memory_store_ns("ns1");
582
583        for name in &["Alpha", "Beta", "Gamma"] {
584            store
585                .upsert_entity(make_entity("ns1", EntityKind::Concept, name))
586                .await
587                .unwrap();
588        }
589        store
590            .upsert_entity(make_entity("ns1", EntityKind::Document, "Paper1"))
591            .await
592            .unwrap();
593
594        let page = store
595            .query_entities(
596                "ns1",
597                EntityFilter::default(),
598                PageRequest {
599                    offset: 0,
600                    limit: 10,
601                },
602            )
603            .await
604            .unwrap();
605        assert_eq!(page.items.len(), 4);
606        assert_eq!(page.total, Some(4));
607
608        // Filter by kind
609        let concepts = store
610            .query_entities(
611                "ns1",
612                EntityFilter {
613                    kinds: vec![EntityKind::Concept],
614                    ..Default::default()
615                },
616                PageRequest::default(),
617            )
618            .await
619            .unwrap();
620        assert_eq!(concepts.items.len(), 3);
621    }
622
623    #[tokio::test]
624    async fn test_query_by_name_prefix() {
625        let store = setup_memory_store_ns("ns1");
626
627        // "Alpha" and "AlphaGo" both start with "Alpha"; "Beta" does not
628        for &name in &["Alpha", "AlphaGo", "Beta"] {
629            store
630                .upsert_entity(make_entity("ns1", EntityKind::Concept, name))
631                .await
632                .unwrap();
633        }
634
635        let result = store
636            .query_entities(
637                "ns1",
638                EntityFilter {
639                    name_prefix: Some("Alpha".to_string()),
640                    ..Default::default()
641                },
642                PageRequest::default(),
643            )
644            .await
645            .unwrap();
646        assert_eq!(result.items.len(), 2);
647        let names: Vec<&str> = result.items.iter().map(|e| e.name.as_str()).collect();
648        assert!(names.contains(&"Alpha"), "Alpha not found in {names:?}");
649        assert!(names.contains(&"AlphaGo"), "AlphaGo not found in {names:?}");
650        assert!(!names.contains(&"Beta"));
651    }
652
653    #[tokio::test]
654    async fn test_count_entities() {
655        let store = setup_memory_store_ns("ns1");
656
657        for _ in 0..5 {
658            store
659                .upsert_entity(make_entity("ns1", EntityKind::Concept, "X"))
660                .await
661                .unwrap();
662        }
663
664        let count = store
665            .count_entities("ns1", EntityFilter::default())
666            .await
667            .unwrap();
668        assert_eq!(count, 5);
669
670        // Namespace is the caller's responsibility — querying "ns2" returns 0
671        // because no entities were inserted in that namespace.
672        let count_other = store
673            .count_entities("ns2", EntityFilter::default())
674            .await
675            .unwrap();
676        assert_eq!(count_other, 0);
677    }
678
679    #[tokio::test]
680    async fn test_batch_upsert() {
681        let store = setup_memory_store_ns("batch_ns");
682
683        let entities: Vec<Entity> = (0..10)
684            .map(|i| make_entity("batch_ns", EntityKind::Concept, &format!("entity_{i}")))
685            .collect();
686
687        let summary = store.upsert_entities(entities).await.unwrap();
688        assert_eq!(summary.attempted, 10);
689        assert_eq!(summary.affected, 10);
690        assert_eq!(summary.failed, 0);
691
692        let count = store
693            .count_entities("batch_ns", EntityFilter::default())
694            .await
695            .unwrap();
696        assert_eq!(count, 10);
697    }
698
699    /// One store, two namespaces — each query sees only its own.
700    #[tokio::test]
701    async fn test_namespace_isolation() {
702        let pool = setup_pool();
703        let store = SqlEntityStore::new(Arc::clone(&pool), false);
704
705        store
706            .upsert_entity(make_entity("ns_a", EntityKind::Concept, "EntityA"))
707            .await
708            .unwrap();
709        store
710            .upsert_entity(make_entity("ns_b", EntityKind::Concept, "EntityB"))
711            .await
712            .unwrap();
713
714        // Namespace is the caller's responsibility — pass it in the query.
715        let count_a = store
716            .count_entities("ns_a", EntityFilter::default())
717            .await
718            .unwrap();
719        let count_b = store
720            .count_entities("ns_b", EntityFilter::default())
721            .await
722            .unwrap();
723
724        assert_eq!(count_a, 1);
725        assert_eq!(count_b, 1);
726
727        let page_a = store
728            .query_entities("ns_a", EntityFilter::default(), PageRequest::default())
729            .await
730            .unwrap();
731        assert_eq!(page_a.items[0].name, "EntityA");
732
733        let page_b = store
734            .query_entities("ns_b", EntityFilter::default(), PageRequest::default())
735            .await
736            .unwrap();
737        assert_eq!(page_b.items[0].name, "EntityB");
738    }
739
740    #[tokio::test]
741    async fn test_query_by_tags() {
742        let store = setup_memory_store_ns("tags_ns");
743
744        let mut e1 = make_entity("tags_ns", EntityKind::Concept, "Tagged1");
745        e1.tags = vec!["rust".to_string(), "systems".to_string()];
746        let mut e2 = make_entity("tags_ns", EntityKind::Concept, "Tagged2");
747        e2.tags = vec!["python".to_string(), "ml".to_string()];
748        let mut e3 = make_entity("tags_ns", EntityKind::Concept, "Tagged3");
749        e3.tags = vec!["rust".to_string(), "ml".to_string()];
750
751        store.upsert_entity(e1).await.unwrap();
752        store.upsert_entity(e2).await.unwrap();
753        store.upsert_entity(e3).await.unwrap();
754
755        // Filter by "rust" tag — should match Tagged1 and Tagged3
756        let result = store
757            .query_entities(
758                "tags_ns",
759                EntityFilter {
760                    tags_any: vec!["rust".to_string()],
761                    ..Default::default()
762                },
763                PageRequest::default(),
764            )
765            .await
766            .unwrap();
767        assert_eq!(result.items.len(), 2);
768        let names: Vec<&str> = result.items.iter().map(|e| e.name.as_str()).collect();
769        assert!(names.contains(&"Tagged1"));
770        assert!(names.contains(&"Tagged3"));
771        assert!(!names.contains(&"Tagged2"));
772
773        // Filter by "ml" tag — should match Tagged2 and Tagged3
774        let result = store
775            .query_entities(
776                "tags_ns",
777                EntityFilter {
778                    tags_any: vec!["ml".to_string()],
779                    ..Default::default()
780                },
781                PageRequest::default(),
782            )
783            .await
784            .unwrap();
785        assert_eq!(result.items.len(), 2);
786
787        // Filter by both "rust" and "python" (union) — should match all three
788        let result = store
789            .query_entities(
790                "tags_ns",
791                EntityFilter {
792                    tags_any: vec!["rust".to_string(), "python".to_string()],
793                    ..Default::default()
794                },
795                PageRequest::default(),
796            )
797            .await
798            .unwrap();
799        assert_eq!(result.items.len(), 3);
800    }
801
802    #[tokio::test]
803    async fn test_query_by_ids() {
804        let store = setup_memory_store_ns("ns1");
805
806        let e1 = make_entity("ns1", EntityKind::Concept, "E1");
807        let e2 = make_entity("ns1", EntityKind::Concept, "E2");
808        let e3 = make_entity("ns1", EntityKind::Concept, "E3");
809        let ids = vec![e1.id, e3.id];
810
811        store.upsert_entity(e1).await.unwrap();
812        store.upsert_entity(e2).await.unwrap();
813        store.upsert_entity(e3).await.unwrap();
814
815        let result = store
816            .query_entities(
817                "ns1",
818                EntityFilter {
819                    ids,
820                    ..Default::default()
821                },
822                PageRequest::default(),
823            )
824            .await
825            .unwrap();
826        assert_eq!(result.items.len(), 2);
827        let names: Vec<&str> = result.items.iter().map(|e| e.name.as_str()).collect();
828        assert!(names.contains(&"E1"));
829        assert!(names.contains(&"E3"));
830        assert!(!names.contains(&"E2"));
831    }
832
833    /// UUID is globally unique (id TEXT PRIMARY KEY). Upserting the same UUID in a
834    /// different namespace overwrites the row (INSERT OR REPLACE). get_entity by ID
835    /// returns whichever namespace currently owns that UUID.
836    #[tokio::test]
837    async fn test_same_id_upsert_replaces_row() {
838        let pool = setup_pool();
839        let store = SqlEntityStore::new(Arc::clone(&pool), false);
840
841        let shared_id = Uuid::new_v4();
842        let now = chrono::Utc::now().timestamp_micros();
843
844        let entity_a = Entity {
845            id: shared_id,
846            namespace: "ns_a".to_string(),
847            kind: EntityKind::Concept,
848            name: "SharedInA".to_string(),
849            description: None,
850            properties: None,
851            tags: Vec::new(),
852            created_at: now,
853            updated_at: now,
854            deleted_at: None,
855        };
856        store.upsert_entity(entity_a).await.unwrap();
857
858        // At this point the row is in ns_a.
859        let fetched = store.get_entity(shared_id).await.unwrap().unwrap();
860        assert_eq!(fetched.namespace, "ns_a");
861        assert_eq!(fetched.name, "SharedInA");
862
863        // Upsert same UUID into ns_b — INSERT OR REPLACE replaces the row.
864        let entity_b = Entity {
865            id: shared_id,
866            namespace: "ns_b".to_string(),
867            kind: EntityKind::Concept,
868            name: "SharedInB".to_string(),
869            description: None,
870            properties: None,
871            tags: Vec::new(),
872            created_at: now,
873            updated_at: now,
874            deleted_at: None,
875        };
876        store.upsert_entity(entity_b).await.unwrap();
877
878        // Now the row is in ns_b — get_entity returns ns_b regardless of which namespace
879        // you query from (namespace is caller's responsibility).
880        let fetched = store.get_entity(shared_id).await.unwrap().unwrap();
881        assert_eq!(fetched.namespace, "ns_b");
882        assert_eq!(fetched.name, "SharedInB");
883
884        // ns_a now has 0 entities; ns_b has 1.
885        let count_a = store
886            .count_entities("ns_a", EntityFilter::default())
887            .await
888            .unwrap();
889        let count_b = store
890            .count_entities("ns_b", EntityFilter::default())
891            .await
892            .unwrap();
893        assert_eq!(count_a, 0);
894        assert_eq!(count_b, 1);
895    }
896}