nt_neural/training/
mod.rs

1//! Training infrastructure for neural models
2
3#[cfg(feature = "candle")]
4pub mod data_loader;
5#[cfg(feature = "candle")]
6pub mod optimizer;
7#[cfg(feature = "candle")]
8pub mod trainer;
9#[cfg(feature = "candle")]
10pub mod nhits_trainer;
11
12// CPU-only training (no candle dependency)
13pub mod cpu_trainer;
14pub mod simple_cpu_trainer;
15
16use serde::{Serialize, Deserialize};
17
18// Re-export main types when candle is enabled
19#[cfg(feature = "candle")]
20pub use data_loader::{DataLoader, TimeSeriesDataset};
21#[cfg(feature = "candle")]
22pub use optimizer::{LRScheduler, Optimizer, OptimizerConfig, OptimizerType};
23#[cfg(feature = "candle")]
24pub use trainer::{CheckpointMetadata, Trainer, quantile_loss};
25#[cfg(feature = "candle")]
26pub use nhits_trainer::{NHITSTrainer, NHITSTrainingConfig};
27
28/// Training configuration
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct TrainingConfig {
31    pub batch_size: usize,
32    pub num_epochs: usize,
33    pub learning_rate: f64,
34    pub weight_decay: f64,
35    pub gradient_clip: Option<f64>,
36    pub early_stopping_patience: usize,
37    pub validation_split: f64,
38    pub mixed_precision: bool,
39}
40
41impl Default for TrainingConfig {
42    fn default() -> Self {
43        Self {
44            batch_size: 32,
45            num_epochs: 100,
46            learning_rate: 1e-3,
47            weight_decay: 1e-5,
48            gradient_clip: Some(1.0),
49            early_stopping_patience: 10,
50            validation_split: 0.2,
51            mixed_precision: true,
52        }
53    }
54}
55
56/// Training metrics
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct TrainingMetrics {
59    pub epoch: usize,
60    pub train_loss: f64,
61    pub val_loss: Option<f64>,
62    pub learning_rate: f64,
63    #[serde(default)]
64    pub epoch_time_seconds: f64,
65}
66
67impl Default for TrainingMetrics {
68    fn default() -> Self {
69        Self {
70            epoch: 0,
71            train_loss: 0.0,
72            val_loss: None,
73            learning_rate: 0.001,
74            epoch_time_seconds: 0.0,
75        }
76    }
77}