Skip to main content

entrenar/train/transformer_trainer/
mod.rs

1//! Transformer-specific training utilities
2//!
3//! Provides specialized training components for transformer language models,
4//! including tokenized batch creation and language modeling training loops.
5
6mod batch;
7mod config;
8mod cuda_trainer;
9pub mod distributed_checkpoint;
10mod distributed_trainer;
11pub mod elastic;
12pub mod gpu_grad_accumulator;
13pub mod grad_accumulator;
14pub mod pipeline;
15pub mod sequence_parallel;
16pub mod step_profiler;
17pub mod tensor_parallel;
18mod trainer;
19mod utils;
20#[cfg(feature = "gpu")]
21pub mod wgpu_attention;
22#[cfg(feature = "gpu")]
23pub mod wgpu_backward;
24#[cfg(feature = "gpu")]
25pub mod wgpu_checkpoint;
26#[cfg(feature = "gpu")]
27pub mod wgpu_nf4;
28#[cfg(feature = "gpu")]
29pub mod wgpu_runner;
30#[cfg(feature = "gpu")]
31pub mod wgpu_trainer;
32pub mod zero;
33
34#[cfg(test)]
35mod falsify_lora_tests;
36#[cfg(test)]
37mod tests;
38
39// Re-export all public types
40pub use batch::LMBatch;
41pub use config::{
42    DistributedBackend, DistributedRole, DistributedTrainConfig, TransformerTrainConfig,
43};
44pub use cuda_trainer::CudaTransformerTrainer;
45pub use distributed_checkpoint::DistributedCheckpointCoordinator;
46pub use distributed_trainer::shard_batches;
47#[cfg(feature = "cuda")]
48#[allow(unused_imports)]
49pub use distributed_trainer::{DistributedComm, DistributedCudaTrainer, GradientMessage};
50pub use elastic::ElasticCoordinator;
51pub use grad_accumulator::{BlockGradientSet, PerBlockGradientAccumulator};
52pub use pipeline::{PipelineAction, PipelineActivationBuffer, PipelineStage};
53pub use sequence_parallel::{
54    CausalMaskType, RingAttentionSchedule, SequenceParallelConfig, SpCommCost,
55};
56pub use tensor_parallel::{
57    ColumnParallelShard, RowParallelShard, TensorParallelConfig, TpCommCost,
58};
59pub use trainer::TransformerTrainer;
60pub use utils::{perplexity, tokens_per_second};
61pub use zero::{OptimizerShard, ZeroShardMap};