use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
use crate::error::ConfigError;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingConfig {
pub num_subcarriers: usize,
pub native_subcarriers: usize,
pub num_antennas_tx: usize,
pub num_antennas_rx: usize,
pub window_frames: usize,
pub heatmap_size: usize,
pub num_keypoints: usize,
pub num_body_parts: usize,
pub backbone_channels: usize,
pub batch_size: usize,
pub learning_rate: f64,
pub weight_decay: f64,
pub num_epochs: usize,
pub warmup_epochs: usize,
pub lr_milestones: Vec<usize>,
pub lr_gamma: f64,
pub grad_clip_norm: f64,
pub lambda_kp: f64,
pub lambda_dp: f64,
pub lambda_tr: f64,
pub val_every_epochs: usize,
pub early_stopping_patience: usize,
pub checkpoint_dir: PathBuf,
pub log_dir: PathBuf,
pub save_top_k: usize,
pub use_gpu: bool,
pub gpu_device_id: i64,
pub num_workers: usize,
pub seed: u64,
}
impl Default for TrainingConfig {
fn default() -> Self {
TrainingConfig {
num_subcarriers: 56,
native_subcarriers: 114,
num_antennas_tx: 3,
num_antennas_rx: 3,
window_frames: 100,
heatmap_size: 56,
num_keypoints: 17,
num_body_parts: 24,
backbone_channels: 256,
batch_size: 8,
learning_rate: 1e-3,
weight_decay: 1e-4,
num_epochs: 50,
warmup_epochs: 5,
lr_milestones: vec![30, 45],
lr_gamma: 0.1,
grad_clip_norm: 1.0,
lambda_kp: 0.3,
lambda_dp: 0.6,
lambda_tr: 0.1,
val_every_epochs: 1,
early_stopping_patience: 10,
checkpoint_dir: PathBuf::from("checkpoints"),
log_dir: PathBuf::from("logs"),
save_top_k: 3,
use_gpu: false,
gpu_device_id: 0,
num_workers: 4,
seed: 42,
}
}
}
impl TrainingConfig {
pub fn from_json(path: &Path) -> Result<Self, ConfigError> {
let contents = std::fs::read_to_string(path).map_err(|source| ConfigError::FileRead {
path: path.to_path_buf(),
source,
})?;
let cfg: TrainingConfig = serde_json::from_str(&contents)
.map_err(|e| ConfigError::invalid_value("(file)", e.to_string()))?;
cfg.validate()?;
Ok(cfg)
}
pub fn to_json(&self, path: &Path) -> Result<(), ConfigError> {
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent).map_err(|source| ConfigError::FileRead {
path: parent.to_path_buf(),
source,
})?;
}
let json = serde_json::to_string_pretty(self)
.map_err(|e| ConfigError::invalid_value("(serialization)", e.to_string()))?;
std::fs::write(path, json).map_err(|source| ConfigError::FileRead {
path: path.to_path_buf(),
source,
})?;
Ok(())
}
#[must_use]
pub fn for_subcarriers(native: usize, target: usize) -> Self {
Self {
native_subcarriers: native,
num_subcarriers: target,
..Self::default()
}
}
#[must_use]
pub fn mmfi() -> Self {
Self::default()
}
#[must_use]
pub fn ht40_192() -> Self {
Self::for_subcarriers(192, 56)
}
#[must_use]
pub fn multiband_168() -> Self {
Self::for_subcarriers(168, 56)
}
pub fn needs_subcarrier_interp(&self) -> bool {
self.native_subcarriers != self.num_subcarriers
}
pub fn validate(&self) -> Result<(), ConfigError> {
if self.num_subcarriers == 0 {
return Err(ConfigError::invalid_value("num_subcarriers", "must be > 0"));
}
if self.native_subcarriers == 0 {
return Err(ConfigError::invalid_value(
"native_subcarriers",
"must be > 0",
));
}
if self.num_antennas_tx == 0 {
return Err(ConfigError::invalid_value("num_antennas_tx", "must be > 0"));
}
if self.num_antennas_rx == 0 {
return Err(ConfigError::invalid_value("num_antennas_rx", "must be > 0"));
}
if self.window_frames == 0 {
return Err(ConfigError::invalid_value("window_frames", "must be > 0"));
}
if self.heatmap_size == 0 {
return Err(ConfigError::invalid_value("heatmap_size", "must be > 0"));
}
if self.num_keypoints == 0 {
return Err(ConfigError::invalid_value("num_keypoints", "must be > 0"));
}
if self.num_body_parts == 0 {
return Err(ConfigError::invalid_value("num_body_parts", "must be > 0"));
}
if self.backbone_channels == 0 {
return Err(ConfigError::invalid_value(
"backbone_channels",
"must be > 0",
));
}
if self.batch_size == 0 {
return Err(ConfigError::invalid_value("batch_size", "must be > 0"));
}
if self.learning_rate <= 0.0 {
return Err(ConfigError::invalid_value(
"learning_rate",
"must be > 0.0",
));
}
if self.weight_decay < 0.0 {
return Err(ConfigError::invalid_value(
"weight_decay",
"must be >= 0.0",
));
}
if self.grad_clip_norm <= 0.0 {
return Err(ConfigError::invalid_value(
"grad_clip_norm",
"must be > 0.0",
));
}
if self.num_epochs == 0 {
return Err(ConfigError::invalid_value("num_epochs", "must be > 0"));
}
if self.warmup_epochs >= self.num_epochs {
return Err(ConfigError::invalid_value(
"warmup_epochs",
"must be < num_epochs",
));
}
let mut prev = 0usize;
for &m in &self.lr_milestones {
if m == 0 || m > self.num_epochs {
return Err(ConfigError::invalid_value(
"lr_milestones",
"each milestone must be in [1, num_epochs]",
));
}
if m <= prev {
return Err(ConfigError::invalid_value(
"lr_milestones",
"milestones must be strictly increasing",
));
}
prev = m;
}
if self.lr_gamma <= 0.0 || self.lr_gamma >= 1.0 {
return Err(ConfigError::invalid_value(
"lr_gamma",
"must be in (0.0, 1.0)",
));
}
if self.lambda_kp < 0.0 {
return Err(ConfigError::invalid_value("lambda_kp", "must be >= 0.0"));
}
if self.lambda_dp < 0.0 {
return Err(ConfigError::invalid_value("lambda_dp", "must be >= 0.0"));
}
if self.lambda_tr < 0.0 {
return Err(ConfigError::invalid_value("lambda_tr", "must be >= 0.0"));
}
let total_weight = self.lambda_kp + self.lambda_dp + self.lambda_tr;
if total_weight <= 0.0 {
return Err(ConfigError::invalid_value(
"lambda_kp / lambda_dp / lambda_tr",
"at least one loss weight must be > 0.0",
));
}
if self.val_every_epochs == 0 {
return Err(ConfigError::invalid_value(
"val_every_epochs",
"must be > 0",
));
}
if self.save_top_k == 0 {
return Err(ConfigError::invalid_value("save_top_k", "must be > 0"));
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn default_config_is_valid() {
let cfg = TrainingConfig::default();
cfg.validate().expect("default config should be valid");
}
#[test]
fn json_round_trip() {
let tmp = tempdir().unwrap();
let path = tmp.path().join("config.json");
let original = TrainingConfig::default();
original.to_json(&path).expect("serialization should succeed");
let loaded = TrainingConfig::from_json(&path).expect("deserialization should succeed");
assert_eq!(loaded.num_subcarriers, original.num_subcarriers);
assert_eq!(loaded.batch_size, original.batch_size);
assert_eq!(loaded.seed, original.seed);
assert_eq!(loaded.lr_milestones, original.lr_milestones);
}
#[test]
fn zero_subcarriers_is_invalid() {
let mut cfg = TrainingConfig::default();
cfg.num_subcarriers = 0;
assert!(cfg.validate().is_err());
}
#[test]
fn negative_learning_rate_is_invalid() {
let mut cfg = TrainingConfig::default();
cfg.learning_rate = -0.001;
assert!(cfg.validate().is_err());
}
#[test]
fn warmup_equal_to_epochs_is_invalid() {
let mut cfg = TrainingConfig::default();
cfg.warmup_epochs = cfg.num_epochs;
assert!(cfg.validate().is_err());
}
#[test]
fn non_increasing_milestones_are_invalid() {
let mut cfg = TrainingConfig::default();
cfg.lr_milestones = vec![30, 20]; assert!(cfg.validate().is_err());
}
#[test]
fn milestone_beyond_epochs_is_invalid() {
let mut cfg = TrainingConfig::default();
cfg.lr_milestones = vec![30, cfg.num_epochs + 1];
assert!(cfg.validate().is_err());
}
#[test]
fn all_zero_loss_weights_are_invalid() {
let mut cfg = TrainingConfig::default();
cfg.lambda_kp = 0.0;
cfg.lambda_dp = 0.0;
cfg.lambda_tr = 0.0;
assert!(cfg.validate().is_err());
}
#[test]
fn needs_subcarrier_interp_when_counts_differ() {
let mut cfg = TrainingConfig::default();
cfg.num_subcarriers = 56;
cfg.native_subcarriers = 114;
assert!(cfg.needs_subcarrier_interp());
cfg.native_subcarriers = 56;
assert!(!cfg.needs_subcarrier_interp());
}
#[test]
fn config_fields_have_expected_defaults() {
let cfg = TrainingConfig::default();
assert_eq!(cfg.num_subcarriers, 56);
assert_eq!(cfg.native_subcarriers, 114);
assert_eq!(cfg.num_antennas_tx, 3);
assert_eq!(cfg.num_antennas_rx, 3);
assert_eq!(cfg.window_frames, 100);
assert_eq!(cfg.heatmap_size, 56);
assert_eq!(cfg.num_keypoints, 17);
assert_eq!(cfg.num_body_parts, 24);
assert_eq!(cfg.batch_size, 8);
assert!((cfg.learning_rate - 1e-3).abs() < 1e-10);
assert_eq!(cfg.num_epochs, 50);
assert_eq!(cfg.warmup_epochs, 5);
assert_eq!(cfg.lr_milestones, vec![30, 45]);
assert!((cfg.lr_gamma - 0.1).abs() < 1e-10);
assert!((cfg.lambda_kp - 0.3).abs() < 1e-10);
assert!((cfg.lambda_dp - 0.6).abs() < 1e-10);
assert!((cfg.lambda_tr - 0.1).abs() < 1e-10);
assert_eq!(cfg.seed, 42);
}
}