pub enum PrecisionPolicy {
AlwaysF32,
AlwaysF16,
AutoMixedConservative,
AutoMixed,
AutoMixedBf16,
Custom(HashMap<OpKind, Precision>),
}Expand description
Declarative precision policy for graph compilation.
Variants§
AlwaysF32
All ops at F32. Default; safe; baseline accuracy.
AlwaysF16
All ops at F16. Maximum speed; may lose accuracy on reductions.
AutoMixedConservative
Mixed precision, conservative variant. Forces F32 at every reduction boundary, matching PyTorch’s pre-2024 autocast and HuggingFace’s historical default. Accuracy is the highest of the AMP variants; performance suffers from a Cast node before and after every LayerNorm / Softmax in the graph. Compute → F16 Reduction → F32 (← the cast tax — see AutoMixed for the fix) Elementwise → F16 DataMovement → F16 Boundary (input/param/output) → F32
AutoMixed
Mixed precision (Phase G — current default). Reductions stay in the input dtype; the kernels themselves promote-to-f32 internally for the accumulation. This eliminates the dozens of Cast nodes that AutoMixedConservative inserts at LN/Softmax boundaries without sacrificing the f32 reduction accumulation that matters. Matches what modern PyTorch autocast actually does on Metal. Compute → F16 Reduction → F16 (kernel accumulates in f32 internally) Elementwise → F16 DataMovement → F16 Boundary (input/param/output) → F32
AutoMixedBf16
Mixed precision targeting BF16 on TPU/XLA. Same shape as
AutoMixed (compute + reduction + elementwise + data-movement
in the chosen low precision; boundaries stay F32) but the low
precision is BF16 instead of F16. BF16 is the native compute
dtype on TPU and recent GPUs; matches what JAX picks when
jax.config.update("jax_default_dtype_bits", "bfloat16").
Compute → BF16
Reduction → BF16 (XLA’s TPU codegen accumulates in f32)
Elementwise → BF16
DataMovement → BF16
Boundary → F32
Custom(HashMap<OpKind, Precision>)
Explicit per-op-kind override.
Implementations§
Source§impl PrecisionPolicy
impl PrecisionPolicy
Sourcepub fn precision_for(&self, kind: OpKind) -> Precision
pub fn precision_for(&self, kind: OpKind) -> Precision
Resolve the target precision for an op kind.
Trait Implementations§
Source§impl Clone for PrecisionPolicy
impl Clone for PrecisionPolicy
Source§fn clone(&self) -> PrecisionPolicy
fn clone(&self) -> PrecisionPolicy
1.0.0 (const: unstable) · Source§fn clone_from(&mut self, source: &Self)
fn clone_from(&mut self, source: &Self)
source. Read more