Skip to main content

kaizen/experiment/
engine.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2//! Turn an `Experiment` + sessions into a report. Pure compute given inputs.
3
4use crate::core::event::{Event, SessionRecord};
5use crate::experiment::binding::{ManualTags, partition};
6use crate::experiment::metric::value_for;
7use crate::experiment::stats::sequential::{Decision, decide as seq_decide};
8use crate::experiment::stats::{DEFAULT_RESAMPLES, Summary, summarize};
9use crate::experiment::types::{Classification, Criterion, Direction, Experiment, GuardrailResult};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::path::Path;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct Report {
16    pub experiment: Experiment,
17    pub summary: Summary,
18    pub excluded_count: usize,
19    pub target_met: Option<bool>,
20    pub guardrail_results: Vec<GuardrailResult>,
21    /// Sequential decision: sticky once Significant.
22    pub sequential_decision: Decision,
23    /// Pass this back on the next `run` call to preserve stickiness.
24    pub ever_significant: bool,
25}
26
27/// Pure ranking step once `sessions` + per-session `events` gathered.
28///
29/// Pass the previous `ever_significant` from the last report to preserve
30/// sequential stickiness across incremental calls.
31pub fn run(
32    exp: &Experiment,
33    sessions: &[(SessionRecord, Vec<Event>)],
34    manual_tags: &ManualTags,
35    workspace: &Path,
36    ever_significant: bool,
37) -> Report {
38    let records: Vec<SessionRecord> = sessions.iter().map(|(s, _)| s.clone()).collect();
39    let (control_s, treatment_s, excluded_s) =
40        partition(&records, &exp.binding, manual_tags, workspace);
41    let control = metric_values(
42        exp,
43        sessions,
44        &control_s,
45        Classification::Control,
46        manual_tags,
47    );
48    let treatment = metric_values(
49        exp,
50        sessions,
51        &treatment_s,
52        Classification::Treatment,
53        manual_tags,
54    );
55    let _ = excluded_s;
56    let excluded = records.len() - control.len() - treatment.len();
57    let summary = summarize(
58        &control,
59        &treatment,
60        stable_seed(&exp.id),
61        DEFAULT_RESAMPLES,
62    );
63    let target_met = evaluate_criterion(&exp.success_criterion, &summary);
64    let seq = seq_decide(
65        &control,
66        &treatment,
67        stable_seed(&exp.id),
68        DEFAULT_RESAMPLES,
69        ever_significant,
70    );
71    let guardrail_results = exp
72        .guardrails
73        .iter()
74        .map(|g| {
75            let gvals_c = metric_values_for(g.metric, sessions, &control_s);
76            let gvals_t = metric_values_for(g.metric, sessions, &treatment_s);
77            let gs = summarize(&gvals_c, &gvals_t, stable_seed(&exp.id), DEFAULT_RESAMPLES);
78            let violated = match g.regression_direction {
79                Direction::Increase => gs.ci95_lo.map(|lo| lo > 0.0).unwrap_or(false),
80                Direction::Decrease => gs.ci95_hi.map(|hi| hi < 0.0).unwrap_or(false),
81            };
82            GuardrailResult {
83                metric: g.metric,
84                delta_pct: gs.delta_pct,
85                violated,
86            }
87        })
88        .collect();
89    Report {
90        experiment: exp.clone(),
91        summary,
92        excluded_count: excluded,
93        target_met,
94        guardrail_results,
95        sequential_decision: seq.decision,
96        ever_significant: seq.ever_significant,
97    }
98}
99
100pub fn run_from_metric_values(
101    exp: &Experiment,
102    sessions: &[SessionRecord],
103    values: &HashMap<String, f64>,
104    guardrail_values: &HashMap<crate::experiment::types::Metric, HashMap<String, f64>>,
105    manual_tags: &ManualTags,
106    workspace: &Path,
107    ever_significant: bool,
108) -> Report {
109    let (control_s, treatment_s, excluded_s) =
110        partition(sessions, &exp.binding, manual_tags, workspace);
111    let control = values_for(&control_s, values);
112    let treatment = values_for(&treatment_s, values);
113    let summary = summarize(
114        &control,
115        &treatment,
116        stable_seed(&exp.id),
117        DEFAULT_RESAMPLES,
118    );
119    let target_met = evaluate_criterion(&exp.success_criterion, &summary);
120    let seq = seq_decide(
121        &control,
122        &treatment,
123        stable_seed(&exp.id),
124        DEFAULT_RESAMPLES,
125        ever_significant,
126    );
127    let guardrail_results = exp
128        .guardrails
129        .iter()
130        .map(|g| {
131            let empty = HashMap::new();
132            let vals = guardrail_values.get(&g.metric).unwrap_or(&empty);
133            let gs = summarize(
134                &values_for(&control_s, vals),
135                &values_for(&treatment_s, vals),
136                stable_seed(&exp.id),
137                DEFAULT_RESAMPLES,
138            );
139            let violated = match g.regression_direction {
140                Direction::Increase => gs.ci95_lo.map(|lo| lo > 0.0).unwrap_or(false),
141                Direction::Decrease => gs.ci95_hi.map(|hi| hi < 0.0).unwrap_or(false),
142            };
143            GuardrailResult {
144                metric: g.metric,
145                delta_pct: gs.delta_pct,
146                violated,
147            }
148        })
149        .collect();
150    Report {
151        experiment: exp.clone(),
152        summary,
153        excluded_count: excluded_s.len(),
154        target_met,
155        guardrail_results,
156        sequential_decision: seq.decision,
157        ever_significant: seq.ever_significant,
158    }
159}
160
161fn values_for(sessions: &[&SessionRecord], values: &HashMap<String, f64>) -> Vec<f64> {
162    sessions
163        .iter()
164        .filter_map(|s| values.get(&s.id).copied())
165        .collect()
166}
167
168fn metric_values(
169    exp: &Experiment,
170    sessions: &[(SessionRecord, Vec<Event>)],
171    picked: &[&SessionRecord],
172    _which: Classification,
173    _tags: &ManualTags,
174) -> Vec<f64> {
175    metric_values_for(exp.metric, sessions, picked)
176}
177
178fn metric_values_for(
179    metric: crate::experiment::types::Metric,
180    sessions: &[(SessionRecord, Vec<Event>)],
181    picked: &[&SessionRecord],
182) -> Vec<f64> {
183    let ids: std::collections::HashSet<&str> = picked.iter().map(|s| s.id.as_str()).collect();
184    sessions
185        .iter()
186        .filter(|(s, _)| ids.contains(s.id.as_str()))
187        .filter_map(|(s, evs)| value_for(metric, s, evs))
188        .collect()
189}
190
191/// Success when the 95% CI excludes zero in the declared direction.
192/// `delta_pct` is display-only and not used here.
193fn evaluate_criterion(c: &Criterion, s: &Summary) -> Option<bool> {
194    match c {
195        Criterion::Delta { direction, .. } => {
196            let lo = s.ci95_lo?;
197            let hi = s.ci95_hi?;
198            Some(match direction {
199                Direction::Decrease => hi < 0.0,
200                Direction::Increase => lo > 0.0,
201            })
202        }
203        Criterion::Absolute { metric_value } => {
204            let m = s.median_treatment?;
205            Some(m <= *metric_value)
206        }
207    }
208}
209
210fn stable_seed(id: &str) -> u64 {
211    let mut h: u64 = 1469598103934665603;
212    for b in id.as_bytes() {
213        h ^= *b as u64;
214        h = h.wrapping_mul(1099511628211);
215    }
216    h
217}
218
219/// Human-readable markdown per `docs/experiments.md`.
220pub fn to_markdown(report: &Report) -> String {
221    let e = &report.experiment;
222    let s = &report.summary;
223    let mut out = String::new();
224    out.push_str(&format!("# Experiment: {}\n\n", e.name));
225    out.push_str(&format!(
226        "State: {:?} · Duration: {}d\nHypothesis: {}\nChange: {}\n\n",
227        e.state, e.duration_days, e.hypothesis, e.change_description
228    ));
229    let (ctl_label, trt_label) = match &e.binding {
230        crate::experiment::types::Binding::GitCommit {
231            control_commit,
232            treatment_commit,
233        } => (short(control_commit), short(treatment_commit)),
234        crate::experiment::types::Binding::Branch {
235            control_branch,
236            treatment_branch,
237        } => (control_branch.clone(), treatment_branch.clone()),
238        crate::experiment::types::Binding::ManualTag { variant_field } => {
239            (format!("manual:{}", variant_field), "manual".into())
240        }
241    };
242    out.push_str(&format!(
243        "Binding: control {} · treatment {}\nMetric: {}\n\n",
244        ctl_label,
245        trt_label,
246        e.metric.as_str()
247    ));
248    out.push_str("|          | N  | median | mean |\n|---|---|---|---|\n");
249    out.push_str(&format!(
250        "| control  | {} | {} | {} |\n",
251        s.n_control,
252        fmt_opt(s.median_control),
253        fmt_opt(s.mean_control),
254    ));
255    out.push_str(&format!(
256        "| treatment| {} | {} | {} |\n\n",
257        s.n_treatment,
258        fmt_opt(s.median_treatment),
259        fmt_opt(s.mean_treatment),
260    ));
261    if let Some(d) = s.delta_median {
262        out.push_str(&format!(
263            "Delta (median): {:+.2}{}\n",
264            d,
265            s.delta_pct
266                .map(|p| format!(" ({:+.1}%)", p))
267                .unwrap_or_default(),
268        ));
269    }
270    if let (Some(lo), Some(hi)) = (s.ci95_lo, s.ci95_hi) {
271        out.push_str(&format!(
272            "95% bootstrap CI on delta: [{:+.2}, {:+.2}]\n",
273            lo, hi
274        ));
275    }
276    if let Some(met) = report.target_met {
277        out.push_str(&format!(
278            "Target: {}\n",
279            if met { "MET" } else { "not met" }
280        ));
281    }
282    out.push_str(&format!("\nExcluded: {} sessions\n", report.excluded_count));
283    if s.small_sample_warning {
284        out.push_str("Warning: N per arm < 30 — CI may be unreliable.\n");
285    }
286    out
287}
288
289fn fmt_opt(v: Option<f64>) -> String {
290    v.map(|x| format!("{:.2}", x)).unwrap_or_else(|| "—".into())
291}
292
293fn short(commit: &str) -> String {
294    commit.chars().take(7).collect()
295}
296
297#[cfg(test)]
298mod tests {
299    use super::*;
300    use crate::core::event::SessionStatus;
301    use crate::experiment::types::{Binding, Criterion, Direction, Metric, State};
302
303    fn exp() -> Experiment {
304        Experiment {
305            id: "e".into(),
306            name: "e".into(),
307            hypothesis: "h".into(),
308            change_description: "c".into(),
309            metric: Metric::TokensPerSession,
310            binding: Binding::GitCommit {
311                control_commit: "c".into(),
312                treatment_commit: "t".into(),
313            },
314            duration_days: 14,
315            success_criterion: Criterion::Delta {
316                direction: Direction::Decrease,
317                target_pct: -10.0,
318            },
319            state: State::Running,
320            created_at_ms: 0,
321            concluded_at_ms: None,
322            guardrails: Vec::new(),
323        }
324    }
325
326    fn session_with(id: &str, tokens: u32) -> (SessionRecord, Vec<Event>) {
327        let s = SessionRecord {
328            id: id.into(),
329            agent: "cursor".into(),
330            model: None,
331            workspace: "/ws".into(),
332            started_at_ms: 0,
333            ended_at_ms: None,
334            status: SessionStatus::Done,
335            trace_path: String::new(),
336            start_commit: None,
337            end_commit: None,
338            branch: None,
339            dirty_start: None,
340            dirty_end: None,
341            repo_binding_source: None,
342            prompt_fingerprint: None,
343            parent_session_id: None,
344            agent_version: None,
345            os: None,
346            arch: None,
347            repo_file_count: None,
348            repo_total_loc: None,
349        };
350        let mut ev = Event {
351            session_id: id.into(),
352            seq: 0,
353            ts_ms: 0,
354            ts_exact: false,
355            kind: crate::core::event::EventKind::ToolCall,
356            source: crate::core::event::EventSource::Tail,
357            tool: None,
358            tool_call_id: None,
359            tokens_in: Some(tokens),
360            tokens_out: None,
361            reasoning_tokens: None,
362            cost_usd_e6: None,
363            stop_reason: None,
364            latency_ms: None,
365            ttft_ms: None,
366            retry_count: None,
367            context_used_tokens: None,
368            context_max_tokens: None,
369            cache_creation_tokens: None,
370            cache_read_tokens: None,
371            system_prompt_tokens: None,
372            payload: serde_json::Value::Null,
373        };
374        ev.tokens_in = Some(tokens);
375        (s, vec![ev])
376    }
377
378    #[test]
379    fn evaluate_criterion_ci_excludes_zero_decrease() {
380        let c = Criterion::Delta {
381            direction: Direction::Decrease,
382            target_pct: -10.0,
383        };
384        let s = Summary {
385            ci95_lo: Some(-20.0),
386            ci95_hi: Some(-5.0),
387            ..Default::default()
388        };
389        assert_eq!(evaluate_criterion(&c, &s), Some(true)); // entire CI < 0
390
391        let s = Summary {
392            ci95_lo: Some(-20.0),
393            ci95_hi: Some(2.0),
394            ..Default::default()
395        };
396        assert_eq!(evaluate_criterion(&c, &s), Some(false)); // CI straddles zero
397    }
398
399    #[test]
400    fn evaluate_criterion_ci_excludes_zero_increase() {
401        let c = Criterion::Delta {
402            direction: Direction::Increase,
403            target_pct: 10.0,
404        };
405        let s = Summary {
406            ci95_lo: Some(5.0),
407            ci95_hi: Some(20.0),
408            ..Default::default()
409        };
410        assert_eq!(evaluate_criterion(&c, &s), Some(true)); // entire CI > 0
411
412        let s = Summary {
413            ci95_lo: Some(-2.0),
414            ci95_hi: Some(20.0),
415            ..Default::default()
416        };
417        assert_eq!(evaluate_criterion(&c, &s), Some(false)); // CI straddles zero
418    }
419
420    #[test]
421    fn manual_tags_drive_partition_without_git() {
422        let e = exp();
423        let sessions = vec![
424            session_with("a", 100),
425            session_with("b", 80),
426            session_with("c", 200),
427            session_with("d", 70),
428        ];
429        let mut tags = ManualTags::new();
430        tags.insert("a".into(), Classification::Control);
431        tags.insert("b".into(), Classification::Control);
432        tags.insert("c".into(), Classification::Treatment);
433        tags.insert("d".into(), Classification::Treatment);
434        let r = run(&e, &sessions, &tags, Path::new("/no"), false);
435        assert_eq!(r.summary.n_control, 2);
436        assert_eq!(r.summary.n_treatment, 2);
437    }
438}