1use crate::error::{InferenceError, InferenceResult};
7use half::{bf16, f16};
8use scirs2_core::ndarray::{Array1, Array2};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize, Default)]
12pub enum PrecisionMode {
13 #[default]
15 FP32,
16 FP16,
18 BF16,
20 Mixed {
22 compute: ComputePrecision,
24 accumulate_fp32: bool,
26 },
27}
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
31pub enum ComputePrecision {
32 FP16,
34 BF16,
36}
37
38impl PrecisionMode {
39 pub fn is_reduced_precision(&self) -> bool {
41 !matches!(self, PrecisionMode::FP32)
42 }
43
44 pub fn memory_reduction_factor(&self) -> f32 {
46 match self {
47 PrecisionMode::FP32 => 1.0,
48 PrecisionMode::FP16 | PrecisionMode::BF16 => 0.5,
49 PrecisionMode::Mixed { .. } => 0.75, }
51 }
52
53 pub fn name(&self) -> &str {
55 match self {
56 PrecisionMode::FP32 => "FP32",
57 PrecisionMode::FP16 => "FP16",
58 PrecisionMode::BF16 => "BF16",
59 PrecisionMode::Mixed { compute, .. } => match compute {
60 ComputePrecision::FP16 => "Mixed-FP16",
61 ComputePrecision::BF16 => "Mixed-BF16",
62 },
63 }
64 }
65}
66
67#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
69pub struct PrecisionConfig {
70 pub mode: PrecisionMode,
72 pub loss_scale: f32,
74 pub dynamic_loss_scale: bool,
76 pub grad_clip_threshold: Option<f32>,
78}
79
80impl Default for PrecisionConfig {
81 fn default() -> Self {
82 Self {
83 mode: PrecisionMode::FP32,
84 loss_scale: 1.0,
85 dynamic_loss_scale: false,
86 grad_clip_threshold: None,
87 }
88 }
89}
90
91impl PrecisionConfig {
92 pub fn new() -> Self {
94 Self::default()
95 }
96
97 pub fn mode(mut self, mode: PrecisionMode) -> Self {
99 self.mode = mode;
100 self
101 }
102
103 pub fn fp16(mut self) -> Self {
105 self.mode = PrecisionMode::FP16;
106 self
107 }
108
109 pub fn bf16(mut self) -> Self {
111 self.mode = PrecisionMode::BF16;
112 self
113 }
114
115 pub fn mixed_fp16(mut self, accumulate_fp32: bool) -> Self {
117 self.mode = PrecisionMode::Mixed {
118 compute: ComputePrecision::FP16,
119 accumulate_fp32,
120 };
121 self
122 }
123
124 pub fn mixed_bf16(mut self, accumulate_fp32: bool) -> Self {
126 self.mode = PrecisionMode::Mixed {
127 compute: ComputePrecision::BF16,
128 accumulate_fp32,
129 };
130 self
131 }
132
133 pub fn loss_scale(mut self, scale: f32) -> Self {
135 self.loss_scale = scale;
136 self
137 }
138
139 pub fn dynamic_loss_scale(mut self, enabled: bool) -> Self {
141 self.dynamic_loss_scale = enabled;
142 self
143 }
144
145 pub fn grad_clip_threshold(mut self, threshold: f32) -> Self {
147 self.grad_clip_threshold = Some(threshold);
148 self
149 }
150}
151
152pub struct PrecisionConverter {
154 config: PrecisionConfig,
155}
156
157impl PrecisionConverter {
158 pub fn new(config: PrecisionConfig) -> Self {
160 Self { config }
161 }
162
163 pub fn convert_and_compute_1d(
165 &self,
166 data: &Array1<f32>,
167 op: impl Fn(&Array1<f32>) -> Array1<f32>,
168 ) -> InferenceResult<Array1<f32>> {
169 match self.config.mode {
170 PrecisionMode::FP32 => Ok(op(data)),
171 PrecisionMode::FP16 => {
172 let fp16_data = self.to_fp16_1d(data);
173 let fp16_result = op(&self.from_fp16_1d(&fp16_data));
174 Ok(fp16_result)
175 }
176 PrecisionMode::BF16 => {
177 let bf16_data = self.to_bf16_1d(data);
178 let bf16_result = op(&self.from_bf16_1d(&bf16_data));
179 Ok(bf16_result)
180 }
181 PrecisionMode::Mixed {
182 compute,
183 accumulate_fp32,
184 } => {
185 if accumulate_fp32 {
186 let reduced = match compute {
188 ComputePrecision::FP16 => {
189 let fp16_data = self.to_fp16_1d(data);
190 self.from_fp16_1d(&fp16_data)
191 }
192 ComputePrecision::BF16 => {
193 let bf16_data = self.to_bf16_1d(data);
194 self.from_bf16_1d(&bf16_data)
195 }
196 };
197 Ok(op(&reduced))
198 } else {
199 match compute {
201 ComputePrecision::FP16 => {
202 let fp16_data = self.to_fp16_1d(data);
203 Ok(op(&self.from_fp16_1d(&fp16_data)))
204 }
205 ComputePrecision::BF16 => {
206 let bf16_data = self.to_bf16_1d(data);
207 Ok(op(&self.from_bf16_1d(&bf16_data)))
208 }
209 }
210 }
211 }
212 }
213 }
214
215 pub fn to_fp16_1d(&self, data: &Array1<f32>) -> Vec<f16> {
217 data.iter().map(|&x| f16::from_f32(x)).collect()
218 }
219
220 pub fn from_fp16_1d(&self, data: &[f16]) -> Array1<f32> {
222 Array1::from_vec(data.iter().map(|&x| x.to_f32()).collect())
223 }
224
225 pub fn to_bf16_1d(&self, data: &Array1<f32>) -> Vec<bf16> {
227 data.iter().map(|&x| bf16::from_f32(x)).collect()
228 }
229
230 pub fn from_bf16_1d(&self, data: &[bf16]) -> Array1<f32> {
232 Array1::from_vec(data.iter().map(|&x| x.to_f32()).collect())
233 }
234
235 pub fn to_fp16_2d(&self, data: &Array2<f32>) -> Vec<f16> {
237 data.iter().map(|&x| f16::from_f32(x)).collect()
238 }
239
240 pub fn from_fp16_2d(
242 &self,
243 data: &[f16],
244 shape: (usize, usize),
245 ) -> InferenceResult<Array2<f32>> {
246 let vec: Vec<f32> = data.iter().map(|&x| x.to_f32()).collect();
247 Array2::from_shape_vec(shape, vec).map_err(|e| {
248 InferenceError::ForwardError(format!("Shape error in FP16 conversion: {}", e))
249 })
250 }
251
252 pub fn to_bf16_2d(&self, data: &Array2<f32>) -> Vec<bf16> {
254 data.iter().map(|&x| bf16::from_f32(x)).collect()
255 }
256
257 pub fn from_bf16_2d(
259 &self,
260 data: &[bf16],
261 shape: (usize, usize),
262 ) -> InferenceResult<Array2<f32>> {
263 let vec: Vec<f32> = data.iter().map(|&x| x.to_f32()).collect();
264 Array2::from_shape_vec(shape, vec).map_err(|e| {
265 InferenceError::ForwardError(format!("Shape error in BF16 conversion: {}", e))
266 })
267 }
268
269 pub fn config(&self) -> &PrecisionConfig {
271 &self.config
272 }
273}
274
275#[derive(Debug, Clone, Default)]
277pub struct PrecisionStats {
278 pub num_conversions: usize,
280 pub memory_saved: usize,
282 pub avg_error: f64,
284 pub max_error: f64,
286}
287
288impl PrecisionStats {
289 pub fn new() -> Self {
291 Self::default()
292 }
293
294 pub fn record_conversion(&mut self, original_size: usize, precision_mode: &PrecisionMode) {
296 self.num_conversions += 1;
297 let saved =
298 (original_size as f32 * (1.0 - precision_mode.memory_reduction_factor())) as usize;
299 self.memory_saved += saved;
300 }
301
302 pub fn record_error(&mut self, error: f64) {
304 let n = self.num_conversions as f64;
305 self.avg_error = (self.avg_error * (n - 1.0) + error) / n;
306 self.max_error = self.max_error.max(error);
307 }
308
309 pub fn memory_saved_mb(&self) -> f64 {
311 self.memory_saved as f64 / (1024.0 * 1024.0)
312 }
313}
314
315#[cfg(test)]
316mod tests {
317 use super::*;
318
319 #[test]
320 fn test_precision_mode_creation() {
321 let mode = PrecisionMode::FP32;
322 assert_eq!(mode.name(), "FP32");
323 assert!(!mode.is_reduced_precision());
324 }
325
326 #[test]
327 fn test_precision_mode_fp16() {
328 let mode = PrecisionMode::FP16;
329 assert_eq!(mode.name(), "FP16");
330 assert!(mode.is_reduced_precision());
331 assert_eq!(mode.memory_reduction_factor(), 0.5);
332 }
333
334 #[test]
335 fn test_precision_mode_bf16() {
336 let mode = PrecisionMode::BF16;
337 assert_eq!(mode.name(), "BF16");
338 assert!(mode.is_reduced_precision());
339 assert_eq!(mode.memory_reduction_factor(), 0.5);
340 }
341
342 #[test]
343 fn test_precision_mode_mixed() {
344 let mode = PrecisionMode::Mixed {
345 compute: ComputePrecision::FP16,
346 accumulate_fp32: true,
347 };
348 assert_eq!(mode.name(), "Mixed-FP16");
349 assert!(mode.is_reduced_precision());
350 }
351
352 #[test]
353 fn test_precision_config_builder() {
354 let config = PrecisionConfig::new()
355 .fp16()
356 .loss_scale(128.0)
357 .dynamic_loss_scale(true);
358
359 assert_eq!(config.mode, PrecisionMode::FP16);
360 assert_eq!(config.loss_scale, 128.0);
361 assert!(config.dynamic_loss_scale);
362 }
363
364 #[test]
365 fn test_fp16_conversion_1d() {
366 let config = PrecisionConfig::new().fp16();
367 let converter = PrecisionConverter::new(config);
368
369 let data = Array1::from_vec(vec![1.0, 2.5, -3.75, 0.0]);
370 let fp16_data = converter.to_fp16_1d(&data);
371 let restored = converter.from_fp16_1d(&fp16_data);
372
373 for (orig, rest) in data.iter().zip(restored.iter()) {
375 assert!((orig - rest).abs() < 0.001);
376 }
377 }
378
379 #[test]
380 fn test_bf16_conversion_1d() {
381 let config = PrecisionConfig::new().bf16();
382 let converter = PrecisionConverter::new(config);
383
384 let data = Array1::from_vec(vec![1.0, 2.5, -3.75, 0.0]);
385 let bf16_data = converter.to_bf16_1d(&data);
386 let restored = converter.from_bf16_1d(&bf16_data);
387
388 for (orig, rest) in data.iter().zip(restored.iter()) {
390 assert!((orig - rest).abs() < 0.01);
391 }
392 }
393
394 #[test]
395 fn test_convert_and_compute() {
396 let config = PrecisionConfig::new().fp16();
397 let converter = PrecisionConverter::new(config);
398
399 let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
400
401 let result = converter
403 .convert_and_compute_1d(&data, |x| x.mapv(|v| v * 2.0))
404 .unwrap();
405
406 for (i, &val) in result.iter().enumerate() {
408 let expected = data[i] * 2.0;
409 assert!((val - expected).abs() < 0.01);
410 }
411 }
412
413 #[test]
414 fn test_precision_stats() {
415 let mut stats = PrecisionStats::new();
416
417 assert_eq!(stats.num_conversions, 0);
418 assert_eq!(stats.memory_saved, 0);
419
420 let mode = PrecisionMode::FP16;
421 stats.record_conversion(1000, &mode);
422
423 assert_eq!(stats.num_conversions, 1);
424 assert_eq!(stats.memory_saved, 500); stats.record_error(0.001);
427 assert!(stats.avg_error > 0.0);
428 assert!(stats.max_error > 0.0);
429 }
430
431 #[test]
432 fn test_mixed_precision_compute() {
433 let config = PrecisionConfig::new().mixed_fp16(true);
434 let converter = PrecisionConverter::new(config);
435
436 let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
437
438 let result = converter
439 .convert_and_compute_1d(&data, |x| x.mapv(|v| v * 2.0))
440 .unwrap();
441
442 for (i, &val) in result.iter().enumerate() {
444 let expected = data[i] * 2.0;
445 assert!((val - expected).abs() < 0.001);
446 }
447 }
448
449 #[test]
450 fn test_fp16_2d_conversion() {
451 let config = PrecisionConfig::new().fp16();
452 let converter = PrecisionConverter::new(config);
453
454 let data = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
455
456 let fp16_data = converter.to_fp16_2d(&data);
457 let restored = converter.from_fp16_2d(&fp16_data, (2, 3)).unwrap();
458
459 assert_eq!(restored.shape(), &[2, 3]);
460
461 for (orig, rest) in data.iter().zip(restored.iter()) {
462 assert!((orig - rest).abs() < 0.001);
463 }
464 }
465
466 #[test]
467 fn test_bf16_2d_conversion() {
468 let config = PrecisionConfig::new().bf16();
469 let converter = PrecisionConverter::new(config);
470
471 let data = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
472
473 let bf16_data = converter.to_bf16_2d(&data);
474 let restored = converter.from_bf16_2d(&bf16_data, (2, 2)).unwrap();
475
476 assert_eq!(restored.shape(), &[2, 2]);
477
478 for (orig, rest) in data.iter().zip(restored.iter()) {
479 assert!((orig - rest).abs() < 0.01);
480 }
481 }
482}