Skip to main content

kaizen/store/
tool_span_index.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2//! Rebuild tool spans from one session event stream.
3
4use crate::core::event::{Event, EventKind, EventSource};
5use crate::store::event_index::paths_from_event_payload;
6use anyhow::Result;
7use rusqlite::{Connection, params};
8use std::collections::{BTreeSet, HashMap};
9
10#[derive(Debug, Clone, Default, PartialEq, Eq)]
11pub(crate) struct SpanBuilder {
12    pub span_id: String,
13    pub session_id: String,
14    pub tool: Option<String>,
15    pub tool_call_id: Option<String>,
16    pub hook_start_ms: Option<u64>,
17    pub hook_end_ms: Option<u64>,
18    pub call_start_ms: Option<u64>,
19    pub result_end_ms: Option<u64>,
20    pub call_start_exact: bool,
21    pub result_end_exact: bool,
22    pub tokens_in: Option<u32>,
23    pub tokens_out: Option<u32>,
24    pub reasoning_tokens: Option<u32>,
25    pub cost_usd_e6: Option<i64>,
26    pub paths: BTreeSet<String>,
27    pub has_call: bool,
28    pub has_end: bool,
29    pub parent_span_id: Option<String>,
30    pub depth: u32,
31    pub subtree_cost_usd_e6: Option<i64>,
32    pub subtree_token_count: Option<u32>,
33}
34
35#[derive(Debug, Clone, PartialEq, Eq)]
36pub struct ToolSpanRecord {
37    pub span_id: String,
38    pub session_id: String,
39    pub tool: Option<String>,
40    pub tool_call_id: Option<String>,
41    pub status: String,
42    pub started_at_ms: Option<u64>,
43    pub ended_at_ms: Option<u64>,
44    pub lead_time_ms: Option<u64>,
45    pub tokens_in: Option<u32>,
46    pub tokens_out: Option<u32>,
47    pub reasoning_tokens: Option<u32>,
48    pub cost_usd_e6: Option<i64>,
49    pub paths: Vec<String>,
50    pub parent_span_id: Option<String>,
51    pub depth: u32,
52    pub subtree_cost_usd_e6: Option<i64>,
53    pub subtree_token_count: Option<u32>,
54}
55
56pub fn rebuild_tool_spans_for_session(conn: &Connection, session_id: &str) -> Result<()> {
57    let events = load_session_events(conn, session_id)?;
58    clear_session_spans(conn, session_id)?;
59    for span in final_span_records(&events) {
60        upsert_tool_span_record(conn, &span)?;
61    }
62    Ok(())
63}
64
65pub(crate) fn span_start(s: &SpanBuilder) -> Option<u64> {
66    s.hook_start_ms.or(s.call_start_ms)
67}
68
69pub(crate) fn span_end(s: &SpanBuilder) -> Option<u64> {
70    s.hook_end_ms.or(s.result_end_ms)
71}
72
73pub(crate) fn assign_parents(spans: &mut [SpanBuilder]) {
74    // Sort: start ASC, end DESC so outer spans appear before inner spans.
75    spans.sort_by(|a, b| {
76        let sa = span_start(a).unwrap_or(u64::MAX);
77        let sb = span_start(b).unwrap_or(u64::MAX);
78        sa.cmp(&sb).then_with(|| {
79            let ea = span_end(a).unwrap_or(0);
80            let eb = span_end(b).unwrap_or(0);
81            eb.cmp(&ea)
82        })
83    });
84    for i in 0..spans.len() {
85        let (s_start, s_end) = match (span_start(&spans[i]), span_end(&spans[i])) {
86            (Some(s), Some(e)) => (s, e),
87            _ => continue,
88        };
89        let mut best: Option<(usize, u32)> = None;
90        for (j, candidate) in spans[..i].iter().enumerate() {
91            let (p_start, p_end) = match (span_start(candidate), span_end(candidate)) {
92                (Some(s), Some(e)) => (s, e),
93                _ => continue,
94            };
95            if p_start <= s_start && s_end <= p_end {
96                let d = candidate.depth;
97                if best.is_none_or(|(_, bd)| d > bd) {
98                    best = Some((j, d));
99                }
100            }
101        }
102        if let Some((pi, pd)) = best {
103            let pid = spans[pi].span_id.clone();
104            spans[i].parent_span_id = Some(pid);
105            spans[i].depth = pd + 1;
106        }
107    }
108}
109
110pub(crate) fn compute_subtree_costs(spans: &mut [SpanBuilder]) {
111    // Seed each span's subtree with its own cost/tokens.
112    for s in spans.iter_mut() {
113        s.subtree_cost_usd_e6 = s.cost_usd_e6;
114        s.subtree_token_count = s.tokens_in.map(|i| i + s.tokens_out.unwrap_or(0));
115    }
116    // Build an index: span_id → index in vec.
117    let ids: Vec<String> = spans.iter().map(|s| s.span_id.clone()).collect();
118    // Bottom-up: iterate in reverse depth order (deepest first).
119    let order: Vec<usize> = {
120        let mut v: Vec<usize> = (0..spans.len()).collect();
121        v.sort_by_key(|&i| u32::MAX - spans[i].depth);
122        v
123    };
124    for i in order {
125        let (cost, tokens, pid) = (
126            spans[i].subtree_cost_usd_e6,
127            spans[i].subtree_token_count,
128            spans[i].parent_span_id.clone(),
129        );
130        let Some(parent_id) = pid else { continue };
131        let Some(pi) = ids.iter().position(|id| id == &parent_id) else {
132            continue;
133        };
134        if let Some(c) = cost {
135            spans[pi].subtree_cost_usd_e6 = Some(spans[pi].subtree_cost_usd_e6.unwrap_or(0) + c);
136        }
137        if let Some(t) = tokens {
138            spans[pi].subtree_token_count = Some(spans[pi].subtree_token_count.unwrap_or(0) + t);
139        }
140    }
141}
142
143pub(crate) fn clear_session_spans(conn: &Connection, session_id: &str) -> Result<()> {
144    conn.execute(
145        "DELETE FROM tool_span_paths
146         WHERE span_id IN (SELECT span_id FROM tool_spans WHERE session_id = ?1)",
147        params![session_id],
148    )?;
149    conn.execute(
150        "DELETE FROM tool_spans WHERE session_id = ?1",
151        params![session_id],
152    )?;
153    Ok(())
154}
155
156pub(crate) fn final_span_records(events: &[Event]) -> Vec<ToolSpanRecord> {
157    let mut spans = build_spans(events);
158    assign_parents(&mut spans);
159    compute_subtree_costs(&mut spans);
160    spans.iter().map(ToolSpanRecord::from_builder).collect()
161}
162
163pub(crate) fn build_spans(events: &[Event]) -> Vec<SpanBuilder> {
164    let mut spans: HashMap<String, SpanBuilder> = HashMap::new();
165    let mut open_order: Vec<String> = Vec::new();
166    for event in events {
167        if !matches!(
168            event.kind,
169            EventKind::ToolCall | EventKind::ToolResult | EventKind::Hook
170        ) {
171            continue;
172        }
173        match event.kind {
174            EventKind::ToolCall => handle_tool_call(event, &mut spans, &mut open_order),
175            EventKind::ToolResult => handle_tool_result(event, &mut spans, &open_order),
176            EventKind::Hook => handle_hook(event, &mut spans, &mut open_order),
177            _ => {}
178        }
179    }
180    spans.into_values().collect()
181}
182
183pub(crate) fn handle_tool_call(
184    event: &Event,
185    spans: &mut HashMap<String, SpanBuilder>,
186    open_order: &mut Vec<String>,
187) {
188    let tool = event.tool.clone();
189    let existing = tool
190        .as_deref()
191        .and_then(|name| find_open_without_call(spans, open_order, name));
192    let span_id = event
193        .tool_call_id
194        .clone()
195        .unwrap_or_else(|| existing.unwrap_or_else(|| synthetic_span_id(event)));
196    let span = spans.entry(span_id.clone()).or_insert_with(|| SpanBuilder {
197        span_id: span_id.clone(),
198        session_id: event.session_id.clone(),
199        tool: tool.clone(),
200        tool_call_id: event.tool_call_id.clone(),
201        ..Default::default()
202    });
203    span.tool = tool;
204    span.tool_call_id = event.tool_call_id.clone();
205    span.call_start_ms = Some(event.ts_ms);
206    span.call_start_exact = event.ts_exact;
207    span.tokens_in = pick_u32(span.tokens_in, event.tokens_in);
208    span.tokens_out = pick_u32(span.tokens_out, event.tokens_out);
209    span.reasoning_tokens = pick_u32(span.reasoning_tokens, event.reasoning_tokens);
210    span.cost_usd_e6 = pick_i64(span.cost_usd_e6, event.cost_usd_e6);
211    span.paths.extend(paths_from_event_payload(&event.payload));
212    span.has_call = true;
213    if !open_order.iter().any(|id| id == &span_id) {
214        open_order.push(span_id);
215    }
216}
217
218pub(crate) fn handle_tool_result(
219    event: &Event,
220    spans: &mut HashMap<String, SpanBuilder>,
221    open_order: &[String],
222) {
223    let Some(span_id) = match_span_id(event, spans, open_order) else {
224        return;
225    };
226    let Some(span) = spans.get_mut(&span_id) else {
227        return;
228    };
229    span.result_end_ms = Some(event.ts_ms);
230    span.result_end_exact = event.ts_exact;
231    span.tokens_in = pick_u32(span.tokens_in, event.tokens_in);
232    span.tokens_out = pick_u32(span.tokens_out, event.tokens_out);
233    span.reasoning_tokens = pick_u32(span.reasoning_tokens, event.reasoning_tokens);
234    span.cost_usd_e6 = pick_i64(span.cost_usd_e6, event.cost_usd_e6);
235    span.paths.extend(paths_from_event_payload(&event.payload));
236    span.has_end = true;
237}
238
239pub(crate) fn handle_hook(
240    event: &Event,
241    spans: &mut HashMap<String, SpanBuilder>,
242    open_order: &mut Vec<String>,
243) {
244    let Some(kind) = hook_kind(&event.payload) else {
245        return;
246    };
247    let tool = hook_tool(&event.payload);
248    let span_id = event
249        .tool_call_id
250        .clone()
251        .or_else(|| {
252            tool.as_deref()
253                .and_then(|name| find_open_same_tool(spans, open_order, name))
254        })
255        .unwrap_or_else(|| synthetic_span_id(event));
256    let span = spans.entry(span_id.clone()).or_insert_with(|| SpanBuilder {
257        span_id: span_id.clone(),
258        session_id: event.session_id.clone(),
259        tool: tool.clone(),
260        tool_call_id: event.tool_call_id.clone(),
261        ..Default::default()
262    });
263    span.tool = span.tool.clone().or(tool);
264    span.tool_call_id = span.tool_call_id.clone().or(event.tool_call_id.clone());
265    span.paths.extend(paths_from_event_payload(&event.payload));
266    match kind {
267        "pre" => span.hook_start_ms = Some(event.ts_ms),
268        "post" => {
269            span.hook_end_ms = Some(event.ts_ms);
270            span.has_end = true;
271        }
272        _ => {}
273    }
274    if !open_order.iter().any(|id| id == &span_id) {
275        open_order.push(span_id);
276    }
277}
278
279fn load_session_events(conn: &Connection, session_id: &str) -> Result<Vec<Event>> {
280    let mut stmt = conn.prepare(
281        "SELECT session_id, seq, ts_ms, COALESCE(ts_exact, 0), kind, source, tool,
282                tool_call_id, tokens_in, tokens_out, reasoning_tokens, cost_usd_e6, payload
283         FROM events WHERE session_id = ?1 ORDER BY ts_ms ASC, seq ASC",
284    )?;
285    let rows = stmt.query_map(params![session_id], |row| {
286        let kind = match row.get::<_, String>(4)?.as_str() {
287            "ToolCall" => EventKind::ToolCall,
288            "ToolResult" => EventKind::ToolResult,
289            "Message" => EventKind::Message,
290            "Error" => EventKind::Error,
291            "Cost" => EventKind::Cost,
292            _ => EventKind::Hook,
293        };
294        let source = match row.get::<_, String>(5)?.as_str() {
295            "Tail" => EventSource::Tail,
296            "Proxy" => EventSource::Proxy,
297            _ => EventSource::Hook,
298        };
299        let payload: String = row.get(12)?;
300        Ok(Event {
301            session_id: row.get(0)?,
302            seq: row.get::<_, i64>(1)? as u64,
303            ts_ms: row.get::<_, i64>(2)? as u64,
304            ts_exact: row.get::<_, i64>(3)? != 0,
305            kind,
306            source,
307            tool: row.get(6)?,
308            tool_call_id: row.get(7)?,
309            tokens_in: row.get::<_, Option<i64>>(8)?.map(|v| v as u32),
310            tokens_out: row.get::<_, Option<i64>>(9)?.map(|v| v as u32),
311            reasoning_tokens: row.get::<_, Option<i64>>(10)?.map(|v| v as u32),
312            cost_usd_e6: row.get(11)?,
313            stop_reason: None,
314            latency_ms: None,
315            ttft_ms: None,
316            retry_count: None,
317            context_used_tokens: None,
318            context_max_tokens: None,
319            cache_creation_tokens: None,
320            cache_read_tokens: None,
321            system_prompt_tokens: None,
322            payload: serde_json::from_str(&payload).unwrap_or(serde_json::Value::Null),
323        })
324    })?;
325    Ok(rows.filter_map(|row| row.ok()).collect())
326}
327
328pub(crate) fn match_span_id(
329    event: &Event,
330    spans: &HashMap<String, SpanBuilder>,
331    open_order: &[String],
332) -> Option<String> {
333    if let Some(id) = event
334        .tool_call_id
335        .as_ref()
336        .filter(|id| spans.contains_key(*id))
337    {
338        return Some(id.clone());
339    }
340    event
341        .tool
342        .as_deref()
343        .and_then(|name| find_open_same_tool(spans, open_order, name))
344        .or_else(|| open_order.last().cloned())
345}
346
347pub(crate) fn find_open_without_call(
348    spans: &HashMap<String, SpanBuilder>,
349    open_order: &[String],
350    tool: &str,
351) -> Option<String> {
352    open_order.iter().rev().find_map(|id| {
353        spans.get(id).and_then(|span| {
354            if span.tool.as_deref() == Some(tool) && !span.has_call {
355                Some(id.clone())
356            } else {
357                None
358            }
359        })
360    })
361}
362
363pub(crate) fn find_open_same_tool(
364    spans: &HashMap<String, SpanBuilder>,
365    open_order: &[String],
366    tool: &str,
367) -> Option<String> {
368    open_order.iter().rev().find_map(|id| {
369        spans.get(id).and_then(|span| {
370            if span.tool.as_deref() == Some(tool) && !span.has_end {
371                Some(id.clone())
372            } else {
373                None
374            }
375        })
376    })
377}
378
379pub(crate) fn synthetic_span_id(event: &Event) -> String {
380    format!("{}:{}:{}", event.session_id, event.seq, event.ts_ms)
381}
382
383pub(crate) fn hook_kind(payload: &serde_json::Value) -> Option<&'static str> {
384    let raw = payload
385        .get("event")
386        .and_then(|v| v.as_str())
387        .or_else(|| payload.get("hook_event_name").and_then(|v| v.as_str()))?;
388    match raw {
389        "PreToolUse" | "pre_tool_use" => Some("pre"),
390        "PostToolUse" | "post_tool_use" => Some("post"),
391        _ => None,
392    }
393}
394
395pub(crate) fn hook_tool(payload: &serde_json::Value) -> Option<String> {
396    ["tool_name", "tool", "name"]
397        .iter()
398        .find_map(|k| payload.get(k).and_then(|v| v.as_str()))
399        .map(ToOwned::to_owned)
400}
401
402pub(crate) fn pick_u32(current: Option<u32>, next: Option<u32>) -> Option<u32> {
403    next.or(current)
404}
405
406pub(crate) fn pick_i64(current: Option<i64>, next: Option<i64>) -> Option<i64> {
407    next.or(current)
408}
409
410impl ToolSpanRecord {
411    pub(crate) fn from_builder(span: &SpanBuilder) -> Self {
412        let lead = span
413            .hook_start_ms
414            .zip(span.hook_end_ms)
415            .map(|(a, b)| b.saturating_sub(a))
416            .or_else(|| {
417                if span.call_start_exact && span.result_end_exact {
418                    span.call_start_ms
419                        .zip(span.result_end_ms)
420                        .map(|(a, b)| b.saturating_sub(a))
421                } else {
422                    None
423                }
424            });
425        let started = span_start(span);
426        let ended = span_end(span);
427        let status = if started.is_some() && ended.is_some() {
428            "done"
429        } else {
430            "orphaned"
431        };
432        Self {
433            span_id: span.span_id.clone(),
434            session_id: span.session_id.clone(),
435            tool: span.tool.clone(),
436            tool_call_id: span.tool_call_id.clone(),
437            status: status.to_string(),
438            started_at_ms: started,
439            ended_at_ms: ended,
440            lead_time_ms: lead,
441            tokens_in: span.tokens_in,
442            tokens_out: span.tokens_out,
443            reasoning_tokens: span.reasoning_tokens,
444            cost_usd_e6: span.cost_usd_e6,
445            paths: span.paths.iter().cloned().collect(),
446            parent_span_id: span.parent_span_id.clone(),
447            depth: span.depth,
448            subtree_cost_usd_e6: span.subtree_cost_usd_e6,
449            subtree_token_count: span.subtree_token_count,
450        }
451    }
452}
453
454pub(crate) fn upsert_tool_span_record(conn: &Connection, span: &ToolSpanRecord) -> Result<()> {
455    conn.execute(
456        "INSERT INTO tool_spans (
457            span_id, session_id, tool, tool_call_id, status,
458            started_at_ms, ended_at_ms, lead_time_ms,
459            tokens_in, tokens_out, reasoning_tokens, cost_usd_e6, paths_json,
460            parent_span_id, depth, subtree_cost_usd_e6, subtree_token_count
461         ) VALUES (?1,?2,?3,?4,?5,?6,?7,?8,?9,?10,?11,?12,?13,?14,?15,?16,?17)
462         ON CONFLICT(span_id) DO UPDATE SET
463            session_id=excluded.session_id,
464            tool=excluded.tool,
465            tool_call_id=excluded.tool_call_id,
466            status=excluded.status,
467            started_at_ms=excluded.started_at_ms,
468            ended_at_ms=excluded.ended_at_ms,
469            lead_time_ms=excluded.lead_time_ms,
470            tokens_in=excluded.tokens_in,
471            tokens_out=excluded.tokens_out,
472            reasoning_tokens=excluded.reasoning_tokens,
473            cost_usd_e6=excluded.cost_usd_e6,
474            paths_json=excluded.paths_json,
475            parent_span_id=excluded.parent_span_id,
476            depth=excluded.depth,
477            subtree_cost_usd_e6=excluded.subtree_cost_usd_e6,
478            subtree_token_count=excluded.subtree_token_count",
479        params![
480            &span.span_id,
481            &span.session_id,
482            span.tool.as_deref(),
483            span.tool_call_id.as_deref(),
484            &span.status,
485            span.started_at_ms.map(|v| v as i64),
486            span.ended_at_ms.map(|v| v as i64),
487            span.lead_time_ms.map(|v| v as i64),
488            span.tokens_in.map(|v| v as i64),
489            span.tokens_out.map(|v| v as i64),
490            span.reasoning_tokens.map(|v| v as i64),
491            span.cost_usd_e6,
492            serde_json::to_string(&span.paths)?,
493            span.parent_span_id.as_deref(),
494            span.depth as i64,
495            span.subtree_cost_usd_e6,
496            span.subtree_token_count.map(|v| v as i64),
497        ],
498    )?;
499    conn.execute(
500        "DELETE FROM tool_span_paths WHERE span_id = ?1",
501        params![&span.span_id],
502    )?;
503    for path in &span.paths {
504        conn.execute(
505            "INSERT INTO tool_span_paths (span_id, path) VALUES (?1, ?2)",
506            params![&span.span_id, path],
507        )?;
508    }
509    Ok(())
510}