1use crate::experiment::binding::ManualTags;
5use crate::experiment::types::{Classification, Experiment, State};
6use crate::store::Store;
7use anyhow::{Context, Result};
8use rusqlite::{OptionalExtension, params};
9
10pub fn save_experiment(store: &Store, exp: &Experiment) -> Result<()> {
11 let json = serde_json::to_string(exp).context("serialize experiment")?;
12 store.conn().execute(
13 "INSERT INTO experiments (id, name, created_at_ms, metadata, state, concluded_at_ms)
14 VALUES (?1, ?2, ?3, ?4, ?5, ?6)
15 ON CONFLICT(id) DO UPDATE SET
16 name=excluded.name,
17 metadata=excluded.metadata,
18 state=excluded.state,
19 concluded_at_ms=excluded.concluded_at_ms",
20 params![
21 exp.id,
22 exp.name,
23 exp.created_at_ms as i64,
24 json,
25 format!("{:?}", exp.state),
26 exp.concluded_at_ms.map(|v| v as i64),
27 ],
28 )?;
29 Ok(())
30}
31
32pub fn load_experiment(store: &Store, id: &str) -> Result<Option<Experiment>> {
33 let row: Option<String> = store
34 .conn()
35 .query_row(
36 "SELECT metadata FROM experiments WHERE id = ?1",
37 params![id],
38 |r| r.get(0),
39 )
40 .optional()?;
41 match row {
42 Some(s) => Ok(Some(serde_json::from_str(&s)?)),
43 None => Ok(None),
44 }
45}
46
47pub fn list_experiments(store: &Store) -> Result<Vec<Experiment>> {
48 let mut stmt = store
49 .conn()
50 .prepare("SELECT metadata FROM experiments ORDER BY created_at_ms DESC")?;
51 let rows = stmt.query_map([], |r| r.get::<_, String>(0))?;
52 let mut out = Vec::new();
53 for row in rows {
54 let s = row?;
55 if let Ok(e) = serde_json::from_str::<Experiment>(&s) {
56 out.push(e);
57 }
58 }
59 Ok(out)
60}
61
62pub fn set_state(store: &Store, id: &str, state: State, now_ms: u64) -> Result<()> {
63 let Some(mut exp) = load_experiment(store, id)? else {
64 anyhow::bail!("experiment not found: {id}");
65 };
66 exp.state = state;
67 if matches!(state, State::Concluded) {
68 exp.concluded_at_ms = Some(now_ms);
69 }
70 save_experiment(store, &exp)
71}
72
73pub fn tag_session(
79 store: &Store,
80 exp_id: &str,
81 session_id: &str,
82 variant: Classification,
83) -> Result<()> {
84 let existing: Option<String> = store
85 .conn()
86 .query_row(
87 "SELECT variant FROM experiment_tags WHERE experiment_id=?1 AND session_id=?2",
88 params![exp_id, session_id],
89 |r| r.get(0),
90 )
91 .optional()?;
92 if let Some(prev) = existing {
93 let prev_cls = parse_variant(&prev);
94 if prev_cls != variant {
95 anyhow::bail!(
96 "variant conflict: session {session_id} already tagged as {prev} \
97 for experiment {exp_id}; cannot retag as {:?}",
98 variant
99 );
100 }
101 return Ok(());
102 }
103 store.conn().execute(
104 "INSERT INTO experiment_tags (experiment_id, session_id, variant) VALUES (?1, ?2, ?3)",
105 params![exp_id, session_id, format!("{:?}", variant)],
106 )?;
107 Ok(())
108}
109
110fn parse_variant(s: &str) -> Classification {
111 match s {
112 "Control" => Classification::Control,
113 "Treatment" => Classification::Treatment,
114 _ => Classification::Excluded,
115 }
116}
117
118pub fn manual_tags(store: &Store, exp_id: &str) -> Result<ManualTags> {
119 let mut stmt = store
120 .conn()
121 .prepare("SELECT session_id, variant FROM experiment_tags WHERE experiment_id = ?1")?;
122 let rows = stmt.query_map(params![exp_id], |r| {
123 Ok((r.get::<_, String>(0)?, r.get::<_, String>(1)?))
124 })?;
125 let mut out = ManualTags::new();
126 for row in rows {
127 let (sid, variant) = row?;
128 out.insert(sid, parse_variant(&variant));
129 }
130 Ok(out)
131}
132
133#[cfg(test)]
134mod tests {
135 use super::*;
136 use crate::experiment::types::{Binding, Criterion, Direction, Metric, State};
137 use tempfile::TempDir;
138
139 fn mk(id: &str) -> Experiment {
140 Experiment {
141 id: id.into(),
142 name: format!("exp-{id}"),
143 hypothesis: "h".into(),
144 change_description: "c".into(),
145 metric: Metric::TokensPerSession,
146 binding: Binding::GitCommit {
147 control_commit: "c1".into(),
148 treatment_commit: "c2".into(),
149 },
150 duration_days: 14,
151 success_criterion: Criterion::Delta {
152 direction: Direction::Decrease,
153 target_pct: 10.0,
154 },
155 state: State::Draft,
156 created_at_ms: 1000,
157 concluded_at_ms: None,
158 guardrails: Vec::new(),
159 }
160 }
161
162 #[test]
163 fn round_trip_save_load() {
164 let dir = TempDir::new().unwrap();
165 let store = Store::open(&dir.path().join("k.db")).unwrap();
166 let e = mk("a");
167 save_experiment(&store, &e).unwrap();
168 let got = load_experiment(&store, "a").unwrap().unwrap();
169 assert_eq!(got.id, "a");
170 assert_eq!(got.state, State::Draft);
171 }
172
173 #[test]
174 fn set_state_transitions() {
175 let dir = TempDir::new().unwrap();
176 let store = Store::open(&dir.path().join("k.db")).unwrap();
177 save_experiment(&store, &mk("b")).unwrap();
178 set_state(&store, "b", State::Running, 5_000).unwrap();
179 let got = load_experiment(&store, "b").unwrap().unwrap();
180 assert_eq!(got.state, State::Running);
181 set_state(&store, "b", State::Concluded, 9_000).unwrap();
182 let got = load_experiment(&store, "b").unwrap().unwrap();
183 assert_eq!(got.state, State::Concluded);
184 assert_eq!(got.concluded_at_ms, Some(9_000));
185 }
186
187 #[test]
188 fn tags_round_trip() {
189 let dir = TempDir::new().unwrap();
190 let store = Store::open(&dir.path().join("k.db")).unwrap();
191 save_experiment(&store, &mk("e")).unwrap();
192 tag_session(&store, "e", "s1", Classification::Treatment).unwrap();
193 tag_session(&store, "e", "s2", Classification::Control).unwrap();
194 let tags = manual_tags(&store, "e").unwrap();
195 assert_eq!(tags.get("s1"), Some(&Classification::Treatment));
196 assert_eq!(tags.get("s2"), Some(&Classification::Control));
197 }
198
199 #[test]
200 fn tag_same_variant_is_idempotent() {
201 let dir = TempDir::new().unwrap();
202 let store = Store::open(&dir.path().join("k.db")).unwrap();
203 save_experiment(&store, &mk("idem")).unwrap();
204 tag_session(&store, "idem", "s1", Classification::Control).unwrap();
205 tag_session(&store, "idem", "s1", Classification::Control).unwrap();
207 let tags = manual_tags(&store, "idem").unwrap();
208 assert_eq!(tags.get("s1"), Some(&Classification::Control));
209 }
210
211 #[test]
212 fn tag_different_variant_is_error() {
213 let dir = TempDir::new().unwrap();
214 let store = Store::open(&dir.path().join("k.db")).unwrap();
215 save_experiment(&store, &mk("conflict")).unwrap();
216 tag_session(&store, "conflict", "s1", Classification::Control).unwrap();
217 let err = tag_session(&store, "conflict", "s1", Classification::Treatment).unwrap_err();
218 assert!(
219 err.to_string().contains("variant conflict"),
220 "expected variant conflict, got: {err}"
221 );
222 }
223
224 #[test]
225 fn concurrent_tag_produces_one_row() {
226 use std::sync::Arc;
227 use std::thread;
228
229 let dir = TempDir::new().unwrap();
230 let db_path = dir.path().join("k.db");
231 let store = Store::open(&db_path).unwrap();
232 save_experiment(&store, &mk("concur")).unwrap();
233 drop(store);
234
235 let path = Arc::new(db_path);
237 let handles: Vec<_> = (0..8)
238 .map(|_| {
239 let p = Arc::clone(&path);
240 thread::spawn(move || {
241 let s = Store::open(&p).unwrap();
242 tag_session(&s, "concur", "sess", Classification::Treatment)
243 })
244 })
245 .collect();
246
247 let results: Vec<_> = handles.into_iter().map(|h| h.join().unwrap()).collect();
248 let ok_count = results.iter().filter(|r| r.is_ok()).count();
249 assert!(ok_count >= 1, "at least one thread must succeed");
250
251 let store2 = Store::open(&path).unwrap();
252 let tags = manual_tags(&store2, "concur").unwrap();
253 assert_eq!(
254 tags.get("sess"),
255 Some(&Classification::Treatment),
256 "exactly one row, correct variant"
257 );
258 }
259}