Skip to main content

kaizen/experiment/
binding.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2//! Classify sessions into Control / Treatment / Excluded under a binding.
3//!
4//! Resolution order: manual tag > binding-specific classifier.
5
6use crate::core::event::SessionRecord;
7use crate::experiment::types::{Binding, Classification};
8use anyhow::Result;
9use std::collections::HashMap;
10use std::path::Path;
11use std::process::Command;
12
13/// Per-experiment manual tags keyed by `session_id` → variant.
14pub type ManualTags = HashMap<String, Classification>;
15
16/// Classify a session. Manual tag wins over git binding.
17///
18/// Pure w.r.t. `session` + `manual_tags`; shells out to `git` only when a
19/// `GitCommit` binding is in effect and no manual tag overrides.
20pub fn classify(
21    session: &SessionRecord,
22    binding: &Binding,
23    manual_tags: &ManualTags,
24    workspace: &Path,
25) -> Classification {
26    if let Some(v) = manual_tags.get(&session.id) {
27        return v.clone();
28    }
29    match binding {
30        Binding::ManualTag { .. } => Classification::Excluded,
31        Binding::GitCommit {
32            control_commit,
33            treatment_commit,
34        } => classify_git(session, control_commit, treatment_commit, workspace),
35        Binding::Branch {
36            control_branch,
37            treatment_branch,
38        } => classify_git(session, control_branch, treatment_branch, workspace),
39        Binding::PromptFingerprint {
40            control_fingerprint,
41            treatment_fingerprint,
42        } => classify_prompt(session, control_fingerprint, treatment_fingerprint),
43    }
44}
45
46fn classify_prompt(session: &SessionRecord, control: &str, treatment: &str) -> Classification {
47    match session.prompt_fingerprint.as_deref() {
48        Some(fp) if fp == control => Classification::Control,
49        Some(fp) if fp == treatment => Classification::Treatment,
50        _ => Classification::Excluded,
51    }
52}
53
54fn classify_git(
55    session: &SessionRecord,
56    control_commit: &str,
57    treatment_commit: &str,
58    workspace: &Path,
59) -> Classification {
60    let Some(start) = session.start_commit.as_deref() else {
61        return Classification::Excluded;
62    };
63    let on_treatment = is_ancestor(workspace, start, treatment_commit).unwrap_or(false);
64    let on_control = is_ancestor(workspace, start, control_commit).unwrap_or(false);
65    match (on_treatment, on_control) {
66        // strictly descended from control (not yet at treatment boundary)
67        (false, true) => Classification::Control,
68        // past the treatment boundary
69        (true, false) => Classification::Treatment,
70        // straddles or unknown
71        _ => Classification::Excluded,
72    }
73}
74
75fn is_ancestor(workspace: &Path, maybe_ancestor: &str, descendant: &str) -> Result<bool> {
76    if maybe_ancestor == descendant {
77        return Ok(true);
78    }
79    let out = Command::new("git")
80        .arg("-C")
81        .arg(workspace)
82        .args(["merge-base", "--is-ancestor", maybe_ancestor, descendant])
83        .output()?;
84    Ok(out.status.success())
85}
86
87/// Partition sessions by classification, preserving input order.
88pub fn partition<'a>(
89    sessions: &'a [SessionRecord],
90    binding: &Binding,
91    manual_tags: &ManualTags,
92    workspace: &Path,
93) -> (
94    Vec<&'a SessionRecord>,
95    Vec<&'a SessionRecord>,
96    Vec<&'a SessionRecord>,
97) {
98    let mut control = Vec::new();
99    let mut treatment = Vec::new();
100    let mut excluded = Vec::new();
101    for s in sessions {
102        match classify(s, binding, manual_tags, workspace) {
103            Classification::Control => control.push(s),
104            Classification::Treatment => treatment.push(s),
105            Classification::Excluded => excluded.push(s),
106        }
107    }
108    (control, treatment, excluded)
109}
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114    use crate::core::event::SessionStatus;
115
116    fn mk(id: &str, commit: Option<&str>) -> SessionRecord {
117        SessionRecord {
118            id: id.into(),
119            agent: "cursor".into(),
120            model: None,
121            workspace: "/ws".into(),
122            started_at_ms: 0,
123            ended_at_ms: None,
124            status: SessionStatus::Done,
125            trace_path: String::new(),
126            start_commit: commit.map(Into::into),
127            end_commit: None,
128            branch: None,
129            dirty_start: None,
130            dirty_end: None,
131            repo_binding_source: None,
132            prompt_fingerprint: None,
133            parent_session_id: None,
134            agent_version: None,
135            os: None,
136            arch: None,
137            repo_file_count: None,
138            repo_total_loc: None,
139        }
140    }
141
142    #[test]
143    fn manual_tag_beats_git_binding() {
144        let s = mk("s1", Some("abc"));
145        let binding = Binding::GitCommit {
146            control_commit: "c".into(),
147            treatment_commit: "t".into(),
148        };
149        let mut tags = ManualTags::new();
150        tags.insert("s1".into(), Classification::Treatment);
151        let got = classify(&s, &binding, &tags, Path::new("/no"));
152        assert_eq!(got, Classification::Treatment);
153    }
154
155    #[test]
156    fn no_start_commit_excludes() {
157        let s = mk("s1", None);
158        let binding = Binding::GitCommit {
159            control_commit: "c".into(),
160            treatment_commit: "t".into(),
161        };
162        let tags = ManualTags::new();
163        let got = classify(&s, &binding, &tags, Path::new("/no"));
164        assert_eq!(got, Classification::Excluded);
165    }
166
167    #[test]
168    fn prompt_fingerprint_classifies_exact_matches() {
169        let mut s = mk("s1", None);
170        s.prompt_fingerprint = Some("fp-b".into());
171        let binding = Binding::PromptFingerprint {
172            control_fingerprint: "fp-a".into(),
173            treatment_fingerprint: "fp-b".into(),
174        };
175        let got = classify(&s, &binding, &ManualTags::new(), Path::new("/no"));
176        assert_eq!(got, Classification::Treatment);
177    }
178
179    #[test]
180    fn partition_splits_three_ways() {
181        let s1 = mk("1", None);
182        let s2 = mk("2", None);
183        let s3 = mk("3", None);
184        let all = vec![s1, s2, s3];
185        let binding = Binding::GitCommit {
186            control_commit: "c".into(),
187            treatment_commit: "t".into(),
188        };
189        let mut tags = ManualTags::new();
190        tags.insert("1".into(), Classification::Control);
191        tags.insert("2".into(), Classification::Treatment);
192        let (c, t, e) = partition(&all, &binding, &tags, Path::new("/no"));
193        assert_eq!(c.len(), 1);
194        assert_eq!(t.len(), 1);
195        assert_eq!(e.len(), 1);
196    }
197}