1use std::sync::Arc;
4
5use async_trait::async_trait;
6use uuid::Uuid;
7
8use khive_storage::entity::{Entity, EntityFilter};
9use khive_storage::error::StorageError;
10use khive_storage::types::{BatchWriteSummary, DeleteMode, Page, PageRequest};
11use khive_storage::EntityStore;
12use khive_storage::StorageCapability;
13
14use crate::error::SqliteError;
15use crate::pool::ConnectionPool;
16
17fn map_err(e: rusqlite::Error, op: &'static str) -> StorageError {
18 StorageError::driver(StorageCapability::Entities, op, e)
19}
20
21fn map_sqlite_err(e: SqliteError, op: &'static str) -> StorageError {
22 StorageError::driver(StorageCapability::Entities, op, e)
23}
24
25pub struct SqlEntityStore {
30 pool: Arc<ConnectionPool>,
31 is_file_backed: bool,
32}
33
34impl SqlEntityStore {
35 pub fn new(pool: Arc<ConnectionPool>, is_file_backed: bool) -> Self {
37 Self {
38 pool,
39 is_file_backed,
40 }
41 }
42
43 fn open_standalone_reader(&self) -> Result<rusqlite::Connection, StorageError> {
44 let config = self.pool.config();
45 let path = config.path.as_ref().ok_or_else(|| StorageError::Pool {
46 operation: "entity_reader".into(),
47 message: "in-memory databases do not support standalone connections".into(),
48 })?;
49
50 let conn = rusqlite::Connection::open_with_flags(
51 path,
52 rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY
53 | rusqlite::OpenFlags::SQLITE_OPEN_NO_MUTEX
54 | rusqlite::OpenFlags::SQLITE_OPEN_URI,
55 )
56 .map_err(|e| map_err(e, "open_entity_reader"))?;
57
58 conn.busy_timeout(config.busy_timeout)
59 .map_err(|e| map_err(e, "open_entity_reader"))?;
60 conn.pragma_update(None, "foreign_keys", "ON")
61 .map_err(|e| map_err(e, "open_entity_reader"))?;
62 conn.pragma_update(None, "synchronous", "NORMAL")
63 .map_err(|e| map_err(e, "open_entity_reader"))?;
64
65 Ok(conn)
66 }
67
68 async fn with_writer<F, R>(&self, op: &'static str, f: F) -> Result<R, StorageError>
69 where
70 F: FnOnce(&rusqlite::Connection) -> Result<R, rusqlite::Error> + Send + 'static,
71 R: Send + 'static,
72 {
73 let pool = Arc::clone(&self.pool);
74 tokio::task::spawn_blocking(move || {
75 let guard = pool.try_writer().map_err(|e| map_sqlite_err(e, op))?;
76 f(guard.conn()).map_err(|e| map_err(e, op))
77 })
78 .await
79 .map_err(|e| StorageError::driver(StorageCapability::Entities, op, e))?
80 }
81
82 async fn with_reader<F, R>(&self, op: &'static str, f: F) -> Result<R, StorageError>
83 where
84 F: FnOnce(&rusqlite::Connection) -> Result<R, rusqlite::Error> + Send + 'static,
85 R: Send + 'static,
86 {
87 if self.is_file_backed {
88 let conn = self.open_standalone_reader()?;
89 tokio::task::spawn_blocking(move || f(&conn).map_err(|e| map_err(e, op)))
90 .await
91 .map_err(|e| StorageError::driver(StorageCapability::Entities, op, e))?
92 } else {
93 let pool = Arc::clone(&self.pool);
94 tokio::task::spawn_blocking(move || {
95 let guard = pool.reader().map_err(|e| map_sqlite_err(e, op))?;
96 f(guard.conn()).map_err(|e| map_err(e, op))
97 })
98 .await
99 .map_err(|e| StorageError::driver(StorageCapability::Entities, op, e))?
100 }
101 }
102}
103
104fn read_entity(row: &rusqlite::Row<'_>) -> Result<Entity, rusqlite::Error> {
109 let id_str: String = row.get(0)?;
110 let namespace: String = row.get(1)?;
111 let kind: String = row.get(2)?;
112 let entity_type: Option<String> = row.get(3)?;
113 let name: String = row.get(4)?;
114 let description: Option<String> = row.get(5)?;
115 let properties_str: Option<String> = row.get(6)?;
116 let tags_str: String = row.get(7)?;
117 let created_at: i64 = row.get(8)?;
118 let updated_at: i64 = row.get(9)?;
119 let deleted_at: Option<i64> = row.get(10)?;
120 let merged_into_str: Option<String> = row.get(11)?;
121 let merge_event_id_str: Option<String> = row.get(12)?;
122
123 let id = parse_uuid(&id_str)?;
124
125 let properties = properties_str
126 .map(|s| {
127 serde_json::from_str(&s).map_err(|e| {
128 rusqlite::Error::FromSqlConversionFailure(
129 6,
130 rusqlite::types::Type::Text,
131 Box::new(e),
132 )
133 })
134 })
135 .transpose()?;
136
137 let tags: Vec<String> = serde_json::from_str(&tags_str).map_err(|e| {
138 rusqlite::Error::FromSqlConversionFailure(7, rusqlite::types::Type::Text, Box::new(e))
139 })?;
140
141 let merged_into = merged_into_str
142 .as_deref()
143 .map(Uuid::parse_str)
144 .transpose()
145 .map_err(|e| {
146 rusqlite::Error::FromSqlConversionFailure(10, rusqlite::types::Type::Text, Box::new(e))
147 })?;
148
149 let merge_event_id = merge_event_id_str
150 .as_deref()
151 .map(Uuid::parse_str)
152 .transpose()
153 .map_err(|e| {
154 rusqlite::Error::FromSqlConversionFailure(11, rusqlite::types::Type::Text, Box::new(e))
155 })?;
156
157 Ok(Entity {
158 id,
159 namespace,
160 kind,
161 entity_type,
162 name,
163 description,
164 properties,
165 tags,
166 created_at,
167 updated_at,
168 deleted_at,
169 merged_into,
170 merge_event_id,
171 })
172}
173
174fn parse_uuid(s: &str) -> Result<Uuid, rusqlite::Error> {
175 Uuid::parse_str(s).map_err(|e| {
176 rusqlite::Error::FromSqlConversionFailure(0, rusqlite::types::Type::Text, Box::new(e))
177 })
178}
179
180fn build_entity_where(
181 namespace: &str,
182 filter: &EntityFilter,
183) -> (String, Vec<Box<dyn rusqlite::types::ToSql>>) {
184 let mut conditions: Vec<String> = vec![
185 "namespace = ?1".to_string(),
186 "deleted_at IS NULL".to_string(),
187 ];
188 let mut params: Vec<Box<dyn rusqlite::types::ToSql>> = vec![Box::new(namespace.to_string())];
189
190 if !filter.ids.is_empty() {
191 let placeholders: Vec<String> = filter
192 .ids
193 .iter()
194 .map(|id| {
195 params.push(Box::new(id.to_string()));
196 format!("?{}", params.len())
197 })
198 .collect();
199 conditions.push(format!("id IN ({})", placeholders.join(", ")));
200 }
201
202 if !filter.kinds.is_empty() {
203 let placeholders: Vec<String> = filter
204 .kinds
205 .iter()
206 .map(|k| {
207 params.push(Box::new(k.clone()));
208 format!("?{}", params.len())
209 })
210 .collect();
211 conditions.push(format!("kind IN ({})", placeholders.join(", ")));
212 }
213
214 if !filter.entity_types.is_empty() {
215 let placeholders: Vec<String> = filter
216 .entity_types
217 .iter()
218 .map(|t| {
219 params.push(Box::new(t.clone()));
220 format!("?{}", params.len())
221 })
222 .collect();
223 conditions.push(format!("entity_type IN ({})", placeholders.join(", ")));
224 }
225
226 if let Some(ref prefix) = filter.name_prefix {
227 params.push(Box::new(format!("{}%", prefix)));
228 conditions.push(format!("name LIKE ?{}", params.len()));
229 }
230
231 if !filter.tags_any.is_empty() {
232 let placeholders: Vec<String> = filter
233 .tags_any
234 .iter()
235 .map(|t| {
236 params.push(Box::new(t.to_lowercase()));
239 format!("?{}", params.len())
240 })
241 .collect();
242 conditions.push(format!(
243 "EXISTS (SELECT 1 FROM json_each(tags) WHERE LOWER(json_each.value) IN ({}))",
244 placeholders.join(", ")
245 ));
246 }
247
248 let clause = format!(" WHERE {}", conditions.join(" AND "));
249 (clause, params)
250}
251
252#[async_trait]
257impl EntityStore for SqlEntityStore {
258 async fn upsert_entity(&self, entity: Entity) -> Result<(), StorageError> {
259 let namespace = entity.namespace.clone();
260 let id_str = entity.id.to_string();
261 let properties_str = entity
262 .properties
263 .as_ref()
264 .map(|v| serde_json::to_string(v).unwrap_or_default());
265 let tags_str = serde_json::to_string(&entity.tags).unwrap_or_else(|_| "[]".to_string());
266
267 let merged_into_str = entity.merged_into.map(|u| u.to_string());
268 let merge_event_id_str = entity.merge_event_id.map(|u| u.to_string());
269
270 self.with_writer("upsert_entity", move |conn| {
271 conn.execute(
272 "INSERT OR REPLACE INTO entities \
273 (id, namespace, kind, entity_type, name, description, properties, tags, \
274 created_at, updated_at, deleted_at, merged_into, merge_event_id) \
275 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13)",
276 rusqlite::params![
277 id_str,
278 namespace,
279 entity.kind,
280 entity.entity_type,
281 entity.name,
282 entity.description,
283 properties_str,
284 tags_str,
285 entity.created_at,
286 entity.updated_at,
287 entity.deleted_at,
288 merged_into_str,
289 merge_event_id_str,
290 ],
291 )?;
292 Ok(())
293 })
294 .await
295 }
296
297 async fn upsert_entities(
298 &self,
299 entities: Vec<Entity>,
300 ) -> Result<BatchWriteSummary, StorageError> {
301 let attempted = entities.len() as u64;
302
303 self.with_writer("upsert_entities", move |conn| {
304 conn.execute_batch("BEGIN IMMEDIATE")?;
305 let mut affected = 0u64;
306 let mut failed = 0u64;
307 let mut first_error = String::new();
308
309 for entity in &entities {
310 let id_str = entity.id.to_string();
311 let properties_str = entity
312 .properties
313 .as_ref()
314 .map(|v| serde_json::to_string(v).unwrap_or_default());
315 let tags_str =
316 serde_json::to_string(&entity.tags).unwrap_or_else(|_| "[]".to_string());
317
318 let merged_into_str = entity.merged_into.map(|u| u.to_string());
319 let merge_event_id_str = entity.merge_event_id.map(|u| u.to_string());
320 match conn.execute(
321 "INSERT OR REPLACE INTO entities \
322 (id, namespace, kind, entity_type, name, description, properties, tags, \
323 created_at, updated_at, deleted_at, merged_into, merge_event_id) \
324 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13)",
325 rusqlite::params![
326 id_str,
327 &entity.namespace,
328 entity.kind,
329 entity.entity_type,
330 entity.name,
331 entity.description,
332 properties_str,
333 tags_str,
334 entity.created_at,
335 entity.updated_at,
336 entity.deleted_at,
337 merged_into_str,
338 merge_event_id_str,
339 ],
340 ) {
341 Ok(_) => affected += 1,
342 Err(e) => {
343 if first_error.is_empty() {
344 first_error = e.to_string();
345 }
346 failed += 1;
347 }
348 }
349 }
350
351 if let Err(e) = conn.execute_batch("COMMIT") {
352 let _ = conn.execute_batch("ROLLBACK");
353 return Err(e);
354 }
355 Ok(BatchWriteSummary {
356 attempted,
357 affected,
358 failed,
359 first_error,
360 })
361 })
362 .await
363 }
364
365 async fn get_entity(&self, id: Uuid) -> Result<Option<Entity>, StorageError> {
366 let id_str = id.to_string();
367
368 self.with_reader("get_entity", move |conn| {
369 let mut stmt = conn.prepare(
370 "SELECT id, namespace, kind, entity_type, name, description, properties, tags, \
371 created_at, updated_at, deleted_at, merged_into, merge_event_id \
372 FROM entities WHERE id = ?1 AND deleted_at IS NULL",
373 )?;
374 let mut rows = stmt.query(rusqlite::params![id_str])?;
375 match rows.next()? {
376 Some(row) => Ok(Some(read_entity(row)?)),
377 None => Ok(None),
378 }
379 })
380 .await
381 }
382
383 async fn delete_entity(&self, id: Uuid, mode: DeleteMode) -> Result<bool, StorageError> {
384 let id_str = id.to_string();
385
386 match mode {
387 DeleteMode::Soft => {
388 self.with_writer("delete_entity_soft", move |conn| {
389 let now = chrono::Utc::now().timestamp_micros();
390 let deleted = conn.execute(
391 "UPDATE entities SET deleted_at = ?1 \
392 WHERE id = ?2 AND deleted_at IS NULL",
393 rusqlite::params![now, id_str],
394 )?;
395 Ok(deleted > 0)
396 })
397 .await
398 }
399 DeleteMode::Hard => {
400 self.with_writer("delete_entity_hard", move |conn| {
401 let deleted = conn.execute(
402 "DELETE FROM entities WHERE id = ?1",
403 rusqlite::params![id_str],
404 )?;
405 Ok(deleted > 0)
406 })
407 .await
408 }
409 }
410 }
411
412 async fn query_entities(
413 &self,
414 namespace: &str,
415 filter: EntityFilter,
416 page: PageRequest,
417 ) -> Result<Page<Entity>, StorageError> {
418 let namespace = namespace.to_string();
419
420 self.with_reader("query_entities", move |conn| {
421 let (count_sql, count_params) = build_entity_where(&namespace, &filter);
422 let total: i64 = {
423 let sql = format!("SELECT COUNT(*) FROM entities{}", count_sql);
424 let mut stmt = conn.prepare(&sql)?;
425 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
426 count_params.iter().map(|p| p.as_ref()).collect();
427 stmt.query_row(param_refs.as_slice(), |row| row.get(0))?
428 };
429
430 let (where_sql, mut data_params) = build_entity_where(&namespace, &filter);
431 data_params.push(Box::new(page.limit as i64));
432 data_params.push(Box::new(page.offset as i64));
433
434 let limit_idx = data_params.len() - 1;
435 let offset_idx = data_params.len();
436
437 let data_sql = format!(
438 "SELECT id, namespace, kind, entity_type, name, description, properties, tags, \
439 created_at, updated_at, deleted_at, merged_into, merge_event_id \
440 FROM entities{} ORDER BY created_at DESC LIMIT ?{} OFFSET ?{}",
441 where_sql, limit_idx, offset_idx,
442 );
443
444 let mut stmt = conn.prepare(&data_sql)?;
445 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
446 data_params.iter().map(|p| p.as_ref()).collect();
447 let rows = stmt.query_map(param_refs.as_slice(), read_entity)?;
448
449 let mut items = Vec::new();
450 for row in rows {
451 items.push(row?);
452 }
453
454 Ok(Page {
455 items,
456 total: Some(total as u64),
457 })
458 })
459 .await
460 }
461
462 async fn count_entities(
463 &self,
464 namespace: &str,
465 filter: EntityFilter,
466 ) -> Result<u64, StorageError> {
467 let namespace = namespace.to_string();
468
469 self.with_reader("count_entities", move |conn| {
470 let (where_sql, params) = build_entity_where(&namespace, &filter);
471 let sql = format!("SELECT COUNT(*) FROM entities{}", where_sql);
472 let mut stmt = conn.prepare(&sql)?;
473 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
474 params.iter().map(|p| p.as_ref()).collect();
475 let count: i64 = stmt.query_row(param_refs.as_slice(), |row| row.get(0))?;
476 Ok(count as u64)
477 })
478 .await
479 }
480}
481
482const ENTITIES_DDL: &str = "\
487 CREATE TABLE IF NOT EXISTS entities (\
488 id TEXT PRIMARY KEY,\
489 namespace TEXT NOT NULL,\
490 kind TEXT NOT NULL,\
491 entity_type TEXT,\
492 name TEXT NOT NULL,\
493 description TEXT,\
494 properties TEXT,\
495 tags TEXT NOT NULL DEFAULT '[]',\
496 created_at INTEGER NOT NULL,\
497 updated_at INTEGER NOT NULL,\
498 deleted_at INTEGER,\
499 merged_into TEXT,\
500 merge_event_id TEXT\
501 );\
502 CREATE INDEX IF NOT EXISTS idx_entities_namespace ON entities(namespace);\
503 CREATE INDEX IF NOT EXISTS idx_entities_kind ON entities(namespace, kind);\
504 CREATE INDEX IF NOT EXISTS idx_entities_kind_entity_type ON entities(namespace, kind, entity_type);\
505 CREATE INDEX IF NOT EXISTS idx_entities_name ON entities(namespace, name);\
506 CREATE INDEX IF NOT EXISTS idx_entities_created ON entities(created_at DESC);\
507 CREATE INDEX IF NOT EXISTS idx_entities_merged_into ON entities(namespace, merged_into);\
508";
509
510pub(crate) fn ensure_entities_schema(conn: &rusqlite::Connection) -> Result<(), rusqlite::Error> {
511 conn.execute_batch(ENTITIES_DDL)
512}
513
514#[cfg(test)]
515mod tests {
516 use super::*;
517 use crate::pool::PoolConfig;
518
519 fn setup_pool() -> Arc<ConnectionPool> {
520 let config = PoolConfig {
521 path: None,
522 ..PoolConfig::default()
523 };
524 let pool = Arc::new(ConnectionPool::new(config).unwrap());
525 {
526 let writer = pool.writer().unwrap();
527 writer.conn().execute_batch(ENTITIES_DDL).unwrap();
528 }
529 pool
530 }
531
532 fn setup_memory_store() -> SqlEntityStore {
533 SqlEntityStore::new(setup_pool(), false)
534 }
535
536 fn setup_memory_store_ns(_ns: &str) -> SqlEntityStore {
537 SqlEntityStore::new(setup_pool(), false)
538 }
539
540 fn make_entity(namespace: &str, kind: &str, name: &str) -> Entity {
541 let now = chrono::Utc::now().timestamp_micros();
542 Entity {
543 id: Uuid::new_v4(),
544 namespace: namespace.to_string(),
545 kind: kind.to_string(),
546 entity_type: None,
547 name: name.to_string(),
548 description: None,
549 properties: None,
550 tags: Vec::new(),
551 created_at: now,
552 updated_at: now,
553 deleted_at: None,
554 merged_into: None,
555 merge_event_id: None,
556 }
557 }
558
559 #[tokio::test]
560 async fn test_upsert_and_get_entity() {
561 let store = setup_memory_store();
562
563 let entity = make_entity("default", "concept", "LoRA");
564 let id = entity.id;
565
566 store.upsert_entity(entity).await.unwrap();
567
568 let fetched = store.get_entity(id).await.unwrap();
569 assert!(fetched.is_some());
570 let fetched = fetched.unwrap();
571 assert_eq!(fetched.id, id);
572 assert_eq!(fetched.name, "LoRA");
573 assert_eq!(fetched.kind, "concept");
574 }
575
576 #[tokio::test]
577 async fn test_upsert_with_builder() {
578 let store = setup_memory_store();
579
580 let props = serde_json::json!({"domain": "fine-tuning", "type": "technique"});
581 let entity = Entity::new("default", "concept", "QLoRA")
582 .with_description("Quantized LoRA")
583 .with_properties(props.clone())
584 .with_tags(vec!["fine-tuning".to_string(), "quantization".to_string()]);
585 let id = entity.id;
586
587 store.upsert_entity(entity).await.unwrap();
588
589 let fetched = store.get_entity(id).await.unwrap().unwrap();
590 assert_eq!(fetched.description.as_deref(), Some("Quantized LoRA"));
591 assert_eq!(fetched.properties, Some(props));
592 assert_eq!(fetched.tags, vec!["fine-tuning", "quantization"]);
593 }
594
595 #[tokio::test]
596 async fn test_soft_delete() {
597 let store = setup_memory_store();
598
599 let entity = make_entity("default", "concept", "to-delete");
600 let id = entity.id;
601 store.upsert_entity(entity).await.unwrap();
602
603 let deleted = store.delete_entity(id, DeleteMode::Soft).await.unwrap();
604 assert!(deleted);
605
606 let fetched = store.get_entity(id).await.unwrap();
607 assert!(fetched.is_none());
608 }
609
610 #[tokio::test]
611 async fn test_hard_delete() {
612 let store = setup_memory_store();
613
614 let entity = make_entity("default", "concept", "to-hard-delete");
615 let id = entity.id;
616 store.upsert_entity(entity).await.unwrap();
617
618 let deleted = store.delete_entity(id, DeleteMode::Hard).await.unwrap();
619 assert!(deleted);
620
621 let fetched = store.get_entity(id).await.unwrap();
622 assert!(fetched.is_none());
623 }
624
625 #[tokio::test]
626 async fn test_query_entities_basic() {
627 let store = setup_memory_store_ns("ns1");
628
629 for name in &["Alpha", "Beta", "Gamma"] {
630 store
631 .upsert_entity(make_entity("ns1", "concept", name))
632 .await
633 .unwrap();
634 }
635 store
636 .upsert_entity(make_entity("ns1", "document", "Paper1"))
637 .await
638 .unwrap();
639
640 let page = store
641 .query_entities(
642 "ns1",
643 EntityFilter::default(),
644 PageRequest {
645 offset: 0,
646 limit: 10,
647 },
648 )
649 .await
650 .unwrap();
651 assert_eq!(page.items.len(), 4);
652 assert_eq!(page.total, Some(4));
653
654 let concepts = store
656 .query_entities(
657 "ns1",
658 EntityFilter {
659 kinds: vec!["concept".to_string()],
660 ..Default::default()
661 },
662 PageRequest::default(),
663 )
664 .await
665 .unwrap();
666 assert_eq!(concepts.items.len(), 3);
667 }
668
669 #[tokio::test]
670 async fn test_query_by_name_prefix() {
671 let store = setup_memory_store_ns("ns1");
672
673 for &name in &["Alpha", "AlphaGo", "Beta"] {
675 store
676 .upsert_entity(make_entity("ns1", "concept", name))
677 .await
678 .unwrap();
679 }
680
681 let result = store
682 .query_entities(
683 "ns1",
684 EntityFilter {
685 name_prefix: Some("Alpha".to_string()),
686 ..Default::default()
687 },
688 PageRequest::default(),
689 )
690 .await
691 .unwrap();
692 assert_eq!(result.items.len(), 2);
693 let names: Vec<&str> = result.items.iter().map(|e| e.name.as_str()).collect();
694 assert!(names.contains(&"Alpha"), "Alpha not found in {names:?}");
695 assert!(names.contains(&"AlphaGo"), "AlphaGo not found in {names:?}");
696 assert!(!names.contains(&"Beta"));
697 }
698
699 #[tokio::test]
700 async fn test_count_entities() {
701 let store = setup_memory_store_ns("ns1");
702
703 for _ in 0..5 {
704 store
705 .upsert_entity(make_entity("ns1", "concept", "X"))
706 .await
707 .unwrap();
708 }
709
710 let count = store
711 .count_entities("ns1", EntityFilter::default())
712 .await
713 .unwrap();
714 assert_eq!(count, 5);
715
716 let count_other = store
719 .count_entities("ns2", EntityFilter::default())
720 .await
721 .unwrap();
722 assert_eq!(count_other, 0);
723 }
724
725 #[tokio::test]
726 async fn test_batch_upsert() {
727 let store = setup_memory_store_ns("batch_ns");
728
729 let entities: Vec<Entity> = (0..10)
730 .map(|i| make_entity("batch_ns", "concept", &format!("entity_{i}")))
731 .collect();
732
733 let summary = store.upsert_entities(entities).await.unwrap();
734 assert_eq!(summary.attempted, 10);
735 assert_eq!(summary.affected, 10);
736 assert_eq!(summary.failed, 0);
737
738 let count = store
739 .count_entities("batch_ns", EntityFilter::default())
740 .await
741 .unwrap();
742 assert_eq!(count, 10);
743 }
744
745 #[tokio::test]
747 async fn test_namespace_isolation() {
748 let pool = setup_pool();
749 let store = SqlEntityStore::new(Arc::clone(&pool), false);
750
751 store
752 .upsert_entity(make_entity("ns_a", "concept", "EntityA"))
753 .await
754 .unwrap();
755 store
756 .upsert_entity(make_entity("ns_b", "concept", "EntityB"))
757 .await
758 .unwrap();
759
760 let count_a = store
762 .count_entities("ns_a", EntityFilter::default())
763 .await
764 .unwrap();
765 let count_b = store
766 .count_entities("ns_b", EntityFilter::default())
767 .await
768 .unwrap();
769
770 assert_eq!(count_a, 1);
771 assert_eq!(count_b, 1);
772
773 let page_a = store
774 .query_entities("ns_a", EntityFilter::default(), PageRequest::default())
775 .await
776 .unwrap();
777 assert_eq!(page_a.items[0].name, "EntityA");
778
779 let page_b = store
780 .query_entities("ns_b", EntityFilter::default(), PageRequest::default())
781 .await
782 .unwrap();
783 assert_eq!(page_b.items[0].name, "EntityB");
784 }
785
786 #[tokio::test]
787 async fn test_query_by_tags() {
788 let store = setup_memory_store_ns("tags_ns");
789
790 let mut e1 = make_entity("tags_ns", "concept", "Tagged1");
791 e1.tags = vec!["rust".to_string(), "systems".to_string()];
792 let mut e2 = make_entity("tags_ns", "concept", "Tagged2");
793 e2.tags = vec!["python".to_string(), "ml".to_string()];
794 let mut e3 = make_entity("tags_ns", "concept", "Tagged3");
795 e3.tags = vec!["rust".to_string(), "ml".to_string()];
796
797 store.upsert_entity(e1).await.unwrap();
798 store.upsert_entity(e2).await.unwrap();
799 store.upsert_entity(e3).await.unwrap();
800
801 let result = store
803 .query_entities(
804 "tags_ns",
805 EntityFilter {
806 tags_any: vec!["rust".to_string()],
807 ..Default::default()
808 },
809 PageRequest::default(),
810 )
811 .await
812 .unwrap();
813 assert_eq!(result.items.len(), 2);
814 let names: Vec<&str> = result.items.iter().map(|e| e.name.as_str()).collect();
815 assert!(names.contains(&"Tagged1"));
816 assert!(names.contains(&"Tagged3"));
817 assert!(!names.contains(&"Tagged2"));
818
819 let result = store
821 .query_entities(
822 "tags_ns",
823 EntityFilter {
824 tags_any: vec!["ml".to_string()],
825 ..Default::default()
826 },
827 PageRequest::default(),
828 )
829 .await
830 .unwrap();
831 assert_eq!(result.items.len(), 2);
832
833 let result = store
835 .query_entities(
836 "tags_ns",
837 EntityFilter {
838 tags_any: vec!["rust".to_string(), "python".to_string()],
839 ..Default::default()
840 },
841 PageRequest::default(),
842 )
843 .await
844 .unwrap();
845 assert_eq!(result.items.len(), 3);
846 }
847
848 #[tokio::test]
849 async fn test_query_by_ids() {
850 let store = setup_memory_store_ns("ns1");
851
852 let e1 = make_entity("ns1", "concept", "E1");
853 let e2 = make_entity("ns1", "concept", "E2");
854 let e3 = make_entity("ns1", "concept", "E3");
855 let ids = vec![e1.id, e3.id];
856
857 store.upsert_entity(e1).await.unwrap();
858 store.upsert_entity(e2).await.unwrap();
859 store.upsert_entity(e3).await.unwrap();
860
861 let result = store
862 .query_entities(
863 "ns1",
864 EntityFilter {
865 ids,
866 ..Default::default()
867 },
868 PageRequest::default(),
869 )
870 .await
871 .unwrap();
872 assert_eq!(result.items.len(), 2);
873 let names: Vec<&str> = result.items.iter().map(|e| e.name.as_str()).collect();
874 assert!(names.contains(&"E1"));
875 assert!(names.contains(&"E3"));
876 assert!(!names.contains(&"E2"));
877 }
878
879 #[tokio::test]
880 async fn test_entity_type_roundtrip() {
881 let store = setup_memory_store();
882
883 let entity =
884 Entity::new("default", "document", "ResearchPaper").with_entity_type(Some("paper"));
885 let id = entity.id;
886
887 store.upsert_entity(entity).await.unwrap();
888
889 let fetched = store.get_entity(id).await.unwrap().unwrap();
890 assert_eq!(fetched.entity_type, Some("paper".to_string()));
891 assert_eq!(fetched.kind, "document");
892 assert_eq!(fetched.name, "ResearchPaper");
893 }
894
895 #[tokio::test]
896 async fn test_query_by_kind_and_entity_type() {
897 let store = setup_memory_store_ns("et_ns");
898
899 let typed =
900 Entity::new("et_ns", "person", "Researcher").with_entity_type(Some("researcher"));
901 let untyped = make_entity("et_ns", "person", "Generic");
902
903 store.upsert_entity(typed).await.unwrap();
904 store.upsert_entity(untyped).await.unwrap();
905
906 let result = store
907 .query_entities(
908 "et_ns",
909 EntityFilter {
910 entity_types: vec!["researcher".to_string()],
911 ..Default::default()
912 },
913 PageRequest::default(),
914 )
915 .await
916 .unwrap();
917
918 assert_eq!(result.items.len(), 1);
919 assert_eq!(result.items[0].name, "Researcher");
920 assert_eq!(result.items[0].entity_type, Some("researcher".to_string()));
921 }
922
923 #[tokio::test]
927 async fn test_same_id_upsert_replaces_row() {
928 let pool = setup_pool();
929 let store = SqlEntityStore::new(Arc::clone(&pool), false);
930
931 let shared_id = Uuid::new_v4();
932 let now = chrono::Utc::now().timestamp_micros();
933
934 let entity_a = Entity {
935 id: shared_id,
936 namespace: "ns_a".to_string(),
937 kind: "concept".to_string(),
938 entity_type: None,
939 name: "SharedInA".to_string(),
940 description: None,
941 properties: None,
942 tags: Vec::new(),
943 created_at: now,
944 updated_at: now,
945 deleted_at: None,
946 merged_into: None,
947 merge_event_id: None,
948 };
949 store.upsert_entity(entity_a).await.unwrap();
950
951 let fetched = store.get_entity(shared_id).await.unwrap().unwrap();
953 assert_eq!(fetched.namespace, "ns_a");
954 assert_eq!(fetched.name, "SharedInA");
955
956 let entity_b = Entity {
958 id: shared_id,
959 namespace: "ns_b".to_string(),
960 kind: "concept".to_string(),
961 entity_type: None,
962 name: "SharedInB".to_string(),
963 description: None,
964 properties: None,
965 tags: Vec::new(),
966 created_at: now,
967 updated_at: now,
968 deleted_at: None,
969 merged_into: None,
970 merge_event_id: None,
971 };
972 store.upsert_entity(entity_b).await.unwrap();
973
974 let fetched = store.get_entity(shared_id).await.unwrap().unwrap();
977 assert_eq!(fetched.namespace, "ns_b");
978 assert_eq!(fetched.name, "SharedInB");
979
980 let count_a = store
982 .count_entities("ns_a", EntityFilter::default())
983 .await
984 .unwrap();
985 let count_b = store
986 .count_entities("ns_b", EntityFilter::default())
987 .await
988 .unwrap();
989 assert_eq!(count_a, 0);
990 assert_eq!(count_b, 1);
991 }
992}