Skip to main content

kaizen/extensions/
aggregates.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2
3use crate::core::event::{Event, EventKind};
4use crate::store::Store;
5use anyhow::Result;
6use rusqlite::{OptionalExtension, params};
7use serde::{Deserialize, Serialize};
8
9#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
10pub struct SessionAggregate {
11    pub session_id: String,
12    pub event_count: u64,
13    pub tool_call_count: u64,
14    pub error_count: u64,
15    pub tokens_in: u64,
16    pub tokens_out: u64,
17    pub reasoning_tokens: u64,
18    pub cache_read_tokens: u64,
19    pub cache_creation_tokens: u64,
20    pub cost_usd_e6: i64,
21    pub first_event_ms: Option<u64>,
22    pub last_event_ms: Option<u64>,
23    pub rebuilt_at_ms: u64,
24}
25
26const APPLY_EVENT_SQL: &str = "
27INSERT INTO session_aggregates (
28    session_id, event_count, tool_call_count, error_count, tokens_in,
29    tokens_out, reasoning_tokens, cache_read_tokens, cache_creation_tokens,
30    cost_usd_e6, first_event_ms, last_event_ms, rebuilt_at_ms
31) VALUES (?1,1,?2,?3,?4,?5,?6,?7,?8,?9,?10,?10,?11)
32ON CONFLICT(session_id) DO UPDATE SET
33    event_count=session_aggregates.event_count + 1,
34    tool_call_count=session_aggregates.tool_call_count + excluded.tool_call_count,
35    error_count=session_aggregates.error_count + excluded.error_count,
36    tokens_in=session_aggregates.tokens_in + excluded.tokens_in,
37    tokens_out=session_aggregates.tokens_out + excluded.tokens_out,
38    reasoning_tokens=session_aggregates.reasoning_tokens + excluded.reasoning_tokens,
39    cache_read_tokens=session_aggregates.cache_read_tokens + excluded.cache_read_tokens,
40    cache_creation_tokens=session_aggregates.cache_creation_tokens + excluded.cache_creation_tokens,
41    cost_usd_e6=session_aggregates.cost_usd_e6 + excluded.cost_usd_e6,
42    first_event_ms=MIN(session_aggregates.first_event_ms, excluded.first_event_ms),
43    last_event_ms=MAX(session_aggregates.last_event_ms, excluded.last_event_ms),
44    rebuilt_at_ms=excluded.rebuilt_at_ms";
45
46pub fn apply_event(store: &Store, event: &Event) -> Result<()> {
47    if get(store, &event.session_id)?.is_none() {
48        upsert_session(store, &event.session_id)?;
49        return Ok(());
50    }
51    store.conn().execute(
52        APPLY_EVENT_SQL,
53        params![
54            event.session_id,
55            i64::from(event.kind == EventKind::ToolCall),
56            i64::from(event.kind == EventKind::Error),
57            i64::from(event.tokens_in.unwrap_or(0)),
58            i64::from(event.tokens_out.unwrap_or(0)),
59            i64::from(event.reasoning_tokens.unwrap_or(0)),
60            i64::from(event.cache_read_tokens.unwrap_or(0)),
61            i64::from(event.cache_creation_tokens.unwrap_or(0)),
62            event.cost_usd_e6.unwrap_or(0),
63            event.ts_ms as i64,
64            now_ms() as i64,
65        ],
66    )?;
67    Ok(())
68}
69
70pub fn rebuild_workspace(store: &Store, workspace: &str) -> Result<usize> {
71    store
72        .list_sessions(workspace)?
73        .iter()
74        .map(|s| upsert_session(store, &s.id).map(|_| 1usize))
75        .sum()
76}
77
78pub fn upsert_session(store: &Store, session_id: &str) -> Result<SessionAggregate> {
79    let row = aggregate_raw(store, session_id, now_ms())?;
80    store.conn().execute(
81        "INSERT INTO session_aggregates (
82            session_id, event_count, tool_call_count, error_count, tokens_in,
83            tokens_out, reasoning_tokens, cache_read_tokens, cache_creation_tokens,
84            cost_usd_e6, first_event_ms, last_event_ms, rebuilt_at_ms
85        ) VALUES (?1,?2,?3,?4,?5,?6,?7,?8,?9,?10,?11,?12,?13)
86        ON CONFLICT(session_id) DO UPDATE SET
87            event_count=excluded.event_count, tool_call_count=excluded.tool_call_count,
88            error_count=excluded.error_count, tokens_in=excluded.tokens_in,
89            tokens_out=excluded.tokens_out, reasoning_tokens=excluded.reasoning_tokens,
90            cache_read_tokens=excluded.cache_read_tokens,
91            cache_creation_tokens=excluded.cache_creation_tokens,
92            cost_usd_e6=excluded.cost_usd_e6, first_event_ms=excluded.first_event_ms,
93            last_event_ms=excluded.last_event_ms, rebuilt_at_ms=excluded.rebuilt_at_ms",
94        params![
95            row.session_id,
96            row.event_count as i64,
97            row.tool_call_count as i64,
98            row.error_count as i64,
99            row.tokens_in as i64,
100            row.tokens_out as i64,
101            row.reasoning_tokens as i64,
102            row.cache_read_tokens as i64,
103            row.cache_creation_tokens as i64,
104            row.cost_usd_e6,
105            row.first_event_ms.map(|v| v as i64),
106            row.last_event_ms.map(|v| v as i64),
107            row.rebuilt_at_ms as i64,
108        ],
109    )?;
110    Ok(row)
111}
112
113pub fn get(store: &Store, session_id: &str) -> Result<Option<SessionAggregate>> {
114    store
115        .conn()
116        .query_row(
117            "SELECT session_id, event_count, tool_call_count, error_count, tokens_in,
118                    tokens_out, reasoning_tokens, cache_read_tokens, cache_creation_tokens,
119                    cost_usd_e6, first_event_ms, last_event_ms, rebuilt_at_ms
120             FROM session_aggregates WHERE session_id = ?1",
121            [session_id],
122            map_aggregate,
123        )
124        .optional()
125        .map_err(Into::into)
126}
127
128fn aggregate_raw(store: &Store, session_id: &str, rebuilt_at_ms: u64) -> Result<SessionAggregate> {
129    store
130        .conn()
131        .query_row(
132            "SELECT COUNT(*), COALESCE(SUM(kind='ToolCall'),0), COALESCE(SUM(kind='Error'),0),
133                COALESCE(SUM(tokens_in),0), COALESCE(SUM(tokens_out),0),
134                COALESCE(SUM(reasoning_tokens),0), COALESCE(SUM(cache_read_tokens),0),
135                COALESCE(SUM(cache_creation_tokens),0), COALESCE(SUM(cost_usd_e6),0),
136                MIN(ts_ms), MAX(ts_ms)
137         FROM events WHERE session_id = ?1",
138            [session_id],
139            |row| {
140                Ok(SessionAggregate {
141                    session_id: session_id.to_string(),
142                    event_count: row.get::<_, i64>(0)? as u64,
143                    tool_call_count: row.get::<_, i64>(1)? as u64,
144                    error_count: row.get::<_, i64>(2)? as u64,
145                    tokens_in: row.get::<_, i64>(3)? as u64,
146                    tokens_out: row.get::<_, i64>(4)? as u64,
147                    reasoning_tokens: row.get::<_, i64>(5)? as u64,
148                    cache_read_tokens: row.get::<_, i64>(6)? as u64,
149                    cache_creation_tokens: row.get::<_, i64>(7)? as u64,
150                    cost_usd_e6: row.get(8)?,
151                    first_event_ms: row.get::<_, Option<i64>>(9)?.map(|v| v as u64),
152                    last_event_ms: row.get::<_, Option<i64>>(10)?.map(|v| v as u64),
153                    rebuilt_at_ms,
154                })
155            },
156        )
157        .map_err(Into::into)
158}
159
160fn map_aggregate(row: &rusqlite::Row<'_>) -> rusqlite::Result<SessionAggregate> {
161    Ok(SessionAggregate {
162        session_id: row.get(0)?,
163        event_count: row.get::<_, i64>(1)? as u64,
164        tool_call_count: row.get::<_, i64>(2)? as u64,
165        error_count: row.get::<_, i64>(3)? as u64,
166        tokens_in: row.get::<_, i64>(4)? as u64,
167        tokens_out: row.get::<_, i64>(5)? as u64,
168        reasoning_tokens: row.get::<_, i64>(6)? as u64,
169        cache_read_tokens: row.get::<_, i64>(7)? as u64,
170        cache_creation_tokens: row.get::<_, i64>(8)? as u64,
171        cost_usd_e6: row.get(9)?,
172        first_event_ms: row.get::<_, Option<i64>>(10)?.map(|v| v as u64),
173        last_event_ms: row.get::<_, Option<i64>>(11)?.map(|v| v as u64),
174        rebuilt_at_ms: row.get::<_, i64>(12)? as u64,
175    })
176}
177
178fn now_ms() -> u64 {
179    std::time::SystemTime::now()
180        .duration_since(std::time::UNIX_EPOCH)
181        .unwrap_or_default()
182        .as_millis() as u64
183}