Skip to main content

kaizen/experiment/
store.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2//! Persistence for experiments. IO at boundary; pure types in `types.rs`.
3
4use crate::experiment::binding::ManualTags;
5use crate::experiment::types::{Classification, Experiment, State};
6use crate::store::Store;
7use anyhow::{Context, Result};
8use rusqlite::{OptionalExtension, params};
9
10pub fn save_experiment(store: &Store, exp: &Experiment) -> Result<()> {
11    let json = serde_json::to_string(exp).context("serialize experiment")?;
12    store.conn().execute(
13        "INSERT INTO experiments (id, name, created_at_ms, metadata, state, concluded_at_ms)
14         VALUES (?1, ?2, ?3, ?4, ?5, ?6)
15         ON CONFLICT(id) DO UPDATE SET
16           name=excluded.name,
17           metadata=excluded.metadata,
18           state=excluded.state,
19           concluded_at_ms=excluded.concluded_at_ms",
20        params![
21            exp.id,
22            exp.name,
23            exp.created_at_ms as i64,
24            json,
25            format!("{:?}", exp.state),
26            exp.concluded_at_ms.map(|v| v as i64),
27        ],
28    )?;
29    Ok(())
30}
31
32pub fn load_experiment(store: &Store, id: &str) -> Result<Option<Experiment>> {
33    let row: Option<String> = store
34        .conn()
35        .query_row(
36            "SELECT metadata FROM experiments WHERE id = ?1",
37            params![id],
38            |r| r.get(0),
39        )
40        .optional()?;
41    match row {
42        Some(s) => Ok(Some(serde_json::from_str(&s)?)),
43        None => Ok(None),
44    }
45}
46
47pub fn list_experiments(store: &Store) -> Result<Vec<Experiment>> {
48    let mut stmt = store
49        .conn()
50        .prepare("SELECT metadata FROM experiments ORDER BY created_at_ms DESC")?;
51    let rows = stmt.query_map([], |r| r.get::<_, String>(0))?;
52    let mut out = Vec::new();
53    for row in rows {
54        let s = row?;
55        if let Ok(e) = serde_json::from_str::<Experiment>(&s) {
56            out.push(e);
57        }
58    }
59    Ok(out)
60}
61
62pub fn set_state(store: &Store, id: &str, state: State, now_ms: u64) -> Result<()> {
63    let Some(mut exp) = load_experiment(store, id)? else {
64        anyhow::bail!("experiment not found: {id}");
65    };
66    exp.state = state;
67    if matches!(state, State::Concluded) {
68        exp.concluded_at_ms = Some(now_ms);
69    }
70    save_experiment(store, &exp)
71}
72
73/// Tag a session with a variant.
74///
75/// Idempotent when the same variant is supplied. Returns `Err` when the session
76/// already carries a *different* variant — the caller must resolve the conflict
77/// rather than silently overwrite.
78pub fn tag_session(
79    store: &Store,
80    exp_id: &str,
81    session_id: &str,
82    variant: Classification,
83) -> Result<()> {
84    let existing: Option<String> = store
85        .conn()
86        .query_row(
87            "SELECT variant FROM experiment_tags WHERE experiment_id=?1 AND session_id=?2",
88            params![exp_id, session_id],
89            |r| r.get(0),
90        )
91        .optional()?;
92    if let Some(prev) = existing {
93        let prev_cls = parse_variant(&prev);
94        if prev_cls != variant {
95            anyhow::bail!(
96                "variant conflict: session {session_id} already tagged as {prev} \
97                 for experiment {exp_id}; cannot retag as {:?}",
98                variant
99            );
100        }
101        return Ok(());
102    }
103    store.conn().execute(
104        "INSERT INTO experiment_tags (experiment_id, session_id, variant) VALUES (?1, ?2, ?3)",
105        params![exp_id, session_id, format!("{:?}", variant)],
106    )?;
107    Ok(())
108}
109
110fn parse_variant(s: &str) -> Classification {
111    match s {
112        "Control" => Classification::Control,
113        "Treatment" => Classification::Treatment,
114        _ => Classification::Excluded,
115    }
116}
117
118pub fn manual_tags(store: &Store, exp_id: &str) -> Result<ManualTags> {
119    let mut stmt = store
120        .conn()
121        .prepare("SELECT session_id, variant FROM experiment_tags WHERE experiment_id = ?1")?;
122    let rows = stmt.query_map(params![exp_id], |r| {
123        Ok((r.get::<_, String>(0)?, r.get::<_, String>(1)?))
124    })?;
125    let mut out = ManualTags::new();
126    for row in rows {
127        let (sid, variant) = row?;
128        out.insert(sid, parse_variant(&variant));
129    }
130    Ok(out)
131}
132
133#[cfg(test)]
134mod tests {
135    use super::*;
136    use crate::experiment::types::{Binding, Criterion, Direction, Metric, State};
137    use tempfile::TempDir;
138
139    fn mk(id: &str) -> Experiment {
140        Experiment {
141            id: id.into(),
142            name: format!("exp-{id}"),
143            hypothesis: "h".into(),
144            change_description: "c".into(),
145            metric: Metric::TokensPerSession,
146            binding: Binding::GitCommit {
147                control_commit: "c1".into(),
148                treatment_commit: "c2".into(),
149            },
150            duration_days: 14,
151            success_criterion: Criterion::Delta {
152                direction: Direction::Decrease,
153                target_pct: 10.0,
154            },
155            state: State::Draft,
156            created_at_ms: 1000,
157            concluded_at_ms: None,
158            guardrails: Vec::new(),
159        }
160    }
161
162    #[test]
163    fn round_trip_save_load() {
164        let dir = TempDir::new().unwrap();
165        let store = Store::open(&dir.path().join("k.db")).unwrap();
166        let e = mk("a");
167        save_experiment(&store, &e).unwrap();
168        let got = load_experiment(&store, "a").unwrap().unwrap();
169        assert_eq!(got.id, "a");
170        assert_eq!(got.state, State::Draft);
171    }
172
173    #[test]
174    fn set_state_transitions() {
175        let dir = TempDir::new().unwrap();
176        let store = Store::open(&dir.path().join("k.db")).unwrap();
177        save_experiment(&store, &mk("b")).unwrap();
178        set_state(&store, "b", State::Running, 5_000).unwrap();
179        let got = load_experiment(&store, "b").unwrap().unwrap();
180        assert_eq!(got.state, State::Running);
181        set_state(&store, "b", State::Concluded, 9_000).unwrap();
182        let got = load_experiment(&store, "b").unwrap().unwrap();
183        assert_eq!(got.state, State::Concluded);
184        assert_eq!(got.concluded_at_ms, Some(9_000));
185    }
186
187    #[test]
188    fn tags_round_trip() {
189        let dir = TempDir::new().unwrap();
190        let store = Store::open(&dir.path().join("k.db")).unwrap();
191        save_experiment(&store, &mk("e")).unwrap();
192        tag_session(&store, "e", "s1", Classification::Treatment).unwrap();
193        tag_session(&store, "e", "s2", Classification::Control).unwrap();
194        let tags = manual_tags(&store, "e").unwrap();
195        assert_eq!(tags.get("s1"), Some(&Classification::Treatment));
196        assert_eq!(tags.get("s2"), Some(&Classification::Control));
197    }
198
199    #[test]
200    fn tag_same_variant_is_idempotent() {
201        let dir = TempDir::new().unwrap();
202        let store = Store::open(&dir.path().join("k.db")).unwrap();
203        save_experiment(&store, &mk("idem")).unwrap();
204        tag_session(&store, "idem", "s1", Classification::Control).unwrap();
205        // tagging the same variant again must succeed
206        tag_session(&store, "idem", "s1", Classification::Control).unwrap();
207        let tags = manual_tags(&store, "idem").unwrap();
208        assert_eq!(tags.get("s1"), Some(&Classification::Control));
209    }
210
211    #[test]
212    fn tag_different_variant_is_error() {
213        let dir = TempDir::new().unwrap();
214        let store = Store::open(&dir.path().join("k.db")).unwrap();
215        save_experiment(&store, &mk("conflict")).unwrap();
216        tag_session(&store, "conflict", "s1", Classification::Control).unwrap();
217        let err = tag_session(&store, "conflict", "s1", Classification::Treatment).unwrap_err();
218        assert!(
219            err.to_string().contains("variant conflict"),
220            "expected variant conflict, got: {err}"
221        );
222    }
223
224    #[test]
225    fn concurrent_tag_produces_one_row() {
226        use std::sync::Arc;
227        use std::thread;
228
229        let dir = TempDir::new().unwrap();
230        let db_path = dir.path().join("k.db");
231        let store = Store::open(&db_path).unwrap();
232        save_experiment(&store, &mk("concur")).unwrap();
233        drop(store);
234
235        // 8 threads all tag the same session as Treatment concurrently.
236        let path = Arc::new(db_path);
237        let handles: Vec<_> = (0..8)
238            .map(|_| {
239                let p = Arc::clone(&path);
240                thread::spawn(move || {
241                    let s = Store::open(&p).unwrap();
242                    tag_session(&s, "concur", "sess", Classification::Treatment)
243                })
244            })
245            .collect();
246
247        let results: Vec<_> = handles.into_iter().map(|h| h.join().unwrap()).collect();
248        let ok_count = results.iter().filter(|r| r.is_ok()).count();
249        assert!(ok_count >= 1, "at least one thread must succeed");
250
251        let store2 = Store::open(&path).unwrap();
252        let tags = manual_tags(&store2, "concur").unwrap();
253        assert_eq!(
254            tags.get("sess"),
255            Some(&Classification::Treatment),
256            "exactly one row, correct variant"
257        );
258    }
259}