1use ghostflow_core::Tensor;
12use std::collections::HashMap;
13
14#[derive(Debug, Clone, Copy, PartialEq)]
16pub enum PrecisionMode {
17 FP32,
19 FP16,
21 BF16,
23 FP8,
25}
26
27#[derive(Debug, Clone)]
29pub struct MixedPrecisionConfig {
30 pub mode: PrecisionMode,
32 pub init_scale: f32,
34 pub growth_factor: f32,
36 pub backoff_factor: f32,
38 pub growth_interval: usize,
40 pub dynamic_loss_scale: bool,
42}
43
44impl Default for MixedPrecisionConfig {
45 fn default() -> Self {
46 MixedPrecisionConfig {
47 mode: PrecisionMode::FP16,
48 init_scale: 65536.0,
49 growth_factor: 2.0,
50 backoff_factor: 0.5,
51 growth_interval: 2000,
52 dynamic_loss_scale: true,
53 }
54 }
55}
56
57impl MixedPrecisionConfig {
58 pub fn fp16() -> Self {
60 Self {
61 mode: PrecisionMode::FP16,
62 ..Default::default()
63 }
64 }
65
66 pub fn bf16() -> Self {
68 Self {
69 mode: PrecisionMode::BF16,
70 init_scale: 1.0, dynamic_loss_scale: false,
72 ..Default::default()
73 }
74 }
75
76 pub fn fp8() -> Self {
78 Self {
79 mode: PrecisionMode::FP8,
80 init_scale: 1024.0,
81 ..Default::default()
82 }
83 }
84}
85
86pub struct GradScaler {
88 config: MixedPrecisionConfig,
89 scale: f32,
90 growth_tracker: usize,
91 found_inf_count: usize,
92}
93
94impl GradScaler {
95 pub fn new(config: MixedPrecisionConfig) -> Self {
97 GradScaler {
98 scale: config.init_scale,
99 config,
100 growth_tracker: 0,
101 found_inf_count: 0,
102 }
103 }
104
105 pub fn scale_loss(&self, loss: &Tensor) -> Tensor {
107 if self.config.mode == PrecisionMode::FP32 {
108 return loss.clone();
109 }
110
111 loss.mul_scalar(self.scale)
112 }
113
114 pub fn unscale_gradients(&self, gradients: &mut HashMap<String, Tensor>) -> bool {
116 if self.config.mode == PrecisionMode::FP32 {
117 return true;
118 }
119
120 let inv_scale = 1.0 / self.scale;
121 let mut found_inf = false;
122
123 for (_name, grad) in gradients.iter_mut() {
124 if self.has_inf_or_nan(grad) {
126 found_inf = true;
127 break;
128 }
129
130 *grad = grad.mul_scalar(inv_scale);
132 }
133
134 !found_inf
135 }
136
137 pub fn step<F>(&mut self, optimizer_step: F, gradients: &mut HashMap<String, Tensor>) -> bool
139 where
140 F: FnOnce(),
141 {
142 let success = self.unscale_gradients(gradients);
144
145 if success {
146 optimizer_step();
148
149 self.update_scale(false);
151 true
152 } else {
153 self.update_scale(true);
155 false
156 }
157 }
158
159 fn update_scale(&mut self, found_inf: bool) {
161 if !self.config.dynamic_loss_scale {
162 return;
163 }
164
165 if found_inf {
166 self.scale *= self.config.backoff_factor;
168 self.scale = self.scale.max(1.0);
169 self.growth_tracker = 0;
170 self.found_inf_count += 1;
171 } else {
172 self.growth_tracker += 1;
174 if self.growth_tracker >= self.config.growth_interval {
175 self.scale *= self.config.growth_factor;
176 self.scale = self.scale.min(65536.0); self.growth_tracker = 0;
178 }
179 }
180 }
181
182 fn has_inf_or_nan(&self, tensor: &Tensor) -> bool {
184 let data = tensor.data_f32();
185 data.iter().any(|&x| x.is_infinite() || x.is_nan())
186 }
187
188 pub fn get_scale(&self) -> f32 {
190 self.scale
191 }
192
193 pub fn get_stats(&self) -> (f32, usize, usize) {
195 (self.scale, self.growth_tracker, self.found_inf_count)
196 }
197}
198
199pub fn to_half_precision(tensor: &Tensor, mode: PrecisionMode) -> Tensor {
201 match mode {
202 PrecisionMode::FP32 => tensor.clone(),
203 PrecisionMode::FP16 => convert_to_fp16(tensor),
204 PrecisionMode::BF16 => convert_to_bf16(tensor),
205 PrecisionMode::FP8 => convert_to_fp8(tensor),
206 }
207}
208
209pub fn to_full_precision(tensor: &Tensor, mode: PrecisionMode) -> Tensor {
211 match mode {
212 PrecisionMode::FP32 => tensor.clone(),
213 PrecisionMode::FP16 => convert_from_fp16(tensor),
214 PrecisionMode::BF16 => convert_from_bf16(tensor),
215 PrecisionMode::FP8 => convert_from_fp8(tensor),
216 }
217}
218
219fn convert_to_fp16(tensor: &Tensor) -> Tensor {
221 let data = tensor.data_f32();
222 let dims = tensor.dims();
223
224 let fp16_data: Vec<f32> = data.iter().map(|&x| {
226 let clamped = x.clamp(-65504.0, 65504.0);
228 let scale = 1024.0; (clamped * scale).round() / scale
231 }).collect();
232
233 Tensor::from_slice(&fp16_data, dims).unwrap()
234}
235
236fn convert_from_fp16(tensor: &Tensor) -> Tensor {
238 tensor.clone()
240}
241
242fn convert_to_bf16(tensor: &Tensor) -> Tensor {
244 let data = tensor.data_f32();
245 let dims = tensor.dims();
246
247 let bf16_data: Vec<f32> = data.iter().map(|&x| {
250 let bits = x.to_bits();
252 let truncated = bits & 0xFFFF_0000; f32::from_bits(truncated)
254 }).collect();
255
256 Tensor::from_slice(&bf16_data, dims).unwrap()
257}
258
259fn convert_from_bf16(tensor: &Tensor) -> Tensor {
261 tensor.clone()
262}
263
264fn convert_to_fp8(tensor: &Tensor) -> Tensor {
266 let data = tensor.data_f32();
267 let dims = tensor.dims();
268
269 let fp8_data: Vec<f32> = data.iter().map(|&x| {
272 let clamped = x.clamp(-448.0, 448.0);
274 let scale = 8.0;
276 (clamped * scale).round() / scale
277 }).collect();
278
279 Tensor::from_slice(&fp8_data, dims).unwrap()
280}
281
282fn convert_from_fp8(tensor: &Tensor) -> Tensor {
284 tensor.clone()
285}
286
287pub struct AutocastContext {
289 mode: PrecisionMode,
290 enabled: bool,
291}
292
293impl AutocastContext {
294 pub fn new(mode: PrecisionMode) -> Self {
296 AutocastContext {
297 mode,
298 enabled: true,
299 }
300 }
301
302 pub fn disable(&mut self) {
304 self.enabled = false;
305 }
306
307 pub fn enable(&mut self) {
309 self.enabled = true;
310 }
311
312 pub fn cast(&self, tensor: &Tensor) -> Tensor {
314 if self.enabled {
315 to_half_precision(tensor, self.mode)
316 } else {
317 tensor.clone()
318 }
319 }
320}
321
322#[cfg(test)]
323mod tests {
324 use super::*;
325
326 #[test]
327 fn test_grad_scaler() {
328 let config = MixedPrecisionConfig::fp16();
329 let mut scaler = GradScaler::new(config);
330
331 let loss = Tensor::from_slice(&[1.0f32], &[1]).unwrap();
333 let scaled_loss = scaler.scale_loss(&loss);
334
335 let scaled_data = scaled_loss.data_f32();
336 assert_eq!(scaled_data[0], 65536.0);
337 }
338
339 #[test]
340 fn test_unscale_gradients() {
341 let config = MixedPrecisionConfig::fp16();
342 let scaler = GradScaler::new(config);
343
344 let mut gradients = HashMap::new();
345 gradients.insert(
346 "weight".to_string(),
347 Tensor::from_slice(&[65536.0f32, 131072.0], &[2]).unwrap()
348 );
349
350 let success = scaler.unscale_gradients(&mut gradients);
351 assert!(success);
352
353 let grad = gradients.get("weight").unwrap();
354 let data = grad.data_f32();
355 assert!((data[0] - 1.0).abs() < 1e-5);
356 assert!((data[1] - 2.0).abs() < 1e-5);
357 }
358
359 #[test]
360 fn test_fp16_conversion() {
361 let tensor = Tensor::from_slice(&[1.5f32, -2.5, 100.0], &[3]).unwrap();
362 let fp16 = convert_to_fp16(&tensor);
363 let data = fp16.data_f32();
364
365 assert!((data[0] - 1.5).abs() < 0.01);
367 assert!((data[1] + 2.5).abs() < 0.01);
368 assert!((data[2] - 100.0).abs() < 0.1);
369 }
370
371 #[test]
372 fn test_bf16_conversion() {
373 let tensor = Tensor::from_slice(&[1.5f32, -2.5, 1000.0], &[3]).unwrap();
374 let bf16 = convert_to_bf16(&tensor);
375 let data = bf16.data_f32();
376
377 assert!((data[2] - 1000.0).abs() < 10.0);
379 }
380
381 #[test]
382 fn test_inf_detection() {
383 let config = MixedPrecisionConfig::fp16();
384 let scaler = GradScaler::new(config);
385
386 let tensor = Tensor::from_slice(&[1.0f32, f32::INFINITY, 2.0], &[3]).unwrap();
387 assert!(scaler.has_inf_or_nan(&tensor));
388
389 let tensor = Tensor::from_slice(&[1.0f32, 2.0, 3.0], &[3]).unwrap();
390 assert!(!scaler.has_inf_or_nan(&tensor));
391 }
392
393 #[test]
394 fn test_autocast_context() {
395 let mut ctx = AutocastContext::new(PrecisionMode::FP16);
396 let tensor = Tensor::from_slice(&[1.5f32, 2.5], &[2]).unwrap();
397
398 let casted = ctx.cast(&tensor);
399 assert_ne!(casted.data_f32(), tensor.data_f32());
400
401 ctx.disable();
402 let not_casted = ctx.cast(&tensor);
403 assert_eq!(not_casted.data_f32(), tensor.data_f32());
404 }
405
406 #[test]
407 fn test_dynamic_loss_scaling() {
408 let config = MixedPrecisionConfig::fp16();
409 let mut scaler = GradScaler::new(config);
410
411 let initial_scale = scaler.get_scale();
412
413 for _ in 0..2000 {
415 scaler.update_scale(false);
416 }
417
418 let grown_scale = scaler.get_scale();
419 assert!(grown_scale > initial_scale);
420
421 scaler.update_scale(true);
423 let reduced_scale = scaler.get_scale();
424 assert!(reduced_scale < grown_scale);
425 }
426}