1use std::collections::HashMap;
16
17use serde::{Deserialize, Serialize};
18
19use crate::optim::{HyperparameterSpace, ParameterDomain, ParameterValue};
20
21pub use super::tune_searchers::{
23 AshaScheduler, GridSearcher, MedianScheduler, NoScheduler, RandomSearcher, TpeSearcher,
24 TuneScheduler, TuneSearcher,
25};
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
33pub enum TuneStrategy {
34 Tpe,
35 Grid,
36 Random,
37}
38
39impl std::str::FromStr for TuneStrategy {
40 type Err = String;
41
42 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
43 match s.to_lowercase().as_str() {
44 "tpe" | "bayesian" => Ok(Self::Tpe),
45 "grid" => Ok(Self::Grid),
46 "random" => Ok(Self::Random),
47 _ => Err(format!("Unknown strategy: {s}. Use: tpe, grid, random")),
48 }
49 }
50}
51
52impl std::fmt::Display for TuneStrategy {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 match self {
55 Self::Tpe => write!(f, "tpe"),
56 Self::Grid => write!(f, "grid"),
57 Self::Random => write!(f, "random"),
58 }
59 }
60}
61
62#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
64pub enum SchedulerKind {
65 Asha,
66 Median,
67 None,
68}
69
70impl std::str::FromStr for SchedulerKind {
71 type Err = String;
72
73 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
74 match s.to_lowercase().as_str() {
75 "asha" => Ok(Self::Asha),
76 "median" => Ok(Self::Median),
77 "none" => Ok(Self::None),
78 _ => Err(format!("Unknown scheduler: {s}. Use: asha, median, none")),
79 }
80 }
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct TuneConfig {
86 pub budget: usize,
88 pub strategy: TuneStrategy,
90 pub scheduler: SchedulerKind,
92 pub scout: bool,
94 pub max_epochs: usize,
96 pub num_classes: usize,
98 pub seed: u64,
100 pub time_limit_secs: Option<u64>,
102}
103
104impl Default for TuneConfig {
105 fn default() -> Self {
106 Self {
107 budget: 10,
108 strategy: TuneStrategy::Tpe,
109 scheduler: SchedulerKind::Asha,
110 scout: false,
111 max_epochs: 20,
112 num_classes: 5,
113 seed: 42,
114 time_limit_secs: None,
115 }
116 }
117}
118
119#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct TrialSummary {
122 pub id: usize,
124 pub val_loss: f64,
126 pub val_accuracy: f64,
128 pub train_loss: f64,
130 pub train_accuracy: f64,
132 pub epochs_run: usize,
134 pub time_ms: u64,
136 pub config: HashMap<String, ParameterValue>,
138 pub status: String,
140}
141
142#[derive(Debug, Clone, Serialize, Deserialize)]
144pub struct TuneResult {
145 pub strategy: String,
147 pub mode: String,
149 pub budget: usize,
151 pub trials: Vec<TrialSummary>,
153 pub best_trial_id: usize,
155 pub total_time_ms: u64,
157}
158
159pub fn default_classify_search_space() -> HyperparameterSpace {
163 let mut space = HyperparameterSpace::new();
164
165 space.add(
167 "learning_rate",
168 ParameterDomain::Continuous { low: 5e-6, high: 5e-4, log_scale: true },
169 );
170
171 space.add("lora_rank", ParameterDomain::Discrete { low: 1, high: 16 });
173
174 space.add(
176 "lora_alpha_ratio",
177 ParameterDomain::Continuous { low: 0.5, high: 2.0, log_scale: false },
178 );
179
180 space.add(
182 "batch_size",
183 ParameterDomain::Categorical {
184 choices: vec![
185 "8".to_string(),
186 "16".to_string(),
187 "32".to_string(),
188 "64".to_string(),
189 "128".to_string(),
190 ],
191 },
192 );
193
194 space.add(
196 "warmup_fraction",
197 ParameterDomain::Continuous { low: 0.01, high: 0.2, log_scale: false },
198 );
199
200 space.add(
202 "gradient_clip_norm",
203 ParameterDomain::Continuous { low: 0.5, high: 5.0, log_scale: false },
204 );
205
206 space.add(
208 "class_weights",
209 ParameterDomain::Categorical {
210 choices: vec![
211 "uniform".to_string(),
212 "inverse_freq".to_string(),
213 "sqrt_inverse".to_string(),
214 ],
215 },
216 );
217
218 space.add(
220 "target_modules",
221 ParameterDomain::Categorical {
222 choices: vec!["qv".to_string(), "qkv".to_string(), "all_linear".to_string()],
223 },
224 );
225
226 space.add(
228 "lr_min_ratio",
229 ParameterDomain::Continuous { low: 0.001, high: 0.1, log_scale: true },
230 );
231
232 space
233}
234
235#[allow(clippy::implicit_hasher)]
240pub fn extract_trial_params(
241 config: &HashMap<String, ParameterValue>,
242) -> (f32, usize, f32, usize, f32, f32, String, String, f32) {
243 let lr = config.get("learning_rate").and_then(ParameterValue::as_float).unwrap_or(1e-4) as f32;
244
245 let rank_raw = config.get("lora_rank").and_then(ParameterValue::as_int).unwrap_or(4) as usize;
247 let rank = (rank_raw * 4).clamp(4, 64);
248
249 let alpha_ratio =
250 config.get("lora_alpha_ratio").and_then(ParameterValue::as_float).unwrap_or(1.0) as f32;
251 let alpha = rank as f32 * alpha_ratio;
252
253 let batch_size = config
254 .get("batch_size")
255 .and_then(ParameterValue::as_str)
256 .and_then(|s| s.parse::<usize>().ok())
257 .unwrap_or(32);
258
259 let warmup =
260 config.get("warmup_fraction").and_then(ParameterValue::as_float).unwrap_or(0.1) as f32;
261
262 let clip =
263 config.get("gradient_clip_norm").and_then(ParameterValue::as_float).unwrap_or(1.0) as f32;
264
265 let weights_strategy = config
266 .get("class_weights")
267 .and_then(ParameterValue::as_str)
268 .unwrap_or("uniform")
269 .to_string();
270
271 let targets =
272 config.get("target_modules").and_then(ParameterValue::as_str).unwrap_or("qv").to_string();
273
274 let lr_min_ratio =
275 config.get("lr_min_ratio").and_then(ParameterValue::as_float).unwrap_or(0.01) as f32;
276
277 (lr, rank, alpha, batch_size, warmup, clip, weights_strategy, targets, lr_min_ratio)
278}
279
280#[derive(Debug)]
292pub struct ClassifyTuner {
293 pub config: TuneConfig,
295 pub space: HyperparameterSpace,
297 pub leaderboard: Vec<TrialSummary>,
299}
300
301impl ClassifyTuner {
302 pub fn new(config: TuneConfig) -> crate::Result<Self> {
304 if config.budget == 0 {
305 return Err(crate::Error::ConfigError(
306 "FALSIFY-TUNE-001: budget must be > 0".to_string(),
307 ));
308 }
309 if config.num_classes == 0 {
310 return Err(crate::Error::ConfigError(
311 "FALSIFY-TUNE-004: num_classes must be > 0".to_string(),
312 ));
313 }
314
315 let space = default_classify_search_space();
316
317 Ok(Self { config, space, leaderboard: Vec::new() })
318 }
319
320 pub fn build_searcher(&self) -> Box<dyn TuneSearcher> {
322 let n_startup = (self.config.budget / 3).max(3);
323 match self.config.strategy {
324 TuneStrategy::Tpe => Box::new(TpeSearcher::new(self.space.clone(), n_startup)),
325 TuneStrategy::Grid => Box::new(GridSearcher::new(self.space.clone(), 3)),
326 TuneStrategy::Random => Box::new(RandomSearcher::new(self.space.clone())),
327 }
328 }
329
330 pub fn build_scheduler(&self) -> Box<dyn TuneScheduler> {
332 if self.config.scout {
333 return Box::new(NoScheduler);
334 }
335 match self.config.scheduler {
336 SchedulerKind::Asha => Box::new(AshaScheduler::new(1, 3.0)),
337 SchedulerKind::Median => Box::new(MedianScheduler::new(1)),
338 SchedulerKind::None => Box::new(NoScheduler),
339 }
340 }
341
342 pub fn record_trial(&mut self, summary: TrialSummary) {
344 self.leaderboard.push(summary);
345 self.leaderboard.sort_by(|a, b| {
347 a.val_loss.partial_cmp(&b.val_loss).unwrap_or(std::cmp::Ordering::Equal)
348 });
349 }
350
351 pub fn best_trial(&self) -> Option<&TrialSummary> {
353 self.leaderboard.first()
354 }
355
356 pub fn into_result(self, total_time_ms: u64) -> TuneResult {
358 let best_id = self.leaderboard.first().map_or(0, |t| t.id);
359 TuneResult {
360 strategy: self.config.strategy.to_string(),
361 mode: if self.config.scout { "scout".to_string() } else { "full".to_string() },
362 budget: self.config.budget,
363 trials: self.leaderboard,
364 best_trial_id: best_id,
365 total_time_ms,
366 }
367 }
368}
369
370#[cfg(test)]
371#[allow(clippy::unwrap_used)]
372#[path = "classify_tuner_tests.rs"]
373mod tests;