Skip to main content

khive_db/stores/
entity.rs

1//! SQL-backed `EntityStore` implementation.
2
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use uuid::Uuid;
7
8use khive_storage::entity::{Entity, EntityFilter};
9use khive_storage::error::StorageError;
10use khive_storage::types::{BatchWriteSummary, DeleteMode, Page, PageRequest};
11use khive_storage::EntityStore;
12use khive_storage::StorageCapability;
13
14use crate::error::SqliteError;
15use crate::pool::ConnectionPool;
16
17fn map_err(e: rusqlite::Error, op: &'static str) -> StorageError {
18    StorageError::driver(StorageCapability::Entities, op, e)
19}
20
21fn map_sqlite_err(e: SqliteError, op: &'static str) -> StorageError {
22    StorageError::driver(StorageCapability::Entities, op, e)
23}
24
25/// An EntityStore backed by SQLite. Namespace is the caller's responsibility.
26///
27/// UUID is globally unique — get/delete by ID alone. Query/count use the
28/// namespace parameter as passed. The store is just a pool + is_file_backed.
29pub struct SqlEntityStore {
30    pool: Arc<ConnectionPool>,
31    is_file_backed: bool,
32}
33
34impl SqlEntityStore {
35    /// Create a new store.
36    pub fn new(pool: Arc<ConnectionPool>, is_file_backed: bool) -> Self {
37        Self {
38            pool,
39            is_file_backed,
40        }
41    }
42
43    fn open_standalone_reader(&self) -> Result<rusqlite::Connection, StorageError> {
44        let config = self.pool.config();
45        let path = config.path.as_ref().ok_or_else(|| StorageError::Pool {
46            operation: "entity_reader".into(),
47            message: "in-memory databases do not support standalone connections".into(),
48        })?;
49
50        let conn = rusqlite::Connection::open_with_flags(
51            path,
52            rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY
53                | rusqlite::OpenFlags::SQLITE_OPEN_NO_MUTEX
54                | rusqlite::OpenFlags::SQLITE_OPEN_URI,
55        )
56        .map_err(|e| map_err(e, "open_entity_reader"))?;
57
58        conn.busy_timeout(config.busy_timeout)
59            .map_err(|e| map_err(e, "open_entity_reader"))?;
60        conn.pragma_update(None, "foreign_keys", "ON")
61            .map_err(|e| map_err(e, "open_entity_reader"))?;
62        conn.pragma_update(None, "synchronous", "NORMAL")
63            .map_err(|e| map_err(e, "open_entity_reader"))?;
64
65        Ok(conn)
66    }
67
68    async fn with_writer<F, R>(&self, op: &'static str, f: F) -> Result<R, StorageError>
69    where
70        F: FnOnce(&rusqlite::Connection) -> Result<R, rusqlite::Error> + Send + 'static,
71        R: Send + 'static,
72    {
73        let pool = Arc::clone(&self.pool);
74        tokio::task::spawn_blocking(move || {
75            let guard = pool.try_writer().map_err(|e| map_sqlite_err(e, op))?;
76            f(guard.conn()).map_err(|e| map_err(e, op))
77        })
78        .await
79        .map_err(|e| StorageError::driver(StorageCapability::Entities, op, e))?
80    }
81
82    async fn with_reader<F, R>(&self, op: &'static str, f: F) -> Result<R, StorageError>
83    where
84        F: FnOnce(&rusqlite::Connection) -> Result<R, rusqlite::Error> + Send + 'static,
85        R: Send + 'static,
86    {
87        if self.is_file_backed {
88            let conn = self.open_standalone_reader()?;
89            tokio::task::spawn_blocking(move || f(&conn).map_err(|e| map_err(e, op)))
90                .await
91                .map_err(|e| StorageError::driver(StorageCapability::Entities, op, e))?
92        } else {
93            let pool = Arc::clone(&self.pool);
94            tokio::task::spawn_blocking(move || {
95                let guard = pool.reader().map_err(|e| map_sqlite_err(e, op))?;
96                f(guard.conn()).map_err(|e| map_err(e, op))
97            })
98            .await
99            .map_err(|e| StorageError::driver(StorageCapability::Entities, op, e))?
100        }
101    }
102}
103
104// =============================================================================
105// Helpers
106// =============================================================================
107
108fn read_entity(row: &rusqlite::Row<'_>) -> Result<Entity, rusqlite::Error> {
109    let id_str: String = row.get(0)?;
110    let namespace: String = row.get(1)?;
111    let kind: String = row.get(2)?;
112    let entity_type: Option<String> = row.get(3)?;
113    let name: String = row.get(4)?;
114    let description: Option<String> = row.get(5)?;
115    let properties_str: Option<String> = row.get(6)?;
116    let tags_str: String = row.get(7)?;
117    let created_at: i64 = row.get(8)?;
118    let updated_at: i64 = row.get(9)?;
119    let deleted_at: Option<i64> = row.get(10)?;
120    let merged_into_str: Option<String> = row.get(11)?;
121    let merge_event_id_str: Option<String> = row.get(12)?;
122
123    let id = parse_uuid(&id_str)?;
124
125    let properties = properties_str
126        .map(|s| {
127            serde_json::from_str(&s).map_err(|e| {
128                rusqlite::Error::FromSqlConversionFailure(
129                    6,
130                    rusqlite::types::Type::Text,
131                    Box::new(e),
132                )
133            })
134        })
135        .transpose()?;
136
137    let tags: Vec<String> = serde_json::from_str(&tags_str).map_err(|e| {
138        rusqlite::Error::FromSqlConversionFailure(7, rusqlite::types::Type::Text, Box::new(e))
139    })?;
140
141    let merged_into = merged_into_str
142        .as_deref()
143        .map(Uuid::parse_str)
144        .transpose()
145        .map_err(|e| {
146            rusqlite::Error::FromSqlConversionFailure(10, rusqlite::types::Type::Text, Box::new(e))
147        })?;
148
149    let merge_event_id = merge_event_id_str
150        .as_deref()
151        .map(Uuid::parse_str)
152        .transpose()
153        .map_err(|e| {
154            rusqlite::Error::FromSqlConversionFailure(11, rusqlite::types::Type::Text, Box::new(e))
155        })?;
156
157    Ok(Entity {
158        id,
159        namespace,
160        kind,
161        entity_type,
162        name,
163        description,
164        properties,
165        tags,
166        created_at,
167        updated_at,
168        deleted_at,
169        merged_into,
170        merge_event_id,
171    })
172}
173
174fn parse_uuid(s: &str) -> Result<Uuid, rusqlite::Error> {
175    Uuid::parse_str(s).map_err(|e| {
176        rusqlite::Error::FromSqlConversionFailure(0, rusqlite::types::Type::Text, Box::new(e))
177    })
178}
179
180fn build_entity_where(
181    namespace: &str,
182    filter: &EntityFilter,
183) -> (String, Vec<Box<dyn rusqlite::types::ToSql>>) {
184    let mut conditions: Vec<String> = vec![
185        "namespace = ?1".to_string(),
186        "deleted_at IS NULL".to_string(),
187    ];
188    let mut params: Vec<Box<dyn rusqlite::types::ToSql>> = vec![Box::new(namespace.to_string())];
189
190    if !filter.ids.is_empty() {
191        let placeholders: Vec<String> = filter
192            .ids
193            .iter()
194            .map(|id| {
195                params.push(Box::new(id.to_string()));
196                format!("?{}", params.len())
197            })
198            .collect();
199        conditions.push(format!("id IN ({})", placeholders.join(", ")));
200    }
201
202    if !filter.kinds.is_empty() {
203        let placeholders: Vec<String> = filter
204            .kinds
205            .iter()
206            .map(|k| {
207                params.push(Box::new(k.clone()));
208                format!("?{}", params.len())
209            })
210            .collect();
211        conditions.push(format!("kind IN ({})", placeholders.join(", ")));
212    }
213
214    if !filter.entity_types.is_empty() {
215        let placeholders: Vec<String> = filter
216            .entity_types
217            .iter()
218            .map(|t| {
219                params.push(Box::new(t.clone()));
220                format!("?{}", params.len())
221            })
222            .collect();
223        conditions.push(format!("entity_type IN ({})", placeholders.join(", ")));
224    }
225
226    if let Some(ref prefix) = filter.name_prefix {
227        params.push(Box::new(format!("{}%", prefix)));
228        conditions.push(format!("name LIKE ?{}", params.len()));
229    }
230
231    if !filter.tags_any.is_empty() {
232        let placeholders: Vec<String> = filter
233            .tags_any
234            .iter()
235            .map(|t| {
236                // Normalise to lowercase so the comparison is case-insensitive
237                // domain filter must be case-insensitive.
238                params.push(Box::new(t.to_lowercase()));
239                format!("?{}", params.len())
240            })
241            .collect();
242        conditions.push(format!(
243            "EXISTS (SELECT 1 FROM json_each(tags) WHERE LOWER(json_each.value) IN ({}))",
244            placeholders.join(", ")
245        ));
246    }
247
248    let clause = format!(" WHERE {}", conditions.join(" AND "));
249    (clause, params)
250}
251
252// =============================================================================
253// EntityStore implementation
254// =============================================================================
255
256#[async_trait]
257impl EntityStore for SqlEntityStore {
258    async fn upsert_entity(&self, entity: Entity) -> Result<(), StorageError> {
259        let namespace = entity.namespace.clone();
260        let id_str = entity.id.to_string();
261        let properties_str = entity
262            .properties
263            .as_ref()
264            .map(|v| serde_json::to_string(v).unwrap_or_default());
265        let tags_str = serde_json::to_string(&entity.tags).unwrap_or_else(|_| "[]".to_string());
266
267        let merged_into_str = entity.merged_into.map(|u| u.to_string());
268        let merge_event_id_str = entity.merge_event_id.map(|u| u.to_string());
269
270        self.with_writer("upsert_entity", move |conn| {
271            conn.execute(
272                "INSERT OR REPLACE INTO entities \
273                 (id, namespace, kind, entity_type, name, description, properties, tags, \
274                  created_at, updated_at, deleted_at, merged_into, merge_event_id) \
275                 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13)",
276                rusqlite::params![
277                    id_str,
278                    namespace,
279                    entity.kind,
280                    entity.entity_type,
281                    entity.name,
282                    entity.description,
283                    properties_str,
284                    tags_str,
285                    entity.created_at,
286                    entity.updated_at,
287                    entity.deleted_at,
288                    merged_into_str,
289                    merge_event_id_str,
290                ],
291            )?;
292            Ok(())
293        })
294        .await
295    }
296
297    async fn upsert_entities(
298        &self,
299        entities: Vec<Entity>,
300    ) -> Result<BatchWriteSummary, StorageError> {
301        let attempted = entities.len() as u64;
302
303        self.with_writer("upsert_entities", move |conn| {
304            conn.execute_batch("BEGIN IMMEDIATE")?;
305            let mut affected = 0u64;
306            let mut failed = 0u64;
307            let mut first_error = String::new();
308
309            for entity in &entities {
310                let id_str = entity.id.to_string();
311                let properties_str = entity
312                    .properties
313                    .as_ref()
314                    .map(|v| serde_json::to_string(v).unwrap_or_default());
315                let tags_str =
316                    serde_json::to_string(&entity.tags).unwrap_or_else(|_| "[]".to_string());
317
318                let merged_into_str = entity.merged_into.map(|u| u.to_string());
319                let merge_event_id_str = entity.merge_event_id.map(|u| u.to_string());
320                match conn.execute(
321                    "INSERT OR REPLACE INTO entities \
322                     (id, namespace, kind, entity_type, name, description, properties, tags, \
323                      created_at, updated_at, deleted_at, merged_into, merge_event_id) \
324                     VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13)",
325                    rusqlite::params![
326                        id_str,
327                        &entity.namespace,
328                        entity.kind,
329                        entity.entity_type,
330                        entity.name,
331                        entity.description,
332                        properties_str,
333                        tags_str,
334                        entity.created_at,
335                        entity.updated_at,
336                        entity.deleted_at,
337                        merged_into_str,
338                        merge_event_id_str,
339                    ],
340                ) {
341                    Ok(_) => affected += 1,
342                    Err(e) => {
343                        if first_error.is_empty() {
344                            first_error = e.to_string();
345                        }
346                        failed += 1;
347                    }
348                }
349            }
350
351            if let Err(e) = conn.execute_batch("COMMIT") {
352                let _ = conn.execute_batch("ROLLBACK");
353                return Err(e);
354            }
355            Ok(BatchWriteSummary {
356                attempted,
357                affected,
358                failed,
359                first_error,
360            })
361        })
362        .await
363    }
364
365    async fn get_entity(&self, id: Uuid) -> Result<Option<Entity>, StorageError> {
366        let id_str = id.to_string();
367
368        self.with_reader("get_entity", move |conn| {
369            let mut stmt = conn.prepare(
370                "SELECT id, namespace, kind, entity_type, name, description, properties, tags, \
371                 created_at, updated_at, deleted_at, merged_into, merge_event_id \
372                 FROM entities WHERE id = ?1 AND deleted_at IS NULL",
373            )?;
374            let mut rows = stmt.query(rusqlite::params![id_str])?;
375            match rows.next()? {
376                Some(row) => Ok(Some(read_entity(row)?)),
377                None => Ok(None),
378            }
379        })
380        .await
381    }
382
383    async fn delete_entity(&self, id: Uuid, mode: DeleteMode) -> Result<bool, StorageError> {
384        let id_str = id.to_string();
385
386        match mode {
387            DeleteMode::Soft => {
388                self.with_writer("delete_entity_soft", move |conn| {
389                    let now = chrono::Utc::now().timestamp_micros();
390                    let deleted = conn.execute(
391                        "UPDATE entities SET deleted_at = ?1 \
392                         WHERE id = ?2 AND deleted_at IS NULL",
393                        rusqlite::params![now, id_str],
394                    )?;
395                    Ok(deleted > 0)
396                })
397                .await
398            }
399            DeleteMode::Hard => {
400                self.with_writer("delete_entity_hard", move |conn| {
401                    let deleted = conn.execute(
402                        "DELETE FROM entities WHERE id = ?1",
403                        rusqlite::params![id_str],
404                    )?;
405                    Ok(deleted > 0)
406                })
407                .await
408            }
409        }
410    }
411
412    async fn query_entities(
413        &self,
414        namespace: &str,
415        filter: EntityFilter,
416        page: PageRequest,
417    ) -> Result<Page<Entity>, StorageError> {
418        let namespace = namespace.to_string();
419
420        self.with_reader("query_entities", move |conn| {
421            let (count_sql, count_params) = build_entity_where(&namespace, &filter);
422            let total: i64 = {
423                let sql = format!("SELECT COUNT(*) FROM entities{}", count_sql);
424                let mut stmt = conn.prepare(&sql)?;
425                let param_refs: Vec<&dyn rusqlite::types::ToSql> =
426                    count_params.iter().map(|p| p.as_ref()).collect();
427                stmt.query_row(param_refs.as_slice(), |row| row.get(0))?
428            };
429
430            let (where_sql, mut data_params) = build_entity_where(&namespace, &filter);
431            data_params.push(Box::new(page.limit as i64));
432            data_params.push(Box::new(page.offset as i64));
433
434            let limit_idx = data_params.len() - 1;
435            let offset_idx = data_params.len();
436
437            let data_sql = format!(
438                "SELECT id, namespace, kind, entity_type, name, description, properties, tags, \
439                 created_at, updated_at, deleted_at, merged_into, merge_event_id \
440                 FROM entities{} ORDER BY created_at DESC LIMIT ?{} OFFSET ?{}",
441                where_sql, limit_idx, offset_idx,
442            );
443
444            let mut stmt = conn.prepare(&data_sql)?;
445            let param_refs: Vec<&dyn rusqlite::types::ToSql> =
446                data_params.iter().map(|p| p.as_ref()).collect();
447            let rows = stmt.query_map(param_refs.as_slice(), read_entity)?;
448
449            let mut items = Vec::new();
450            for row in rows {
451                items.push(row?);
452            }
453
454            Ok(Page {
455                items,
456                total: Some(total as u64),
457            })
458        })
459        .await
460    }
461
462    async fn count_entities(
463        &self,
464        namespace: &str,
465        filter: EntityFilter,
466    ) -> Result<u64, StorageError> {
467        let namespace = namespace.to_string();
468
469        self.with_reader("count_entities", move |conn| {
470            let (where_sql, params) = build_entity_where(&namespace, &filter);
471            let sql = format!("SELECT COUNT(*) FROM entities{}", where_sql);
472            let mut stmt = conn.prepare(&sql)?;
473            let param_refs: Vec<&dyn rusqlite::types::ToSql> =
474                params.iter().map(|p| p.as_ref()).collect();
475            let count: i64 = stmt.query_row(param_refs.as_slice(), |row| row.get(0))?;
476            Ok(count as u64)
477        })
478        .await
479    }
480}
481
482// =============================================================================
483// DDL
484// =============================================================================
485
486const ENTITIES_DDL: &str = "\
487    CREATE TABLE IF NOT EXISTS entities (\
488        id TEXT PRIMARY KEY,\
489        namespace TEXT NOT NULL,\
490        kind TEXT NOT NULL,\
491        entity_type TEXT,\
492        name TEXT NOT NULL,\
493        description TEXT,\
494        properties TEXT,\
495        tags TEXT NOT NULL DEFAULT '[]',\
496        created_at INTEGER NOT NULL,\
497        updated_at INTEGER NOT NULL,\
498        deleted_at INTEGER,\
499        merged_into TEXT,\
500        merge_event_id TEXT\
501    );\
502    CREATE INDEX IF NOT EXISTS idx_entities_namespace ON entities(namespace);\
503    CREATE INDEX IF NOT EXISTS idx_entities_kind ON entities(namespace, kind);\
504    CREATE INDEX IF NOT EXISTS idx_entities_kind_entity_type ON entities(namespace, kind, entity_type);\
505    CREATE INDEX IF NOT EXISTS idx_entities_name ON entities(namespace, name);\
506    CREATE INDEX IF NOT EXISTS idx_entities_created ON entities(created_at DESC);\
507    CREATE INDEX IF NOT EXISTS idx_entities_merged_into ON entities(namespace, merged_into);\
508";
509
510pub(crate) fn ensure_entities_schema(conn: &rusqlite::Connection) -> Result<(), rusqlite::Error> {
511    conn.execute_batch(ENTITIES_DDL)
512}
513
514#[cfg(test)]
515#[path = "entity_tests.rs"]
516mod tests;