Skip to main content

kaizen/experiment/
types.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2//! Pure data for experiments. See `docs/experiments.md`.
3
4use serde::{Deserialize, Serialize};
5
6/// Variant a session falls into under a binding.
7#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
8pub enum Classification {
9    Control,
10    Treatment,
11    Excluded,
12}
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
15pub enum Metric {
16    TokensPerSession,
17    CostPerSession,
18    SuccessRate,
19    ToolLoops,
20    DurationMinutes,
21    FilesPerSession,
22    SuccessRateByPrompt,
23    CostByPrompt,
24}
25
26#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
27pub enum Binding {
28    GitCommit {
29        control_commit: String,
30        treatment_commit: String,
31    },
32    Branch {
33        control_branch: String,
34        treatment_branch: String,
35    },
36    PromptFingerprint {
37        control_fingerprint: String,
38        treatment_fingerprint: String,
39    },
40    ManualTag {
41        variant_field: String,
42    },
43}
44
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
46pub enum Direction {
47    Decrease,
48    Increase,
49}
50
51#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
52pub enum Criterion {
53    Delta {
54        direction: Direction,
55        target_pct: f64,
56    },
57    Absolute {
58        metric_value: f64,
59    },
60}
61
62/// Lifecycle state. `Archived` is terminal.
63#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
64pub enum State {
65    Draft,
66    Running,
67    Concluded,
68    Archived,
69}
70
71/// Guardrail: a secondary metric that must not regress.
72///
73/// If the CI shows a regression beyond `threshold_pct` in the specified
74/// direction, the report flags the guardrail as violated.
75#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
76pub struct GuardrailSpec {
77    pub metric: Metric,
78    /// Direction of *regression* (e.g. `Increase` for cost = cost going up is bad).
79    pub regression_direction: Direction,
80    /// Flag if CI endpoint crosses this threshold.
81    pub threshold_pct: f64,
82}
83
84/// Per-guardrail result in the experiment report.
85#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
86pub struct GuardrailResult {
87    pub metric: Metric,
88    pub delta_pct: Option<f64>,
89    pub violated: bool,
90}
91
92#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
93pub struct Experiment {
94    pub id: String,
95    pub name: String,
96    pub hypothesis: String,
97    pub change_description: String,
98    pub metric: Metric,
99    pub binding: Binding,
100    pub duration_days: u32,
101    pub success_criterion: Criterion,
102    pub state: State,
103    pub created_at_ms: u64,
104    pub concluded_at_ms: Option<u64>,
105    #[serde(default)]
106    pub guardrails: Vec<GuardrailSpec>,
107}
108
109impl Metric {
110    pub fn as_str(&self) -> &'static str {
111        match self {
112            Metric::TokensPerSession => "tokens_per_session",
113            Metric::CostPerSession => "cost_per_session",
114            Metric::SuccessRate => "success_rate",
115            Metric::ToolLoops => "tool_loops",
116            Metric::DurationMinutes => "duration_minutes",
117            Metric::FilesPerSession => "files_per_session",
118            Metric::SuccessRateByPrompt => "success_rate_by_prompt",
119            Metric::CostByPrompt => "cost_by_prompt",
120        }
121    }
122
123    pub fn parse(s: &str) -> Option<Metric> {
124        Some(match s {
125            "tokens_per_session" => Metric::TokensPerSession,
126            "cost_per_session" => Metric::CostPerSession,
127            "success_rate" => Metric::SuccessRate,
128            "tool_loops" => Metric::ToolLoops,
129            "duration_minutes" => Metric::DurationMinutes,
130            "files_per_session" => Metric::FilesPerSession,
131            "success_rate_by_prompt" => Metric::SuccessRateByPrompt,
132            "cost_by_prompt" => Metric::CostByPrompt,
133            _ => return None,
134        })
135    }
136}
137
138/// Pure state-machine transition. Returns `Some(next)` when `action` is enabled.
139pub fn transition(state: State, action: &str) -> Option<State> {
140    Some(match (state, action) {
141        (State::Draft, "start") => State::Running,
142        (State::Running, "conclude") => State::Concluded,
143        (State::Concluded, "archive") => State::Archived,
144        _ => return None,
145    })
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151
152    #[test]
153    fn transitions_follow_spec_order() {
154        assert_eq!(transition(State::Draft, "start"), Some(State::Running));
155        assert_eq!(
156            transition(State::Running, "conclude"),
157            Some(State::Concluded)
158        );
159        assert_eq!(
160            transition(State::Concluded, "archive"),
161            Some(State::Archived)
162        );
163    }
164
165    #[test]
166    fn archived_is_terminal() {
167        assert_eq!(transition(State::Archived, "start"), None);
168        assert_eq!(transition(State::Archived, "conclude"), None);
169        assert_eq!(transition(State::Archived, "archive"), None);
170    }
171
172    #[test]
173    fn no_backward_transitions() {
174        assert_eq!(transition(State::Concluded, "start"), None);
175        assert_eq!(transition(State::Running, "archive"), None);
176    }
177
178    #[test]
179    fn metric_round_trip() {
180        for m in [
181            Metric::TokensPerSession,
182            Metric::CostPerSession,
183            Metric::SuccessRate,
184            Metric::ToolLoops,
185            Metric::DurationMinutes,
186            Metric::FilesPerSession,
187        ] {
188            assert_eq!(Metric::parse(m.as_str()), Some(m));
189        }
190    }
191}