Skip to main content

baracuda_kernels/shape_layout/
triu.rs

1//! `triu` plan — upper-triangular matrix mask (Phase 13.4).
2//!
3//! `output[..., i, j] = input[..., i, j] if j >= i + diagonal else 0`.
4//! Operates on the last two dims of a rank-≥2 tensor; the batch prefix
5//! (anything before the matrix axes) is masked independently with the
6//! same `diagonal`. Output shape equals input shape.
7//!
8//! - `diagonal == 0`: main diagonal.
9//! - `diagonal > 0`: shift the kept region UP (triu keeps more).
10//! - `diagonal < 0`: shift the kept region DOWN (triu keeps less).
11//!
12//! Driven by Fuel team's CPU-only triu/tril gap. Pair with
13//! [`TriuBackwardPlan`](crate::TriuBackwardPlan) — the BW applies the
14//! same mask to the upstream gradient (`d_input = triu(d_output,
15//! diagonal)`) so it reuses the FW launch symbol.
16
17use core::ffi::c_void;
18use core::marker::PhantomData;
19
20use baracuda_cutlass::{Error, Result};
21use baracuda_driver::Stream;
22use baracuda_kernels_types::{
23    ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
24    PlanPreference, PrecisionGuarantee, ShapeLayoutKind, TensorMut, TensorRef, Workspace,
25};
26
27/// Descriptor for a `triu` op.
28///
29/// `shape` is the logical tensor shape (input and output agree).
30/// `diagonal` shifts the mask boundary — `0` is the main diagonal,
31/// positive shifts up (more zeros below), negative shifts down (more
32/// elements kept).
33#[derive(Copy, Clone, Debug)]
34pub struct TriuDescriptor<const N: usize> {
35    /// Logical tensor shape. `N >= 2` enforced at `select` time.
36    pub shape: [i32; N],
37    /// Diagonal offset. `0` == main diagonal.
38    pub diagonal: i32,
39    /// Element type.
40    pub element: ElementKind,
41}
42
43/// Args bundle for a Triu launch.
44pub struct TriuArgs<'a, T: Element, const N: usize> {
45    /// Input — same shape as output.
46    pub input: TensorRef<'a, T, N>,
47    /// Output — same shape as input. Off-diagonal positions are zeroed.
48    pub output: TensorMut<'a, T, N>,
49}
50
51/// `triu` plan.
52///
53/// `y = torch.triu(x, diagonal)` — upper-triangular mask on the last
54/// two dims of `x`.
55///
56/// **When to use**: forward triu. Pair with
57/// [`TriuBackwardPlan`](crate::TriuBackwardPlan) — the BW is the same
58/// mask applied to `d_output`.
59///
60/// **Dtypes**: `{f16, bf16, f32, f64, i32, i64, Bool}`.
61///
62/// **Shape limits**: rank in `[2, 8]`; last two dims (`M = shape[N-2]`,
63/// `N_cols = shape[N-1]`) define the matrix; everything before is the
64/// batch prefix.
65///
66/// **Workspace**: none.
67///
68/// **Precision guarantee**: deterministic, bit-stable, bit-exact —
69/// pure element select + zero, no arithmetic.
70pub struct TriuPlan<T: Element, const N: usize> {
71    desc: TriuDescriptor<N>,
72    sku: KernelSku,
73    _marker: PhantomData<T>,
74}
75
76impl<T: Element, const N: usize> TriuPlan<T, N> {
77    /// Pick a kernel for `desc`.
78    pub fn select(
79        _stream: &Stream,
80        desc: &TriuDescriptor<N>,
81        _pref: PlanPreference,
82    ) -> Result<Self> {
83        if desc.element != T::KIND {
84            return Err(Error::Unsupported(
85                "baracuda-kernels::TriuPlan: descriptor element != type parameter T",
86            ));
87        }
88        if N < 2 {
89            return Err(Error::InvalidProblem(
90                "baracuda-kernels::TriuPlan: tensor rank must be >= 2 \
91                 (need at least an (M, N) matrix to mask)",
92            ));
93        }
94        if N > 8 {
95            return Err(Error::Unsupported(
96                "baracuda-kernels::TriuPlan: tensor rank > 8 not supported",
97            ));
98        }
99        for &d in desc.shape.iter() {
100            if d < 0 {
101                return Err(Error::InvalidProblem(
102                    "baracuda-kernels::TriuPlan: shape dims must be non-negative",
103                ));
104            }
105        }
106        if !dtype_in_scope(T::KIND) {
107            return Err(Error::Unsupported(
108                "baracuda-kernels::TriuPlan: dtype not wired; supported set is \
109                 {f16, bf16, f32, f64, i32, i64, Bool}",
110            ));
111        }
112        let precision_guarantee = PrecisionGuarantee {
113            math_precision: MathPrecision::F32,
114            accumulator: ElementKind::F32,
115            // Pure select-or-zero — no arithmetic.
116            bit_stable_on_same_hardware: true,
117            deterministic: true,
118        };
119        let sku = KernelSku {
120            category: OpCategory::ShapeLayout,
121            op: ShapeLayoutKind::Triu as u16,
122            element: T::KIND,
123            aux_element: None,
124            layout: None,
125            epilogue: None,
126            arch: ArchSku::Sm80,
127            backend: BackendKind::Bespoke,
128            precision_guarantee,
129        };
130        Ok(Self {
131            desc: *desc,
132            sku,
133            _marker: PhantomData,
134        })
135    }
136
137    /// Validate args.
138    pub fn can_implement(&self, args: &TriuArgs<'_, T, N>) -> Result<()> {
139        if args.input.shape != self.desc.shape {
140            return Err(Error::InvalidProblem(
141                "baracuda-kernels::TriuPlan: input shape mismatch with descriptor",
142            ));
143        }
144        if args.output.shape != self.desc.shape {
145            return Err(Error::InvalidProblem(
146                "baracuda-kernels::TriuPlan: output shape mismatch with descriptor",
147            ));
148        }
149        let numel = args.output.numel();
150        let in_len = args.input.data.len() as i64;
151        let out_len = args.output.data.len() as i64;
152        if in_len < numel || out_len < numel {
153            return Err(Error::BufferTooSmall {
154                needed: numel as usize,
155                got: in_len.min(out_len) as usize,
156            });
157        }
158        Ok(())
159    }
160
161    /// Workspace size in bytes. Always `0`.
162    #[inline]
163    pub fn workspace_size(&self) -> usize {
164        0
165    }
166    /// Identity of the kernel this plan picked.
167    #[inline]
168    pub fn sku(&self) -> KernelSku {
169        self.sku
170    }
171    /// Numerical guarantees.
172    #[inline]
173    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
174        self.sku.precision_guarantee
175    }
176
177    /// Launch.
178    ///
179    /// Dispatch policy: if both `input` and `output` are canonical
180    /// row-major contiguous, route to the contig fast path
181    /// (`baracuda_kernels_triu_<dtype>_run`). Otherwise route to the
182    /// strided sibling (`baracuda_kernels_triu_<dtype>_strided_run`,
183    /// Phase 14.3) which threads per-axis signed strides for input
184    /// and output through the kernel parameter block.
185    pub fn run(
186        &self,
187        stream: &Stream,
188        _workspace: Workspace<'_>,
189        args: TriuArgs<'_, T, N>,
190    ) -> Result<()> {
191        self.can_implement(&args)?;
192        let numel = args.output.numel();
193        if numel == 0 {
194            return Ok(());
195        }
196        let input_ptr = args.input.data.as_raw().0 as *const c_void;
197        let output_ptr = args.output.data.as_raw().0 as *mut c_void;
198        let stream_ptr = stream.as_raw() as *mut c_void;
199        let shape = self.desc.shape;
200        let rank = N as i32;
201        let diagonal = self.desc.diagonal;
202
203        // Canonical-contig fast path: both operands canonical row-major.
204        // Any other layout (transposed, sliced, broadcast, flipped)
205        // routes to the strided sibling.
206        let all_contig = args.input.is_contiguous() && args.output.is_contiguous();
207
208        if !all_contig {
209            let stride_x = args.input.stride;
210            let stride_y = args.output.stride;
211            let status = match T::KIND {
212                ElementKind::F16 => unsafe {
213                    baracuda_kernels_sys::baracuda_kernels_triu_f16_strided_run(
214                        input_ptr, output_ptr, shape.as_ptr(), rank,
215                        stride_x.as_ptr(), stride_y.as_ptr(), diagonal, stream_ptr,
216                    )
217                },
218                ElementKind::Bf16 => unsafe {
219                    baracuda_kernels_sys::baracuda_kernels_triu_bf16_strided_run(
220                        input_ptr, output_ptr, shape.as_ptr(), rank,
221                        stride_x.as_ptr(), stride_y.as_ptr(), diagonal, stream_ptr,
222                    )
223                },
224                ElementKind::F32 => unsafe {
225                    baracuda_kernels_sys::baracuda_kernels_triu_f32_strided_run(
226                        input_ptr, output_ptr, shape.as_ptr(), rank,
227                        stride_x.as_ptr(), stride_y.as_ptr(), diagonal, stream_ptr,
228                    )
229                },
230                ElementKind::F64 => unsafe {
231                    baracuda_kernels_sys::baracuda_kernels_triu_f64_strided_run(
232                        input_ptr, output_ptr, shape.as_ptr(), rank,
233                        stride_x.as_ptr(), stride_y.as_ptr(), diagonal, stream_ptr,
234                    )
235                },
236                ElementKind::I32 => unsafe {
237                    baracuda_kernels_sys::baracuda_kernels_triu_i32_strided_run(
238                        input_ptr, output_ptr, shape.as_ptr(), rank,
239                        stride_x.as_ptr(), stride_y.as_ptr(), diagonal, stream_ptr,
240                    )
241                },
242                ElementKind::I64 => unsafe {
243                    baracuda_kernels_sys::baracuda_kernels_triu_i64_strided_run(
244                        input_ptr, output_ptr, shape.as_ptr(), rank,
245                        stride_x.as_ptr(), stride_y.as_ptr(), diagonal, stream_ptr,
246                    )
247                },
248                ElementKind::Bool => unsafe {
249                    baracuda_kernels_sys::baracuda_kernels_triu_bool_strided_run(
250                        input_ptr, output_ptr, shape.as_ptr(), rank,
251                        stride_x.as_ptr(), stride_y.as_ptr(), diagonal, stream_ptr,
252                    )
253                },
254                _ => {
255                    return Err(Error::Unsupported(
256                        "baracuda-kernels::TriuPlan::run: dtype not wired (strided) \
257                         (should have been rejected at select())",
258                    ));
259                }
260            };
261            return map_status(status);
262        }
263
264        let status = match T::KIND {
265            ElementKind::F16 => unsafe {
266                baracuda_kernels_sys::baracuda_kernels_triu_f16_run(
267                    input_ptr, output_ptr, shape.as_ptr(), rank, diagonal, stream_ptr,
268                )
269            },
270            ElementKind::Bf16 => unsafe {
271                baracuda_kernels_sys::baracuda_kernels_triu_bf16_run(
272                    input_ptr, output_ptr, shape.as_ptr(), rank, diagonal, stream_ptr,
273                )
274            },
275            ElementKind::F32 => unsafe {
276                baracuda_kernels_sys::baracuda_kernels_triu_f32_run(
277                    input_ptr, output_ptr, shape.as_ptr(), rank, diagonal, stream_ptr,
278                )
279            },
280            ElementKind::F64 => unsafe {
281                baracuda_kernels_sys::baracuda_kernels_triu_f64_run(
282                    input_ptr, output_ptr, shape.as_ptr(), rank, diagonal, stream_ptr,
283                )
284            },
285            ElementKind::I32 => unsafe {
286                baracuda_kernels_sys::baracuda_kernels_triu_i32_run(
287                    input_ptr, output_ptr, shape.as_ptr(), rank, diagonal, stream_ptr,
288                )
289            },
290            ElementKind::I64 => unsafe {
291                baracuda_kernels_sys::baracuda_kernels_triu_i64_run(
292                    input_ptr, output_ptr, shape.as_ptr(), rank, diagonal, stream_ptr,
293                )
294            },
295            ElementKind::Bool => unsafe {
296                baracuda_kernels_sys::baracuda_kernels_triu_bool_run(
297                    input_ptr, output_ptr, shape.as_ptr(), rank, diagonal, stream_ptr,
298                )
299            },
300            _ => {
301                return Err(Error::Unsupported(
302                    "baracuda-kernels::TriuPlan::run: dtype not wired \
303                     (should have been rejected at select())",
304                ));
305            }
306        };
307        map_status(status)
308    }
309}
310
311fn dtype_in_scope(k: ElementKind) -> bool {
312    matches!(
313        k,
314        ElementKind::F16
315            | ElementKind::Bf16
316            | ElementKind::F32
317            | ElementKind::F64
318            | ElementKind::I32
319            | ElementKind::I64
320            | ElementKind::Bool
321    )
322}
323
324fn map_status(code: i32) -> Result<()> {
325    match code {
326        0 => Ok(()),
327        1 => Err(Error::MisalignedOperand),
328        2 => Err(Error::InvalidProblem(
329            "baracuda-kernels-sys reported invalid problem",
330        )),
331        3 => Err(Error::Unsupported(
332            "baracuda-kernels-sys reported unsupported configuration",
333        )),
334        4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
335        n => Err(Error::CutlassInternal(n)),
336    }
337}