Skip to main content

mdx_rust_core/
ledger.rs

1//! Experiment budgeting and prompt variant ledger primitives.
2//!
3//! These records make optimization runs explainable without requiring a
4//! database. They are append-friendly JSON structures that can later be moved
5//! behind a richer storage layer.
6
7use crate::eval::{stable_hash_hex, EvaluationDataset, EvaluationSample};
8use schemars::JsonSchema;
9use serde::{Deserialize, Serialize};
10
11#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
12pub enum OptimizationBudget {
13    Light,
14    #[default]
15    Medium,
16    Heavy,
17}
18
19impl OptimizationBudget {
20    pub fn from_label(value: &str) -> anyhow::Result<Self> {
21        match value {
22            "light" => Ok(Self::Light),
23            "medium" => Ok(Self::Medium),
24            "heavy" => Ok(Self::Heavy),
25            other => anyhow::bail!("unknown optimization budget: {other}"),
26        }
27    }
28
29    pub fn label(self) -> &'static str {
30        match self {
31            Self::Light => "light",
32            Self::Medium => "medium",
33            Self::Heavy => "heavy",
34        }
35    }
36
37    pub fn candidate_limit(self, requested: u32) -> usize {
38        let cap = match self {
39            Self::Light => 2,
40            Self::Medium => 4,
41            Self::Heavy => 8,
42        };
43        requested.max(1).min(cap) as usize
44    }
45
46    pub fn holdout_percent(self) -> usize {
47        match self {
48            Self::Light => 20,
49            Self::Medium => 25,
50            Self::Heavy => 30,
51        }
52    }
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
56pub struct DatasetSplit {
57    pub train: Vec<EvaluationSample>,
58    pub holdout: Vec<EvaluationSample>,
59    pub holdout_percent: usize,
60}
61
62pub fn split_dataset(dataset: &EvaluationDataset, budget: OptimizationBudget) -> DatasetSplit {
63    if dataset.samples.len() < 2 {
64        return DatasetSplit {
65            train: dataset.samples.clone(),
66            holdout: Vec::new(),
67            holdout_percent: 0,
68        };
69    }
70
71    let holdout_percent = budget.holdout_percent();
72    let holdout_len = ((dataset.samples.len() * holdout_percent).div_ceil(100))
73        .max(1)
74        .min(dataset.samples.len() - 1);
75    let split_at = dataset.samples.len() - holdout_len;
76
77    DatasetSplit {
78        train: dataset.samples[..split_at].to_vec(),
79        holdout: dataset.samples[split_at..].to_vec(),
80        holdout_percent,
81    }
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
85pub struct PromptVariantRecord {
86    pub id: String,
87    pub strategy: String,
88    pub target_file: String,
89    pub patch_hash: String,
90    pub description: String,
91}
92
93impl PromptVariantRecord {
94    pub fn from_patch(
95        strategy: impl Into<String>,
96        target_file: impl Into<String>,
97        description: impl Into<String>,
98        patch: &str,
99    ) -> Self {
100        let patch_hash = stable_hash_hex(patch.as_bytes());
101        Self {
102            id: patch_hash.replace(':', "_"),
103            strategy: strategy.into(),
104            target_file: target_file.into(),
105            patch_hash,
106            description: description.into(),
107        }
108    }
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
112pub struct ExperimentLedger {
113    pub budget: OptimizationBudget,
114    pub dataset_version: String,
115    pub dataset_hash: String,
116    pub train_samples: usize,
117    pub holdout_samples: usize,
118    pub variants: Vec<PromptVariantRecord>,
119}
120
121impl ExperimentLedger {
122    pub fn new(
123        budget: OptimizationBudget,
124        dataset: &EvaluationDataset,
125        split: &DatasetSplit,
126    ) -> Self {
127        Self {
128            budget,
129            dataset_version: dataset.version.clone(),
130            dataset_hash: dataset.content_hash(),
131            train_samples: split.train.len(),
132            holdout_samples: split.holdout.len(),
133            variants: Vec::new(),
134        }
135    }
136
137    pub fn record_variant(&mut self, variant: PromptVariantRecord) {
138        self.variants.push(variant);
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145
146    #[test]
147    fn split_dataset_keeps_at_least_one_train_sample() {
148        let dataset = EvaluationDataset::synthetic_v1();
149        let split = split_dataset(&dataset, OptimizationBudget::Heavy);
150
151        assert!(!split.train.is_empty());
152        assert!(!split.holdout.is_empty());
153        assert_eq!(
154            split.train.len() + split.holdout.len(),
155            dataset.samples.len()
156        );
157    }
158
159    #[test]
160    fn variant_ids_are_stable_hashes() {
161        let first = PromptVariantRecord::from_patch("schema", "src/main.rs", "desc", "patch");
162        let second = PromptVariantRecord::from_patch("schema", "src/main.rs", "desc", "patch");
163
164        assert_eq!(first.id, second.id);
165        assert!(first.patch_hash.starts_with("fnv1a64:"));
166    }
167
168    #[test]
169    fn budget_caps_candidates_but_never_to_zero() {
170        assert_eq!(OptimizationBudget::Light.candidate_limit(99), 2);
171        assert_eq!(OptimizationBudget::Medium.candidate_limit(99), 4);
172        assert_eq!(OptimizationBudget::Heavy.candidate_limit(99), 8);
173        assert_eq!(OptimizationBudget::Light.candidate_limit(0), 1);
174    }
175
176    #[test]
177    fn ledger_recording_variants_does_not_record_acceptance() {
178        let dataset = EvaluationDataset::synthetic_v1();
179        let split = split_dataset(&dataset, OptimizationBudget::Medium);
180        let mut ledger = ExperimentLedger::new(OptimizationBudget::Medium, &dataset, &split);
181
182        ledger.record_variant(PromptVariantRecord::from_patch(
183            "schema",
184            "src/main.rs",
185            "candidate only",
186            "patch",
187        ));
188
189        assert_eq!(ledger.variants.len(), 1);
190        assert_eq!(
191            ledger.train_samples + ledger.holdout_samples,
192            dataset.samples.len()
193        );
194    }
195}