Skip to main content

entrenar/autograd/precision/
mod.rs

1//! Mixed-precision training utilities
2//!
3//! Provides support for training with reduced precision (fp16/bf16) while
4//! maintaining numerical stability through loss scaling and master weights.
5//!
6//! ## Overview
7//!
8//! Mixed-precision training uses lower precision (fp16/bf16) for:
9//! - Forward pass activations (memory savings)
10//! - Gradient computation (compute speedup)
11//!
12//! While maintaining full precision (fp32) for:
13//! - Master weights (numerical stability)
14//! - Loss scaling (gradient underflow prevention)
15//!
16//! ## Example
17//!
18//! ```ignore
19//! use entrenar::autograd::precision::{MixedPrecisionConfig, Precision, GradScaler};
20//!
21//! let config = MixedPrecisionConfig::bf16();
22//! let mut scaler = GradScaler::new(config.initial_scale);
23//!
24//! // Forward pass in reduced precision
25//! let loss = model.forward(&input);
26//!
27//! // Scale loss before backward
28//! let scaled_loss = scaler.scale(loss);
29//! backward(&mut scaled_loss, None);
30//!
31//! // Unscale and update
32//! scaler.unscale_grads(&mut params);
33//! optimizer.step(&mut params);
34//! scaler.update();
35//! ```
36
37mod config;
38mod conversions;
39mod precision_types;
40mod scaler;
41
42#[cfg(test)]
43mod tests;
44
45// Re-export all public types and functions
46pub use config::MixedPrecisionConfig;
47pub use conversions::{
48    bf16_to_f32, bf16_truncate, estimate_memory_savings, f32_to_bf16, f32_to_fp16, fp16_to_f32,
49    gemm_bf16_reference,
50};
51pub use precision_types::Precision;
52pub use scaler::GradScaler;