Skip to main content

entrenar/finetune/
classify_tuner.rs

1//! Automatic hyperparameter tuning for classification fine-tuning (SPEC-TUNE-2026-001)
2//!
3//! Provides `ClassifyTuner` which orchestrates HPO search over LoRA + classifier
4//! configurations using existing TPE/Grid/Hyperband infrastructure.
5//!
6//! # Architecture
7//!
8//! ```text
9//! ClassifyTuner
10//!   ├── TuneSearcher (TPE / Grid / Random)
11//!   ├── TuneScheduler (ASHA / Median / None)
12//!   └── per trial: ClassifyPipeline → ClassifyTrainer → EpochMetrics
13//! ```
14
15use std::collections::HashMap;
16
17use serde::{Deserialize, Serialize};
18
19use crate::optim::{HyperparameterSpace, ParameterDomain, ParameterValue};
20
21// Re-export searcher/scheduler types from submodule
22pub use super::tune_searchers::{
23    AshaScheduler, GridSearcher, MedianScheduler, NoScheduler, RandomSearcher, TpeSearcher,
24    TuneScheduler, TuneSearcher,
25};
26
27// ═══════════════════════════════════════════════════════════════════════
28// Configuration and result types
29// ═══════════════════════════════════════════════════════════════════════
30
31/// Tuning strategy selection.
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
33pub enum TuneStrategy {
34    Tpe,
35    Grid,
36    Random,
37}
38
39impl std::str::FromStr for TuneStrategy {
40    type Err = String;
41
42    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
43        match s.to_lowercase().as_str() {
44            "tpe" | "bayesian" => Ok(Self::Tpe),
45            "grid" => Ok(Self::Grid),
46            "random" => Ok(Self::Random),
47            _ => Err(format!("Unknown strategy: {s}. Use: tpe, grid, random")),
48        }
49    }
50}
51
52impl std::fmt::Display for TuneStrategy {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        match self {
55            Self::Tpe => write!(f, "tpe"),
56            Self::Grid => write!(f, "grid"),
57            Self::Random => write!(f, "random"),
58        }
59    }
60}
61
62/// Scheduler selection.
63#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
64pub enum SchedulerKind {
65    Asha,
66    Median,
67    None,
68}
69
70impl std::str::FromStr for SchedulerKind {
71    type Err = String;
72
73    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
74        match s.to_lowercase().as_str() {
75            "asha" => Ok(Self::Asha),
76            "median" => Ok(Self::Median),
77            "none" => Ok(Self::None),
78            _ => Err(format!("Unknown scheduler: {s}. Use: asha, median, none")),
79        }
80    }
81}
82
83/// Tuning run configuration.
84#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct TuneConfig {
86    /// Maximum number of trials.
87    pub budget: usize,
88    /// Search strategy (TPE, Grid, Random).
89    pub strategy: TuneStrategy,
90    /// Scheduler for early stopping.
91    pub scheduler: SchedulerKind,
92    /// Scout mode: 1 epoch per trial, no scheduling.
93    pub scout: bool,
94    /// Maximum epochs per trial (full mode).
95    pub max_epochs: usize,
96    /// Number of output classes.
97    pub num_classes: usize,
98    /// Random seed for reproducibility.
99    pub seed: u64,
100    /// Optional time limit in seconds.
101    pub time_limit_secs: Option<u64>,
102}
103
104impl Default for TuneConfig {
105    fn default() -> Self {
106        Self {
107            budget: 10,
108            strategy: TuneStrategy::Tpe,
109            scheduler: SchedulerKind::Asha,
110            scout: false,
111            max_epochs: 20,
112            num_classes: 5,
113            seed: 42,
114            time_limit_secs: None,
115        }
116    }
117}
118
119/// Summary of a single completed trial.
120#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct TrialSummary {
122    /// Trial index.
123    pub id: usize,
124    /// Validation loss (best epoch).
125    pub val_loss: f64,
126    /// Validation accuracy (best epoch).
127    pub val_accuracy: f64,
128    /// Training loss (final epoch).
129    pub train_loss: f64,
130    /// Training accuracy (final epoch).
131    pub train_accuracy: f64,
132    /// Number of epochs actually run.
133    pub epochs_run: usize,
134    /// Wall-clock time in milliseconds.
135    pub time_ms: u64,
136    /// Hyperparameter configuration.
137    pub config: HashMap<String, ParameterValue>,
138    /// Trial status.
139    pub status: String,
140}
141
142/// Result of a complete tuning run.
143#[derive(Debug, Clone, Serialize, Deserialize)]
144pub struct TuneResult {
145    /// Strategy used.
146    pub strategy: String,
147    /// Mode (scout or full).
148    pub mode: String,
149    /// Budget (total trials attempted).
150    pub budget: usize,
151    /// All trial summaries, sorted by val_loss ascending (best first).
152    pub trials: Vec<TrialSummary>,
153    /// ID of the best trial.
154    pub best_trial_id: usize,
155    /// Total wall-clock time in milliseconds.
156    pub total_time_ms: u64,
157}
158
159/// Build the default 9-parameter search space for classification HPO.
160///
161/// Parameters from SPEC-TUNE-2026-001 §4.1.
162pub fn default_classify_search_space() -> HyperparameterSpace {
163    let mut space = HyperparameterSpace::new();
164
165    // Learning rate: 5e-6 .. 5e-4 (log-scale)
166    space.add(
167        "learning_rate",
168        ParameterDomain::Continuous { low: 5e-6, high: 5e-4, log_scale: true },
169    );
170
171    // LoRA rank: 4 .. 64 (discrete, step of 4 → 4,8,12,...,64)
172    space.add("lora_rank", ParameterDomain::Discrete { low: 1, high: 16 });
173
174    // Alpha ratio: 0.5 .. 2.0 (ties alpha to rank)
175    space.add(
176        "lora_alpha_ratio",
177        ParameterDomain::Continuous { low: 0.5, high: 2.0, log_scale: false },
178    );
179
180    // Batch size: categorical [8, 16, 32, 64, 128]
181    space.add(
182        "batch_size",
183        ParameterDomain::Categorical {
184            choices: vec![
185                "8".to_string(),
186                "16".to_string(),
187                "32".to_string(),
188                "64".to_string(),
189                "128".to_string(),
190            ],
191        },
192    );
193
194    // Warmup fraction: 0.01 .. 0.2
195    space.add(
196        "warmup_fraction",
197        ParameterDomain::Continuous { low: 0.01, high: 0.2, log_scale: false },
198    );
199
200    // Gradient clip norm: 0.5 .. 5.0
201    space.add(
202        "gradient_clip_norm",
203        ParameterDomain::Continuous { low: 0.5, high: 5.0, log_scale: false },
204    );
205
206    // Class weights strategy
207    space.add(
208        "class_weights",
209        ParameterDomain::Categorical {
210            choices: vec![
211                "uniform".to_string(),
212                "inverse_freq".to_string(),
213                "sqrt_inverse".to_string(),
214            ],
215        },
216    );
217
218    // Target modules
219    space.add(
220        "target_modules",
221        ParameterDomain::Categorical {
222            choices: vec!["qv".to_string(), "qkv".to_string(), "all_linear".to_string()],
223        },
224    );
225
226    // LR min ratio (cosine decay floor = lr * ratio)
227    space.add(
228        "lr_min_ratio",
229        ParameterDomain::Continuous { low: 0.001, high: 0.1, log_scale: true },
230    );
231
232    space
233}
234
235/// Convert a trial's ParameterValue map into concrete hyperparameter values.
236///
237/// Returns (learning_rate, lora_rank, lora_alpha, batch_size, warmup_fraction,
238///          gradient_clip_norm, class_weights_strategy, target_modules, lr_min_ratio).
239#[allow(clippy::implicit_hasher)]
240pub fn extract_trial_params(
241    config: &HashMap<String, ParameterValue>,
242) -> (f32, usize, f32, usize, f32, f32, String, String, f32) {
243    let lr = config.get("learning_rate").and_then(ParameterValue::as_float).unwrap_or(1e-4) as f32;
244
245    // lora_rank: discrete 1-16 maps to rank * 4 → 4,8,...,64
246    let rank_raw = config.get("lora_rank").and_then(ParameterValue::as_int).unwrap_or(4) as usize;
247    let rank = (rank_raw * 4).clamp(4, 64);
248
249    let alpha_ratio =
250        config.get("lora_alpha_ratio").and_then(ParameterValue::as_float).unwrap_or(1.0) as f32;
251    let alpha = rank as f32 * alpha_ratio;
252
253    let batch_size = config
254        .get("batch_size")
255        .and_then(ParameterValue::as_str)
256        .and_then(|s| s.parse::<usize>().ok())
257        .unwrap_or(32);
258
259    let warmup =
260        config.get("warmup_fraction").and_then(ParameterValue::as_float).unwrap_or(0.1) as f32;
261
262    let clip =
263        config.get("gradient_clip_norm").and_then(ParameterValue::as_float).unwrap_or(1.0) as f32;
264
265    let weights_strategy = config
266        .get("class_weights")
267        .and_then(ParameterValue::as_str)
268        .unwrap_or("uniform")
269        .to_string();
270
271    let targets =
272        config.get("target_modules").and_then(ParameterValue::as_str).unwrap_or("qv").to_string();
273
274    let lr_min_ratio =
275        config.get("lr_min_ratio").and_then(ParameterValue::as_float).unwrap_or(0.01) as f32;
276
277    (lr, rank, alpha, batch_size, warmup, clip, weights_strategy, targets, lr_min_ratio)
278}
279
280// ═══════════════════════════════════════════════════════════════════════
281// ClassifyTuner
282// ═══════════════════════════════════════════════════════════════════════
283
284/// Orchestrates hyperparameter optimization for classification fine-tuning.
285///
286/// Coordinates:
287/// 1. Searcher (TPE/Grid/Random) to suggest configs
288/// 2. Scheduler (ASHA/Median/None) for early stopping
289/// 3. ClassifyTrainer execution per trial
290/// 4. Leaderboard ranking and persistence
291#[derive(Debug)]
292pub struct ClassifyTuner {
293    /// Tuning configuration.
294    pub config: TuneConfig,
295    /// Search space.
296    pub space: HyperparameterSpace,
297    /// Completed trial summaries.
298    pub leaderboard: Vec<TrialSummary>,
299}
300
301impl ClassifyTuner {
302    /// Create a new tuner with the default classification search space.
303    pub fn new(config: TuneConfig) -> crate::Result<Self> {
304        if config.budget == 0 {
305            return Err(crate::Error::ConfigError(
306                "FALSIFY-TUNE-001: budget must be > 0".to_string(),
307            ));
308        }
309        if config.num_classes == 0 {
310            return Err(crate::Error::ConfigError(
311                "FALSIFY-TUNE-004: num_classes must be > 0".to_string(),
312            ));
313        }
314
315        let space = default_classify_search_space();
316
317        Ok(Self { config, space, leaderboard: Vec::new() })
318    }
319
320    /// Create the appropriate searcher based on strategy.
321    pub fn build_searcher(&self) -> Box<dyn TuneSearcher> {
322        let n_startup = (self.config.budget / 3).max(3);
323        match self.config.strategy {
324            TuneStrategy::Tpe => Box::new(TpeSearcher::new(self.space.clone(), n_startup)),
325            TuneStrategy::Grid => Box::new(GridSearcher::new(self.space.clone(), 3)),
326            TuneStrategy::Random => Box::new(RandomSearcher::new(self.space.clone())),
327        }
328    }
329
330    /// Create the appropriate scheduler.
331    pub fn build_scheduler(&self) -> Box<dyn TuneScheduler> {
332        if self.config.scout {
333            return Box::new(NoScheduler);
334        }
335        match self.config.scheduler {
336            SchedulerKind::Asha => Box::new(AshaScheduler::new(1, 3.0)),
337            SchedulerKind::Median => Box::new(MedianScheduler::new(1)),
338            SchedulerKind::None => Box::new(NoScheduler),
339        }
340    }
341
342    /// Record a completed trial result and update the leaderboard.
343    pub fn record_trial(&mut self, summary: TrialSummary) {
344        self.leaderboard.push(summary);
345        // Sort leaderboard by val_loss ascending (best first)
346        self.leaderboard.sort_by(|a, b| {
347            a.val_loss.partial_cmp(&b.val_loss).unwrap_or(std::cmp::Ordering::Equal)
348        });
349    }
350
351    /// Get the best trial summary from the leaderboard.
352    pub fn best_trial(&self) -> Option<&TrialSummary> {
353        self.leaderboard.first()
354    }
355
356    /// Build the final TuneResult from collected trials.
357    pub fn into_result(self, total_time_ms: u64) -> TuneResult {
358        let best_id = self.leaderboard.first().map_or(0, |t| t.id);
359        TuneResult {
360            strategy: self.config.strategy.to_string(),
361            mode: if self.config.scout { "scout".to_string() } else { "full".to_string() },
362            budget: self.config.budget,
363            trials: self.leaderboard,
364            best_trial_id: best_id,
365            total_time_ms,
366        }
367    }
368}
369
370#[cfg(test)]
371#[allow(clippy::unwrap_used)]
372#[path = "classify_tuner_tests.rs"]
373mod tests;