nt_neural/training/
mod.rs1#[cfg(feature = "candle")]
4pub mod data_loader;
5#[cfg(feature = "candle")]
6pub mod optimizer;
7#[cfg(feature = "candle")]
8pub mod trainer;
9#[cfg(feature = "candle")]
10pub mod nhits_trainer;
11
12pub mod cpu_trainer;
14pub mod simple_cpu_trainer;
15
16use serde::{Serialize, Deserialize};
17
18#[cfg(feature = "candle")]
20pub use data_loader::{DataLoader, TimeSeriesDataset};
21#[cfg(feature = "candle")]
22pub use optimizer::{LRScheduler, Optimizer, OptimizerConfig, OptimizerType};
23#[cfg(feature = "candle")]
24pub use trainer::{CheckpointMetadata, Trainer, quantile_loss};
25#[cfg(feature = "candle")]
26pub use nhits_trainer::{NHITSTrainer, NHITSTrainingConfig};
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct TrainingConfig {
31 pub batch_size: usize,
32 pub num_epochs: usize,
33 pub learning_rate: f64,
34 pub weight_decay: f64,
35 pub gradient_clip: Option<f64>,
36 pub early_stopping_patience: usize,
37 pub validation_split: f64,
38 pub mixed_precision: bool,
39}
40
41impl Default for TrainingConfig {
42 fn default() -> Self {
43 Self {
44 batch_size: 32,
45 num_epochs: 100,
46 learning_rate: 1e-3,
47 weight_decay: 1e-5,
48 gradient_clip: Some(1.0),
49 early_stopping_patience: 10,
50 validation_split: 0.2,
51 mixed_precision: true,
52 }
53 }
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct TrainingMetrics {
59 pub epoch: usize,
60 pub train_loss: f64,
61 pub val_loss: Option<f64>,
62 pub learning_rate: f64,
63 #[serde(default)]
64 pub epoch_time_seconds: f64,
65}
66
67impl Default for TrainingMetrics {
68 fn default() -> Self {
69 Self {
70 epoch: 0,
71 train_loss: 0.0,
72 val_loss: None,
73 learning_rate: 0.001,
74 epoch_time_seconds: 0.0,
75 }
76 }
77}