1use std::sync::Arc;
4
5use async_trait::async_trait;
6use uuid::Uuid;
7
8use khive_storage::entity::{Entity, EntityFilter};
9use khive_storage::error::StorageError;
10use khive_storage::types::{BatchWriteSummary, DeleteMode, Page, PageRequest};
11use khive_storage::EntityStore;
12use khive_storage::StorageCapability;
13use khive_types::EntityKind;
14
15use crate::error::SqliteError;
16use crate::pool::ConnectionPool;
17
18fn map_err(e: rusqlite::Error, op: &'static str) -> StorageError {
19 StorageError::driver(StorageCapability::Entities, op, e)
20}
21
22fn map_sqlite_err(e: SqliteError, op: &'static str) -> StorageError {
23 StorageError::driver(StorageCapability::Entities, op, e)
24}
25
26pub struct SqlEntityStore {
31 pool: Arc<ConnectionPool>,
32 is_file_backed: bool,
33}
34
35impl SqlEntityStore {
36 pub fn new(pool: Arc<ConnectionPool>, is_file_backed: bool) -> Self {
38 Self {
39 pool,
40 is_file_backed,
41 }
42 }
43
44 fn open_standalone_reader(&self) -> Result<rusqlite::Connection, StorageError> {
45 let config = self.pool.config();
46 let path = config.path.as_ref().ok_or_else(|| StorageError::Pool {
47 operation: "entity_reader".into(),
48 message: "in-memory databases do not support standalone connections".into(),
49 })?;
50
51 let conn = rusqlite::Connection::open_with_flags(
52 path,
53 rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY
54 | rusqlite::OpenFlags::SQLITE_OPEN_NO_MUTEX
55 | rusqlite::OpenFlags::SQLITE_OPEN_URI,
56 )
57 .map_err(|e| map_err(e, "open_entity_reader"))?;
58
59 conn.busy_timeout(config.busy_timeout)
60 .map_err(|e| map_err(e, "open_entity_reader"))?;
61 conn.pragma_update(None, "foreign_keys", "ON")
62 .map_err(|e| map_err(e, "open_entity_reader"))?;
63 conn.pragma_update(None, "synchronous", "NORMAL")
64 .map_err(|e| map_err(e, "open_entity_reader"))?;
65
66 Ok(conn)
67 }
68
69 async fn with_writer<F, R>(&self, op: &'static str, f: F) -> Result<R, StorageError>
70 where
71 F: FnOnce(&rusqlite::Connection) -> Result<R, rusqlite::Error> + Send + 'static,
72 R: Send + 'static,
73 {
74 let pool = Arc::clone(&self.pool);
75 tokio::task::spawn_blocking(move || {
76 let guard = pool.try_writer().map_err(|e| map_sqlite_err(e, op))?;
77 f(guard.conn()).map_err(|e| map_err(e, op))
78 })
79 .await
80 .map_err(|e| StorageError::driver(StorageCapability::Entities, op, e))?
81 }
82
83 async fn with_reader<F, R>(&self, op: &'static str, f: F) -> Result<R, StorageError>
84 where
85 F: FnOnce(&rusqlite::Connection) -> Result<R, rusqlite::Error> + Send + 'static,
86 R: Send + 'static,
87 {
88 if self.is_file_backed {
89 let conn = self.open_standalone_reader()?;
90 tokio::task::spawn_blocking(move || f(&conn).map_err(|e| map_err(e, op)))
91 .await
92 .map_err(|e| StorageError::driver(StorageCapability::Entities, op, e))?
93 } else {
94 let pool = Arc::clone(&self.pool);
95 tokio::task::spawn_blocking(move || {
96 let guard = pool.reader().map_err(|e| map_sqlite_err(e, op))?;
97 f(guard.conn()).map_err(|e| map_err(e, op))
98 })
99 .await
100 .map_err(|e| StorageError::driver(StorageCapability::Entities, op, e))?
101 }
102 }
103}
104
105fn read_entity(row: &rusqlite::Row<'_>) -> Result<Entity, rusqlite::Error> {
110 let id_str: String = row.get(0)?;
111 let namespace: String = row.get(1)?;
112 let kind: String = row.get(2)?;
113 let name: String = row.get(3)?;
114 let description: Option<String> = row.get(4)?;
115 let properties_str: Option<String> = row.get(5)?;
116 let tags_str: String = row.get(6)?;
117 let created_at: i64 = row.get(7)?;
118 let updated_at: i64 = row.get(8)?;
119 let deleted_at: Option<i64> = row.get(9)?;
120
121 let id = parse_uuid(&id_str)?;
122
123 let parsed_kind = kind.parse::<EntityKind>().map_err(|e| {
124 rusqlite::Error::FromSqlConversionFailure(
125 2,
126 rusqlite::types::Type::Text,
127 Box::new(std::io::Error::new(std::io::ErrorKind::InvalidData, e)),
128 )
129 })?;
130
131 let properties = properties_str
132 .map(|s| {
133 serde_json::from_str(&s).map_err(|e| {
134 rusqlite::Error::FromSqlConversionFailure(
135 5,
136 rusqlite::types::Type::Text,
137 Box::new(e),
138 )
139 })
140 })
141 .transpose()?;
142
143 let tags: Vec<String> = serde_json::from_str(&tags_str).map_err(|e| {
144 rusqlite::Error::FromSqlConversionFailure(6, rusqlite::types::Type::Text, Box::new(e))
145 })?;
146
147 Ok(Entity {
148 id,
149 namespace,
150 kind: parsed_kind,
151 name,
152 description,
153 properties,
154 tags,
155 created_at,
156 updated_at,
157 deleted_at,
158 })
159}
160
161fn parse_uuid(s: &str) -> Result<Uuid, rusqlite::Error> {
162 Uuid::parse_str(s).map_err(|e| {
163 rusqlite::Error::FromSqlConversionFailure(0, rusqlite::types::Type::Text, Box::new(e))
164 })
165}
166
167fn build_entity_where(
168 namespace: &str,
169 filter: &EntityFilter,
170) -> (String, Vec<Box<dyn rusqlite::types::ToSql>>) {
171 let mut conditions: Vec<String> = vec![
172 "namespace = ?1".to_string(),
173 "deleted_at IS NULL".to_string(),
174 ];
175 let mut params: Vec<Box<dyn rusqlite::types::ToSql>> = vec![Box::new(namespace.to_string())];
176
177 if !filter.ids.is_empty() {
178 let placeholders: Vec<String> = filter
179 .ids
180 .iter()
181 .map(|id| {
182 params.push(Box::new(id.to_string()));
183 format!("?{}", params.len())
184 })
185 .collect();
186 conditions.push(format!("id IN ({})", placeholders.join(", ")));
187 }
188
189 if !filter.kinds.is_empty() {
190 let placeholders: Vec<String> = filter
191 .kinds
192 .iter()
193 .map(|k| {
194 params.push(Box::new(k.to_string()));
195 format!("?{}", params.len())
196 })
197 .collect();
198 conditions.push(format!("kind IN ({})", placeholders.join(", ")));
199 }
200
201 if let Some(ref prefix) = filter.name_prefix {
202 params.push(Box::new(format!("{}%", prefix)));
203 conditions.push(format!("name LIKE ?{}", params.len()));
204 }
205
206 if !filter.tags_any.is_empty() {
207 let placeholders: Vec<String> = filter
208 .tags_any
209 .iter()
210 .map(|t| {
211 params.push(Box::new(t.clone()));
212 format!("?{}", params.len())
213 })
214 .collect();
215 conditions.push(format!(
216 "EXISTS (SELECT 1 FROM json_each(tags) WHERE json_each.value IN ({}))",
217 placeholders.join(", ")
218 ));
219 }
220
221 let clause = format!(" WHERE {}", conditions.join(" AND "));
222 (clause, params)
223}
224
225#[async_trait]
230impl EntityStore for SqlEntityStore {
231 async fn upsert_entity(&self, entity: Entity) -> Result<(), StorageError> {
232 let namespace = entity.namespace.clone();
233 let id_str = entity.id.to_string();
234 let properties_str = entity
235 .properties
236 .as_ref()
237 .map(|v| serde_json::to_string(v).unwrap_or_default());
238 let tags_str = serde_json::to_string(&entity.tags).unwrap_or_else(|_| "[]".to_string());
239
240 self.with_writer("upsert_entity", move |conn| {
241 conn.execute(
242 "INSERT OR REPLACE INTO entities \
243 (id, namespace, kind, name, description, properties, tags, \
244 created_at, updated_at, deleted_at) \
245 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)",
246 rusqlite::params![
247 id_str,
248 namespace,
249 entity.kind.to_string(),
250 entity.name,
251 entity.description,
252 properties_str,
253 tags_str,
254 entity.created_at,
255 entity.updated_at,
256 entity.deleted_at,
257 ],
258 )?;
259 Ok(())
260 })
261 .await
262 }
263
264 async fn upsert_entities(
265 &self,
266 entities: Vec<Entity>,
267 ) -> Result<BatchWriteSummary, StorageError> {
268 let attempted = entities.len() as u64;
269
270 self.with_writer("upsert_entities", move |conn| {
271 conn.execute_batch("BEGIN IMMEDIATE")?;
272 let mut affected = 0u64;
273 let mut failed = 0u64;
274 let mut first_error = String::new();
275
276 for entity in &entities {
277 let id_str = entity.id.to_string();
278 let properties_str = entity
279 .properties
280 .as_ref()
281 .map(|v| serde_json::to_string(v).unwrap_or_default());
282 let tags_str =
283 serde_json::to_string(&entity.tags).unwrap_or_else(|_| "[]".to_string());
284
285 match conn.execute(
286 "INSERT OR REPLACE INTO entities \
287 (id, namespace, kind, name, description, properties, tags, \
288 created_at, updated_at, deleted_at) \
289 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)",
290 rusqlite::params![
291 id_str,
292 &entity.namespace,
293 entity.kind.to_string(),
294 entity.name,
295 entity.description,
296 properties_str,
297 tags_str,
298 entity.created_at,
299 entity.updated_at,
300 entity.deleted_at,
301 ],
302 ) {
303 Ok(_) => affected += 1,
304 Err(e) => {
305 if first_error.is_empty() {
306 first_error = e.to_string();
307 }
308 failed += 1;
309 }
310 }
311 }
312
313 if let Err(e) = conn.execute_batch("COMMIT") {
314 let _ = conn.execute_batch("ROLLBACK");
315 return Err(e);
316 }
317 Ok(BatchWriteSummary {
318 attempted,
319 affected,
320 failed,
321 first_error,
322 })
323 })
324 .await
325 }
326
327 async fn get_entity(&self, id: Uuid) -> Result<Option<Entity>, StorageError> {
328 let id_str = id.to_string();
329
330 self.with_reader("get_entity", move |conn| {
331 let mut stmt = conn.prepare(
332 "SELECT id, namespace, kind, name, description, properties, tags, \
333 created_at, updated_at, deleted_at \
334 FROM entities WHERE id = ?1 AND deleted_at IS NULL",
335 )?;
336 let mut rows = stmt.query(rusqlite::params![id_str])?;
337 match rows.next()? {
338 Some(row) => Ok(Some(read_entity(row)?)),
339 None => Ok(None),
340 }
341 })
342 .await
343 }
344
345 async fn delete_entity(&self, id: Uuid, mode: DeleteMode) -> Result<bool, StorageError> {
346 let id_str = id.to_string();
347
348 match mode {
349 DeleteMode::Soft => {
350 self.with_writer("delete_entity_soft", move |conn| {
351 let now = chrono::Utc::now().timestamp_micros();
352 let deleted = conn.execute(
353 "UPDATE entities SET deleted_at = ?1 \
354 WHERE id = ?2 AND deleted_at IS NULL",
355 rusqlite::params![now, id_str],
356 )?;
357 Ok(deleted > 0)
358 })
359 .await
360 }
361 DeleteMode::Hard => {
362 self.with_writer("delete_entity_hard", move |conn| {
363 let deleted = conn.execute(
364 "DELETE FROM entities WHERE id = ?1",
365 rusqlite::params![id_str],
366 )?;
367 Ok(deleted > 0)
368 })
369 .await
370 }
371 }
372 }
373
374 async fn query_entities(
375 &self,
376 namespace: &str,
377 filter: EntityFilter,
378 page: PageRequest,
379 ) -> Result<Page<Entity>, StorageError> {
380 let namespace = namespace.to_string();
381
382 self.with_reader("query_entities", move |conn| {
383 let (count_sql, count_params) = build_entity_where(&namespace, &filter);
384 let total: i64 = {
385 let sql = format!("SELECT COUNT(*) FROM entities{}", count_sql);
386 let mut stmt = conn.prepare(&sql)?;
387 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
388 count_params.iter().map(|p| p.as_ref()).collect();
389 stmt.query_row(param_refs.as_slice(), |row| row.get(0))?
390 };
391
392 let (where_sql, mut data_params) = build_entity_where(&namespace, &filter);
393 data_params.push(Box::new(page.limit as i64));
394 data_params.push(Box::new(page.offset as i64));
395
396 let limit_idx = data_params.len() - 1;
397 let offset_idx = data_params.len();
398
399 let data_sql = format!(
400 "SELECT id, namespace, kind, name, description, properties, tags, \
401 created_at, updated_at, deleted_at \
402 FROM entities{} ORDER BY created_at DESC LIMIT ?{} OFFSET ?{}",
403 where_sql, limit_idx, offset_idx,
404 );
405
406 let mut stmt = conn.prepare(&data_sql)?;
407 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
408 data_params.iter().map(|p| p.as_ref()).collect();
409 let rows = stmt.query_map(param_refs.as_slice(), read_entity)?;
410
411 let mut items = Vec::new();
412 for row in rows {
413 items.push(row?);
414 }
415
416 Ok(Page {
417 items,
418 total: Some(total as u64),
419 })
420 })
421 .await
422 }
423
424 async fn count_entities(
425 &self,
426 namespace: &str,
427 filter: EntityFilter,
428 ) -> Result<u64, StorageError> {
429 let namespace = namespace.to_string();
430
431 self.with_reader("count_entities", move |conn| {
432 let (where_sql, params) = build_entity_where(&namespace, &filter);
433 let sql = format!("SELECT COUNT(*) FROM entities{}", where_sql);
434 let mut stmt = conn.prepare(&sql)?;
435 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
436 params.iter().map(|p| p.as_ref()).collect();
437 let count: i64 = stmt.query_row(param_refs.as_slice(), |row| row.get(0))?;
438 Ok(count as u64)
439 })
440 .await
441 }
442}
443
444const ENTITIES_DDL: &str = "\
449 CREATE TABLE IF NOT EXISTS entities (\
450 id TEXT PRIMARY KEY,\
451 namespace TEXT NOT NULL,\
452 kind TEXT NOT NULL,\
453 name TEXT NOT NULL,\
454 description TEXT,\
455 properties TEXT,\
456 tags TEXT NOT NULL DEFAULT '[]',\
457 created_at INTEGER NOT NULL,\
458 updated_at INTEGER NOT NULL,\
459 deleted_at INTEGER\
460 );\
461 CREATE INDEX IF NOT EXISTS idx_entities_namespace ON entities(namespace);\
462 CREATE INDEX IF NOT EXISTS idx_entities_kind ON entities(namespace, kind);\
463 CREATE INDEX IF NOT EXISTS idx_entities_name ON entities(namespace, name);\
464 CREATE INDEX IF NOT EXISTS idx_entities_created ON entities(created_at DESC);\
465";
466
467pub(crate) fn ensure_entities_schema(conn: &rusqlite::Connection) -> Result<(), rusqlite::Error> {
468 conn.execute_batch(ENTITIES_DDL)
469}
470
471#[cfg(test)]
472mod tests {
473 use super::*;
474 use crate::pool::PoolConfig;
475
476 fn setup_pool() -> Arc<ConnectionPool> {
477 let config = PoolConfig {
478 path: None,
479 ..PoolConfig::default()
480 };
481 let pool = Arc::new(ConnectionPool::new(config).unwrap());
482 {
483 let writer = pool.writer().unwrap();
484 writer.conn().execute_batch(ENTITIES_DDL).unwrap();
485 }
486 pool
487 }
488
489 fn setup_memory_store() -> SqlEntityStore {
490 SqlEntityStore::new(setup_pool(), false)
491 }
492
493 fn setup_memory_store_ns(_ns: &str) -> SqlEntityStore {
494 SqlEntityStore::new(setup_pool(), false)
495 }
496
497 fn make_entity(namespace: &str, kind: EntityKind, name: &str) -> Entity {
498 let now = chrono::Utc::now().timestamp_micros();
499 Entity {
500 id: Uuid::new_v4(),
501 namespace: namespace.to_string(),
502 kind,
503 name: name.to_string(),
504 description: None,
505 properties: None,
506 tags: Vec::new(),
507 created_at: now,
508 updated_at: now,
509 deleted_at: None,
510 }
511 }
512
513 #[tokio::test]
514 async fn test_upsert_and_get_entity() {
515 let store = setup_memory_store();
516
517 let entity = make_entity("default", EntityKind::Concept, "LoRA");
518 let id = entity.id;
519
520 store.upsert_entity(entity).await.unwrap();
521
522 let fetched = store.get_entity(id).await.unwrap();
523 assert!(fetched.is_some());
524 let fetched = fetched.unwrap();
525 assert_eq!(fetched.id, id);
526 assert_eq!(fetched.name, "LoRA");
527 assert_eq!(fetched.kind, EntityKind::Concept);
528 }
529
530 #[tokio::test]
531 async fn test_upsert_with_builder() {
532 let store = setup_memory_store();
533
534 let props = serde_json::json!({"domain": "fine-tuning", "type": "technique"});
535 let entity = Entity::new("default", EntityKind::Concept, "QLoRA")
536 .with_description("Quantized LoRA")
537 .with_properties(props.clone())
538 .with_tags(vec!["fine-tuning".to_string(), "quantization".to_string()]);
539 let id = entity.id;
540
541 store.upsert_entity(entity).await.unwrap();
542
543 let fetched = store.get_entity(id).await.unwrap().unwrap();
544 assert_eq!(fetched.description.as_deref(), Some("Quantized LoRA"));
545 assert_eq!(fetched.properties, Some(props));
546 assert_eq!(fetched.tags, vec!["fine-tuning", "quantization"]);
547 }
548
549 #[tokio::test]
550 async fn test_soft_delete() {
551 let store = setup_memory_store();
552
553 let entity = make_entity("default", EntityKind::Concept, "to-delete");
554 let id = entity.id;
555 store.upsert_entity(entity).await.unwrap();
556
557 let deleted = store.delete_entity(id, DeleteMode::Soft).await.unwrap();
558 assert!(deleted);
559
560 let fetched = store.get_entity(id).await.unwrap();
561 assert!(fetched.is_none());
562 }
563
564 #[tokio::test]
565 async fn test_hard_delete() {
566 let store = setup_memory_store();
567
568 let entity = make_entity("default", EntityKind::Concept, "to-hard-delete");
569 let id = entity.id;
570 store.upsert_entity(entity).await.unwrap();
571
572 let deleted = store.delete_entity(id, DeleteMode::Hard).await.unwrap();
573 assert!(deleted);
574
575 let fetched = store.get_entity(id).await.unwrap();
576 assert!(fetched.is_none());
577 }
578
579 #[tokio::test]
580 async fn test_query_entities_basic() {
581 let store = setup_memory_store_ns("ns1");
582
583 for name in &["Alpha", "Beta", "Gamma"] {
584 store
585 .upsert_entity(make_entity("ns1", EntityKind::Concept, name))
586 .await
587 .unwrap();
588 }
589 store
590 .upsert_entity(make_entity("ns1", EntityKind::Document, "Paper1"))
591 .await
592 .unwrap();
593
594 let page = store
595 .query_entities(
596 "ns1",
597 EntityFilter::default(),
598 PageRequest {
599 offset: 0,
600 limit: 10,
601 },
602 )
603 .await
604 .unwrap();
605 assert_eq!(page.items.len(), 4);
606 assert_eq!(page.total, Some(4));
607
608 let concepts = store
610 .query_entities(
611 "ns1",
612 EntityFilter {
613 kinds: vec![EntityKind::Concept],
614 ..Default::default()
615 },
616 PageRequest::default(),
617 )
618 .await
619 .unwrap();
620 assert_eq!(concepts.items.len(), 3);
621 }
622
623 #[tokio::test]
624 async fn test_query_by_name_prefix() {
625 let store = setup_memory_store_ns("ns1");
626
627 for &name in &["Alpha", "AlphaGo", "Beta"] {
629 store
630 .upsert_entity(make_entity("ns1", EntityKind::Concept, name))
631 .await
632 .unwrap();
633 }
634
635 let result = store
636 .query_entities(
637 "ns1",
638 EntityFilter {
639 name_prefix: Some("Alpha".to_string()),
640 ..Default::default()
641 },
642 PageRequest::default(),
643 )
644 .await
645 .unwrap();
646 assert_eq!(result.items.len(), 2);
647 let names: Vec<&str> = result.items.iter().map(|e| e.name.as_str()).collect();
648 assert!(names.contains(&"Alpha"), "Alpha not found in {names:?}");
649 assert!(names.contains(&"AlphaGo"), "AlphaGo not found in {names:?}");
650 assert!(!names.contains(&"Beta"));
651 }
652
653 #[tokio::test]
654 async fn test_count_entities() {
655 let store = setup_memory_store_ns("ns1");
656
657 for _ in 0..5 {
658 store
659 .upsert_entity(make_entity("ns1", EntityKind::Concept, "X"))
660 .await
661 .unwrap();
662 }
663
664 let count = store
665 .count_entities("ns1", EntityFilter::default())
666 .await
667 .unwrap();
668 assert_eq!(count, 5);
669
670 let count_other = store
673 .count_entities("ns2", EntityFilter::default())
674 .await
675 .unwrap();
676 assert_eq!(count_other, 0);
677 }
678
679 #[tokio::test]
680 async fn test_batch_upsert() {
681 let store = setup_memory_store_ns("batch_ns");
682
683 let entities: Vec<Entity> = (0..10)
684 .map(|i| make_entity("batch_ns", EntityKind::Concept, &format!("entity_{i}")))
685 .collect();
686
687 let summary = store.upsert_entities(entities).await.unwrap();
688 assert_eq!(summary.attempted, 10);
689 assert_eq!(summary.affected, 10);
690 assert_eq!(summary.failed, 0);
691
692 let count = store
693 .count_entities("batch_ns", EntityFilter::default())
694 .await
695 .unwrap();
696 assert_eq!(count, 10);
697 }
698
699 #[tokio::test]
701 async fn test_namespace_isolation() {
702 let pool = setup_pool();
703 let store = SqlEntityStore::new(Arc::clone(&pool), false);
704
705 store
706 .upsert_entity(make_entity("ns_a", EntityKind::Concept, "EntityA"))
707 .await
708 .unwrap();
709 store
710 .upsert_entity(make_entity("ns_b", EntityKind::Concept, "EntityB"))
711 .await
712 .unwrap();
713
714 let count_a = store
716 .count_entities("ns_a", EntityFilter::default())
717 .await
718 .unwrap();
719 let count_b = store
720 .count_entities("ns_b", EntityFilter::default())
721 .await
722 .unwrap();
723
724 assert_eq!(count_a, 1);
725 assert_eq!(count_b, 1);
726
727 let page_a = store
728 .query_entities("ns_a", EntityFilter::default(), PageRequest::default())
729 .await
730 .unwrap();
731 assert_eq!(page_a.items[0].name, "EntityA");
732
733 let page_b = store
734 .query_entities("ns_b", EntityFilter::default(), PageRequest::default())
735 .await
736 .unwrap();
737 assert_eq!(page_b.items[0].name, "EntityB");
738 }
739
740 #[tokio::test]
741 async fn test_query_by_tags() {
742 let store = setup_memory_store_ns("tags_ns");
743
744 let mut e1 = make_entity("tags_ns", EntityKind::Concept, "Tagged1");
745 e1.tags = vec!["rust".to_string(), "systems".to_string()];
746 let mut e2 = make_entity("tags_ns", EntityKind::Concept, "Tagged2");
747 e2.tags = vec!["python".to_string(), "ml".to_string()];
748 let mut e3 = make_entity("tags_ns", EntityKind::Concept, "Tagged3");
749 e3.tags = vec!["rust".to_string(), "ml".to_string()];
750
751 store.upsert_entity(e1).await.unwrap();
752 store.upsert_entity(e2).await.unwrap();
753 store.upsert_entity(e3).await.unwrap();
754
755 let result = store
757 .query_entities(
758 "tags_ns",
759 EntityFilter {
760 tags_any: vec!["rust".to_string()],
761 ..Default::default()
762 },
763 PageRequest::default(),
764 )
765 .await
766 .unwrap();
767 assert_eq!(result.items.len(), 2);
768 let names: Vec<&str> = result.items.iter().map(|e| e.name.as_str()).collect();
769 assert!(names.contains(&"Tagged1"));
770 assert!(names.contains(&"Tagged3"));
771 assert!(!names.contains(&"Tagged2"));
772
773 let result = store
775 .query_entities(
776 "tags_ns",
777 EntityFilter {
778 tags_any: vec!["ml".to_string()],
779 ..Default::default()
780 },
781 PageRequest::default(),
782 )
783 .await
784 .unwrap();
785 assert_eq!(result.items.len(), 2);
786
787 let result = store
789 .query_entities(
790 "tags_ns",
791 EntityFilter {
792 tags_any: vec!["rust".to_string(), "python".to_string()],
793 ..Default::default()
794 },
795 PageRequest::default(),
796 )
797 .await
798 .unwrap();
799 assert_eq!(result.items.len(), 3);
800 }
801
802 #[tokio::test]
803 async fn test_query_by_ids() {
804 let store = setup_memory_store_ns("ns1");
805
806 let e1 = make_entity("ns1", EntityKind::Concept, "E1");
807 let e2 = make_entity("ns1", EntityKind::Concept, "E2");
808 let e3 = make_entity("ns1", EntityKind::Concept, "E3");
809 let ids = vec![e1.id, e3.id];
810
811 store.upsert_entity(e1).await.unwrap();
812 store.upsert_entity(e2).await.unwrap();
813 store.upsert_entity(e3).await.unwrap();
814
815 let result = store
816 .query_entities(
817 "ns1",
818 EntityFilter {
819 ids,
820 ..Default::default()
821 },
822 PageRequest::default(),
823 )
824 .await
825 .unwrap();
826 assert_eq!(result.items.len(), 2);
827 let names: Vec<&str> = result.items.iter().map(|e| e.name.as_str()).collect();
828 assert!(names.contains(&"E1"));
829 assert!(names.contains(&"E3"));
830 assert!(!names.contains(&"E2"));
831 }
832
833 #[tokio::test]
837 async fn test_same_id_upsert_replaces_row() {
838 let pool = setup_pool();
839 let store = SqlEntityStore::new(Arc::clone(&pool), false);
840
841 let shared_id = Uuid::new_v4();
842 let now = chrono::Utc::now().timestamp_micros();
843
844 let entity_a = Entity {
845 id: shared_id,
846 namespace: "ns_a".to_string(),
847 kind: EntityKind::Concept,
848 name: "SharedInA".to_string(),
849 description: None,
850 properties: None,
851 tags: Vec::new(),
852 created_at: now,
853 updated_at: now,
854 deleted_at: None,
855 };
856 store.upsert_entity(entity_a).await.unwrap();
857
858 let fetched = store.get_entity(shared_id).await.unwrap().unwrap();
860 assert_eq!(fetched.namespace, "ns_a");
861 assert_eq!(fetched.name, "SharedInA");
862
863 let entity_b = Entity {
865 id: shared_id,
866 namespace: "ns_b".to_string(),
867 kind: EntityKind::Concept,
868 name: "SharedInB".to_string(),
869 description: None,
870 properties: None,
871 tags: Vec::new(),
872 created_at: now,
873 updated_at: now,
874 deleted_at: None,
875 };
876 store.upsert_entity(entity_b).await.unwrap();
877
878 let fetched = store.get_entity(shared_id).await.unwrap().unwrap();
881 assert_eq!(fetched.namespace, "ns_b");
882 assert_eq!(fetched.name, "SharedInB");
883
884 let count_a = store
886 .count_entities("ns_a", EntityFilter::default())
887 .await
888 .unwrap();
889 let count_b = store
890 .count_entities("ns_b", EntityFilter::default())
891 .await
892 .unwrap();
893 assert_eq!(count_a, 0);
894 assert_eq!(count_b, 1);
895 }
896}