use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use crate::optim::{HyperparameterSpace, ParameterDomain, ParameterValue};
pub use super::tune_searchers::{
AshaScheduler, GridSearcher, MedianScheduler, NoScheduler, RandomSearcher, TpeSearcher,
TuneScheduler, TuneSearcher,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum TuneStrategy {
Tpe,
Grid,
Random,
}
impl std::str::FromStr for TuneStrategy {
type Err = String;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"tpe" | "bayesian" => Ok(Self::Tpe),
"grid" => Ok(Self::Grid),
"random" => Ok(Self::Random),
_ => Err(format!("Unknown strategy: {s}. Use: tpe, grid, random")),
}
}
}
impl std::fmt::Display for TuneStrategy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Tpe => write!(f, "tpe"),
Self::Grid => write!(f, "grid"),
Self::Random => write!(f, "random"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum SchedulerKind {
Asha,
Median,
None,
}
impl std::str::FromStr for SchedulerKind {
type Err = String;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"asha" => Ok(Self::Asha),
"median" => Ok(Self::Median),
"none" => Ok(Self::None),
_ => Err(format!("Unknown scheduler: {s}. Use: asha, median, none")),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TuneConfig {
pub budget: usize,
pub strategy: TuneStrategy,
pub scheduler: SchedulerKind,
pub scout: bool,
pub max_epochs: usize,
pub num_classes: usize,
pub seed: u64,
pub time_limit_secs: Option<u64>,
}
impl Default for TuneConfig {
fn default() -> Self {
Self {
budget: 10,
strategy: TuneStrategy::Tpe,
scheduler: SchedulerKind::Asha,
scout: false,
max_epochs: 20,
num_classes: 5,
seed: 42,
time_limit_secs: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrialSummary {
pub id: usize,
pub val_loss: f64,
pub val_accuracy: f64,
pub train_loss: f64,
pub train_accuracy: f64,
pub epochs_run: usize,
pub time_ms: u64,
pub config: HashMap<String, ParameterValue>,
pub status: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TuneResult {
pub strategy: String,
pub mode: String,
pub budget: usize,
pub trials: Vec<TrialSummary>,
pub best_trial_id: usize,
pub total_time_ms: u64,
}
pub fn default_classify_search_space() -> HyperparameterSpace {
let mut space = HyperparameterSpace::new();
space.add(
"learning_rate",
ParameterDomain::Continuous { low: 5e-6, high: 5e-4, log_scale: true },
);
space.add("lora_rank", ParameterDomain::Discrete { low: 1, high: 16 });
space.add(
"lora_alpha_ratio",
ParameterDomain::Continuous { low: 0.5, high: 2.0, log_scale: false },
);
space.add(
"batch_size",
ParameterDomain::Categorical {
choices: vec![
"8".to_string(),
"16".to_string(),
"32".to_string(),
"64".to_string(),
"128".to_string(),
],
},
);
space.add(
"warmup_fraction",
ParameterDomain::Continuous { low: 0.01, high: 0.2, log_scale: false },
);
space.add(
"gradient_clip_norm",
ParameterDomain::Continuous { low: 0.5, high: 5.0, log_scale: false },
);
space.add(
"class_weights",
ParameterDomain::Categorical {
choices: vec![
"uniform".to_string(),
"inverse_freq".to_string(),
"sqrt_inverse".to_string(),
],
},
);
space.add(
"target_modules",
ParameterDomain::Categorical {
choices: vec!["qv".to_string(), "qkv".to_string(), "all_linear".to_string()],
},
);
space.add(
"lr_min_ratio",
ParameterDomain::Continuous { low: 0.001, high: 0.1, log_scale: true },
);
space
}
#[allow(clippy::implicit_hasher)]
pub fn extract_trial_params(
config: &HashMap<String, ParameterValue>,
) -> (f32, usize, f32, usize, f32, f32, String, String, f32) {
let lr = config.get("learning_rate").and_then(ParameterValue::as_float).unwrap_or(1e-4) as f32;
let rank_raw = config.get("lora_rank").and_then(ParameterValue::as_int).unwrap_or(4) as usize;
let rank = (rank_raw * 4).clamp(4, 64);
let alpha_ratio =
config.get("lora_alpha_ratio").and_then(ParameterValue::as_float).unwrap_or(1.0) as f32;
let alpha = rank as f32 * alpha_ratio;
let batch_size = config
.get("batch_size")
.and_then(ParameterValue::as_str)
.and_then(|s| s.parse::<usize>().ok())
.unwrap_or(32);
let warmup =
config.get("warmup_fraction").and_then(ParameterValue::as_float).unwrap_or(0.1) as f32;
let clip =
config.get("gradient_clip_norm").and_then(ParameterValue::as_float).unwrap_or(1.0) as f32;
let weights_strategy = config
.get("class_weights")
.and_then(ParameterValue::as_str)
.unwrap_or("uniform")
.to_string();
let targets =
config.get("target_modules").and_then(ParameterValue::as_str).unwrap_or("qv").to_string();
let lr_min_ratio =
config.get("lr_min_ratio").and_then(ParameterValue::as_float).unwrap_or(0.01) as f32;
(lr, rank, alpha, batch_size, warmup, clip, weights_strategy, targets, lr_min_ratio)
}
#[derive(Debug)]
pub struct ClassifyTuner {
pub config: TuneConfig,
pub space: HyperparameterSpace,
pub leaderboard: Vec<TrialSummary>,
}
impl ClassifyTuner {
pub fn new(config: TuneConfig) -> crate::Result<Self> {
if config.budget == 0 {
return Err(crate::Error::ConfigError(
"FALSIFY-TUNE-001: budget must be > 0".to_string(),
));
}
if config.num_classes == 0 {
return Err(crate::Error::ConfigError(
"FALSIFY-TUNE-004: num_classes must be > 0".to_string(),
));
}
let space = default_classify_search_space();
Ok(Self { config, space, leaderboard: Vec::new() })
}
pub fn build_searcher(&self) -> Box<dyn TuneSearcher> {
let n_startup = (self.config.budget / 3).max(3);
match self.config.strategy {
TuneStrategy::Tpe => Box::new(TpeSearcher::new(self.space.clone(), n_startup)),
TuneStrategy::Grid => Box::new(GridSearcher::new(self.space.clone(), 3)),
TuneStrategy::Random => Box::new(RandomSearcher::new(self.space.clone())),
}
}
pub fn build_scheduler(&self) -> Box<dyn TuneScheduler> {
if self.config.scout {
return Box::new(NoScheduler);
}
match self.config.scheduler {
SchedulerKind::Asha => Box::new(AshaScheduler::new(1, 3.0)),
SchedulerKind::Median => Box::new(MedianScheduler::new(1)),
SchedulerKind::None => Box::new(NoScheduler),
}
}
pub fn record_trial(&mut self, summary: TrialSummary) {
self.leaderboard.push(summary);
self.leaderboard.sort_by(|a, b| {
a.val_loss.partial_cmp(&b.val_loss).unwrap_or(std::cmp::Ordering::Equal)
});
}
pub fn best_trial(&self) -> Option<&TrialSummary> {
self.leaderboard.first()
}
pub fn into_result(self, total_time_ms: u64) -> TuneResult {
let best_id = self.leaderboard.first().map_or(0, |t| t.id);
TuneResult {
strategy: self.config.strategy.to_string(),
mode: if self.config.scout { "scout".to_string() } else { "full".to_string() },
budget: self.config.budget,
trials: self.leaderboard,
best_trial_id: best_id,
total_time_ms,
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
#[path = "classify_tuner_tests.rs"]
mod tests;