Expand description
High-level training loop
This module provides a complete training framework with:
- Loss functions (MSE, Cross-Entropy, Huber/SmoothL1, L1)
- Evaluation metrics (Accuracy, Precision, Recall, F1, R², MAE, RMSE)
- Curriculum learning (Linear, Tiered, Adaptive)
- Trainer abstraction
- Training configuration
- Metrics tracking
- Checkpoint support
§Example
use entrenar::train::{Trainer, TrainConfig, Batch};
use entrenar::optim::Adam;
use entrenar::Tensor;
let params = vec![Tensor::zeros(10, true)];
let optimizer = Adam::new(0.001, 0.9, 0.999, 1e-8);
let config = TrainConfig::default();
let mut trainer = Trainer::new(params, Box::new(optimizer), config);
// Training loop
// for epoch in 0..10 {
// let loss = trainer.train_epoch(&dataloader);
// println!("Epoch {}: loss={:.4}", epoch, loss);
// }Re-exports§
pub use callback::CallbackAction;pub use callback::CallbackContext;pub use callback::CallbackManager;pub use callback::CheckpointCallback;pub use callback::EarlyStopping;pub use callback::ExplainMethod;pub use callback::ExplainabilityCallback;pub use callback::FeatureImportanceResult;pub use callback::LRSchedulerCallback;pub use callback::MonitorCallback;pub use callback::ProgressCallback;pub use callback::TrainerCallback;pub use device::resolve_device;pub use device::Device;pub use device::DeviceError;pub use pretrain::check_non_divergence;pub use pretrain::check_numerical_stability;pub use pretrain::CheckpointFn;pub use pretrain::EpochArtifact;pub use pretrain::EpochMetadata;pub use pretrain::LinearDecaySynthetic;pub use pretrain::NanAtStepSynthetic;pub use pretrain::PretrainAbort;pub use pretrain::PretrainConfig;pub use pretrain::PretrainLoop;pub use pretrain::RunStatus;pub use pretrain::ScriptedVal;pub use pretrain::StepFn;pub use pretrain::StepMetrics;pub use pretrain::TrainingRegime;pub use pretrain::ValFn;pub use pretrain::DIVERGENCE_RATIO_LIMIT;pub use pretrain::EPOCH_ZERO_VAL_LOSS_LIMIT;pub use transformer_trainer::distributed_checkpoint::checkpoint_path;pub use transformer_trainer::distributed_checkpoint::hash_weights;pub use transformer_trainer::distributed_checkpoint::should_save_checkpoint;pub use transformer_trainer::distributed_checkpoint::verify_weight_consistency;pub use transformer_trainer::distributed_checkpoint::CheckpointPhase;pub use transformer_trainer::grad_accumulator::BLOCK_GRAD_COMPONENTS;pub use transformer_trainer::perplexity;pub use transformer_trainer::shard_batches;pub use transformer_trainer::tokens_per_second;pub use transformer_trainer::BlockGradientSet;pub use transformer_trainer::CausalMaskType;pub use transformer_trainer::ColumnParallelShard;pub use transformer_trainer::CudaTransformerTrainer;pub use transformer_trainer::DistributedBackend;pub use transformer_trainer::DistributedCheckpointCoordinator;pub use transformer_trainer::DistributedRole;pub use transformer_trainer::DistributedTrainConfig;pub use transformer_trainer::ElasticCoordinator;pub use transformer_trainer::LMBatch;pub use transformer_trainer::OptimizerShard;pub use transformer_trainer::PerBlockGradientAccumulator;pub use transformer_trainer::PipelineAction;pub use transformer_trainer::PipelineActivationBuffer;pub use transformer_trainer::PipelineStage;pub use transformer_trainer::RingAttentionSchedule;pub use transformer_trainer::RowParallelShard;pub use transformer_trainer::SequenceParallelConfig;pub use transformer_trainer::SpCommCost;pub use transformer_trainer::TensorParallelConfig;pub use transformer_trainer::TpCommCost;pub use transformer_trainer::TransformerTrainConfig;pub use transformer_trainer::TransformerTrainer;pub use transformer_trainer::ZeroShardMap;pub use transformer_trainer::DistributedComm;pub use transformer_trainer::DistributedCudaTrainer;pub use tui::format_duration;pub use tui::sparkline;pub use tui::sparkline_range;pub use tui::Alert;pub use tui::AlertLevel;pub use tui::AndonSystem;pub use tui::DashboardLayout;pub use tui::FeatureImportanceChart;pub use tui::GradientFlowHeatmap;pub use tui::KalmanEta;pub use tui::LossCurveDisplay;pub use tui::MetricsBuffer;pub use tui::MonitorConfig;pub use tui::ProgressBar;pub use tui::ReferenceCurve;pub use tui::RefreshPolicy;pub use tui::SeriesSummaryTuple;pub use tui::TerminalCapabilities;pub use tui::TerminalMode;pub use tui::TerminalMonitorCallback;pub use tui::SPARK_CHARS;
Modules§
- callback
- Callback system for training events
- device
Device— selector for the training backend (apr pretrain).- gputrain_
003 - FALSIFY-GPUTRAIN-003 / INV-GPUTRAIN-003 / GATE-GPUTRAIN-003 — algorithm-level PARTIAL discharge.
- gputrain_
004 - FALSIFY-GPUTRAIN-004 / INV-GPUTRAIN-004 / GATE-GPUTRAIN-005 — algorithm-level PARTIAL discharge.
- gputrain_
005 - FALSIFY-GPUTRAIN-005 / INV-GPUTRAIN-005 / GATE-GPUTRAIN-004 — algorithm-level PARTIAL discharge.
- gputrain_
006 - FALSIFY-GPUTRAIN-006 / INV-GPUTRAIN-006 — empirical reproducibility discharge.
- gputrain_
007 - FALSIFY-GPUTRAIN-007 / INV-GPUTRAIN-007 / GATE-GPUTRAIN-006 — algorithm-level PARTIAL discharge.
- pretrain
- Pretraining loop driver for SHIP-TWO-001 MODEL-2 (albor 370M).
- pretrain_
real - Real-corpus
StepFn/ValFnfor MODEL-2 pretrain MVP (task #111). - pretrain_
real_ cuda - CUDA-backend
StepFn/ValFn/CheckpointFnfor the 370M pretrain loop (task #132 Phase 2, contractgpu-training-backend-v1). - shard_
reader - Minimal tokenized-shard reader for MODEL-2 pretrain MVP (task #111).
- transformer_
trainer - Transformer-specific training utilities
- tui
- Real-Time Terminal Monitoring and Visualization (ENT-054 through ENT-067)
Structs§
- Accuracy
- Accuracy metric for classification
- Adaptive
Curriculum - Adaptive curriculum that adjusts based on error class performance
- BCEWith
Logits Loss - Binary Cross-Entropy with Logits Loss.
- Batch
- A training batch containing inputs and targets
- CausalLM
Loss - Causal Language Modeling Loss
- Cross
Entropy Loss - Cross Entropy Loss (for classification)
- F1Score
- F1 Score (harmonic mean of precision and recall)
- Huber
Loss - Huber Loss (Smooth L1 Loss)
- L1Loss
- L1 Loss (Mean Absolute Error)
- Linear
Curriculum - Linear curriculum that increases difficulty over epochs
- MAE
- Mean Absolute Error (MAE) metric
- MSELoss
- Mean Squared Error Loss
- Metrics
Tracker - Tracks training metrics across epochs
- Precision
- Precision metric (true positives / predicted positives)
- R2Score
- R² (coefficient of determination) for regression
- RMSE
- Root Mean Squared Error (RMSE) metric
- Recall
- Recall metric (true positives / actual positives)
- Sample
Weighted Loss - Per-sample weighted loss for fine-grained control
- Tiered
Curriculum - Tiered curriculum for diagnostic verbosity levels
- Train
Config - Training configuration
- Train
Result - Result of a training run
- Trainer
- High-level trainer that orchestrates the training loop
- Weighted
Loss - Weighted loss wrapper for sample reweighting
Traits§
- Curriculum
Scheduler - Trait for curriculum learning schedulers
- LossFn
- Trait for loss functions
- Metric
- Trait for evaluation metrics
Functions§
- efficiency_
score - Compute efficiency score as per CITL spec
- select_
optimal_ tier - Compare tiers and select optimal based on efficiency
Type Aliases§
- Smooth
L1Loss - Smooth L1 Loss (alias for HuberLoss with delta=1.0)