hybrid_predict_trainer_rs/lib.rs
1//! # hybrid-predict-trainer-rs
2//!
3//! Hybridized predictive training framework that accelerates deep learning through
4//! intelligent phase-based training with whole-phase prediction and residual correction.
5//!
6//! ## Overview
7//!
8//! This crate implements a novel training paradigm that achieves 5-10x speedup over
9//! traditional training by predicting training outcomes rather than computing every
10//! gradient step. The key insight is that training dynamics evolve on low-dimensional
11//! manifolds, making whole-phase prediction tractable.
12//!
13//! ## Training Phases
14//!
15//! The training loop cycles through four distinct phases:
16//!
17//! 1. **Warmup Phase** - Initial training steps to establish baseline dynamics
18//! 2. **Full Training Phase** - Traditional forward/backward pass computation
19//! 3. **Predictive Phase** - Skip backward passes using learned dynamics model
20//! 4. **Correction Phase** - Apply residual corrections to maintain accuracy
21//!
22//! ```text
23//! ┌─────────┐
24//! │ WARMUP │
25//! └────┬────┘
26//! │
27//! ▼
28//! ┌─────────────────────┐
29//! │ │
30//! ▼ │
31//! ┌──────────┐ │
32//! ┌───▶│ FULL │◀───────────────┤
33//! │ └────┬─────┘ │
34//! │ │ │
35//! │ ▼ │
36//! │ ┌──────────┐ │
37//! │ │ PREDICT │ │
38//! │ └────┬─────┘ │
39//! │ │ │
40//! │ ▼ │
41//! │ ┌──────────┐ │
42//! │ │ CORRECT │────────────────┘
43//! │ └────┬─────┘
44//! │ │
45//! └─────────┘
46//! ```
47//!
48//! ## Quick Start
49//!
50//! ```no_run
51//! use hybrid_predict_trainer_rs::{HybridTrainer, HybridTrainerConfig};
52//!
53//! // Create configuration with sensible defaults
54//! let config = HybridTrainerConfig::default();
55//!
56//! // Initialize the hybrid trainer
57//! // let trainer = HybridTrainer::new(model, optimizer, config)?;
58//!
59//! // Training loop
60//! // for batch in dataloader {
61//! // let result = trainer.step(&batch)?;
62//! // println!("Loss: {}, Phase: {:?}", result.loss, result.phase);
63//! // }
64//! ```
65//!
66//! ## Features
67//!
68//! - **GPU Acceleration** - `CubeCL` and Burn backends for high-performance compute
69//! - **Adaptive Phase Selection** - Bandit-based algorithm for optimal phase lengths
70//! - **Divergence Detection** - Multi-signal monitoring prevents training instability
71//! - **Residual Correction** - Online learning corrects prediction errors
72//! - **Checkpoint Support** - Save/restore full training state including predictor
73//!
74//! ## Feature Flags
75//!
76//! - `std` - Enable standard library support (default)
77//! - `cuda` - Enable CUDA GPU acceleration via `CubeCL`
78//! - `candle` - Enable Candle tensor operations for model compatibility
79//! - `async` - Enable async/await support with Tokio
80//! - `full` - Enable all features
81//!
82//! ## Architecture
83//!
84//! The crate is organized into the following modules:
85//!
86//! - [`config`] - Training configuration and serialization
87//! - [`error`] - Error types with recovery actions
88//! - [`phases`] - Phase state machine and execution control
89//! - [`state`] - Training state encoding and management
90//! - [`dynamics`] - RSSM-lite dynamics model for prediction
91//! - [`residuals`] - Residual extraction and storage
92//! - [`corrector`] - Prediction correction via residual application
93//! - [`divergence`] - Multi-signal divergence detection
94//! - [`metrics`] - Training metrics collection and reporting
95#![cfg_attr(
96 feature = "cuda",
97 doc = "- [`gpu`] - GPU acceleration kernels (requires `cuda` feature)"
98)]
99#![cfg_attr(
100 not(feature = "cuda"),
101 doc = "- `gpu` - GPU acceleration kernels (requires `cuda` feature)"
102)]
103//!
104//! ## References
105//!
106//! This implementation is based on research findings documented in
107//! `predictive-training-research.md`, synthesizing insights from:
108//!
109//! - Neural Tangent Kernel (NTK) theory for training dynamics
110//! - RSSM world models from `DreamerV3`
111//! - K-FAC for structured gradient approximation
112//! - `PowerSGD` for low-rank gradient compression
113
114#![cfg_attr(docsrs, feature(doc_cfg))]
115#![warn(missing_docs)]
116#![warn(rustdoc::missing_crate_level_docs)]
117#![deny(unsafe_code)]
118// Allow precision loss casts - acceptable in ML numerical code
119#![allow(clippy::cast_precision_loss)]
120#![allow(clippy::cast_possible_truncation)]
121#![allow(clippy::cast_sign_loss)]
122// Suppress documentation warnings during development
123#![allow(clippy::missing_errors_doc)]
124#![allow(clippy::missing_panics_doc)]
125#![allow(clippy::too_many_lines)]
126// Allow other common patterns
127#![allow(clippy::needless_range_loop)]
128#![allow(clippy::items_after_statements)]
129#![allow(clippy::unused_self)]
130#![allow(clippy::manual_clamp)]
131
132// Core modules
133pub mod config;
134pub mod error;
135pub mod phases;
136pub mod state;
137
138// Training phase implementations
139pub mod corrector;
140pub mod delta_accumulator;
141pub mod full_train;
142pub mod predictive;
143pub mod residuals;
144pub mod warmup;
145
146// Prediction and control
147pub mod bandit;
148pub mod divergence;
149pub mod dynamics;
150pub mod gru;
151
152// Metrics and monitoring
153pub mod metrics;
154
155// High-precision timing utilities
156pub mod timing;
157
158// Mixed precision training support
159pub mod mixed_precision;
160
161// Gradient accumulation for memory efficiency
162pub mod gradient_accumulation;
163
164// Predict-aware memory management (unique to HybridTrainer)
165pub mod predict_aware_memory;
166
167// Automatic tuning and optimization
168pub mod auto_tuning;
169
170// Burn framework integration
171pub mod burn_integration;
172
173// VRAM budget management
174pub mod vram_budget;
175
176// Checkpoint save/restore
177pub mod checkpoint;
178
179// VRAM manager for tracking and cleaning up GPU memory
180pub mod vram_manager;
181
182// Rust AI ecosystem integration (GpuDispatchable trait)
183#[cfg(feature = "ecosystem")]
184#[cfg_attr(docsrs, doc(cfg(feature = "ecosystem")))]
185pub mod ecosystem;
186
187// Reference model implementations for validation
188#[cfg(feature = "autodiff")]
189pub mod models;
190
191// GPU acceleration (feature-gated)
192#[cfg(feature = "cuda")]
193#[cfg_attr(docsrs, doc(cfg(feature = "cuda")))]
194pub mod gpu;
195
196// Re-exports for convenient access
197pub use auto_tuning::{
198 AutoTuningConfig, AutoTuningController, AutoTuningUpdate, BatchPrediction,
199 BatchPredictionRecommendation, HealthClassification, HealthRecommendation, HealthScorer,
200 HealthWeights, MultiStepPredictor, TrainingHealthScore,
201};
202pub use config::HybridTrainerConfig;
203pub use error::{HybridResult, HybridTrainingError, RecoveryAction};
204pub use phases::{Phase, PhaseController, PhaseDecision, PhaseOutcome};
205pub use residuals::{Residual, ResidualStore};
206pub use state::TrainingState;
207pub use timing::{Duration, Timer, TimingMetrics, TimingStats};
208
209// Standard library imports
210use std::sync::Arc;
211use std::time::Instant;
212
213// External crate imports
214// Note: Mutex used instead of RwLock for model/optimizer storage to support !Sync types
215
216/// Batch of training data.
217///
218/// Generic container for a batch of input data that will be fed to the model
219/// during training. The actual batch format depends on the model implementation.
220pub trait Batch: Send + Sync {
221 /// Returns the batch size (number of samples).
222 fn batch_size(&self) -> usize;
223}
224
225/// Gradient information from a backward pass.
226///
227/// Contains the computed gradients and loss for a training step.
228#[derive(Debug, Clone)]
229pub struct GradientInfo {
230 /// The computed loss value.
231 pub loss: f32,
232 /// L2 norm of all gradients.
233 pub gradient_norm: f32,
234 /// Per-parameter gradient norms (optional, for debugging).
235 pub per_param_norms: Option<Vec<f32>>,
236}
237
238/// Trait for models that can be trained with the hybrid trainer.
239///
240/// Models must implement forward pass, backward pass, and parameter access.
241/// The trainer will call these methods during different training phases.
242///
243/// # Why This Trait?
244///
245/// The hybrid trainer is framework-agnostic. By requiring only forward/backward
246/// and weight delta application, it works with any deep learning framework
247/// (Burn, Candle, tch-rs, etc.) that can implement these operations.
248///
249/// # Type Parameters
250///
251/// - `B`: The batch type containing input data
252///
253/// # Example
254///
255/// ```rust,ignore
256/// impl Model<MyBatch> for MyModel {
257/// fn forward(&mut self, batch: &MyBatch) -> HybridResult<f32> {
258/// // Compute forward pass and return loss
259/// }
260///
261/// fn backward(&mut self) -> HybridResult<GradientInfo> {
262/// // Compute gradients (assumes forward was just called)
263/// }
264///
265/// fn parameter_count(&self) -> usize {
266/// self.parameters.iter().map(|p| p.numel()).sum()
267/// }
268/// }
269/// ```
270///
271/// # Thread Safety
272///
273/// Models must be `Send` to allow moving between threads, but `Sync` is not
274/// required. This enables integration with autodiff frameworks (like Burn)
275/// that use gradient types which are `!Sync` by design.
276///
277/// For multi-threaded access to models, use `Arc<Mutex<>>` rather than
278/// `Arc<RwLock<>>` since the model itself may not be `Sync`.
279pub trait Model<B: Batch>: Send {
280 /// Executes the forward pass and returns the loss.
281 ///
282 /// # Arguments
283 ///
284 /// * `batch` - The input batch data
285 ///
286 /// # Returns
287 ///
288 /// The loss value for this batch.
289 fn forward(&mut self, batch: &B) -> HybridResult<f32>;
290
291 /// Executes the backward pass (gradient computation).
292 ///
293 /// Should be called after `forward()`. Computes gradients with respect
294 /// to the loss returned by the most recent forward pass.
295 ///
296 /// # Returns
297 ///
298 /// Gradient information including loss and gradient norms.
299 fn backward(&mut self) -> HybridResult<GradientInfo>;
300
301 /// Clears forward pass state when backward() won't be called.
302 ///
303 /// This method should be called during Predict phase after forward()
304 /// when backward() will be skipped. It allows implementations to free
305 /// resources associated with the forward pass (e.g., autodiff graphs,
306 /// cached activations) to prevent memory accumulation.
307 ///
308 /// # Example
309 ///
310 /// ```rust,ignore
311 /// // During Predict phase
312 /// let loss = model.forward(batch)?;
313 /// model.clear_forward_state(); // Won't call backward()
314 /// ```
315 ///
316 /// Default implementation does nothing (for implementations that don't
317 /// need cleanup).
318 fn clear_forward_state(&mut self) {
319 // Default: no-op
320 }
321
322 /// Returns the total number of trainable parameters.
323 fn parameter_count(&self) -> usize;
324
325 /// Applies a weight delta to the model parameters.
326 ///
327 /// Used during predictive phase to apply predicted weight updates.
328 ///
329 /// # Arguments
330 ///
331 /// * `delta` - The weight changes to apply
332 fn apply_weight_delta(&mut self, delta: &state::WeightDelta) -> HybridResult<()>;
333}
334
335/// Trait for optimizers that update model parameters.
336///
337/// Optimizers implement the parameter update rule (SGD, Adam, etc.).
338///
339/// # Why Separate from Model?
340///
341/// Optimizer state (momentum, variance estimates) is distinct from model
342/// parameters. Separating them allows:
343/// - **Swapping optimizers**: Try different optimizers without changing model code
344/// - **Independent serialization**: Save/load optimizer state separately
345/// - **Stateful updates**: Adam/AdaGrad need per-parameter state across steps
346///
347/// # Example
348///
349/// ```rust,ignore
350/// impl<M: Model<B>, B: Batch> Optimizer<M, B> for AdamOptimizer {
351/// fn step(&mut self, model: &mut M, gradients: &GradientInfo) -> HybridResult<()> {
352/// // Apply Adam update rule to model parameters
353/// }
354/// }
355/// ```
356///
357/// # Thread Safety
358///
359/// Optimizers must be `Send` to allow moving between threads, but `Sync` is not
360/// required. This matches the `Model` trait's threading constraints.
361pub trait Optimizer<M, B: Batch>: Send
362where
363 M: Model<B>,
364{
365 /// Performs a single optimization step.
366 ///
367 /// Updates model parameters using the computed gradients.
368 ///
369 /// # Arguments
370 ///
371 /// * `model` - The model to update
372 /// * `gradients` - Gradient information from backward pass
373 fn step(&mut self, model: &mut M, gradients: &GradientInfo) -> HybridResult<()>;
374
375 /// Returns the current learning rate.
376 fn learning_rate(&self) -> f32;
377
378 /// Sets the learning rate (for warmup/decay schedules).
379 fn set_learning_rate(&mut self, lr: f32);
380
381 /// Zeros all accumulated gradients.
382 fn zero_grad(&mut self);
383}
384
385/// The main hybrid trainer that orchestrates phase-based predictive training.
386///
387/// # Overview
388///
389/// `HybridTrainer` wraps a model and optimizer, managing the training loop through
390/// warmup, full training, predictive, and correction phases. It automatically
391/// selects optimal phase lengths using bandit-based algorithms and monitors for
392/// divergence to ensure training stability.
393///
394/// # Type Parameters
395///
396/// - `M`: The model type (must implement `Model` trait)
397/// - `O`: The optimizer type (must implement `Optimizer` trait)
398///
399/// # Example
400///
401/// ```no_run
402/// use hybrid_predict_trainer_rs::{HybridTrainer, HybridTrainerConfig};
403///
404/// // Configure the trainer
405/// let config = HybridTrainerConfig::builder()
406/// .warmup_steps(100)
407/// .full_steps(20)
408/// .max_predict_steps(50) // Can be higher with more VRAM
409/// .confidence_threshold(0.85)
410/// .build();
411///
412/// // Create trainer (model and optimizer types are inferred)
413/// // let trainer = HybridTrainer::new(model, optimizer, config)?;
414/// ```
415pub struct HybridTrainer<M, O> {
416 /// The model being trained.
417 ///
418 /// Uses `Mutex` instead of `RwLock` because models may not be `Sync`
419 /// (e.g., when using autodiff frameworks with !Sync gradient types).
420 model: Arc<parking_lot::Mutex<M>>,
421
422 /// The optimizer for parameter updates.
423 ///
424 /// Uses `Mutex` instead of `RwLock` for consistency with model storage.
425 optimizer: Arc<parking_lot::Mutex<O>>,
426
427 /// Training configuration.
428 config: HybridTrainerConfig,
429
430 /// Current training state.
431 state: TrainingState,
432
433 /// Phase controller for state machine management.
434 phase_controller: phases::DefaultPhaseController,
435
436 /// Dynamics model for whole-phase prediction.
437 dynamics_model: dynamics::RSSMLite,
438
439 /// Divergence monitor for stability detection.
440 divergence_monitor: divergence::DivergenceMonitor,
441
442 /// Residual corrector for prediction adjustment.
443 residual_corrector: corrector::ResidualCorrector,
444
445 /// Storage for residuals extracted from prediction errors.
446 residual_store: residuals::ResidualStore,
447
448 /// Metrics collector for training statistics.
449 metrics: metrics::MetricsCollector,
450
451 /// Current phase and remaining steps (for respecting multi-step phase decisions).
452 phase_budget: Option<(Phase, usize)>,
453
454 /// Automatic tuning controller (optional).
455 auto_tuning: Option<AutoTuningController>,
456
457 /// Last auto-tuning update (for external access).
458 last_auto_tuning_update: Option<AutoTuningUpdate>,
459
460 /// Checkpoint manager for automatic checkpointing (optional).
461 checkpoint_manager: Option<checkpoint::CheckpointManager>,
462
463 /// Delta accumulator for batched weight updates (VRAM optimization).
464 delta_accumulator: delta_accumulator::DeltaAccumulator,
465
466 /// VRAM manager for tracking and cleaning up GPU memory.
467 vram_manager: vram_manager::VramManager,
468}
469
470impl<M, O> HybridTrainer<M, O> {
471 /// Creates a new hybrid trainer with the given model, optimizer, and configuration.
472 ///
473 /// # Arguments
474 ///
475 /// * `model` - The model to train
476 /// * `optimizer` - The optimizer for parameter updates
477 /// * `config` - Training configuration
478 ///
479 /// # Returns
480 ///
481 /// A new `HybridTrainer` instance wrapped in a `HybridResult`.
482 ///
483 /// # Errors
484 ///
485 /// Returns an error if the configuration is invalid or initialization fails.
486 pub fn new(model: M, optimizer: O, config: HybridTrainerConfig) -> HybridResult<Self> {
487 let state = TrainingState::new();
488 let phase_controller = phases::DefaultPhaseController::new(&config);
489 let dynamics_model = dynamics::RSSMLite::new(&config.predictor_config)?;
490 let divergence_monitor = divergence::DivergenceMonitor::new(&config);
491 let residual_corrector = corrector::ResidualCorrector::new(&config);
492 let residual_store = residuals::ResidualStore::new(1000);
493 let metrics = metrics::MetricsCollector::new(config.collect_metrics);
494
495 // Initialize auto-tuning controller if config provided
496 let auto_tuning = if let Some(auto_config) = config.auto_tuning_config.clone() {
497 let max_steps = config.max_steps.unwrap_or(10000); // Default if not provided
498 Some(auto_tuning::AutoTuningController::new(
499 auto_config,
500 max_steps,
501 ))
502 } else {
503 None
504 };
505
506 // Initialize checkpoint manager if save_interval > 0
507 let checkpoint_manager = if config.checkpoint_config.save_interval > 0 {
508 // Use "./checkpoints" as default directory if not specified
509 let checkpoint_dir = std::path::PathBuf::from("./checkpoints");
510 Some(checkpoint::CheckpointManager::new(
511 checkpoint_dir,
512 config.checkpoint_config.save_interval,
513 config.checkpoint_config.keep_last_n,
514 )?)
515 } else {
516 None
517 };
518
519 Ok(Self {
520 model: Arc::new(parking_lot::Mutex::new(model)),
521 optimizer: Arc::new(parking_lot::Mutex::new(optimizer)),
522 config,
523 state,
524 phase_controller,
525 dynamics_model,
526 divergence_monitor,
527 residual_corrector,
528 residual_store,
529 metrics,
530 phase_budget: None,
531 auto_tuning,
532 last_auto_tuning_update: None,
533 checkpoint_manager,
534 delta_accumulator: delta_accumulator::DeltaAccumulator::new(),
535 vram_manager: vram_manager::VramManager::new(),
536 })
537 }
538
539 /// Returns the current training step.
540 ///
541 /// # Returns
542 ///
543 /// The current step number (0-indexed).
544 #[must_use]
545 pub fn current_step(&self) -> u64 {
546 self.state.step
547 }
548
549 /// Returns the current training phase.
550 ///
551 /// # Returns
552 ///
553 /// The current [`Phase`] of training.
554 #[must_use]
555 pub fn current_phase(&self) -> Phase {
556 self.phase_controller.current_phase()
557 }
558
559 /// Returns the current predictor confidence level.
560 ///
561 /// # Returns
562 ///
563 /// A confidence score between 0.0 and 1.0 indicating how reliable
564 /// the predictor's outputs are estimated to be.
565 #[must_use]
566 pub fn current_confidence(&self) -> f32 {
567 self.dynamics_model.prediction_confidence(&self.state)
568 }
569
570 /// Returns training statistics and metrics.
571 ///
572 /// # Returns
573 ///
574 /// A [`metrics::TrainingStatistics`] struct containing aggregate metrics.
575 #[must_use]
576 pub fn statistics(&mut self) -> metrics::TrainingStatistics {
577 self.metrics.statistics()
578 }
579
580 /// Returns the last auto-tuning update, if available.
581 ///
582 /// # Returns
583 ///
584 /// The most recent [`auto_tuning::AutoTuningUpdate`] if auto-tuning is enabled,
585 /// or `None` if auto-tuning is disabled or no updates have occurred yet.
586 #[must_use]
587 pub fn last_auto_tuning_update(&self) -> Option<&auto_tuning::AutoTuningUpdate> {
588 self.last_auto_tuning_update.as_ref()
589 }
590
591 /// Returns a read lock on the model.
592 ///
593 /// Use this to access model state for checkpointing or inspection.
594 pub fn model(&self) -> parking_lot::MutexGuard<'_, M> {
595 self.model.lock()
596 }
597
598 /// Returns a write lock on the model.
599 ///
600 /// Use this for operations that need to modify the model directly.
601 pub fn model_mut(&self) -> parking_lot::MutexGuard<'_, M> {
602 self.model.lock()
603 }
604
605 /// Sets the learning rate on the underlying optimizer.
606 ///
607 /// This is used for learning rate scheduling (warmup, decay, etc.).
608 ///
609 /// # Arguments
610 ///
611 /// * `lr` - The new learning rate
612 ///
613 /// # Example
614 ///
615 /// ```rust,ignore
616 /// // Update learning rate based on schedule
617 /// let lr = scheduler.get_lr(trainer.current_step());
618 /// trainer.set_learning_rate(lr);
619 /// ```
620 pub fn set_learning_rate<B>(&self, lr: f32)
621 where
622 B: Batch,
623 M: Model<B>,
624 O: Optimizer<M, B>,
625 {
626 self.optimizer.lock().set_learning_rate(lr);
627 }
628
629 /// Returns the current learning rate from the optimizer.
630 ///
631 /// # Returns
632 ///
633 /// The current learning rate value.
634 pub fn learning_rate<B>(&self) -> f32
635 where
636 B: Batch,
637 M: Model<B>,
638 O: Optimizer<M, B>,
639 {
640 self.optimizer.lock().learning_rate()
641 }
642
643 // TODO: Enable when auto_tuning fields are added to struct
644 // /// Returns the last auto-tuning update, if auto-tuning is enabled.
645 // ///
646 // /// # Returns
647 // ///
648 // /// The most recent auto-tuning update, or None if auto-tuning is disabled
649 // /// or no updates have occurred yet.
650 // #[must_use]
651 // pub fn last_auto_tuning_update(&self) -> Option<&auto_tuning::AutoTuningUpdate> {
652 // self.last_auto_tuning_update.as_ref()
653 // }
654
655 /// Returns the last recorded gradient norm.
656 ///
657 /// # Returns
658 ///
659 /// The gradient norm from the most recent backward pass.
660 #[must_use]
661 fn last_gradient_norm(&self) -> f32 {
662 self.state.gradient_norm
663 }
664
665 /// Collects per-layer gradient statistics.
666 ///
667 /// This is a stub implementation that distributes the global gradient
668 /// norm across dummy layers. A real implementation would track per-layer
669 /// gradients during the backward pass.
670 ///
671 /// # Returns
672 ///
673 /// `HashMap` of `layer_name` -> (`grad_norm`, `weight_norm`).
674 fn collect_layer_gradients(&self) -> std::collections::HashMap<String, (f32, f32)> {
675 let global_norm = self.last_gradient_norm();
676
677 // Stub: distribute gradient across typical transformer layers
678 // In a real implementation, this would come from actual per-layer tracking
679 let mut map = std::collections::HashMap::new();
680 map.insert("embed".to_string(), (global_norm * 0.8, 10.0));
681 map.insert("attention".to_string(), (global_norm * 1.0, 15.0));
682 map.insert("mlp".to_string(), (global_norm * 1.2, 20.0));
683 map.insert("lm_head".to_string(), (global_norm * 0.9, 8.0));
684 map
685 }
686}
687
688impl<M, O> HybridTrainer<M, O> {
689 /// Executes a single training step.
690 ///
691 /// This is the main entry point for the training loop. The trainer
692 /// automatically selects the appropriate phase (warmup, full, predict,
693 /// or correct) based on the current state and configuration.
694 ///
695 /// # Type Parameters
696 ///
697 /// * `B` - The batch type containing input data
698 ///
699 /// # Arguments
700 ///
701 /// * `batch` - The training batch to process
702 ///
703 /// # Returns
704 ///
705 /// A [`StepResult`] containing the loss, phase info, and prediction metadata.
706 ///
707 /// # Errors
708 ///
709 /// Returns an error if training diverges or encounters numerical issues.
710 /// The error includes a suggested recovery action when possible.
711 ///
712 /// # Example
713 ///
714 /// ```rust,ignore
715 /// for batch in dataloader {
716 /// let result = trainer.step(&batch)?;
717 /// println!("Step {}: loss={:.4}, phase={:?}",
718 /// trainer.current_step(),
719 /// result.loss,
720 /// result.phase
721 /// );
722 /// }
723 /// ```
724 pub fn step<B>(&mut self, batch: &B) -> HybridResult<StepResult>
725 where
726 B: Batch,
727 M: Model<B>,
728 O: Optimizer<M, B>,
729 {
730 let start_time = Instant::now();
731
732 // Update predictor confidence for phase controller
733 let confidence = self.dynamics_model.prediction_confidence(&self.state);
734 self.phase_controller.set_predictor_confidence(confidence);
735
736 // Track previous phase to detect transitions
737 let previous_phase = self.state.current_phase;
738
739 // Get phase from budget or request new decision
740 let phase = match &mut self.phase_budget {
741 Some((current_phase, remaining)) if *remaining > 0 => {
742 // Use current phase and decrement budget
743 *remaining -= 1;
744 *current_phase
745 }
746 _ => {
747 // Budget exhausted or None - get new phase decision
748 let decision = self.phase_controller.select_next_phase(&self.state);
749 let new_phase = decision.phase();
750 let steps = decision.steps();
751
752 // Set budget for remaining steps (steps-1 since we're using one now)
753 if steps > 1 {
754 self.phase_budget = Some((new_phase, steps - 1));
755 } else {
756 self.phase_budget = None;
757 }
758
759 new_phase
760 }
761 };
762
763 // Reset steps_in_current_phase counter on phase transitions
764 if phase != previous_phase {
765 self.state.steps_in_current_phase = 0;
766 self.state.current_phase = phase;
767
768 // Log VRAM usage at phase transitions
769 let vram_mb = self.vram_manager.last_vram_mb();
770 println!(
771 "Phase transition: {:?} → {:?} | VRAM: {} MB | Copies: {}",
772 previous_phase,
773 phase,
774 vram_mb,
775 crate::vram_manager::VramManager::total_copies()
776 );
777
778 // Flush accumulated weight deltas at phase transitions (VRAM optimization)
779 // This applies all accumulated deltas from the previous phase in one batch,
780 // minimizing model copies from Burn's .map() API
781 if let Some(merged_delta) = self.delta_accumulator.flush() {
782 let mut model = self.model.lock();
783 model.apply_weight_delta(&merged_delta)?;
784 }
785 }
786
787 // Execute the appropriate phase
788 let (loss, was_predicted, prediction_error) = match phase {
789 Phase::Warmup | Phase::Full => self.execute_full_step(batch)?,
790 Phase::Predict => self.execute_predict_step(batch)?,
791 Phase::Correct => self.execute_correct_step(batch)?,
792 };
793
794 // Check for divergence (skip during warmup - NaN values are expected initially)
795 if phase != Phase::Warmup {
796 let divergence_result = self.divergence_monitor.check(&self.state, prediction_error);
797 if divergence_result.level > error::DivergenceLevel::Caution {
798 let recovery = self
799 .phase_controller
800 .handle_divergence(divergence_result.level);
801 if !recovery.can_continue() {
802 return Err((
803 HybridTrainingError::PredictionDivergence {
804 actual: loss,
805 predicted: self.state.loss,
806 delta: (loss - self.state.loss).abs(),
807 step: self.state.step,
808 },
809 Some(recovery),
810 ));
811 }
812 // Reset phase budget to force Full training after divergence
813 self.phase_budget = Some((Phase::Full, self.config.full_steps));
814 }
815 }
816
817 // Auto-tuning integration: Update controller and apply recommendations
818 if self.auto_tuning.is_some() {
819 // Collect gradients before taking mutable borrow of auto_tuning
820 let layer_grads: Vec<(String, f32, f32)> = self
821 .collect_layer_gradients()
822 .into_iter()
823 .map(|(name, (grad_norm, weight_norm))| (name, grad_norm, weight_norm))
824 .collect();
825
826 #[allow(clippy::unnecessary_unwrap)]
827 let update = self.auto_tuning.as_mut().unwrap().update(
828 self.state.step,
829 loss,
830 self.state.gradient_norm,
831 &layer_grads,
832 confidence,
833 );
834
835 // Store update for external access (TODO: wire to optimizer when available)
836 self.last_auto_tuning_update = Some(update);
837 }
838
839 // Update state
840 self.state.step += 1;
841 self.state.loss = loss;
842 self.state.loss_history.push(loss);
843 self.state.loss_ema.update(loss); // CRITICAL FIX: Update EMA for stability confidence
844
845 // Update divergence monitor
846 self.divergence_monitor
847 .observe(loss, self.state.gradient_norm);
848
849 // Record metrics
850 let step_metrics = if self.config.collect_metrics {
851 Some(self.metrics.record_step_data(
852 self.state.step,
853 loss,
854 phase,
855 was_predicted,
856 prediction_error,
857 confidence,
858 ))
859 } else {
860 None
861 };
862
863 // Auto-checkpoint if enabled and interval reached OR VRAM critical
864 // Compute checkpoint state before taking mutable borrow
865 const VRAM_CHECKPOINT_THRESHOLD_MB: usize = 14_000; // 14 GB triggers emergency checkpoint
866 let vram_critical = self.vram_manager.last_vram_mb() > VRAM_CHECKPOINT_THRESHOLD_MB;
867
868 let checkpoint_to_save = if self.checkpoint_manager.as_ref().map_or(false, |mgr| {
869 mgr.should_save(self.state.step) || vram_critical
870 }) {
871 if vram_critical {
872 eprintln!(
873 "🚨 Emergency checkpoint triggered by high VRAM ({} MB > {} MB)",
874 self.vram_manager.last_vram_mb(),
875 VRAM_CHECKPOINT_THRESHOLD_MB
876 );
877 }
878 use crate::checkpoint::*;
879
880 Some(TrainingCheckpoint::new(
881 self.config.clone(),
882 self.state.clone(),
883 DynamicsState::default(),
884 ResidualStoreState::default(),
885 PhaseControllerState {
886 current_phase: self.phase_controller.current_phase(),
887 predictor_confidence: self.current_confidence(),
888 warmup_complete: self.phase_controller.is_warmup_complete(),
889 phase_stats: Vec::new(),
890 },
891 DivergenceMonitorState::default(),
892 CorrectorState::default(),
893 ))
894 } else {
895 None
896 };
897
898 // Now we can take mutable borrow to save
899 if let Some(checkpoint) = checkpoint_to_save {
900 if let Some(ref mut checkpoint_mgr) = self.checkpoint_manager {
901 // Save checkpoint (errors are logged but don't stop training)
902 if let Err((err, _)) = checkpoint_mgr.save(&checkpoint) {
903 eprintln!(
904 "Warning: Failed to save checkpoint at step {}: {}",
905 self.state.step, err
906 );
907 }
908 }
909 }
910
911 // Check if VRAM cleanup is needed (workaround for Burn's model.map() leak)
912 if self.vram_manager.should_cleanup() {
913 self.vram_manager.force_cleanup();
914 }
915
916 let elapsed_ms = start_time.elapsed().as_secs_f64() * 1000.0;
917
918 Ok(StepResult {
919 loss,
920 phase,
921 was_predicted,
922 prediction_error,
923 confidence,
924 gradient_norm: self.state.gradient_norm,
925 step_time_ms: elapsed_ms,
926 metrics: step_metrics,
927 })
928 }
929
930 /// Executes a full training step (forward + backward + optimizer step).
931 ///
932 /// Used during Warmup and Full phases.
933 fn execute_full_step<B>(&mut self, batch: &B) -> HybridResult<(f32, bool, Option<f32>)>
934 where
935 B: Batch,
936 M: Model<B>,
937 O: Optimizer<M, B>,
938 {
939 let mut model = self.model.lock();
940 let mut optimizer = self.optimizer.lock();
941
942 // Zero gradients
943 optimizer.zero_grad();
944
945 // Forward pass
946 let loss = model.forward(batch)?;
947
948 // Backward pass
949 let grad_info = model.backward()?;
950
951 // Optimizer step
952 optimizer.step(&mut *model, &grad_info)?;
953
954 // Update training state with gradient info
955 self.state.gradient_norm = grad_info.gradient_norm;
956 self.state
957 .gradient_norm_history
958 .push(grad_info.gradient_norm);
959
960 // Train the dynamics model during full steps (not warmup)
961 if self.phase_controller.is_warmup_complete() {
962 // Capture state features before model update
963 let state_features_before = self.state.compute_features();
964 let confidence = self.dynamics_model.prediction_confidence(&self.state);
965
966 // Get 1-step prediction to compute gradient residuals
967 let (prediction, _) = self.dynamics_model.predict_y_steps(&self.state, 1);
968 let predicted_loss = prediction.predicted_final_loss;
969
970 // Compute loss residual (actual - predicted)
971 let loss_residual = loss - predicted_loss;
972
973 // Create gradient residuals from per-param norms if available
974 let gradient_residuals = if let Some(ref per_param) = grad_info.per_param_norms {
975 per_param
976 .iter()
977 .enumerate()
978 .map(|(idx, &actual_norm)| residuals::LayerResidual {
979 layer_name: format!("layer_{}", idx),
980 magnitude: actual_norm,
981 compressed: None, // TODO: Add compression support
982 cosine_similarity: 1.0, // Perfect match when no prediction available
983 })
984 .collect()
985 } else {
986 Vec::new()
987 };
988
989 // Create and store residual for weight-level corrections
990 let residual = residuals::Residual {
991 step: self.state.step,
992 phase: Phase::Full,
993 prediction_horizon: 1,
994 loss_residual,
995 gradient_residuals,
996 state_features: state_features_before,
997 prediction_confidence: confidence,
998 };
999
1000 // Store the residual for future correction
1001 self.residual_store.add(residual.clone());
1002
1003 // Update the corrector's online model with this residual
1004 self.residual_corrector
1005 .update_from_residual(&residual, &self.state);
1006
1007 // Train the dynamics model
1008 self.dynamics_model
1009 .observe_gradient(&self.state, &grad_info);
1010 }
1011
1012 Ok((loss, false, None))
1013 }
1014
1015 /// Executes a predictive step (forward only, apply predicted weight delta).
1016 ///
1017 /// Uses the phase controller's `compute_predict_steps()` to determine the
1018 /// optimal prediction horizon based on current confidence and history,
1019 /// then calls `predict_y_steps()` with that horizon for multi-step
1020 /// prediction. This enables the dynamics model to predict further ahead
1021 /// when confidence is high, yielding greater training speedup.
1022 ///
1023 /// Used during Predict phase - skips backward pass for speedup.
1024 fn execute_predict_step<B>(&mut self, batch: &B) -> HybridResult<(f32, bool, Option<f32>)>
1025 where
1026 B: Batch,
1027 M: Model<B>,
1028 {
1029 let mut model = self.model.lock();
1030
1031 // Capture state before prediction for residual extraction
1032 let state_features_before = self.state.compute_features();
1033 let confidence = self.dynamics_model.prediction_confidence(&self.state);
1034
1035 // Compute adaptive prediction horizon based on confidence and history
1036 let y_steps = self.phase_controller.compute_predict_steps();
1037
1038 // Get multi-step prediction from dynamics model
1039 let (prediction, _uncertainty) = self.dynamics_model.predict_y_steps(&self.state, y_steps);
1040 let predicted_loss = prediction.predicted_final_loss;
1041
1042 // Apply predicted weight delta immediately (Burn limitation - can't defer)
1043 // TODO: Accumulation strategy doesn't work due to forward pass dependency
1044 model.apply_weight_delta(&prediction.weight_delta)?;
1045
1046 // Forward pass to get actual loss (for validation)
1047 let actual_loss = model.forward(batch)?;
1048
1049 // Clear forward state immediately (no backward in Predict phase)
1050 // This prevents memory accumulation from unused autodiff graphs
1051 model.clear_forward_state();
1052
1053 // Compute prediction error (absolute difference)
1054 let prediction_error = (actual_loss - predicted_loss).abs();
1055
1056 // Create and store residual (actual - predicted)
1057 let loss_residual = actual_loss - predicted_loss;
1058 let residual = residuals::Residual {
1059 step: self.state.step,
1060 phase: Phase::Predict,
1061 prediction_horizon: y_steps,
1062 loss_residual,
1063 gradient_residuals: Vec::new(), // No gradient info in predict phase
1064 state_features: state_features_before,
1065 prediction_confidence: confidence,
1066 };
1067
1068 // Store the residual for future correction
1069 self.residual_store.add(residual.clone());
1070
1071 // Update the corrector's online model with this residual
1072 self.residual_corrector
1073 .update_from_residual(&residual, &self.state);
1074
1075 // Check if micro-correction is needed (intra-horizon correction)
1076 self.state.steps_in_current_phase += 1;
1077
1078 if self.config.correction_interval > 0
1079 && self.state.steps_in_current_phase % self.config.correction_interval == 0
1080 {
1081 // Apply micro-correction without transitioning to Correct phase
1082 let correction = if self.residual_store.is_empty() {
1083 corrector::Correction::zero()
1084 } else {
1085 self.residual_corrector.compute_correction(
1086 &self.state,
1087 &self.residual_store,
1088 actual_loss,
1089 )
1090 };
1091
1092 // Apply weight correction if available and significant
1093 if let Some(ref weight_correction) = correction.weight_correction {
1094 model.apply_weight_delta(weight_correction)?;
1095 } else if correction.is_significant(0.01) {
1096 // Apply simple correction if loss correction is significant but no weight delta
1097 if let Some(simple_delta) = self
1098 .residual_corrector
1099 .compute_simple_correction(&self.state)
1100 {
1101 model.apply_weight_delta(&simple_delta)?;
1102 }
1103 }
1104
1105 // Record that we applied a micro-correction
1106 self.metrics.record_micro_correction();
1107 }
1108
1109 Ok((actual_loss, true, Some(prediction_error)))
1110 }
1111
1112 /// Executes a correction step (apply residual corrections).
1113 ///
1114 /// Used during Correct phase to adjust for accumulated prediction errors.
1115 /// Uses stored residuals from prediction phase to compute corrections.
1116 fn execute_correct_step<B>(&mut self, batch: &B) -> HybridResult<(f32, bool, Option<f32>)>
1117 where
1118 B: Batch,
1119 M: Model<B>,
1120 {
1121 let mut model = self.model.lock();
1122
1123 // Compute correction using stored residuals for context-aware adjustment
1124 let correction = if self.residual_store.is_empty() {
1125 // Fall back to simple correction if no residuals available
1126 corrector::Correction::zero()
1127 } else {
1128 // Use full correction with residual store for better estimates
1129 let predicted_loss = self.state.loss; // Use current loss as baseline
1130 self.residual_corrector.compute_correction(
1131 &self.state,
1132 &self.residual_store,
1133 predicted_loss,
1134 )
1135 };
1136
1137 // Apply weight delta correction immediately (Burn limitation - can't defer)
1138 if let Some(ref delta) = correction.weight_correction {
1139 model.apply_weight_delta(delta)?;
1140 } else if correction.is_significant(0.01) {
1141 // If no weight correction but loss correction is significant,
1142 // apply a simple scaled correction
1143 let simple_delta = self
1144 .residual_corrector
1145 .compute_simple_correction(&self.state);
1146 if let Some(delta) = simple_delta {
1147 model.apply_weight_delta(&delta)?;
1148 }
1149 }
1150
1151 // Forward pass to validate correction
1152 let loss = model.forward(batch)?;
1153
1154 // Clear forward state immediately (no backward in Correct phase)
1155 model.clear_forward_state();
1156
1157 // Compute how much the correction changed the loss (for metrics)
1158 let correction_effect = if correction.loss_correction.abs() > 0.001 {
1159 Some((self.state.loss - loss).abs())
1160 } else {
1161 None
1162 };
1163
1164 Ok((loss, false, correction_effect))
1165 }
1166
1167 /// Forces the trainer into full training mode for the specified number of steps.
1168 ///
1169 /// Useful for recovery from divergence or manual intervention.
1170 ///
1171 /// # Arguments
1172 ///
1173 /// * `steps` - Number of full steps to force
1174 pub fn force_full_phase(&mut self, steps: usize) {
1175 self.phase_controller.force_phase(Phase::Full);
1176 // Set the phase budget to enforce the requested number of steps
1177 self.phase_budget = Some((Phase::Full, steps));
1178 }
1179
1180 /// Saves a checkpoint of the trainer state.
1181 ///
1182 /// This saves all hybrid trainer state including training state, dynamics model,
1183 /// residual store, and phase controller state. It does NOT save model or optimizer
1184 /// state - those must be checkpointed separately via your deep learning framework.
1185 ///
1186 /// # Arguments
1187 ///
1188 /// * `path` - Path where the checkpoint should be saved
1189 ///
1190 /// # Returns
1191 ///
1192 /// `Ok(())` on success.
1193 ///
1194 /// # Errors
1195 ///
1196 /// Returns an error if checkpoint serialization or file I/O fails.
1197 ///
1198 /// # Example
1199 ///
1200 /// ```rust,ignore
1201 /// // Save hybrid trainer state
1202 /// trainer.save_checkpoint("checkpoints/hybrid_step_1000.bin")?;
1203 ///
1204 /// // User should also save model and optimizer separately
1205 /// model.save("checkpoints/model_step_1000.safetensors")?;
1206 /// optimizer.save("checkpoints/optimizer_step_1000.bin")?;
1207 /// ```
1208 pub fn save_checkpoint(&self, path: impl AsRef<std::path::Path>) -> HybridResult<()> {
1209 use crate::checkpoint::*;
1210
1211 // Extract serializable state from components
1212 let dynamics_state = DynamicsState::default(); // TODO: Extract from self.dynamics_model
1213 let residual_store_state = ResidualStoreState::default(); // TODO: Extract from self.residual_store
1214 let phase_controller_state = PhaseControllerState {
1215 current_phase: self.phase_controller.current_phase(),
1216 predictor_confidence: self.current_confidence(),
1217 warmup_complete: self.phase_controller.is_warmup_complete(),
1218 phase_stats: Vec::new(), // TODO: Extract stats
1219 };
1220 let divergence_monitor_state = DivergenceMonitorState::default(); // TODO: Extract from self.divergence_monitor
1221 let corrector_state = CorrectorState::default(); // TODO: Extract from self.residual_corrector
1222
1223 let checkpoint = TrainingCheckpoint::new(
1224 self.config.clone(),
1225 self.state.clone(),
1226 dynamics_state,
1227 residual_store_state,
1228 phase_controller_state,
1229 divergence_monitor_state,
1230 corrector_state,
1231 );
1232
1233 checkpoint.save(path)
1234 }
1235
1236 /// Loads a checkpoint and creates a new trainer with the given model and optimizer.
1237 ///
1238 /// This restores all hybrid trainer state from a checkpoint file. The model and
1239 /// optimizer must be provided separately since they're framework-specific.
1240 ///
1241 /// # Type Parameters
1242 ///
1243 /// * `M` - The model type
1244 /// * `O` - The optimizer type
1245 ///
1246 /// # Arguments
1247 ///
1248 /// * `path` - Path to the checkpoint file
1249 /// * `model` - The model to train (should be loaded from a separate checkpoint)
1250 /// * `optimizer` - The optimizer (should be loaded from a separate checkpoint)
1251 ///
1252 /// # Returns
1253 ///
1254 /// A new `HybridTrainer` instance with restored state.
1255 ///
1256 /// # Errors
1257 ///
1258 /// Returns an error if checkpoint loading or deserialization fails.
1259 ///
1260 /// # Example
1261 ///
1262 /// ```rust,ignore
1263 /// // Load model and optimizer from their checkpoints
1264 /// let model = MyModel::load("checkpoints/model_step_1000.safetensors")?;
1265 /// let optimizer = MyOptimizer::load("checkpoints/optimizer_step_1000.bin")?;
1266 ///
1267 /// // Load hybrid trainer state
1268 /// let trainer = HybridTrainer::load_checkpoint(
1269 /// "checkpoints/hybrid_step_1000.bin",
1270 /// model,
1271 /// optimizer
1272 /// )?;
1273 /// ```
1274 pub fn load_checkpoint(
1275 path: impl AsRef<std::path::Path>,
1276 model: M,
1277 optimizer: O,
1278 ) -> HybridResult<Self> {
1279 use crate::checkpoint::*;
1280
1281 let checkpoint = TrainingCheckpoint::load(path)?;
1282
1283 // Reconstruct trainer components from checkpoint state
1284 let phase_controller = phases::DefaultPhaseController::new(&checkpoint.config);
1285 let dynamics_model = dynamics::RSSMLite::new(&checkpoint.config.predictor_config)?;
1286 let divergence_monitor = divergence::DivergenceMonitor::new(&checkpoint.config);
1287 let residual_corrector = corrector::ResidualCorrector::new(&checkpoint.config);
1288 let residual_store = residuals::ResidualStore::new(1000);
1289 let metrics = metrics::MetricsCollector::new(checkpoint.config.collect_metrics);
1290
1291 // TODO: Restore state into dynamics_model, residual_store, etc. from checkpoint
1292
1293 // Initialize auto-tuning controller if config provided
1294 let auto_tuning = if let Some(auto_config) = checkpoint.config.auto_tuning_config.clone() {
1295 let max_steps = checkpoint.config.max_steps.unwrap_or(10000);
1296 Some(auto_tuning::AutoTuningController::new(
1297 auto_config,
1298 max_steps,
1299 ))
1300 } else {
1301 None
1302 };
1303
1304 // Note: checkpoint_manager is NOT restored from checkpoint
1305 // It will be re-initialized if needed based on the config
1306 let checkpoint_manager = if checkpoint.config.checkpoint_config.save_interval > 0 {
1307 let checkpoint_dir = std::path::PathBuf::from("./checkpoints");
1308 Some(checkpoint::CheckpointManager::new(
1309 checkpoint_dir,
1310 checkpoint.config.checkpoint_config.save_interval,
1311 checkpoint.config.checkpoint_config.keep_last_n,
1312 )?)
1313 } else {
1314 None
1315 };
1316
1317 Ok(Self {
1318 model: Arc::new(parking_lot::Mutex::new(model)),
1319 optimizer: Arc::new(parking_lot::Mutex::new(optimizer)),
1320 config: checkpoint.config,
1321 state: checkpoint.training_state,
1322 phase_controller,
1323 dynamics_model,
1324 divergence_monitor,
1325 residual_corrector,
1326 residual_store,
1327 metrics,
1328 phase_budget: None,
1329 auto_tuning,
1330 last_auto_tuning_update: None,
1331 checkpoint_manager,
1332 delta_accumulator: delta_accumulator::DeltaAccumulator::new(),
1333 vram_manager: vram_manager::VramManager::new(),
1334 })
1335 }
1336}
1337
1338/// Result of a single training step.
1339///
1340/// Contains the loss value, phase information, and prediction metadata
1341/// for monitoring training progress and predictor accuracy.
1342#[derive(Debug, Clone)]
1343pub struct StepResult {
1344 /// The loss value for this step.
1345 pub loss: f32,
1346
1347 /// The phase during which this step was executed.
1348 pub phase: Phase,
1349
1350 /// Whether this step used predicted gradients (true) or computed gradients (false).
1351 pub was_predicted: bool,
1352
1353 /// The error between predicted and actual loss (if applicable).
1354 pub prediction_error: Option<f32>,
1355
1356 /// The predictor's confidence for this step.
1357 pub confidence: f32,
1358
1359 /// Gradient norm from the backward pass (0.0 if predicted step).
1360 pub gradient_norm: f32,
1361
1362 /// Wall-clock time for this step in milliseconds.
1363 pub step_time_ms: f64,
1364
1365 /// Detailed metrics (if collection is enabled).
1366 pub metrics: Option<metrics::StepMetrics>,
1367}
1368
1369/// Prelude module for convenient imports.
1370///
1371/// # Example
1372///
1373/// ```
1374/// use hybrid_predict_trainer_rs::prelude::*;
1375/// ```
1376pub mod prelude {
1377 pub use crate::{
1378 Batch, GradientInfo, HybridResult, HybridTrainer, HybridTrainerConfig, HybridTrainingError,
1379 Model, Optimizer, Phase, PhaseDecision, RecoveryAction, StepResult, TrainingState,
1380 };
1381}
1382
1383#[cfg(test)]
1384mod tests {
1385 use super::*;
1386
1387 #[test]
1388 fn test_default_config() {
1389 let config = HybridTrainerConfig::default();
1390 assert_eq!(config.warmup_steps, 100);
1391 assert_eq!(config.full_steps, 20);
1392 assert_eq!(config.max_predict_steps, 15); // Updated: default reduced for VRAM optimization
1393 assert!((config.confidence_threshold - 0.85).abs() < f32::EPSILON);
1394 }
1395
1396 #[test]
1397 fn test_correction_interval_config() {
1398 // Test default (disabled)
1399 let config = HybridTrainerConfig::default();
1400 assert_eq!(config.correction_interval, 0);
1401
1402 // Test builder pattern
1403 let config = HybridTrainerConfig::builder()
1404 .correction_interval(10)
1405 .build();
1406 assert_eq!(config.correction_interval, 10);
1407 }
1408
1409 #[test]
1410 fn test_steps_in_current_phase_counter() {
1411 let mut state = TrainingState::new();
1412 assert_eq!(state.steps_in_current_phase, 0);
1413
1414 // Simulate phase transition
1415 state.enter_phase(Phase::Predict);
1416 assert_eq!(state.steps_in_current_phase, 0);
1417
1418 // Simulate steps within phase
1419 state.steps_in_current_phase = 5;
1420 assert_eq!(state.steps_in_current_phase, 5);
1421
1422 // Phase transition should reset
1423 state.enter_phase(Phase::Correct);
1424 assert_eq!(state.steps_in_current_phase, 0);
1425 }
1426}