Skip to main content

kaizen/store/sqlite/
experiment_windows.rs

1use super::rows::*;
2use super::*;
3impl Store {
4    /// Events in `[start_ms, end_ms]` for a workspace, with session metadata per row.
5    pub fn retro_events_in_window(
6        &self,
7        workspace: &str,
8        start_ms: u64,
9        end_ms: u64,
10    ) -> Result<Vec<(SessionRecord, Event)>> {
11        let mut stmt = self.conn.prepare(
12            "SELECT e.session_id, e.seq, e.ts_ms, COALESCE(e.ts_exact, 0), e.kind, e.source, e.tool, e.tool_call_id,
13                    e.tokens_in, e.tokens_out, e.reasoning_tokens, e.cost_usd_e6, e.payload,
14                    s.id, s.agent, s.model, s.workspace, s.started_at_ms, s.ended_at_ms, s.status, s.trace_path,
15                    s.start_commit, s.end_commit, s.branch, s.dirty_start, s.dirty_end, s.repo_binding_source,
16                    s.prompt_fingerprint, s.parent_session_id, s.agent_version, s.os, s.arch,
17                    s.repo_file_count, s.repo_total_loc,
18                    e.stop_reason, e.latency_ms, e.ttft_ms, e.retry_count,
19                    e.context_used_tokens, e.context_max_tokens,
20                    e.cache_creation_tokens, e.cache_read_tokens, e.system_prompt_tokens
21             FROM events e
22             JOIN sessions s ON s.id = e.session_id
23             WHERE s.workspace = ?1
24               AND (
25                 (e.ts_ms >= ?2 AND e.ts_ms <= ?3)
26                 OR (e.ts_ms < ?4 AND s.started_at_ms >= ?2 AND s.started_at_ms <= ?3)
27               )
28             ORDER BY e.ts_ms ASC, e.session_id ASC, e.seq ASC",
29        )?;
30        let rows = stmt.query_map(
31            params![
32                workspace,
33                start_ms as i64,
34                end_ms as i64,
35                SYNTHETIC_TS_CEILING_MS,
36            ],
37            |row| {
38                let payload_str: String = row.get(12)?;
39                let status_str: String = row.get(19)?;
40                Ok((
41                    SessionRecord {
42                        id: row.get(13)?,
43                        agent: row.get(14)?,
44                        model: row.get(15)?,
45                        workspace: row.get(16)?,
46                        started_at_ms: row.get::<_, i64>(17)? as u64,
47                        ended_at_ms: row.get::<_, Option<i64>>(18)?.map(|v| v as u64),
48                        status: status_from_str(&status_str),
49                        trace_path: row.get(20)?,
50                        start_commit: row.get(21)?,
51                        end_commit: row.get(22)?,
52                        branch: row.get(23)?,
53                        dirty_start: row.get::<_, Option<i64>>(24)?.map(i64_to_bool),
54                        dirty_end: row.get::<_, Option<i64>>(25)?.map(i64_to_bool),
55                        repo_binding_source: empty_to_none(row.get::<_, String>(26)?),
56                        prompt_fingerprint: row.get(27)?,
57                        parent_session_id: row.get(28)?,
58                        agent_version: row.get(29)?,
59                        os: row.get(30)?,
60                        arch: row.get(31)?,
61                        repo_file_count: row.get::<_, Option<i64>>(32)?.map(|v| v as u32),
62                        repo_total_loc: row.get::<_, Option<i64>>(33)?.map(|v| v as u64),
63                    },
64                    Event {
65                        session_id: row.get(0)?,
66                        seq: row.get::<_, i64>(1)? as u64,
67                        ts_ms: row.get::<_, i64>(2)? as u64,
68                        ts_exact: row.get::<_, i64>(3)? != 0,
69                        kind: kind_from_str(&row.get::<_, String>(4)?),
70                        source: source_from_str(&row.get::<_, String>(5)?),
71                        tool: row.get(6)?,
72                        tool_call_id: row.get(7)?,
73                        tokens_in: row.get::<_, Option<i64>>(8)?.map(|v| v as u32),
74                        tokens_out: row.get::<_, Option<i64>>(9)?.map(|v| v as u32),
75                        reasoning_tokens: row.get::<_, Option<i64>>(10)?.map(|v| v as u32),
76                        cost_usd_e6: row.get(11)?,
77                        payload: serde_json::from_str(&payload_str)
78                            .unwrap_or(serde_json::Value::Null),
79                        stop_reason: row.get(34)?,
80                        latency_ms: row.get::<_, Option<i64>>(35)?.map(|v| v as u32),
81                        ttft_ms: row.get::<_, Option<i64>>(36)?.map(|v| v as u32),
82                        retry_count: row.get::<_, Option<i64>>(37)?.map(|v| v as u16),
83                        context_used_tokens: row.get::<_, Option<i64>>(38)?.map(|v| v as u32),
84                        context_max_tokens: row.get::<_, Option<i64>>(39)?.map(|v| v as u32),
85                        cache_creation_tokens: row.get::<_, Option<i64>>(40)?.map(|v| v as u32),
86                        cache_read_tokens: row.get::<_, Option<i64>>(41)?.map(|v| v as u32),
87                        system_prompt_tokens: row.get::<_, Option<i64>>(42)?.map(|v| v as u32),
88                    },
89                ))
90            },
91        )?;
92
93        let mut out = Vec::new();
94        for r in rows {
95            out.push(r?);
96        }
97        Ok(out)
98    }
99    pub fn experiment_metric_values_in_window(
100        &self,
101        workspace: &str,
102        start_ms: u64,
103        end_ms: u64,
104        metric: crate::experiment::types::Metric,
105    ) -> Result<Vec<(SessionRecord, f64)>> {
106        use crate::experiment::types::Metric;
107        let session_cols = "s.id, s.agent, s.model, s.workspace, s.started_at_ms, s.ended_at_ms,
108            s.status, s.trace_path, s.start_commit, s.end_commit, s.branch, s.dirty_start,
109            s.dirty_end, s.repo_binding_source, s.prompt_fingerprint, s.parent_session_id,
110            s.agent_version, s.os, s.arch, s.repo_file_count, s.repo_total_loc";
111        let window = "s.workspace = ?1 AND ((e.ts_ms >= ?2 AND e.ts_ms <= ?3)
112            OR (e.ts_ms < ?4 AND s.started_at_ms >= ?2 AND s.started_at_ms <= ?3))";
113        let sql = match metric {
114            Metric::TokensPerSession => format!(
115                "SELECT {session_cols},
116                    SUM(COALESCE(e.tokens_in,0)+COALESCE(e.tokens_out,0)+COALESCE(e.reasoning_tokens,0)) AS value
117                 FROM sessions s JOIN events e ON e.session_id = s.id
118                 WHERE {window}
119                 GROUP BY s.id"
120            ),
121            Metric::CostPerSession => format!(
122                "SELECT {session_cols}, SUM(COALESCE(e.cost_usd_e6,0)) / 1000000.0 AS value
123                 FROM sessions s JOIN events e ON e.session_id = s.id
124                 WHERE {window}
125                 GROUP BY s.id"
126            ),
127            Metric::SuccessRate => format!(
128                "SELECT {session_cols},
129                    CASE WHEN SUM(CASE WHEN e.kind='Error' THEN 1 ELSE 0 END) > 0 THEN 0.0 ELSE 1.0 END AS value
130                 FROM sessions s JOIN events e ON e.session_id = s.id
131                 WHERE {window}
132                 GROUP BY s.id"
133            ),
134            Metric::DurationMinutes => format!(
135                "SELECT {session_cols},
136                    (s.ended_at_ms - s.started_at_ms) / 60000.0 AS value
137                 FROM sessions s
138                 WHERE s.workspace = ?1
139                   AND s.ended_at_ms IS NOT NULL
140                   AND EXISTS (
141                     SELECT 1 FROM events e
142                     WHERE e.session_id = s.id
143                       AND ((e.ts_ms >= ?2 AND e.ts_ms <= ?3)
144                         OR (e.ts_ms < ?4 AND s.started_at_ms >= ?2 AND s.started_at_ms <= ?3))
145                   )"
146            ),
147            Metric::FilesPerSession => format!(
148                "SELECT {session_cols}, COUNT(DISTINCT ft.path) AS value
149                 FROM sessions s
150                 JOIN events e ON e.session_id = s.id
151                 LEFT JOIN files_touched ft ON ft.session_id = s.id
152                 WHERE {window}
153                 GROUP BY s.id"
154            ),
155            Metric::SuccessRateByPrompt => format!(
156                "SELECT {session_cols},
157                    1.0 - (MIN(
158                      SUM(CASE WHEN e.kind='Error' THEN 1 ELSE 0 END),
159                      SUM(CASE WHEN e.kind='Message' THEN 1 ELSE 0 END)
160                    ) * 1.0 / SUM(CASE WHEN e.kind='Message' THEN 1 ELSE 0 END)) AS value
161                 FROM sessions s JOIN events e ON e.session_id = s.id
162                 WHERE {window}
163                 GROUP BY s.id
164                 HAVING SUM(CASE WHEN e.kind='Message' THEN 1 ELSE 0 END) > 0"
165            ),
166            Metric::CostByPrompt => format!(
167                "SELECT {session_cols},
168                    SUM(COALESCE(e.cost_usd_e6,0)) / 1000000.0 /
169                    SUM(CASE WHEN e.kind='Message' THEN 1 ELSE 0 END) AS value
170                 FROM sessions s JOIN events e ON e.session_id = s.id
171                 WHERE {window}
172                 GROUP BY s.id
173                 HAVING SUM(CASE WHEN e.kind='Message' THEN 1 ELSE 0 END) > 0"
174            ),
175            Metric::ToolLoops => format!(
176                "WITH calls AS (
177                   SELECT e.session_id, e.tool,
178                     LAG(e.tool) OVER (PARTITION BY e.session_id ORDER BY e.ts_ms, e.seq) AS prev_tool
179                   FROM events e JOIN sessions s ON s.id = e.session_id
180                   WHERE {window} AND e.kind='ToolCall' AND e.tool IS NOT NULL
181                 )
182                 SELECT {session_cols},
183                    SUM(CASE WHEN calls.tool = calls.prev_tool THEN 1 ELSE 0 END) AS value
184                 FROM sessions s JOIN calls ON calls.session_id = s.id
185                 GROUP BY s.id"
186            ),
187        };
188        let mut stmt = self.conn.prepare(&sql)?;
189        let rows = stmt.query_map(
190            params![
191                workspace,
192                start_ms as i64,
193                end_ms as i64,
194                SYNTHETIC_TS_CEILING_MS,
195            ],
196            |row| Ok((session_row(row)?, row.get::<_, f64>(21)?)),
197        )?;
198        rows.map(|r| r.map_err(anyhow::Error::from)).collect()
199    }
200}