pub struct TriuPlan<T: Element, const N: usize> { /* private fields */ }Expand description
triu plan.
y = torch.triu(x, diagonal) — upper-triangular mask on the last
two dims of x.
When to use: forward triu. Pair with
TriuBackwardPlan — the BW is the same
mask applied to d_output.
Dtypes: {f16, bf16, f32, f64, i32, i64, Bool}.
Shape limits: rank in [2, 8]; last two dims (M = shape[N-2],
N_cols = shape[N-1]) define the matrix; everything before is the
batch prefix.
Workspace: none.
Precision guarantee: deterministic, bit-stable, bit-exact — pure element select + zero, no arithmetic.
Implementations§
Source§impl<T: Element, const N: usize> TriuPlan<T, N>
impl<T: Element, const N: usize> TriuPlan<T, N>
Sourcepub fn select(
_stream: &Stream,
desc: &TriuDescriptor<N>,
_pref: PlanPreference,
) -> Result<Self>
pub fn select( _stream: &Stream, desc: &TriuDescriptor<N>, _pref: PlanPreference, ) -> Result<Self>
Pick a kernel for desc.
Sourcepub fn can_implement(&self, args: &TriuArgs<'_, T, N>) -> Result<()>
pub fn can_implement(&self, args: &TriuArgs<'_, T, N>) -> Result<()>
Validate args.
Sourcepub fn workspace_size(&self) -> usize
pub fn workspace_size(&self) -> usize
Workspace size in bytes. Always 0.
Sourcepub fn precision_guarantee(&self) -> PrecisionGuarantee
pub fn precision_guarantee(&self) -> PrecisionGuarantee
Numerical guarantees.
Sourcepub fn run(
&self,
stream: &Stream,
_workspace: Workspace<'_>,
args: TriuArgs<'_, T, N>,
) -> Result<()>
pub fn run( &self, stream: &Stream, _workspace: Workspace<'_>, args: TriuArgs<'_, T, N>, ) -> Result<()>
Launch.
Dispatch policy: if both input and output are canonical
row-major contiguous, route to the contig fast path
(baracuda_kernels_triu_<dtype>_run). Otherwise route to the
strided sibling (baracuda_kernels_triu_<dtype>_strided_run,
Phase 14.3) which threads per-axis signed strides for input
and output through the kernel parameter block.