use crate::amp::{GradScaler, ScalerStats, StepResult};
use crate::optim::Optimizer;
use crate::tensor::Tensor;
pub struct AMPOptimizer<O: Optimizer> {
optimizer: O,
scaler: GradScaler,
param_groups: Vec<ParamGroup>,
step_count: usize,
overflow_count: usize,
successful_steps: usize,
}
#[derive(Debug, Clone)]
pub struct ParamGroup {
pub param_ids: Vec<usize>,
pub clip_gradients: bool,
pub max_grad_norm: Option<f32>,
pub use_amp: bool,
}
impl<O: Optimizer> AMPOptimizer<O> {
pub fn new(optimizer: O, scaler_config: Option<GradScaler>) -> Self {
let scaler = scaler_config.unwrap_or_else(GradScaler::default);
Self {
optimizer,
scaler,
param_groups: Vec::new(),
step_count: 0,
overflow_count: 0,
successful_steps: 0,
}
}
pub fn add_param_group(&mut self, group: ParamGroup) {
self.param_groups.push(group);
}
pub fn step(&mut self, params: &[Tensor<f32>], grads: &mut [Tensor<f32>]) -> StepResult {
self.step_count += 1;
if self.param_groups.is_empty() {
let result = self
.scaler
.step_with_clipping(&mut self.optimizer, params, grads, None);
self.update_stats(&result);
result
} else {
self.step_with_groups(params, grads)
}
}
fn step_with_groups(
&mut self,
params: &[Tensor<f32>],
grads: &mut [Tensor<f32>],
) -> StepResult {
let mut overall_result = StepResult::Success {
scale: 1.0,
grad_norm: None,
};
for group in &self.param_groups {
if !group.use_amp {
for ¶m_id in &group.param_ids {
if param_id < params.len() && param_id < grads.len() {
if group.clip_gradients {
if let Some(max_norm) = group.max_grad_norm {
let mut single_grad = vec![grads[param_id].clone()];
crate::amp::dtype_utils::utils::clip_grad_norm(
&mut single_grad,
max_norm,
);
grads[param_id] = single_grad.into_iter().next().unwrap();
}
}
self.optimizer.step(¶ms[param_id], &grads[param_id]);
}
}
} else {
let group_params: Vec<_> = group
.param_ids
.iter()
.filter_map(|&id| {
if id < params.len() {
Some(params[id].clone())
} else {
None
}
})
.collect();
let mut group_grads: Vec<_> = group
.param_ids
.iter()
.filter_map(|&id| {
if id < grads.len() {
Some(grads[id].clone())
} else {
None
}
})
.collect();
if !group_params.is_empty() && !group_grads.is_empty() {
let result = self.scaler.step_with_clipping(
&mut self.optimizer,
&group_params,
&mut group_grads,
group.max_grad_norm,
);
for (i, ¶m_id) in group.param_ids.iter().enumerate() {
if param_id < grads.len() && i < group_grads.len() {
grads[param_id] = group_grads[i].clone();
}
}
overall_result = match (&overall_result, &result) {
(_, StepResult::Overflow { .. }) => result,
(_, StepResult::InfNan { .. }) => result,
(StepResult::Success { .. }, _) => result,
_ => overall_result,
};
}
}
}
self.update_stats(&overall_result);
overall_result
}
fn update_stats(&mut self, result: &StepResult) {
match result {
StepResult::Success { .. } => {
self.successful_steps += 1;
}
StepResult::Overflow { .. } | StepResult::InfNan { .. } => {
self.overflow_count += 1;
}
}
}
pub fn get_training_stats(&self) -> TrainingStats {
let overflow_rate = if self.step_count > 0 {
self.overflow_count as f32 / self.step_count as f32
} else {
0.0
};
TrainingStats {
total_steps: self.step_count,
successful_steps: self.successful_steps,
overflow_count: self.overflow_count,
overflow_rate,
scaler_stats: self.scaler.get_stats(),
}
}
pub fn optimizer(&self) -> &O {
&self.optimizer
}
pub fn optimizer_mut(&mut self) -> &mut O {
&mut self.optimizer
}
pub fn scaler(&self) -> &GradScaler {
&self.scaler
}
pub fn scaler_mut(&mut self) -> &mut GradScaler {
&mut self.scaler
}
pub fn zero_grad(&mut self) {
}
pub fn update_schedule(&mut self) {
let stats = self.get_training_stats();
self.scaler.adaptive_growth_interval(stats.overflow_rate);
}
pub fn reset(&mut self) {
self.scaler.reset();
self.step_count = 0;
self.overflow_count = 0;
self.successful_steps = 0;
}
}
#[derive(Debug, Clone)]
pub struct TrainingStats {
pub total_steps: usize,
pub successful_steps: usize,
pub overflow_count: usize,
pub overflow_rate: f32,
pub scaler_stats: ScalerStats,
}
impl TrainingStats {
pub fn success_rate(&self) -> f32 {
if self.total_steps > 0 {
self.successful_steps as f32 / self.total_steps as f32
} else {
0.0
}
}
pub fn is_stable(&self) -> bool {
self.overflow_rate < 0.05 }
pub fn get_recommendations(&self) -> Vec<String> {
let mut recommendations = Vec::new();
if self.overflow_rate > 0.1 {
recommendations.push("Consider reducing initial loss scale".to_string());
recommendations.push("Consider increasing growth interval".to_string());
}
if self.overflow_rate > 0.2 {
recommendations.push("Consider using gradient clipping".to_string());
}
if self.overflow_rate < 0.01 && self.scaler_stats.current_scale < 1000.0 {
recommendations.push("Consider increasing initial loss scale".to_string());
}
if recommendations.is_empty() {
recommendations.push("Training appears stable".to_string());
}
recommendations
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::optim::sgd::SGD;
#[test]
fn test_amp_optimizer_creation() {
let sgd = SGD::new(0.01);
let amp_optimizer = AMPOptimizer::new(sgd, None);
assert_eq!(amp_optimizer.step_count, 0);
assert_eq!(amp_optimizer.overflow_count, 0);
}
#[test]
fn test_param_group() {
let group = ParamGroup {
param_ids: vec![0, 1, 2],
clip_gradients: true,
max_grad_norm: Some(1.0),
use_amp: true,
};
assert_eq!(group.param_ids.len(), 3);
assert!(group.clip_gradients);
assert!(group.use_amp);
}
#[test]
fn test_training_stats() {
let stats = TrainingStats {
total_steps: 100,
successful_steps: 98,
overflow_count: 2,
overflow_rate: 0.02,
scaler_stats: ScalerStats {
current_scale: 65536.0,
growth_factor: 2.0,
backoff_factor: 0.5,
growth_interval: 2000,
growth_tracker: 500,
consecutive_non_overflow: 10,
enabled: true,
has_overflow: false,
},
};
assert_eq!(stats.success_rate(), 0.98);
assert!(stats.is_stable());
let recommendations = stats.get_recommendations();
assert!(!recommendations.is_empty());
}
}