Skip to main content

Module train

Module train 

Source
Expand description

High-level training loop

This module provides a complete training framework with:

  • Loss functions (MSE, Cross-Entropy, Huber/SmoothL1, L1)
  • Evaluation metrics (Accuracy, Precision, Recall, F1, R², MAE, RMSE)
  • Curriculum learning (Linear, Tiered, Adaptive)
  • Trainer abstraction
  • Training configuration
  • Metrics tracking
  • Checkpoint support

§Example

use entrenar::train::{Trainer, TrainConfig, Batch};
use entrenar::optim::Adam;
use entrenar::Tensor;

let params = vec![Tensor::zeros(10, true)];
let optimizer = Adam::new(0.001, 0.9, 0.999, 1e-8);
let config = TrainConfig::default();

let mut trainer = Trainer::new(params, Box::new(optimizer), config);

// Training loop
// for epoch in 0..10 {
//     let loss = trainer.train_epoch(&dataloader);
//     println!("Epoch {}: loss={:.4}", epoch, loss);
// }

Re-exports§

pub use callback::CallbackAction;
pub use callback::CallbackContext;
pub use callback::CallbackManager;
pub use callback::CheckpointCallback;
pub use callback::EarlyStopping;
pub use callback::ExplainMethod;
pub use callback::ExplainabilityCallback;
pub use callback::FeatureImportanceResult;
pub use callback::LRSchedulerCallback;
pub use callback::MonitorCallback;
pub use callback::ProgressCallback;
pub use callback::TrainerCallback;
pub use device::resolve_device;
pub use device::Device;
pub use device::DeviceError;
pub use pretrain::check_non_divergence;
pub use pretrain::check_numerical_stability;
pub use pretrain::CheckpointFn;
pub use pretrain::EpochArtifact;
pub use pretrain::EpochMetadata;
pub use pretrain::LinearDecaySynthetic;
pub use pretrain::NanAtStepSynthetic;
pub use pretrain::PretrainAbort;
pub use pretrain::PretrainConfig;
pub use pretrain::PretrainLoop;
pub use pretrain::RunStatus;
pub use pretrain::ScriptedVal;
pub use pretrain::StepFn;
pub use pretrain::StepMetrics;
pub use pretrain::TrainingRegime;
pub use pretrain::ValFn;
pub use pretrain::DIVERGENCE_RATIO_LIMIT;
pub use pretrain::EPOCH_ZERO_VAL_LOSS_LIMIT;
pub use transformer_trainer::distributed_checkpoint::checkpoint_path;
pub use transformer_trainer::distributed_checkpoint::hash_weights;
pub use transformer_trainer::distributed_checkpoint::should_save_checkpoint;
pub use transformer_trainer::distributed_checkpoint::verify_weight_consistency;
pub use transformer_trainer::distributed_checkpoint::CheckpointPhase;
pub use transformer_trainer::grad_accumulator::BLOCK_GRAD_COMPONENTS;
pub use transformer_trainer::perplexity;
pub use transformer_trainer::shard_batches;
pub use transformer_trainer::tokens_per_second;
pub use transformer_trainer::BlockGradientSet;
pub use transformer_trainer::CausalMaskType;
pub use transformer_trainer::ColumnParallelShard;
pub use transformer_trainer::CudaTransformerTrainer;
pub use transformer_trainer::DistributedBackend;
pub use transformer_trainer::DistributedCheckpointCoordinator;
pub use transformer_trainer::DistributedRole;
pub use transformer_trainer::DistributedTrainConfig;
pub use transformer_trainer::ElasticCoordinator;
pub use transformer_trainer::LMBatch;
pub use transformer_trainer::OptimizerShard;
pub use transformer_trainer::PerBlockGradientAccumulator;
pub use transformer_trainer::PipelineAction;
pub use transformer_trainer::PipelineActivationBuffer;
pub use transformer_trainer::PipelineStage;
pub use transformer_trainer::RingAttentionSchedule;
pub use transformer_trainer::RowParallelShard;
pub use transformer_trainer::SequenceParallelConfig;
pub use transformer_trainer::SpCommCost;
pub use transformer_trainer::TensorParallelConfig;
pub use transformer_trainer::TpCommCost;
pub use transformer_trainer::TransformerTrainConfig;
pub use transformer_trainer::TransformerTrainer;
pub use transformer_trainer::ZeroShardMap;
pub use transformer_trainer::DistributedComm;
pub use transformer_trainer::DistributedCudaTrainer;
pub use tui::format_duration;
pub use tui::sparkline;
pub use tui::sparkline_range;
pub use tui::Alert;
pub use tui::AlertLevel;
pub use tui::AndonSystem;
pub use tui::DashboardLayout;
pub use tui::FeatureImportanceChart;
pub use tui::GradientFlowHeatmap;
pub use tui::KalmanEta;
pub use tui::LossCurveDisplay;
pub use tui::MetricsBuffer;
pub use tui::MonitorConfig;
pub use tui::ProgressBar;
pub use tui::ReferenceCurve;
pub use tui::RefreshPolicy;
pub use tui::SeriesSummaryTuple;
pub use tui::TerminalCapabilities;
pub use tui::TerminalMode;
pub use tui::TerminalMonitorCallback;
pub use tui::SPARK_CHARS;

Modules§

callback
Callback system for training events
device
Device — selector for the training backend (apr pretrain).
gputrain_003
FALSIFY-GPUTRAIN-003 / INV-GPUTRAIN-003 / GATE-GPUTRAIN-003 — algorithm-level PARTIAL discharge.
gputrain_004
FALSIFY-GPUTRAIN-004 / INV-GPUTRAIN-004 / GATE-GPUTRAIN-005 — algorithm-level PARTIAL discharge.
gputrain_005
FALSIFY-GPUTRAIN-005 / INV-GPUTRAIN-005 / GATE-GPUTRAIN-004 — algorithm-level PARTIAL discharge.
gputrain_006
FALSIFY-GPUTRAIN-006 / INV-GPUTRAIN-006 — empirical reproducibility discharge.
gputrain_007
FALSIFY-GPUTRAIN-007 / INV-GPUTRAIN-007 / GATE-GPUTRAIN-006 — algorithm-level PARTIAL discharge.
pretrain
Pretraining loop driver for SHIP-TWO-001 MODEL-2 (albor 370M).
pretrain_real
Real-corpus StepFn / ValFn for MODEL-2 pretrain MVP (task #111).
pretrain_real_cuda
CUDA-backend StepFn / ValFn / CheckpointFn for the 370M pretrain loop (task #132 Phase 2, contract gpu-training-backend-v1).
shard_reader
Minimal tokenized-shard reader for MODEL-2 pretrain MVP (task #111).
transformer_trainer
Transformer-specific training utilities
tui
Real-Time Terminal Monitoring and Visualization (ENT-054 through ENT-067)

Structs§

Accuracy
Accuracy metric for classification
AdaptiveCurriculum
Adaptive curriculum that adjusts based on error class performance
BCEWithLogitsLoss
Binary Cross-Entropy with Logits Loss.
Batch
A training batch containing inputs and targets
CausalLMLoss
Causal Language Modeling Loss
CrossEntropyLoss
Cross Entropy Loss (for classification)
F1Score
F1 Score (harmonic mean of precision and recall)
HuberLoss
Huber Loss (Smooth L1 Loss)
L1Loss
L1 Loss (Mean Absolute Error)
LinearCurriculum
Linear curriculum that increases difficulty over epochs
MAE
Mean Absolute Error (MAE) metric
MSELoss
Mean Squared Error Loss
MetricsTracker
Tracks training metrics across epochs
Precision
Precision metric (true positives / predicted positives)
R2Score
R² (coefficient of determination) for regression
RMSE
Root Mean Squared Error (RMSE) metric
Recall
Recall metric (true positives / actual positives)
SampleWeightedLoss
Per-sample weighted loss for fine-grained control
TieredCurriculum
Tiered curriculum for diagnostic verbosity levels
TrainConfig
Training configuration
TrainResult
Result of a training run
Trainer
High-level trainer that orchestrates the training loop
WeightedLoss
Weighted loss wrapper for sample reweighting

Traits§

CurriculumScheduler
Trait for curriculum learning schedulers
LossFn
Trait for loss functions
Metric
Trait for evaluation metrics

Functions§

efficiency_score
Compute efficiency score as per CITL spec
select_optimal_tier
Compare tiers and select optimal based on efficiency

Type Aliases§

SmoothL1Loss
Smooth L1 Loss (alias for HuberLoss with delta=1.0)