1use crate::core::event::SessionRecord;
7use crate::experiment::types::{Binding, Classification};
8use anyhow::Result;
9use std::collections::HashMap;
10use std::path::Path;
11use std::process::Command;
12
13pub type ManualTags = HashMap<String, Classification>;
15
16pub fn classify(
21 session: &SessionRecord,
22 binding: &Binding,
23 manual_tags: &ManualTags,
24 workspace: &Path,
25) -> Classification {
26 if let Some(v) = manual_tags.get(&session.id) {
27 return v.clone();
28 }
29 match binding {
30 Binding::ManualTag { .. } => Classification::Excluded,
31 Binding::GitCommit {
32 control_commit,
33 treatment_commit,
34 } => classify_git(session, control_commit, treatment_commit, workspace),
35 Binding::Branch {
36 control_branch,
37 treatment_branch,
38 } => classify_git(session, control_branch, treatment_branch, workspace),
39 Binding::PromptFingerprint {
40 control_fingerprint,
41 treatment_fingerprint,
42 } => classify_prompt(session, control_fingerprint, treatment_fingerprint),
43 }
44}
45
46fn classify_prompt(session: &SessionRecord, control: &str, treatment: &str) -> Classification {
47 match session.prompt_fingerprint.as_deref() {
48 Some(fp) if fp == control => Classification::Control,
49 Some(fp) if fp == treatment => Classification::Treatment,
50 _ => Classification::Excluded,
51 }
52}
53
54fn classify_git(
55 session: &SessionRecord,
56 control_commit: &str,
57 treatment_commit: &str,
58 workspace: &Path,
59) -> Classification {
60 let Some(start) = session.start_commit.as_deref() else {
61 return Classification::Excluded;
62 };
63 let on_treatment = is_ancestor(workspace, start, treatment_commit).unwrap_or(false);
64 let on_control = is_ancestor(workspace, start, control_commit).unwrap_or(false);
65 match (on_treatment, on_control) {
66 (false, true) => Classification::Control,
68 (true, false) => Classification::Treatment,
70 _ => Classification::Excluded,
72 }
73}
74
75fn is_ancestor(workspace: &Path, maybe_ancestor: &str, descendant: &str) -> Result<bool> {
76 if maybe_ancestor == descendant {
77 return Ok(true);
78 }
79 let out = Command::new("git")
80 .arg("-C")
81 .arg(workspace)
82 .args(["merge-base", "--is-ancestor", maybe_ancestor, descendant])
83 .output()?;
84 Ok(out.status.success())
85}
86
87pub fn partition<'a>(
89 sessions: &'a [SessionRecord],
90 binding: &Binding,
91 manual_tags: &ManualTags,
92 workspace: &Path,
93) -> (
94 Vec<&'a SessionRecord>,
95 Vec<&'a SessionRecord>,
96 Vec<&'a SessionRecord>,
97) {
98 let mut control = Vec::new();
99 let mut treatment = Vec::new();
100 let mut excluded = Vec::new();
101 for s in sessions {
102 match classify(s, binding, manual_tags, workspace) {
103 Classification::Control => control.push(s),
104 Classification::Treatment => treatment.push(s),
105 Classification::Excluded => excluded.push(s),
106 }
107 }
108 (control, treatment, excluded)
109}
110
111#[cfg(test)]
112mod tests {
113 use super::*;
114 use crate::core::event::SessionStatus;
115
116 fn mk(id: &str, commit: Option<&str>) -> SessionRecord {
117 SessionRecord {
118 id: id.into(),
119 agent: "cursor".into(),
120 model: None,
121 workspace: "/ws".into(),
122 started_at_ms: 0,
123 ended_at_ms: None,
124 status: SessionStatus::Done,
125 trace_path: String::new(),
126 start_commit: commit.map(Into::into),
127 end_commit: None,
128 branch: None,
129 dirty_start: None,
130 dirty_end: None,
131 repo_binding_source: None,
132 prompt_fingerprint: None,
133 parent_session_id: None,
134 agent_version: None,
135 os: None,
136 arch: None,
137 repo_file_count: None,
138 repo_total_loc: None,
139 }
140 }
141
142 #[test]
143 fn manual_tag_beats_git_binding() {
144 let s = mk("s1", Some("abc"));
145 let binding = Binding::GitCommit {
146 control_commit: "c".into(),
147 treatment_commit: "t".into(),
148 };
149 let mut tags = ManualTags::new();
150 tags.insert("s1".into(), Classification::Treatment);
151 let got = classify(&s, &binding, &tags, Path::new("/no"));
152 assert_eq!(got, Classification::Treatment);
153 }
154
155 #[test]
156 fn no_start_commit_excludes() {
157 let s = mk("s1", None);
158 let binding = Binding::GitCommit {
159 control_commit: "c".into(),
160 treatment_commit: "t".into(),
161 };
162 let tags = ManualTags::new();
163 let got = classify(&s, &binding, &tags, Path::new("/no"));
164 assert_eq!(got, Classification::Excluded);
165 }
166
167 #[test]
168 fn prompt_fingerprint_classifies_exact_matches() {
169 let mut s = mk("s1", None);
170 s.prompt_fingerprint = Some("fp-b".into());
171 let binding = Binding::PromptFingerprint {
172 control_fingerprint: "fp-a".into(),
173 treatment_fingerprint: "fp-b".into(),
174 };
175 let got = classify(&s, &binding, &ManualTags::new(), Path::new("/no"));
176 assert_eq!(got, Classification::Treatment);
177 }
178
179 #[test]
180 fn partition_splits_three_ways() {
181 let s1 = mk("1", None);
182 let s2 = mk("2", None);
183 let s3 = mk("3", None);
184 let all = vec![s1, s2, s3];
185 let binding = Binding::GitCommit {
186 control_commit: "c".into(),
187 treatment_commit: "t".into(),
188 };
189 let mut tags = ManualTags::new();
190 tags.insert("1".into(), Classification::Control);
191 tags.insert("2".into(), Classification::Treatment);
192 let (c, t, e) = partition(&all, &binding, &tags, Path::new("/no"));
193 assert_eq!(c.len(), 1);
194 assert_eq!(t.len(), 1);
195 assert_eq!(e.len(), 1);
196 }
197}