hybrid_predict_trainer_rs/auto_tuning/
controller.rs1use super::gradient_tuner::{GradientRangeTuner, PhaseGradientThresholds};
8use super::health_scorer::{HealthClassification, HealthRecommendation, HealthScorer};
9use super::plateau_detector::{PlateauDetector, PlateauStatus};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct AutoTuningConfig {
18 pub health_window: usize,
20
21 pub plateau_window: usize,
23
24 pub plateau_velocity_threshold: f32,
26
27 pub enable_warmup_restart: bool,
29
30 pub warmup_restart_lr_multiplier: f32,
32
33 pub warmup_restart_cooldown: u64,
35
36 pub max_warmup_restarts: usize,
38
39 pub enable_gradient_clipping: bool,
41
42 pub agc_lambda: f32,
44
45 pub phase_thresholds: HashMap<String, PhaseGradientThresholds>,
47}
48
49impl Default for AutoTuningConfig {
50 fn default() -> Self {
51 let mut phase_thresholds = HashMap::new();
52 phase_thresholds.insert(
53 "warmup".to_string(),
54 PhaseGradientThresholds::new(0.01, 0.5),
55 );
56 phase_thresholds.insert("early".to_string(), PhaseGradientThresholds::new(0.1, 1.0));
57 phase_thresholds.insert("mid".to_string(), PhaseGradientThresholds::new(0.3, 0.8));
58 phase_thresholds.insert("late".to_string(), PhaseGradientThresholds::new(0.1, 0.5));
59
60 Self {
61 health_window: 50,
62 plateau_window: 50,
63 plateau_velocity_threshold: 0.001,
64 enable_warmup_restart: true,
65 warmup_restart_lr_multiplier: 1.5,
66 warmup_restart_cooldown: 500,
67 max_warmup_restarts: 3,
68 enable_gradient_clipping: true,
69 agc_lambda: 0.01,
70 phase_thresholds,
71 }
72 }
73}
74
75#[derive(Debug, Clone)]
79pub struct AutoTuningUpdate {
80 pub health: HealthClassification,
82
83 pub recommendations: Vec<HealthRecommendation>,
85
86 pub plateau_status: PlateauStatus,
88
89 pub layer_clip_factors: HashMap<String, f32>,
93
94 pub warmup_restart: Option<f32>,
96
97 pub progress_pct: f32,
99
100 pub velocity: f32,
102
103 pub acceleration: f32,
105
106 pub gradient_entropy: f32,
108
109 pub prediction_accuracy: f32,
111
112 pub gradient_stability: f32,
114
115 pub health_score: f32,
117}
118
119impl AutoTuningUpdate {
120 #[must_use]
122 pub fn should_restart(&self) -> bool {
123 self.warmup_restart.is_some()
124 }
125
126 #[must_use]
128 pub fn is_critical(&self) -> bool {
129 self.health == HealthClassification::Critical
130 }
131
132 #[must_use]
134 pub fn has_clipping(&self) -> bool {
135 self.layer_clip_factors.values().any(|&coeff| coeff < 1.0)
136 }
137}
138
139pub struct AutoTuningController {
144 config: AutoTuningConfig,
145 health_scorer: HealthScorer,
146 gradient_tuner: GradientRangeTuner,
147 plateau_detector: PlateauDetector,
148 max_steps: u64,
149 warmup_restart_count: usize,
150 last_restart_step: Option<u64>,
151 current_step: u64,
152}
153
154impl AutoTuningController {
155 #[must_use]
162 pub fn new(config: AutoTuningConfig, max_steps: u64) -> Self {
163 let health_scorer = HealthScorer::new(config.health_window, max_steps as usize);
164 let gradient_tuner = GradientRangeTuner::with_agc_lambda(config.agc_lambda);
165 let plateau_detector =
166 PlateauDetector::new(config.plateau_window, config.plateau_velocity_threshold);
167
168 Self {
169 config,
170 health_scorer,
171 gradient_tuner,
172 plateau_detector,
173 max_steps,
174 warmup_restart_count: 0,
175 last_restart_step: None,
176 current_step: 0,
177 }
178 }
179
180 pub fn update(
194 &mut self,
195 step: u64,
196 loss: f32,
197 gradient_norm: f32,
198 layer_gradients: &[(String, f32, f32)],
199 confidence: f32,
200 ) -> AutoTuningUpdate {
201 self.current_step = step;
202 let progress_pct = (step as f32 / self.max_steps as f32) * 100.0;
203
204 let plateau_status = self.plateau_detector.update(loss, step, progress_pct);
206
207 let velocity = self.plateau_detector.velocity();
209 let acceleration = self.plateau_detector.acceleration();
210
211 let gradient_entropy = self.compute_gradient_entropy(gradient_norm);
213 let gradient_stability = self.compute_gradient_stability(gradient_norm);
214
215 let prediction_accuracy = confidence;
217
218 let health_result = self.health_scorer.compute(
220 velocity,
221 acceleration,
222 gradient_entropy,
223 prediction_accuracy,
224 gradient_stability,
225 progress_pct,
226 );
227
228 let _training_phase = self.determine_training_phase(progress_pct);
230
231 let mut layer_grads_map = HashMap::new();
233 for (layer_name, grad_norm, weight_norm) in layer_gradients {
234 layer_grads_map.insert(layer_name.clone(), (*grad_norm, *weight_norm));
235 }
236
237 let clip_factors =
238 self.gradient_tuner
239 .update(gradient_norm, &layer_grads_map, progress_pct);
240
241 let layer_clip_factors = clip_factors;
243
244 let warmup_restart = if self.config.enable_warmup_restart
246 && plateau_status == PlateauStatus::Stuck
247 && self.warmup_restart_count < self.config.max_warmup_restarts
248 {
249 let can_restart = match self.last_restart_step {
251 None => true,
252 Some(last_step) => step - last_step >= self.config.warmup_restart_cooldown,
253 };
254
255 if can_restart {
256 self.warmup_restart_count += 1;
257 self.last_restart_step = Some(step);
258 Some(self.config.warmup_restart_lr_multiplier)
259 } else {
260 None
261 }
262 } else {
263 None
264 };
265
266 AutoTuningUpdate {
267 health: health_result.classification,
268 recommendations: health_result.recommendations,
269 plateau_status,
270 layer_clip_factors,
271 warmup_restart,
272 progress_pct,
273 velocity,
274 acceleration,
275 gradient_entropy,
276 prediction_accuracy,
277 gradient_stability,
278 health_score: health_result.overall,
279 }
280 }
281
282 fn determine_training_phase(&self, progress_pct: f32) -> &str {
284 if progress_pct < 10.0 {
285 "warmup"
286 } else if progress_pct < 40.0 {
287 "early"
288 } else if progress_pct < 80.0 {
289 "mid"
290 } else {
291 "late"
292 }
293 }
294
295 fn compute_gradient_entropy(&self, gradient_norm: f32) -> f32 {
300 let log_norm = gradient_norm.max(0.01).ln();
302 let log_min = 0.01_f32.ln();
303 let log_max = 10.0_f32.ln();
304
305 ((log_norm - log_min) / (log_max - log_min)).clamp(0.0, 1.0)
306 }
307
308 fn compute_gradient_stability(&self, gradient_norm: f32) -> f32 {
313 let normalized = (gradient_norm.ln() - 0.01_f32.ln()) / (10.0_f32.ln() - 0.01_f32.ln());
317 (1.0 - normalized).clamp(0.0, 1.0)
318 }
319
320 #[must_use]
322 pub fn warmup_restart_count(&self) -> usize {
323 self.warmup_restart_count
324 }
325
326 #[must_use]
328 pub fn last_restart_step(&self) -> Option<u64> {
329 self.last_restart_step
330 }
331}
332
333#[cfg(test)]
334mod tests {
335 use super::*;
336
337 #[test]
338 fn test_auto_tuning_config_default() {
339 let config = AutoTuningConfig::default();
340 assert_eq!(config.health_window, 50);
341 assert_eq!(config.plateau_window, 50);
342 assert!(config.enable_warmup_restart);
343 assert!(config.enable_gradient_clipping);
344 }
345
346 #[test]
347 fn test_auto_tuning_controller_creation() {
348 let config = AutoTuningConfig::default();
349 let controller = AutoTuningController::new(config, 1000);
350 assert_eq!(controller.warmup_restart_count(), 0);
351 assert!(controller.last_restart_step().is_none());
352 }
353
354 #[test]
355 fn test_auto_tuning_update() {
356 let config = AutoTuningConfig::default();
357 let mut controller = AutoTuningController::new(config, 1000);
358
359 let layer_grads = vec![
360 ("embed".to_string(), 0.5, 10.0),
361 ("attention".to_string(), 0.8, 15.0),
362 ];
363
364 let update = controller.update(0, 1.5, 0.6, &layer_grads, 0.9);
365
366 assert!(update.progress_pct >= 0.0 && update.progress_pct <= 100.0);
367 assert!(!update.should_restart()); }
369
370 #[test]
371 fn test_auto_tuning_update_has_recommendations() {
372 let update = AutoTuningUpdate {
373 health: HealthClassification::Good,
374 recommendations: vec![],
375 plateau_status: PlateauStatus::Normal,
376 layer_clip_factors: HashMap::new(),
377 warmup_restart: None,
378 progress_pct: 50.0,
379 velocity: -0.01,
380 acceleration: 0.001,
381 gradient_entropy: 0.5,
382 prediction_accuracy: 0.9,
383 gradient_stability: 0.8,
384 health_score: 0.75,
385 };
386
387 assert!(!update.should_restart());
388 assert!(!update.is_critical());
389 assert!(!update.has_clipping());
390 }
391}