1use super::{OutputEvent, OutputRecord, OutputRef, OutputStore, OutputStoreError};
15use async_trait::async_trait;
16use rusqlite::{params, OptionalExtension};
17use rusqlite_isle::{AsyncIsle, AsyncIsleDriver, IsleError};
18use std::path::Path;
19
20const SCHEMA_SQL: &str = "\
21CREATE TABLE IF NOT EXISTS outputs (\
22 id TEXT PRIMARY KEY, \
23 task_id TEXT NOT NULL, \
24 attempt INTEGER NOT NULL, \
25 producer_agent TEXT NOT NULL, \
26 event_json TEXT NOT NULL, \
27 parent_refs_json TEXT NOT NULL, \
28 seq INTEGER NOT NULL\
29);\
30CREATE INDEX IF NOT EXISTS ix_outputs_attempt ON outputs(task_id, attempt, seq);\
31CREATE INDEX IF NOT EXISTS ix_outputs_producer ON outputs(producer_agent, seq);\
32";
33
34pub struct SqliteOutputStore {
36 isle: AsyncIsle,
37}
38
39impl SqliteOutputStore {
40 pub async fn open(path: impl AsRef<Path>) -> Result<(Self, AsyncIsleDriver), OutputStoreError> {
42 let (isle, driver) = AsyncIsle::spawn(path.as_ref().to_path_buf(), |conn| {
43 conn.execute_batch(SCHEMA_SQL)
44 })
45 .await
46 .map_err(map_isle_err)?;
47 Ok((Self { isle }, driver))
48 }
49
50 pub async fn open_in_memory() -> Result<(Self, AsyncIsleDriver), OutputStoreError> {
52 let (isle, driver) = AsyncIsle::open_in_memory(|conn| conn.execute_batch(SCHEMA_SQL))
53 .await
54 .map_err(map_isle_err)?;
55 Ok((Self { isle }, driver))
56 }
57}
58
59fn map_isle_err(e: IsleError) -> OutputStoreError {
60 OutputStoreError::Internal(format!("sqlite: {e}"))
61}
62
63fn decode_record(
64 id: String,
65 task_id: String,
66 attempt: i64,
67 producer_agent: String,
68 event_json: String,
69 parent_refs_json: String,
70) -> Result<OutputRecord, OutputStoreError> {
71 let event: OutputEvent = serde_json::from_str(&event_json)
72 .map_err(|e| OutputStoreError::Internal(format!("decode event: {e}")))?;
73 let parent_refs: Vec<OutputRef> = serde_json::from_str(&parent_refs_json)
74 .map_err(|e| OutputStoreError::Internal(format!("decode parent_refs: {e}")))?;
75 Ok(OutputRecord {
76 id: OutputRef(id),
77 task_id,
78 attempt: attempt as u32,
79 producer_agent,
80 event,
81 parent_refs,
82 })
83}
84
85#[async_trait]
86impl OutputStore for SqliteOutputStore {
87 async fn append(
88 &self,
89 task_id: &str,
90 attempt: u32,
91 producer_agent: &str,
92 event: OutputEvent,
93 parent_refs: Vec<OutputRef>,
94 ) -> Result<OutputRef, OutputStoreError> {
95 let id = OutputRef::new();
96 let id_str = id.0.clone();
97 let task_id = task_id.to_string();
98 let attempt = attempt as i64;
99 let producer_agent = producer_agent.to_string();
100 let event_json = serde_json::to_string(&event)
101 .map_err(|e| OutputStoreError::Internal(format!("encode event: {e}")))?;
102 let parent_refs_json = serde_json::to_string(&parent_refs)
103 .map_err(|e| OutputStoreError::Internal(format!("encode parent_refs: {e}")))?;
104
105 self.isle
106 .call(move |conn| {
107 let tx = conn.transaction()?;
108 let seq: i64 = tx.query_row(
109 "SELECT COALESCE(MAX(seq), 0) + 1 FROM outputs",
110 [],
111 |row| row.get(0),
112 )?;
113 tx.execute(
114 "INSERT INTO outputs (id, task_id, attempt, producer_agent, event_json, \
115 parent_refs_json, seq) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
116 params![
117 id_str,
118 task_id,
119 attempt,
120 producer_agent,
121 event_json,
122 parent_refs_json,
123 seq,
124 ],
125 )?;
126 tx.commit()?;
127 Ok(())
128 })
129 .await
130 .map_err(map_isle_err)?;
131 Ok(id)
132 }
133
134 async fn get(&self, id: &OutputRef) -> Result<OutputRecord, OutputStoreError> {
135 let id_str = id.0.clone();
136 let id_for_notfound = id.0.clone();
137 let row = self
138 .isle
139 .call(move |conn| {
140 conn.query_row(
141 "SELECT id, task_id, attempt, producer_agent, event_json, parent_refs_json \
142 FROM outputs WHERE id = ?1",
143 params![id_str],
144 |row| {
145 Ok((
146 row.get::<_, String>(0)?,
147 row.get::<_, String>(1)?,
148 row.get::<_, i64>(2)?,
149 row.get::<_, String>(3)?,
150 row.get::<_, String>(4)?,
151 row.get::<_, String>(5)?,
152 ))
153 },
154 )
155 .optional()
156 })
157 .await
158 .map_err(map_isle_err)?;
159 match row {
160 Some((id, task_id, attempt, producer, event, parent)) => {
161 decode_record(id, task_id, attempt, producer, event, parent)
162 }
163 None => Err(OutputStoreError::NotFound(id_for_notfound)),
164 }
165 }
166
167 async fn get_latest_by_name(&self, name: &str) -> Result<OutputRecord, OutputStoreError> {
168 let name_str = name.to_string();
169 let name_for_notfound = name.to_string();
170 let row = self
171 .isle
172 .call(move |conn| {
173 conn.query_row(
174 "SELECT id, task_id, attempt, producer_agent, event_json, parent_refs_json \
175 FROM outputs WHERE producer_agent = ?1 ORDER BY seq DESC LIMIT 1",
176 params![name_str],
177 |row| {
178 Ok((
179 row.get::<_, String>(0)?,
180 row.get::<_, String>(1)?,
181 row.get::<_, i64>(2)?,
182 row.get::<_, String>(3)?,
183 row.get::<_, String>(4)?,
184 row.get::<_, String>(5)?,
185 ))
186 },
187 )
188 .optional()
189 })
190 .await
191 .map_err(map_isle_err)?;
192 match row {
193 Some((id, task_id, attempt, producer, event, parent)) => {
194 decode_record(id, task_id, attempt, producer, event, parent)
195 }
196 None => Err(OutputStoreError::NotFound(name_for_notfound)),
197 }
198 }
199
200 async fn list_for_attempt(
201 &self,
202 task_id: &str,
203 attempt: u32,
204 ) -> Result<Vec<OutputRecord>, OutputStoreError> {
205 let task_id = task_id.to_string();
206 let attempt = attempt as i64;
207 let rows = self
208 .isle
209 .call(move |conn| {
210 let mut stmt = conn.prepare(
211 "SELECT id, task_id, attempt, producer_agent, event_json, parent_refs_json \
212 FROM outputs WHERE task_id = ?1 AND attempt = ?2 ORDER BY seq ASC",
213 )?;
214 let iter = stmt.query_map(params![task_id, attempt], |row| {
215 Ok((
216 row.get::<_, String>(0)?,
217 row.get::<_, String>(1)?,
218 row.get::<_, i64>(2)?,
219 row.get::<_, String>(3)?,
220 row.get::<_, String>(4)?,
221 row.get::<_, String>(5)?,
222 ))
223 })?;
224 let mut out = Vec::new();
225 for r in iter {
226 out.push(r?);
227 }
228 Ok(out)
229 })
230 .await
231 .map_err(map_isle_err)?;
232 rows.into_iter()
233 .map(|(id, task_id, attempt, producer, event, parent)| {
234 decode_record(id, task_id, attempt, producer, event, parent)
235 })
236 .collect()
237 }
238}
239
240#[cfg(test)]
241mod tests {
242 use super::*;
243 use crate::store::output::ContentRef;
244
245 fn mk_final(text: &str, ok: bool) -> OutputEvent {
246 OutputEvent::Final {
247 content: ContentRef::inline_text(text),
248 ok,
249 }
250 }
251
252 #[tokio::test]
253 async fn append_then_get_roundtrip() {
254 let (s, driver) = SqliteOutputStore::open_in_memory().await.unwrap();
255 let id = s
256 .append("task-1", 1, "producer-a", mk_final("hello", true), vec![])
257 .await
258 .unwrap();
259 let got = s.get(&id).await.unwrap();
260 assert_eq!(got.id, id);
261 assert_eq!(got.task_id, "task-1");
262 assert_eq!(got.attempt, 1);
263 assert_eq!(got.producer_agent, "producer-a");
264 match got.event {
265 OutputEvent::Final { ok, .. } => assert!(ok),
266 other => panic!("unexpected: {other:?}"),
267 }
268 drop(s);
269 driver.shutdown().await.unwrap();
270 }
271
272 #[tokio::test]
273 async fn get_not_found_returns_error() {
274 let (s, driver) = SqliteOutputStore::open_in_memory().await.unwrap();
275 let err = s.get(&OutputRef("missing".into())).await.unwrap_err();
276 assert!(matches!(err, OutputStoreError::NotFound(_)));
277 drop(s);
278 driver.shutdown().await.unwrap();
279 }
280
281 #[tokio::test]
282 async fn list_for_attempt_orders_by_insertion() {
283 let (s, driver) = SqliteOutputStore::open_in_memory().await.unwrap();
284 let a = s
285 .append("t", 1, "p1", mk_final("a", true), vec![])
286 .await
287 .unwrap();
288 let b = s
289 .append("t", 1, "p2", mk_final("b", true), vec![])
290 .await
291 .unwrap();
292 let _ = s
294 .append("t", 2, "p1", mk_final("other-attempt", true), vec![])
295 .await
296 .unwrap();
297 let c = s
298 .append("t", 1, "p3", mk_final("c", true), vec![])
299 .await
300 .unwrap();
301
302 let listed = s.list_for_attempt("t", 1).await.unwrap();
303 let ids: Vec<_> = listed.iter().map(|r| r.id.clone()).collect();
304 assert_eq!(ids, vec![a, b, c]);
305 drop(s);
306 driver.shutdown().await.unwrap();
307 }
308
309 #[tokio::test]
310 async fn get_latest_by_name_returns_newest_emit() {
311 let (s, driver) = SqliteOutputStore::open_in_memory().await.unwrap();
312 let _ = s
313 .append("t", 1, "same-producer", mk_final("v1", true), vec![])
314 .await
315 .unwrap();
316 let _ = s
317 .append("t", 1, "other-producer", mk_final("unrelated", true), vec![])
318 .await
319 .unwrap();
320 let latest_id = s
321 .append("t", 2, "same-producer", mk_final("v2", true), vec![])
322 .await
323 .unwrap();
324 let got = s.get_latest_by_name("same-producer").await.unwrap();
325 assert_eq!(got.id, latest_id);
326 drop(s);
327 driver.shutdown().await.unwrap();
328 }
329
330 #[tokio::test]
331 async fn get_latest_by_name_unknown_returns_not_found() {
332 let (s, driver) = SqliteOutputStore::open_in_memory().await.unwrap();
333 let err = s.get_latest_by_name("nobody").await.unwrap_err();
334 assert!(matches!(err, OutputStoreError::NotFound(_)));
335 drop(s);
336 driver.shutdown().await.unwrap();
337 }
338
339 #[tokio::test]
340 async fn parent_refs_are_persisted() {
341 let (s, driver) = SqliteOutputStore::open_in_memory().await.unwrap();
342 let a = s
343 .append("t", 1, "p", mk_final("parent", true), vec![])
344 .await
345 .unwrap();
346 let b = s
347 .append("t", 1, "p", mk_final("child", true), vec![a.clone()])
348 .await
349 .unwrap();
350 let got = s.get(&b).await.unwrap();
351 assert_eq!(got.parent_refs, vec![a]);
352 drop(s);
353 driver.shutdown().await.unwrap();
354 }
355
356 #[tokio::test]
357 async fn persists_across_reopen() {
358 let dir = tempfile::tempdir().unwrap();
359 let path = dir.path().join("outputs.db");
360 let id;
361 {
362 let (s, driver) = SqliteOutputStore::open(&path).await.unwrap();
363 id = s
364 .append("keep", 1, "p", mk_final("body", true), vec![])
365 .await
366 .unwrap();
367 drop(s);
368 driver.shutdown().await.unwrap();
369 }
370 let (s, driver) = SqliteOutputStore::open(&path).await.unwrap();
371 let got = s.get(&id).await.unwrap();
372 assert_eq!(got.task_id, "keep");
373 drop(s);
374 driver.shutdown().await.unwrap();
375 }
376}