1use std::sync::Arc;
4
5use async_trait::async_trait;
6use uuid::Uuid;
7
8use khive_storage::entity::{Entity, EntityFilter};
9use khive_storage::error::StorageError;
10use khive_storage::types::{BatchWriteSummary, DeleteMode, Page, PageRequest};
11use khive_storage::EntityStore;
12use khive_storage::StorageCapability;
13
14use crate::error::SqliteError;
15use crate::pool::ConnectionPool;
16
17fn map_err(e: rusqlite::Error, op: &'static str) -> StorageError {
18 StorageError::driver(StorageCapability::Entities, op, e)
19}
20
21fn map_sqlite_err(e: SqliteError, op: &'static str) -> StorageError {
22 StorageError::driver(StorageCapability::Entities, op, e)
23}
24
25pub struct SqlEntityStore {
30 pool: Arc<ConnectionPool>,
31 is_file_backed: bool,
32}
33
34impl SqlEntityStore {
35 pub fn new(pool: Arc<ConnectionPool>, is_file_backed: bool) -> Self {
37 Self {
38 pool,
39 is_file_backed,
40 }
41 }
42
43 fn open_standalone_reader(&self) -> Result<rusqlite::Connection, StorageError> {
44 let config = self.pool.config();
45 let path = config.path.as_ref().ok_or_else(|| StorageError::Pool {
46 operation: "entity_reader".into(),
47 message: "in-memory databases do not support standalone connections".into(),
48 })?;
49
50 let conn = rusqlite::Connection::open_with_flags(
51 path,
52 rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY
53 | rusqlite::OpenFlags::SQLITE_OPEN_NO_MUTEX
54 | rusqlite::OpenFlags::SQLITE_OPEN_URI,
55 )
56 .map_err(|e| map_err(e, "open_entity_reader"))?;
57
58 conn.busy_timeout(config.busy_timeout)
59 .map_err(|e| map_err(e, "open_entity_reader"))?;
60 conn.pragma_update(None, "foreign_keys", "ON")
61 .map_err(|e| map_err(e, "open_entity_reader"))?;
62 conn.pragma_update(None, "synchronous", "NORMAL")
63 .map_err(|e| map_err(e, "open_entity_reader"))?;
64
65 Ok(conn)
66 }
67
68 async fn with_writer<F, R>(&self, op: &'static str, f: F) -> Result<R, StorageError>
69 where
70 F: FnOnce(&rusqlite::Connection) -> Result<R, rusqlite::Error> + Send + 'static,
71 R: Send + 'static,
72 {
73 let pool = Arc::clone(&self.pool);
74 tokio::task::spawn_blocking(move || {
75 let guard = pool.try_writer().map_err(|e| map_sqlite_err(e, op))?;
76 f(guard.conn()).map_err(|e| map_err(e, op))
77 })
78 .await
79 .map_err(|e| StorageError::driver(StorageCapability::Entities, op, e))?
80 }
81
82 async fn with_reader<F, R>(&self, op: &'static str, f: F) -> Result<R, StorageError>
83 where
84 F: FnOnce(&rusqlite::Connection) -> Result<R, rusqlite::Error> + Send + 'static,
85 R: Send + 'static,
86 {
87 if self.is_file_backed {
88 let conn = self.open_standalone_reader()?;
89 tokio::task::spawn_blocking(move || f(&conn).map_err(|e| map_err(e, op)))
90 .await
91 .map_err(|e| StorageError::driver(StorageCapability::Entities, op, e))?
92 } else {
93 let pool = Arc::clone(&self.pool);
94 tokio::task::spawn_blocking(move || {
95 let guard = pool.reader().map_err(|e| map_sqlite_err(e, op))?;
96 f(guard.conn()).map_err(|e| map_err(e, op))
97 })
98 .await
99 .map_err(|e| StorageError::driver(StorageCapability::Entities, op, e))?
100 }
101 }
102}
103
104fn read_entity(row: &rusqlite::Row<'_>) -> Result<Entity, rusqlite::Error> {
109 let id_str: String = row.get(0)?;
110 let namespace: String = row.get(1)?;
111 let kind: String = row.get(2)?;
112 let entity_type: Option<String> = row.get(3)?;
113 let name: String = row.get(4)?;
114 let description: Option<String> = row.get(5)?;
115 let properties_str: Option<String> = row.get(6)?;
116 let tags_str: String = row.get(7)?;
117 let created_at: i64 = row.get(8)?;
118 let updated_at: i64 = row.get(9)?;
119 let deleted_at: Option<i64> = row.get(10)?;
120 let merged_into_str: Option<String> = row.get(11)?;
121 let merge_event_id_str: Option<String> = row.get(12)?;
122
123 let id = parse_uuid(&id_str)?;
124
125 let properties = properties_str
126 .map(|s| {
127 serde_json::from_str(&s).map_err(|e| {
128 rusqlite::Error::FromSqlConversionFailure(
129 6,
130 rusqlite::types::Type::Text,
131 Box::new(e),
132 )
133 })
134 })
135 .transpose()?;
136
137 let tags: Vec<String> = serde_json::from_str(&tags_str).map_err(|e| {
138 rusqlite::Error::FromSqlConversionFailure(7, rusqlite::types::Type::Text, Box::new(e))
139 })?;
140
141 let merged_into = merged_into_str
142 .as_deref()
143 .map(Uuid::parse_str)
144 .transpose()
145 .map_err(|e| {
146 rusqlite::Error::FromSqlConversionFailure(10, rusqlite::types::Type::Text, Box::new(e))
147 })?;
148
149 let merge_event_id = merge_event_id_str
150 .as_deref()
151 .map(Uuid::parse_str)
152 .transpose()
153 .map_err(|e| {
154 rusqlite::Error::FromSqlConversionFailure(11, rusqlite::types::Type::Text, Box::new(e))
155 })?;
156
157 Ok(Entity {
158 id,
159 namespace,
160 kind,
161 entity_type,
162 name,
163 description,
164 properties,
165 tags,
166 created_at,
167 updated_at,
168 deleted_at,
169 merged_into,
170 merge_event_id,
171 })
172}
173
174fn parse_uuid(s: &str) -> Result<Uuid, rusqlite::Error> {
175 Uuid::parse_str(s).map_err(|e| {
176 rusqlite::Error::FromSqlConversionFailure(0, rusqlite::types::Type::Text, Box::new(e))
177 })
178}
179
180fn build_entity_where(
181 namespace: &str,
182 filter: &EntityFilter,
183) -> (String, Vec<Box<dyn rusqlite::types::ToSql>>) {
184 let mut conditions: Vec<String> = vec![
185 "namespace = ?1".to_string(),
186 "deleted_at IS NULL".to_string(),
187 ];
188 let mut params: Vec<Box<dyn rusqlite::types::ToSql>> = vec![Box::new(namespace.to_string())];
189
190 if !filter.ids.is_empty() {
191 let placeholders: Vec<String> = filter
192 .ids
193 .iter()
194 .map(|id| {
195 params.push(Box::new(id.to_string()));
196 format!("?{}", params.len())
197 })
198 .collect();
199 conditions.push(format!("id IN ({})", placeholders.join(", ")));
200 }
201
202 if !filter.kinds.is_empty() {
203 let placeholders: Vec<String> = filter
204 .kinds
205 .iter()
206 .map(|k| {
207 params.push(Box::new(k.clone()));
208 format!("?{}", params.len())
209 })
210 .collect();
211 conditions.push(format!("kind IN ({})", placeholders.join(", ")));
212 }
213
214 if !filter.entity_types.is_empty() {
215 let placeholders: Vec<String> = filter
216 .entity_types
217 .iter()
218 .map(|t| {
219 params.push(Box::new(t.clone()));
220 format!("?{}", params.len())
221 })
222 .collect();
223 conditions.push(format!("entity_type IN ({})", placeholders.join(", ")));
224 }
225
226 if let Some(ref prefix) = filter.name_prefix {
227 params.push(Box::new(format!("{}%", prefix)));
228 conditions.push(format!("name LIKE ?{}", params.len()));
229 }
230
231 if !filter.tags_any.is_empty() {
232 let placeholders: Vec<String> = filter
233 .tags_any
234 .iter()
235 .map(|t| {
236 params.push(Box::new(t.clone()));
237 format!("?{}", params.len())
238 })
239 .collect();
240 conditions.push(format!(
241 "EXISTS (SELECT 1 FROM json_each(tags) WHERE json_each.value IN ({}))",
242 placeholders.join(", ")
243 ));
244 }
245
246 let clause = format!(" WHERE {}", conditions.join(" AND "));
247 (clause, params)
248}
249
250#[async_trait]
255impl EntityStore for SqlEntityStore {
256 async fn upsert_entity(&self, entity: Entity) -> Result<(), StorageError> {
257 let namespace = entity.namespace.clone();
258 let id_str = entity.id.to_string();
259 let properties_str = entity
260 .properties
261 .as_ref()
262 .map(|v| serde_json::to_string(v).unwrap_or_default());
263 let tags_str = serde_json::to_string(&entity.tags).unwrap_or_else(|_| "[]".to_string());
264
265 let merged_into_str = entity.merged_into.map(|u| u.to_string());
266 let merge_event_id_str = entity.merge_event_id.map(|u| u.to_string());
267
268 self.with_writer("upsert_entity", move |conn| {
269 conn.execute(
270 "INSERT OR REPLACE INTO entities \
271 (id, namespace, kind, entity_type, name, description, properties, tags, \
272 created_at, updated_at, deleted_at, merged_into, merge_event_id) \
273 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13)",
274 rusqlite::params![
275 id_str,
276 namespace,
277 entity.kind,
278 entity.entity_type,
279 entity.name,
280 entity.description,
281 properties_str,
282 tags_str,
283 entity.created_at,
284 entity.updated_at,
285 entity.deleted_at,
286 merged_into_str,
287 merge_event_id_str,
288 ],
289 )?;
290 Ok(())
291 })
292 .await
293 }
294
295 async fn upsert_entities(
296 &self,
297 entities: Vec<Entity>,
298 ) -> Result<BatchWriteSummary, StorageError> {
299 let attempted = entities.len() as u64;
300
301 self.with_writer("upsert_entities", move |conn| {
302 conn.execute_batch("BEGIN IMMEDIATE")?;
303 let mut affected = 0u64;
304 let mut failed = 0u64;
305 let mut first_error = String::new();
306
307 for entity in &entities {
308 let id_str = entity.id.to_string();
309 let properties_str = entity
310 .properties
311 .as_ref()
312 .map(|v| serde_json::to_string(v).unwrap_or_default());
313 let tags_str =
314 serde_json::to_string(&entity.tags).unwrap_or_else(|_| "[]".to_string());
315
316 let merged_into_str = entity.merged_into.map(|u| u.to_string());
317 let merge_event_id_str = entity.merge_event_id.map(|u| u.to_string());
318 match conn.execute(
319 "INSERT OR REPLACE INTO entities \
320 (id, namespace, kind, entity_type, name, description, properties, tags, \
321 created_at, updated_at, deleted_at, merged_into, merge_event_id) \
322 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13)",
323 rusqlite::params![
324 id_str,
325 &entity.namespace,
326 entity.kind,
327 entity.entity_type,
328 entity.name,
329 entity.description,
330 properties_str,
331 tags_str,
332 entity.created_at,
333 entity.updated_at,
334 entity.deleted_at,
335 merged_into_str,
336 merge_event_id_str,
337 ],
338 ) {
339 Ok(_) => affected += 1,
340 Err(e) => {
341 if first_error.is_empty() {
342 first_error = e.to_string();
343 }
344 failed += 1;
345 }
346 }
347 }
348
349 if let Err(e) = conn.execute_batch("COMMIT") {
350 let _ = conn.execute_batch("ROLLBACK");
351 return Err(e);
352 }
353 Ok(BatchWriteSummary {
354 attempted,
355 affected,
356 failed,
357 first_error,
358 })
359 })
360 .await
361 }
362
363 async fn get_entity(&self, id: Uuid) -> Result<Option<Entity>, StorageError> {
364 let id_str = id.to_string();
365
366 self.with_reader("get_entity", move |conn| {
367 let mut stmt = conn.prepare(
368 "SELECT id, namespace, kind, entity_type, name, description, properties, tags, \
369 created_at, updated_at, deleted_at, merged_into, merge_event_id \
370 FROM entities WHERE id = ?1 AND deleted_at IS NULL",
371 )?;
372 let mut rows = stmt.query(rusqlite::params![id_str])?;
373 match rows.next()? {
374 Some(row) => Ok(Some(read_entity(row)?)),
375 None => Ok(None),
376 }
377 })
378 .await
379 }
380
381 async fn delete_entity(&self, id: Uuid, mode: DeleteMode) -> Result<bool, StorageError> {
382 let id_str = id.to_string();
383
384 match mode {
385 DeleteMode::Soft => {
386 self.with_writer("delete_entity_soft", move |conn| {
387 let now = chrono::Utc::now().timestamp_micros();
388 let deleted = conn.execute(
389 "UPDATE entities SET deleted_at = ?1 \
390 WHERE id = ?2 AND deleted_at IS NULL",
391 rusqlite::params![now, id_str],
392 )?;
393 Ok(deleted > 0)
394 })
395 .await
396 }
397 DeleteMode::Hard => {
398 self.with_writer("delete_entity_hard", move |conn| {
399 let deleted = conn.execute(
400 "DELETE FROM entities WHERE id = ?1",
401 rusqlite::params![id_str],
402 )?;
403 Ok(deleted > 0)
404 })
405 .await
406 }
407 }
408 }
409
410 async fn query_entities(
411 &self,
412 namespace: &str,
413 filter: EntityFilter,
414 page: PageRequest,
415 ) -> Result<Page<Entity>, StorageError> {
416 let namespace = namespace.to_string();
417
418 self.with_reader("query_entities", move |conn| {
419 let (count_sql, count_params) = build_entity_where(&namespace, &filter);
420 let total: i64 = {
421 let sql = format!("SELECT COUNT(*) FROM entities{}", count_sql);
422 let mut stmt = conn.prepare(&sql)?;
423 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
424 count_params.iter().map(|p| p.as_ref()).collect();
425 stmt.query_row(param_refs.as_slice(), |row| row.get(0))?
426 };
427
428 let (where_sql, mut data_params) = build_entity_where(&namespace, &filter);
429 data_params.push(Box::new(page.limit as i64));
430 data_params.push(Box::new(page.offset as i64));
431
432 let limit_idx = data_params.len() - 1;
433 let offset_idx = data_params.len();
434
435 let data_sql = format!(
436 "SELECT id, namespace, kind, entity_type, name, description, properties, tags, \
437 created_at, updated_at, deleted_at, merged_into, merge_event_id \
438 FROM entities{} ORDER BY created_at DESC LIMIT ?{} OFFSET ?{}",
439 where_sql, limit_idx, offset_idx,
440 );
441
442 let mut stmt = conn.prepare(&data_sql)?;
443 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
444 data_params.iter().map(|p| p.as_ref()).collect();
445 let rows = stmt.query_map(param_refs.as_slice(), read_entity)?;
446
447 let mut items = Vec::new();
448 for row in rows {
449 items.push(row?);
450 }
451
452 Ok(Page {
453 items,
454 total: Some(total as u64),
455 })
456 })
457 .await
458 }
459
460 async fn count_entities(
461 &self,
462 namespace: &str,
463 filter: EntityFilter,
464 ) -> Result<u64, StorageError> {
465 let namespace = namespace.to_string();
466
467 self.with_reader("count_entities", move |conn| {
468 let (where_sql, params) = build_entity_where(&namespace, &filter);
469 let sql = format!("SELECT COUNT(*) FROM entities{}", where_sql);
470 let mut stmt = conn.prepare(&sql)?;
471 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
472 params.iter().map(|p| p.as_ref()).collect();
473 let count: i64 = stmt.query_row(param_refs.as_slice(), |row| row.get(0))?;
474 Ok(count as u64)
475 })
476 .await
477 }
478}
479
480const ENTITIES_DDL: &str = "\
485 CREATE TABLE IF NOT EXISTS entities (\
486 id TEXT PRIMARY KEY,\
487 namespace TEXT NOT NULL,\
488 kind TEXT NOT NULL,\
489 entity_type TEXT,\
490 name TEXT NOT NULL,\
491 description TEXT,\
492 properties TEXT,\
493 tags TEXT NOT NULL DEFAULT '[]',\
494 created_at INTEGER NOT NULL,\
495 updated_at INTEGER NOT NULL,\
496 deleted_at INTEGER,\
497 merged_into TEXT,\
498 merge_event_id TEXT\
499 );\
500 CREATE INDEX IF NOT EXISTS idx_entities_namespace ON entities(namespace);\
501 CREATE INDEX IF NOT EXISTS idx_entities_kind ON entities(namespace, kind);\
502 CREATE INDEX IF NOT EXISTS idx_entities_kind_entity_type ON entities(namespace, kind, entity_type);\
503 CREATE INDEX IF NOT EXISTS idx_entities_name ON entities(namespace, name);\
504 CREATE INDEX IF NOT EXISTS idx_entities_created ON entities(created_at DESC);\
505 CREATE INDEX IF NOT EXISTS idx_entities_merged_into ON entities(namespace, merged_into);\
506";
507
508pub(crate) fn ensure_entities_schema(conn: &rusqlite::Connection) -> Result<(), rusqlite::Error> {
509 conn.execute_batch(ENTITIES_DDL)
510}
511
512#[cfg(test)]
513mod tests {
514 use super::*;
515 use crate::pool::PoolConfig;
516
517 fn setup_pool() -> Arc<ConnectionPool> {
518 let config = PoolConfig {
519 path: None,
520 ..PoolConfig::default()
521 };
522 let pool = Arc::new(ConnectionPool::new(config).unwrap());
523 {
524 let writer = pool.writer().unwrap();
525 writer.conn().execute_batch(ENTITIES_DDL).unwrap();
526 }
527 pool
528 }
529
530 fn setup_memory_store() -> SqlEntityStore {
531 SqlEntityStore::new(setup_pool(), false)
532 }
533
534 fn setup_memory_store_ns(_ns: &str) -> SqlEntityStore {
535 SqlEntityStore::new(setup_pool(), false)
536 }
537
538 fn make_entity(namespace: &str, kind: &str, name: &str) -> Entity {
539 let now = chrono::Utc::now().timestamp_micros();
540 Entity {
541 id: Uuid::new_v4(),
542 namespace: namespace.to_string(),
543 kind: kind.to_string(),
544 entity_type: None,
545 name: name.to_string(),
546 description: None,
547 properties: None,
548 tags: Vec::new(),
549 created_at: now,
550 updated_at: now,
551 deleted_at: None,
552 merged_into: None,
553 merge_event_id: None,
554 }
555 }
556
557 #[tokio::test]
558 async fn test_upsert_and_get_entity() {
559 let store = setup_memory_store();
560
561 let entity = make_entity("default", "concept", "LoRA");
562 let id = entity.id;
563
564 store.upsert_entity(entity).await.unwrap();
565
566 let fetched = store.get_entity(id).await.unwrap();
567 assert!(fetched.is_some());
568 let fetched = fetched.unwrap();
569 assert_eq!(fetched.id, id);
570 assert_eq!(fetched.name, "LoRA");
571 assert_eq!(fetched.kind, "concept");
572 }
573
574 #[tokio::test]
575 async fn test_upsert_with_builder() {
576 let store = setup_memory_store();
577
578 let props = serde_json::json!({"domain": "fine-tuning", "type": "technique"});
579 let entity = Entity::new("default", "concept", "QLoRA")
580 .with_description("Quantized LoRA")
581 .with_properties(props.clone())
582 .with_tags(vec!["fine-tuning".to_string(), "quantization".to_string()]);
583 let id = entity.id;
584
585 store.upsert_entity(entity).await.unwrap();
586
587 let fetched = store.get_entity(id).await.unwrap().unwrap();
588 assert_eq!(fetched.description.as_deref(), Some("Quantized LoRA"));
589 assert_eq!(fetched.properties, Some(props));
590 assert_eq!(fetched.tags, vec!["fine-tuning", "quantization"]);
591 }
592
593 #[tokio::test]
594 async fn test_soft_delete() {
595 let store = setup_memory_store();
596
597 let entity = make_entity("default", "concept", "to-delete");
598 let id = entity.id;
599 store.upsert_entity(entity).await.unwrap();
600
601 let deleted = store.delete_entity(id, DeleteMode::Soft).await.unwrap();
602 assert!(deleted);
603
604 let fetched = store.get_entity(id).await.unwrap();
605 assert!(fetched.is_none());
606 }
607
608 #[tokio::test]
609 async fn test_hard_delete() {
610 let store = setup_memory_store();
611
612 let entity = make_entity("default", "concept", "to-hard-delete");
613 let id = entity.id;
614 store.upsert_entity(entity).await.unwrap();
615
616 let deleted = store.delete_entity(id, DeleteMode::Hard).await.unwrap();
617 assert!(deleted);
618
619 let fetched = store.get_entity(id).await.unwrap();
620 assert!(fetched.is_none());
621 }
622
623 #[tokio::test]
624 async fn test_query_entities_basic() {
625 let store = setup_memory_store_ns("ns1");
626
627 for name in &["Alpha", "Beta", "Gamma"] {
628 store
629 .upsert_entity(make_entity("ns1", "concept", name))
630 .await
631 .unwrap();
632 }
633 store
634 .upsert_entity(make_entity("ns1", "document", "Paper1"))
635 .await
636 .unwrap();
637
638 let page = store
639 .query_entities(
640 "ns1",
641 EntityFilter::default(),
642 PageRequest {
643 offset: 0,
644 limit: 10,
645 },
646 )
647 .await
648 .unwrap();
649 assert_eq!(page.items.len(), 4);
650 assert_eq!(page.total, Some(4));
651
652 let concepts = store
654 .query_entities(
655 "ns1",
656 EntityFilter {
657 kinds: vec!["concept".to_string()],
658 ..Default::default()
659 },
660 PageRequest::default(),
661 )
662 .await
663 .unwrap();
664 assert_eq!(concepts.items.len(), 3);
665 }
666
667 #[tokio::test]
668 async fn test_query_by_name_prefix() {
669 let store = setup_memory_store_ns("ns1");
670
671 for &name in &["Alpha", "AlphaGo", "Beta"] {
673 store
674 .upsert_entity(make_entity("ns1", "concept", name))
675 .await
676 .unwrap();
677 }
678
679 let result = store
680 .query_entities(
681 "ns1",
682 EntityFilter {
683 name_prefix: Some("Alpha".to_string()),
684 ..Default::default()
685 },
686 PageRequest::default(),
687 )
688 .await
689 .unwrap();
690 assert_eq!(result.items.len(), 2);
691 let names: Vec<&str> = result.items.iter().map(|e| e.name.as_str()).collect();
692 assert!(names.contains(&"Alpha"), "Alpha not found in {names:?}");
693 assert!(names.contains(&"AlphaGo"), "AlphaGo not found in {names:?}");
694 assert!(!names.contains(&"Beta"));
695 }
696
697 #[tokio::test]
698 async fn test_count_entities() {
699 let store = setup_memory_store_ns("ns1");
700
701 for _ in 0..5 {
702 store
703 .upsert_entity(make_entity("ns1", "concept", "X"))
704 .await
705 .unwrap();
706 }
707
708 let count = store
709 .count_entities("ns1", EntityFilter::default())
710 .await
711 .unwrap();
712 assert_eq!(count, 5);
713
714 let count_other = store
717 .count_entities("ns2", EntityFilter::default())
718 .await
719 .unwrap();
720 assert_eq!(count_other, 0);
721 }
722
723 #[tokio::test]
724 async fn test_batch_upsert() {
725 let store = setup_memory_store_ns("batch_ns");
726
727 let entities: Vec<Entity> = (0..10)
728 .map(|i| make_entity("batch_ns", "concept", &format!("entity_{i}")))
729 .collect();
730
731 let summary = store.upsert_entities(entities).await.unwrap();
732 assert_eq!(summary.attempted, 10);
733 assert_eq!(summary.affected, 10);
734 assert_eq!(summary.failed, 0);
735
736 let count = store
737 .count_entities("batch_ns", EntityFilter::default())
738 .await
739 .unwrap();
740 assert_eq!(count, 10);
741 }
742
743 #[tokio::test]
745 async fn test_namespace_isolation() {
746 let pool = setup_pool();
747 let store = SqlEntityStore::new(Arc::clone(&pool), false);
748
749 store
750 .upsert_entity(make_entity("ns_a", "concept", "EntityA"))
751 .await
752 .unwrap();
753 store
754 .upsert_entity(make_entity("ns_b", "concept", "EntityB"))
755 .await
756 .unwrap();
757
758 let count_a = store
760 .count_entities("ns_a", EntityFilter::default())
761 .await
762 .unwrap();
763 let count_b = store
764 .count_entities("ns_b", EntityFilter::default())
765 .await
766 .unwrap();
767
768 assert_eq!(count_a, 1);
769 assert_eq!(count_b, 1);
770
771 let page_a = store
772 .query_entities("ns_a", EntityFilter::default(), PageRequest::default())
773 .await
774 .unwrap();
775 assert_eq!(page_a.items[0].name, "EntityA");
776
777 let page_b = store
778 .query_entities("ns_b", EntityFilter::default(), PageRequest::default())
779 .await
780 .unwrap();
781 assert_eq!(page_b.items[0].name, "EntityB");
782 }
783
784 #[tokio::test]
785 async fn test_query_by_tags() {
786 let store = setup_memory_store_ns("tags_ns");
787
788 let mut e1 = make_entity("tags_ns", "concept", "Tagged1");
789 e1.tags = vec!["rust".to_string(), "systems".to_string()];
790 let mut e2 = make_entity("tags_ns", "concept", "Tagged2");
791 e2.tags = vec!["python".to_string(), "ml".to_string()];
792 let mut e3 = make_entity("tags_ns", "concept", "Tagged3");
793 e3.tags = vec!["rust".to_string(), "ml".to_string()];
794
795 store.upsert_entity(e1).await.unwrap();
796 store.upsert_entity(e2).await.unwrap();
797 store.upsert_entity(e3).await.unwrap();
798
799 let result = store
801 .query_entities(
802 "tags_ns",
803 EntityFilter {
804 tags_any: vec!["rust".to_string()],
805 ..Default::default()
806 },
807 PageRequest::default(),
808 )
809 .await
810 .unwrap();
811 assert_eq!(result.items.len(), 2);
812 let names: Vec<&str> = result.items.iter().map(|e| e.name.as_str()).collect();
813 assert!(names.contains(&"Tagged1"));
814 assert!(names.contains(&"Tagged3"));
815 assert!(!names.contains(&"Tagged2"));
816
817 let result = store
819 .query_entities(
820 "tags_ns",
821 EntityFilter {
822 tags_any: vec!["ml".to_string()],
823 ..Default::default()
824 },
825 PageRequest::default(),
826 )
827 .await
828 .unwrap();
829 assert_eq!(result.items.len(), 2);
830
831 let result = store
833 .query_entities(
834 "tags_ns",
835 EntityFilter {
836 tags_any: vec!["rust".to_string(), "python".to_string()],
837 ..Default::default()
838 },
839 PageRequest::default(),
840 )
841 .await
842 .unwrap();
843 assert_eq!(result.items.len(), 3);
844 }
845
846 #[tokio::test]
847 async fn test_query_by_ids() {
848 let store = setup_memory_store_ns("ns1");
849
850 let e1 = make_entity("ns1", "concept", "E1");
851 let e2 = make_entity("ns1", "concept", "E2");
852 let e3 = make_entity("ns1", "concept", "E3");
853 let ids = vec![e1.id, e3.id];
854
855 store.upsert_entity(e1).await.unwrap();
856 store.upsert_entity(e2).await.unwrap();
857 store.upsert_entity(e3).await.unwrap();
858
859 let result = store
860 .query_entities(
861 "ns1",
862 EntityFilter {
863 ids,
864 ..Default::default()
865 },
866 PageRequest::default(),
867 )
868 .await
869 .unwrap();
870 assert_eq!(result.items.len(), 2);
871 let names: Vec<&str> = result.items.iter().map(|e| e.name.as_str()).collect();
872 assert!(names.contains(&"E1"));
873 assert!(names.contains(&"E3"));
874 assert!(!names.contains(&"E2"));
875 }
876
877 #[tokio::test]
878 async fn test_entity_type_roundtrip() {
879 let store = setup_memory_store();
880
881 let entity =
882 Entity::new("default", "document", "ResearchPaper").with_entity_type(Some("paper"));
883 let id = entity.id;
884
885 store.upsert_entity(entity).await.unwrap();
886
887 let fetched = store.get_entity(id).await.unwrap().unwrap();
888 assert_eq!(fetched.entity_type, Some("paper".to_string()));
889 assert_eq!(fetched.kind, "document");
890 assert_eq!(fetched.name, "ResearchPaper");
891 }
892
893 #[tokio::test]
894 async fn test_query_by_kind_and_entity_type() {
895 let store = setup_memory_store_ns("et_ns");
896
897 let typed =
898 Entity::new("et_ns", "person", "Researcher").with_entity_type(Some("researcher"));
899 let untyped = make_entity("et_ns", "person", "Generic");
900
901 store.upsert_entity(typed).await.unwrap();
902 store.upsert_entity(untyped).await.unwrap();
903
904 let result = store
905 .query_entities(
906 "et_ns",
907 EntityFilter {
908 entity_types: vec!["researcher".to_string()],
909 ..Default::default()
910 },
911 PageRequest::default(),
912 )
913 .await
914 .unwrap();
915
916 assert_eq!(result.items.len(), 1);
917 assert_eq!(result.items[0].name, "Researcher");
918 assert_eq!(result.items[0].entity_type, Some("researcher".to_string()));
919 }
920
921 #[tokio::test]
925 async fn test_same_id_upsert_replaces_row() {
926 let pool = setup_pool();
927 let store = SqlEntityStore::new(Arc::clone(&pool), false);
928
929 let shared_id = Uuid::new_v4();
930 let now = chrono::Utc::now().timestamp_micros();
931
932 let entity_a = Entity {
933 id: shared_id,
934 namespace: "ns_a".to_string(),
935 kind: "concept".to_string(),
936 entity_type: None,
937 name: "SharedInA".to_string(),
938 description: None,
939 properties: None,
940 tags: Vec::new(),
941 created_at: now,
942 updated_at: now,
943 deleted_at: None,
944 merged_into: None,
945 merge_event_id: None,
946 };
947 store.upsert_entity(entity_a).await.unwrap();
948
949 let fetched = store.get_entity(shared_id).await.unwrap().unwrap();
951 assert_eq!(fetched.namespace, "ns_a");
952 assert_eq!(fetched.name, "SharedInA");
953
954 let entity_b = Entity {
956 id: shared_id,
957 namespace: "ns_b".to_string(),
958 kind: "concept".to_string(),
959 entity_type: None,
960 name: "SharedInB".to_string(),
961 description: None,
962 properties: None,
963 tags: Vec::new(),
964 created_at: now,
965 updated_at: now,
966 deleted_at: None,
967 merged_into: None,
968 merge_event_id: None,
969 };
970 store.upsert_entity(entity_b).await.unwrap();
971
972 let fetched = store.get_entity(shared_id).await.unwrap().unwrap();
975 assert_eq!(fetched.namespace, "ns_b");
976 assert_eq!(fetched.name, "SharedInB");
977
978 let count_a = store
980 .count_entities("ns_a", EntityFilter::default())
981 .await
982 .unwrap();
983 let count_b = store
984 .count_entities("ns_b", EntityFilter::default())
985 .await
986 .unwrap();
987 assert_eq!(count_a, 0);
988 assert_eq!(count_b, 1);
989 }
990}