use scirs2_core::ndarray::ScalarOperand;
use scirs2_core::numeric::{Float, NumAssign};
use std::collections::HashMap;
use std::fmt::Debug;
pub mod augmentation;
pub mod backprop_efficient;
pub mod checkpoint;
pub mod early_stopping;
pub mod gradient_accumulation;
pub mod gradient_checkpointing;
pub mod gradient_clipping;
pub mod metrics_tracker;
pub mod mixed_precision;
pub mod progress_monitor;
pub mod quantization_aware;
pub mod schedulers;
pub mod sparse_training;
pub mod enhanced_trainer;
pub mod optimized_dataloader;
pub mod curriculum;
pub mod federated;
pub mod hparam_tuner;
pub mod lr_finder;
pub mod pipeline_parallel;
pub mod profiler;
pub mod tensor_parallel;
pub use backprop_efficient::*;
pub use checkpoint::{
best_checkpoint, checkpoint_dir_name, latest_checkpoint, list_checkpoints, load_checkpoint,
save_checkpoint, CheckpointConfig, CheckpointError, CheckpointManager, CheckpointMetadata,
LrSchedulerState, OptimizerCheckpointState, OptimizerStateMetadata, ParamGroupState,
TrainingCheckpoint,
};
pub use gradient_accumulation::*;
pub use gradient_checkpointing::*;
pub use mixed_precision::*;
pub use progress_monitor::*;
pub use quantization_aware::*;
pub use schedulers::{
ChainedScheduler, CosineAnnealingLR as CosineAnnealingScheduler, CosineAnnealingWarmRestarts,
CyclicLR, CyclicMode, ExponentialLR, LRScheduler, LinearWarmup, MultiStepLR, OneCycleLR,
PolynomialLR, ReduceOnPlateau as ReduceOnPlateauScheduler, StepLR, WarmupCosine,
};
pub use sparse_training::*;
pub use gradient_clipping::{
clip_grad_adaptive, clip_grad_agc, clip_grad_norm, clip_grad_value, grad_norm,
AdaptiveGradClipConfig, ClipNormType, GradientClipResult,
};
pub use early_stopping::{
EarlyStopping, EarlyStoppingWithState, StepResult as EarlyStoppingStepResult, StopReason,
StoppingMode,
};
pub use metrics_tracker::{
BestMetric, MetricEntry, MetricGoal, MetricHistory, MetricStats, MetricsTracker,
TrainingHistory,
};
pub use augmentation::{apply_cutmix, apply_mixup, AugmentationPipeline, AugmentationType};
pub use enhanced_trainer::{
EarlyStoppingConfig, EnhancedTrainer, EnhancedTrainingConfig, GradientAccumulationSettings,
LRWarmupConfig, OperationTiming, OptimizationAnalyzer, OptimizationRecommendation,
ProfilingConfig, ProfilingResults, ProgressConfig, RecommendationType, TrainingState,
ValidationConfig, WarmupSchedule,
};
pub use optimized_dataloader::{
BatchSizeOptimizationResult, BatchSizeOptimizer, LoadingStats, MemoryAwareConfig,
MemoryAwareDataLoader, MemoryAwarePrefetchIter, OptimizedDataLoader, OptimizedLoaderConfig,
PrefetchingIterator,
};
pub use lr_finder::{
find_optimal_lr, LRFinder, LRFinderConfig, LRFinderConfigBuilder, LRFinderPoint,
LRFinderResult, LRFinderStatus, LRScheduleType, TypedLRFinder,
};
pub use curriculum::{
CompetenceSchedule, CurriculumConfig, CurriculumConfigBuilder, CurriculumLearner,
CurriculumSchedule, CurriculumStrategy, DifficultyScorer, LossBasedScorer, StaticScorer,
};
pub use federated::{
clip_l2_norm, AggregationMethod, ClientSelectionStrategy, ClientUpdate,
DifferentialPrivacyConfig, FederatedConfig, FederatedConfigBuilder, FederatedServer,
GradientCompressionConfig, RoundStats,
};
pub use profiler::{
estimate_conv2d_memory, estimate_dense_flops, estimate_dense_memory, BatchStats, Bottleneck,
EpochStats, LayerProfile, ProfilePhase, ProfileSummary, TrainingProfiler,
};
pub use hparam_tuner::{
HParamSpace, HParamTuner, HParamValue, SearchStrategy, SpaceType, TrialResult,
};
#[derive(Debug, Clone)]
pub struct TrainingConfig {
pub batch_size: usize,
pub shuffle: bool,
pub num_workers: usize,
pub learning_rate: f64,
pub epochs: usize,
pub verbose: usize,
pub validation: Option<ValidationSettings>,
pub gradient_accumulation: Option<GradientAccumulationConfig>,
pub mixed_precision: Option<MixedPrecisionConfig>,
}
impl Default for TrainingConfig {
fn default() -> Self {
Self {
batch_size: 32,
shuffle: true,
num_workers: 0,
learning_rate: 0.001,
epochs: 10,
verbose: 1,
validation: None,
gradient_accumulation: None,
mixed_precision: None,
}
}
}
#[derive(Debug, Clone)]
pub struct ValidationSettings {
pub enabled: bool,
pub validation_split: f64,
pub batch_size: usize,
pub num_workers: usize,
}
impl Default for ValidationSettings {
fn default() -> Self {
Self {
enabled: true,
validation_split: 0.2,
batch_size: 32,
num_workers: 0,
}
}
}
#[derive(Debug, Clone)]
pub struct TrainingSession<F: Float + Debug + ScalarOperand> {
pub history: HashMap<String, Vec<F>>,
pub initial_learning_rate: F,
pub epochs_trained: usize,
pub current_epoch: usize,
pub best_validation_score: Option<F>,
pub early_stopped: bool,
}
impl<F: Float + Debug + ScalarOperand> TrainingSession<F> {
pub fn new(config: TrainingConfig) -> Self {
Self {
history: HashMap::new(),
initial_learning_rate: F::from(config.learning_rate)
.expect("Failed to convert to float"),
epochs_trained: 0,
current_epoch: 0,
best_validation_score: None,
early_stopped: false,
}
}
pub fn add_metric(&mut self, metricname: &str, value: F) {
self.history
.entry(metricname.to_string())
.or_default()
.push(value);
}
pub fn get_metric_history(&self, metricname: &str) -> Option<&Vec<F>> {
self.history.get(metricname)
}
pub fn get_metric_names(&self) -> Vec<&String> {
self.history.keys().collect()
}
pub fn next_epoch(&mut self) {
self.current_epoch += 1;
self.epochs_trained += 1;
}
pub fn finish_training(&mut self) {
}
pub fn early_stop(&mut self) {
self.early_stopped = true;
}
}
impl<F: Float + Debug + ScalarOperand> Default for TrainingSession<F> {
fn default() -> Self {
Self {
history: HashMap::new(),
initial_learning_rate: F::from(0.001).expect("Failed to convert constant to float"),
epochs_trained: 0,
current_epoch: 0,
best_validation_score: None,
early_stopped: false,
}
}
}