Skip to main content

kaizen/experiment/
binding.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2//! Classify sessions into Control / Treatment / Excluded under a binding.
3//!
4//! Resolution order: manual tag > git ancestry. Branch-binding is a
5//! special case of GitCommit: caller resolves branch tips to commits.
6
7use crate::core::event::SessionRecord;
8use crate::experiment::types::{Binding, Classification};
9use anyhow::Result;
10use std::collections::HashMap;
11use std::path::Path;
12use std::process::Command;
13
14/// Per-experiment manual tags keyed by `session_id` → variant.
15pub type ManualTags = HashMap<String, Classification>;
16
17/// Classify a session. Manual tag wins over git binding.
18///
19/// Pure w.r.t. `session` + `manual_tags`; shells out to `git` only when a
20/// `GitCommit` binding is in effect and no manual tag overrides.
21pub fn classify(
22    session: &SessionRecord,
23    binding: &Binding,
24    manual_tags: &ManualTags,
25    workspace: &Path,
26) -> Classification {
27    if let Some(v) = manual_tags.get(&session.id) {
28        return v.clone();
29    }
30    match binding {
31        Binding::ManualTag { .. } => Classification::Excluded,
32        Binding::GitCommit {
33            control_commit,
34            treatment_commit,
35        } => classify_git(session, control_commit, treatment_commit, workspace),
36        Binding::Branch {
37            control_branch,
38            treatment_branch,
39        } => classify_git(session, control_branch, treatment_branch, workspace),
40    }
41}
42
43fn classify_git(
44    session: &SessionRecord,
45    control_commit: &str,
46    treatment_commit: &str,
47    workspace: &Path,
48) -> Classification {
49    let Some(start) = session.start_commit.as_deref() else {
50        return Classification::Excluded;
51    };
52    let on_treatment = is_ancestor(workspace, start, treatment_commit).unwrap_or(false);
53    let on_control = is_ancestor(workspace, start, control_commit).unwrap_or(false);
54    match (on_treatment, on_control) {
55        // strictly descended from control (not yet at treatment boundary)
56        (false, true) => Classification::Control,
57        // past the treatment boundary
58        (true, false) => Classification::Treatment,
59        // straddles or unknown
60        _ => Classification::Excluded,
61    }
62}
63
64fn is_ancestor(workspace: &Path, maybe_ancestor: &str, descendant: &str) -> Result<bool> {
65    if maybe_ancestor == descendant {
66        return Ok(true);
67    }
68    let out = Command::new("git")
69        .arg("-C")
70        .arg(workspace)
71        .args(["merge-base", "--is-ancestor", maybe_ancestor, descendant])
72        .output()?;
73    Ok(out.status.success())
74}
75
76/// Partition sessions by classification, preserving input order.
77pub fn partition<'a>(
78    sessions: &'a [SessionRecord],
79    binding: &Binding,
80    manual_tags: &ManualTags,
81    workspace: &Path,
82) -> (
83    Vec<&'a SessionRecord>,
84    Vec<&'a SessionRecord>,
85    Vec<&'a SessionRecord>,
86) {
87    let mut control = Vec::new();
88    let mut treatment = Vec::new();
89    let mut excluded = Vec::new();
90    for s in sessions {
91        match classify(s, binding, manual_tags, workspace) {
92            Classification::Control => control.push(s),
93            Classification::Treatment => treatment.push(s),
94            Classification::Excluded => excluded.push(s),
95        }
96    }
97    (control, treatment, excluded)
98}
99
100#[cfg(test)]
101mod tests {
102    use super::*;
103    use crate::core::event::SessionStatus;
104
105    fn mk(id: &str, commit: Option<&str>) -> SessionRecord {
106        SessionRecord {
107            id: id.into(),
108            agent: "cursor".into(),
109            model: None,
110            workspace: "/ws".into(),
111            started_at_ms: 0,
112            ended_at_ms: None,
113            status: SessionStatus::Done,
114            trace_path: String::new(),
115            start_commit: commit.map(Into::into),
116            end_commit: None,
117            branch: None,
118            dirty_start: None,
119            dirty_end: None,
120            repo_binding_source: None,
121            prompt_fingerprint: None,
122            parent_session_id: None,
123            agent_version: None,
124            os: None,
125            arch: None,
126            repo_file_count: None,
127            repo_total_loc: None,
128        }
129    }
130
131    #[test]
132    fn manual_tag_beats_git_binding() {
133        let s = mk("s1", Some("abc"));
134        let binding = Binding::GitCommit {
135            control_commit: "c".into(),
136            treatment_commit: "t".into(),
137        };
138        let mut tags = ManualTags::new();
139        tags.insert("s1".into(), Classification::Treatment);
140        let got = classify(&s, &binding, &tags, Path::new("/no"));
141        assert_eq!(got, Classification::Treatment);
142    }
143
144    #[test]
145    fn no_start_commit_excludes() {
146        let s = mk("s1", None);
147        let binding = Binding::GitCommit {
148            control_commit: "c".into(),
149            treatment_commit: "t".into(),
150        };
151        let tags = ManualTags::new();
152        let got = classify(&s, &binding, &tags, Path::new("/no"));
153        assert_eq!(got, Classification::Excluded);
154    }
155
156    #[test]
157    fn partition_splits_three_ways() {
158        let s1 = mk("1", None);
159        let s2 = mk("2", None);
160        let s3 = mk("3", None);
161        let all = vec![s1, s2, s3];
162        let binding = Binding::GitCommit {
163            control_commit: "c".into(),
164            treatment_commit: "t".into(),
165        };
166        let mut tags = ManualTags::new();
167        tags.insert("1".into(), Classification::Control);
168        tags.insert("2".into(), Classification::Treatment);
169        let (c, t, e) = partition(&all, &binding, &tags, Path::new("/no"));
170        assert_eq!(c.len(), 1);
171        assert_eq!(t.len(), 1);
172        assert_eq!(e.len(), 1);
173    }
174}