Skip to main content

Crate hybrid_predict_trainer_rs

Crate hybrid_predict_trainer_rs 

Source
Expand description

§hybrid-predict-trainer-rs

Hybridized predictive training framework that accelerates deep learning through intelligent phase-based training with whole-phase prediction and residual correction.

§Overview

This crate implements a novel training paradigm that achieves 5-10x speedup over traditional training by predicting training outcomes rather than computing every gradient step. The key insight is that training dynamics evolve on low-dimensional manifolds, making whole-phase prediction tractable.

§Training Phases

The training loop cycles through four distinct phases:

  1. Warmup Phase - Initial training steps to establish baseline dynamics
  2. Full Training Phase - Traditional forward/backward pass computation
  3. Predictive Phase - Skip backward passes using learned dynamics model
  4. Correction Phase - Apply residual corrections to maintain accuracy
                    ┌─────────┐
                    │ WARMUP  │
                    └────┬────┘
                         │
                         ▼
              ┌─────────────────────┐
              │                     │
              ▼                     │
        ┌──────────┐                │
   ┌───▶│   FULL   │◀───────────────┤
   │    └────┬─────┘                │
   │         │                      │
   │         ▼                      │
   │    ┌──────────┐                │
   │    │ PREDICT  │                │
   │    └────┬─────┘                │
   │         │                      │
   │         ▼                      │
   │    ┌──────────┐                │
   │    │ CORRECT  │────────────────┘
   │    └────┬─────┘
   │         │
   └─────────┘

§Quick Start

use hybrid_predict_trainer_rs::{HybridTrainer, HybridTrainerConfig};

// Create configuration with sensible defaults
let config = HybridTrainerConfig::default();

// Initialize the hybrid trainer
// let trainer = HybridTrainer::new(model, optimizer, config)?;

// Training loop
// for batch in dataloader {
//     let result = trainer.step(&batch)?;
//     println!("Loss: {}, Phase: {:?}", result.loss, result.phase);
// }

§Features

  • GPU Acceleration - CubeCL and Burn backends for high-performance compute
  • Adaptive Phase Selection - Bandit-based algorithm for optimal phase lengths
  • Divergence Detection - Multi-signal monitoring prevents training instability
  • Residual Correction - Online learning corrects prediction errors
  • Checkpoint Support - Save/restore full training state including predictor

§Feature Flags

  • std - Enable standard library support (default)
  • cuda - Enable CUDA GPU acceleration via CubeCL
  • candle - Enable Candle tensor operations for model compatibility
  • async - Enable async/await support with Tokio
  • full - Enable all features

§Architecture

The crate is organized into the following modules:

  • config - Training configuration and serialization
  • error - Error types with recovery actions
  • phases - Phase state machine and execution control
  • state - Training state encoding and management
  • dynamics - RSSM-lite dynamics model for prediction
  • residuals - Residual extraction and storage
  • corrector - Prediction correction via residual application
  • divergence - Multi-signal divergence detection
  • metrics - Training metrics collection and reporting
  • gpu - GPU acceleration kernels (requires cuda feature)

§References

This implementation is based on research findings documented in predictive-training-research.md, synthesizing insights from:

  • Neural Tangent Kernel (NTK) theory for training dynamics
  • RSSM world models from DreamerV3
  • K-FAC for structured gradient approximation
  • PowerSGD for low-rank gradient compression

Re-exports§

pub use auto_tuning::AutoTuningConfig;
pub use auto_tuning::AutoTuningController;
pub use auto_tuning::AutoTuningUpdate;
pub use auto_tuning::BatchPrediction;
pub use auto_tuning::BatchPredictionRecommendation;
pub use auto_tuning::HealthClassification;
pub use auto_tuning::HealthRecommendation;
pub use auto_tuning::HealthScorer;
pub use auto_tuning::HealthWeights;
pub use auto_tuning::MultiStepPredictor;
pub use auto_tuning::TrainingHealthScore;
pub use config::HybridTrainerConfig;
pub use error::HybridResult;
pub use error::HybridTrainingError;
pub use error::RecoveryAction;
pub use phases::Phase;
pub use phases::PhaseController;
pub use phases::PhaseDecision;
pub use phases::PhaseOutcome;
pub use residuals::Residual;
pub use residuals::ResidualStore;
pub use state::TrainingState;
pub use timing::Duration;
pub use timing::Timer;
pub use timing::TimingMetrics;
pub use timing::TimingStats;

Modules§

auto_tuning
Automatic tuning and health monitoring for hybrid predictive training.
bandit
Bandit-based adaptive phase selection.
burn_integration
Burn framework integration for hybrid-predict-trainer-rs.
checkpoint
Checkpoint save/restore functionality for hybrid trainer.
config
Configuration types for hybrid predictive training.
corrector
Residual correction for prediction adjustment.
delta_accumulator
Delta accumulation for batched weight updates.
divergence
Multi-signal divergence detection.
dynamics
RSSM-lite dynamics model for training trajectory prediction.
ecosystemecosystem
Rust AI ecosystem integration.
error
Error types and recovery actions for hybrid predictive training.
full_train
Full training phase implementation.
gpucuda
GPU acceleration kernels via CubeCL and Burn.
gradient_accumulation
Gradient accumulation for memory-efficient training with large effective batch sizes.
gru
GRU (Gated Recurrent Unit) implementation with forward pass and training.
metrics
Training metrics collection and reporting.
mixed_precision
Mixed precision training support for memory optimization.
modelsautodiff
Reference model implementations for validation.
phases
Phase state machine and execution control.
predict_aware_memory
Predict-aware memory management for HybridTrainer.
predictive
Predictive training phase implementation.
prelude
Prelude module for convenient imports.
residuals
Residual extraction and storage for prediction correction.
state
Training state management and encoding.
timing
High-precision timing utilities for training metrics.
vram_budget
VRAM budget management and auto-configuration.
vram_manager
VRAM management utilities for controlling GPU memory usage.
warmup
Warmup phase implementation.

Structs§

GradientInfo
Gradient information from a backward pass.
HybridTrainer
The main hybrid trainer that orchestrates phase-based predictive training.
StepResult
Result of a single training step.

Traits§

Batch
Batch of training data.
Model
Trait for models that can be trained with the hybrid trainer.
Optimizer
Trait for optimizers that update model parameters.