use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingConfig {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub epochs: Option<usize>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_steps: Option<usize>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub duration: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub gradient: Option<GradientConfig>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub mixed_precision: Option<MixedPrecisionConfig>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub distributed: Option<DistributedConfig>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub checkpoint: Option<CheckpointConfig>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub early_stopping: Option<EarlyStoppingConfig>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub validation: Option<ValidationConfig>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub deterministic: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub benchmark: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub curriculum: Option<Vec<crate::config::CurriculumStage>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GradientConfig {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub accumulation_steps: Option<usize>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub clip_norm: Option<f64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub clip_value: Option<f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MixedPrecisionConfig {
pub enabled: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub dtype: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub loss_scale: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DistributedConfig {
pub strategy: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub world_size: Option<usize>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub gradient_as_bucket_view: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub find_unused_parameters: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CheckpointConfig {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub save_every: Option<usize>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub keep_last: Option<usize>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub save_best: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub metric: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub mode: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EarlyStoppingConfig {
pub enabled: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub metric: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub patience: Option<usize>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub min_delta: Option<f64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub mode: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ValidationConfig {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub every: Option<usize>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub every_epoch: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub metrics: Option<Vec<String>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub cross_validation: Option<CrossValidationConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CrossValidationConfig {
pub folds: usize,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub stratified: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub shuffle: Option<bool>,
}