pub enum PrecisionPolicy {
AlwaysF32,
AlwaysF16,
AutoMixedConservative,
AutoMixed,
AutoMixedBf16,
Custom(HashMap<OpKind, Precision>),
}Expand description
Declarative precision policy for graph compilation.
Variants§
AlwaysF32
All ops at F32. Default; safe; baseline accuracy.
AlwaysF16
All ops at F16. Maximum speed; may lose accuracy on reductions.
AutoMixedConservative
Mixed precision, conservative variant. Forces F32 at every reduction boundary, matching PyTorch’s pre-2024 autocast and HuggingFace’s historical default. Accuracy is the highest of the AMP variants; performance suffers from a Cast node before and after every LayerNorm / Softmax in the graph. Compute → F16 Reduction → F32 (← the cast tax — see AutoMixed for the fix) Elementwise → F16 DataMovement → F16 Boundary (input/param/output) → F32
AutoMixed
Mixed precision (Phase G — current default). Reductions stay in the input dtype; the kernels themselves promote-to-f32 internally for the accumulation. This eliminates the dozens of Cast nodes that AutoMixedConservative inserts at LN/Softmax boundaries without sacrificing the f32 reduction accumulation that matters. Matches what modern PyTorch autocast actually does on Metal. Compute → F16 Reduction → F16 (kernel accumulates in f32 internally) Elementwise → F16 DataMovement → F16 Boundary (input/param/output) → F32
AutoMixedBf16
Mixed precision targeting BF16 on TPU/XLA. Same shape as
AutoMixed (compute + reduction + elementwise + data-movement
in the chosen low precision; boundaries stay F32) but the low
precision is BF16 instead of F16. BF16 is the native compute
dtype on TPU and recent GPUs; matches what JAX picks when
jax.config.update("jax_default_dtype_bits", "bfloat16").
Compute → BF16
Reduction → BF16 (XLA’s TPU codegen accumulates in f32)
Elementwise → BF16
DataMovement → BF16
Boundary → F32
Custom(HashMap<OpKind, Precision>)
Explicit per-op-kind override.
Implementations§
Source§impl PrecisionPolicy
impl PrecisionPolicy
Sourcepub fn precision_for(&self, kind: OpKind) -> Precision
pub fn precision_for(&self, kind: OpKind) -> Precision
Resolve the target precision for an op kind.
Trait Implementations§
Source§impl Clone for PrecisionPolicy
impl Clone for PrecisionPolicy
Source§fn clone(&self) -> PrecisionPolicy
fn clone(&self) -> PrecisionPolicy
1.0.0 (const: unstable) · Source§fn clone_from(&mut self, source: &Self)
fn clone_from(&mut self, source: &Self)
source. Read moreSource§impl Debug for PrecisionPolicy
impl Debug for PrecisionPolicy
Source§impl Default for PrecisionPolicy
impl Default for PrecisionPolicy
Source§fn default() -> PrecisionPolicy
fn default() -> PrecisionPolicy
Auto Trait Implementations§
impl Freeze for PrecisionPolicy
impl RefUnwindSafe for PrecisionPolicy
impl Send for PrecisionPolicy
impl Sync for PrecisionPolicy
impl Unpin for PrecisionPolicy
impl UnsafeUnpin for PrecisionPolicy
impl UnwindSafe for PrecisionPolicy
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Source§impl<T> CloneToUninit for Twhere
T: Clone,
impl<T> CloneToUninit for Twhere
T: Clone,
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read more