Skip to main content

hybrid_predict_trainer_rs/auto_tuning/
controller.rs

1//! Automatic tuning controller that orchestrates health scoring, gradient tuning,
2//! and plateau detection to provide unified training optimization recommendations.
3//!
4//! This module integrates the `auto_tuning` subsystems into a single controller
5//! that can be integrated into the `HybridTrainer`'s `step()` method.
6
7use super::gradient_tuner::{GradientRangeTuner, PhaseGradientThresholds};
8use super::health_scorer::{HealthClassification, HealthRecommendation, HealthScorer};
9use super::plateau_detector::{PlateauDetector, PlateauStatus};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13/// Configuration for the automatic tuning controller.
14///
15/// Controls the behavior of health scoring, gradient tuning, and plateau detection.
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct AutoTuningConfig {
18    /// Window size for health scoring (number of steps to consider).
19    pub health_window: usize,
20
21    /// Window size for plateau detection.
22    pub plateau_window: usize,
23
24    /// Minimum velocity threshold for plateau detection (absolute value).
25    pub plateau_velocity_threshold: f32,
26
27    /// Whether to enable automatic warmup restarts.
28    pub enable_warmup_restart: bool,
29
30    /// Learning rate multiplier for warmup restart (e.g., 1.5).
31    pub warmup_restart_lr_multiplier: f32,
32
33    /// Cooldown steps between warmup restarts.
34    pub warmup_restart_cooldown: u64,
35
36    /// Maximum number of warmup restarts allowed.
37    pub max_warmup_restarts: usize,
38
39    /// Enable adaptive gradient clipping.
40    pub enable_gradient_clipping: bool,
41
42    /// AGC lambda parameter for per-layer clipping.
43    pub agc_lambda: f32,
44
45    /// Phase-specific gradient thresholds.
46    pub phase_thresholds: HashMap<String, PhaseGradientThresholds>,
47}
48
49impl Default for AutoTuningConfig {
50    fn default() -> Self {
51        let mut phase_thresholds = HashMap::new();
52        phase_thresholds.insert(
53            "warmup".to_string(),
54            PhaseGradientThresholds::new(0.01, 0.5),
55        );
56        phase_thresholds.insert("early".to_string(), PhaseGradientThresholds::new(0.1, 1.0));
57        phase_thresholds.insert("mid".to_string(), PhaseGradientThresholds::new(0.3, 0.8));
58        phase_thresholds.insert("late".to_string(), PhaseGradientThresholds::new(0.1, 0.5));
59
60        Self {
61            health_window: 50,
62            plateau_window: 50,
63            plateau_velocity_threshold: 0.001,
64            enable_warmup_restart: true,
65            warmup_restart_lr_multiplier: 1.5,
66            warmup_restart_cooldown: 500,
67            max_warmup_restarts: 3,
68            enable_gradient_clipping: true,
69            agc_lambda: 0.01,
70            phase_thresholds,
71        }
72    }
73}
74
75/// Update result from the auto-tuning controller.
76///
77/// Contains recommended actions and state information for the training loop.
78#[derive(Debug, Clone)]
79pub struct AutoTuningUpdate {
80    /// Overall training health classification.
81    pub health: HealthClassification,
82
83    /// Recommended actions to improve training.
84    pub recommendations: Vec<HealthRecommendation>,
85
86    /// Plateau detection status.
87    pub plateau_status: PlateauStatus,
88
89    /// Per-layer gradient clipping factors (if enabled).
90    ///
91    /// Maps layer name to clipping coefficient. Coefficient of 1.0 means no clipping.
92    pub layer_clip_factors: HashMap<String, f32>,
93
94    /// Learning rate multiplier for warmup restart (Some(value) if restart is recommended).
95    pub warmup_restart: Option<f32>,
96
97    /// Current training progress percentage (0-100).
98    pub progress_pct: f32,
99
100    /// Loss velocity (negative is good).
101    pub velocity: f32,
102
103    /// Loss acceleration.
104    pub acceleration: f32,
105
106    /// Gradient entropy score [0, 1].
107    pub gradient_entropy: f32,
108
109    /// Prediction accuracy score [0, 1].
110    pub prediction_accuracy: f32,
111
112    /// Gradient stability score [0, 1].
113    pub gradient_stability: f32,
114
115    /// Overall health score [0, 1].
116    pub health_score: f32,
117}
118
119impl AutoTuningUpdate {
120    /// Returns true if a warmup restart is recommended.
121    #[must_use]
122    pub fn should_restart(&self) -> bool {
123        self.warmup_restart.is_some()
124    }
125
126    /// Returns true if health is critical and requires immediate action.
127    #[must_use]
128    pub fn is_critical(&self) -> bool {
129        self.health == HealthClassification::Critical
130    }
131
132    /// Returns true if any gradient clipping is recommended.
133    #[must_use]
134    pub fn has_clipping(&self) -> bool {
135        self.layer_clip_factors.values().any(|&coeff| coeff < 1.0)
136    }
137}
138
139/// Automatic tuning controller that orchestrates health monitoring and tuning.
140///
141/// Integrates health scoring, gradient tuning, and plateau detection to provide
142/// unified training optimization recommendations.
143pub struct AutoTuningController {
144    config: AutoTuningConfig,
145    health_scorer: HealthScorer,
146    gradient_tuner: GradientRangeTuner,
147    plateau_detector: PlateauDetector,
148    max_steps: u64,
149    warmup_restart_count: usize,
150    last_restart_step: Option<u64>,
151    current_step: u64,
152}
153
154impl AutoTuningController {
155    /// Creates a new auto-tuning controller.
156    ///
157    /// # Arguments
158    ///
159    /// * `config` - Configuration for auto-tuning behavior
160    /// * `max_steps` - Total number of training steps (for progress percentage)
161    #[must_use]
162    pub fn new(config: AutoTuningConfig, max_steps: u64) -> Self {
163        let health_scorer = HealthScorer::new(config.health_window, max_steps as usize);
164        let gradient_tuner = GradientRangeTuner::with_agc_lambda(config.agc_lambda);
165        let plateau_detector =
166            PlateauDetector::new(config.plateau_window, config.plateau_velocity_threshold);
167
168        Self {
169            config,
170            health_scorer,
171            gradient_tuner,
172            plateau_detector,
173            max_steps,
174            warmup_restart_count: 0,
175            last_restart_step: None,
176            current_step: 0,
177        }
178    }
179
180    /// Updates the controller with new training step data.
181    ///
182    /// # Arguments
183    ///
184    /// * `step` - Current training step
185    /// * `loss` - Current loss value
186    /// * `gradient_norm` - Global gradient norm
187    /// * `layer_gradients` - Per-layer gradient statistics (name, `grad_norm`, `weight_norm`)
188    /// * `confidence` - Predictor confidence [0, 1]
189    ///
190    /// # Returns
191    ///
192    /// An `AutoTuningUpdate` containing recommendations and state information.
193    pub fn update(
194        &mut self,
195        step: u64,
196        loss: f32,
197        gradient_norm: f32,
198        layer_gradients: &[(String, f32, f32)],
199        confidence: f32,
200    ) -> AutoTuningUpdate {
201        self.current_step = step;
202        let progress_pct = (step as f32 / self.max_steps as f32) * 100.0;
203
204        // Update plateau detector
205        let plateau_status = self.plateau_detector.update(loss, step, progress_pct);
206
207        // Compute loss dynamics from plateau detector
208        let velocity = self.plateau_detector.velocity();
209        let acceleration = self.plateau_detector.acceleration();
210
211        // Compute gradient entropy and stability (simplified for now)
212        let gradient_entropy = self.compute_gradient_entropy(gradient_norm);
213        let gradient_stability = self.compute_gradient_stability(gradient_norm);
214
215        // Compute prediction accuracy from confidence
216        let prediction_accuracy = confidence;
217
218        // Update health scorer
219        let health_result = self.health_scorer.compute(
220            velocity,
221            acceleration,
222            gradient_entropy,
223            prediction_accuracy,
224            gradient_stability,
225            progress_pct,
226        );
227
228        // Update gradient tuner with layer statistics
229        let _training_phase = self.determine_training_phase(progress_pct);
230
231        // Convert layer_gradients to HashMap format for gradient_tuner
232        let mut layer_grads_map = HashMap::new();
233        for (layer_name, grad_norm, weight_norm) in layer_gradients {
234            layer_grads_map.insert(layer_name.clone(), (*grad_norm, *weight_norm));
235        }
236
237        let clip_factors =
238            self.gradient_tuner
239                .update(gradient_norm, &layer_grads_map, progress_pct);
240
241        // Layer clip factors come from gradient tuner directly
242        let layer_clip_factors = clip_factors;
243
244        // Check if warmup restart is needed
245        let warmup_restart = if self.config.enable_warmup_restart
246            && plateau_status == PlateauStatus::Stuck
247            && self.warmup_restart_count < self.config.max_warmup_restarts
248        {
249            // Check cooldown
250            let can_restart = match self.last_restart_step {
251                None => true,
252                Some(last_step) => step - last_step >= self.config.warmup_restart_cooldown,
253            };
254
255            if can_restart {
256                self.warmup_restart_count += 1;
257                self.last_restart_step = Some(step);
258                Some(self.config.warmup_restart_lr_multiplier)
259            } else {
260                None
261            }
262        } else {
263            None
264        };
265
266        AutoTuningUpdate {
267            health: health_result.classification,
268            recommendations: health_result.recommendations,
269            plateau_status,
270            layer_clip_factors,
271            warmup_restart,
272            progress_pct,
273            velocity,
274            acceleration,
275            gradient_entropy,
276            prediction_accuracy,
277            gradient_stability,
278            health_score: health_result.overall,
279        }
280    }
281
282    /// Determines the current training phase based on progress percentage.
283    fn determine_training_phase(&self, progress_pct: f32) -> &str {
284        if progress_pct < 10.0 {
285            "warmup"
286        } else if progress_pct < 40.0 {
287            "early"
288        } else if progress_pct < 80.0 {
289            "mid"
290        } else {
291            "late"
292        }
293    }
294
295    /// Computes gradient entropy score from gradient norm.
296    ///
297    /// This is a simplified implementation. A more sophisticated version would
298    /// track gradient variance and compute actual entropy.
299    fn compute_gradient_entropy(&self, gradient_norm: f32) -> f32 {
300        // Normalize to [0, 1] assuming reasonable gradient range [0.01, 10.0]
301        let log_norm = gradient_norm.max(0.01).ln();
302        let log_min = 0.01_f32.ln();
303        let log_max = 10.0_f32.ln();
304
305        ((log_norm - log_min) / (log_max - log_min)).clamp(0.0, 1.0)
306    }
307
308    /// Computes gradient stability score from gradient norm.
309    ///
310    /// This is a simplified implementation. A more sophisticated version would
311    /// track variance of gradient norms over time.
312    fn compute_gradient_stability(&self, gradient_norm: f32) -> f32 {
313        // For now, use inverse of gradient magnitude as stability proxy
314        // High gradients = low stability, low gradients = high stability
315        // Map [0.01, 10.0] to [1.0, 0.0]
316        let normalized = (gradient_norm.ln() - 0.01_f32.ln()) / (10.0_f32.ln() - 0.01_f32.ln());
317        (1.0 - normalized).clamp(0.0, 1.0)
318    }
319
320    /// Returns the current warmup restart count.
321    #[must_use]
322    pub fn warmup_restart_count(&self) -> usize {
323        self.warmup_restart_count
324    }
325
326    /// Returns the step at which the last warmup restart occurred.
327    #[must_use]
328    pub fn last_restart_step(&self) -> Option<u64> {
329        self.last_restart_step
330    }
331}
332
333#[cfg(test)]
334mod tests {
335    use super::*;
336
337    #[test]
338    fn test_auto_tuning_config_default() {
339        let config = AutoTuningConfig::default();
340        assert_eq!(config.health_window, 50);
341        assert_eq!(config.plateau_window, 50);
342        assert!(config.enable_warmup_restart);
343        assert!(config.enable_gradient_clipping);
344    }
345
346    #[test]
347    fn test_auto_tuning_controller_creation() {
348        let config = AutoTuningConfig::default();
349        let controller = AutoTuningController::new(config, 1000);
350        assert_eq!(controller.warmup_restart_count(), 0);
351        assert!(controller.last_restart_step().is_none());
352    }
353
354    #[test]
355    fn test_auto_tuning_update() {
356        let config = AutoTuningConfig::default();
357        let mut controller = AutoTuningController::new(config, 1000);
358
359        let layer_grads = vec![
360            ("embed".to_string(), 0.5, 10.0),
361            ("attention".to_string(), 0.8, 15.0),
362        ];
363
364        let update = controller.update(0, 1.5, 0.6, &layer_grads, 0.9);
365
366        assert!(update.progress_pct >= 0.0 && update.progress_pct <= 100.0);
367        assert!(!update.should_restart()); // No restart on first update with reasonable loss
368    }
369
370    #[test]
371    fn test_auto_tuning_update_has_recommendations() {
372        let update = AutoTuningUpdate {
373            health: HealthClassification::Good,
374            recommendations: vec![],
375            plateau_status: PlateauStatus::Normal,
376            layer_clip_factors: HashMap::new(),
377            warmup_restart: None,
378            progress_pct: 50.0,
379            velocity: -0.01,
380            acceleration: 0.001,
381            gradient_entropy: 0.5,
382            prediction_accuracy: 0.9,
383            gradient_stability: 0.8,
384            health_score: 0.75,
385        };
386
387        assert!(!update.should_restart());
388        assert!(!update.is_critical());
389        assert!(!update.has_clipping());
390    }
391}