mod early_stopping;
mod lr_schedule;
mod optimizer;
mod regularization;
pub use early_stopping::EarlyStopping;
pub use lr_schedule::LRSchedule;
pub use optimizer::{Optimizer, OptimizerConfig};
pub use regularization::RegularizationConfig;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::error::{Result, TuneError};
pub const MAX_BATCH_SIZE: usize = 8192;
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct TrainingConfig {
pub epochs: usize,
pub batch_size: usize,
pub optimizer: OptimizerConfig,
pub lr_schedule: LRSchedule,
pub regularization: RegularizationConfig,
pub early_stopping: Option<EarlyStopping>,
pub val_split: f32,
pub seed: Option<u64>,
pub checkpoint_dir: Option<String>,
pub checkpoint_interval: usize,
pub log_interval: usize,
pub accumulation_steps: usize,
pub mixed_precision: bool,
pub memory_budget_mb: Option<usize>,
}
impl Default for TrainingConfig {
fn default() -> Self {
Self {
epochs: 100,
batch_size: 32,
optimizer: OptimizerConfig::default(),
lr_schedule: LRSchedule::default(),
regularization: RegularizationConfig::default(),
early_stopping: Some(EarlyStopping::default()),
val_split: 0.1,
seed: Some(42),
checkpoint_dir: None,
checkpoint_interval: 10,
log_interval: 100,
accumulation_steps: 1,
mixed_precision: false,
memory_budget_mb: None,
}
}
}
impl TrainingConfig {
pub fn quick() -> Self {
Self {
epochs: 10,
batch_size: 64,
val_split: 0.0,
early_stopping: None,
checkpoint_interval: 5,
..Default::default()
}
}
pub fn thorough() -> Self {
Self {
epochs: 200,
batch_size: 32,
optimizer: OptimizerConfig::adamw(0.0001, 0.01),
lr_schedule: LRSchedule::CosineAnnealingWarmup {
warmup_steps: 500,
min_lr: 1e-7,
t_max: 200,
},
regularization: RegularizationConfig::strong(),
early_stopping: Some(EarlyStopping::val_loss(20)),
val_split: 0.15,
seed: Some(42),
checkpoint_dir: None,
checkpoint_interval: 5,
log_interval: 50,
accumulation_steps: 1,
mixed_precision: true,
memory_budget_mb: None,
}
}
pub fn epochs(mut self, epochs: usize) -> Self {
self.epochs = epochs;
self
}
pub fn batch_size(mut self, batch_size: usize) -> Self {
self.batch_size = batch_size;
self
}
pub fn learning_rate(mut self, lr: f32) -> Self {
self.optimizer.learning_rate = lr;
self
}
pub fn optimizer(mut self, optimizer: OptimizerConfig) -> Self {
self.optimizer = optimizer;
self
}
pub fn lr_schedule(mut self, schedule: LRSchedule) -> Self {
self.lr_schedule = schedule;
self
}
pub fn regularization(mut self, reg: RegularizationConfig) -> Self {
self.regularization = reg;
self
}
pub fn early_stopping(mut self, es: EarlyStopping) -> Self {
self.early_stopping = Some(es);
self
}
pub fn no_early_stopping(mut self) -> Self {
self.early_stopping = None;
self
}
pub fn val_split(mut self, split: f32) -> Self {
self.val_split = split;
self
}
pub fn seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
pub fn checkpoint_dir(mut self, dir: impl Into<String>) -> Self {
self.checkpoint_dir = Some(dir.into());
self
}
pub fn memory_budget_mb(mut self, budget_mb: usize) -> Self {
self.memory_budget_mb = Some(budget_mb);
self
}
pub fn check_memory_budget(&self, required_mb: usize) -> Result<()> {
if let Some(budget_mb) = self.memory_budget_mb {
if required_mb > budget_mb {
return Err(TuneError::MemoryBudgetExceeded {
required_mb,
budget_mb,
});
}
}
Ok(())
}
pub fn estimate_memory_mb(&self, model_params: usize, embedding_dim: usize) -> Result<usize> {
let model_size_bytes = model_params.checked_mul(4).ok_or_else(|| {
TuneError::InvalidConfig("model_params overflow in memory estimate".into())
})?;
let gradient_size_bytes = model_size_bytes;
let optimizer_state_bytes = match self.optimizer.optimizer {
Optimizer::Adam | Optimizer::AdamW => model_params.checked_mul(8),
Optimizer::RMSprop | Optimizer::SGDMomentum => model_params.checked_mul(4),
Optimizer::SGD => Some(0),
}
.ok_or_else(|| {
TuneError::InvalidConfig("optimizer state overflow in memory estimate".into())
})?;
let batch_data_bytes = self
.batch_size
.checked_mul(embedding_dim)
.and_then(|v| v.checked_mul(8))
.ok_or_else(|| {
TuneError::InvalidConfig("batch data overflow in memory estimate".into())
})?;
let activation_bytes = batch_data_bytes.checked_mul(2).ok_or_else(|| {
TuneError::InvalidConfig("activation overflow in memory estimate".into())
})?;
let total_bytes = model_size_bytes
.checked_add(gradient_size_bytes)
.and_then(|v| v.checked_add(optimizer_state_bytes))
.and_then(|v| v.checked_add(batch_data_bytes))
.and_then(|v| v.checked_add(activation_bytes))
.ok_or_else(|| TuneError::InvalidConfig("total overflow in memory estimate".into()))?;
Ok((total_bytes as f64 * 1.2 / (1024.0 * 1024.0)).ceil() as usize)
}
pub fn validate(&self) -> Result<()> {
if self.epochs == 0 {
return Err(TuneError::InvalidConfig("epochs must be > 0".to_string()));
}
if self.batch_size == 0 {
return Err(TuneError::InvalidConfig(
"batch_size must be > 0".to_string(),
));
}
if self.batch_size > MAX_BATCH_SIZE {
return Err(TuneError::InvalidConfig(format!(
"batch_size {} exceeds MAX_BATCH_SIZE {}",
self.batch_size, MAX_BATCH_SIZE
)));
}
if !(0.0..=1.0).contains(&self.val_split) {
return Err(TuneError::InvalidConfig(format!(
"val_split must be between 0 and 1, got {}",
self.val_split
)));
}
if self.accumulation_steps == 0 {
return Err(TuneError::InvalidConfig(
"accumulation_steps must be > 0".to_string(),
));
}
self.optimizer.validate()?;
self.regularization.validate()?;
self.lr_schedule.validate()?;
Ok(())
}
pub fn effective_batch_size(&self) -> usize {
self.batch_size.saturating_mul(self.accumulation_steps)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_training_config_defaults() {
let config = TrainingConfig::default();
assert!(config.validate().is_ok());
assert_eq!(config.epochs, 100);
assert_eq!(config.batch_size, 32);
}
#[test]
fn test_training_config_builder() {
let config = TrainingConfig::default()
.epochs(50)
.batch_size(64)
.learning_rate(0.0001)
.val_split(0.2);
assert_eq!(config.epochs, 50);
assert_eq!(config.batch_size, 64);
assert_eq!(config.optimizer.learning_rate, 0.0001);
assert_eq!(config.val_split, 0.2);
}
#[test]
fn test_training_config_validation() {
let mut config = TrainingConfig::default();
assert!(config.validate().is_ok());
config.epochs = 0;
assert!(config.validate().is_err());
config.epochs = 100;
config.val_split = 1.5;
assert!(config.validate().is_err());
}
#[test]
fn test_presets() {
let quick = TrainingConfig::quick();
assert!(quick.validate().is_ok());
assert!(quick.epochs < TrainingConfig::default().epochs);
let thorough = TrainingConfig::thorough();
assert!(thorough.validate().is_ok());
assert!(thorough.epochs > TrainingConfig::default().epochs);
}
}