Skip to main content

kaizen/experiment/
engine.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2//! Turn an `Experiment` + sessions into a report. Pure compute given inputs.
3
4use crate::core::event::{Event, SessionRecord};
5use crate::experiment::binding::{ManualTags, partition};
6use crate::experiment::metric::value_for;
7use crate::experiment::stats::sequential::{Decision, decide as seq_decide};
8use crate::experiment::stats::{DEFAULT_RESAMPLES, Summary, summarize};
9use crate::experiment::types::{Classification, Criterion, Direction, Experiment, GuardrailResult};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::path::Path;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct Report {
16    pub experiment: Experiment,
17    pub summary: Summary,
18    pub excluded_count: usize,
19    pub target_met: Option<bool>,
20    pub guardrail_results: Vec<GuardrailResult>,
21    /// Sequential decision: sticky once Significant.
22    pub sequential_decision: Decision,
23    /// Pass this back on the next `run` call to preserve stickiness.
24    pub ever_significant: bool,
25}
26
27/// Pure ranking step once `sessions` + per-session `events` gathered.
28///
29/// Pass the previous `ever_significant` from the last report to preserve
30/// sequential stickiness across incremental calls.
31pub fn run(
32    exp: &Experiment,
33    sessions: &[(SessionRecord, Vec<Event>)],
34    manual_tags: &ManualTags,
35    workspace: &Path,
36    ever_significant: bool,
37) -> Report {
38    let records: Vec<SessionRecord> = sessions.iter().map(|(s, _)| s.clone()).collect();
39    let (control_s, treatment_s, excluded_s) =
40        partition(&records, &exp.binding, manual_tags, workspace);
41    let control = metric_values(
42        exp,
43        sessions,
44        &control_s,
45        Classification::Control,
46        manual_tags,
47    );
48    let treatment = metric_values(
49        exp,
50        sessions,
51        &treatment_s,
52        Classification::Treatment,
53        manual_tags,
54    );
55    let _ = excluded_s;
56    let excluded = records.len() - control.len() - treatment.len();
57    let summary = summarize(
58        &control,
59        &treatment,
60        stable_seed(&exp.id),
61        DEFAULT_RESAMPLES,
62    );
63    let target_met = evaluate_criterion(&exp.success_criterion, &summary);
64    let seq = seq_decide(
65        &control,
66        &treatment,
67        stable_seed(&exp.id),
68        DEFAULT_RESAMPLES,
69        ever_significant,
70    );
71    let guardrail_results = exp
72        .guardrails
73        .iter()
74        .map(|g| {
75            let gvals_c = metric_values_for(g.metric, sessions, &control_s);
76            let gvals_t = metric_values_for(g.metric, sessions, &treatment_s);
77            let gs = summarize(&gvals_c, &gvals_t, stable_seed(&exp.id), DEFAULT_RESAMPLES);
78            let violated = match g.regression_direction {
79                Direction::Increase => gs.ci95_lo.map(|lo| lo > 0.0).unwrap_or(false),
80                Direction::Decrease => gs.ci95_hi.map(|hi| hi < 0.0).unwrap_or(false),
81            };
82            GuardrailResult {
83                metric: g.metric,
84                delta_pct: gs.delta_pct,
85                violated,
86            }
87        })
88        .collect();
89    Report {
90        experiment: exp.clone(),
91        summary,
92        excluded_count: excluded,
93        target_met,
94        guardrail_results,
95        sequential_decision: seq.decision,
96        ever_significant: seq.ever_significant,
97    }
98}
99
100pub fn run_from_metric_values(
101    exp: &Experiment,
102    sessions: &[SessionRecord],
103    values: &HashMap<String, f64>,
104    guardrail_values: &HashMap<crate::experiment::types::Metric, HashMap<String, f64>>,
105    manual_tags: &ManualTags,
106    workspace: &Path,
107    ever_significant: bool,
108) -> Report {
109    let (control_s, treatment_s, excluded_s) =
110        partition(sessions, &exp.binding, manual_tags, workspace);
111    let control = values_for(&control_s, values);
112    let treatment = values_for(&treatment_s, values);
113    let summary = summarize(
114        &control,
115        &treatment,
116        stable_seed(&exp.id),
117        DEFAULT_RESAMPLES,
118    );
119    let target_met = evaluate_criterion(&exp.success_criterion, &summary);
120    let seq = seq_decide(
121        &control,
122        &treatment,
123        stable_seed(&exp.id),
124        DEFAULT_RESAMPLES,
125        ever_significant,
126    );
127    let guardrail_results = exp
128        .guardrails
129        .iter()
130        .map(|g| {
131            let empty = HashMap::new();
132            let vals = guardrail_values.get(&g.metric).unwrap_or(&empty);
133            let gs = summarize(
134                &values_for(&control_s, vals),
135                &values_for(&treatment_s, vals),
136                stable_seed(&exp.id),
137                DEFAULT_RESAMPLES,
138            );
139            let violated = match g.regression_direction {
140                Direction::Increase => gs.ci95_lo.map(|lo| lo > 0.0).unwrap_or(false),
141                Direction::Decrease => gs.ci95_hi.map(|hi| hi < 0.0).unwrap_or(false),
142            };
143            GuardrailResult {
144                metric: g.metric,
145                delta_pct: gs.delta_pct,
146                violated,
147            }
148        })
149        .collect();
150    Report {
151        experiment: exp.clone(),
152        summary,
153        excluded_count: excluded_s.len(),
154        target_met,
155        guardrail_results,
156        sequential_decision: seq.decision,
157        ever_significant: seq.ever_significant,
158    }
159}
160
161fn values_for(sessions: &[&SessionRecord], values: &HashMap<String, f64>) -> Vec<f64> {
162    sessions
163        .iter()
164        .filter_map(|s| values.get(&s.id).copied())
165        .collect()
166}
167
168fn metric_values(
169    exp: &Experiment,
170    sessions: &[(SessionRecord, Vec<Event>)],
171    picked: &[&SessionRecord],
172    _which: Classification,
173    _tags: &ManualTags,
174) -> Vec<f64> {
175    metric_values_for(exp.metric, sessions, picked)
176}
177
178fn metric_values_for(
179    metric: crate::experiment::types::Metric,
180    sessions: &[(SessionRecord, Vec<Event>)],
181    picked: &[&SessionRecord],
182) -> Vec<f64> {
183    let ids: std::collections::HashSet<&str> = picked.iter().map(|s| s.id.as_str()).collect();
184    sessions
185        .iter()
186        .filter(|(s, _)| ids.contains(s.id.as_str()))
187        .filter_map(|(s, evs)| value_for(metric, s, evs))
188        .collect()
189}
190
191/// Success when the 95% CI excludes zero in the declared direction.
192/// `delta_pct` is display-only and not used here.
193fn evaluate_criterion(c: &Criterion, s: &Summary) -> Option<bool> {
194    match c {
195        Criterion::Delta { direction, .. } => {
196            let lo = s.ci95_lo?;
197            let hi = s.ci95_hi?;
198            Some(match direction {
199                Direction::Decrease => hi < 0.0,
200                Direction::Increase => lo > 0.0,
201            })
202        }
203        Criterion::Absolute { metric_value } => {
204            let m = s.median_treatment?;
205            Some(m <= *metric_value)
206        }
207    }
208}
209
210fn stable_seed(id: &str) -> u64 {
211    let mut h: u64 = 1469598103934665603;
212    for b in id.as_bytes() {
213        h ^= *b as u64;
214        h = h.wrapping_mul(1099511628211);
215    }
216    h
217}
218
219/// Human-readable markdown per `docs/experiments.md`.
220pub fn to_markdown(report: &Report) -> String {
221    let e = &report.experiment;
222    let s = &report.summary;
223    let mut out = String::new();
224    out.push_str(&format!("# Experiment: {}\n\n", e.name));
225    out.push_str(&format!(
226        "State: {:?} · Duration: {}d\nHypothesis: {}\nChange: {}\n\n",
227        e.state, e.duration_days, e.hypothesis, e.change_description
228    ));
229    let (ctl_label, trt_label) = match &e.binding {
230        crate::experiment::types::Binding::GitCommit {
231            control_commit,
232            treatment_commit,
233        } => (short(control_commit), short(treatment_commit)),
234        crate::experiment::types::Binding::Branch {
235            control_branch,
236            treatment_branch,
237        } => (control_branch.clone(), treatment_branch.clone()),
238        crate::experiment::types::Binding::PromptFingerprint {
239            control_fingerprint,
240            treatment_fingerprint,
241        } => (short(control_fingerprint), short(treatment_fingerprint)),
242        crate::experiment::types::Binding::ManualTag { variant_field } => {
243            (format!("manual:{}", variant_field), "manual".into())
244        }
245    };
246    out.push_str(&format!(
247        "Binding: control {} · treatment {}\nMetric: {}\n\n",
248        ctl_label,
249        trt_label,
250        e.metric.as_str()
251    ));
252    out.push_str("|          | N  | median | mean |\n|---|---|---|---|\n");
253    out.push_str(&format!(
254        "| control  | {} | {} | {} |\n",
255        s.n_control,
256        fmt_opt(s.median_control),
257        fmt_opt(s.mean_control),
258    ));
259    out.push_str(&format!(
260        "| treatment| {} | {} | {} |\n\n",
261        s.n_treatment,
262        fmt_opt(s.median_treatment),
263        fmt_opt(s.mean_treatment),
264    ));
265    if let Some(d) = s.delta_median {
266        out.push_str(&format!(
267            "Delta (median): {:+.2}{}\n",
268            d,
269            s.delta_pct
270                .map(|p| format!(" ({:+.1}%)", p))
271                .unwrap_or_default(),
272        ));
273    }
274    if let (Some(lo), Some(hi)) = (s.ci95_lo, s.ci95_hi) {
275        out.push_str(&format!(
276            "95% bootstrap CI on delta: [{:+.2}, {:+.2}]\n",
277            lo, hi
278        ));
279    }
280    if let Some(met) = report.target_met {
281        out.push_str(&format!(
282            "Target: {}\n",
283            if met { "MET" } else { "not met" }
284        ));
285    }
286    out.push_str(&format!("\nExcluded: {} sessions\n", report.excluded_count));
287    if s.small_sample_warning {
288        out.push_str("Warning: N per arm < 30 — CI may be unreliable.\n");
289    }
290    out
291}
292
293fn fmt_opt(v: Option<f64>) -> String {
294    v.map(|x| format!("{:.2}", x)).unwrap_or_else(|| "—".into())
295}
296
297fn short(commit: &str) -> String {
298    commit.chars().take(7).collect()
299}
300
301#[cfg(test)]
302mod tests {
303    use super::*;
304    use crate::core::event::SessionStatus;
305    use crate::experiment::types::{Binding, Criterion, Direction, Metric, State};
306
307    fn exp() -> Experiment {
308        Experiment {
309            id: "e".into(),
310            name: "e".into(),
311            hypothesis: "h".into(),
312            change_description: "c".into(),
313            metric: Metric::TokensPerSession,
314            binding: Binding::GitCommit {
315                control_commit: "c".into(),
316                treatment_commit: "t".into(),
317            },
318            duration_days: 14,
319            success_criterion: Criterion::Delta {
320                direction: Direction::Decrease,
321                target_pct: -10.0,
322            },
323            state: State::Running,
324            created_at_ms: 0,
325            concluded_at_ms: None,
326            guardrails: Vec::new(),
327        }
328    }
329
330    fn session_with(id: &str, tokens: u32) -> (SessionRecord, Vec<Event>) {
331        let s = SessionRecord {
332            id: id.into(),
333            agent: "cursor".into(),
334            model: None,
335            workspace: "/ws".into(),
336            started_at_ms: 0,
337            ended_at_ms: None,
338            status: SessionStatus::Done,
339            trace_path: String::new(),
340            start_commit: None,
341            end_commit: None,
342            branch: None,
343            dirty_start: None,
344            dirty_end: None,
345            repo_binding_source: None,
346            prompt_fingerprint: None,
347            parent_session_id: None,
348            agent_version: None,
349            os: None,
350            arch: None,
351            repo_file_count: None,
352            repo_total_loc: None,
353        };
354        let mut ev = Event {
355            session_id: id.into(),
356            seq: 0,
357            ts_ms: 0,
358            ts_exact: false,
359            kind: crate::core::event::EventKind::ToolCall,
360            source: crate::core::event::EventSource::Tail,
361            tool: None,
362            tool_call_id: None,
363            tokens_in: Some(tokens),
364            tokens_out: None,
365            reasoning_tokens: None,
366            cost_usd_e6: None,
367            stop_reason: None,
368            latency_ms: None,
369            ttft_ms: None,
370            retry_count: None,
371            context_used_tokens: None,
372            context_max_tokens: None,
373            cache_creation_tokens: None,
374            cache_read_tokens: None,
375            system_prompt_tokens: None,
376            payload: serde_json::Value::Null,
377        };
378        ev.tokens_in = Some(tokens);
379        (s, vec![ev])
380    }
381
382    #[test]
383    fn evaluate_criterion_ci_excludes_zero_decrease() {
384        let c = Criterion::Delta {
385            direction: Direction::Decrease,
386            target_pct: -10.0,
387        };
388        let s = Summary {
389            ci95_lo: Some(-20.0),
390            ci95_hi: Some(-5.0),
391            ..Default::default()
392        };
393        assert_eq!(evaluate_criterion(&c, &s), Some(true)); // entire CI < 0
394
395        let s = Summary {
396            ci95_lo: Some(-20.0),
397            ci95_hi: Some(2.0),
398            ..Default::default()
399        };
400        assert_eq!(evaluate_criterion(&c, &s), Some(false)); // CI straddles zero
401    }
402
403    #[test]
404    fn evaluate_criterion_ci_excludes_zero_increase() {
405        let c = Criterion::Delta {
406            direction: Direction::Increase,
407            target_pct: 10.0,
408        };
409        let s = Summary {
410            ci95_lo: Some(5.0),
411            ci95_hi: Some(20.0),
412            ..Default::default()
413        };
414        assert_eq!(evaluate_criterion(&c, &s), Some(true)); // entire CI > 0
415
416        let s = Summary {
417            ci95_lo: Some(-2.0),
418            ci95_hi: Some(20.0),
419            ..Default::default()
420        };
421        assert_eq!(evaluate_criterion(&c, &s), Some(false)); // CI straddles zero
422    }
423
424    #[test]
425    fn manual_tags_drive_partition_without_git() {
426        let e = exp();
427        let sessions = vec![
428            session_with("a", 100),
429            session_with("b", 80),
430            session_with("c", 200),
431            session_with("d", 70),
432        ];
433        let mut tags = ManualTags::new();
434        tags.insert("a".into(), Classification::Control);
435        tags.insert("b".into(), Classification::Control);
436        tags.insert("c".into(), Classification::Treatment);
437        tags.insert("d".into(), Classification::Treatment);
438        let r = run(&e, &sessions, &tags, Path::new("/no"), false);
439        assert_eq!(r.summary.n_control, 2);
440        assert_eq!(r.summary.n_treatment, 2);
441    }
442}