Skip to main content

mlua_swarm/store/output/
sqlite.rs

1//! `SqliteOutputStore` — SQLite-backed [`OutputStore`].
2//!
3//! One row per emit (`OutputRecord`). `event` and `parent_refs` are stored
4//! as JSON blobs (both types already carry `Serialize + Deserialize`), so
5//! adding a new [`OutputEvent`] variant does not require a schema migration.
6//!
7//! Ordering guarantees:
8//!
9//! - `list_for_attempt` returns rows in insertion order (per the trait
10//!   contract) via an autoincrementing `seq` column.
11//! - `get_latest_by_name` picks the row with the largest `seq` for a given
12//!   `producer_agent`.
13
14use super::{OutputEvent, OutputRecord, OutputRef, OutputStore, OutputStoreError};
15use async_trait::async_trait;
16use rusqlite::{params, OptionalExtension};
17use rusqlite_isle::{AsyncIsle, AsyncIsleDriver, IsleError};
18use std::path::Path;
19
20const SCHEMA_SQL: &str = "\
21CREATE TABLE IF NOT EXISTS outputs (\
22  id             TEXT PRIMARY KEY, \
23  task_id        TEXT NOT NULL, \
24  attempt        INTEGER NOT NULL, \
25  producer_agent TEXT NOT NULL, \
26  event_json     TEXT NOT NULL, \
27  parent_refs_json TEXT NOT NULL, \
28  seq            INTEGER NOT NULL\
29);\
30CREATE INDEX IF NOT EXISTS ix_outputs_attempt ON outputs(task_id, attempt, seq);\
31CREATE INDEX IF NOT EXISTS ix_outputs_producer ON outputs(producer_agent, seq);\
32";
33
34/// SQLite-backed [`OutputStore`].
35pub struct SqliteOutputStore {
36    isle: AsyncIsle,
37}
38
39impl SqliteOutputStore {
40    /// Open (or create) a SQLite file and apply the schema.
41    pub async fn open(path: impl AsRef<Path>) -> Result<(Self, AsyncIsleDriver), OutputStoreError> {
42        let (isle, driver) = AsyncIsle::spawn(path.as_ref().to_path_buf(), |conn| {
43            conn.execute_batch(SCHEMA_SQL)
44        })
45        .await
46        .map_err(map_isle_err)?;
47        Ok((Self { isle }, driver))
48    }
49
50    /// Open an ephemeral in-memory database (tests).
51    pub async fn open_in_memory() -> Result<(Self, AsyncIsleDriver), OutputStoreError> {
52        let (isle, driver) = AsyncIsle::open_in_memory(|conn| conn.execute_batch(SCHEMA_SQL))
53            .await
54            .map_err(map_isle_err)?;
55        Ok((Self { isle }, driver))
56    }
57}
58
59fn map_isle_err(e: IsleError) -> OutputStoreError {
60    OutputStoreError::Internal(format!("sqlite: {e}"))
61}
62
63fn decode_record(
64    id: String,
65    task_id: String,
66    attempt: i64,
67    producer_agent: String,
68    event_json: String,
69    parent_refs_json: String,
70) -> Result<OutputRecord, OutputStoreError> {
71    let event: OutputEvent = serde_json::from_str(&event_json)
72        .map_err(|e| OutputStoreError::Internal(format!("decode event: {e}")))?;
73    let parent_refs: Vec<OutputRef> = serde_json::from_str(&parent_refs_json)
74        .map_err(|e| OutputStoreError::Internal(format!("decode parent_refs: {e}")))?;
75    Ok(OutputRecord {
76        id: OutputRef(id),
77        task_id,
78        attempt: attempt as u32,
79        producer_agent,
80        event,
81        parent_refs,
82    })
83}
84
85#[async_trait]
86impl OutputStore for SqliteOutputStore {
87    async fn append(
88        &self,
89        task_id: &str,
90        attempt: u32,
91        producer_agent: &str,
92        event: OutputEvent,
93        parent_refs: Vec<OutputRef>,
94    ) -> Result<OutputRef, OutputStoreError> {
95        let id = OutputRef::new();
96        let id_str = id.0.clone();
97        let task_id = task_id.to_string();
98        let attempt = attempt as i64;
99        let producer_agent = producer_agent.to_string();
100        let event_json = serde_json::to_string(&event)
101            .map_err(|e| OutputStoreError::Internal(format!("encode event: {e}")))?;
102        let parent_refs_json = serde_json::to_string(&parent_refs)
103            .map_err(|e| OutputStoreError::Internal(format!("encode parent_refs: {e}")))?;
104
105        self.isle
106            .call(move |conn| {
107                let tx = conn.transaction()?;
108                let seq: i64 = tx.query_row(
109                    "SELECT COALESCE(MAX(seq), 0) + 1 FROM outputs",
110                    [],
111                    |row| row.get(0),
112                )?;
113                tx.execute(
114                    "INSERT INTO outputs (id, task_id, attempt, producer_agent, event_json, \
115                     parent_refs_json, seq) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
116                    params![
117                        id_str,
118                        task_id,
119                        attempt,
120                        producer_agent,
121                        event_json,
122                        parent_refs_json,
123                        seq,
124                    ],
125                )?;
126                tx.commit()?;
127                Ok(())
128            })
129            .await
130            .map_err(map_isle_err)?;
131        Ok(id)
132    }
133
134    async fn get(&self, id: &OutputRef) -> Result<OutputRecord, OutputStoreError> {
135        let id_str = id.0.clone();
136        let id_for_notfound = id.0.clone();
137        let row = self
138            .isle
139            .call(move |conn| {
140                conn.query_row(
141                    "SELECT id, task_id, attempt, producer_agent, event_json, parent_refs_json \
142                     FROM outputs WHERE id = ?1",
143                    params![id_str],
144                    |row| {
145                        Ok((
146                            row.get::<_, String>(0)?,
147                            row.get::<_, String>(1)?,
148                            row.get::<_, i64>(2)?,
149                            row.get::<_, String>(3)?,
150                            row.get::<_, String>(4)?,
151                            row.get::<_, String>(5)?,
152                        ))
153                    },
154                )
155                .optional()
156            })
157            .await
158            .map_err(map_isle_err)?;
159        match row {
160            Some((id, task_id, attempt, producer, event, parent)) => {
161                decode_record(id, task_id, attempt, producer, event, parent)
162            }
163            None => Err(OutputStoreError::NotFound(id_for_notfound)),
164        }
165    }
166
167    async fn get_latest_by_name(&self, name: &str) -> Result<OutputRecord, OutputStoreError> {
168        let name_str = name.to_string();
169        let name_for_notfound = name.to_string();
170        let row = self
171            .isle
172            .call(move |conn| {
173                conn.query_row(
174                    "SELECT id, task_id, attempt, producer_agent, event_json, parent_refs_json \
175                     FROM outputs WHERE producer_agent = ?1 ORDER BY seq DESC LIMIT 1",
176                    params![name_str],
177                    |row| {
178                        Ok((
179                            row.get::<_, String>(0)?,
180                            row.get::<_, String>(1)?,
181                            row.get::<_, i64>(2)?,
182                            row.get::<_, String>(3)?,
183                            row.get::<_, String>(4)?,
184                            row.get::<_, String>(5)?,
185                        ))
186                    },
187                )
188                .optional()
189            })
190            .await
191            .map_err(map_isle_err)?;
192        match row {
193            Some((id, task_id, attempt, producer, event, parent)) => {
194                decode_record(id, task_id, attempt, producer, event, parent)
195            }
196            None => Err(OutputStoreError::NotFound(name_for_notfound)),
197        }
198    }
199
200    async fn list_for_attempt(
201        &self,
202        task_id: &str,
203        attempt: u32,
204    ) -> Result<Vec<OutputRecord>, OutputStoreError> {
205        let task_id = task_id.to_string();
206        let attempt = attempt as i64;
207        let rows = self
208            .isle
209            .call(move |conn| {
210                let mut stmt = conn.prepare(
211                    "SELECT id, task_id, attempt, producer_agent, event_json, parent_refs_json \
212                     FROM outputs WHERE task_id = ?1 AND attempt = ?2 ORDER BY seq ASC",
213                )?;
214                let iter = stmt.query_map(params![task_id, attempt], |row| {
215                    Ok((
216                        row.get::<_, String>(0)?,
217                        row.get::<_, String>(1)?,
218                        row.get::<_, i64>(2)?,
219                        row.get::<_, String>(3)?,
220                        row.get::<_, String>(4)?,
221                        row.get::<_, String>(5)?,
222                    ))
223                })?;
224                let mut out = Vec::new();
225                for r in iter {
226                    out.push(r?);
227                }
228                Ok(out)
229            })
230            .await
231            .map_err(map_isle_err)?;
232        rows.into_iter()
233            .map(|(id, task_id, attempt, producer, event, parent)| {
234                decode_record(id, task_id, attempt, producer, event, parent)
235            })
236            .collect()
237    }
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243    use crate::store::output::ContentRef;
244
245    fn mk_final(text: &str, ok: bool) -> OutputEvent {
246        OutputEvent::Final {
247            content: ContentRef::inline_text(text),
248            ok,
249        }
250    }
251
252    #[tokio::test]
253    async fn append_then_get_roundtrip() {
254        let (s, driver) = SqliteOutputStore::open_in_memory().await.unwrap();
255        let id = s
256            .append("task-1", 1, "producer-a", mk_final("hello", true), vec![])
257            .await
258            .unwrap();
259        let got = s.get(&id).await.unwrap();
260        assert_eq!(got.id, id);
261        assert_eq!(got.task_id, "task-1");
262        assert_eq!(got.attempt, 1);
263        assert_eq!(got.producer_agent, "producer-a");
264        match got.event {
265            OutputEvent::Final { ok, .. } => assert!(ok),
266            other => panic!("unexpected: {other:?}"),
267        }
268        drop(s);
269        driver.shutdown().await.unwrap();
270    }
271
272    #[tokio::test]
273    async fn get_not_found_returns_error() {
274        let (s, driver) = SqliteOutputStore::open_in_memory().await.unwrap();
275        let err = s.get(&OutputRef("missing".into())).await.unwrap_err();
276        assert!(matches!(err, OutputStoreError::NotFound(_)));
277        drop(s);
278        driver.shutdown().await.unwrap();
279    }
280
281    #[tokio::test]
282    async fn list_for_attempt_orders_by_insertion() {
283        let (s, driver) = SqliteOutputStore::open_in_memory().await.unwrap();
284        let a = s
285            .append("t", 1, "p1", mk_final("a", true), vec![])
286            .await
287            .unwrap();
288        let b = s
289            .append("t", 1, "p2", mk_final("b", true), vec![])
290            .await
291            .unwrap();
292        // Not part of the same attempt — must be skipped by the filter.
293        let _ = s
294            .append("t", 2, "p1", mk_final("other-attempt", true), vec![])
295            .await
296            .unwrap();
297        let c = s
298            .append("t", 1, "p3", mk_final("c", true), vec![])
299            .await
300            .unwrap();
301
302        let listed = s.list_for_attempt("t", 1).await.unwrap();
303        let ids: Vec<_> = listed.iter().map(|r| r.id.clone()).collect();
304        assert_eq!(ids, vec![a, b, c]);
305        drop(s);
306        driver.shutdown().await.unwrap();
307    }
308
309    #[tokio::test]
310    async fn get_latest_by_name_returns_newest_emit() {
311        let (s, driver) = SqliteOutputStore::open_in_memory().await.unwrap();
312        let _ = s
313            .append("t", 1, "same-producer", mk_final("v1", true), vec![])
314            .await
315            .unwrap();
316        let _ = s
317            .append("t", 1, "other-producer", mk_final("unrelated", true), vec![])
318            .await
319            .unwrap();
320        let latest_id = s
321            .append("t", 2, "same-producer", mk_final("v2", true), vec![])
322            .await
323            .unwrap();
324        let got = s.get_latest_by_name("same-producer").await.unwrap();
325        assert_eq!(got.id, latest_id);
326        drop(s);
327        driver.shutdown().await.unwrap();
328    }
329
330    #[tokio::test]
331    async fn get_latest_by_name_unknown_returns_not_found() {
332        let (s, driver) = SqliteOutputStore::open_in_memory().await.unwrap();
333        let err = s.get_latest_by_name("nobody").await.unwrap_err();
334        assert!(matches!(err, OutputStoreError::NotFound(_)));
335        drop(s);
336        driver.shutdown().await.unwrap();
337    }
338
339    #[tokio::test]
340    async fn parent_refs_are_persisted() {
341        let (s, driver) = SqliteOutputStore::open_in_memory().await.unwrap();
342        let a = s
343            .append("t", 1, "p", mk_final("parent", true), vec![])
344            .await
345            .unwrap();
346        let b = s
347            .append("t", 1, "p", mk_final("child", true), vec![a.clone()])
348            .await
349            .unwrap();
350        let got = s.get(&b).await.unwrap();
351        assert_eq!(got.parent_refs, vec![a]);
352        drop(s);
353        driver.shutdown().await.unwrap();
354    }
355
356    #[tokio::test]
357    async fn persists_across_reopen() {
358        let dir = tempfile::tempdir().unwrap();
359        let path = dir.path().join("outputs.db");
360        let id;
361        {
362            let (s, driver) = SqliteOutputStore::open(&path).await.unwrap();
363            id = s
364                .append("keep", 1, "p", mk_final("body", true), vec![])
365                .await
366                .unwrap();
367            drop(s);
368            driver.shutdown().await.unwrap();
369        }
370        let (s, driver) = SqliteOutputStore::open(&path).await.unwrap();
371        let got = s.get(&id).await.unwrap();
372        assert_eq!(got.task_id, "keep");
373        drop(s);
374        driver.shutdown().await.unwrap();
375    }
376}