Skip to main content

hybrid_predict_trainer_rs/
lib.rs

1//! # hybrid-predict-trainer-rs
2//!
3//! Hybridized predictive training framework that accelerates deep learning through
4//! intelligent phase-based training with whole-phase prediction and residual correction.
5//!
6//! ## Overview
7//!
8//! This crate implements a novel training paradigm that achieves 5-10x speedup over
9//! traditional training by predicting training outcomes rather than computing every
10//! gradient step. The key insight is that training dynamics evolve on low-dimensional
11//! manifolds, making whole-phase prediction tractable.
12//!
13//! ## Training Phases
14//!
15//! The training loop cycles through four distinct phases:
16//!
17//! 1. **Warmup Phase** - Initial training steps to establish baseline dynamics
18//! 2. **Full Training Phase** - Traditional forward/backward pass computation
19//! 3. **Predictive Phase** - Skip backward passes using learned dynamics model
20//! 4. **Correction Phase** - Apply residual corrections to maintain accuracy
21//!
22//! ```text
23//!                     ┌─────────┐
24//!                     │ WARMUP  │
25//!                     └────┬────┘
26//!                          │
27//!                          ▼
28//!               ┌─────────────────────┐
29//!               │                     │
30//!               ▼                     │
31//!         ┌──────────┐                │
32//!    ┌───▶│   FULL   │◀───────────────┤
33//!    │    └────┬─────┘                │
34//!    │         │                      │
35//!    │         ▼                      │
36//!    │    ┌──────────┐                │
37//!    │    │ PREDICT  │                │
38//!    │    └────┬─────┘                │
39//!    │         │                      │
40//!    │         ▼                      │
41//!    │    ┌──────────┐                │
42//!    │    │ CORRECT  │────────────────┘
43//!    │    └────┬─────┘
44//!    │         │
45//!    └─────────┘
46//! ```
47//!
48//! ## Quick Start
49//!
50//! ```no_run
51//! use hybrid_predict_trainer_rs::{HybridTrainer, HybridTrainerConfig};
52//!
53//! // Create configuration with sensible defaults
54//! let config = HybridTrainerConfig::default();
55//!
56//! // Initialize the hybrid trainer
57//! // let trainer = HybridTrainer::new(model, optimizer, config)?;
58//!
59//! // Training loop
60//! // for batch in dataloader {
61//! //     let result = trainer.step(&batch)?;
62//! //     println!("Loss: {}, Phase: {:?}", result.loss, result.phase);
63//! // }
64//! ```
65//!
66//! ## Features
67//!
68//! - **GPU Acceleration** - `CubeCL` and Burn backends for high-performance compute
69//! - **Adaptive Phase Selection** - Bandit-based algorithm for optimal phase lengths
70//! - **Divergence Detection** - Multi-signal monitoring prevents training instability
71//! - **Residual Correction** - Online learning corrects prediction errors
72//! - **Checkpoint Support** - Save/restore full training state including predictor
73//!
74//! ## Feature Flags
75//!
76//! - `std` - Enable standard library support (default)
77//! - `cuda` - Enable CUDA GPU acceleration via `CubeCL`
78//! - `candle` - Enable Candle tensor operations for model compatibility
79//! - `async` - Enable async/await support with Tokio
80//! - `full` - Enable all features
81//!
82//! ## Architecture
83//!
84//! The crate is organized into the following modules:
85//!
86//! - [`config`] - Training configuration and serialization
87//! - [`error`] - Error types with recovery actions
88//! - [`phases`] - Phase state machine and execution control
89//! - [`state`] - Training state encoding and management
90//! - [`dynamics`] - RSSM-lite dynamics model for prediction
91//! - [`residuals`] - Residual extraction and storage
92//! - [`corrector`] - Prediction correction via residual application
93//! - [`divergence`] - Multi-signal divergence detection
94//! - [`metrics`] - Training metrics collection and reporting
95#![cfg_attr(
96    feature = "cuda",
97    doc = "- [`gpu`] - GPU acceleration kernels (requires `cuda` feature)"
98)]
99#![cfg_attr(
100    not(feature = "cuda"),
101    doc = "- `gpu` - GPU acceleration kernels (requires `cuda` feature)"
102)]
103//!
104//! ## References
105//!
106//! This implementation is based on research findings documented in
107//! `predictive-training-research.md`, synthesizing insights from:
108//!
109//! - Neural Tangent Kernel (NTK) theory for training dynamics
110//! - RSSM world models from `DreamerV3`
111//! - K-FAC for structured gradient approximation
112//! - `PowerSGD` for low-rank gradient compression
113
114#![cfg_attr(docsrs, feature(doc_cfg))]
115#![warn(missing_docs)]
116#![warn(rustdoc::missing_crate_level_docs)]
117#![deny(unsafe_code)]
118// Allow precision loss casts - acceptable in ML numerical code
119#![allow(clippy::cast_precision_loss)]
120#![allow(clippy::cast_possible_truncation)]
121#![allow(clippy::cast_sign_loss)]
122// Suppress documentation warnings during development
123#![allow(clippy::missing_errors_doc)]
124#![allow(clippy::missing_panics_doc)]
125#![allow(clippy::too_many_lines)]
126// Allow other common patterns
127#![allow(clippy::needless_range_loop)]
128#![allow(clippy::items_after_statements)]
129#![allow(clippy::unused_self)]
130#![allow(clippy::manual_clamp)]
131
132// Core modules
133pub mod config;
134pub mod error;
135pub mod phases;
136pub mod state;
137
138// Training phase implementations
139pub mod corrector;
140pub mod delta_accumulator;
141pub mod full_train;
142pub mod predictive;
143pub mod residuals;
144pub mod warmup;
145
146// Prediction and control
147pub mod bandit;
148pub mod divergence;
149pub mod dynamics;
150pub mod gru;
151
152// Metrics and monitoring
153pub mod metrics;
154
155// High-precision timing utilities
156pub mod timing;
157
158// Mixed precision training support
159pub mod mixed_precision;
160
161// Gradient accumulation for memory efficiency
162pub mod gradient_accumulation;
163
164// Predict-aware memory management (unique to HybridTrainer)
165pub mod predict_aware_memory;
166
167// Automatic tuning and optimization
168pub mod auto_tuning;
169
170// Burn framework integration
171pub mod burn_integration;
172
173// VRAM budget management
174pub mod vram_budget;
175
176// Checkpoint save/restore
177pub mod checkpoint;
178
179// VRAM manager for tracking and cleaning up GPU memory
180pub mod vram_manager;
181
182// Rust AI ecosystem integration (GpuDispatchable trait)
183#[cfg(feature = "ecosystem")]
184#[cfg_attr(docsrs, doc(cfg(feature = "ecosystem")))]
185pub mod ecosystem;
186
187// Reference model implementations for validation
188#[cfg(feature = "autodiff")]
189pub mod models;
190
191// GPU acceleration (feature-gated)
192#[cfg(feature = "cuda")]
193#[cfg_attr(docsrs, doc(cfg(feature = "cuda")))]
194pub mod gpu;
195
196// Re-exports for convenient access
197pub use auto_tuning::{
198    AutoTuningConfig, AutoTuningController, AutoTuningUpdate, BatchPrediction,
199    BatchPredictionRecommendation, HealthClassification, HealthRecommendation, HealthScorer,
200    HealthWeights, MultiStepPredictor, TrainingHealthScore,
201};
202pub use config::HybridTrainerConfig;
203pub use error::{HybridResult, HybridTrainingError, RecoveryAction};
204pub use phases::{Phase, PhaseController, PhaseDecision, PhaseOutcome};
205pub use residuals::{Residual, ResidualStore};
206pub use state::TrainingState;
207pub use timing::{Duration, Timer, TimingMetrics, TimingStats};
208
209// Standard library imports
210use std::sync::Arc;
211use std::time::Instant;
212
213// External crate imports
214// Note: Mutex used instead of RwLock for model/optimizer storage to support !Sync types
215
216/// Batch of training data.
217///
218/// Generic container for a batch of input data that will be fed to the model
219/// during training. The actual batch format depends on the model implementation.
220pub trait Batch: Send + Sync {
221    /// Returns the batch size (number of samples).
222    fn batch_size(&self) -> usize;
223}
224
225/// Gradient information from a backward pass.
226///
227/// Contains the computed gradients and loss for a training step.
228#[derive(Debug, Clone)]
229pub struct GradientInfo {
230    /// The computed loss value.
231    pub loss: f32,
232    /// L2 norm of all gradients.
233    pub gradient_norm: f32,
234    /// Per-parameter gradient norms (optional, for debugging).
235    pub per_param_norms: Option<Vec<f32>>,
236}
237
238/// Trait for models that can be trained with the hybrid trainer.
239///
240/// Models must implement forward pass, backward pass, and parameter access.
241/// The trainer will call these methods during different training phases.
242///
243/// # Why This Trait?
244///
245/// The hybrid trainer is framework-agnostic. By requiring only forward/backward
246/// and weight delta application, it works with any deep learning framework
247/// (Burn, Candle, tch-rs, etc.) that can implement these operations.
248///
249/// # Type Parameters
250///
251/// - `B`: The batch type containing input data
252///
253/// # Example
254///
255/// ```rust,ignore
256/// impl Model<MyBatch> for MyModel {
257///     fn forward(&mut self, batch: &MyBatch) -> HybridResult<f32> {
258///         // Compute forward pass and return loss
259///     }
260///
261///     fn backward(&mut self) -> HybridResult<GradientInfo> {
262///         // Compute gradients (assumes forward was just called)
263///     }
264///
265///     fn parameter_count(&self) -> usize {
266///         self.parameters.iter().map(|p| p.numel()).sum()
267///     }
268/// }
269/// ```
270///
271/// # Thread Safety
272///
273/// Models must be `Send` to allow moving between threads, but `Sync` is not
274/// required. This enables integration with autodiff frameworks (like Burn)
275/// that use gradient types which are `!Sync` by design.
276///
277/// For multi-threaded access to models, use `Arc<Mutex<>>` rather than
278/// `Arc<RwLock<>>` since the model itself may not be `Sync`.
279pub trait Model<B: Batch>: Send {
280    /// Executes the forward pass and returns the loss.
281    ///
282    /// # Arguments
283    ///
284    /// * `batch` - The input batch data
285    ///
286    /// # Returns
287    ///
288    /// The loss value for this batch.
289    fn forward(&mut self, batch: &B) -> HybridResult<f32>;
290
291    /// Executes the backward pass (gradient computation).
292    ///
293    /// Should be called after `forward()`. Computes gradients with respect
294    /// to the loss returned by the most recent forward pass.
295    ///
296    /// # Returns
297    ///
298    /// Gradient information including loss and gradient norms.
299    fn backward(&mut self) -> HybridResult<GradientInfo>;
300
301    /// Clears forward pass state when backward() won't be called.
302    ///
303    /// This method should be called during Predict phase after forward()
304    /// when backward() will be skipped. It allows implementations to free
305    /// resources associated with the forward pass (e.g., autodiff graphs,
306    /// cached activations) to prevent memory accumulation.
307    ///
308    /// # Example
309    ///
310    /// ```rust,ignore
311    /// // During Predict phase
312    /// let loss = model.forward(batch)?;
313    /// model.clear_forward_state(); // Won't call backward()
314    /// ```
315    ///
316    /// Default implementation does nothing (for implementations that don't
317    /// need cleanup).
318    fn clear_forward_state(&mut self) {
319        // Default: no-op
320    }
321
322    /// Returns the total number of trainable parameters.
323    fn parameter_count(&self) -> usize;
324
325    /// Applies a weight delta to the model parameters.
326    ///
327    /// Used during predictive phase to apply predicted weight updates.
328    ///
329    /// # Arguments
330    ///
331    /// * `delta` - The weight changes to apply
332    fn apply_weight_delta(&mut self, delta: &state::WeightDelta) -> HybridResult<()>;
333}
334
335/// Trait for optimizers that update model parameters.
336///
337/// Optimizers implement the parameter update rule (SGD, Adam, etc.).
338///
339/// # Why Separate from Model?
340///
341/// Optimizer state (momentum, variance estimates) is distinct from model
342/// parameters. Separating them allows:
343/// - **Swapping optimizers**: Try different optimizers without changing model code
344/// - **Independent serialization**: Save/load optimizer state separately
345/// - **Stateful updates**: Adam/AdaGrad need per-parameter state across steps
346///
347/// # Example
348///
349/// ```rust,ignore
350/// impl<M: Model<B>, B: Batch> Optimizer<M, B> for AdamOptimizer {
351///     fn step(&mut self, model: &mut M, gradients: &GradientInfo) -> HybridResult<()> {
352///         // Apply Adam update rule to model parameters
353///     }
354/// }
355/// ```
356///
357/// # Thread Safety
358///
359/// Optimizers must be `Send` to allow moving between threads, but `Sync` is not
360/// required. This matches the `Model` trait's threading constraints.
361pub trait Optimizer<M, B: Batch>: Send
362where
363    M: Model<B>,
364{
365    /// Performs a single optimization step.
366    ///
367    /// Updates model parameters using the computed gradients.
368    ///
369    /// # Arguments
370    ///
371    /// * `model` - The model to update
372    /// * `gradients` - Gradient information from backward pass
373    fn step(&mut self, model: &mut M, gradients: &GradientInfo) -> HybridResult<()>;
374
375    /// Returns the current learning rate.
376    fn learning_rate(&self) -> f32;
377
378    /// Sets the learning rate (for warmup/decay schedules).
379    fn set_learning_rate(&mut self, lr: f32);
380
381    /// Zeros all accumulated gradients.
382    fn zero_grad(&mut self);
383}
384
385/// The main hybrid trainer that orchestrates phase-based predictive training.
386///
387/// # Overview
388///
389/// `HybridTrainer` wraps a model and optimizer, managing the training loop through
390/// warmup, full training, predictive, and correction phases. It automatically
391/// selects optimal phase lengths using bandit-based algorithms and monitors for
392/// divergence to ensure training stability.
393///
394/// # Type Parameters
395///
396/// - `M`: The model type (must implement `Model` trait)
397/// - `O`: The optimizer type (must implement `Optimizer` trait)
398///
399/// # Example
400///
401/// ```no_run
402/// use hybrid_predict_trainer_rs::{HybridTrainer, HybridTrainerConfig};
403///
404/// // Configure the trainer
405/// let config = HybridTrainerConfig::builder()
406///     .warmup_steps(100)
407///     .full_steps(20)
408///     .max_predict_steps(50)  // Can be higher with more VRAM
409///     .confidence_threshold(0.85)
410///     .build();
411///
412/// // Create trainer (model and optimizer types are inferred)
413/// // let trainer = HybridTrainer::new(model, optimizer, config)?;
414/// ```
415pub struct HybridTrainer<M, O> {
416    /// The model being trained.
417    ///
418    /// Uses `Mutex` instead of `RwLock` because models may not be `Sync`
419    /// (e.g., when using autodiff frameworks with !Sync gradient types).
420    model: Arc<parking_lot::Mutex<M>>,
421
422    /// The optimizer for parameter updates.
423    ///
424    /// Uses `Mutex` instead of `RwLock` for consistency with model storage.
425    optimizer: Arc<parking_lot::Mutex<O>>,
426
427    /// Training configuration.
428    config: HybridTrainerConfig,
429
430    /// Current training state.
431    state: TrainingState,
432
433    /// Phase controller for state machine management.
434    phase_controller: phases::DefaultPhaseController,
435
436    /// Dynamics model for whole-phase prediction.
437    dynamics_model: dynamics::RSSMLite,
438
439    /// Divergence monitor for stability detection.
440    divergence_monitor: divergence::DivergenceMonitor,
441
442    /// Residual corrector for prediction adjustment.
443    residual_corrector: corrector::ResidualCorrector,
444
445    /// Storage for residuals extracted from prediction errors.
446    residual_store: residuals::ResidualStore,
447
448    /// Metrics collector for training statistics.
449    metrics: metrics::MetricsCollector,
450
451    /// Current phase and remaining steps (for respecting multi-step phase decisions).
452    phase_budget: Option<(Phase, usize)>,
453
454    /// Automatic tuning controller (optional).
455    auto_tuning: Option<AutoTuningController>,
456
457    /// Last auto-tuning update (for external access).
458    last_auto_tuning_update: Option<AutoTuningUpdate>,
459
460    /// Checkpoint manager for automatic checkpointing (optional).
461    checkpoint_manager: Option<checkpoint::CheckpointManager>,
462
463    /// Delta accumulator for batched weight updates (VRAM optimization).
464    delta_accumulator: delta_accumulator::DeltaAccumulator,
465
466    /// VRAM manager for tracking and cleaning up GPU memory.
467    vram_manager: vram_manager::VramManager,
468}
469
470impl<M, O> HybridTrainer<M, O> {
471    /// Creates a new hybrid trainer with the given model, optimizer, and configuration.
472    ///
473    /// # Arguments
474    ///
475    /// * `model` - The model to train
476    /// * `optimizer` - The optimizer for parameter updates
477    /// * `config` - Training configuration
478    ///
479    /// # Returns
480    ///
481    /// A new `HybridTrainer` instance wrapped in a `HybridResult`.
482    ///
483    /// # Errors
484    ///
485    /// Returns an error if the configuration is invalid or initialization fails.
486    pub fn new(model: M, optimizer: O, config: HybridTrainerConfig) -> HybridResult<Self> {
487        let state = TrainingState::new();
488        let phase_controller = phases::DefaultPhaseController::new(&config);
489        let dynamics_model = dynamics::RSSMLite::new(&config.predictor_config)?;
490        let divergence_monitor = divergence::DivergenceMonitor::new(&config);
491        let residual_corrector = corrector::ResidualCorrector::new(&config);
492        let residual_store = residuals::ResidualStore::new(1000);
493        let metrics = metrics::MetricsCollector::new(config.collect_metrics);
494
495        // Initialize auto-tuning controller if config provided
496        let auto_tuning = if let Some(auto_config) = config.auto_tuning_config.clone() {
497            let max_steps = config.max_steps.unwrap_or(10000); // Default if not provided
498            Some(auto_tuning::AutoTuningController::new(
499                auto_config,
500                max_steps,
501            ))
502        } else {
503            None
504        };
505
506        // Initialize checkpoint manager if save_interval > 0
507        let checkpoint_manager = if config.checkpoint_config.save_interval > 0 {
508            // Use "./checkpoints" as default directory if not specified
509            let checkpoint_dir = std::path::PathBuf::from("./checkpoints");
510            Some(checkpoint::CheckpointManager::new(
511                checkpoint_dir,
512                config.checkpoint_config.save_interval,
513                config.checkpoint_config.keep_last_n,
514            )?)
515        } else {
516            None
517        };
518
519        Ok(Self {
520            model: Arc::new(parking_lot::Mutex::new(model)),
521            optimizer: Arc::new(parking_lot::Mutex::new(optimizer)),
522            config,
523            state,
524            phase_controller,
525            dynamics_model,
526            divergence_monitor,
527            residual_corrector,
528            residual_store,
529            metrics,
530            phase_budget: None,
531            auto_tuning,
532            last_auto_tuning_update: None,
533            checkpoint_manager,
534            delta_accumulator: delta_accumulator::DeltaAccumulator::new(),
535            vram_manager: vram_manager::VramManager::new(),
536        })
537    }
538
539    /// Returns the current training step.
540    ///
541    /// # Returns
542    ///
543    /// The current step number (0-indexed).
544    #[must_use]
545    pub fn current_step(&self) -> u64 {
546        self.state.step
547    }
548
549    /// Returns the current training phase.
550    ///
551    /// # Returns
552    ///
553    /// The current [`Phase`] of training.
554    #[must_use]
555    pub fn current_phase(&self) -> Phase {
556        self.phase_controller.current_phase()
557    }
558
559    /// Returns the current predictor confidence level.
560    ///
561    /// # Returns
562    ///
563    /// A confidence score between 0.0 and 1.0 indicating how reliable
564    /// the predictor's outputs are estimated to be.
565    #[must_use]
566    pub fn current_confidence(&self) -> f32 {
567        self.dynamics_model.prediction_confidence(&self.state)
568    }
569
570    /// Returns training statistics and metrics.
571    ///
572    /// # Returns
573    ///
574    /// A [`metrics::TrainingStatistics`] struct containing aggregate metrics.
575    #[must_use]
576    pub fn statistics(&mut self) -> metrics::TrainingStatistics {
577        self.metrics.statistics()
578    }
579
580    /// Returns the last auto-tuning update, if available.
581    ///
582    /// # Returns
583    ///
584    /// The most recent [`auto_tuning::AutoTuningUpdate`] if auto-tuning is enabled,
585    /// or `None` if auto-tuning is disabled or no updates have occurred yet.
586    #[must_use]
587    pub fn last_auto_tuning_update(&self) -> Option<&auto_tuning::AutoTuningUpdate> {
588        self.last_auto_tuning_update.as_ref()
589    }
590
591    /// Returns a read lock on the model.
592    ///
593    /// Use this to access model state for checkpointing or inspection.
594    pub fn model(&self) -> parking_lot::MutexGuard<'_, M> {
595        self.model.lock()
596    }
597
598    /// Returns a write lock on the model.
599    ///
600    /// Use this for operations that need to modify the model directly.
601    pub fn model_mut(&self) -> parking_lot::MutexGuard<'_, M> {
602        self.model.lock()
603    }
604
605    /// Sets the learning rate on the underlying optimizer.
606    ///
607    /// This is used for learning rate scheduling (warmup, decay, etc.).
608    ///
609    /// # Arguments
610    ///
611    /// * `lr` - The new learning rate
612    ///
613    /// # Example
614    ///
615    /// ```rust,ignore
616    /// // Update learning rate based on schedule
617    /// let lr = scheduler.get_lr(trainer.current_step());
618    /// trainer.set_learning_rate(lr);
619    /// ```
620    pub fn set_learning_rate<B>(&self, lr: f32)
621    where
622        B: Batch,
623        M: Model<B>,
624        O: Optimizer<M, B>,
625    {
626        self.optimizer.lock().set_learning_rate(lr);
627    }
628
629    /// Returns the current learning rate from the optimizer.
630    ///
631    /// # Returns
632    ///
633    /// The current learning rate value.
634    pub fn learning_rate<B>(&self) -> f32
635    where
636        B: Batch,
637        M: Model<B>,
638        O: Optimizer<M, B>,
639    {
640        self.optimizer.lock().learning_rate()
641    }
642
643    // TODO: Enable when auto_tuning fields are added to struct
644    // /// Returns the last auto-tuning update, if auto-tuning is enabled.
645    // ///
646    // /// # Returns
647    // ///
648    // /// The most recent auto-tuning update, or None if auto-tuning is disabled
649    // /// or no updates have occurred yet.
650    // #[must_use]
651    // pub fn last_auto_tuning_update(&self) -> Option<&auto_tuning::AutoTuningUpdate> {
652    //     self.last_auto_tuning_update.as_ref()
653    // }
654
655    /// Returns the last recorded gradient norm.
656    ///
657    /// # Returns
658    ///
659    /// The gradient norm from the most recent backward pass.
660    #[must_use]
661    fn last_gradient_norm(&self) -> f32 {
662        self.state.gradient_norm
663    }
664
665    /// Collects per-layer gradient statistics.
666    ///
667    /// This is a stub implementation that distributes the global gradient
668    /// norm across dummy layers. A real implementation would track per-layer
669    /// gradients during the backward pass.
670    ///
671    /// # Returns
672    ///
673    /// `HashMap` of `layer_name` -> (`grad_norm`, `weight_norm`).
674    fn collect_layer_gradients(&self) -> std::collections::HashMap<String, (f32, f32)> {
675        let global_norm = self.last_gradient_norm();
676
677        // Stub: distribute gradient across typical transformer layers
678        // In a real implementation, this would come from actual per-layer tracking
679        let mut map = std::collections::HashMap::new();
680        map.insert("embed".to_string(), (global_norm * 0.8, 10.0));
681        map.insert("attention".to_string(), (global_norm * 1.0, 15.0));
682        map.insert("mlp".to_string(), (global_norm * 1.2, 20.0));
683        map.insert("lm_head".to_string(), (global_norm * 0.9, 8.0));
684        map
685    }
686}
687
688impl<M, O> HybridTrainer<M, O> {
689    /// Executes a single training step.
690    ///
691    /// This is the main entry point for the training loop. The trainer
692    /// automatically selects the appropriate phase (warmup, full, predict,
693    /// or correct) based on the current state and configuration.
694    ///
695    /// # Type Parameters
696    ///
697    /// * `B` - The batch type containing input data
698    ///
699    /// # Arguments
700    ///
701    /// * `batch` - The training batch to process
702    ///
703    /// # Returns
704    ///
705    /// A [`StepResult`] containing the loss, phase info, and prediction metadata.
706    ///
707    /// # Errors
708    ///
709    /// Returns an error if training diverges or encounters numerical issues.
710    /// The error includes a suggested recovery action when possible.
711    ///
712    /// # Example
713    ///
714    /// ```rust,ignore
715    /// for batch in dataloader {
716    ///     let result = trainer.step(&batch)?;
717    ///     println!("Step {}: loss={:.4}, phase={:?}",
718    ///         trainer.current_step(),
719    ///         result.loss,
720    ///         result.phase
721    ///     );
722    /// }
723    /// ```
724    pub fn step<B>(&mut self, batch: &B) -> HybridResult<StepResult>
725    where
726        B: Batch,
727        M: Model<B>,
728        O: Optimizer<M, B>,
729    {
730        let start_time = Instant::now();
731
732        // Update predictor confidence for phase controller
733        let confidence = self.dynamics_model.prediction_confidence(&self.state);
734        self.phase_controller.set_predictor_confidence(confidence);
735
736        // Track previous phase to detect transitions
737        let previous_phase = self.state.current_phase;
738
739        // Get phase from budget or request new decision
740        let phase = match &mut self.phase_budget {
741            Some((current_phase, remaining)) if *remaining > 0 => {
742                // Use current phase and decrement budget
743                *remaining -= 1;
744                *current_phase
745            }
746            _ => {
747                // Budget exhausted or None - get new phase decision
748                let decision = self.phase_controller.select_next_phase(&self.state);
749                let new_phase = decision.phase();
750                let steps = decision.steps();
751
752                // Set budget for remaining steps (steps-1 since we're using one now)
753                if steps > 1 {
754                    self.phase_budget = Some((new_phase, steps - 1));
755                } else {
756                    self.phase_budget = None;
757                }
758
759                new_phase
760            }
761        };
762
763        // Reset steps_in_current_phase counter on phase transitions
764        if phase != previous_phase {
765            self.state.steps_in_current_phase = 0;
766            self.state.current_phase = phase;
767
768            // Log VRAM usage at phase transitions
769            let vram_mb = self.vram_manager.last_vram_mb();
770            println!(
771                "Phase transition: {:?} → {:?} | VRAM: {} MB | Copies: {}",
772                previous_phase,
773                phase,
774                vram_mb,
775                crate::vram_manager::VramManager::total_copies()
776            );
777
778            // Flush accumulated weight deltas at phase transitions (VRAM optimization)
779            // This applies all accumulated deltas from the previous phase in one batch,
780            // minimizing model copies from Burn's .map() API
781            if let Some(merged_delta) = self.delta_accumulator.flush() {
782                let mut model = self.model.lock();
783                model.apply_weight_delta(&merged_delta)?;
784            }
785        }
786
787        // Execute the appropriate phase
788        let (loss, was_predicted, prediction_error) = match phase {
789            Phase::Warmup | Phase::Full => self.execute_full_step(batch)?,
790            Phase::Predict => self.execute_predict_step(batch)?,
791            Phase::Correct => self.execute_correct_step(batch)?,
792        };
793
794        // Check for divergence (skip during warmup - NaN values are expected initially)
795        if phase != Phase::Warmup {
796            let divergence_result = self.divergence_monitor.check(&self.state, prediction_error);
797            if divergence_result.level > error::DivergenceLevel::Caution {
798                let recovery = self
799                    .phase_controller
800                    .handle_divergence(divergence_result.level);
801                if !recovery.can_continue() {
802                    return Err((
803                        HybridTrainingError::PredictionDivergence {
804                            actual: loss,
805                            predicted: self.state.loss,
806                            delta: (loss - self.state.loss).abs(),
807                            step: self.state.step,
808                        },
809                        Some(recovery),
810                    ));
811                }
812                // Reset phase budget to force Full training after divergence
813                self.phase_budget = Some((Phase::Full, self.config.full_steps));
814            }
815        }
816
817        // Auto-tuning integration: Update controller and apply recommendations
818        if self.auto_tuning.is_some() {
819            // Collect gradients before taking mutable borrow of auto_tuning
820            let layer_grads: Vec<(String, f32, f32)> = self
821                .collect_layer_gradients()
822                .into_iter()
823                .map(|(name, (grad_norm, weight_norm))| (name, grad_norm, weight_norm))
824                .collect();
825
826            #[allow(clippy::unnecessary_unwrap)]
827            let update = self.auto_tuning.as_mut().unwrap().update(
828                self.state.step,
829                loss,
830                self.state.gradient_norm,
831                &layer_grads,
832                confidence,
833            );
834
835            // Store update for external access (TODO: wire to optimizer when available)
836            self.last_auto_tuning_update = Some(update);
837        }
838
839        // Update state
840        self.state.step += 1;
841        self.state.loss = loss;
842        self.state.loss_history.push(loss);
843        self.state.loss_ema.update(loss); // CRITICAL FIX: Update EMA for stability confidence
844
845        // Update divergence monitor
846        self.divergence_monitor
847            .observe(loss, self.state.gradient_norm);
848
849        // Record metrics
850        let step_metrics = if self.config.collect_metrics {
851            Some(self.metrics.record_step_data(
852                self.state.step,
853                loss,
854                phase,
855                was_predicted,
856                prediction_error,
857                confidence,
858            ))
859        } else {
860            None
861        };
862
863        // Auto-checkpoint if enabled and interval reached OR VRAM critical
864        // Compute checkpoint state before taking mutable borrow
865        const VRAM_CHECKPOINT_THRESHOLD_MB: usize = 14_000; // 14 GB triggers emergency checkpoint
866        let vram_critical = self.vram_manager.last_vram_mb() > VRAM_CHECKPOINT_THRESHOLD_MB;
867
868        let checkpoint_to_save = if self.checkpoint_manager.as_ref().map_or(false, |mgr| {
869            mgr.should_save(self.state.step) || vram_critical
870        }) {
871            if vram_critical {
872                eprintln!(
873                    "🚨 Emergency checkpoint triggered by high VRAM ({} MB > {} MB)",
874                    self.vram_manager.last_vram_mb(),
875                    VRAM_CHECKPOINT_THRESHOLD_MB
876                );
877            }
878            use crate::checkpoint::*;
879
880            Some(TrainingCheckpoint::new(
881                self.config.clone(),
882                self.state.clone(),
883                DynamicsState::default(),
884                ResidualStoreState::default(),
885                PhaseControllerState {
886                    current_phase: self.phase_controller.current_phase(),
887                    predictor_confidence: self.current_confidence(),
888                    warmup_complete: self.phase_controller.is_warmup_complete(),
889                    phase_stats: Vec::new(),
890                },
891                DivergenceMonitorState::default(),
892                CorrectorState::default(),
893            ))
894        } else {
895            None
896        };
897
898        // Now we can take mutable borrow to save
899        if let Some(checkpoint) = checkpoint_to_save {
900            if let Some(ref mut checkpoint_mgr) = self.checkpoint_manager {
901                // Save checkpoint (errors are logged but don't stop training)
902                if let Err((err, _)) = checkpoint_mgr.save(&checkpoint) {
903                    eprintln!(
904                        "Warning: Failed to save checkpoint at step {}: {}",
905                        self.state.step, err
906                    );
907                }
908            }
909        }
910
911        // Check if VRAM cleanup is needed (workaround for Burn's model.map() leak)
912        if self.vram_manager.should_cleanup() {
913            self.vram_manager.force_cleanup();
914        }
915
916        let elapsed_ms = start_time.elapsed().as_secs_f64() * 1000.0;
917
918        Ok(StepResult {
919            loss,
920            phase,
921            was_predicted,
922            prediction_error,
923            confidence,
924            gradient_norm: self.state.gradient_norm,
925            step_time_ms: elapsed_ms,
926            metrics: step_metrics,
927        })
928    }
929
930    /// Executes a full training step (forward + backward + optimizer step).
931    ///
932    /// Used during Warmup and Full phases.
933    fn execute_full_step<B>(&mut self, batch: &B) -> HybridResult<(f32, bool, Option<f32>)>
934    where
935        B: Batch,
936        M: Model<B>,
937        O: Optimizer<M, B>,
938    {
939        let mut model = self.model.lock();
940        let mut optimizer = self.optimizer.lock();
941
942        // Zero gradients
943        optimizer.zero_grad();
944
945        // Forward pass
946        let loss = model.forward(batch)?;
947
948        // Backward pass
949        let grad_info = model.backward()?;
950
951        // Optimizer step
952        optimizer.step(&mut *model, &grad_info)?;
953
954        // Update training state with gradient info
955        self.state.gradient_norm = grad_info.gradient_norm;
956        self.state
957            .gradient_norm_history
958            .push(grad_info.gradient_norm);
959
960        // Train the dynamics model during full steps (not warmup)
961        if self.phase_controller.is_warmup_complete() {
962            // Capture state features before model update
963            let state_features_before = self.state.compute_features();
964            let confidence = self.dynamics_model.prediction_confidence(&self.state);
965
966            // Get 1-step prediction to compute gradient residuals
967            let (prediction, _) = self.dynamics_model.predict_y_steps(&self.state, 1);
968            let predicted_loss = prediction.predicted_final_loss;
969
970            // Compute loss residual (actual - predicted)
971            let loss_residual = loss - predicted_loss;
972
973            // Create gradient residuals from per-param norms if available
974            let gradient_residuals = if let Some(ref per_param) = grad_info.per_param_norms {
975                per_param
976                    .iter()
977                    .enumerate()
978                    .map(|(idx, &actual_norm)| residuals::LayerResidual {
979                        layer_name: format!("layer_{}", idx),
980                        magnitude: actual_norm,
981                        compressed: None,       // TODO: Add compression support
982                        cosine_similarity: 1.0, // Perfect match when no prediction available
983                    })
984                    .collect()
985            } else {
986                Vec::new()
987            };
988
989            // Create and store residual for weight-level corrections
990            let residual = residuals::Residual {
991                step: self.state.step,
992                phase: Phase::Full,
993                prediction_horizon: 1,
994                loss_residual,
995                gradient_residuals,
996                state_features: state_features_before,
997                prediction_confidence: confidence,
998            };
999
1000            // Store the residual for future correction
1001            self.residual_store.add(residual.clone());
1002
1003            // Update the corrector's online model with this residual
1004            self.residual_corrector
1005                .update_from_residual(&residual, &self.state);
1006
1007            // Train the dynamics model
1008            self.dynamics_model
1009                .observe_gradient(&self.state, &grad_info);
1010        }
1011
1012        Ok((loss, false, None))
1013    }
1014
1015    /// Executes a predictive step (forward only, apply predicted weight delta).
1016    ///
1017    /// Uses the phase controller's `compute_predict_steps()` to determine the
1018    /// optimal prediction horizon based on current confidence and history,
1019    /// then calls `predict_y_steps()` with that horizon for multi-step
1020    /// prediction. This enables the dynamics model to predict further ahead
1021    /// when confidence is high, yielding greater training speedup.
1022    ///
1023    /// Used during Predict phase - skips backward pass for speedup.
1024    fn execute_predict_step<B>(&mut self, batch: &B) -> HybridResult<(f32, bool, Option<f32>)>
1025    where
1026        B: Batch,
1027        M: Model<B>,
1028    {
1029        let mut model = self.model.lock();
1030
1031        // Capture state before prediction for residual extraction
1032        let state_features_before = self.state.compute_features();
1033        let confidence = self.dynamics_model.prediction_confidence(&self.state);
1034
1035        // Compute adaptive prediction horizon based on confidence and history
1036        let y_steps = self.phase_controller.compute_predict_steps();
1037
1038        // Get multi-step prediction from dynamics model
1039        let (prediction, _uncertainty) = self.dynamics_model.predict_y_steps(&self.state, y_steps);
1040        let predicted_loss = prediction.predicted_final_loss;
1041
1042        // Apply predicted weight delta immediately (Burn limitation - can't defer)
1043        // TODO: Accumulation strategy doesn't work due to forward pass dependency
1044        model.apply_weight_delta(&prediction.weight_delta)?;
1045
1046        // Forward pass to get actual loss (for validation)
1047        let actual_loss = model.forward(batch)?;
1048
1049        // Clear forward state immediately (no backward in Predict phase)
1050        // This prevents memory accumulation from unused autodiff graphs
1051        model.clear_forward_state();
1052
1053        // Compute prediction error (absolute difference)
1054        let prediction_error = (actual_loss - predicted_loss).abs();
1055
1056        // Create and store residual (actual - predicted)
1057        let loss_residual = actual_loss - predicted_loss;
1058        let residual = residuals::Residual {
1059            step: self.state.step,
1060            phase: Phase::Predict,
1061            prediction_horizon: y_steps,
1062            loss_residual,
1063            gradient_residuals: Vec::new(), // No gradient info in predict phase
1064            state_features: state_features_before,
1065            prediction_confidence: confidence,
1066        };
1067
1068        // Store the residual for future correction
1069        self.residual_store.add(residual.clone());
1070
1071        // Update the corrector's online model with this residual
1072        self.residual_corrector
1073            .update_from_residual(&residual, &self.state);
1074
1075        // Check if micro-correction is needed (intra-horizon correction)
1076        self.state.steps_in_current_phase += 1;
1077
1078        if self.config.correction_interval > 0
1079            && self.state.steps_in_current_phase % self.config.correction_interval == 0
1080        {
1081            // Apply micro-correction without transitioning to Correct phase
1082            let correction = if self.residual_store.is_empty() {
1083                corrector::Correction::zero()
1084            } else {
1085                self.residual_corrector.compute_correction(
1086                    &self.state,
1087                    &self.residual_store,
1088                    actual_loss,
1089                )
1090            };
1091
1092            // Apply weight correction if available and significant
1093            if let Some(ref weight_correction) = correction.weight_correction {
1094                model.apply_weight_delta(weight_correction)?;
1095            } else if correction.is_significant(0.01) {
1096                // Apply simple correction if loss correction is significant but no weight delta
1097                if let Some(simple_delta) = self
1098                    .residual_corrector
1099                    .compute_simple_correction(&self.state)
1100                {
1101                    model.apply_weight_delta(&simple_delta)?;
1102                }
1103            }
1104
1105            // Record that we applied a micro-correction
1106            self.metrics.record_micro_correction();
1107        }
1108
1109        Ok((actual_loss, true, Some(prediction_error)))
1110    }
1111
1112    /// Executes a correction step (apply residual corrections).
1113    ///
1114    /// Used during Correct phase to adjust for accumulated prediction errors.
1115    /// Uses stored residuals from prediction phase to compute corrections.
1116    fn execute_correct_step<B>(&mut self, batch: &B) -> HybridResult<(f32, bool, Option<f32>)>
1117    where
1118        B: Batch,
1119        M: Model<B>,
1120    {
1121        let mut model = self.model.lock();
1122
1123        // Compute correction using stored residuals for context-aware adjustment
1124        let correction = if self.residual_store.is_empty() {
1125            // Fall back to simple correction if no residuals available
1126            corrector::Correction::zero()
1127        } else {
1128            // Use full correction with residual store for better estimates
1129            let predicted_loss = self.state.loss; // Use current loss as baseline
1130            self.residual_corrector.compute_correction(
1131                &self.state,
1132                &self.residual_store,
1133                predicted_loss,
1134            )
1135        };
1136
1137        // Apply weight delta correction immediately (Burn limitation - can't defer)
1138        if let Some(ref delta) = correction.weight_correction {
1139            model.apply_weight_delta(delta)?;
1140        } else if correction.is_significant(0.01) {
1141            // If no weight correction but loss correction is significant,
1142            // apply a simple scaled correction
1143            let simple_delta = self
1144                .residual_corrector
1145                .compute_simple_correction(&self.state);
1146            if let Some(delta) = simple_delta {
1147                model.apply_weight_delta(&delta)?;
1148            }
1149        }
1150
1151        // Forward pass to validate correction
1152        let loss = model.forward(batch)?;
1153
1154        // Clear forward state immediately (no backward in Correct phase)
1155        model.clear_forward_state();
1156
1157        // Compute how much the correction changed the loss (for metrics)
1158        let correction_effect = if correction.loss_correction.abs() > 0.001 {
1159            Some((self.state.loss - loss).abs())
1160        } else {
1161            None
1162        };
1163
1164        Ok((loss, false, correction_effect))
1165    }
1166
1167    /// Forces the trainer into full training mode for the specified number of steps.
1168    ///
1169    /// Useful for recovery from divergence or manual intervention.
1170    ///
1171    /// # Arguments
1172    ///
1173    /// * `steps` - Number of full steps to force
1174    pub fn force_full_phase(&mut self, steps: usize) {
1175        self.phase_controller.force_phase(Phase::Full);
1176        // Set the phase budget to enforce the requested number of steps
1177        self.phase_budget = Some((Phase::Full, steps));
1178    }
1179
1180    /// Saves a checkpoint of the trainer state.
1181    ///
1182    /// This saves all hybrid trainer state including training state, dynamics model,
1183    /// residual store, and phase controller state. It does NOT save model or optimizer
1184    /// state - those must be checkpointed separately via your deep learning framework.
1185    ///
1186    /// # Arguments
1187    ///
1188    /// * `path` - Path where the checkpoint should be saved
1189    ///
1190    /// # Returns
1191    ///
1192    /// `Ok(())` on success.
1193    ///
1194    /// # Errors
1195    ///
1196    /// Returns an error if checkpoint serialization or file I/O fails.
1197    ///
1198    /// # Example
1199    ///
1200    /// ```rust,ignore
1201    /// // Save hybrid trainer state
1202    /// trainer.save_checkpoint("checkpoints/hybrid_step_1000.bin")?;
1203    ///
1204    /// // User should also save model and optimizer separately
1205    /// model.save("checkpoints/model_step_1000.safetensors")?;
1206    /// optimizer.save("checkpoints/optimizer_step_1000.bin")?;
1207    /// ```
1208    pub fn save_checkpoint(&self, path: impl AsRef<std::path::Path>) -> HybridResult<()> {
1209        use crate::checkpoint::*;
1210
1211        // Extract serializable state from components
1212        let dynamics_state = DynamicsState::default(); // TODO: Extract from self.dynamics_model
1213        let residual_store_state = ResidualStoreState::default(); // TODO: Extract from self.residual_store
1214        let phase_controller_state = PhaseControllerState {
1215            current_phase: self.phase_controller.current_phase(),
1216            predictor_confidence: self.current_confidence(),
1217            warmup_complete: self.phase_controller.is_warmup_complete(),
1218            phase_stats: Vec::new(), // TODO: Extract stats
1219        };
1220        let divergence_monitor_state = DivergenceMonitorState::default(); // TODO: Extract from self.divergence_monitor
1221        let corrector_state = CorrectorState::default(); // TODO: Extract from self.residual_corrector
1222
1223        let checkpoint = TrainingCheckpoint::new(
1224            self.config.clone(),
1225            self.state.clone(),
1226            dynamics_state,
1227            residual_store_state,
1228            phase_controller_state,
1229            divergence_monitor_state,
1230            corrector_state,
1231        );
1232
1233        checkpoint.save(path)
1234    }
1235
1236    /// Loads a checkpoint and creates a new trainer with the given model and optimizer.
1237    ///
1238    /// This restores all hybrid trainer state from a checkpoint file. The model and
1239    /// optimizer must be provided separately since they're framework-specific.
1240    ///
1241    /// # Type Parameters
1242    ///
1243    /// * `M` - The model type
1244    /// * `O` - The optimizer type
1245    ///
1246    /// # Arguments
1247    ///
1248    /// * `path` - Path to the checkpoint file
1249    /// * `model` - The model to train (should be loaded from a separate checkpoint)
1250    /// * `optimizer` - The optimizer (should be loaded from a separate checkpoint)
1251    ///
1252    /// # Returns
1253    ///
1254    /// A new `HybridTrainer` instance with restored state.
1255    ///
1256    /// # Errors
1257    ///
1258    /// Returns an error if checkpoint loading or deserialization fails.
1259    ///
1260    /// # Example
1261    ///
1262    /// ```rust,ignore
1263    /// // Load model and optimizer from their checkpoints
1264    /// let model = MyModel::load("checkpoints/model_step_1000.safetensors")?;
1265    /// let optimizer = MyOptimizer::load("checkpoints/optimizer_step_1000.bin")?;
1266    ///
1267    /// // Load hybrid trainer state
1268    /// let trainer = HybridTrainer::load_checkpoint(
1269    ///     "checkpoints/hybrid_step_1000.bin",
1270    ///     model,
1271    ///     optimizer
1272    /// )?;
1273    /// ```
1274    pub fn load_checkpoint(
1275        path: impl AsRef<std::path::Path>,
1276        model: M,
1277        optimizer: O,
1278    ) -> HybridResult<Self> {
1279        use crate::checkpoint::*;
1280
1281        let checkpoint = TrainingCheckpoint::load(path)?;
1282
1283        // Reconstruct trainer components from checkpoint state
1284        let phase_controller = phases::DefaultPhaseController::new(&checkpoint.config);
1285        let dynamics_model = dynamics::RSSMLite::new(&checkpoint.config.predictor_config)?;
1286        let divergence_monitor = divergence::DivergenceMonitor::new(&checkpoint.config);
1287        let residual_corrector = corrector::ResidualCorrector::new(&checkpoint.config);
1288        let residual_store = residuals::ResidualStore::new(1000);
1289        let metrics = metrics::MetricsCollector::new(checkpoint.config.collect_metrics);
1290
1291        // TODO: Restore state into dynamics_model, residual_store, etc. from checkpoint
1292
1293        // Initialize auto-tuning controller if config provided
1294        let auto_tuning = if let Some(auto_config) = checkpoint.config.auto_tuning_config.clone() {
1295            let max_steps = checkpoint.config.max_steps.unwrap_or(10000);
1296            Some(auto_tuning::AutoTuningController::new(
1297                auto_config,
1298                max_steps,
1299            ))
1300        } else {
1301            None
1302        };
1303
1304        // Note: checkpoint_manager is NOT restored from checkpoint
1305        // It will be re-initialized if needed based on the config
1306        let checkpoint_manager = if checkpoint.config.checkpoint_config.save_interval > 0 {
1307            let checkpoint_dir = std::path::PathBuf::from("./checkpoints");
1308            Some(checkpoint::CheckpointManager::new(
1309                checkpoint_dir,
1310                checkpoint.config.checkpoint_config.save_interval,
1311                checkpoint.config.checkpoint_config.keep_last_n,
1312            )?)
1313        } else {
1314            None
1315        };
1316
1317        Ok(Self {
1318            model: Arc::new(parking_lot::Mutex::new(model)),
1319            optimizer: Arc::new(parking_lot::Mutex::new(optimizer)),
1320            config: checkpoint.config,
1321            state: checkpoint.training_state,
1322            phase_controller,
1323            dynamics_model,
1324            divergence_monitor,
1325            residual_corrector,
1326            residual_store,
1327            metrics,
1328            phase_budget: None,
1329            auto_tuning,
1330            last_auto_tuning_update: None,
1331            checkpoint_manager,
1332            delta_accumulator: delta_accumulator::DeltaAccumulator::new(),
1333            vram_manager: vram_manager::VramManager::new(),
1334        })
1335    }
1336}
1337
1338/// Result of a single training step.
1339///
1340/// Contains the loss value, phase information, and prediction metadata
1341/// for monitoring training progress and predictor accuracy.
1342#[derive(Debug, Clone)]
1343pub struct StepResult {
1344    /// The loss value for this step.
1345    pub loss: f32,
1346
1347    /// The phase during which this step was executed.
1348    pub phase: Phase,
1349
1350    /// Whether this step used predicted gradients (true) or computed gradients (false).
1351    pub was_predicted: bool,
1352
1353    /// The error between predicted and actual loss (if applicable).
1354    pub prediction_error: Option<f32>,
1355
1356    /// The predictor's confidence for this step.
1357    pub confidence: f32,
1358
1359    /// Gradient norm from the backward pass (0.0 if predicted step).
1360    pub gradient_norm: f32,
1361
1362    /// Wall-clock time for this step in milliseconds.
1363    pub step_time_ms: f64,
1364
1365    /// Detailed metrics (if collection is enabled).
1366    pub metrics: Option<metrics::StepMetrics>,
1367}
1368
1369/// Prelude module for convenient imports.
1370///
1371/// # Example
1372///
1373/// ```
1374/// use hybrid_predict_trainer_rs::prelude::*;
1375/// ```
1376pub mod prelude {
1377    pub use crate::{
1378        Batch, GradientInfo, HybridResult, HybridTrainer, HybridTrainerConfig, HybridTrainingError,
1379        Model, Optimizer, Phase, PhaseDecision, RecoveryAction, StepResult, TrainingState,
1380    };
1381}
1382
1383#[cfg(test)]
1384mod tests {
1385    use super::*;
1386
1387    #[test]
1388    fn test_default_config() {
1389        let config = HybridTrainerConfig::default();
1390        assert_eq!(config.warmup_steps, 100);
1391        assert_eq!(config.full_steps, 20);
1392        assert_eq!(config.max_predict_steps, 15); // Updated: default reduced for VRAM optimization
1393        assert!((config.confidence_threshold - 0.85).abs() < f32::EPSILON);
1394    }
1395
1396    #[test]
1397    fn test_correction_interval_config() {
1398        // Test default (disabled)
1399        let config = HybridTrainerConfig::default();
1400        assert_eq!(config.correction_interval, 0);
1401
1402        // Test builder pattern
1403        let config = HybridTrainerConfig::builder()
1404            .correction_interval(10)
1405            .build();
1406        assert_eq!(config.correction_interval, 10);
1407    }
1408
1409    #[test]
1410    fn test_steps_in_current_phase_counter() {
1411        let mut state = TrainingState::new();
1412        assert_eq!(state.steps_in_current_phase, 0);
1413
1414        // Simulate phase transition
1415        state.enter_phase(Phase::Predict);
1416        assert_eq!(state.steps_in_current_phase, 0);
1417
1418        // Simulate steps within phase
1419        state.steps_in_current_phase = 5;
1420        assert_eq!(state.steps_in_current_phase, 5);
1421
1422        // Phase transition should reset
1423        state.enter_phase(Phase::Correct);
1424        assert_eq!(state.steps_in_current_phase, 0);
1425    }
1426}