Skip to main content

kaizen/experiment/
types.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2//! Pure data for experiments. See `docs/experiments.md`.
3
4use serde::{Deserialize, Serialize};
5
6/// Variant a session falls into under a binding.
7#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
8pub enum Classification {
9    Control,
10    Treatment,
11    Excluded,
12}
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
15pub enum Metric {
16    TokensPerSession,
17    CostPerSession,
18    SuccessRate,
19    ToolLoops,
20    DurationMinutes,
21    FilesPerSession,
22    SuccessRateByPrompt,
23    CostByPrompt,
24}
25
26#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
27pub enum Binding {
28    GitCommit {
29        control_commit: String,
30        treatment_commit: String,
31    },
32    Branch {
33        control_branch: String,
34        treatment_branch: String,
35    },
36    ManualTag {
37        variant_field: String,
38    },
39}
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
42pub enum Direction {
43    Decrease,
44    Increase,
45}
46
47#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
48pub enum Criterion {
49    Delta {
50        direction: Direction,
51        target_pct: f64,
52    },
53    Absolute {
54        metric_value: f64,
55    },
56}
57
58/// Lifecycle state. `Archived` is terminal.
59#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
60pub enum State {
61    Draft,
62    Running,
63    Concluded,
64    Archived,
65}
66
67/// Guardrail: a secondary metric that must not regress.
68///
69/// If the CI shows a regression beyond `threshold_pct` in the specified
70/// direction, the report flags the guardrail as violated.
71#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
72pub struct GuardrailSpec {
73    pub metric: Metric,
74    /// Direction of *regression* (e.g. `Increase` for cost = cost going up is bad).
75    pub regression_direction: Direction,
76    /// Flag if CI endpoint crosses this threshold.
77    pub threshold_pct: f64,
78}
79
80/// Per-guardrail result in the experiment report.
81#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
82pub struct GuardrailResult {
83    pub metric: Metric,
84    pub delta_pct: Option<f64>,
85    pub violated: bool,
86}
87
88#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
89pub struct Experiment {
90    pub id: String,
91    pub name: String,
92    pub hypothesis: String,
93    pub change_description: String,
94    pub metric: Metric,
95    pub binding: Binding,
96    pub duration_days: u32,
97    pub success_criterion: Criterion,
98    pub state: State,
99    pub created_at_ms: u64,
100    pub concluded_at_ms: Option<u64>,
101    #[serde(default)]
102    pub guardrails: Vec<GuardrailSpec>,
103}
104
105impl Metric {
106    pub fn as_str(&self) -> &'static str {
107        match self {
108            Metric::TokensPerSession => "tokens_per_session",
109            Metric::CostPerSession => "cost_per_session",
110            Metric::SuccessRate => "success_rate",
111            Metric::ToolLoops => "tool_loops",
112            Metric::DurationMinutes => "duration_minutes",
113            Metric::FilesPerSession => "files_per_session",
114            Metric::SuccessRateByPrompt => "success_rate_by_prompt",
115            Metric::CostByPrompt => "cost_by_prompt",
116        }
117    }
118
119    pub fn parse(s: &str) -> Option<Metric> {
120        Some(match s {
121            "tokens_per_session" => Metric::TokensPerSession,
122            "cost_per_session" => Metric::CostPerSession,
123            "success_rate" => Metric::SuccessRate,
124            "tool_loops" => Metric::ToolLoops,
125            "duration_minutes" => Metric::DurationMinutes,
126            "files_per_session" => Metric::FilesPerSession,
127            "success_rate_by_prompt" => Metric::SuccessRateByPrompt,
128            "cost_by_prompt" => Metric::CostByPrompt,
129            _ => return None,
130        })
131    }
132}
133
134/// Pure state-machine transition. Returns `Some(next)` when `action` is enabled.
135pub fn transition(state: State, action: &str) -> Option<State> {
136    Some(match (state, action) {
137        (State::Draft, "start") => State::Running,
138        (State::Running, "conclude") => State::Concluded,
139        (State::Concluded, "archive") => State::Archived,
140        _ => return None,
141    })
142}
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147
148    #[test]
149    fn transitions_follow_spec_order() {
150        assert_eq!(transition(State::Draft, "start"), Some(State::Running));
151        assert_eq!(
152            transition(State::Running, "conclude"),
153            Some(State::Concluded)
154        );
155        assert_eq!(
156            transition(State::Concluded, "archive"),
157            Some(State::Archived)
158        );
159    }
160
161    #[test]
162    fn archived_is_terminal() {
163        assert_eq!(transition(State::Archived, "start"), None);
164        assert_eq!(transition(State::Archived, "conclude"), None);
165        assert_eq!(transition(State::Archived, "archive"), None);
166    }
167
168    #[test]
169    fn no_backward_transitions() {
170        assert_eq!(transition(State::Concluded, "start"), None);
171        assert_eq!(transition(State::Running, "archive"), None);
172    }
173
174    #[test]
175    fn metric_round_trip() {
176        for m in [
177            Metric::TokensPerSession,
178            Metric::CostPerSession,
179            Metric::SuccessRate,
180            Metric::ToolLoops,
181            Metric::DurationMinutes,
182            Metric::FilesPerSession,
183        ] {
184            assert_eq!(Metric::parse(m.as_str()), Some(m));
185        }
186    }
187}