Skip to main content

kaizen/experiment/
engine.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2//! Turn an `Experiment` + sessions into a report. Pure compute given inputs.
3
4use crate::core::event::{Event, SessionRecord};
5use crate::experiment::binding::{ManualTags, partition};
6use crate::experiment::metric::value_for;
7use crate::experiment::stats::sequential::{Decision, decide as seq_decide};
8use crate::experiment::stats::{DEFAULT_RESAMPLES, Summary, summarize};
9use crate::experiment::types::{Classification, Criterion, Direction, Experiment, GuardrailResult};
10use serde::{Deserialize, Serialize};
11use std::path::Path;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct Report {
15    pub experiment: Experiment,
16    pub summary: Summary,
17    pub excluded_count: usize,
18    pub target_met: Option<bool>,
19    pub guardrail_results: Vec<GuardrailResult>,
20    /// Sequential decision: sticky once Significant.
21    pub sequential_decision: Decision,
22    /// Pass this back on the next `run` call to preserve stickiness.
23    pub ever_significant: bool,
24}
25
26/// Pure ranking step once `sessions` + per-session `events` gathered.
27///
28/// Pass the previous `ever_significant` from the last report to preserve
29/// sequential stickiness across incremental calls.
30pub fn run(
31    exp: &Experiment,
32    sessions: &[(SessionRecord, Vec<Event>)],
33    manual_tags: &ManualTags,
34    workspace: &Path,
35    ever_significant: bool,
36) -> Report {
37    let records: Vec<SessionRecord> = sessions.iter().map(|(s, _)| s.clone()).collect();
38    let (control_s, treatment_s, excluded_s) =
39        partition(&records, &exp.binding, manual_tags, workspace);
40    let control = metric_values(
41        exp,
42        sessions,
43        &control_s,
44        Classification::Control,
45        manual_tags,
46    );
47    let treatment = metric_values(
48        exp,
49        sessions,
50        &treatment_s,
51        Classification::Treatment,
52        manual_tags,
53    );
54    let _ = excluded_s;
55    let excluded = records.len() - control.len() - treatment.len();
56    let summary = summarize(
57        &control,
58        &treatment,
59        stable_seed(&exp.id),
60        DEFAULT_RESAMPLES,
61    );
62    let target_met = evaluate_criterion(&exp.success_criterion, &summary);
63    let seq = seq_decide(
64        &control,
65        &treatment,
66        stable_seed(&exp.id),
67        DEFAULT_RESAMPLES,
68        ever_significant,
69    );
70    let guardrail_results = exp
71        .guardrails
72        .iter()
73        .map(|g| {
74            let gvals_c = metric_values_for(g.metric, sessions, &control_s);
75            let gvals_t = metric_values_for(g.metric, sessions, &treatment_s);
76            let gs = summarize(&gvals_c, &gvals_t, stable_seed(&exp.id), DEFAULT_RESAMPLES);
77            let violated = match g.regression_direction {
78                Direction::Increase => gs.ci95_lo.map(|lo| lo > 0.0).unwrap_or(false),
79                Direction::Decrease => gs.ci95_hi.map(|hi| hi < 0.0).unwrap_or(false),
80            };
81            GuardrailResult {
82                metric: g.metric,
83                delta_pct: gs.delta_pct,
84                violated,
85            }
86        })
87        .collect();
88    Report {
89        experiment: exp.clone(),
90        summary,
91        excluded_count: excluded,
92        target_met,
93        guardrail_results,
94        sequential_decision: seq.decision,
95        ever_significant: seq.ever_significant,
96    }
97}
98
99fn metric_values(
100    exp: &Experiment,
101    sessions: &[(SessionRecord, Vec<Event>)],
102    picked: &[&SessionRecord],
103    _which: Classification,
104    _tags: &ManualTags,
105) -> Vec<f64> {
106    metric_values_for(exp.metric, sessions, picked)
107}
108
109fn metric_values_for(
110    metric: crate::experiment::types::Metric,
111    sessions: &[(SessionRecord, Vec<Event>)],
112    picked: &[&SessionRecord],
113) -> Vec<f64> {
114    let ids: std::collections::HashSet<&str> = picked.iter().map(|s| s.id.as_str()).collect();
115    sessions
116        .iter()
117        .filter(|(s, _)| ids.contains(s.id.as_str()))
118        .filter_map(|(s, evs)| value_for(metric, s, evs))
119        .collect()
120}
121
122/// Success when the 95% CI excludes zero in the declared direction.
123/// `delta_pct` is display-only and not used here.
124fn evaluate_criterion(c: &Criterion, s: &Summary) -> Option<bool> {
125    match c {
126        Criterion::Delta { direction, .. } => {
127            let lo = s.ci95_lo?;
128            let hi = s.ci95_hi?;
129            Some(match direction {
130                Direction::Decrease => hi < 0.0,
131                Direction::Increase => lo > 0.0,
132            })
133        }
134        Criterion::Absolute { metric_value } => {
135            let m = s.median_treatment?;
136            Some(m <= *metric_value)
137        }
138    }
139}
140
141fn stable_seed(id: &str) -> u64 {
142    let mut h: u64 = 1469598103934665603;
143    for b in id.as_bytes() {
144        h ^= *b as u64;
145        h = h.wrapping_mul(1099511628211);
146    }
147    h
148}
149
150/// Human-readable markdown per `docs/experiments.md`.
151pub fn to_markdown(report: &Report) -> String {
152    let e = &report.experiment;
153    let s = &report.summary;
154    let mut out = String::new();
155    out.push_str(&format!("# Experiment: {}\n\n", e.name));
156    out.push_str(&format!(
157        "State: {:?} · Duration: {}d\nHypothesis: {}\nChange: {}\n\n",
158        e.state, e.duration_days, e.hypothesis, e.change_description
159    ));
160    let (ctl_label, trt_label) = match &e.binding {
161        crate::experiment::types::Binding::GitCommit {
162            control_commit,
163            treatment_commit,
164        } => (short(control_commit), short(treatment_commit)),
165        crate::experiment::types::Binding::Branch {
166            control_branch,
167            treatment_branch,
168        } => (control_branch.clone(), treatment_branch.clone()),
169        crate::experiment::types::Binding::ManualTag { variant_field } => {
170            (format!("manual:{}", variant_field), "manual".into())
171        }
172    };
173    out.push_str(&format!(
174        "Binding: control {} · treatment {}\nMetric: {}\n\n",
175        ctl_label,
176        trt_label,
177        e.metric.as_str()
178    ));
179    out.push_str("|          | N  | median | mean |\n|---|---|---|---|\n");
180    out.push_str(&format!(
181        "| control  | {} | {} | {} |\n",
182        s.n_control,
183        fmt_opt(s.median_control),
184        fmt_opt(s.mean_control),
185    ));
186    out.push_str(&format!(
187        "| treatment| {} | {} | {} |\n\n",
188        s.n_treatment,
189        fmt_opt(s.median_treatment),
190        fmt_opt(s.mean_treatment),
191    ));
192    if let Some(d) = s.delta_median {
193        out.push_str(&format!(
194            "Delta (median): {:+.2}{}\n",
195            d,
196            s.delta_pct
197                .map(|p| format!(" ({:+.1}%)", p))
198                .unwrap_or_default(),
199        ));
200    }
201    if let (Some(lo), Some(hi)) = (s.ci95_lo, s.ci95_hi) {
202        out.push_str(&format!(
203            "95% bootstrap CI on delta: [{:+.2}, {:+.2}]\n",
204            lo, hi
205        ));
206    }
207    if let Some(met) = report.target_met {
208        out.push_str(&format!(
209            "Target: {}\n",
210            if met { "MET" } else { "not met" }
211        ));
212    }
213    out.push_str(&format!("\nExcluded: {} sessions\n", report.excluded_count));
214    if s.small_sample_warning {
215        out.push_str("Warning: N per arm < 30 — CI may be unreliable.\n");
216    }
217    out
218}
219
220fn fmt_opt(v: Option<f64>) -> String {
221    v.map(|x| format!("{:.2}", x)).unwrap_or_else(|| "—".into())
222}
223
224fn short(commit: &str) -> String {
225    commit.chars().take(7).collect()
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231    use crate::core::event::SessionStatus;
232    use crate::experiment::types::{Binding, Criterion, Direction, Metric, State};
233
234    fn exp() -> Experiment {
235        Experiment {
236            id: "e".into(),
237            name: "e".into(),
238            hypothesis: "h".into(),
239            change_description: "c".into(),
240            metric: Metric::TokensPerSession,
241            binding: Binding::GitCommit {
242                control_commit: "c".into(),
243                treatment_commit: "t".into(),
244            },
245            duration_days: 14,
246            success_criterion: Criterion::Delta {
247                direction: Direction::Decrease,
248                target_pct: -10.0,
249            },
250            state: State::Running,
251            created_at_ms: 0,
252            concluded_at_ms: None,
253            guardrails: Vec::new(),
254        }
255    }
256
257    fn session_with(id: &str, tokens: u32) -> (SessionRecord, Vec<Event>) {
258        let s = SessionRecord {
259            id: id.into(),
260            agent: "cursor".into(),
261            model: None,
262            workspace: "/ws".into(),
263            started_at_ms: 0,
264            ended_at_ms: None,
265            status: SessionStatus::Done,
266            trace_path: String::new(),
267            start_commit: None,
268            end_commit: None,
269            branch: None,
270            dirty_start: None,
271            dirty_end: None,
272            repo_binding_source: None,
273            prompt_fingerprint: None,
274            parent_session_id: None,
275            agent_version: None,
276            os: None,
277            arch: None,
278            repo_file_count: None,
279            repo_total_loc: None,
280        };
281        let mut ev = Event {
282            session_id: id.into(),
283            seq: 0,
284            ts_ms: 0,
285            ts_exact: false,
286            kind: crate::core::event::EventKind::ToolCall,
287            source: crate::core::event::EventSource::Tail,
288            tool: None,
289            tool_call_id: None,
290            tokens_in: Some(tokens),
291            tokens_out: None,
292            reasoning_tokens: None,
293            cost_usd_e6: None,
294            stop_reason: None,
295            latency_ms: None,
296            ttft_ms: None,
297            retry_count: None,
298            context_used_tokens: None,
299            context_max_tokens: None,
300            cache_creation_tokens: None,
301            cache_read_tokens: None,
302            system_prompt_tokens: None,
303            payload: serde_json::Value::Null,
304        };
305        ev.tokens_in = Some(tokens);
306        (s, vec![ev])
307    }
308
309    #[test]
310    fn evaluate_criterion_ci_excludes_zero_decrease() {
311        let c = Criterion::Delta {
312            direction: Direction::Decrease,
313            target_pct: -10.0,
314        };
315        let s = Summary {
316            ci95_lo: Some(-20.0),
317            ci95_hi: Some(-5.0),
318            ..Default::default()
319        };
320        assert_eq!(evaluate_criterion(&c, &s), Some(true)); // entire CI < 0
321
322        let s = Summary {
323            ci95_lo: Some(-20.0),
324            ci95_hi: Some(2.0),
325            ..Default::default()
326        };
327        assert_eq!(evaluate_criterion(&c, &s), Some(false)); // CI straddles zero
328    }
329
330    #[test]
331    fn evaluate_criterion_ci_excludes_zero_increase() {
332        let c = Criterion::Delta {
333            direction: Direction::Increase,
334            target_pct: 10.0,
335        };
336        let s = Summary {
337            ci95_lo: Some(5.0),
338            ci95_hi: Some(20.0),
339            ..Default::default()
340        };
341        assert_eq!(evaluate_criterion(&c, &s), Some(true)); // entire CI > 0
342
343        let s = Summary {
344            ci95_lo: Some(-2.0),
345            ci95_hi: Some(20.0),
346            ..Default::default()
347        };
348        assert_eq!(evaluate_criterion(&c, &s), Some(false)); // CI straddles zero
349    }
350
351    #[test]
352    fn manual_tags_drive_partition_without_git() {
353        let e = exp();
354        let sessions = vec![
355            session_with("a", 100),
356            session_with("b", 80),
357            session_with("c", 200),
358            session_with("d", 70),
359        ];
360        let mut tags = ManualTags::new();
361        tags.insert("a".into(), Classification::Control);
362        tags.insert("b".into(), Classification::Control);
363        tags.insert("c".into(), Classification::Treatment);
364        tags.insert("d".into(), Classification::Treatment);
365        let r = run(&e, &sessions, &tags, Path::new("/no"), false);
366        assert_eq!(r.summary.n_control, 2);
367        assert_eq!(r.summary.n_treatment, 2);
368    }
369}