Expand description
§hybrid-predict-trainer-rs
Hybridized predictive training framework that accelerates deep learning through intelligent phase-based training with whole-phase prediction and residual correction.
§Overview
This crate implements a novel training paradigm that achieves 5-10x speedup over traditional training by predicting training outcomes rather than computing every gradient step. The key insight is that training dynamics evolve on low-dimensional manifolds, making whole-phase prediction tractable.
§Training Phases
The training loop cycles through four distinct phases:
- Warmup Phase - Initial training steps to establish baseline dynamics
- Full Training Phase - Traditional forward/backward pass computation
- Predictive Phase - Skip backward passes using learned dynamics model
- Correction Phase - Apply residual corrections to maintain accuracy
┌─────────┐
│ WARMUP │
└────┬────┘
│
▼
┌─────────────────────┐
│ │
▼ │
┌──────────┐ │
┌───▶│ FULL │◀───────────────┤
│ └────┬─────┘ │
│ │ │
│ ▼ │
│ ┌──────────┐ │
│ │ PREDICT │ │
│ └────┬─────┘ │
│ │ │
│ ▼ │
│ ┌──────────┐ │
│ │ CORRECT │────────────────┘
│ └────┬─────┘
│ │
└─────────┘§Quick Start
use hybrid_predict_trainer_rs::{HybridTrainer, HybridTrainerConfig};
// Create configuration with sensible defaults
let config = HybridTrainerConfig::default();
// Initialize the hybrid trainer
// let trainer = HybridTrainer::new(model, optimizer, config)?;
// Training loop
// for batch in dataloader {
// let result = trainer.step(&batch)?;
// println!("Loss: {}, Phase: {:?}", result.loss, result.phase);
// }§Features
- GPU Acceleration -
CubeCLand Burn backends for high-performance compute - Adaptive Phase Selection - Bandit-based algorithm for optimal phase lengths
- Divergence Detection - Multi-signal monitoring prevents training instability
- Residual Correction - Online learning corrects prediction errors
- Checkpoint Support - Save/restore full training state including predictor
§Feature Flags
std- Enable standard library support (default)cuda- Enable CUDA GPU acceleration viaCubeCLcandle- Enable Candle tensor operations for model compatibilityasync- Enable async/await support with Tokiofull- Enable all features
§Architecture
The crate is organized into the following modules:
config- Training configuration and serializationerror- Error types with recovery actionsphases- Phase state machine and execution controlstate- Training state encoding and managementdynamics- RSSM-lite dynamics model for predictionresiduals- Residual extraction and storagecorrector- Prediction correction via residual applicationdivergence- Multi-signal divergence detectionmetrics- Training metrics collection and reportinggpu- GPU acceleration kernels (requirescudafeature)
§References
This implementation is based on research findings documented in
predictive-training-research.md, synthesizing insights from:
- Neural Tangent Kernel (NTK) theory for training dynamics
- RSSM world models from
DreamerV3 - K-FAC for structured gradient approximation
PowerSGDfor low-rank gradient compression
Re-exports§
pub use auto_tuning::AutoTuningConfig;pub use auto_tuning::AutoTuningController;pub use auto_tuning::AutoTuningUpdate;pub use auto_tuning::BatchPrediction;pub use auto_tuning::BatchPredictionRecommendation;pub use auto_tuning::HealthClassification;pub use auto_tuning::HealthRecommendation;pub use auto_tuning::HealthScorer;pub use auto_tuning::HealthWeights;pub use auto_tuning::MultiStepPredictor;pub use auto_tuning::TrainingHealthScore;pub use config::HybridTrainerConfig;pub use error::HybridResult;pub use error::HybridTrainingError;pub use error::RecoveryAction;pub use phases::Phase;pub use phases::PhaseController;pub use phases::PhaseDecision;pub use phases::PhaseOutcome;pub use residuals::Residual;pub use residuals::ResidualStore;pub use state::TrainingState;pub use timing::Duration;pub use timing::Timer;pub use timing::TimingMetrics;pub use timing::TimingStats;
Modules§
- auto_
tuning - Automatic tuning and health monitoring for hybrid predictive training.
- bandit
- Bandit-based adaptive phase selection.
- burn_
integration - Burn framework integration for hybrid-predict-trainer-rs.
- checkpoint
- Checkpoint save/restore functionality for hybrid trainer.
- config
- Configuration types for hybrid predictive training.
- corrector
- Residual correction for prediction adjustment.
- delta_
accumulator - Delta accumulation for batched weight updates.
- divergence
- Multi-signal divergence detection.
- dynamics
- RSSM-lite dynamics model for training trajectory prediction.
- ecosystem
ecosystem - Rust AI ecosystem integration.
- error
- Error types and recovery actions for hybrid predictive training.
- full_
train - Full training phase implementation.
- gpu
cuda - GPU acceleration kernels via
CubeCLand Burn. - gradient_
accumulation - Gradient accumulation for memory-efficient training with large effective batch sizes.
- gru
- GRU (Gated Recurrent Unit) implementation with forward pass and training.
- metrics
- Training metrics collection and reporting.
- mixed_
precision - Mixed precision training support for memory optimization.
- models
autodiff - Reference model implementations for validation.
- phases
- Phase state machine and execution control.
- predict_
aware_ memory - Predict-aware memory management for HybridTrainer.
- predictive
- Predictive training phase implementation.
- prelude
- Prelude module for convenient imports.
- residuals
- Residual extraction and storage for prediction correction.
- state
- Training state management and encoding.
- timing
- High-precision timing utilities for training metrics.
- vram_
budget - VRAM budget management and auto-configuration.
- vram_
manager - VRAM management utilities for controlling GPU memory usage.
- warmup
- Warmup phase implementation.
Structs§
- Gradient
Info - Gradient information from a backward pass.
- Hybrid
Trainer - The main hybrid trainer that orchestrates phase-based predictive training.
- Step
Result - Result of a single training step.