1use std::sync::Arc;
4
5use async_trait::async_trait;
6use uuid::Uuid;
7
8use khive_storage::entity::{Entity, EntityFilter};
9use khive_storage::error::StorageError;
10use khive_storage::types::{BatchWriteSummary, DeleteMode, Page, PageRequest};
11use khive_storage::EntityStore;
12use khive_storage::StorageCapability;
13
14use crate::error::SqliteError;
15use crate::pool::ConnectionPool;
16
17fn map_err(e: rusqlite::Error, op: &'static str) -> StorageError {
18 StorageError::driver(StorageCapability::Entities, op, e)
19}
20
21fn map_sqlite_err(e: SqliteError, op: &'static str) -> StorageError {
22 StorageError::driver(StorageCapability::Entities, op, e)
23}
24
25pub struct SqlEntityStore {
30 pool: Arc<ConnectionPool>,
31 is_file_backed: bool,
32}
33
34impl SqlEntityStore {
35 pub fn new(pool: Arc<ConnectionPool>, is_file_backed: bool) -> Self {
37 Self {
38 pool,
39 is_file_backed,
40 }
41 }
42
43 fn open_standalone_reader(&self) -> Result<rusqlite::Connection, StorageError> {
44 let config = self.pool.config();
45 let path = config.path.as_ref().ok_or_else(|| StorageError::Pool {
46 operation: "entity_reader".into(),
47 message: "in-memory databases do not support standalone connections".into(),
48 })?;
49
50 let conn = rusqlite::Connection::open_with_flags(
51 path,
52 rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY
53 | rusqlite::OpenFlags::SQLITE_OPEN_NO_MUTEX
54 | rusqlite::OpenFlags::SQLITE_OPEN_URI,
55 )
56 .map_err(|e| map_err(e, "open_entity_reader"))?;
57
58 conn.busy_timeout(config.busy_timeout)
59 .map_err(|e| map_err(e, "open_entity_reader"))?;
60 conn.pragma_update(None, "foreign_keys", "ON")
61 .map_err(|e| map_err(e, "open_entity_reader"))?;
62 conn.pragma_update(None, "synchronous", "NORMAL")
63 .map_err(|e| map_err(e, "open_entity_reader"))?;
64
65 Ok(conn)
66 }
67
68 async fn with_writer<F, R>(&self, op: &'static str, f: F) -> Result<R, StorageError>
69 where
70 F: FnOnce(&rusqlite::Connection) -> Result<R, rusqlite::Error> + Send + 'static,
71 R: Send + 'static,
72 {
73 let pool = Arc::clone(&self.pool);
74 tokio::task::spawn_blocking(move || {
75 let guard = pool.try_writer().map_err(|e| map_sqlite_err(e, op))?;
76 f(guard.conn()).map_err(|e| map_err(e, op))
77 })
78 .await
79 .map_err(|e| StorageError::driver(StorageCapability::Entities, op, e))?
80 }
81
82 async fn with_reader<F, R>(&self, op: &'static str, f: F) -> Result<R, StorageError>
83 where
84 F: FnOnce(&rusqlite::Connection) -> Result<R, rusqlite::Error> + Send + 'static,
85 R: Send + 'static,
86 {
87 if self.is_file_backed {
88 let conn = self.open_standalone_reader()?;
89 tokio::task::spawn_blocking(move || f(&conn).map_err(|e| map_err(e, op)))
90 .await
91 .map_err(|e| StorageError::driver(StorageCapability::Entities, op, e))?
92 } else {
93 let pool = Arc::clone(&self.pool);
94 tokio::task::spawn_blocking(move || {
95 let guard = pool.reader().map_err(|e| map_sqlite_err(e, op))?;
96 f(guard.conn()).map_err(|e| map_err(e, op))
97 })
98 .await
99 .map_err(|e| StorageError::driver(StorageCapability::Entities, op, e))?
100 }
101 }
102}
103
104fn read_entity(row: &rusqlite::Row<'_>) -> Result<Entity, rusqlite::Error> {
109 let id_str: String = row.get(0)?;
110 let namespace: String = row.get(1)?;
111 let kind: String = row.get(2)?;
112 let name: String = row.get(3)?;
113 let description: Option<String> = row.get(4)?;
114 let properties_str: Option<String> = row.get(5)?;
115 let tags_str: String = row.get(6)?;
116 let created_at: i64 = row.get(7)?;
117 let updated_at: i64 = row.get(8)?;
118 let deleted_at: Option<i64> = row.get(9)?;
119
120 let id = parse_uuid(&id_str)?;
121
122 let properties = properties_str
123 .map(|s| {
124 serde_json::from_str(&s).map_err(|e| {
125 rusqlite::Error::FromSqlConversionFailure(
126 5,
127 rusqlite::types::Type::Text,
128 Box::new(e),
129 )
130 })
131 })
132 .transpose()?;
133
134 let tags: Vec<String> = serde_json::from_str(&tags_str).map_err(|e| {
135 rusqlite::Error::FromSqlConversionFailure(6, rusqlite::types::Type::Text, Box::new(e))
136 })?;
137
138 Ok(Entity {
139 id,
140 namespace,
141 kind,
142 name,
143 description,
144 properties,
145 tags,
146 created_at,
147 updated_at,
148 deleted_at,
149 })
150}
151
152fn parse_uuid(s: &str) -> Result<Uuid, rusqlite::Error> {
153 Uuid::parse_str(s).map_err(|e| {
154 rusqlite::Error::FromSqlConversionFailure(0, rusqlite::types::Type::Text, Box::new(e))
155 })
156}
157
158fn build_entity_where(
159 namespace: &str,
160 filter: &EntityFilter,
161) -> (String, Vec<Box<dyn rusqlite::types::ToSql>>) {
162 let mut conditions: Vec<String> = vec![
163 "namespace = ?1".to_string(),
164 "deleted_at IS NULL".to_string(),
165 ];
166 let mut params: Vec<Box<dyn rusqlite::types::ToSql>> = vec![Box::new(namespace.to_string())];
167
168 if !filter.ids.is_empty() {
169 let placeholders: Vec<String> = filter
170 .ids
171 .iter()
172 .map(|id| {
173 params.push(Box::new(id.to_string()));
174 format!("?{}", params.len())
175 })
176 .collect();
177 conditions.push(format!("id IN ({})", placeholders.join(", ")));
178 }
179
180 if !filter.kinds.is_empty() {
181 let placeholders: Vec<String> = filter
182 .kinds
183 .iter()
184 .map(|k| {
185 params.push(Box::new(k.clone()));
186 format!("?{}", params.len())
187 })
188 .collect();
189 conditions.push(format!("kind IN ({})", placeholders.join(", ")));
190 }
191
192 if let Some(ref prefix) = filter.name_prefix {
193 params.push(Box::new(format!("{}%", prefix)));
194 conditions.push(format!("name LIKE ?{}", params.len()));
195 }
196
197 if !filter.tags_any.is_empty() {
198 let placeholders: Vec<String> = filter
199 .tags_any
200 .iter()
201 .map(|t| {
202 params.push(Box::new(t.clone()));
203 format!("?{}", params.len())
204 })
205 .collect();
206 conditions.push(format!(
207 "EXISTS (SELECT 1 FROM json_each(tags) WHERE json_each.value IN ({}))",
208 placeholders.join(", ")
209 ));
210 }
211
212 let clause = format!(" WHERE {}", conditions.join(" AND "));
213 (clause, params)
214}
215
216#[async_trait]
221impl EntityStore for SqlEntityStore {
222 async fn upsert_entity(&self, entity: Entity) -> Result<(), StorageError> {
223 let namespace = entity.namespace.clone();
224 let id_str = entity.id.to_string();
225 let properties_str = entity
226 .properties
227 .as_ref()
228 .map(|v| serde_json::to_string(v).unwrap_or_default());
229 let tags_str = serde_json::to_string(&entity.tags).unwrap_or_else(|_| "[]".to_string());
230
231 self.with_writer("upsert_entity", move |conn| {
232 conn.execute(
233 "INSERT OR REPLACE INTO entities \
234 (id, namespace, kind, name, description, properties, tags, \
235 created_at, updated_at, deleted_at) \
236 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)",
237 rusqlite::params![
238 id_str,
239 namespace,
240 entity.kind,
241 entity.name,
242 entity.description,
243 properties_str,
244 tags_str,
245 entity.created_at,
246 entity.updated_at,
247 entity.deleted_at,
248 ],
249 )?;
250 Ok(())
251 })
252 .await
253 }
254
255 async fn upsert_entities(
256 &self,
257 entities: Vec<Entity>,
258 ) -> Result<BatchWriteSummary, StorageError> {
259 let attempted = entities.len() as u64;
260
261 self.with_writer("upsert_entities", move |conn| {
262 conn.execute_batch("BEGIN IMMEDIATE")?;
263 let mut affected = 0u64;
264 let mut failed = 0u64;
265 let mut first_error = String::new();
266
267 for entity in &entities {
268 let id_str = entity.id.to_string();
269 let properties_str = entity
270 .properties
271 .as_ref()
272 .map(|v| serde_json::to_string(v).unwrap_or_default());
273 let tags_str =
274 serde_json::to_string(&entity.tags).unwrap_or_else(|_| "[]".to_string());
275
276 match conn.execute(
277 "INSERT OR REPLACE INTO entities \
278 (id, namespace, kind, name, description, properties, tags, \
279 created_at, updated_at, deleted_at) \
280 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)",
281 rusqlite::params![
282 id_str,
283 &entity.namespace,
284 entity.kind,
285 entity.name,
286 entity.description,
287 properties_str,
288 tags_str,
289 entity.created_at,
290 entity.updated_at,
291 entity.deleted_at,
292 ],
293 ) {
294 Ok(_) => affected += 1,
295 Err(e) => {
296 if first_error.is_empty() {
297 first_error = e.to_string();
298 }
299 failed += 1;
300 }
301 }
302 }
303
304 if let Err(e) = conn.execute_batch("COMMIT") {
305 let _ = conn.execute_batch("ROLLBACK");
306 return Err(e);
307 }
308 Ok(BatchWriteSummary {
309 attempted,
310 affected,
311 failed,
312 first_error,
313 })
314 })
315 .await
316 }
317
318 async fn get_entity(&self, id: Uuid) -> Result<Option<Entity>, StorageError> {
319 let id_str = id.to_string();
320
321 self.with_reader("get_entity", move |conn| {
322 let mut stmt = conn.prepare(
323 "SELECT id, namespace, kind, name, description, properties, tags, \
324 created_at, updated_at, deleted_at \
325 FROM entities WHERE id = ?1 AND deleted_at IS NULL",
326 )?;
327 let mut rows = stmt.query(rusqlite::params![id_str])?;
328 match rows.next()? {
329 Some(row) => Ok(Some(read_entity(row)?)),
330 None => Ok(None),
331 }
332 })
333 .await
334 }
335
336 async fn delete_entity(&self, id: Uuid, mode: DeleteMode) -> Result<bool, StorageError> {
337 let id_str = id.to_string();
338
339 match mode {
340 DeleteMode::Soft => {
341 self.with_writer("delete_entity_soft", move |conn| {
342 let now = chrono::Utc::now().timestamp_micros();
343 let deleted = conn.execute(
344 "UPDATE entities SET deleted_at = ?1 \
345 WHERE id = ?2 AND deleted_at IS NULL",
346 rusqlite::params![now, id_str],
347 )?;
348 Ok(deleted > 0)
349 })
350 .await
351 }
352 DeleteMode::Hard => {
353 self.with_writer("delete_entity_hard", move |conn| {
354 let deleted = conn.execute(
355 "DELETE FROM entities WHERE id = ?1",
356 rusqlite::params![id_str],
357 )?;
358 Ok(deleted > 0)
359 })
360 .await
361 }
362 }
363 }
364
365 async fn query_entities(
366 &self,
367 namespace: &str,
368 filter: EntityFilter,
369 page: PageRequest,
370 ) -> Result<Page<Entity>, StorageError> {
371 let namespace = namespace.to_string();
372
373 self.with_reader("query_entities", move |conn| {
374 let (count_sql, count_params) = build_entity_where(&namespace, &filter);
375 let total: i64 = {
376 let sql = format!("SELECT COUNT(*) FROM entities{}", count_sql);
377 let mut stmt = conn.prepare(&sql)?;
378 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
379 count_params.iter().map(|p| p.as_ref()).collect();
380 stmt.query_row(param_refs.as_slice(), |row| row.get(0))?
381 };
382
383 let (where_sql, mut data_params) = build_entity_where(&namespace, &filter);
384 data_params.push(Box::new(page.limit as i64));
385 data_params.push(Box::new(page.offset as i64));
386
387 let limit_idx = data_params.len() - 1;
388 let offset_idx = data_params.len();
389
390 let data_sql = format!(
391 "SELECT id, namespace, kind, name, description, properties, tags, \
392 created_at, updated_at, deleted_at \
393 FROM entities{} ORDER BY created_at DESC LIMIT ?{} OFFSET ?{}",
394 where_sql, limit_idx, offset_idx,
395 );
396
397 let mut stmt = conn.prepare(&data_sql)?;
398 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
399 data_params.iter().map(|p| p.as_ref()).collect();
400 let rows = stmt.query_map(param_refs.as_slice(), read_entity)?;
401
402 let mut items = Vec::new();
403 for row in rows {
404 items.push(row?);
405 }
406
407 Ok(Page {
408 items,
409 total: Some(total as u64),
410 })
411 })
412 .await
413 }
414
415 async fn count_entities(
416 &self,
417 namespace: &str,
418 filter: EntityFilter,
419 ) -> Result<u64, StorageError> {
420 let namespace = namespace.to_string();
421
422 self.with_reader("count_entities", move |conn| {
423 let (where_sql, params) = build_entity_where(&namespace, &filter);
424 let sql = format!("SELECT COUNT(*) FROM entities{}", where_sql);
425 let mut stmt = conn.prepare(&sql)?;
426 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
427 params.iter().map(|p| p.as_ref()).collect();
428 let count: i64 = stmt.query_row(param_refs.as_slice(), |row| row.get(0))?;
429 Ok(count as u64)
430 })
431 .await
432 }
433}
434
435const ENTITIES_DDL: &str = "\
440 CREATE TABLE IF NOT EXISTS entities (\
441 id TEXT PRIMARY KEY,\
442 namespace TEXT NOT NULL,\
443 kind TEXT NOT NULL,\
444 name TEXT NOT NULL,\
445 description TEXT,\
446 properties TEXT,\
447 tags TEXT NOT NULL DEFAULT '[]',\
448 created_at INTEGER NOT NULL,\
449 updated_at INTEGER NOT NULL,\
450 deleted_at INTEGER\
451 );\
452 CREATE INDEX IF NOT EXISTS idx_entities_namespace ON entities(namespace);\
453 CREATE INDEX IF NOT EXISTS idx_entities_kind ON entities(namespace, kind);\
454 CREATE INDEX IF NOT EXISTS idx_entities_name ON entities(namespace, name);\
455 CREATE INDEX IF NOT EXISTS idx_entities_created ON entities(created_at DESC);\
456";
457
458pub(crate) fn ensure_entities_schema(conn: &rusqlite::Connection) -> Result<(), rusqlite::Error> {
459 conn.execute_batch(ENTITIES_DDL)
460}
461
462#[cfg(test)]
463mod tests {
464 use super::*;
465 use crate::pool::PoolConfig;
466
467 fn setup_pool() -> Arc<ConnectionPool> {
468 let config = PoolConfig {
469 path: None,
470 ..PoolConfig::default()
471 };
472 let pool = Arc::new(ConnectionPool::new(config).unwrap());
473 {
474 let writer = pool.writer().unwrap();
475 writer.conn().execute_batch(ENTITIES_DDL).unwrap();
476 }
477 pool
478 }
479
480 fn setup_memory_store() -> SqlEntityStore {
481 SqlEntityStore::new(setup_pool(), false)
482 }
483
484 fn setup_memory_store_ns(_ns: &str) -> SqlEntityStore {
485 SqlEntityStore::new(setup_pool(), false)
486 }
487
488 fn make_entity(namespace: &str, kind: &str, name: &str) -> Entity {
489 let now = chrono::Utc::now().timestamp_micros();
490 Entity {
491 id: Uuid::new_v4(),
492 namespace: namespace.to_string(),
493 kind: kind.to_string(),
494 name: name.to_string(),
495 description: None,
496 properties: None,
497 tags: Vec::new(),
498 created_at: now,
499 updated_at: now,
500 deleted_at: None,
501 }
502 }
503
504 #[tokio::test]
505 async fn test_upsert_and_get_entity() {
506 let store = setup_memory_store();
507
508 let entity = make_entity("default", "concept", "LoRA");
509 let id = entity.id;
510
511 store.upsert_entity(entity).await.unwrap();
512
513 let fetched = store.get_entity(id).await.unwrap();
514 assert!(fetched.is_some());
515 let fetched = fetched.unwrap();
516 assert_eq!(fetched.id, id);
517 assert_eq!(fetched.name, "LoRA");
518 assert_eq!(fetched.kind, "concept");
519 }
520
521 #[tokio::test]
522 async fn test_upsert_with_builder() {
523 let store = setup_memory_store();
524
525 let props = serde_json::json!({"domain": "fine-tuning", "type": "technique"});
526 let entity = Entity::new("default", "concept", "QLoRA")
527 .with_description("Quantized LoRA")
528 .with_properties(props.clone())
529 .with_tags(vec!["fine-tuning".to_string(), "quantization".to_string()]);
530 let id = entity.id;
531
532 store.upsert_entity(entity).await.unwrap();
533
534 let fetched = store.get_entity(id).await.unwrap().unwrap();
535 assert_eq!(fetched.description.as_deref(), Some("Quantized LoRA"));
536 assert_eq!(fetched.properties, Some(props));
537 assert_eq!(fetched.tags, vec!["fine-tuning", "quantization"]);
538 }
539
540 #[tokio::test]
541 async fn test_soft_delete() {
542 let store = setup_memory_store();
543
544 let entity = make_entity("default", "concept", "to-delete");
545 let id = entity.id;
546 store.upsert_entity(entity).await.unwrap();
547
548 let deleted = store.delete_entity(id, DeleteMode::Soft).await.unwrap();
549 assert!(deleted);
550
551 let fetched = store.get_entity(id).await.unwrap();
552 assert!(fetched.is_none());
553 }
554
555 #[tokio::test]
556 async fn test_hard_delete() {
557 let store = setup_memory_store();
558
559 let entity = make_entity("default", "concept", "to-hard-delete");
560 let id = entity.id;
561 store.upsert_entity(entity).await.unwrap();
562
563 let deleted = store.delete_entity(id, DeleteMode::Hard).await.unwrap();
564 assert!(deleted);
565
566 let fetched = store.get_entity(id).await.unwrap();
567 assert!(fetched.is_none());
568 }
569
570 #[tokio::test]
571 async fn test_query_entities_basic() {
572 let store = setup_memory_store_ns("ns1");
573
574 for name in &["Alpha", "Beta", "Gamma"] {
575 store
576 .upsert_entity(make_entity("ns1", "concept", name))
577 .await
578 .unwrap();
579 }
580 store
581 .upsert_entity(make_entity("ns1", "document", "Paper1"))
582 .await
583 .unwrap();
584
585 let page = store
586 .query_entities(
587 "ns1",
588 EntityFilter::default(),
589 PageRequest {
590 offset: 0,
591 limit: 10,
592 },
593 )
594 .await
595 .unwrap();
596 assert_eq!(page.items.len(), 4);
597 assert_eq!(page.total, Some(4));
598
599 let concepts = store
601 .query_entities(
602 "ns1",
603 EntityFilter {
604 kinds: vec!["concept".to_string()],
605 ..Default::default()
606 },
607 PageRequest::default(),
608 )
609 .await
610 .unwrap();
611 assert_eq!(concepts.items.len(), 3);
612 }
613
614 #[tokio::test]
615 async fn test_query_by_name_prefix() {
616 let store = setup_memory_store_ns("ns1");
617
618 for &name in &["Alpha", "AlphaGo", "Beta"] {
620 store
621 .upsert_entity(make_entity("ns1", "concept", name))
622 .await
623 .unwrap();
624 }
625
626 let result = store
627 .query_entities(
628 "ns1",
629 EntityFilter {
630 name_prefix: Some("Alpha".to_string()),
631 ..Default::default()
632 },
633 PageRequest::default(),
634 )
635 .await
636 .unwrap();
637 assert_eq!(result.items.len(), 2);
638 let names: Vec<&str> = result.items.iter().map(|e| e.name.as_str()).collect();
639 assert!(names.contains(&"Alpha"), "Alpha not found in {names:?}");
640 assert!(names.contains(&"AlphaGo"), "AlphaGo not found in {names:?}");
641 assert!(!names.contains(&"Beta"));
642 }
643
644 #[tokio::test]
645 async fn test_count_entities() {
646 let store = setup_memory_store_ns("ns1");
647
648 for _ in 0..5 {
649 store
650 .upsert_entity(make_entity("ns1", "concept", "X"))
651 .await
652 .unwrap();
653 }
654
655 let count = store
656 .count_entities("ns1", EntityFilter::default())
657 .await
658 .unwrap();
659 assert_eq!(count, 5);
660
661 let count_other = store
664 .count_entities("ns2", EntityFilter::default())
665 .await
666 .unwrap();
667 assert_eq!(count_other, 0);
668 }
669
670 #[tokio::test]
671 async fn test_batch_upsert() {
672 let store = setup_memory_store_ns("batch_ns");
673
674 let entities: Vec<Entity> = (0..10)
675 .map(|i| make_entity("batch_ns", "concept", &format!("entity_{i}")))
676 .collect();
677
678 let summary = store.upsert_entities(entities).await.unwrap();
679 assert_eq!(summary.attempted, 10);
680 assert_eq!(summary.affected, 10);
681 assert_eq!(summary.failed, 0);
682
683 let count = store
684 .count_entities("batch_ns", EntityFilter::default())
685 .await
686 .unwrap();
687 assert_eq!(count, 10);
688 }
689
690 #[tokio::test]
692 async fn test_namespace_isolation() {
693 let pool = setup_pool();
694 let store = SqlEntityStore::new(Arc::clone(&pool), false);
695
696 store
697 .upsert_entity(make_entity("ns_a", "concept", "EntityA"))
698 .await
699 .unwrap();
700 store
701 .upsert_entity(make_entity("ns_b", "concept", "EntityB"))
702 .await
703 .unwrap();
704
705 let count_a = store
707 .count_entities("ns_a", EntityFilter::default())
708 .await
709 .unwrap();
710 let count_b = store
711 .count_entities("ns_b", EntityFilter::default())
712 .await
713 .unwrap();
714
715 assert_eq!(count_a, 1);
716 assert_eq!(count_b, 1);
717
718 let page_a = store
719 .query_entities("ns_a", EntityFilter::default(), PageRequest::default())
720 .await
721 .unwrap();
722 assert_eq!(page_a.items[0].name, "EntityA");
723
724 let page_b = store
725 .query_entities("ns_b", EntityFilter::default(), PageRequest::default())
726 .await
727 .unwrap();
728 assert_eq!(page_b.items[0].name, "EntityB");
729 }
730
731 #[tokio::test]
732 async fn test_query_by_tags() {
733 let store = setup_memory_store_ns("tags_ns");
734
735 let mut e1 = make_entity("tags_ns", "concept", "Tagged1");
736 e1.tags = vec!["rust".to_string(), "systems".to_string()];
737 let mut e2 = make_entity("tags_ns", "concept", "Tagged2");
738 e2.tags = vec!["python".to_string(), "ml".to_string()];
739 let mut e3 = make_entity("tags_ns", "concept", "Tagged3");
740 e3.tags = vec!["rust".to_string(), "ml".to_string()];
741
742 store.upsert_entity(e1).await.unwrap();
743 store.upsert_entity(e2).await.unwrap();
744 store.upsert_entity(e3).await.unwrap();
745
746 let result = store
748 .query_entities(
749 "tags_ns",
750 EntityFilter {
751 tags_any: vec!["rust".to_string()],
752 ..Default::default()
753 },
754 PageRequest::default(),
755 )
756 .await
757 .unwrap();
758 assert_eq!(result.items.len(), 2);
759 let names: Vec<&str> = result.items.iter().map(|e| e.name.as_str()).collect();
760 assert!(names.contains(&"Tagged1"));
761 assert!(names.contains(&"Tagged3"));
762 assert!(!names.contains(&"Tagged2"));
763
764 let result = store
766 .query_entities(
767 "tags_ns",
768 EntityFilter {
769 tags_any: vec!["ml".to_string()],
770 ..Default::default()
771 },
772 PageRequest::default(),
773 )
774 .await
775 .unwrap();
776 assert_eq!(result.items.len(), 2);
777
778 let result = store
780 .query_entities(
781 "tags_ns",
782 EntityFilter {
783 tags_any: vec!["rust".to_string(), "python".to_string()],
784 ..Default::default()
785 },
786 PageRequest::default(),
787 )
788 .await
789 .unwrap();
790 assert_eq!(result.items.len(), 3);
791 }
792
793 #[tokio::test]
794 async fn test_query_by_ids() {
795 let store = setup_memory_store_ns("ns1");
796
797 let e1 = make_entity("ns1", "concept", "E1");
798 let e2 = make_entity("ns1", "concept", "E2");
799 let e3 = make_entity("ns1", "concept", "E3");
800 let ids = vec![e1.id, e3.id];
801
802 store.upsert_entity(e1).await.unwrap();
803 store.upsert_entity(e2).await.unwrap();
804 store.upsert_entity(e3).await.unwrap();
805
806 let result = store
807 .query_entities(
808 "ns1",
809 EntityFilter {
810 ids,
811 ..Default::default()
812 },
813 PageRequest::default(),
814 )
815 .await
816 .unwrap();
817 assert_eq!(result.items.len(), 2);
818 let names: Vec<&str> = result.items.iter().map(|e| e.name.as_str()).collect();
819 assert!(names.contains(&"E1"));
820 assert!(names.contains(&"E3"));
821 assert!(!names.contains(&"E2"));
822 }
823
824 #[tokio::test]
828 async fn test_same_id_upsert_replaces_row() {
829 let pool = setup_pool();
830 let store = SqlEntityStore::new(Arc::clone(&pool), false);
831
832 let shared_id = Uuid::new_v4();
833 let now = chrono::Utc::now().timestamp_micros();
834
835 let entity_a = Entity {
836 id: shared_id,
837 namespace: "ns_a".to_string(),
838 kind: "concept".to_string(),
839 name: "SharedInA".to_string(),
840 description: None,
841 properties: None,
842 tags: Vec::new(),
843 created_at: now,
844 updated_at: now,
845 deleted_at: None,
846 };
847 store.upsert_entity(entity_a).await.unwrap();
848
849 let fetched = store.get_entity(shared_id).await.unwrap().unwrap();
851 assert_eq!(fetched.namespace, "ns_a");
852 assert_eq!(fetched.name, "SharedInA");
853
854 let entity_b = Entity {
856 id: shared_id,
857 namespace: "ns_b".to_string(),
858 kind: "concept".to_string(),
859 name: "SharedInB".to_string(),
860 description: None,
861 properties: None,
862 tags: Vec::new(),
863 created_at: now,
864 updated_at: now,
865 deleted_at: None,
866 };
867 store.upsert_entity(entity_b).await.unwrap();
868
869 let fetched = store.get_entity(shared_id).await.unwrap().unwrap();
872 assert_eq!(fetched.namespace, "ns_b");
873 assert_eq!(fetched.name, "SharedInB");
874
875 let count_a = store
877 .count_entities("ns_a", EntityFilter::default())
878 .await
879 .unwrap();
880 let count_b = store
881 .count_entities("ns_b", EntityFilter::default())
882 .await
883 .unwrap();
884 assert_eq!(count_a, 0);
885 assert_eq!(count_b, 1);
886 }
887}