Skip to main content

kaizen/store/sqlite/
evals.rs

1use super::rows::*;
2use super::*;
3
4impl Store {
5    pub fn upsert_eval(&self, eval: &crate::eval::types::EvalRow) -> rusqlite::Result<()> {
6        self.conn.execute(
7            "INSERT OR REPLACE INTO session_evals
8             (id, session_id, judge_model, rubric_id, score, rationale, flagged, created_at_ms)
9             VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
10            rusqlite::params![
11                eval.id,
12                eval.session_id,
13                eval.judge_model,
14                eval.rubric_id,
15                eval.score,
16                eval.rationale,
17                eval.flagged as i64,
18                eval.created_at_ms as i64,
19            ],
20        )?;
21        Ok(())
22    }
23
24    pub fn list_evals_in_window(
25        &self,
26        start_ms: u64,
27        end_ms: u64,
28    ) -> rusqlite::Result<Vec<crate::eval::types::EvalRow>> {
29        let mut stmt = self.conn.prepare(
30            "SELECT id, session_id, judge_model, rubric_id, score,
31                    rationale, flagged, created_at_ms
32             FROM session_evals
33             WHERE created_at_ms >= ?1 AND created_at_ms < ?2
34             ORDER BY created_at_ms ASC",
35        )?;
36        let rows = stmt.query_map(rusqlite::params![start_ms as i64, end_ms as i64], |r| {
37            Ok(crate::eval::types::EvalRow {
38                id: r.get(0)?,
39                session_id: r.get(1)?,
40                judge_model: r.get(2)?,
41                rubric_id: r.get(3)?,
42                score: r.get(4)?,
43                rationale: r.get(5)?,
44                flagged: r.get::<_, i64>(6)? != 0,
45                created_at_ms: r.get::<_, i64>(7)? as u64,
46            })
47        })?;
48        rows.collect()
49    }
50
51    pub fn list_evals_for_session(
52        &self,
53        session_id: &str,
54    ) -> rusqlite::Result<Vec<crate::eval::types::EvalRow>> {
55        let mut stmt = self.conn.prepare(
56            "SELECT id, session_id, judge_model, rubric_id, score,
57                    rationale, flagged, created_at_ms
58             FROM session_evals
59             WHERE session_id = ?1
60             ORDER BY created_at_ms DESC",
61        )?;
62        let rows = stmt.query_map(rusqlite::params![session_id], |r| {
63            Ok(crate::eval::types::EvalRow {
64                id: r.get(0)?,
65                session_id: r.get(1)?,
66                judge_model: r.get(2)?,
67                rubric_id: r.get(3)?,
68                score: r.get(4)?,
69                rationale: r.get(5)?,
70                flagged: r.get::<_, i64>(6)? != 0,
71                created_at_ms: r.get::<_, i64>(7)? as u64,
72            })
73        })?;
74        rows.collect()
75    }
76
77    pub fn list_sessions_for_eval(
78        &self,
79        since_ms: u64,
80        min_cost_usd: f64,
81    ) -> Result<Vec<crate::core::event::SessionRecord>> {
82        let min_cost_e6 = (min_cost_usd * 1_000_000.0) as i64;
83        let mut stmt = self.conn.prepare(
84            "SELECT s.id, s.agent, s.model, s.workspace, s.started_at_ms, s.ended_at_ms,
85                    s.status, s.trace_path, s.start_commit, s.end_commit, s.branch,
86                    s.dirty_start, s.dirty_end, s.repo_binding_source, s.prompt_fingerprint,
87                    s.parent_session_id, s.agent_version, s.os, s.arch, s.repo_file_count, s.repo_total_loc
88             FROM sessions s
89             WHERE s.started_at_ms >= ?1
90               AND COALESCE((SELECT SUM(e.cost_usd_e6) FROM events e WHERE e.session_id = s.id), 0) >= ?2
91               AND NOT EXISTS (SELECT 1 FROM session_evals ev WHERE ev.session_id = s.id)
92             ORDER BY s.started_at_ms DESC",
93        )?;
94        let rows = stmt.query_map(params![since_ms as i64, min_cost_e6], |r| {
95            Ok((
96                r.get::<_, String>(0)?,
97                r.get::<_, String>(1)?,
98                r.get::<_, Option<String>>(2)?,
99                r.get::<_, String>(3)?,
100                r.get::<_, i64>(4)?,
101                r.get::<_, Option<i64>>(5)?,
102                r.get::<_, String>(6)?,
103                r.get::<_, String>(7)?,
104                r.get::<_, Option<String>>(8)?,
105                r.get::<_, Option<String>>(9)?,
106                r.get::<_, Option<String>>(10)?,
107                r.get::<_, Option<i64>>(11)?,
108                r.get::<_, Option<i64>>(12)?,
109                r.get::<_, Option<String>>(13)?,
110                r.get::<_, Option<String>>(14)?,
111                r.get::<_, Option<String>>(15)?,
112                r.get::<_, Option<String>>(16)?,
113                r.get::<_, Option<String>>(17)?,
114                r.get::<_, Option<String>>(18)?,
115                r.get::<_, Option<i64>>(19)?,
116                r.get::<_, Option<i64>>(20)?,
117            ))
118        })?;
119        let mut out = Vec::new();
120        for row in rows {
121            let (
122                id,
123                agent,
124                model,
125                workspace,
126                started,
127                ended,
128                status_str,
129                trace,
130                start_commit,
131                end_commit,
132                branch,
133                dirty_start,
134                dirty_end,
135                source,
136                prompt_fingerprint,
137                parent_session_id,
138                agent_version,
139                os,
140                arch,
141                repo_file_count,
142                repo_total_loc,
143            ) = row?;
144            out.push(crate::core::event::SessionRecord {
145                id,
146                agent,
147                model,
148                workspace,
149                started_at_ms: started as u64,
150                ended_at_ms: ended.map(|v| v as u64),
151                status: status_from_str(&status_str),
152                trace_path: trace,
153                start_commit,
154                end_commit,
155                branch,
156                dirty_start: dirty_start.map(i64_to_bool),
157                dirty_end: dirty_end.map(i64_to_bool),
158                repo_binding_source: source.and_then(|s| if s.is_empty() { None } else { Some(s) }),
159                prompt_fingerprint,
160                parent_session_id,
161                agent_version,
162                os,
163                arch,
164                repo_file_count: repo_file_count.map(|v| v as u32),
165                repo_total_loc: repo_total_loc.map(|v| v as u64),
166            });
167        }
168        Ok(out)
169    }
170}