yscv_model/
mixed_precision.rs1use yscv_autograd::Graph;
2use yscv_tensor::{DType, Tensor};
3
4use crate::ModelError;
5
6#[derive(Debug, Clone)]
8pub struct MixedPrecisionConfig {
9 pub forward_dtype: DType,
11 pub master_dtype: DType,
13 pub loss_scale: f32,
15 pub dynamic_loss_scaling: bool,
17}
18
19impl Default for MixedPrecisionConfig {
20 fn default() -> Self {
21 Self {
22 forward_dtype: DType::F16,
23 master_dtype: DType::F32,
24 loss_scale: 1024.0,
25 dynamic_loss_scaling: true,
26 }
27 }
28}
29
30#[derive(Debug, Clone)]
32pub struct DynamicLossScaler {
33 current_scale: f32,
34 growth_factor: f32,
35 backoff_factor: f32,
36 growth_interval: u32,
37 steps_since_last_overflow: u32,
38}
39
40impl DynamicLossScaler {
41 pub fn new(initial_scale: f32) -> Self {
42 Self {
43 current_scale: initial_scale,
44 growth_factor: 2.0,
45 backoff_factor: 0.5,
46 growth_interval: 2000,
47 steps_since_last_overflow: 0,
48 }
49 }
50
51 pub fn scale(&self) -> f32 {
52 self.current_scale
53 }
54
55 pub fn scale_loss(&self, loss: &Tensor) -> Result<Tensor, ModelError> {
57 Ok(loss.scale(self.current_scale))
58 }
59
60 pub fn unscale_gradients(&self, gradients: &[Tensor]) -> Vec<Tensor> {
62 let inv_scale = 1.0 / self.current_scale;
63 gradients
64 .iter()
65 .map(|g| {
66 let scaled: Vec<f32> = g.data().iter().map(|&v| v * inv_scale).collect();
67 Tensor::from_vec(g.shape().to_vec(), scaled).expect("shape matches data")
68 })
69 .collect()
70 }
71
72 pub fn check_overflow(gradients: &[Tensor]) -> bool {
74 gradients.iter().any(|g| !g.all_finite())
75 }
76
77 pub fn update(&mut self, overflow: bool) -> bool {
80 if overflow {
81 self.current_scale *= self.backoff_factor;
82 self.steps_since_last_overflow = 0;
83 if self.current_scale < 1.0 {
84 self.current_scale = 1.0;
85 }
86 false
87 } else {
88 self.steps_since_last_overflow += 1;
89 if self.steps_since_last_overflow >= self.growth_interval {
90 self.current_scale *= self.growth_factor;
91 self.steps_since_last_overflow = 0;
92 if self.current_scale > 65504.0 {
93 self.current_scale = 65504.0;
94 }
95 }
96 true
97 }
98 }
99}
100
101pub fn cast_params_for_forward(
103 graph: &Graph,
104 param_nodes: &[yscv_autograd::NodeId],
105 target_dtype: DType,
106) -> Result<Vec<Tensor>, ModelError> {
107 let mut casted = Vec::with_capacity(param_nodes.len());
108 for &node in param_nodes {
109 let tensor = graph.value(node)?;
110 casted.push(tensor.to_dtype(target_dtype));
111 }
112 Ok(casted)
113}
114
115pub fn cast_to_master(tensors: &[Tensor], master_dtype: DType) -> Vec<Tensor> {
117 tensors.iter().map(|t| t.to_dtype(master_dtype)).collect()
118}
119
120pub fn mixed_precision_train_step(
130 graph: &mut Graph,
131 model: &crate::SequentialModel,
132 input: &Tensor,
133 target: &Tensor,
134 _config: &MixedPrecisionConfig,
135 scaler: &mut DynamicLossScaler,
136) -> Result<(f32, bool), ModelError> {
137 let input_node = graph.variable(input.clone());
138 let target_node = graph.constant(target.clone());
139 let pred = model.forward(graph, input_node)?;
140 let loss = crate::mse_loss(graph, pred, target_node)?;
141
142 let loss_val = graph.value(loss)?.data()[0];
143
144 graph.backward(loss)?;
145
146 let input_grad = graph.grad(input_node)?;
147 let has_overflow = input_grad.is_some_and(|g| !g.all_finite());
148 let should_apply = scaler.update(has_overflow);
149
150 Ok((loss_val, should_apply))
151}