use scirs2_core::ndarray::ScalarOperand;
use scirs2_core::numeric::{Float, NumAssign};
use std::collections::HashMap;
use std::fmt::Debug;
pub mod backprop_efficient;
pub mod gradient_accumulation;
pub mod gradient_checkpointing;
pub mod mixed_precision;
pub mod progress_monitor;
pub mod quantization_aware;
pub mod sparse_training;
pub mod enhanced_trainer;
pub mod optimized_dataloader;
pub use backprop_efficient::*;
pub use gradient_accumulation::*;
pub use gradient_checkpointing::*;
pub use mixed_precision::*;
pub use progress_monitor::*;
pub use quantization_aware::*;
pub use sparse_training::*;
pub use enhanced_trainer::{
EarlyStoppingConfig, EnhancedTrainer, EnhancedTrainingConfig, GradientAccumulationSettings,
LRWarmupConfig, OperationTiming, OptimizationAnalyzer, OptimizationRecommendation,
ProfilingConfig, ProfilingResults, ProgressConfig, RecommendationType, TrainingState,
ValidationConfig, WarmupSchedule,
};
pub use optimized_dataloader::{
BatchSizeOptimizationResult, BatchSizeOptimizer, LoadingStats, OptimizedDataLoader,
OptimizedLoaderConfig, PrefetchingIterator,
};
#[derive(Debug, Clone)]
pub struct TrainingConfig {
pub batch_size: usize,
pub shuffle: bool,
pub num_workers: usize,
pub learning_rate: f64,
pub epochs: usize,
pub verbose: usize,
pub validation: Option<ValidationSettings>,
pub gradient_accumulation: Option<GradientAccumulationConfig>,
pub mixed_precision: Option<MixedPrecisionConfig>,
}
impl Default for TrainingConfig {
fn default() -> Self {
Self {
batch_size: 32,
shuffle: true,
num_workers: 0,
learning_rate: 0.001,
epochs: 10,
verbose: 1,
validation: None,
gradient_accumulation: None,
mixed_precision: None,
}
}
}
#[derive(Debug, Clone)]
pub struct ValidationSettings {
pub enabled: bool,
pub validation_split: f64,
pub batch_size: usize,
pub num_workers: usize,
}
impl Default for ValidationSettings {
fn default() -> Self {
Self {
enabled: true,
validation_split: 0.2,
batch_size: 32,
num_workers: 0,
}
}
}
#[derive(Debug, Clone)]
pub struct TrainingSession<F: Float + Debug + ScalarOperand> {
pub history: HashMap<String, Vec<F>>,
pub initial_learning_rate: F,
pub epochs_trained: usize,
pub current_epoch: usize,
pub best_validation_score: Option<F>,
pub early_stopped: bool,
}
impl<F: Float + Debug + ScalarOperand> TrainingSession<F> {
pub fn new(config: TrainingConfig) -> Self {
Self {
history: HashMap::new(),
initial_learning_rate: F::from(config.learning_rate)
.expect("Failed to convert to float"),
epochs_trained: 0,
current_epoch: 0,
best_validation_score: None,
early_stopped: false,
}
}
pub fn add_metric(&mut self, metricname: &str, value: F) {
self.history
.entry(metricname.to_string())
.or_default()
.push(value);
}
pub fn get_metric_history(&self, metricname: &str) -> Option<&Vec<F>> {
self.history.get(metricname)
}
pub fn get_metric_names(&self) -> Vec<&String> {
self.history.keys().collect()
}
pub fn next_epoch(&mut self) {
self.current_epoch += 1;
self.epochs_trained += 1;
}
pub fn finish_training(&mut self) {
}
pub fn early_stop(&mut self) {
self.early_stopped = true;
}
}
impl<F: Float + Debug + ScalarOperand> Default for TrainingSession<F> {
fn default() -> Self {
Self {
history: HashMap::new(),
initial_learning_rate: F::from(0.001).expect("Failed to convert constant to float"),
epochs_trained: 0,
current_epoch: 0,
best_validation_score: None,
early_stopped: false,
}
}
}