Skip to main content

mlua_swarm/store/output/
sqlite.rs

1//! `SqliteOutputStore` — SQLite-backed [`OutputStore`].
2//!
3//! One row per emit (`OutputRecord`). `event` and `parent_refs` are stored
4//! as JSON blobs (both types already carry `Serialize + Deserialize`), so
5//! adding a new [`OutputEvent`] variant does not require a schema migration.
6//!
7//! Ordering guarantees:
8//!
9//! - `list_for_attempt` returns rows in insertion order (per the trait
10//!   contract) via an autoincrementing `seq` column.
11//! - `get_latest_by_name` picks the row with the largest `seq` for a given
12//!   `producer_agent`.
13
14use super::{OutputEvent, OutputRecord, OutputRef, OutputStore, OutputStoreError};
15use async_trait::async_trait;
16use rusqlite::{params, OptionalExtension};
17use rusqlite_isle::{AsyncIsle, AsyncIsleDriver, IsleError};
18use std::path::Path;
19
20const SCHEMA_SQL: &str = "\
21CREATE TABLE IF NOT EXISTS outputs (\
22  id             TEXT PRIMARY KEY, \
23  task_id        TEXT NOT NULL, \
24  attempt        INTEGER NOT NULL, \
25  producer_agent TEXT NOT NULL, \
26  event_json     TEXT NOT NULL, \
27  parent_refs_json TEXT NOT NULL, \
28  seq            INTEGER NOT NULL\
29);\
30CREATE INDEX IF NOT EXISTS ix_outputs_attempt ON outputs(task_id, attempt, seq);\
31CREATE INDEX IF NOT EXISTS ix_outputs_producer ON outputs(producer_agent, seq);\
32";
33
34/// SQLite-backed [`OutputStore`].
35pub struct SqliteOutputStore {
36    isle: AsyncIsle,
37}
38
39impl SqliteOutputStore {
40    /// Open (or create) a SQLite file and apply the schema.
41    pub async fn open(path: impl AsRef<Path>) -> Result<(Self, AsyncIsleDriver), OutputStoreError> {
42        let (isle, driver) = AsyncIsle::spawn(path.as_ref().to_path_buf(), |conn| {
43            conn.execute_batch(SCHEMA_SQL)
44        })
45        .await
46        .map_err(map_isle_err)?;
47        Ok((Self { isle }, driver))
48    }
49
50    /// Open an ephemeral in-memory database (tests).
51    pub async fn open_in_memory() -> Result<(Self, AsyncIsleDriver), OutputStoreError> {
52        let (isle, driver) = AsyncIsle::open_in_memory(|conn| conn.execute_batch(SCHEMA_SQL))
53            .await
54            .map_err(map_isle_err)?;
55        Ok((Self { isle }, driver))
56    }
57}
58
59fn map_isle_err(e: IsleError) -> OutputStoreError {
60    OutputStoreError::Internal(format!("sqlite: {e}"))
61}
62
63fn decode_record(
64    id: String,
65    task_id: String,
66    attempt: i64,
67    producer_agent: String,
68    event_json: String,
69    parent_refs_json: String,
70) -> Result<OutputRecord, OutputStoreError> {
71    let event: OutputEvent = serde_json::from_str(&event_json)
72        .map_err(|e| OutputStoreError::Internal(format!("decode event: {e}")))?;
73    let parent_refs: Vec<OutputRef> = serde_json::from_str(&parent_refs_json)
74        .map_err(|e| OutputStoreError::Internal(format!("decode parent_refs: {e}")))?;
75    Ok(OutputRecord {
76        id: OutputRef(id),
77        task_id,
78        attempt: attempt as u32,
79        producer_agent,
80        event,
81        parent_refs,
82    })
83}
84
85#[async_trait]
86impl OutputStore for SqliteOutputStore {
87    async fn append(
88        &self,
89        task_id: &str,
90        attempt: u32,
91        producer_agent: &str,
92        event: OutputEvent,
93        parent_refs: Vec<OutputRef>,
94    ) -> Result<OutputRef, OutputStoreError> {
95        let id = OutputRef::new();
96        let id_str = id.0.clone();
97        let task_id = task_id.to_string();
98        let attempt = attempt as i64;
99        let producer_agent = producer_agent.to_string();
100        let event_json = serde_json::to_string(&event)
101            .map_err(|e| OutputStoreError::Internal(format!("encode event: {e}")))?;
102        let parent_refs_json = serde_json::to_string(&parent_refs)
103            .map_err(|e| OutputStoreError::Internal(format!("encode parent_refs: {e}")))?;
104
105        self.isle
106            .call(move |conn| {
107                let tx = conn.transaction()?;
108                let seq: i64 =
109                    tx.query_row("SELECT COALESCE(MAX(seq), 0) + 1 FROM outputs", [], |row| {
110                        row.get(0)
111                    })?;
112                tx.execute(
113                    "INSERT INTO outputs (id, task_id, attempt, producer_agent, event_json, \
114                     parent_refs_json, seq) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
115                    params![
116                        id_str,
117                        task_id,
118                        attempt,
119                        producer_agent,
120                        event_json,
121                        parent_refs_json,
122                        seq,
123                    ],
124                )?;
125                tx.commit()?;
126                Ok(())
127            })
128            .await
129            .map_err(map_isle_err)?;
130        Ok(id)
131    }
132
133    async fn get(&self, id: &OutputRef) -> Result<OutputRecord, OutputStoreError> {
134        let id_str = id.0.clone();
135        let id_for_notfound = id.0.clone();
136        let row = self
137            .isle
138            .call(move |conn| {
139                conn.query_row(
140                    "SELECT id, task_id, attempt, producer_agent, event_json, parent_refs_json \
141                     FROM outputs WHERE id = ?1",
142                    params![id_str],
143                    |row| {
144                        Ok((
145                            row.get::<_, String>(0)?,
146                            row.get::<_, String>(1)?,
147                            row.get::<_, i64>(2)?,
148                            row.get::<_, String>(3)?,
149                            row.get::<_, String>(4)?,
150                            row.get::<_, String>(5)?,
151                        ))
152                    },
153                )
154                .optional()
155            })
156            .await
157            .map_err(map_isle_err)?;
158        match row {
159            Some((id, task_id, attempt, producer, event, parent)) => {
160                decode_record(id, task_id, attempt, producer, event, parent)
161            }
162            None => Err(OutputStoreError::NotFound(id_for_notfound)),
163        }
164    }
165
166    async fn get_latest_by_name(&self, name: &str) -> Result<OutputRecord, OutputStoreError> {
167        let name_str = name.to_string();
168        let name_for_notfound = name.to_string();
169        let row = self
170            .isle
171            .call(move |conn| {
172                conn.query_row(
173                    "SELECT id, task_id, attempt, producer_agent, event_json, parent_refs_json \
174                     FROM outputs WHERE producer_agent = ?1 ORDER BY seq DESC LIMIT 1",
175                    params![name_str],
176                    |row| {
177                        Ok((
178                            row.get::<_, String>(0)?,
179                            row.get::<_, String>(1)?,
180                            row.get::<_, i64>(2)?,
181                            row.get::<_, String>(3)?,
182                            row.get::<_, String>(4)?,
183                            row.get::<_, String>(5)?,
184                        ))
185                    },
186                )
187                .optional()
188            })
189            .await
190            .map_err(map_isle_err)?;
191        match row {
192            Some((id, task_id, attempt, producer, event, parent)) => {
193                decode_record(id, task_id, attempt, producer, event, parent)
194            }
195            None => Err(OutputStoreError::NotFound(name_for_notfound)),
196        }
197    }
198
199    async fn list_for_attempt(
200        &self,
201        task_id: &str,
202        attempt: u32,
203    ) -> Result<Vec<OutputRecord>, OutputStoreError> {
204        let task_id = task_id.to_string();
205        let attempt = attempt as i64;
206        let rows = self
207            .isle
208            .call(move |conn| {
209                let mut stmt = conn.prepare(
210                    "SELECT id, task_id, attempt, producer_agent, event_json, parent_refs_json \
211                     FROM outputs WHERE task_id = ?1 AND attempt = ?2 ORDER BY seq ASC",
212                )?;
213                let iter = stmt.query_map(params![task_id, attempt], |row| {
214                    Ok((
215                        row.get::<_, String>(0)?,
216                        row.get::<_, String>(1)?,
217                        row.get::<_, i64>(2)?,
218                        row.get::<_, String>(3)?,
219                        row.get::<_, String>(4)?,
220                        row.get::<_, String>(5)?,
221                    ))
222                })?;
223                let mut out = Vec::new();
224                for r in iter {
225                    out.push(r?);
226                }
227                Ok(out)
228            })
229            .await
230            .map_err(map_isle_err)?;
231        rows.into_iter()
232            .map(|(id, task_id, attempt, producer, event, parent)| {
233                decode_record(id, task_id, attempt, producer, event, parent)
234            })
235            .collect()
236    }
237}
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242    use crate::store::output::ContentRef;
243
244    fn mk_final(text: &str, ok: bool) -> OutputEvent {
245        OutputEvent::Final {
246            content: ContentRef::inline_text(text),
247            ok,
248        }
249    }
250
251    #[tokio::test]
252    async fn append_then_get_roundtrip() {
253        let (s, driver) = SqliteOutputStore::open_in_memory().await.unwrap();
254        let id = s
255            .append("task-1", 1, "producer-a", mk_final("hello", true), vec![])
256            .await
257            .unwrap();
258        let got = s.get(&id).await.unwrap();
259        assert_eq!(got.id, id);
260        assert_eq!(got.task_id, "task-1");
261        assert_eq!(got.attempt, 1);
262        assert_eq!(got.producer_agent, "producer-a");
263        match got.event {
264            OutputEvent::Final { ok, .. } => assert!(ok),
265            other => panic!("unexpected: {other:?}"),
266        }
267        drop(s);
268        driver.shutdown().await.unwrap();
269    }
270
271    #[tokio::test]
272    async fn get_not_found_returns_error() {
273        let (s, driver) = SqliteOutputStore::open_in_memory().await.unwrap();
274        let err = s.get(&OutputRef("missing".into())).await.unwrap_err();
275        assert!(matches!(err, OutputStoreError::NotFound(_)));
276        drop(s);
277        driver.shutdown().await.unwrap();
278    }
279
280    #[tokio::test]
281    async fn list_for_attempt_orders_by_insertion() {
282        let (s, driver) = SqliteOutputStore::open_in_memory().await.unwrap();
283        let a = s
284            .append("t", 1, "p1", mk_final("a", true), vec![])
285            .await
286            .unwrap();
287        let b = s
288            .append("t", 1, "p2", mk_final("b", true), vec![])
289            .await
290            .unwrap();
291        // Not part of the same attempt — must be skipped by the filter.
292        let _ = s
293            .append("t", 2, "p1", mk_final("other-attempt", true), vec![])
294            .await
295            .unwrap();
296        let c = s
297            .append("t", 1, "p3", mk_final("c", true), vec![])
298            .await
299            .unwrap();
300
301        let listed = s.list_for_attempt("t", 1).await.unwrap();
302        let ids: Vec<_> = listed.iter().map(|r| r.id.clone()).collect();
303        assert_eq!(ids, vec![a, b, c]);
304        drop(s);
305        driver.shutdown().await.unwrap();
306    }
307
308    #[tokio::test]
309    async fn get_latest_by_name_returns_newest_emit() {
310        let (s, driver) = SqliteOutputStore::open_in_memory().await.unwrap();
311        let _ = s
312            .append("t", 1, "same-producer", mk_final("v1", true), vec![])
313            .await
314            .unwrap();
315        let _ = s
316            .append(
317                "t",
318                1,
319                "other-producer",
320                mk_final("unrelated", true),
321                vec![],
322            )
323            .await
324            .unwrap();
325        let latest_id = s
326            .append("t", 2, "same-producer", mk_final("v2", true), vec![])
327            .await
328            .unwrap();
329        let got = s.get_latest_by_name("same-producer").await.unwrap();
330        assert_eq!(got.id, latest_id);
331        drop(s);
332        driver.shutdown().await.unwrap();
333    }
334
335    #[tokio::test]
336    async fn get_latest_by_name_unknown_returns_not_found() {
337        let (s, driver) = SqliteOutputStore::open_in_memory().await.unwrap();
338        let err = s.get_latest_by_name("nobody").await.unwrap_err();
339        assert!(matches!(err, OutputStoreError::NotFound(_)));
340        drop(s);
341        driver.shutdown().await.unwrap();
342    }
343
344    #[tokio::test]
345    async fn parent_refs_are_persisted() {
346        let (s, driver) = SqliteOutputStore::open_in_memory().await.unwrap();
347        let a = s
348            .append("t", 1, "p", mk_final("parent", true), vec![])
349            .await
350            .unwrap();
351        let b = s
352            .append("t", 1, "p", mk_final("child", true), vec![a.clone()])
353            .await
354            .unwrap();
355        let got = s.get(&b).await.unwrap();
356        assert_eq!(got.parent_refs, vec![a]);
357        drop(s);
358        driver.shutdown().await.unwrap();
359    }
360
361    #[tokio::test]
362    async fn persists_across_reopen() {
363        let dir = tempfile::tempdir().unwrap();
364        let path = dir.path().join("outputs.db");
365        let id;
366        {
367            let (s, driver) = SqliteOutputStore::open(&path).await.unwrap();
368            id = s
369                .append("keep", 1, "p", mk_final("body", true), vec![])
370                .await
371                .unwrap();
372            drop(s);
373            driver.shutdown().await.unwrap();
374        }
375        let (s, driver) = SqliteOutputStore::open(&path).await.unwrap();
376        let got = s.get(&id).await.unwrap();
377        assert_eq!(got.task_id, "keep");
378        drop(s);
379        driver.shutdown().await.unwrap();
380    }
381}