Skip to main content

kaizen/experiment/
binding.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2//! Classify sessions into Control / Treatment / Excluded under a binding.
3//!
4//! Resolution order: manual tag > git ancestry. Branch-binding is a
5//! special case of GitCommit: caller resolves branch tips to commits.
6
7use crate::core::event::SessionRecord;
8use crate::experiment::types::{Binding, Classification};
9use anyhow::Result;
10use std::collections::HashMap;
11use std::path::Path;
12use std::process::Command;
13
14/// Per-experiment manual tags keyed by `session_id` → variant.
15pub type ManualTags = HashMap<String, Classification>;
16
17/// Classify a session. Manual tag wins over git binding.
18///
19/// Pure w.r.t. `session` + `manual_tags`; shells out to `git` only when a
20/// `GitCommit` binding is in effect and no manual tag overrides.
21pub fn classify(
22    session: &SessionRecord,
23    binding: &Binding,
24    manual_tags: &ManualTags,
25    workspace: &Path,
26) -> Classification {
27    if let Some(v) = manual_tags.get(&session.id) {
28        return v.clone();
29    }
30    match binding {
31        Binding::ManualTag { .. } => Classification::Excluded,
32        Binding::GitCommit {
33            control_commit,
34            treatment_commit,
35        } => classify_git(session, control_commit, treatment_commit, workspace),
36        Binding::Branch {
37            control_branch,
38            treatment_branch,
39        } => classify_git(session, control_branch, treatment_branch, workspace),
40    }
41}
42
43fn classify_git(
44    session: &SessionRecord,
45    control_commit: &str,
46    treatment_commit: &str,
47    workspace: &Path,
48) -> Classification {
49    let Some(start) = session.start_commit.as_deref() else {
50        return Classification::Excluded;
51    };
52    let on_treatment = is_ancestor(workspace, start, treatment_commit).unwrap_or(false);
53    let on_control = is_ancestor(workspace, start, control_commit).unwrap_or(false);
54    match (on_treatment, on_control) {
55        // strictly descended from control (not yet at treatment boundary)
56        (false, true) => Classification::Control,
57        // past the treatment boundary
58        (true, false) => Classification::Treatment,
59        // straddles or unknown
60        _ => Classification::Excluded,
61    }
62}
63
64fn is_ancestor(workspace: &Path, maybe_ancestor: &str, descendant: &str) -> Result<bool> {
65    if maybe_ancestor == descendant {
66        return Ok(true);
67    }
68    let out = Command::new("git")
69        .arg("-C")
70        .arg(workspace)
71        .args(["merge-base", "--is-ancestor", maybe_ancestor, descendant])
72        .output()?;
73    Ok(out.status.success())
74}
75
76/// Partition sessions by classification, preserving input order.
77pub fn partition<'a>(
78    sessions: &'a [SessionRecord],
79    binding: &Binding,
80    manual_tags: &ManualTags,
81    workspace: &Path,
82) -> (
83    Vec<&'a SessionRecord>,
84    Vec<&'a SessionRecord>,
85    Vec<&'a SessionRecord>,
86) {
87    let mut control = Vec::new();
88    let mut treatment = Vec::new();
89    let mut excluded = Vec::new();
90    for s in sessions {
91        match classify(s, binding, manual_tags, workspace) {
92            Classification::Control => control.push(s),
93            Classification::Treatment => treatment.push(s),
94            Classification::Excluded => excluded.push(s),
95        }
96    }
97    (control, treatment, excluded)
98}
99
100#[cfg(test)]
101mod tests {
102    use super::*;
103    use crate::core::event::SessionStatus;
104
105    fn mk(id: &str, commit: Option<&str>) -> SessionRecord {
106        SessionRecord {
107            id: id.into(),
108            agent: "cursor".into(),
109            model: None,
110            workspace: "/ws".into(),
111            started_at_ms: 0,
112            ended_at_ms: None,
113            status: SessionStatus::Done,
114            trace_path: String::new(),
115            start_commit: commit.map(Into::into),
116            end_commit: None,
117            branch: None,
118            dirty_start: None,
119            dirty_end: None,
120            repo_binding_source: None,
121        }
122    }
123
124    #[test]
125    fn manual_tag_beats_git_binding() {
126        let s = mk("s1", Some("abc"));
127        let binding = Binding::GitCommit {
128            control_commit: "c".into(),
129            treatment_commit: "t".into(),
130        };
131        let mut tags = ManualTags::new();
132        tags.insert("s1".into(), Classification::Treatment);
133        let got = classify(&s, &binding, &tags, Path::new("/no"));
134        assert_eq!(got, Classification::Treatment);
135    }
136
137    #[test]
138    fn no_start_commit_excludes() {
139        let s = mk("s1", None);
140        let binding = Binding::GitCommit {
141            control_commit: "c".into(),
142            treatment_commit: "t".into(),
143        };
144        let tags = ManualTags::new();
145        let got = classify(&s, &binding, &tags, Path::new("/no"));
146        assert_eq!(got, Classification::Excluded);
147    }
148
149    #[test]
150    fn partition_splits_three_ways() {
151        let s1 = mk("1", None);
152        let s2 = mk("2", None);
153        let s3 = mk("3", None);
154        let all = vec![s1, s2, s3];
155        let binding = Binding::GitCommit {
156            control_commit: "c".into(),
157            treatment_commit: "t".into(),
158        };
159        let mut tags = ManualTags::new();
160        tags.insert("1".into(), Classification::Control);
161        tags.insert("2".into(), Classification::Treatment);
162        let (c, t, e) = partition(&all, &binding, &tags, Path::new("/no"));
163        assert_eq!(c.len(), 1);
164        assert_eq!(t.len(), 1);
165        assert_eq!(e.len(), 1);
166    }
167}