pub struct QuantizePerTensorBackwardPlan<TIn: Element, TOut: IntElement> { /* private fields */ }Expand description
quantize_per_tensor backward plan.
Straight-Through Estimator (STE):
dx = (dy / scale) * 1[qmin ≤ round(x/scale)+zp ≤ qmax]. The
in-range mask is recomputed in-kernel from the saved input x
(no separate mask is saved on FW).
When to use: backward for
QuantizePerTensorPlan. Caller
must retain the original input x from the FW pass.
Dtypes: gradient dy and dx in input FP {f32, f64, f16, bf16}.
TOut is the FW output int dtype, carried for SKU consistency
only — BW kernel does not consume an int operand.
Shape limits: flat [numel].
Workspace: none.
Precision guarantee: deterministic, bit-stable. The 1/scale
factor is mandatory (omitting it is the most common STE-grad bug).
Implementations§
Source§impl<TIn: Element, TOut: IntElement> QuantizePerTensorBackwardPlan<TIn, TOut>
impl<TIn: Element, TOut: IntElement> QuantizePerTensorBackwardPlan<TIn, TOut>
Sourcepub fn select(
_stream: &Stream,
desc: &QuantizePerTensorBackwardDescriptor,
_pref: PlanPreference,
) -> Result<Self>
pub fn select( _stream: &Stream, desc: &QuantizePerTensorBackwardDescriptor, _pref: PlanPreference, ) -> Result<Self>
Pick a kernel.
Sourcepub fn can_implement(
&self,
args: &QuantizePerTensorBackwardArgs<'_, TIn, TOut>,
) -> Result<()>
pub fn can_implement( &self, args: &QuantizePerTensorBackwardArgs<'_, TIn, TOut>, ) -> Result<()>
Validate args.
Sourcepub fn workspace_size(&self) -> usize
pub fn workspace_size(&self) -> usize
Workspace bytes.
Sourcepub fn precision_guarantee(&self) -> PrecisionGuarantee
pub fn precision_guarantee(&self) -> PrecisionGuarantee
Numerical guarantees.
Auto Trait Implementations§
impl<TIn, TOut> Freeze for QuantizePerTensorBackwardPlan<TIn, TOut>
impl<TIn, TOut> RefUnwindSafe for QuantizePerTensorBackwardPlan<TIn, TOut>where
TIn: RefUnwindSafe,
TOut: RefUnwindSafe,
impl<TIn, TOut> Send for QuantizePerTensorBackwardPlan<TIn, TOut>
impl<TIn, TOut> Sync for QuantizePerTensorBackwardPlan<TIn, TOut>
impl<TIn, TOut> Unpin for QuantizePerTensorBackwardPlan<TIn, TOut>
impl<TIn, TOut> UnsafeUnpin for QuantizePerTensorBackwardPlan<TIn, TOut>
impl<TIn, TOut> UnwindSafe for QuantizePerTensorBackwardPlan<TIn, TOut>where
TIn: UnwindSafe,
TOut: UnwindSafe,
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more