1use super::{OutputEvent, OutputRecord, OutputRef, OutputStore, OutputStoreError};
15use async_trait::async_trait;
16use rusqlite::{params, OptionalExtension};
17use rusqlite_isle::{AsyncIsle, AsyncIsleDriver, IsleError};
18use std::path::Path;
19
20const SCHEMA_SQL: &str = "\
21CREATE TABLE IF NOT EXISTS outputs (\
22 id TEXT PRIMARY KEY, \
23 task_id TEXT NOT NULL, \
24 attempt INTEGER NOT NULL, \
25 producer_agent TEXT NOT NULL, \
26 event_json TEXT NOT NULL, \
27 parent_refs_json TEXT NOT NULL, \
28 seq INTEGER NOT NULL\
29);\
30CREATE INDEX IF NOT EXISTS ix_outputs_attempt ON outputs(task_id, attempt, seq);\
31CREATE INDEX IF NOT EXISTS ix_outputs_producer ON outputs(producer_agent, seq);\
32";
33
34pub struct SqliteOutputStore {
36 isle: AsyncIsle,
37}
38
39impl SqliteOutputStore {
40 pub async fn open(path: impl AsRef<Path>) -> Result<(Self, AsyncIsleDriver), OutputStoreError> {
42 let (isle, driver) = AsyncIsle::spawn(path.as_ref().to_path_buf(), |conn| {
43 conn.execute_batch(SCHEMA_SQL)
44 })
45 .await
46 .map_err(map_isle_err)?;
47 Ok((Self { isle }, driver))
48 }
49
50 pub async fn open_in_memory() -> Result<(Self, AsyncIsleDriver), OutputStoreError> {
52 let (isle, driver) = AsyncIsle::open_in_memory(|conn| conn.execute_batch(SCHEMA_SQL))
53 .await
54 .map_err(map_isle_err)?;
55 Ok((Self { isle }, driver))
56 }
57}
58
59fn map_isle_err(e: IsleError) -> OutputStoreError {
60 OutputStoreError::Internal(format!("sqlite: {e}"))
61}
62
63fn decode_record(
64 id: String,
65 task_id: String,
66 attempt: i64,
67 producer_agent: String,
68 event_json: String,
69 parent_refs_json: String,
70) -> Result<OutputRecord, OutputStoreError> {
71 let event: OutputEvent = serde_json::from_str(&event_json)
72 .map_err(|e| OutputStoreError::Internal(format!("decode event: {e}")))?;
73 let parent_refs: Vec<OutputRef> = serde_json::from_str(&parent_refs_json)
74 .map_err(|e| OutputStoreError::Internal(format!("decode parent_refs: {e}")))?;
75 Ok(OutputRecord {
76 id: OutputRef(id),
77 task_id,
78 attempt: attempt as u32,
79 producer_agent,
80 event,
81 parent_refs,
82 })
83}
84
85#[async_trait]
86impl OutputStore for SqliteOutputStore {
87 async fn append(
88 &self,
89 task_id: &str,
90 attempt: u32,
91 producer_agent: &str,
92 event: OutputEvent,
93 parent_refs: Vec<OutputRef>,
94 ) -> Result<OutputRef, OutputStoreError> {
95 let id = OutputRef::new();
96 let id_str = id.0.clone();
97 let task_id = task_id.to_string();
98 let attempt = attempt as i64;
99 let producer_agent = producer_agent.to_string();
100 let event_json = serde_json::to_string(&event)
101 .map_err(|e| OutputStoreError::Internal(format!("encode event: {e}")))?;
102 let parent_refs_json = serde_json::to_string(&parent_refs)
103 .map_err(|e| OutputStoreError::Internal(format!("encode parent_refs: {e}")))?;
104
105 self.isle
106 .call(move |conn| {
107 let tx = conn.transaction()?;
108 let seq: i64 =
109 tx.query_row("SELECT COALESCE(MAX(seq), 0) + 1 FROM outputs", [], |row| {
110 row.get(0)
111 })?;
112 tx.execute(
113 "INSERT INTO outputs (id, task_id, attempt, producer_agent, event_json, \
114 parent_refs_json, seq) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
115 params![
116 id_str,
117 task_id,
118 attempt,
119 producer_agent,
120 event_json,
121 parent_refs_json,
122 seq,
123 ],
124 )?;
125 tx.commit()?;
126 Ok(())
127 })
128 .await
129 .map_err(map_isle_err)?;
130 Ok(id)
131 }
132
133 async fn get(&self, id: &OutputRef) -> Result<OutputRecord, OutputStoreError> {
134 let id_str = id.0.clone();
135 let id_for_notfound = id.0.clone();
136 let row = self
137 .isle
138 .call(move |conn| {
139 conn.query_row(
140 "SELECT id, task_id, attempt, producer_agent, event_json, parent_refs_json \
141 FROM outputs WHERE id = ?1",
142 params![id_str],
143 |row| {
144 Ok((
145 row.get::<_, String>(0)?,
146 row.get::<_, String>(1)?,
147 row.get::<_, i64>(2)?,
148 row.get::<_, String>(3)?,
149 row.get::<_, String>(4)?,
150 row.get::<_, String>(5)?,
151 ))
152 },
153 )
154 .optional()
155 })
156 .await
157 .map_err(map_isle_err)?;
158 match row {
159 Some((id, task_id, attempt, producer, event, parent)) => {
160 decode_record(id, task_id, attempt, producer, event, parent)
161 }
162 None => Err(OutputStoreError::NotFound(id_for_notfound)),
163 }
164 }
165
166 async fn get_latest_by_name(&self, name: &str) -> Result<OutputRecord, OutputStoreError> {
167 let name_str = name.to_string();
168 let name_for_notfound = name.to_string();
169 let row = self
170 .isle
171 .call(move |conn| {
172 conn.query_row(
173 "SELECT id, task_id, attempt, producer_agent, event_json, parent_refs_json \
174 FROM outputs WHERE producer_agent = ?1 ORDER BY seq DESC LIMIT 1",
175 params![name_str],
176 |row| {
177 Ok((
178 row.get::<_, String>(0)?,
179 row.get::<_, String>(1)?,
180 row.get::<_, i64>(2)?,
181 row.get::<_, String>(3)?,
182 row.get::<_, String>(4)?,
183 row.get::<_, String>(5)?,
184 ))
185 },
186 )
187 .optional()
188 })
189 .await
190 .map_err(map_isle_err)?;
191 match row {
192 Some((id, task_id, attempt, producer, event, parent)) => {
193 decode_record(id, task_id, attempt, producer, event, parent)
194 }
195 None => Err(OutputStoreError::NotFound(name_for_notfound)),
196 }
197 }
198
199 async fn list_for_attempt(
200 &self,
201 task_id: &str,
202 attempt: u32,
203 ) -> Result<Vec<OutputRecord>, OutputStoreError> {
204 let task_id = task_id.to_string();
205 let attempt = attempt as i64;
206 let rows = self
207 .isle
208 .call(move |conn| {
209 let mut stmt = conn.prepare(
210 "SELECT id, task_id, attempt, producer_agent, event_json, parent_refs_json \
211 FROM outputs WHERE task_id = ?1 AND attempt = ?2 ORDER BY seq ASC",
212 )?;
213 let iter = stmt.query_map(params![task_id, attempt], |row| {
214 Ok((
215 row.get::<_, String>(0)?,
216 row.get::<_, String>(1)?,
217 row.get::<_, i64>(2)?,
218 row.get::<_, String>(3)?,
219 row.get::<_, String>(4)?,
220 row.get::<_, String>(5)?,
221 ))
222 })?;
223 let mut out = Vec::new();
224 for r in iter {
225 out.push(r?);
226 }
227 Ok(out)
228 })
229 .await
230 .map_err(map_isle_err)?;
231 rows.into_iter()
232 .map(|(id, task_id, attempt, producer, event, parent)| {
233 decode_record(id, task_id, attempt, producer, event, parent)
234 })
235 .collect()
236 }
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242 use crate::store::output::ContentRef;
243
244 fn mk_final(text: &str, ok: bool) -> OutputEvent {
245 OutputEvent::Final {
246 content: ContentRef::inline_text(text),
247 ok,
248 }
249 }
250
251 #[tokio::test]
252 async fn append_then_get_roundtrip() {
253 let (s, driver) = SqliteOutputStore::open_in_memory().await.unwrap();
254 let id = s
255 .append("task-1", 1, "producer-a", mk_final("hello", true), vec![])
256 .await
257 .unwrap();
258 let got = s.get(&id).await.unwrap();
259 assert_eq!(got.id, id);
260 assert_eq!(got.task_id, "task-1");
261 assert_eq!(got.attempt, 1);
262 assert_eq!(got.producer_agent, "producer-a");
263 match got.event {
264 OutputEvent::Final { ok, .. } => assert!(ok),
265 other => panic!("unexpected: {other:?}"),
266 }
267 drop(s);
268 driver.shutdown().await.unwrap();
269 }
270
271 #[tokio::test]
272 async fn get_not_found_returns_error() {
273 let (s, driver) = SqliteOutputStore::open_in_memory().await.unwrap();
274 let err = s.get(&OutputRef("missing".into())).await.unwrap_err();
275 assert!(matches!(err, OutputStoreError::NotFound(_)));
276 drop(s);
277 driver.shutdown().await.unwrap();
278 }
279
280 #[tokio::test]
281 async fn list_for_attempt_orders_by_insertion() {
282 let (s, driver) = SqliteOutputStore::open_in_memory().await.unwrap();
283 let a = s
284 .append("t", 1, "p1", mk_final("a", true), vec![])
285 .await
286 .unwrap();
287 let b = s
288 .append("t", 1, "p2", mk_final("b", true), vec![])
289 .await
290 .unwrap();
291 let _ = s
293 .append("t", 2, "p1", mk_final("other-attempt", true), vec![])
294 .await
295 .unwrap();
296 let c = s
297 .append("t", 1, "p3", mk_final("c", true), vec![])
298 .await
299 .unwrap();
300
301 let listed = s.list_for_attempt("t", 1).await.unwrap();
302 let ids: Vec<_> = listed.iter().map(|r| r.id.clone()).collect();
303 assert_eq!(ids, vec![a, b, c]);
304 drop(s);
305 driver.shutdown().await.unwrap();
306 }
307
308 #[tokio::test]
309 async fn get_latest_by_name_returns_newest_emit() {
310 let (s, driver) = SqliteOutputStore::open_in_memory().await.unwrap();
311 let _ = s
312 .append("t", 1, "same-producer", mk_final("v1", true), vec![])
313 .await
314 .unwrap();
315 let _ = s
316 .append(
317 "t",
318 1,
319 "other-producer",
320 mk_final("unrelated", true),
321 vec![],
322 )
323 .await
324 .unwrap();
325 let latest_id = s
326 .append("t", 2, "same-producer", mk_final("v2", true), vec![])
327 .await
328 .unwrap();
329 let got = s.get_latest_by_name("same-producer").await.unwrap();
330 assert_eq!(got.id, latest_id);
331 drop(s);
332 driver.shutdown().await.unwrap();
333 }
334
335 #[tokio::test]
336 async fn get_latest_by_name_unknown_returns_not_found() {
337 let (s, driver) = SqliteOutputStore::open_in_memory().await.unwrap();
338 let err = s.get_latest_by_name("nobody").await.unwrap_err();
339 assert!(matches!(err, OutputStoreError::NotFound(_)));
340 drop(s);
341 driver.shutdown().await.unwrap();
342 }
343
344 #[tokio::test]
345 async fn parent_refs_are_persisted() {
346 let (s, driver) = SqliteOutputStore::open_in_memory().await.unwrap();
347 let a = s
348 .append("t", 1, "p", mk_final("parent", true), vec![])
349 .await
350 .unwrap();
351 let b = s
352 .append("t", 1, "p", mk_final("child", true), vec![a.clone()])
353 .await
354 .unwrap();
355 let got = s.get(&b).await.unwrap();
356 assert_eq!(got.parent_refs, vec![a]);
357 drop(s);
358 driver.shutdown().await.unwrap();
359 }
360
361 #[tokio::test]
362 async fn persists_across_reopen() {
363 let dir = tempfile::tempdir().unwrap();
364 let path = dir.path().join("outputs.db");
365 let id;
366 {
367 let (s, driver) = SqliteOutputStore::open(&path).await.unwrap();
368 id = s
369 .append("keep", 1, "p", mk_final("body", true), vec![])
370 .await
371 .unwrap();
372 drop(s);
373 driver.shutdown().await.unwrap();
374 }
375 let (s, driver) = SqliteOutputStore::open(&path).await.unwrap();
376 let got = s.get(&id).await.unwrap();
377 assert_eq!(got.task_id, "keep");
378 drop(s);
379 driver.shutdown().await.unwrap();
380 }
381}