1use crate::eval::{stable_hash_hex, EvaluationDataset, EvaluationSample};
8use schemars::JsonSchema;
9use serde::{Deserialize, Serialize};
10
11#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
12pub enum OptimizationBudget {
13 Light,
14 #[default]
15 Medium,
16 Heavy,
17}
18
19impl OptimizationBudget {
20 pub fn from_label(value: &str) -> anyhow::Result<Self> {
21 match value {
22 "light" => Ok(Self::Light),
23 "medium" => Ok(Self::Medium),
24 "heavy" => Ok(Self::Heavy),
25 other => anyhow::bail!("unknown optimization budget: {other}"),
26 }
27 }
28
29 pub fn label(self) -> &'static str {
30 match self {
31 Self::Light => "light",
32 Self::Medium => "medium",
33 Self::Heavy => "heavy",
34 }
35 }
36
37 pub fn candidate_limit(self, requested: u32) -> usize {
38 let cap = match self {
39 Self::Light => 2,
40 Self::Medium => 4,
41 Self::Heavy => 8,
42 };
43 requested.max(1).min(cap) as usize
44 }
45
46 pub fn holdout_percent(self) -> usize {
47 match self {
48 Self::Light => 20,
49 Self::Medium => 25,
50 Self::Heavy => 30,
51 }
52 }
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
56pub struct DatasetSplit {
57 pub train: Vec<EvaluationSample>,
58 pub holdout: Vec<EvaluationSample>,
59 pub holdout_percent: usize,
60}
61
62pub fn split_dataset(dataset: &EvaluationDataset, budget: OptimizationBudget) -> DatasetSplit {
63 if dataset.samples.len() < 2 {
64 return DatasetSplit {
65 train: dataset.samples.clone(),
66 holdout: Vec::new(),
67 holdout_percent: 0,
68 };
69 }
70
71 let holdout_percent = budget.holdout_percent();
72 let holdout_len = ((dataset.samples.len() * holdout_percent).div_ceil(100))
73 .max(1)
74 .min(dataset.samples.len() - 1);
75 let split_at = dataset.samples.len() - holdout_len;
76
77 DatasetSplit {
78 train: dataset.samples[..split_at].to_vec(),
79 holdout: dataset.samples[split_at..].to_vec(),
80 holdout_percent,
81 }
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
85pub struct PromptVariantRecord {
86 pub id: String,
87 pub strategy: String,
88 pub target_file: String,
89 pub patch_hash: String,
90 pub description: String,
91}
92
93impl PromptVariantRecord {
94 pub fn from_patch(
95 strategy: impl Into<String>,
96 target_file: impl Into<String>,
97 description: impl Into<String>,
98 patch: &str,
99 ) -> Self {
100 let patch_hash = stable_hash_hex(patch.as_bytes());
101 Self {
102 id: patch_hash.replace(':', "_"),
103 strategy: strategy.into(),
104 target_file: target_file.into(),
105 patch_hash,
106 description: description.into(),
107 }
108 }
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
112pub struct ExperimentLedger {
113 pub budget: OptimizationBudget,
114 pub dataset_version: String,
115 pub dataset_hash: String,
116 pub train_samples: usize,
117 pub holdout_samples: usize,
118 pub variants: Vec<PromptVariantRecord>,
119}
120
121impl ExperimentLedger {
122 pub fn new(
123 budget: OptimizationBudget,
124 dataset: &EvaluationDataset,
125 split: &DatasetSplit,
126 ) -> Self {
127 Self {
128 budget,
129 dataset_version: dataset.version.clone(),
130 dataset_hash: dataset.content_hash(),
131 train_samples: split.train.len(),
132 holdout_samples: split.holdout.len(),
133 variants: Vec::new(),
134 }
135 }
136
137 pub fn record_variant(&mut self, variant: PromptVariantRecord) {
138 self.variants.push(variant);
139 }
140}
141
142#[cfg(test)]
143mod tests {
144 use super::*;
145
146 #[test]
147 fn split_dataset_keeps_at_least_one_train_sample() {
148 let dataset = EvaluationDataset::synthetic_v1();
149 let split = split_dataset(&dataset, OptimizationBudget::Heavy);
150
151 assert!(!split.train.is_empty());
152 assert!(!split.holdout.is_empty());
153 assert_eq!(
154 split.train.len() + split.holdout.len(),
155 dataset.samples.len()
156 );
157 }
158
159 #[test]
160 fn variant_ids_are_stable_hashes() {
161 let first = PromptVariantRecord::from_patch("schema", "src/main.rs", "desc", "patch");
162 let second = PromptVariantRecord::from_patch("schema", "src/main.rs", "desc", "patch");
163
164 assert_eq!(first.id, second.id);
165 assert!(first.patch_hash.starts_with("fnv1a64:"));
166 }
167
168 #[test]
169 fn budget_caps_candidates_but_never_to_zero() {
170 assert_eq!(OptimizationBudget::Light.candidate_limit(99), 2);
171 assert_eq!(OptimizationBudget::Medium.candidate_limit(99), 4);
172 assert_eq!(OptimizationBudget::Heavy.candidate_limit(99), 8);
173 assert_eq!(OptimizationBudget::Light.candidate_limit(0), 1);
174 }
175
176 #[test]
177 fn ledger_recording_variants_does_not_record_acceptance() {
178 let dataset = EvaluationDataset::synthetic_v1();
179 let split = split_dataset(&dataset, OptimizationBudget::Medium);
180 let mut ledger = ExperimentLedger::new(OptimizationBudget::Medium, &dataset, &split);
181
182 ledger.record_variant(PromptVariantRecord::from_patch(
183 "schema",
184 "src/main.rs",
185 "candidate only",
186 "patch",
187 ));
188
189 assert_eq!(ledger.variants.len(), 1);
190 assert_eq!(
191 ledger.train_samples + ledger.holdout_samples,
192 dataset.samples.len()
193 );
194 }
195}