pub struct SoftmaxPlan<T: Element, const N: usize> { /* private fields */ }Expand description
Softmax forward plan — see the module-level docs for formulas, dtypes, workspace, and precision guarantees.
T: Element is the element type (f32 / f64 / f16 / bf16).
const N: usize is the tensor rank (1..=8).
Implementations§
Source§impl<T: Element, const N: usize> SoftmaxPlan<T, N>
impl<T: Element, const N: usize> SoftmaxPlan<T, N>
Sourcepub fn select(
_stream: &Stream,
desc: &SoftmaxDescriptor<N>,
_pref: PlanPreference,
) -> Result<Self>
pub fn select( _stream: &Stream, desc: &SoftmaxDescriptor<N>, _pref: PlanPreference, ) -> Result<Self>
Pick a kernel for desc. Validates softmax_axis < N, the dtype
is in the wired FP family, and tensor rank ≤ 8. Returns
Error::Unsupported for cells outside the matrix and
Error::InvalidProblem for malformed shapes / axes.
Sourcepub fn can_implement(&self, args: &SoftmaxArgs<'_, T, N>) -> Result<()>
pub fn can_implement(&self, args: &SoftmaxArgs<'_, T, N>) -> Result<()>
Validate args.
Sourcepub fn workspace_size(&self) -> usize
pub fn workspace_size(&self) -> usize
Workspace size in bytes. Always zero — the kernel does its two-pass scan in registers.
Sourcepub fn sku(&self) -> KernelSku
pub fn sku(&self) -> KernelSku
Identity of the kernel this plan picked (for telemetry + autotuner cache keying).
Sourcepub fn precision_guarantee(&self) -> PrecisionGuarantee
pub fn precision_guarantee(&self) -> PrecisionGuarantee
Numerical guarantees for this plan’s kernel — deterministic, bit-stable on the same hardware, f32 accumulator for the FP-detour half / bf16 inputs and f32 / f64 native for those dtypes.