Skip to main content

entrenar/train/
mod.rs

1//! High-level training loop
2//!
3//! This module provides a complete training framework with:
4//! - Loss functions (MSE, Cross-Entropy, Huber/SmoothL1, L1)
5//! - Evaluation metrics (Accuracy, Precision, Recall, F1, R², MAE, RMSE)
6//! - Curriculum learning (Linear, Tiered, Adaptive)
7//! - Trainer abstraction
8//! - Training configuration
9//! - Metrics tracking
10//! - Checkpoint support
11//!
12//! # Example
13//!
14//! ```no_run
15//! use entrenar::train::{Trainer, TrainConfig, Batch};
16//! use entrenar::optim::Adam;
17//! use entrenar::Tensor;
18//!
19//! let params = vec![Tensor::zeros(10, true)];
20//! let optimizer = Adam::new(0.001, 0.9, 0.999, 1e-8);
21//! let config = TrainConfig::default();
22//!
23//! let mut trainer = Trainer::new(params, Box::new(optimizer), config);
24//!
25//! // Training loop
26//! // for epoch in 0..10 {
27//! //     let loss = trainer.train_epoch(&dataloader);
28//! //     println!("Epoch {}: loss={:.4}", epoch, loss);
29//! // }
30//! ```
31
32mod batch;
33pub mod callback;
34mod config;
35mod curriculum;
36pub mod device;
37pub mod gputrain_003;
38pub mod gputrain_004;
39pub mod gputrain_005;
40pub mod gputrain_006;
41pub mod gputrain_007;
42mod loss;
43mod metrics;
44pub mod pretrain;
45pub mod pretrain_real;
46#[cfg(feature = "cuda")]
47pub mod pretrain_real_cuda;
48pub mod shard_reader;
49mod trainer;
50pub mod transformer_trainer;
51pub mod tui;
52
53#[cfg(test)]
54mod tests;
55
56pub use batch::Batch;
57pub use callback::{
58    CallbackAction, CallbackContext, CallbackManager, CheckpointCallback, EarlyStopping,
59    ExplainMethod, ExplainabilityCallback, FeatureImportanceResult, LRSchedulerCallback,
60    MonitorCallback, ProgressCallback, TrainerCallback,
61};
62pub use config::{MetricsTracker, TrainConfig};
63pub use curriculum::{
64    efficiency_score, select_optimal_tier, AdaptiveCurriculum, CurriculumScheduler,
65    LinearCurriculum, TieredCurriculum,
66};
67pub use device::{resolve_device, Device, DeviceError};
68pub use loss::{
69    BCEWithLogitsLoss, CausalLMLoss, CrossEntropyLoss, HuberLoss, L1Loss, LossFn, MSELoss,
70    SampleWeightedLoss, SmoothL1Loss, WeightedLoss,
71};
72pub use metrics::{Accuracy, F1Score, Metric, Precision, R2Score, Recall, MAE, RMSE};
73pub use pretrain::{
74    check_non_divergence, check_numerical_stability, CheckpointFn, EpochArtifact, EpochMetadata,
75    LinearDecaySynthetic, NanAtStepSynthetic, PretrainAbort, PretrainConfig, PretrainLoop,
76    RunStatus, ScriptedVal, StepFn, StepMetrics, TrainingRegime, ValFn, DIVERGENCE_RATIO_LIMIT,
77    EPOCH_ZERO_VAL_LOSS_LIMIT,
78};
79pub use trainer::{TrainResult, Trainer};
80pub use transformer_trainer::distributed_checkpoint::{
81    checkpoint_path, hash_weights, should_save_checkpoint, verify_weight_consistency,
82    CheckpointPhase,
83};
84pub use transformer_trainer::grad_accumulator::BLOCK_GRAD_COMPONENTS;
85pub use transformer_trainer::{
86    perplexity,
87    // DDP (#133)
88    shard_batches,
89    tokens_per_second,
90    BlockGradientSet,
91    // Parallelism strategies
92    CausalMaskType,
93    ColumnParallelShard,
94    CudaTransformerTrainer,
95    DistributedBackend,
96    DistributedCheckpointCoordinator,
97    DistributedRole,
98    DistributedTrainConfig,
99    ElasticCoordinator,
100    LMBatch,
101    OptimizerShard,
102    PerBlockGradientAccumulator,
103    PipelineAction,
104    PipelineActivationBuffer,
105    PipelineStage,
106    RingAttentionSchedule,
107    RowParallelShard,
108    SequenceParallelConfig,
109    SpCommCost,
110    TensorParallelConfig,
111    TpCommCost,
112    TransformerTrainConfig,
113    TransformerTrainer,
114    ZeroShardMap,
115};
116#[cfg(feature = "cuda")]
117pub use transformer_trainer::{DistributedComm, DistributedCudaTrainer};
118pub use tui::{
119    format_duration, sparkline, sparkline_range, Alert, AlertLevel, AndonSystem, DashboardLayout,
120    FeatureImportanceChart, GradientFlowHeatmap, KalmanEta, LossCurveDisplay, MetricsBuffer,
121    MonitorConfig, ProgressBar, ReferenceCurve, RefreshPolicy, SeriesSummaryTuple,
122    TerminalCapabilities, TerminalMode, TerminalMonitorCallback, SPARK_CHARS,
123};