Skip to main content

kaizen/guidance/
validation.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2//! Held-out gate for applied guidance candidates.
3
4use crate::experiment::store as exp_store;
5use crate::experiment::types::{Experiment, Metric};
6use crate::experiment::{self, Report};
7use crate::guidance::{CandidateStatus, GuidanceCandidate};
8use crate::store::Store;
9use anyhow::{Result, anyhow};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::path::Path;
13
14#[derive(Clone, Debug, Serialize, Deserialize)]
15pub struct ValidationGate {
16    pub candidate_id: String,
17    pub experiment_id: String,
18    pub outcome: ValidationOutcome,
19    pub n_control: usize,
20    pub n_treatment: usize,
21    pub delta_pct: Option<f64>,
22    pub target_met: Option<bool>,
23    pub guardrail_violations: usize,
24}
25
26#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)]
27#[serde(rename_all = "snake_case")]
28pub enum ValidationOutcome {
29    Validated,
30    Rejected,
31    InsufficientEvidence,
32}
33
34impl ValidationGate {
35    pub fn next_status(&self) -> Option<CandidateStatus> {
36        match self.outcome {
37            ValidationOutcome::Validated => Some(CandidateStatus::Validated),
38            ValidationOutcome::Rejected => Some(CandidateStatus::Rejected),
39            ValidationOutcome::InsufficientEvidence => None,
40        }
41    }
42}
43
44pub fn evaluate(store: &Store, workspace: &Path, c: &GuidanceCandidate) -> Result<ValidationGate> {
45    let exp_id = c
46        .experiment_id
47        .as_deref()
48        .ok_or_else(|| anyhow!("candidate has no prompt-bound experiment"))?;
49    let exp = exp_store::load_experiment(store, exp_id)?
50        .ok_or_else(|| anyhow!("experiment not found: {exp_id}"))?;
51    Ok(from_report(c, &report(store, workspace, &exp)?, exp_id))
52}
53
54fn from_report(c: &GuidanceCandidate, report: &Report, exp_id: &str) -> ValidationGate {
55    let summary = &report.summary;
56    let violations = report
57        .guardrail_results
58        .iter()
59        .filter(|g| g.violated)
60        .count();
61    ValidationGate {
62        candidate_id: c.id.clone(),
63        experiment_id: exp_id.into(),
64        outcome: outcome(
65            report.target_met,
66            summary.n_control,
67            summary.n_treatment,
68            violations,
69        ),
70        n_control: summary.n_control,
71        n_treatment: summary.n_treatment,
72        delta_pct: summary.delta_pct,
73        target_met: report.target_met,
74        guardrail_violations: violations,
75    }
76}
77
78fn outcome(
79    target_met: Option<bool>,
80    n_control: usize,
81    n_treatment: usize,
82    guardrails: usize,
83) -> ValidationOutcome {
84    match (n_control, n_treatment, target_met, guardrails) {
85        (0, _, _, _) | (_, 0, _, _) | (_, _, None, _) => ValidationOutcome::InsufficientEvidence,
86        (_, _, _, n) if n > 0 => ValidationOutcome::Rejected,
87        (_, _, Some(true), _) => ValidationOutcome::Validated,
88        (_, _, Some(false), _) => ValidationOutcome::Rejected,
89    }
90}
91
92fn report(store: &Store, workspace: &Path, exp: &Experiment) -> Result<Report> {
93    let ws = workspace.to_string_lossy().to_string();
94    let (start, end) = window_for(exp);
95    let manual = exp_store::manual_tags(store, &exp.id)?;
96    let (sessions, values) = metric_values(store, &ws, start, end, exp.metric)?;
97    let guardrails = guardrail_values(store, &ws, start, end, exp)?;
98    Ok(experiment::run_from_metric_values(
99        exp,
100        &sessions,
101        &values,
102        &guardrails,
103        &manual,
104        workspace,
105        false,
106    ))
107}
108
109fn guardrail_values(
110    store: &Store,
111    ws: &str,
112    start: u64,
113    end: u64,
114    exp: &Experiment,
115) -> Result<HashMap<Metric, HashMap<String, f64>>> {
116    exp.guardrails
117        .iter()
118        .map(|g| metric_values(store, ws, start, end, g.metric).map(|(_, vals)| (g.metric, vals)))
119        .collect()
120}
121
122fn metric_values(
123    store: &Store,
124    ws: &str,
125    start: u64,
126    end: u64,
127    metric: Metric,
128) -> Result<(Vec<crate::core::event::SessionRecord>, HashMap<String, f64>)> {
129    let rows = store.experiment_metric_values_in_window(ws, start, end, metric)?;
130    Ok(rows.into_iter().fold(
131        (Vec::new(), HashMap::new()),
132        |(mut sessions, mut values), (session, value)| {
133            values.insert(session.id.clone(), value);
134            sessions.push(session);
135            (sessions, values)
136        },
137    ))
138}
139
140fn window_for(e: &Experiment) -> (u64, u64) {
141    let end = e
142        .concluded_at_ms
143        .unwrap_or_else(|| e.created_at_ms + (e.duration_days as u64) * 86_400_000);
144    (e.created_at_ms, end.max(e.created_at_ms))
145}
146
147#[cfg(test)]
148mod tests {
149    use super::*;
150
151    #[test]
152    fn target_hit_validates() {
153        assert_eq!(outcome(Some(true), 30, 30, 0), ValidationOutcome::Validated);
154    }
155
156    #[test]
157    fn target_miss_rejects() {
158        assert_eq!(outcome(Some(false), 30, 30, 0), ValidationOutcome::Rejected);
159    }
160
161    #[test]
162    fn missing_arm_keeps_applied() {
163        assert_eq!(
164            outcome(None, 0, 30, 0),
165            ValidationOutcome::InsufficientEvidence
166        );
167    }
168}