Skip to main content

kaizen/retro/
engine.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2//! Merge heuristics, dedupe vs prior reports, rank top bets.
3
4use crate::retro::heuristics;
5use crate::retro::types::{Bet, BetCategory, Confidence, Inputs, Report, RetroMeta, RetroStats};
6use std::collections::{HashMap, HashSet};
7
8const TOP_BET_N: usize = 1;
9const INVESTIGATE_N: usize = 2;
10const HYGIENE_N: usize = 2;
11
12/// Pure ranking step after `Inputs` are assembled.
13pub fn run(inputs: &Inputs, prior_bet_ids: &HashSet<String>) -> Report {
14    let mut candidates = heuristics::all_bets(inputs);
15    candidates.iter_mut().for_each(enrich_bet);
16    candidates.sort_by(|a, b| {
17        b.score()
18            .partial_cmp(&a.score())
19            .unwrap_or(std::cmp::Ordering::Equal)
20            .then_with(|| b.evidence_recency_ms.cmp(&a.evidence_recency_ms))
21            .then_with(|| a.id.cmp(&b.id))
22    });
23
24    let (available, skipped) = available_candidates(candidates, prior_bet_ids);
25    let top = select_grouped_bets(&available);
26
27    let session_count = inputs.aggregates.unique_session_ids.len() as u64;
28    let (top_model, top_model_pct) = top_model_share(&inputs.aggregates.model_session_counts);
29    let (top_tool, top_tool_pct) = top_tool_share(&inputs.aggregates.tool_event_counts);
30    let median_min = median_session_minutes(inputs);
31
32    Report {
33        meta: RetroMeta {
34            week_label: String::new(),
35            span_start_ms: inputs.window_start_ms,
36            span_end_ms: inputs.window_end_ms,
37            session_count,
38            total_cost_usd_e6: inputs.aggregates.total_cost_usd_e6,
39        },
40        top_bets: top,
41        skipped_deduped: skipped,
42        stats: RetroStats {
43            sessions: session_count,
44            total_cost_usd_e6: inputs.aggregates.total_cost_usd_e6,
45            top_model,
46            top_model_pct,
47            top_tool,
48            top_tool_pct,
49            median_session_minutes: median_min,
50        },
51    }
52}
53
54fn enrich_bet(bet: &mut Bet) {
55    let (confidence, category) = heuristic_metadata(&bet.heuristic_id);
56    bet.confidence = Some(confidence);
57    bet.category = Some(category);
58}
59
60fn heuristic_metadata(heuristic_id: &str) -> (Confidence, BetCategory) {
61    match heuristic_id {
62        "H1" | "H29" => (Confidence::High, BetCategory::QuickWin),
63        "H9" | "H10" | "H12" | "H19" | "H21" | "H27" | "H33" => {
64            (Confidence::High, BetCategory::Investigation)
65        }
66        "H2" | "H3" | "H11" | "H14" | "H22" | "H24" => {
67            (Confidence::Medium, BetCategory::Investigation)
68        }
69        "H7" | "H20" | "H28" | "H30" | "H31" | "H32" => (Confidence::Medium, BetCategory::Hygiene),
70        _ => (Confidence::Low, BetCategory::Hygiene),
71    }
72}
73
74fn available_candidates(
75    candidates: Vec<Bet>,
76    prior_bet_ids: &HashSet<String>,
77) -> (Vec<Bet>, Vec<String>) {
78    let mut skipped = Vec::new();
79    let mut seen: HashSet<String> = HashSet::new();
80    let mut available = Vec::new();
81    for bet in candidates {
82        if prior_bet_ids.contains(&bet.id) {
83            skipped.push(format!("{} ({})", bet.title, bet.id));
84        } else if seen.insert(bet.id.clone()) {
85            available.push(bet);
86        }
87    }
88    (available, skipped)
89}
90
91fn select_grouped_bets(candidates: &[Bet]) -> Vec<Bet> {
92    let mut selected_ids: HashSet<String> = HashSet::new();
93    let mut top = Vec::new();
94    push_matches(&mut top, &mut selected_ids, candidates, TOP_BET_N, |b| {
95        b.confidence == Some(Confidence::High)
96    });
97    push_matches(
98        &mut top,
99        &mut selected_ids,
100        candidates,
101        INVESTIGATE_N,
102        |b| {
103            b.category == Some(BetCategory::Investigation)
104                && matches!(b.confidence, Some(Confidence::High | Confidence::Medium))
105        },
106    );
107    push_matches(&mut top, &mut selected_ids, candidates, HYGIENE_N, |b| {
108        matches!(
109            b.category,
110            Some(BetCategory::QuickWin | BetCategory::Hygiene)
111        )
112    });
113    top
114}
115
116fn push_matches<F>(
117    out: &mut Vec<Bet>,
118    selected_ids: &mut HashSet<String>,
119    candidates: &[Bet],
120    limit: usize,
121    mut pred: F,
122) where
123    F: FnMut(&Bet) -> bool,
124{
125    let mut added = 0;
126    for bet in candidates {
127        if added == limit {
128            break;
129        }
130        if pred(bet) && selected_ids.insert(bet.id.clone()) {
131            out.push(bet.clone());
132            added += 1;
133        }
134    }
135}
136
137fn top_model_share(m: &HashMap<String, u64>) -> (Option<String>, Option<u64>) {
138    let total: u64 = m.values().sum();
139    if total == 0 {
140        return (None, None);
141    }
142    let (k, v) = m.iter().max_by_key(|(_, c)| *c).unwrap();
143    let pct = (*v * 100) / total;
144    (Some(k.clone()), Some(pct))
145}
146
147fn top_tool_share(m: &HashMap<String, u64>) -> (Option<String>, Option<u64>) {
148    let total: u64 = m.values().sum();
149    if total == 0 {
150        return (None, None);
151    }
152    let (k, v) = m.iter().max_by_key(|(_, c)| *c).unwrap();
153    let pct = (*v * 100) / total;
154    (Some(k.clone()), Some(pct))
155}
156
157fn median_session_minutes(inputs: &Inputs) -> Option<u64> {
158    let mut by_id: HashMap<String, (u64, Option<u64>)> = HashMap::new();
159    for (s, _) in &inputs.events {
160        by_id
161            .entry(s.id.clone())
162            .or_insert((s.started_at_ms, s.ended_at_ms));
163    }
164    let mut durations: Vec<u64> = by_id
165        .into_values()
166        .map(|(start, end)| {
167            let e = end.unwrap_or(inputs.window_end_ms);
168            e.saturating_sub(start) / 60_000
169        })
170        .collect();
171    if durations.is_empty() {
172        return None;
173    }
174    durations.sort_unstable();
175    Some(durations[durations.len() / 2])
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181    use crate::core::event::{Event, EventKind, EventSource, SessionRecord, SessionStatus};
182    use crate::retro::types::{Bet, RetroAggregates, SkillFileOnDisk};
183    use serde_json::json;
184    use std::collections::HashSet;
185
186    fn minimal_inputs() -> Inputs {
187        let mut agg = RetroAggregates::default();
188        agg.unique_session_ids.insert("s1".into());
189        agg.tool_event_counts.insert("read_file".into(), 20);
190        agg.tool_event_counts.insert("x".into(), 2);
191        Inputs {
192            window_start_ms: 0,
193            window_end_ms: 1000,
194            events: vec![(
195                SessionRecord {
196                    id: "s1".into(),
197                    agent: "cursor".into(),
198                    model: Some("m".into()),
199                    workspace: "/w".into(),
200                    started_at_ms: 0,
201                    ended_at_ms: Some(120_000),
202                    status: SessionStatus::Done,
203                    trace_path: "".into(),
204                    start_commit: None,
205                    end_commit: None,
206                    branch: None,
207                    dirty_start: None,
208                    dirty_end: None,
209                    repo_binding_source: None,
210                    prompt_fingerprint: None,
211                    parent_session_id: None,
212                    agent_version: None,
213                    os: None,
214                    arch: None,
215                    repo_file_count: None,
216                    repo_total_loc: None,
217                },
218                Event {
219                    session_id: "s1".into(),
220                    seq: 0,
221                    ts_ms: 100,
222                    ts_exact: false,
223                    kind: EventKind::ToolCall,
224                    source: EventSource::Tail,
225                    tool: Some("read_file".into()),
226                    tool_call_id: None,
227                    tokens_in: None,
228                    tokens_out: None,
229                    reasoning_tokens: None,
230                    cost_usd_e6: None,
231                    stop_reason: None,
232                    latency_ms: None,
233                    ttft_ms: None,
234                    retry_count: None,
235                    context_used_tokens: None,
236                    context_max_tokens: None,
237                    cache_creation_tokens: None,
238                    cache_read_tokens: None,
239                    system_prompt_tokens: None,
240                    payload: json!({}),
241                },
242            )],
243            files_touched: vec![],
244            skills_used: vec![],
245            tool_spans: vec![],
246            skills_used_recent_slugs: HashSet::new(),
247            usage_lookback_ms: 0,
248            skill_files_on_disk: vec![SkillFileOnDisk {
249                slug: "z".into(),
250                size_bytes: 100,
251                mtime_ms: 0,
252            }],
253            rule_files_on_disk: vec![],
254            rules_used_recent_slugs: HashSet::new(),
255            file_facts: HashMap::new(),
256            eval_scores: vec![],
257            aggregates: agg,
258            prompt_fingerprints: vec![],
259            feedback: vec![],
260            session_outcomes: vec![],
261            session_sample_aggs: vec![],
262        }
263    }
264
265    #[test]
266    fn dedupes_prior_ids() {
267        let inputs = minimal_inputs();
268        let mut prior = HashSet::new();
269        prior.insert("H4:read_file".into());
270        let r = run(&inputs, &prior);
271        assert!(r.top_bets.iter().all(|b| b.id != "H4:read_file"));
272        assert!(!r.skipped_deduped.is_empty() || r.top_bets.len() <= 4);
273    }
274
275    #[test]
276    fn metadata_is_added_to_bets() {
277        let inputs = minimal_inputs();
278        let r = run(&inputs, &HashSet::new());
279        assert!(r.top_bets.iter().all(|b| b.confidence.is_some()));
280        assert!(r.top_bets.iter().all(|b| b.category.is_some()));
281    }
282
283    #[test]
284    fn selection_uses_one_two_two_shape() {
285        let mut bets = vec![
286            bet("H1:a", "H1", 1000.0),
287            bet("H9:a", "H9", 900.0),
288            bet("H2:a", "H2", 800.0),
289            bet("H7:a", "H7", 700.0),
290            bet("H4:a", "H4", 600.0),
291        ];
292        bets.iter_mut().for_each(enrich_bet);
293        bets.sort_by(|a, b| b.score().partial_cmp(&a.score()).unwrap());
294        let top = select_grouped_bets(&bets);
295        assert_eq!(
296            top.iter().map(|b| b.id.as_str()).collect::<Vec<_>>(),
297            vec!["H1:a", "H9:a", "H2:a", "H7:a", "H4:a"]
298        );
299    }
300
301    fn bet(id: &str, heuristic_id: &str, tokens: f64) -> Bet {
302        Bet {
303            id: id.into(),
304            heuristic_id: heuristic_id.into(),
305            title: id.into(),
306            hypothesis: String::new(),
307            expected_tokens_saved_per_week: tokens,
308            effort_minutes: 10,
309            evidence: vec![],
310            apply_step: String::new(),
311            evidence_recency_ms: 0,
312            confidence: None,
313            category: None,
314        }
315    }
316}