1use crate::core::event::SessionRecord;
8use crate::experiment::types::{Binding, Classification};
9use anyhow::Result;
10use std::collections::HashMap;
11use std::path::Path;
12use std::process::Command;
13
14pub type ManualTags = HashMap<String, Classification>;
16
17pub fn classify(
22 session: &SessionRecord,
23 binding: &Binding,
24 manual_tags: &ManualTags,
25 workspace: &Path,
26) -> Classification {
27 if let Some(v) = manual_tags.get(&session.id) {
28 return v.clone();
29 }
30 match binding {
31 Binding::ManualTag { .. } => Classification::Excluded,
32 Binding::GitCommit {
33 control_commit,
34 treatment_commit,
35 } => classify_git(session, control_commit, treatment_commit, workspace),
36 Binding::Branch {
37 control_branch,
38 treatment_branch,
39 } => classify_git(session, control_branch, treatment_branch, workspace),
40 }
41}
42
43fn classify_git(
44 session: &SessionRecord,
45 control_commit: &str,
46 treatment_commit: &str,
47 workspace: &Path,
48) -> Classification {
49 let Some(start) = session.start_commit.as_deref() else {
50 return Classification::Excluded;
51 };
52 let on_treatment = is_ancestor(workspace, start, treatment_commit).unwrap_or(false);
53 let on_control = is_ancestor(workspace, start, control_commit).unwrap_or(false);
54 match (on_treatment, on_control) {
55 (false, true) => Classification::Control,
57 (true, false) => Classification::Treatment,
59 _ => Classification::Excluded,
61 }
62}
63
64fn is_ancestor(workspace: &Path, maybe_ancestor: &str, descendant: &str) -> Result<bool> {
65 if maybe_ancestor == descendant {
66 return Ok(true);
67 }
68 let out = Command::new("git")
69 .arg("-C")
70 .arg(workspace)
71 .args(["merge-base", "--is-ancestor", maybe_ancestor, descendant])
72 .output()?;
73 Ok(out.status.success())
74}
75
76pub fn partition<'a>(
78 sessions: &'a [SessionRecord],
79 binding: &Binding,
80 manual_tags: &ManualTags,
81 workspace: &Path,
82) -> (
83 Vec<&'a SessionRecord>,
84 Vec<&'a SessionRecord>,
85 Vec<&'a SessionRecord>,
86) {
87 let mut control = Vec::new();
88 let mut treatment = Vec::new();
89 let mut excluded = Vec::new();
90 for s in sessions {
91 match classify(s, binding, manual_tags, workspace) {
92 Classification::Control => control.push(s),
93 Classification::Treatment => treatment.push(s),
94 Classification::Excluded => excluded.push(s),
95 }
96 }
97 (control, treatment, excluded)
98}
99
100#[cfg(test)]
101mod tests {
102 use super::*;
103 use crate::core::event::SessionStatus;
104
105 fn mk(id: &str, commit: Option<&str>) -> SessionRecord {
106 SessionRecord {
107 id: id.into(),
108 agent: "cursor".into(),
109 model: None,
110 workspace: "/ws".into(),
111 started_at_ms: 0,
112 ended_at_ms: None,
113 status: SessionStatus::Done,
114 trace_path: String::new(),
115 start_commit: commit.map(Into::into),
116 end_commit: None,
117 branch: None,
118 dirty_start: None,
119 dirty_end: None,
120 repo_binding_source: None,
121 prompt_fingerprint: None,
122 parent_session_id: None,
123 agent_version: None,
124 os: None,
125 arch: None,
126 repo_file_count: None,
127 repo_total_loc: None,
128 }
129 }
130
131 #[test]
132 fn manual_tag_beats_git_binding() {
133 let s = mk("s1", Some("abc"));
134 let binding = Binding::GitCommit {
135 control_commit: "c".into(),
136 treatment_commit: "t".into(),
137 };
138 let mut tags = ManualTags::new();
139 tags.insert("s1".into(), Classification::Treatment);
140 let got = classify(&s, &binding, &tags, Path::new("/no"));
141 assert_eq!(got, Classification::Treatment);
142 }
143
144 #[test]
145 fn no_start_commit_excludes() {
146 let s = mk("s1", None);
147 let binding = Binding::GitCommit {
148 control_commit: "c".into(),
149 treatment_commit: "t".into(),
150 };
151 let tags = ManualTags::new();
152 let got = classify(&s, &binding, &tags, Path::new("/no"));
153 assert_eq!(got, Classification::Excluded);
154 }
155
156 #[test]
157 fn partition_splits_three_ways() {
158 let s1 = mk("1", None);
159 let s2 = mk("2", None);
160 let s3 = mk("3", None);
161 let all = vec![s1, s2, s3];
162 let binding = Binding::GitCommit {
163 control_commit: "c".into(),
164 treatment_commit: "t".into(),
165 };
166 let mut tags = ManualTags::new();
167 tags.insert("1".into(), Classification::Control);
168 tags.insert("2".into(), Classification::Treatment);
169 let (c, t, e) = partition(&all, &binding, &tags, Path::new("/no"));
170 assert_eq!(c.len(), 1);
171 assert_eq!(t.len(), 1);
172 assert_eq!(e.len(), 1);
173 }
174}