trustformers-training 0.1.0

Training infrastructure for TrustformeRS
Documentation

trustformers-training

Comprehensive training infrastructure for transformer models with distributed training, hyperparameter optimization, and advanced training techniques.

Current State

Version: 0.1.0 | Status: Stable | Released: 2026-03-21

This crate provides production-ready training capabilities including distributed training across multiple nodes, mixed precision training, hyperparameter tuning, and quantization-aware training. The design closely follows HuggingFace Transformers' Trainer API for familiarity.

  • 333 passing tests, 0 failures
  • 846 public API items
  • 38,667 SLoC
  • 0 stubs — all functionality fully implemented

Features

Core Training Infrastructure

  • Trainer API: HuggingFace-compatible training interface
  • TrainingArguments: Comprehensive configuration system
  • Gradient Management: Accumulation, clipping, and scaling
  • Checkpoint System: Save/resume with full state preservation
  • Early Stopping: Patience-based training termination
  • Callback System: Extensible hooks for custom behavior
  • Experiment Management: TensorBoard, W&B, Neptune, ClearML, MLflow

Distributed Training

  • Data Parallel Training: Multi-GPU with gradient synchronization
  • Model Parallel Support: Tensor and pipeline parallelism
  • Pipeline Parallel: GPipe-style scheduling with microbatching
  • ZeRO Optimization: All three stages implemented
    • Stage 1: Optimizer state sharding
    • Stage 2: Optimizer + gradient sharding
    • Stage 3: Full parameter sharding with CPU offload
  • Multi-Node Support: MPI backend for cluster training
  • Process Groups: NCCL, Gloo, MPI backends

Mixed Precision Training

  • Automatic Mixed Precision (AMP): FP16/BF16 training
  • Dynamic Loss Scaling: Prevent gradient underflow
  • Gradient Scaling: Automatic scale adjustment
  • Memory Savings: ~50% reduction in memory usage
  • Performance: 1.5-2x training speedup

RLHF and Alignment

  • PPO (Proximal Policy Optimization): Full RL-based fine-tuning pipeline
  • DPO (Direct Preference Optimization): Sigmoid-based loss, log_softmax for log probs, preference accuracy metric
  • KTO (Kahneman-Tversky Optimization): Prospect-theory-based alignment
  • Reward Models: Training and inference for human preference modeling
  • Feedback Processing: Human preference data ingestion and formatting
  • Reference Model Management: KL divergence constraint and stability

Few-Shot and Meta-Learning

  • MAML: Model-Agnostic Meta-Learning with inner/outer loop
  • First-Order MAML (FOMAML): Efficient approximation
  • Reptile: Simple first-order meta-learning
  • In-Context Learning: Prompt-based few-shot inference
  • Prompt Tuning: Soft-prompt optimization

Continual Learning

  • Elastic Weight Consolidation (EWC): Importance-weighted regularization
  • Progressive Neural Networks: Column-based architecture growth
  • Learning without Forgetting (LwF): Distillation-based continual learning
  • Replay Buffers: Experience replay for catastrophic forgetting prevention

Curriculum Learning

  • Length-Based Curriculum: Short-to-long sequence ordering
  • Difficulty-Based Curriculum: Dynamic example difficulty scoring
  • Self-Paced Learning: Performance-driven progression

Hyperparameter Optimization

  • Random Search: Sampling from distributions with configurable trials
  • Bayesian Optimization: Gaussian process surrogate with EI/UCB acquisition
  • Grid Search: Exhaustive parameter combinations with parallel execution
  • Hyperband: Successive halving for resource-efficient search
  • Population-Based Training (PBT): Evolutionary online adaptation

Advanced Training Features

  • Quantization-Aware Training (QAT): Fake quantization, per-tensor and per-channel schemes
  • Gradient Checkpointing: Trade compute for memory
  • Learning Rate Schedulers:
    • Linear, cosine, polynomial decay
    • Warmup strategies
    • Exponential and step decay
  • Custom Loss Functions: CrossEntropy, MSE, with label smoothing
  • Metrics System: Accuracy, F1, perplexity, BLEU, preference accuracy, custom metrics

Usage Example

use trustformers_training::{
    Trainer, TrainingArguments,
    optimizers::{AdamW, AdamWConfig},
    schedulers::{LinearScheduler, SchedulerConfig},
};

// Configure training
let args = TrainingArguments {
    output_dir: "output".to_string(),
    num_train_epochs: 3,
    per_device_train_batch_size: 32,
    learning_rate: 5e-5,
    warmup_steps: 500,
    logging_steps: 100,
    save_steps: 1000,
    fp16: true,
    gradient_accumulation_steps: 4,
    ..Default::default()
};

// Create optimizer
let optimizer = AdamW::new(AdamWConfig {
    lr: args.learning_rate,
    weight_decay: 0.01,
    ..Default::default()
})?;

// Create trainer
let trainer = Trainer::new(
    model,
    args,
    train_dataset,
    eval_dataset,
    optimizer,
)?;

// Train
trainer.train()?;

Distributed Training Example

use trustformers_training::distributed::{
    DistributedTrainer,
    ProcessGroup,
    ZeroStage,
};

// Initialize distributed environment
let process_group = ProcessGroup::new_from_env()?;

// Configure ZeRO
let trainer = DistributedTrainer::new(
    model,
    args,
    process_group,
    ZeroStage::Three, // Full parameter sharding
)?;

// Train across multiple GPUs/nodes
trainer.train()?;

RLHF / DPO Example

use trustformers_training::rlhf::{DPOTrainer, DPOConfig};

let dpo_config = DPOConfig {
    beta: 0.1,
    reference_model_path: "path/to/sft_model",
    ..Default::default()
};

let trainer = DPOTrainer::new(model, dpo_config, args)?;
trainer.train()?; // uses log_softmax-based log probs + preference accuracy metric

Architecture

trustformers-training/
├── src/
│   ├── trainer.rs          # Main trainer implementation
│   ├── args.rs             # Training arguments
│   ├── distributed/        # Distributed training
│   │   ├── data_parallel.rs
│   │   ├── model_parallel.rs
│   │   ├── pipeline_parallel.rs
│   │   ├── zero.rs         # ZeRO optimizer (stages 1/2/3)
│   │   └── process_group.rs
│   ├── mixed_precision/    # AMP implementation
│   ├── rlhf/               # PPO, DPO, KTO, reward models
│   ├── few_shot/           # MAML, Reptile, in-context, prompt tuning
│   ├── continual/          # EWC, progressive networks
│   ├── hyperparameter/     # HP optimization (Bayesian, Grid, Random, PBT)
│   ├── loss/               # Loss functions
│   ├── metrics/            # Evaluation metrics
│   ├── callbacks/          # Callback system
│   └── schedulers/         # LR schedulers

Performance Features

Memory Optimization

  • Gradient Checkpointing: Re-compute activations during backward pass
  • CPU Offloading: Move optimizer states to CPU (ZeRO Stage 3)
  • Mixed Precision: Reduce memory with FP16/BF16 weights
  • Gradient Accumulation: Larger effective batch sizes

Speed Optimization

  • Data Parallel: Linear scaling with multiple GPUs
  • Fused Optimizers: Combined update operations
  • Efficient Communication: Optimized all-reduce
  • Overlapped Computation: Hide communication latency

Benchmarks

Configuration GPUs Model Throughput Speedup
Single GPU 1 BERT-Large 250 samples/s 1.0x
Data Parallel 8 BERT-Large 1,920 samples/s 7.7x
ZeRO Stage 2 8 GPT-2 1.5B 450 samples/s -
ZeRO Stage 3 16 LLaMA 7B 320 samples/s -

Benchmarks on NVIDIA A100 GPUs with NVLink

Advanced Features

Callbacks

  • EarlyStoppingCallback: Stop training when metric plateaus
  • ModelCheckpoint: Save best models during training
  • TensorBoardCallback: Log metrics to TensorBoard
  • Custom Callbacks: Implement the TrainerCallback trait

Metrics

  • Classification: Accuracy, F1, Precision, Recall
  • Generation: Perplexity, BLEU
  • Alignment: Preference accuracy (DPO/KTO)
  • Custom Metrics: Implement the Metric trait
  • Metric Collections: Combine multiple metrics

Testing

  • 333 unit and integration tests, all passing
  • Tests for distributed training, RLHF (PPO, DPO, KTO), few-shot, continual learning, QAT
  • Memory leak detection
  • Performance benchmarks

License

Apache-2.0