Skip to main content

assay_core/storage/
store.rs

1use super::now_rfc3339ish;
2use crate::model::{AttemptRow, EvalConfig, LlmResponse, TestResultRow, TestStatus};
3use crate::trace::schema::{EpisodeEnd, EpisodeStart, StepEntry, ToolCallEntry, TraceEvent};
4use anyhow::Context;
5use rusqlite::{params, Connection};
6use std::path::Path;
7use std::sync::{Arc, Mutex};
8
9#[path = "store_internal/mod.rs"]
10mod store_internal;
11
12#[derive(Clone)]
13pub struct Store {
14    pub conn: Arc<Mutex<Connection>>,
15}
16
17pub struct StoreStats {
18    pub runs: Option<u64>,
19    pub results: Option<u64>,
20    pub last_run_id: Option<i64>,
21    pub last_run_at: Option<String>,
22    pub version: Option<String>,
23}
24
25impl Store {
26    pub fn open(path: &Path) -> anyhow::Result<Self> {
27        let conn = Connection::open(path).context("failed to open sqlite db")?;
28        conn.execute("PRAGMA foreign_keys = ON", [])?;
29        Ok(Self {
30            conn: Arc::new(Mutex::new(conn)),
31        })
32    }
33
34    pub fn memory() -> anyhow::Result<Self> {
35        // SQLite in-memory DB
36        let conn = Connection::open_in_memory().context("failed to open in-memory sqlite db")?;
37        Ok(Self {
38            conn: Arc::new(Mutex::new(conn)),
39        })
40    }
41
42    pub fn init_schema(&self) -> anyhow::Result<()> {
43        let conn = self.conn.lock().unwrap();
44        conn.execute_batch(crate::storage::schema::DDL)?;
45
46        // v0.3.0 Migrations
47        migrate_v030(&conn)?;
48
49        // Ensure attempts table exists (covered by DDL if creating fresh, but good to be explicit if DDL didn't run on existing DB)
50        // DDL handles IF NOT EXISTS for attempts.
51
52        // Index on fingerprint for speed (CREATE INDEX IF NOT EXISTS is valid sqlite)
53        let _ = conn.execute(
54            "CREATE INDEX IF NOT EXISTS idx_results_fingerprint ON results(fingerprint)",
55            [],
56        );
57
58        Ok(())
59    }
60
61    pub fn fetch_recent_results(
62        &self,
63        suite: &str,
64        limit: u32,
65    ) -> anyhow::Result<Vec<crate::model::TestResultRow>> {
66        let conn = self.conn.lock().unwrap();
67        let mut stmt = conn.prepare(
68            "SELECT
69                r.test_id, r.outcome, r.duration_ms, r.score, r.attempts_json,
70                r.fingerprint, r.skip_reason
71             FROM results r
72             JOIN runs ON r.run_id = runs.id
73             WHERE runs.suite = ?1
74             ORDER BY r.id DESC
75             LIMIT ?2",
76        )?;
77
78        let rows = stmt.query_map(rusqlite::params![suite, limit], row_to_test_result)?;
79
80        let mut results = Vec::new();
81        for r in rows {
82            results.push(r?);
83        }
84        Ok(results)
85    }
86
87    pub fn fetch_results_for_last_n_runs(
88        &self,
89        suite: &str,
90        n: u32,
91    ) -> anyhow::Result<Vec<crate::model::TestResultRow>> {
92        let conn = self.conn.lock().unwrap();
93        let mut stmt = conn.prepare(
94            "SELECT
95                r.test_id, r.outcome, r.duration_ms, r.score, r.attempts_json,
96                r.fingerprint, r.skip_reason
97             FROM results r
98             JOIN runs ON r.run_id = runs.id
99             WHERE runs.id IN (
100                 SELECT id FROM runs WHERE suite = ?1 ORDER BY id DESC LIMIT ?2
101             )
102             ORDER BY r.id DESC",
103        )?;
104
105        let rows = stmt.query_map(rusqlite::params![suite, n], row_to_test_result)?;
106
107        let mut results = Vec::new();
108        for r in rows {
109            results.push(r?);
110        }
111        Ok(results)
112    }
113
114    pub fn get_latest_run_id(&self, suite: &str) -> anyhow::Result<Option<i64>> {
115        let conn = self.conn.lock().unwrap();
116        let mut stmt =
117            conn.prepare("SELECT id FROM runs WHERE suite = ?1 ORDER BY id DESC LIMIT 1")?;
118        let mut rows = stmt.query(params![suite])?;
119        if let Some(row) = rows.next()? {
120            Ok(Some(row.get(0)?))
121        } else {
122            Ok(None)
123        }
124    }
125
126    pub fn fetch_results_for_run(
127        &self,
128        run_id: i64,
129    ) -> anyhow::Result<Vec<crate::model::TestResultRow>> {
130        let conn = self.conn.lock().unwrap();
131        let mut stmt = conn.prepare(
132            "SELECT
133                r.test_id, r.outcome, r.duration_ms, r.score, r.attempts_json,
134                r.fingerprint, r.skip_reason
135             FROM results r
136             WHERE r.run_id = ?1
137             ORDER BY r.test_id ASC",
138        )?;
139
140        let rows = stmt.query_map(params![run_id], row_to_test_result)?;
141
142        let mut results = Vec::new();
143        for r in rows {
144            results.push(r?);
145        }
146        Ok(results)
147    }
148
149    pub fn get_last_passing_by_fingerprint(
150        &self,
151        fingerprint: &str,
152    ) -> anyhow::Result<Option<TestResultRow>> {
153        let conn = self.conn.lock().unwrap();
154        // We want the most recent passing result for this fingerprint.
155        // run_id DESC ensures recency.
156        let mut stmt = conn.prepare(
157            "SELECT r.test_id, r.score, r.duration_ms, r.output_json, r.skip_reason, run.id, run.started_at
158             FROM results r
159             JOIN runs run ON r.run_id = run.id
160             WHERE r.fingerprint = ?1 AND r.outcome = 'pass'
161             ORDER BY r.id DESC LIMIT 1"
162        )?;
163
164        let mut rows = stmt.query(params![fingerprint])?;
165        if let Some(row) = rows.next()? {
166            let status = TestStatus::Pass;
167
168            let skip_reason: Option<String> = row.get(4)?;
169            let run_id: i64 = row.get(5)?;
170            let started_at: String = row.get(6)?;
171
172            let details = serde_json::json!({
173                "skip": {
174                    "reason": skip_reason.clone().unwrap_or_else(|| "fingerprint_match".into()),
175                    "fingerprint": fingerprint,
176                    "previous_run_id": run_id,
177                    "previous_at": started_at,
178                    "origin_run_id": run_id,
179                    "previous_score": row.get::<_, Option<f64>>(1)?
180                }
181            });
182
183            Ok(Some(TestResultRow {
184                test_id: row.get(0)?,
185                status,
186                message: skip_reason.unwrap_or_else(|| "fingerprint_match".to_string()),
187                score: row.get(1)?,
188                duration_ms: row.get(2)?,
189                cached: true,
190                details,
191                fingerprint: Some(fingerprint.to_string()),
192                skip_reason: None,
193                attempts: None,
194                error_policy_applied: None,
195            }))
196        } else {
197            Ok(None)
198        }
199    }
200
201    pub fn insert_run(&self, suite: &str) -> anyhow::Result<i64> {
202        let started_at = now_rfc3339ish();
203        let conn = self.conn.lock().unwrap();
204        insert_run_row(&conn, suite, &started_at, "running", None)
205    }
206
207    pub fn create_run(&self, cfg: &EvalConfig) -> anyhow::Result<i64> {
208        let started_at = now_rfc3339ish();
209        let config_json = serde_json::to_string(cfg)?;
210        let conn = self.conn.lock().unwrap();
211        insert_run_row(
212            &conn,
213            &cfg.suite,
214            &started_at,
215            "running",
216            Some(config_json.as_str()),
217        )
218    }
219
220    pub fn finalize_run(&self, run_id: i64, status: &str) -> anyhow::Result<()> {
221        let conn = self.conn.lock().unwrap();
222        conn.execute(
223            "UPDATE runs SET status=?1 WHERE id=?2",
224            params![status, run_id],
225        )?;
226        Ok(())
227    }
228
229    pub fn insert_result_embedded(
230        &self,
231        run_id: i64,
232        row: &TestResultRow,
233        attempts: &[AttemptRow],
234        output: &LlmResponse,
235    ) -> anyhow::Result<()> {
236        let conn = self.conn.lock().unwrap();
237
238        // 1. Insert into results
239        conn.execute(
240            "INSERT INTO results(run_id, test_id, outcome, score, duration_ms, attempts_json, output_json, fingerprint, skip_reason)
241             VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)",
242            params![
243                run_id,
244                row.test_id,
245                status_to_outcome(&row.status),
246                row.score,
247                row.duration_ms.map(|v| v as i64),
248                serde_json::to_string(attempts)?,
249                serde_json::to_string(output)?,
250                row.fingerprint,
251                row.skip_reason
252            ],
253        )?;
254
255        let result_id = conn.last_insert_rowid();
256
257        // 2. Insert individual attempts
258        let mut stmt = conn.prepare(
259            "INSERT INTO attempts(result_id, attempt_number, outcome, score, duration_ms, output_json, error_message)
260             VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)"
261        )?;
262
263        for attempt in attempts {
264            stmt.execute(params![
265                result_id,
266                attempt.attempt_no as i64,
267                status_to_outcome(&attempt.status),
268                0.0, // Score not tracked per attempt yet
269                attempt.duration_ms.map(|v| v as i64),
270                serde_json::to_string(&attempt.details)?,
271                Option::<String>::None
272            ])?;
273        }
274
275        Ok(())
276    }
277
278    // ... existing ...
279
280    // quarantine
281    pub fn quarantine_get_reason(
282        &self,
283        suite: &str,
284        test_id: &str,
285    ) -> anyhow::Result<Option<String>> {
286        let conn = self.conn.lock().unwrap();
287        let mut stmt =
288            conn.prepare("SELECT reason FROM quarantine WHERE suite=?1 AND test_id=?2")?;
289        let mut rows = stmt.query(params![suite, test_id])?;
290        if let Some(row) = rows.next()? {
291            Ok(Some(row.get::<_, Option<String>>(0)?.unwrap_or_default()))
292        } else {
293            Ok(None)
294        }
295    }
296
297    pub fn quarantine_add(&self, suite: &str, test_id: &str, reason: &str) -> anyhow::Result<()> {
298        let conn = self.conn.lock().unwrap();
299        conn.execute(
300            "INSERT INTO quarantine(suite, test_id, reason, added_at)
301             VALUES (?1, ?2, ?3, ?4)
302             ON CONFLICT(suite, test_id) DO UPDATE SET reason=excluded.reason, added_at=excluded.added_at",
303            params![suite, test_id, reason, now_rfc3339ish()],
304        )?;
305        Ok(())
306    }
307
308    pub fn quarantine_remove(&self, suite: &str, test_id: &str) -> anyhow::Result<()> {
309        let conn = self.conn.lock().unwrap();
310        conn.execute(
311            "DELETE FROM quarantine WHERE suite=?1 AND test_id=?2",
312            params![suite, test_id],
313        )?;
314        Ok(())
315    }
316
317    // cache
318    pub fn cache_get(&self, key: &str) -> anyhow::Result<Option<LlmResponse>> {
319        let conn = self.conn.lock().unwrap();
320        let mut stmt = conn.prepare("SELECT response_json FROM cache WHERE key=?1")?;
321        let mut rows = stmt.query(params![key])?;
322        if let Some(row) = rows.next()? {
323            let s: String = row.get(0)?;
324            let mut resp: LlmResponse = serde_json::from_str(&s)?;
325            resp.cached = true;
326            Ok(Some(resp))
327        } else {
328            Ok(None)
329        }
330    }
331
332    pub fn cache_put(&self, key: &str, resp: &LlmResponse) -> anyhow::Result<()> {
333        let conn = self.conn.lock().unwrap();
334        let created_at = now_rfc3339ish();
335        let mut to_store = resp.clone();
336        to_store.cached = false;
337        conn.execute(
338            "INSERT INTO cache(key, response_json, created_at) VALUES (?1, ?2, ?3)
339             ON CONFLICT(key) DO UPDATE SET response_json=excluded.response_json, created_at=excluded.created_at",
340            params![key, serde_json::to_string(&to_store)?, created_at],
341        )?;
342        Ok(())
343    }
344
345    // embeddings
346    pub fn get_embedding(&self, key: &str) -> anyhow::Result<Option<(String, Vec<f32>)>> {
347        let conn = self.conn.lock().unwrap();
348        let mut stmt = conn.prepare("SELECT model, vec FROM embeddings WHERE key = ?1 LIMIT 1")?;
349        let mut rows = stmt.query(params![key])?;
350
351        if let Some(row) = rows.next()? {
352            let model: String = row.get(0)?;
353            let blob: Vec<u8> = row.get(1)?;
354            let vec = crate::embeddings::util::decode_vec_f32(&blob)?;
355            Ok(Some((model, vec)))
356        } else {
357            Ok(None)
358        }
359    }
360
361    pub fn put_embedding(&self, key: &str, model: &str, vec: &[f32]) -> anyhow::Result<()> {
362        let conn = self.conn.lock().unwrap();
363        let blob = crate::embeddings::util::encode_vec_f32(vec);
364        let dims = vec.len() as i64;
365        let created_at = now_rfc3339ish();
366
367        conn.execute(
368            "INSERT OR REPLACE INTO embeddings (key, model, dims, vec, created_at)
369             VALUES (?1, ?2, ?3, ?4, ?5)",
370            params![key, model, dims, blob, created_at],
371        )?;
372        Ok(())
373    }
374    pub fn stats_best_effort(&self) -> anyhow::Result<StoreStats> {
375        let conn = self.conn.lock().unwrap();
376
377        let runs: Option<u64> = conn
378            .query_row("SELECT COUNT(*) FROM runs", [], |r| {
379                r.get::<_, i64>(0).map(|x| x as u64)
380            })
381            .ok();
382        let results: Option<u64> = conn
383            .query_row("SELECT COUNT(*) FROM results", [], |r| {
384                r.get::<_, i64>(0).map(|x| x as u64)
385            })
386            .ok();
387
388        let last: Option<(i64, String)> = conn
389            .query_row(
390                "SELECT id, started_at FROM runs ORDER BY id DESC LIMIT 1",
391                [],
392                |r| Ok((r.get(0)?, r.get(1)?)),
393            )
394            .ok();
395
396        let (last_id, last_started) = if let Some((id, s)) = last {
397            (Some(id), Some(s))
398        } else {
399            (None, None)
400        };
401
402        let v_str: Option<String> = conn
403            .query_row("PRAGMA user_version", [], |r| r.get(0))
404            .ok()
405            .map(|v: i64| v.to_string());
406
407        Ok(StoreStats {
408            runs,
409            results,
410            last_run_id: last_id,
411            last_run_at: last_started,
412            version: v_str,
413        })
414    }
415
416    // --- Assertions Support ---
417
418    pub fn get_episode_graph(
419        &self,
420        run_id: i64,
421        test_id: &str,
422    ) -> anyhow::Result<crate::agent_assertions::EpisodeGraph> {
423        let conn = self.conn.lock().unwrap();
424
425        // 1. Find Episode
426        let mut stmt = conn.prepare("SELECT id FROM episodes WHERE run_id = ? AND test_id = ?")?;
427        let mut rows = stmt.query(params![run_id, test_id])?;
428
429        let mut episode_ids = Vec::new();
430        while let Some(row) = rows.next()? {
431            episode_ids.push(row.get::<_, String>(0)?);
432        }
433
434        if episode_ids.is_empty() {
435            anyhow::bail!(
436                "E_TRACE_EPISODE_MISSING: No episode found for run_id={} test_id={}",
437                run_id,
438                test_id
439            );
440        }
441        if episode_ids.len() > 1 {
442            anyhow::bail!(
443                "E_TRACE_EPISODE_AMBIGUOUS: Multiple episodes ({}) found for run_id={} test_id={}",
444                episode_ids.len(),
445                run_id,
446                test_id
447            );
448        }
449        let episode_id = episode_ids[0].clone();
450
451        load_episode_graph_for_episode_id(&conn, &episode_id)
452    }
453
454    // --- Trace V2 Storage ---
455
456    pub fn insert_event(
457        &self,
458        event: &TraceEvent,
459        run_id: Option<i64>,
460        test_id: Option<&str>,
461    ) -> anyhow::Result<()> {
462        let mut conn = self.conn.lock().unwrap();
463        let tx = conn.transaction()?;
464        match event {
465            TraceEvent::EpisodeStart(e) => Self::insert_episode(&tx, e, run_id, test_id)?,
466            TraceEvent::Step(e) => Self::insert_step(&tx, e)?,
467            TraceEvent::ToolCall(e) => Self::insert_tool_call(&tx, e)?,
468            TraceEvent::EpisodeEnd(e) => Self::update_episode_end(&tx, e)?,
469        }
470        tx.commit()?;
471        Ok(())
472    }
473
474    pub fn insert_batch(
475        &self,
476        events: &[TraceEvent],
477        run_id: Option<i64>,
478        test_id: Option<&str>,
479    ) -> anyhow::Result<()> {
480        let mut conn = self.conn.lock().unwrap();
481        let tx = conn.transaction()?;
482        for event in events {
483            match event {
484                TraceEvent::EpisodeStart(e) => Self::insert_episode(&tx, e, run_id, test_id)?,
485                TraceEvent::Step(e) => Self::insert_step(&tx, e)?,
486                TraceEvent::ToolCall(e) => Self::insert_tool_call(&tx, e)?,
487                TraceEvent::EpisodeEnd(e) => Self::update_episode_end(&tx, e)?,
488            }
489        }
490        tx.commit()?;
491        Ok(())
492    }
493
494    fn insert_episode(
495        tx: &rusqlite::Transaction<'_>,
496        e: &EpisodeStart,
497        run_id: Option<i64>,
498        test_id: Option<&str>,
499    ) -> anyhow::Result<()> {
500        let prompt_val = e.input.get("prompt").unwrap_or(&serde_json::Value::Null);
501        let prompt_str = if let Some(s) = prompt_val.as_str() {
502            s.to_string()
503        } else {
504            serde_json::to_string(prompt_val).unwrap_or_default()
505        };
506        let meta = serde_json::to_string(&e.meta).unwrap_or_default();
507
508        // PR-406: Support test_id in meta (from MCP import) to override episode_id default
509        let meta_test_id = e.meta.get("test_id").and_then(|v| v.as_str());
510        let effective_test_id = test_id.or(meta_test_id).or(Some(&e.episode_id));
511
512        // Idempotent: OR REPLACE to update meta/prompt if re-ingesting? Or OR IGNORE?
513        // User said: "INSERT OR IGNORE op episode_id" for idempotency of IDs.
514        tx.execute(
515            "INSERT INTO episodes (id, run_id, test_id, timestamp, prompt, meta_json) VALUES (?, ?, ?, ?, ?, ?)
516             ON CONFLICT(id) DO UPDATE SET
517                run_id=COALESCE(excluded.run_id, episodes.run_id),
518                test_id=COALESCE(excluded.test_id, episodes.test_id),
519                timestamp=excluded.timestamp,
520                prompt=excluded.prompt,
521                meta_json=excluded.meta_json",
522            (
523                &e.episode_id,
524                run_id,
525                effective_test_id,
526                e.timestamp,
527                prompt_str,
528                meta,
529            ),
530        ).context("insert episode")?;
531        Ok(())
532    }
533
534    fn insert_step(tx: &rusqlite::Transaction<'_>, e: &StepEntry) -> anyhow::Result<()> {
535        let meta = serde_json::to_string(&e.meta).unwrap_or_default();
536        let trunc = serde_json::to_string(&e.truncations).unwrap_or_default();
537
538        // Idempotency: UNIQUE(episode_id, idx)
539        tx.execute(
540            "INSERT INTO steps (id, episode_id, idx, kind, name, content, content_sha256, truncations_json, meta_json)
541             VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
542             ON CONFLICT(id) DO UPDATE SET content=excluded.content, meta_json=excluded.meta_json",
543            (
544                &e.step_id,
545                &e.episode_id,
546                e.idx,
547                &e.kind,
548                e.name.as_deref(),
549                e.content.as_deref(),
550                e.content_sha256.as_deref(),
551                trunc,
552                meta
553            ),
554        ).context("insert step")?;
555        Ok(())
556    }
557
558    fn insert_tool_call(tx: &rusqlite::Transaction<'_>, e: &ToolCallEntry) -> anyhow::Result<()> {
559        let args = serde_json::to_string(&e.args).unwrap_or_default();
560        let result = e
561            .result
562            .as_ref()
563            .map(|r| serde_json::to_string(r).unwrap_or_default());
564        let trunc = serde_json::to_string(&e.truncations).unwrap_or_default();
565
566        let call_idx = e.call_index.unwrap_or(0); // Default 0
567
568        tx.execute(
569            "INSERT INTO tool_calls (step_id, episode_id, tool_name, call_index, args, args_sha256, result, result_sha256, error, truncations_json)
570             VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
571             ON CONFLICT(step_id, call_index) DO NOTHING",
572            (
573                &e.step_id,
574                &e.episode_id,
575                &e.tool_name,
576                call_idx,
577                args,
578                e.args_sha256.as_deref(),
579                result,
580                e.result_sha256.as_deref(),
581                e.error.as_deref(),
582                trunc
583            ),
584        ).context("insert tool call")?;
585        Ok(())
586    }
587
588    pub fn count_rows(&self, table: &str) -> anyhow::Result<i64> {
589        let conn = self.conn.lock().unwrap();
590        // Validation to prevent SQL injection (simple allowlist)
591        if !["episodes", "steps", "tool_calls", "runs", "results"].contains(&table) {
592            anyhow::bail!("Invalid table name for count_rows: {}", table);
593        }
594        let sql = format!("SELECT COUNT(*) FROM {}", table);
595        let n: i64 = conn.query_row(&sql, [], |r| r.get(0))?;
596        Ok(n)
597    }
598
599    fn update_episode_end(tx: &rusqlite::Transaction<'_>, e: &EpisodeEnd) -> anyhow::Result<()> {
600        tx.execute(
601            "UPDATE episodes SET outcome = ? WHERE id = ?",
602            (e.outcome.as_deref(), &e.episode_id),
603        )
604        .context("update episode outcome")?;
605        Ok(())
606    }
607}
608
609fn status_to_outcome(s: &TestStatus) -> &'static str {
610    store_internal::results::status_to_outcome_impl(s)
611}
612
613fn migrate_v030(conn: &Connection) -> anyhow::Result<()> {
614    store_internal::schema::migrate_v030_impl(conn)
615}
616
617impl Store {
618    pub fn get_latest_episode_graph_by_test_id(
619        &self,
620        test_id: &str,
621    ) -> anyhow::Result<crate::agent_assertions::EpisodeGraph> {
622        let conn = self.conn.lock().unwrap();
623
624        // 1. Find latest episode for this test_id
625        let mut stmt = conn.prepare(
626            "SELECT id FROM episodes
627             WHERE test_id = ?1
628             ORDER BY timestamp DESC
629             LIMIT 1",
630        )?;
631
632        let episode_id: String = stmt.query_row(params![test_id], |row| row.get(0))
633            .map_err(|e| anyhow::anyhow!("E_TRACE_EPISODE_MISSING: No episode found for test_id={} (fallback check) : {}", test_id, e))?;
634
635        load_episode_graph_for_episode_id(&conn, &episode_id)
636    }
637}
638
639fn row_to_test_result(row: &rusqlite::Row<'_>) -> rusqlite::Result<TestResultRow> {
640    store_internal::results::row_to_test_result_impl(row)
641}
642
643fn insert_run_row(
644    conn: &Connection,
645    suite: &str,
646    started_at: &str,
647    status: &str,
648    config_json: Option<&str>,
649) -> anyhow::Result<i64> {
650    store_internal::results::insert_run_row_impl(conn, suite, started_at, status, config_json)
651}
652
653fn load_episode_graph_for_episode_id(
654    conn: &Connection,
655    episode_id: &str,
656) -> anyhow::Result<crate::agent_assertions::EpisodeGraph> {
657    store_internal::episodes::load_episode_graph_for_episode_id_impl(conn, episode_id)
658}