Skip to main content

kaizen/retro/
engine.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2//! Merge heuristics, dedupe vs prior reports, rank top bets.
3
4use crate::retro::heuristics;
5use crate::retro::types::{Inputs, Report, RetroMeta, RetroStats};
6use std::collections::{HashMap, HashSet};
7
8const TOP_N: usize = 5;
9
10/// Pure ranking step after `Inputs` are assembled.
11pub fn run(inputs: &Inputs, prior_bet_ids: &HashSet<String>) -> Report {
12    let mut candidates = heuristics::all_bets(inputs);
13    candidates.sort_by(|a, b| {
14        b.score()
15            .partial_cmp(&a.score())
16            .unwrap_or(std::cmp::Ordering::Equal)
17            .then_with(|| b.evidence_recency_ms.cmp(&a.evidence_recency_ms))
18            .then_with(|| a.id.cmp(&b.id))
19    });
20
21    let mut skipped = Vec::new();
22    let mut seen: HashSet<String> = HashSet::new();
23    let mut top = Vec::new();
24    for bet in candidates {
25        if prior_bet_ids.contains(&bet.id) {
26            skipped.push(format!("{} ({})", bet.title, bet.id));
27            continue;
28        }
29        if !seen.insert(bet.id.clone()) {
30            continue;
31        }
32        if top.len() < TOP_N {
33            top.push(bet);
34        }
35    }
36
37    let session_count = inputs.aggregates.unique_session_ids.len() as u64;
38    let (top_model, top_model_pct) = top_model_share(&inputs.aggregates.model_session_counts);
39    let (top_tool, top_tool_pct) = top_tool_share(&inputs.aggregates.tool_event_counts);
40    let median_min = median_session_minutes(inputs);
41
42    Report {
43        meta: RetroMeta {
44            week_label: String::new(),
45            span_start_ms: inputs.window_start_ms,
46            span_end_ms: inputs.window_end_ms,
47            session_count,
48            total_cost_usd_e6: inputs.aggregates.total_cost_usd_e6,
49        },
50        top_bets: top,
51        skipped_deduped: skipped,
52        stats: RetroStats {
53            sessions: session_count,
54            total_cost_usd_e6: inputs.aggregates.total_cost_usd_e6,
55            top_model,
56            top_model_pct,
57            top_tool,
58            top_tool_pct,
59            median_session_minutes: median_min,
60        },
61    }
62}
63
64fn top_model_share(m: &HashMap<String, u64>) -> (Option<String>, Option<u64>) {
65    let total: u64 = m.values().sum();
66    if total == 0 {
67        return (None, None);
68    }
69    let (k, v) = m.iter().max_by_key(|(_, c)| *c).unwrap();
70    let pct = (*v * 100) / total;
71    (Some(k.clone()), Some(pct))
72}
73
74fn top_tool_share(m: &HashMap<String, u64>) -> (Option<String>, Option<u64>) {
75    let total: u64 = m.values().sum();
76    if total == 0 {
77        return (None, None);
78    }
79    let (k, v) = m.iter().max_by_key(|(_, c)| *c).unwrap();
80    let pct = (*v * 100) / total;
81    (Some(k.clone()), Some(pct))
82}
83
84fn median_session_minutes(inputs: &Inputs) -> Option<u64> {
85    let mut by_id: HashMap<String, (u64, Option<u64>)> = HashMap::new();
86    for (s, _) in &inputs.events {
87        by_id
88            .entry(s.id.clone())
89            .or_insert((s.started_at_ms, s.ended_at_ms));
90    }
91    let mut durations: Vec<u64> = by_id
92        .into_values()
93        .map(|(start, end)| {
94            let e = end.unwrap_or(inputs.window_end_ms);
95            e.saturating_sub(start) / 60_000
96        })
97        .collect();
98    if durations.is_empty() {
99        return None;
100    }
101    durations.sort_unstable();
102    Some(durations[durations.len() / 2])
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108    use crate::core::event::{Event, EventKind, EventSource, SessionRecord, SessionStatus};
109    use crate::retro::types::{RetroAggregates, SkillFileOnDisk};
110    use serde_json::json;
111    use std::collections::HashSet;
112
113    fn minimal_inputs() -> Inputs {
114        let mut agg = RetroAggregates::default();
115        agg.unique_session_ids.insert("s1".into());
116        agg.tool_event_counts.insert("read_file".into(), 20);
117        agg.tool_event_counts.insert("x".into(), 2);
118        Inputs {
119            window_start_ms: 0,
120            window_end_ms: 1000,
121            events: vec![(
122                SessionRecord {
123                    id: "s1".into(),
124                    agent: "cursor".into(),
125                    model: Some("m".into()),
126                    workspace: "/w".into(),
127                    started_at_ms: 0,
128                    ended_at_ms: Some(120_000),
129                    status: SessionStatus::Done,
130                    trace_path: "".into(),
131                    start_commit: None,
132                    end_commit: None,
133                    branch: None,
134                    dirty_start: None,
135                    dirty_end: None,
136                    repo_binding_source: None,
137                    prompt_fingerprint: None,
138                    parent_session_id: None,
139                    agent_version: None,
140                    os: None,
141                    arch: None,
142                    repo_file_count: None,
143                    repo_total_loc: None,
144                },
145                Event {
146                    session_id: "s1".into(),
147                    seq: 0,
148                    ts_ms: 100,
149                    ts_exact: false,
150                    kind: EventKind::ToolCall,
151                    source: EventSource::Tail,
152                    tool: Some("read_file".into()),
153                    tool_call_id: None,
154                    tokens_in: None,
155                    tokens_out: None,
156                    reasoning_tokens: None,
157                    cost_usd_e6: None,
158                    stop_reason: None,
159                    latency_ms: None,
160                    ttft_ms: None,
161                    retry_count: None,
162                    context_used_tokens: None,
163                    context_max_tokens: None,
164                    cache_creation_tokens: None,
165                    cache_read_tokens: None,
166                    system_prompt_tokens: None,
167                    payload: json!({}),
168                },
169            )],
170            files_touched: vec![],
171            skills_used: vec![],
172            tool_spans: vec![],
173            skills_used_recent_slugs: HashSet::new(),
174            usage_lookback_ms: 0,
175            skill_files_on_disk: vec![SkillFileOnDisk {
176                slug: "z".into(),
177                size_bytes: 100,
178                mtime_ms: 0,
179            }],
180            rule_files_on_disk: vec![],
181            rules_used_recent_slugs: HashSet::new(),
182            file_facts: HashMap::new(),
183            eval_scores: vec![],
184            aggregates: agg,
185            prompt_fingerprints: vec![],
186            feedback: vec![],
187            session_outcomes: vec![],
188            session_sample_aggs: vec![],
189        }
190    }
191
192    #[test]
193    fn dedupes_prior_ids() {
194        let inputs = minimal_inputs();
195        let mut prior = HashSet::new();
196        prior.insert("H4:read_file".into());
197        let r = run(&inputs, &prior);
198        assert!(r.top_bets.iter().all(|b| b.id != "H4:read_file"));
199        assert!(!r.skipped_deduped.is_empty() || r.top_bets.len() <= 4);
200    }
201}