Skip to main content

khive_db/stores/
entity.rs

1//! SQL-backed `EntityStore` implementation.
2
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use uuid::Uuid;
7
8use khive_storage::entity::{Entity, EntityFilter};
9use khive_storage::error::StorageError;
10use khive_storage::types::{BatchWriteSummary, DeleteMode, Page, PageRequest};
11use khive_storage::EntityStore;
12use khive_storage::StorageCapability;
13
14use crate::error::SqliteError;
15use crate::pool::ConnectionPool;
16
17fn map_err(e: rusqlite::Error, op: &'static str) -> StorageError {
18    StorageError::driver(StorageCapability::Entities, op, e)
19}
20
21fn map_sqlite_err(e: SqliteError, op: &'static str) -> StorageError {
22    StorageError::driver(StorageCapability::Entities, op, e)
23}
24
25/// An EntityStore backed by SQLite. Namespace is the caller's responsibility.
26///
27/// UUID is globally unique — get/delete by ID alone. Query/count use the
28/// namespace parameter as passed. The store is just a pool + is_file_backed.
29pub struct SqlEntityStore {
30    pool: Arc<ConnectionPool>,
31    is_file_backed: bool,
32}
33
34impl SqlEntityStore {
35    /// Create a new store.
36    pub fn new(pool: Arc<ConnectionPool>, is_file_backed: bool) -> Self {
37        Self {
38            pool,
39            is_file_backed,
40        }
41    }
42
43    fn open_standalone_reader(&self) -> Result<rusqlite::Connection, StorageError> {
44        let config = self.pool.config();
45        let path = config.path.as_ref().ok_or_else(|| StorageError::Pool {
46            operation: "entity_reader".into(),
47            message: "in-memory databases do not support standalone connections".into(),
48        })?;
49
50        let conn = rusqlite::Connection::open_with_flags(
51            path,
52            rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY
53                | rusqlite::OpenFlags::SQLITE_OPEN_NO_MUTEX
54                | rusqlite::OpenFlags::SQLITE_OPEN_URI,
55        )
56        .map_err(|e| map_err(e, "open_entity_reader"))?;
57
58        conn.busy_timeout(config.busy_timeout)
59            .map_err(|e| map_err(e, "open_entity_reader"))?;
60        conn.pragma_update(None, "foreign_keys", "ON")
61            .map_err(|e| map_err(e, "open_entity_reader"))?;
62        conn.pragma_update(None, "synchronous", "NORMAL")
63            .map_err(|e| map_err(e, "open_entity_reader"))?;
64
65        Ok(conn)
66    }
67
68    async fn with_writer<F, R>(&self, op: &'static str, f: F) -> Result<R, StorageError>
69    where
70        F: FnOnce(&rusqlite::Connection) -> Result<R, rusqlite::Error> + Send + 'static,
71        R: Send + 'static,
72    {
73        let pool = Arc::clone(&self.pool);
74        tokio::task::spawn_blocking(move || {
75            let guard = pool.try_writer().map_err(|e| map_sqlite_err(e, op))?;
76            f(guard.conn()).map_err(|e| map_err(e, op))
77        })
78        .await
79        .map_err(|e| StorageError::driver(StorageCapability::Entities, op, e))?
80    }
81
82    async fn with_reader<F, R>(&self, op: &'static str, f: F) -> Result<R, StorageError>
83    where
84        F: FnOnce(&rusqlite::Connection) -> Result<R, rusqlite::Error> + Send + 'static,
85        R: Send + 'static,
86    {
87        if self.is_file_backed {
88            let conn = self.open_standalone_reader()?;
89            tokio::task::spawn_blocking(move || f(&conn).map_err(|e| map_err(e, op)))
90                .await
91                .map_err(|e| StorageError::driver(StorageCapability::Entities, op, e))?
92        } else {
93            let pool = Arc::clone(&self.pool);
94            tokio::task::spawn_blocking(move || {
95                let guard = pool.reader().map_err(|e| map_sqlite_err(e, op))?;
96                f(guard.conn()).map_err(|e| map_err(e, op))
97            })
98            .await
99            .map_err(|e| StorageError::driver(StorageCapability::Entities, op, e))?
100        }
101    }
102}
103
104// =============================================================================
105// Helpers
106// =============================================================================
107
108fn read_entity(row: &rusqlite::Row<'_>) -> Result<Entity, rusqlite::Error> {
109    let id_str: String = row.get(0)?;
110    let namespace: String = row.get(1)?;
111    let kind: String = row.get(2)?;
112    let entity_type: Option<String> = row.get(3)?;
113    let name: String = row.get(4)?;
114    let description: Option<String> = row.get(5)?;
115    let properties_str: Option<String> = row.get(6)?;
116    let tags_str: String = row.get(7)?;
117    let created_at: i64 = row.get(8)?;
118    let updated_at: i64 = row.get(9)?;
119    let deleted_at: Option<i64> = row.get(10)?;
120    let merged_into_str: Option<String> = row.get(11)?;
121    let merge_event_id_str: Option<String> = row.get(12)?;
122
123    let id = parse_uuid(&id_str)?;
124
125    let properties = properties_str
126        .map(|s| {
127            serde_json::from_str(&s).map_err(|e| {
128                rusqlite::Error::FromSqlConversionFailure(
129                    6,
130                    rusqlite::types::Type::Text,
131                    Box::new(e),
132                )
133            })
134        })
135        .transpose()?;
136
137    let tags: Vec<String> = serde_json::from_str(&tags_str).map_err(|e| {
138        rusqlite::Error::FromSqlConversionFailure(7, rusqlite::types::Type::Text, Box::new(e))
139    })?;
140
141    let merged_into = merged_into_str
142        .as_deref()
143        .map(Uuid::parse_str)
144        .transpose()
145        .map_err(|e| {
146            rusqlite::Error::FromSqlConversionFailure(10, rusqlite::types::Type::Text, Box::new(e))
147        })?;
148
149    let merge_event_id = merge_event_id_str
150        .as_deref()
151        .map(Uuid::parse_str)
152        .transpose()
153        .map_err(|e| {
154            rusqlite::Error::FromSqlConversionFailure(11, rusqlite::types::Type::Text, Box::new(e))
155        })?;
156
157    Ok(Entity {
158        id,
159        namespace,
160        kind,
161        entity_type,
162        name,
163        description,
164        properties,
165        tags,
166        created_at,
167        updated_at,
168        deleted_at,
169        merged_into,
170        merge_event_id,
171    })
172}
173
174fn parse_uuid(s: &str) -> Result<Uuid, rusqlite::Error> {
175    Uuid::parse_str(s).map_err(|e| {
176        rusqlite::Error::FromSqlConversionFailure(0, rusqlite::types::Type::Text, Box::new(e))
177    })
178}
179
180fn build_entity_where(
181    namespace: &str,
182    filter: &EntityFilter,
183) -> (String, Vec<Box<dyn rusqlite::types::ToSql>>) {
184    let mut conditions: Vec<String> = vec![
185        "namespace = ?1".to_string(),
186        "deleted_at IS NULL".to_string(),
187    ];
188    let mut params: Vec<Box<dyn rusqlite::types::ToSql>> = vec![Box::new(namespace.to_string())];
189
190    if !filter.ids.is_empty() {
191        let placeholders: Vec<String> = filter
192            .ids
193            .iter()
194            .map(|id| {
195                params.push(Box::new(id.to_string()));
196                format!("?{}", params.len())
197            })
198            .collect();
199        conditions.push(format!("id IN ({})", placeholders.join(", ")));
200    }
201
202    if !filter.kinds.is_empty() {
203        let placeholders: Vec<String> = filter
204            .kinds
205            .iter()
206            .map(|k| {
207                params.push(Box::new(k.clone()));
208                format!("?{}", params.len())
209            })
210            .collect();
211        conditions.push(format!("kind IN ({})", placeholders.join(", ")));
212    }
213
214    if !filter.entity_types.is_empty() {
215        let placeholders: Vec<String> = filter
216            .entity_types
217            .iter()
218            .map(|t| {
219                params.push(Box::new(t.clone()));
220                format!("?{}", params.len())
221            })
222            .collect();
223        conditions.push(format!("entity_type IN ({})", placeholders.join(", ")));
224    }
225
226    if let Some(ref prefix) = filter.name_prefix {
227        params.push(Box::new(format!("{}%", prefix)));
228        conditions.push(format!("name LIKE ?{}", params.len()));
229    }
230
231    if !filter.tags_any.is_empty() {
232        let placeholders: Vec<String> = filter
233            .tags_any
234            .iter()
235            .map(|t| {
236                params.push(Box::new(t.clone()));
237                format!("?{}", params.len())
238            })
239            .collect();
240        conditions.push(format!(
241            "EXISTS (SELECT 1 FROM json_each(tags) WHERE json_each.value IN ({}))",
242            placeholders.join(", ")
243        ));
244    }
245
246    let clause = format!(" WHERE {}", conditions.join(" AND "));
247    (clause, params)
248}
249
250// =============================================================================
251// EntityStore implementation
252// =============================================================================
253
254#[async_trait]
255impl EntityStore for SqlEntityStore {
256    async fn upsert_entity(&self, entity: Entity) -> Result<(), StorageError> {
257        let namespace = entity.namespace.clone();
258        let id_str = entity.id.to_string();
259        let properties_str = entity
260            .properties
261            .as_ref()
262            .map(|v| serde_json::to_string(v).unwrap_or_default());
263        let tags_str = serde_json::to_string(&entity.tags).unwrap_or_else(|_| "[]".to_string());
264
265        let merged_into_str = entity.merged_into.map(|u| u.to_string());
266        let merge_event_id_str = entity.merge_event_id.map(|u| u.to_string());
267
268        self.with_writer("upsert_entity", move |conn| {
269            conn.execute(
270                "INSERT OR REPLACE INTO entities \
271                 (id, namespace, kind, entity_type, name, description, properties, tags, \
272                  created_at, updated_at, deleted_at, merged_into, merge_event_id) \
273                 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13)",
274                rusqlite::params![
275                    id_str,
276                    namespace,
277                    entity.kind,
278                    entity.entity_type,
279                    entity.name,
280                    entity.description,
281                    properties_str,
282                    tags_str,
283                    entity.created_at,
284                    entity.updated_at,
285                    entity.deleted_at,
286                    merged_into_str,
287                    merge_event_id_str,
288                ],
289            )?;
290            Ok(())
291        })
292        .await
293    }
294
295    async fn upsert_entities(
296        &self,
297        entities: Vec<Entity>,
298    ) -> Result<BatchWriteSummary, StorageError> {
299        let attempted = entities.len() as u64;
300
301        self.with_writer("upsert_entities", move |conn| {
302            conn.execute_batch("BEGIN IMMEDIATE")?;
303            let mut affected = 0u64;
304            let mut failed = 0u64;
305            let mut first_error = String::new();
306
307            for entity in &entities {
308                let id_str = entity.id.to_string();
309                let properties_str = entity
310                    .properties
311                    .as_ref()
312                    .map(|v| serde_json::to_string(v).unwrap_or_default());
313                let tags_str =
314                    serde_json::to_string(&entity.tags).unwrap_or_else(|_| "[]".to_string());
315
316                let merged_into_str = entity.merged_into.map(|u| u.to_string());
317                let merge_event_id_str = entity.merge_event_id.map(|u| u.to_string());
318                match conn.execute(
319                    "INSERT OR REPLACE INTO entities \
320                     (id, namespace, kind, entity_type, name, description, properties, tags, \
321                      created_at, updated_at, deleted_at, merged_into, merge_event_id) \
322                     VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13)",
323                    rusqlite::params![
324                        id_str,
325                        &entity.namespace,
326                        entity.kind,
327                        entity.entity_type,
328                        entity.name,
329                        entity.description,
330                        properties_str,
331                        tags_str,
332                        entity.created_at,
333                        entity.updated_at,
334                        entity.deleted_at,
335                        merged_into_str,
336                        merge_event_id_str,
337                    ],
338                ) {
339                    Ok(_) => affected += 1,
340                    Err(e) => {
341                        if first_error.is_empty() {
342                            first_error = e.to_string();
343                        }
344                        failed += 1;
345                    }
346                }
347            }
348
349            if let Err(e) = conn.execute_batch("COMMIT") {
350                let _ = conn.execute_batch("ROLLBACK");
351                return Err(e);
352            }
353            Ok(BatchWriteSummary {
354                attempted,
355                affected,
356                failed,
357                first_error,
358            })
359        })
360        .await
361    }
362
363    async fn get_entity(&self, id: Uuid) -> Result<Option<Entity>, StorageError> {
364        let id_str = id.to_string();
365
366        self.with_reader("get_entity", move |conn| {
367            let mut stmt = conn.prepare(
368                "SELECT id, namespace, kind, entity_type, name, description, properties, tags, \
369                 created_at, updated_at, deleted_at, merged_into, merge_event_id \
370                 FROM entities WHERE id = ?1 AND deleted_at IS NULL",
371            )?;
372            let mut rows = stmt.query(rusqlite::params![id_str])?;
373            match rows.next()? {
374                Some(row) => Ok(Some(read_entity(row)?)),
375                None => Ok(None),
376            }
377        })
378        .await
379    }
380
381    async fn delete_entity(&self, id: Uuid, mode: DeleteMode) -> Result<bool, StorageError> {
382        let id_str = id.to_string();
383
384        match mode {
385            DeleteMode::Soft => {
386                self.with_writer("delete_entity_soft", move |conn| {
387                    let now = chrono::Utc::now().timestamp_micros();
388                    let deleted = conn.execute(
389                        "UPDATE entities SET deleted_at = ?1 \
390                         WHERE id = ?2 AND deleted_at IS NULL",
391                        rusqlite::params![now, id_str],
392                    )?;
393                    Ok(deleted > 0)
394                })
395                .await
396            }
397            DeleteMode::Hard => {
398                self.with_writer("delete_entity_hard", move |conn| {
399                    let deleted = conn.execute(
400                        "DELETE FROM entities WHERE id = ?1",
401                        rusqlite::params![id_str],
402                    )?;
403                    Ok(deleted > 0)
404                })
405                .await
406            }
407        }
408    }
409
410    async fn query_entities(
411        &self,
412        namespace: &str,
413        filter: EntityFilter,
414        page: PageRequest,
415    ) -> Result<Page<Entity>, StorageError> {
416        let namespace = namespace.to_string();
417
418        self.with_reader("query_entities", move |conn| {
419            let (count_sql, count_params) = build_entity_where(&namespace, &filter);
420            let total: i64 = {
421                let sql = format!("SELECT COUNT(*) FROM entities{}", count_sql);
422                let mut stmt = conn.prepare(&sql)?;
423                let param_refs: Vec<&dyn rusqlite::types::ToSql> =
424                    count_params.iter().map(|p| p.as_ref()).collect();
425                stmt.query_row(param_refs.as_slice(), |row| row.get(0))?
426            };
427
428            let (where_sql, mut data_params) = build_entity_where(&namespace, &filter);
429            data_params.push(Box::new(page.limit as i64));
430            data_params.push(Box::new(page.offset as i64));
431
432            let limit_idx = data_params.len() - 1;
433            let offset_idx = data_params.len();
434
435            let data_sql = format!(
436                "SELECT id, namespace, kind, entity_type, name, description, properties, tags, \
437                 created_at, updated_at, deleted_at, merged_into, merge_event_id \
438                 FROM entities{} ORDER BY created_at DESC LIMIT ?{} OFFSET ?{}",
439                where_sql, limit_idx, offset_idx,
440            );
441
442            let mut stmt = conn.prepare(&data_sql)?;
443            let param_refs: Vec<&dyn rusqlite::types::ToSql> =
444                data_params.iter().map(|p| p.as_ref()).collect();
445            let rows = stmt.query_map(param_refs.as_slice(), read_entity)?;
446
447            let mut items = Vec::new();
448            for row in rows {
449                items.push(row?);
450            }
451
452            Ok(Page {
453                items,
454                total: Some(total as u64),
455            })
456        })
457        .await
458    }
459
460    async fn count_entities(
461        &self,
462        namespace: &str,
463        filter: EntityFilter,
464    ) -> Result<u64, StorageError> {
465        let namespace = namespace.to_string();
466
467        self.with_reader("count_entities", move |conn| {
468            let (where_sql, params) = build_entity_where(&namespace, &filter);
469            let sql = format!("SELECT COUNT(*) FROM entities{}", where_sql);
470            let mut stmt = conn.prepare(&sql)?;
471            let param_refs: Vec<&dyn rusqlite::types::ToSql> =
472                params.iter().map(|p| p.as_ref()).collect();
473            let count: i64 = stmt.query_row(param_refs.as_slice(), |row| row.get(0))?;
474            Ok(count as u64)
475        })
476        .await
477    }
478}
479
480// =============================================================================
481// DDL
482// =============================================================================
483
484const ENTITIES_DDL: &str = "\
485    CREATE TABLE IF NOT EXISTS entities (\
486        id TEXT PRIMARY KEY,\
487        namespace TEXT NOT NULL,\
488        kind TEXT NOT NULL,\
489        entity_type TEXT,\
490        name TEXT NOT NULL,\
491        description TEXT,\
492        properties TEXT,\
493        tags TEXT NOT NULL DEFAULT '[]',\
494        created_at INTEGER NOT NULL,\
495        updated_at INTEGER NOT NULL,\
496        deleted_at INTEGER,\
497        merged_into TEXT,\
498        merge_event_id TEXT\
499    );\
500    CREATE INDEX IF NOT EXISTS idx_entities_namespace ON entities(namespace);\
501    CREATE INDEX IF NOT EXISTS idx_entities_kind ON entities(namespace, kind);\
502    CREATE INDEX IF NOT EXISTS idx_entities_kind_entity_type ON entities(namespace, kind, entity_type);\
503    CREATE INDEX IF NOT EXISTS idx_entities_name ON entities(namespace, name);\
504    CREATE INDEX IF NOT EXISTS idx_entities_created ON entities(created_at DESC);\
505    CREATE INDEX IF NOT EXISTS idx_entities_merged_into ON entities(namespace, merged_into);\
506";
507
508pub(crate) fn ensure_entities_schema(conn: &rusqlite::Connection) -> Result<(), rusqlite::Error> {
509    conn.execute_batch(ENTITIES_DDL)
510}
511
512#[cfg(test)]
513mod tests {
514    use super::*;
515    use crate::pool::PoolConfig;
516
517    fn setup_pool() -> Arc<ConnectionPool> {
518        let config = PoolConfig {
519            path: None,
520            ..PoolConfig::default()
521        };
522        let pool = Arc::new(ConnectionPool::new(config).unwrap());
523        {
524            let writer = pool.writer().unwrap();
525            writer.conn().execute_batch(ENTITIES_DDL).unwrap();
526        }
527        pool
528    }
529
530    fn setup_memory_store() -> SqlEntityStore {
531        SqlEntityStore::new(setup_pool(), false)
532    }
533
534    fn setup_memory_store_ns(_ns: &str) -> SqlEntityStore {
535        SqlEntityStore::new(setup_pool(), false)
536    }
537
538    fn make_entity(namespace: &str, kind: &str, name: &str) -> Entity {
539        let now = chrono::Utc::now().timestamp_micros();
540        Entity {
541            id: Uuid::new_v4(),
542            namespace: namespace.to_string(),
543            kind: kind.to_string(),
544            entity_type: None,
545            name: name.to_string(),
546            description: None,
547            properties: None,
548            tags: Vec::new(),
549            created_at: now,
550            updated_at: now,
551            deleted_at: None,
552            merged_into: None,
553            merge_event_id: None,
554        }
555    }
556
557    #[tokio::test]
558    async fn test_upsert_and_get_entity() {
559        let store = setup_memory_store();
560
561        let entity = make_entity("default", "concept", "LoRA");
562        let id = entity.id;
563
564        store.upsert_entity(entity).await.unwrap();
565
566        let fetched = store.get_entity(id).await.unwrap();
567        assert!(fetched.is_some());
568        let fetched = fetched.unwrap();
569        assert_eq!(fetched.id, id);
570        assert_eq!(fetched.name, "LoRA");
571        assert_eq!(fetched.kind, "concept");
572    }
573
574    #[tokio::test]
575    async fn test_upsert_with_builder() {
576        let store = setup_memory_store();
577
578        let props = serde_json::json!({"domain": "fine-tuning", "type": "technique"});
579        let entity = Entity::new("default", "concept", "QLoRA")
580            .with_description("Quantized LoRA")
581            .with_properties(props.clone())
582            .with_tags(vec!["fine-tuning".to_string(), "quantization".to_string()]);
583        let id = entity.id;
584
585        store.upsert_entity(entity).await.unwrap();
586
587        let fetched = store.get_entity(id).await.unwrap().unwrap();
588        assert_eq!(fetched.description.as_deref(), Some("Quantized LoRA"));
589        assert_eq!(fetched.properties, Some(props));
590        assert_eq!(fetched.tags, vec!["fine-tuning", "quantization"]);
591    }
592
593    #[tokio::test]
594    async fn test_soft_delete() {
595        let store = setup_memory_store();
596
597        let entity = make_entity("default", "concept", "to-delete");
598        let id = entity.id;
599        store.upsert_entity(entity).await.unwrap();
600
601        let deleted = store.delete_entity(id, DeleteMode::Soft).await.unwrap();
602        assert!(deleted);
603
604        let fetched = store.get_entity(id).await.unwrap();
605        assert!(fetched.is_none());
606    }
607
608    #[tokio::test]
609    async fn test_hard_delete() {
610        let store = setup_memory_store();
611
612        let entity = make_entity("default", "concept", "to-hard-delete");
613        let id = entity.id;
614        store.upsert_entity(entity).await.unwrap();
615
616        let deleted = store.delete_entity(id, DeleteMode::Hard).await.unwrap();
617        assert!(deleted);
618
619        let fetched = store.get_entity(id).await.unwrap();
620        assert!(fetched.is_none());
621    }
622
623    #[tokio::test]
624    async fn test_query_entities_basic() {
625        let store = setup_memory_store_ns("ns1");
626
627        for name in &["Alpha", "Beta", "Gamma"] {
628            store
629                .upsert_entity(make_entity("ns1", "concept", name))
630                .await
631                .unwrap();
632        }
633        store
634            .upsert_entity(make_entity("ns1", "document", "Paper1"))
635            .await
636            .unwrap();
637
638        let page = store
639            .query_entities(
640                "ns1",
641                EntityFilter::default(),
642                PageRequest {
643                    offset: 0,
644                    limit: 10,
645                },
646            )
647            .await
648            .unwrap();
649        assert_eq!(page.items.len(), 4);
650        assert_eq!(page.total, Some(4));
651
652        // Filter by kind
653        let concepts = store
654            .query_entities(
655                "ns1",
656                EntityFilter {
657                    kinds: vec!["concept".to_string()],
658                    ..Default::default()
659                },
660                PageRequest::default(),
661            )
662            .await
663            .unwrap();
664        assert_eq!(concepts.items.len(), 3);
665    }
666
667    #[tokio::test]
668    async fn test_query_by_name_prefix() {
669        let store = setup_memory_store_ns("ns1");
670
671        // "Alpha" and "AlphaGo" both start with "Alpha"; "Beta" does not
672        for &name in &["Alpha", "AlphaGo", "Beta"] {
673            store
674                .upsert_entity(make_entity("ns1", "concept", name))
675                .await
676                .unwrap();
677        }
678
679        let result = store
680            .query_entities(
681                "ns1",
682                EntityFilter {
683                    name_prefix: Some("Alpha".to_string()),
684                    ..Default::default()
685                },
686                PageRequest::default(),
687            )
688            .await
689            .unwrap();
690        assert_eq!(result.items.len(), 2);
691        let names: Vec<&str> = result.items.iter().map(|e| e.name.as_str()).collect();
692        assert!(names.contains(&"Alpha"), "Alpha not found in {names:?}");
693        assert!(names.contains(&"AlphaGo"), "AlphaGo not found in {names:?}");
694        assert!(!names.contains(&"Beta"));
695    }
696
697    #[tokio::test]
698    async fn test_count_entities() {
699        let store = setup_memory_store_ns("ns1");
700
701        for _ in 0..5 {
702            store
703                .upsert_entity(make_entity("ns1", "concept", "X"))
704                .await
705                .unwrap();
706        }
707
708        let count = store
709            .count_entities("ns1", EntityFilter::default())
710            .await
711            .unwrap();
712        assert_eq!(count, 5);
713
714        // Namespace is the caller's responsibility — querying "ns2" returns 0
715        // because no entities were inserted in that namespace.
716        let count_other = store
717            .count_entities("ns2", EntityFilter::default())
718            .await
719            .unwrap();
720        assert_eq!(count_other, 0);
721    }
722
723    #[tokio::test]
724    async fn test_batch_upsert() {
725        let store = setup_memory_store_ns("batch_ns");
726
727        let entities: Vec<Entity> = (0..10)
728            .map(|i| make_entity("batch_ns", "concept", &format!("entity_{i}")))
729            .collect();
730
731        let summary = store.upsert_entities(entities).await.unwrap();
732        assert_eq!(summary.attempted, 10);
733        assert_eq!(summary.affected, 10);
734        assert_eq!(summary.failed, 0);
735
736        let count = store
737            .count_entities("batch_ns", EntityFilter::default())
738            .await
739            .unwrap();
740        assert_eq!(count, 10);
741    }
742
743    /// One store, two namespaces — each query sees only its own.
744    #[tokio::test]
745    async fn test_namespace_isolation() {
746        let pool = setup_pool();
747        let store = SqlEntityStore::new(Arc::clone(&pool), false);
748
749        store
750            .upsert_entity(make_entity("ns_a", "concept", "EntityA"))
751            .await
752            .unwrap();
753        store
754            .upsert_entity(make_entity("ns_b", "concept", "EntityB"))
755            .await
756            .unwrap();
757
758        // Namespace is the caller's responsibility — pass it in the query.
759        let count_a = store
760            .count_entities("ns_a", EntityFilter::default())
761            .await
762            .unwrap();
763        let count_b = store
764            .count_entities("ns_b", EntityFilter::default())
765            .await
766            .unwrap();
767
768        assert_eq!(count_a, 1);
769        assert_eq!(count_b, 1);
770
771        let page_a = store
772            .query_entities("ns_a", EntityFilter::default(), PageRequest::default())
773            .await
774            .unwrap();
775        assert_eq!(page_a.items[0].name, "EntityA");
776
777        let page_b = store
778            .query_entities("ns_b", EntityFilter::default(), PageRequest::default())
779            .await
780            .unwrap();
781        assert_eq!(page_b.items[0].name, "EntityB");
782    }
783
784    #[tokio::test]
785    async fn test_query_by_tags() {
786        let store = setup_memory_store_ns("tags_ns");
787
788        let mut e1 = make_entity("tags_ns", "concept", "Tagged1");
789        e1.tags = vec!["rust".to_string(), "systems".to_string()];
790        let mut e2 = make_entity("tags_ns", "concept", "Tagged2");
791        e2.tags = vec!["python".to_string(), "ml".to_string()];
792        let mut e3 = make_entity("tags_ns", "concept", "Tagged3");
793        e3.tags = vec!["rust".to_string(), "ml".to_string()];
794
795        store.upsert_entity(e1).await.unwrap();
796        store.upsert_entity(e2).await.unwrap();
797        store.upsert_entity(e3).await.unwrap();
798
799        // Filter by "rust" tag — should match Tagged1 and Tagged3
800        let result = store
801            .query_entities(
802                "tags_ns",
803                EntityFilter {
804                    tags_any: vec!["rust".to_string()],
805                    ..Default::default()
806                },
807                PageRequest::default(),
808            )
809            .await
810            .unwrap();
811        assert_eq!(result.items.len(), 2);
812        let names: Vec<&str> = result.items.iter().map(|e| e.name.as_str()).collect();
813        assert!(names.contains(&"Tagged1"));
814        assert!(names.contains(&"Tagged3"));
815        assert!(!names.contains(&"Tagged2"));
816
817        // Filter by "ml" tag — should match Tagged2 and Tagged3
818        let result = store
819            .query_entities(
820                "tags_ns",
821                EntityFilter {
822                    tags_any: vec!["ml".to_string()],
823                    ..Default::default()
824                },
825                PageRequest::default(),
826            )
827            .await
828            .unwrap();
829        assert_eq!(result.items.len(), 2);
830
831        // Filter by both "rust" and "python" (union) — should match all three
832        let result = store
833            .query_entities(
834                "tags_ns",
835                EntityFilter {
836                    tags_any: vec!["rust".to_string(), "python".to_string()],
837                    ..Default::default()
838                },
839                PageRequest::default(),
840            )
841            .await
842            .unwrap();
843        assert_eq!(result.items.len(), 3);
844    }
845
846    #[tokio::test]
847    async fn test_query_by_ids() {
848        let store = setup_memory_store_ns("ns1");
849
850        let e1 = make_entity("ns1", "concept", "E1");
851        let e2 = make_entity("ns1", "concept", "E2");
852        let e3 = make_entity("ns1", "concept", "E3");
853        let ids = vec![e1.id, e3.id];
854
855        store.upsert_entity(e1).await.unwrap();
856        store.upsert_entity(e2).await.unwrap();
857        store.upsert_entity(e3).await.unwrap();
858
859        let result = store
860            .query_entities(
861                "ns1",
862                EntityFilter {
863                    ids,
864                    ..Default::default()
865                },
866                PageRequest::default(),
867            )
868            .await
869            .unwrap();
870        assert_eq!(result.items.len(), 2);
871        let names: Vec<&str> = result.items.iter().map(|e| e.name.as_str()).collect();
872        assert!(names.contains(&"E1"));
873        assert!(names.contains(&"E3"));
874        assert!(!names.contains(&"E2"));
875    }
876
877    #[tokio::test]
878    async fn test_entity_type_roundtrip() {
879        let store = setup_memory_store();
880
881        let entity =
882            Entity::new("default", "document", "ResearchPaper").with_entity_type(Some("paper"));
883        let id = entity.id;
884
885        store.upsert_entity(entity).await.unwrap();
886
887        let fetched = store.get_entity(id).await.unwrap().unwrap();
888        assert_eq!(fetched.entity_type, Some("paper".to_string()));
889        assert_eq!(fetched.kind, "document");
890        assert_eq!(fetched.name, "ResearchPaper");
891    }
892
893    #[tokio::test]
894    async fn test_query_by_kind_and_entity_type() {
895        let store = setup_memory_store_ns("et_ns");
896
897        let typed =
898            Entity::new("et_ns", "person", "Researcher").with_entity_type(Some("researcher"));
899        let untyped = make_entity("et_ns", "person", "Generic");
900
901        store.upsert_entity(typed).await.unwrap();
902        store.upsert_entity(untyped).await.unwrap();
903
904        let result = store
905            .query_entities(
906                "et_ns",
907                EntityFilter {
908                    entity_types: vec!["researcher".to_string()],
909                    ..Default::default()
910                },
911                PageRequest::default(),
912            )
913            .await
914            .unwrap();
915
916        assert_eq!(result.items.len(), 1);
917        assert_eq!(result.items[0].name, "Researcher");
918        assert_eq!(result.items[0].entity_type, Some("researcher".to_string()));
919    }
920
921    /// UUID is globally unique (id TEXT PRIMARY KEY). Upserting the same UUID in a
922    /// different namespace overwrites the row (INSERT OR REPLACE). get_entity by ID
923    /// returns whichever namespace currently owns that UUID.
924    #[tokio::test]
925    async fn test_same_id_upsert_replaces_row() {
926        let pool = setup_pool();
927        let store = SqlEntityStore::new(Arc::clone(&pool), false);
928
929        let shared_id = Uuid::new_v4();
930        let now = chrono::Utc::now().timestamp_micros();
931
932        let entity_a = Entity {
933            id: shared_id,
934            namespace: "ns_a".to_string(),
935            kind: "concept".to_string(),
936            entity_type: None,
937            name: "SharedInA".to_string(),
938            description: None,
939            properties: None,
940            tags: Vec::new(),
941            created_at: now,
942            updated_at: now,
943            deleted_at: None,
944            merged_into: None,
945            merge_event_id: None,
946        };
947        store.upsert_entity(entity_a).await.unwrap();
948
949        // At this point the row is in ns_a.
950        let fetched = store.get_entity(shared_id).await.unwrap().unwrap();
951        assert_eq!(fetched.namespace, "ns_a");
952        assert_eq!(fetched.name, "SharedInA");
953
954        // Upsert same UUID into ns_b — INSERT OR REPLACE replaces the row.
955        let entity_b = Entity {
956            id: shared_id,
957            namespace: "ns_b".to_string(),
958            kind: "concept".to_string(),
959            entity_type: None,
960            name: "SharedInB".to_string(),
961            description: None,
962            properties: None,
963            tags: Vec::new(),
964            created_at: now,
965            updated_at: now,
966            deleted_at: None,
967            merged_into: None,
968            merge_event_id: None,
969        };
970        store.upsert_entity(entity_b).await.unwrap();
971
972        // Now the row is in ns_b — get_entity returns ns_b regardless of which namespace
973        // you query from (namespace is caller's responsibility).
974        let fetched = store.get_entity(shared_id).await.unwrap().unwrap();
975        assert_eq!(fetched.namespace, "ns_b");
976        assert_eq!(fetched.name, "SharedInB");
977
978        // ns_a now has 0 entities; ns_b has 1.
979        let count_a = store
980            .count_entities("ns_a", EntityFilter::default())
981            .await
982            .unwrap();
983        let count_b = store
984            .count_entities("ns_b", EntityFilter::default())
985            .await
986            .unwrap();
987        assert_eq!(count_a, 0);
988        assert_eq!(count_b, 1);
989    }
990}