Skip to main content

yscv_model/
mixed_precision.rs

1use yscv_autograd::Graph;
2use yscv_tensor::{DType, Tensor};
3
4use crate::ModelError;
5
6/// Mixed-precision training configuration.
7#[derive(Debug, Clone)]
8pub struct MixedPrecisionConfig {
9    /// Dtype used for forward pass computation.
10    pub forward_dtype: DType,
11    /// Dtype used for weight storage and gradient accumulation.
12    pub master_dtype: DType,
13    /// Loss scaling factor to prevent gradient underflow in half precision.
14    pub loss_scale: f32,
15    /// Enable dynamic loss scaling that adjusts scale based on overflow detection.
16    pub dynamic_loss_scaling: bool,
17}
18
19impl Default for MixedPrecisionConfig {
20    fn default() -> Self {
21        Self {
22            forward_dtype: DType::F16,
23            master_dtype: DType::F32,
24            loss_scale: 1024.0,
25            dynamic_loss_scaling: true,
26        }
27    }
28}
29
30/// State for dynamic loss scaling during mixed-precision training.
31#[derive(Debug, Clone)]
32pub struct DynamicLossScaler {
33    current_scale: f32,
34    growth_factor: f32,
35    backoff_factor: f32,
36    growth_interval: u32,
37    steps_since_last_overflow: u32,
38}
39
40impl DynamicLossScaler {
41    pub fn new(initial_scale: f32) -> Self {
42        Self {
43            current_scale: initial_scale,
44            growth_factor: 2.0,
45            backoff_factor: 0.5,
46            growth_interval: 2000,
47            steps_since_last_overflow: 0,
48        }
49    }
50
51    pub fn scale(&self) -> f32 {
52        self.current_scale
53    }
54
55    /// Scale a loss tensor by the current loss scale factor.
56    pub fn scale_loss(&self, loss: &Tensor) -> Result<Tensor, ModelError> {
57        Ok(loss.scale(self.current_scale))
58    }
59
60    /// Unscale gradients by dividing by the current scale factor.
61    pub fn unscale_gradients(&self, gradients: &[Tensor]) -> Vec<Tensor> {
62        let inv_scale = 1.0 / self.current_scale;
63        gradients
64            .iter()
65            .map(|g| {
66                let scaled: Vec<f32> = g.data().iter().map(|&v| v * inv_scale).collect();
67                Tensor::from_vec(g.shape().to_vec(), scaled).expect("shape matches data")
68            })
69            .collect()
70    }
71
72    /// Check if any gradient contains inf/nan (overflow indicator).
73    pub fn check_overflow(gradients: &[Tensor]) -> bool {
74        gradients.iter().any(|g| !g.all_finite())
75    }
76
77    /// Update the scaler state after a training step.
78    /// Returns `true` if the step should be applied (no overflow), `false` to skip.
79    pub fn update(&mut self, overflow: bool) -> bool {
80        if overflow {
81            self.current_scale *= self.backoff_factor;
82            self.steps_since_last_overflow = 0;
83            if self.current_scale < 1.0 {
84                self.current_scale = 1.0;
85            }
86            false
87        } else {
88            self.steps_since_last_overflow += 1;
89            if self.steps_since_last_overflow >= self.growth_interval {
90                self.current_scale *= self.growth_factor;
91                self.steps_since_last_overflow = 0;
92                if self.current_scale > 65504.0 {
93                    self.current_scale = 65504.0;
94                }
95            }
96            true
97        }
98    }
99}
100
101/// Convert model parameters from master precision to forward precision.
102pub fn cast_params_for_forward(
103    graph: &Graph,
104    param_nodes: &[yscv_autograd::NodeId],
105    target_dtype: DType,
106) -> Result<Vec<Tensor>, ModelError> {
107    let mut casted = Vec::with_capacity(param_nodes.len());
108    for &node in param_nodes {
109        let tensor = graph.value(node)?;
110        casted.push(tensor.to_dtype(target_dtype));
111    }
112    Ok(casted)
113}
114
115/// Cast a list of tensors back to master dtype for gradient accumulation.
116pub fn cast_to_master(tensors: &[Tensor], master_dtype: DType) -> Vec<Tensor> {
117    tensors.iter().map(|t| t.to_dtype(master_dtype)).collect()
118}
119
120/// Runs a mixed-precision forward+backward step.
121///
122/// 1. Cast input to forward_dtype, then back to F32 for graph computation
123/// 2. Run forward pass, compute loss
124/// 3. Scale loss, backprop
125/// 4. Check gradients for overflow
126/// 5. Update scaler state
127///
128/// Returns (loss_value, step_applied).
129pub fn mixed_precision_train_step(
130    graph: &mut Graph,
131    model: &crate::SequentialModel,
132    input: &Tensor,
133    target: &Tensor,
134    _config: &MixedPrecisionConfig,
135    scaler: &mut DynamicLossScaler,
136) -> Result<(f32, bool), ModelError> {
137    let input_node = graph.variable(input.clone());
138    let target_node = graph.constant(target.clone());
139    let pred = model.forward(graph, input_node)?;
140    let loss = crate::mse_loss(graph, pred, target_node)?;
141
142    let loss_val = graph.value(loss)?.data()[0];
143
144    graph.backward(loss)?;
145
146    let input_grad = graph.grad(input_node)?;
147    let has_overflow = input_grad.is_some_and(|g| !g.all_finite());
148    let should_apply = scaler.update(has_overflow);
149
150    Ok((loss_val, should_apply))
151}