Skip to main content

kaizen/retro/
engine.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2//! Merge heuristics, dedupe vs prior reports, rank top bets.
3
4use crate::retro::heuristics;
5use crate::retro::types::{Inputs, Report, RetroMeta, RetroStats};
6use std::collections::{HashMap, HashSet};
7
8const TOP_N: usize = 5;
9
10/// Pure ranking step after `Inputs` are assembled.
11pub fn run(inputs: &Inputs, prior_bet_ids: &HashSet<String>) -> Report {
12    let mut candidates = heuristics::all_bets(inputs);
13    candidates.sort_by(|a, b| {
14        b.score()
15            .partial_cmp(&a.score())
16            .unwrap_or(std::cmp::Ordering::Equal)
17            .then_with(|| b.evidence_recency_ms.cmp(&a.evidence_recency_ms))
18            .then_with(|| a.id.cmp(&b.id))
19    });
20
21    let mut skipped = Vec::new();
22    let mut seen: HashSet<String> = HashSet::new();
23    let mut top = Vec::new();
24    for bet in candidates {
25        if prior_bet_ids.contains(&bet.id) {
26            skipped.push(format!("{} ({})", bet.title, bet.id));
27            continue;
28        }
29        if !seen.insert(bet.id.clone()) {
30            continue;
31        }
32        if top.len() < TOP_N {
33            top.push(bet);
34        }
35    }
36
37    let session_count = inputs.aggregates.unique_session_ids.len() as u64;
38    let (top_model, top_model_pct) = top_model_share(&inputs.aggregates.model_session_counts);
39    let (top_tool, top_tool_pct) = top_tool_share(&inputs.aggregates.tool_event_counts);
40    let median_min = median_session_minutes(inputs);
41
42    Report {
43        meta: RetroMeta {
44            week_label: String::new(),
45            span_start_ms: inputs.window_start_ms,
46            span_end_ms: inputs.window_end_ms,
47            session_count,
48            total_cost_usd_e6: inputs.aggregates.total_cost_usd_e6,
49        },
50        top_bets: top,
51        skipped_deduped: skipped,
52        stats: RetroStats {
53            sessions: session_count,
54            total_cost_usd_e6: inputs.aggregates.total_cost_usd_e6,
55            top_model,
56            top_model_pct,
57            top_tool,
58            top_tool_pct,
59            median_session_minutes: median_min,
60        },
61    }
62}
63
64fn top_model_share(m: &HashMap<String, u64>) -> (Option<String>, Option<u64>) {
65    let total: u64 = m.values().sum();
66    if total == 0 {
67        return (None, None);
68    }
69    let (k, v) = m.iter().max_by_key(|(_, c)| *c).unwrap();
70    let pct = (*v * 100) / total;
71    (Some(k.clone()), Some(pct))
72}
73
74fn top_tool_share(m: &HashMap<String, u64>) -> (Option<String>, Option<u64>) {
75    let total: u64 = m.values().sum();
76    if total == 0 {
77        return (None, None);
78    }
79    let (k, v) = m.iter().max_by_key(|(_, c)| *c).unwrap();
80    let pct = (*v * 100) / total;
81    (Some(k.clone()), Some(pct))
82}
83
84fn median_session_minutes(inputs: &Inputs) -> Option<u64> {
85    let mut by_id: HashMap<String, (u64, Option<u64>)> = HashMap::new();
86    for (s, _) in &inputs.events {
87        by_id
88            .entry(s.id.clone())
89            .or_insert((s.started_at_ms, s.ended_at_ms));
90    }
91    let mut durations: Vec<u64> = by_id
92        .into_values()
93        .map(|(start, end)| {
94            let e = end.unwrap_or(inputs.window_end_ms);
95            e.saturating_sub(start) / 60_000
96        })
97        .collect();
98    if durations.is_empty() {
99        return None;
100    }
101    durations.sort_unstable();
102    Some(durations[durations.len() / 2])
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108    use crate::core::event::{Event, EventKind, EventSource, SessionRecord, SessionStatus};
109    use crate::retro::types::{RetroAggregates, SkillFileOnDisk};
110    use serde_json::json;
111    use std::collections::HashSet;
112
113    fn minimal_inputs() -> Inputs {
114        let mut agg = RetroAggregates::default();
115        agg.unique_session_ids.insert("s1".into());
116        agg.tool_event_counts.insert("read_file".into(), 20);
117        agg.tool_event_counts.insert("x".into(), 2);
118        Inputs {
119            window_start_ms: 0,
120            window_end_ms: 1000,
121            events: vec![(
122                SessionRecord {
123                    id: "s1".into(),
124                    agent: "cursor".into(),
125                    model: Some("m".into()),
126                    workspace: "/w".into(),
127                    started_at_ms: 0,
128                    ended_at_ms: Some(120_000),
129                    status: SessionStatus::Done,
130                    trace_path: "".into(),
131                    start_commit: None,
132                    end_commit: None,
133                    branch: None,
134                    dirty_start: None,
135                    dirty_end: None,
136                    repo_binding_source: None,
137                },
138                Event {
139                    session_id: "s1".into(),
140                    seq: 0,
141                    ts_ms: 100,
142                    ts_exact: false,
143                    kind: EventKind::ToolCall,
144                    source: EventSource::Tail,
145                    tool: Some("read_file".into()),
146                    tool_call_id: None,
147                    tokens_in: None,
148                    tokens_out: None,
149                    reasoning_tokens: None,
150                    cost_usd_e6: None,
151                    payload: json!({}),
152                },
153            )],
154            files_touched: vec![],
155            skills_used: vec![],
156            tool_spans: vec![],
157            skills_used_recent_slugs: HashSet::new(),
158            usage_lookback_ms: 0,
159            skill_files_on_disk: vec![SkillFileOnDisk {
160                slug: "z".into(),
161                size_bytes: 100,
162                mtime_ms: 0,
163            }],
164            rule_files_on_disk: vec![],
165            rules_used_recent_slugs: HashSet::new(),
166            file_facts: HashMap::new(),
167            eval_scores: vec![],
168            aggregates: agg,
169        }
170    }
171
172    #[test]
173    fn dedupes_prior_ids() {
174        let inputs = minimal_inputs();
175        let mut prior = HashSet::new();
176        prior.insert("H4:read_file".into());
177        let r = run(&inputs, &prior);
178        assert!(r.top_bets.iter().all(|b| b.id != "H4:read_file"));
179        assert!(!r.skipped_deduped.is_empty() || r.top_bets.len() <= 4);
180    }
181}