1use core::ffi::c_void;
18use core::marker::PhantomData;
19
20use baracuda_cutlass::{Error, Result};
21use baracuda_driver::Stream;
22use baracuda_kernels_types::{
23 ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
24 PlanPreference, PrecisionGuarantee, ShapeLayoutKind, TensorMut, TensorRef, Workspace,
25};
26
27#[derive(Copy, Clone, Debug)]
34pub struct TriuDescriptor<const N: usize> {
35 pub shape: [i32; N],
37 pub diagonal: i32,
39 pub element: ElementKind,
41}
42
43pub struct TriuArgs<'a, T: Element, const N: usize> {
45 pub input: TensorRef<'a, T, N>,
47 pub output: TensorMut<'a, T, N>,
49}
50
51pub struct TriuPlan<T: Element, const N: usize> {
71 desc: TriuDescriptor<N>,
72 sku: KernelSku,
73 _marker: PhantomData<T>,
74}
75
76impl<T: Element, const N: usize> TriuPlan<T, N> {
77 pub fn select(
79 _stream: &Stream,
80 desc: &TriuDescriptor<N>,
81 _pref: PlanPreference,
82 ) -> Result<Self> {
83 if desc.element != T::KIND {
84 return Err(Error::Unsupported(
85 "baracuda-kernels::TriuPlan: descriptor element != type parameter T",
86 ));
87 }
88 if N < 2 {
89 return Err(Error::InvalidProblem(
90 "baracuda-kernels::TriuPlan: tensor rank must be >= 2 \
91 (need at least an (M, N) matrix to mask)",
92 ));
93 }
94 if N > 8 {
95 return Err(Error::Unsupported(
96 "baracuda-kernels::TriuPlan: tensor rank > 8 not supported",
97 ));
98 }
99 for &d in desc.shape.iter() {
100 if d < 0 {
101 return Err(Error::InvalidProblem(
102 "baracuda-kernels::TriuPlan: shape dims must be non-negative",
103 ));
104 }
105 }
106 if !dtype_in_scope(T::KIND) {
107 return Err(Error::Unsupported(
108 "baracuda-kernels::TriuPlan: dtype not wired; supported set is \
109 {f16, bf16, f32, f64, i32, i64, Bool}",
110 ));
111 }
112 let precision_guarantee = PrecisionGuarantee {
113 math_precision: MathPrecision::F32,
114 accumulator: ElementKind::F32,
115 bit_stable_on_same_hardware: true,
117 deterministic: true,
118 };
119 let sku = KernelSku {
120 category: OpCategory::ShapeLayout,
121 op: ShapeLayoutKind::Triu as u16,
122 element: T::KIND,
123 aux_element: None,
124 layout: None,
125 epilogue: None,
126 arch: ArchSku::Sm80,
127 backend: BackendKind::Bespoke,
128 precision_guarantee,
129 };
130 Ok(Self {
131 desc: *desc,
132 sku,
133 _marker: PhantomData,
134 })
135 }
136
137 pub fn can_implement(&self, args: &TriuArgs<'_, T, N>) -> Result<()> {
139 if args.input.shape != self.desc.shape {
140 return Err(Error::InvalidProblem(
141 "baracuda-kernels::TriuPlan: input shape mismatch with descriptor",
142 ));
143 }
144 if args.output.shape != self.desc.shape {
145 return Err(Error::InvalidProblem(
146 "baracuda-kernels::TriuPlan: output shape mismatch with descriptor",
147 ));
148 }
149 let numel = args.output.numel();
150 let in_len = args.input.data.len() as i64;
151 let out_len = args.output.data.len() as i64;
152 if in_len < numel || out_len < numel {
153 return Err(Error::BufferTooSmall {
154 needed: numel as usize,
155 got: in_len.min(out_len) as usize,
156 });
157 }
158 Ok(())
159 }
160
161 #[inline]
163 pub fn workspace_size(&self) -> usize {
164 0
165 }
166 #[inline]
168 pub fn sku(&self) -> KernelSku {
169 self.sku
170 }
171 #[inline]
173 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
174 self.sku.precision_guarantee
175 }
176
177 pub fn run(
186 &self,
187 stream: &Stream,
188 _workspace: Workspace<'_>,
189 args: TriuArgs<'_, T, N>,
190 ) -> Result<()> {
191 self.can_implement(&args)?;
192 let numel = args.output.numel();
193 if numel == 0 {
194 return Ok(());
195 }
196 let input_ptr = args.input.data.as_raw().0 as *const c_void;
197 let output_ptr = args.output.data.as_raw().0 as *mut c_void;
198 let stream_ptr = stream.as_raw() as *mut c_void;
199 let shape = self.desc.shape;
200 let rank = N as i32;
201 let diagonal = self.desc.diagonal;
202
203 let all_contig = args.input.is_contiguous() && args.output.is_contiguous();
207
208 if !all_contig {
209 let stride_x = args.input.stride;
210 let stride_y = args.output.stride;
211 let status = match T::KIND {
212 ElementKind::F16 => unsafe {
213 baracuda_kernels_sys::baracuda_kernels_triu_f16_strided_run(
214 input_ptr, output_ptr, shape.as_ptr(), rank,
215 stride_x.as_ptr(), stride_y.as_ptr(), diagonal, stream_ptr,
216 )
217 },
218 ElementKind::Bf16 => unsafe {
219 baracuda_kernels_sys::baracuda_kernels_triu_bf16_strided_run(
220 input_ptr, output_ptr, shape.as_ptr(), rank,
221 stride_x.as_ptr(), stride_y.as_ptr(), diagonal, stream_ptr,
222 )
223 },
224 ElementKind::F32 => unsafe {
225 baracuda_kernels_sys::baracuda_kernels_triu_f32_strided_run(
226 input_ptr, output_ptr, shape.as_ptr(), rank,
227 stride_x.as_ptr(), stride_y.as_ptr(), diagonal, stream_ptr,
228 )
229 },
230 ElementKind::F64 => unsafe {
231 baracuda_kernels_sys::baracuda_kernels_triu_f64_strided_run(
232 input_ptr, output_ptr, shape.as_ptr(), rank,
233 stride_x.as_ptr(), stride_y.as_ptr(), diagonal, stream_ptr,
234 )
235 },
236 ElementKind::I32 => unsafe {
237 baracuda_kernels_sys::baracuda_kernels_triu_i32_strided_run(
238 input_ptr, output_ptr, shape.as_ptr(), rank,
239 stride_x.as_ptr(), stride_y.as_ptr(), diagonal, stream_ptr,
240 )
241 },
242 ElementKind::I64 => unsafe {
243 baracuda_kernels_sys::baracuda_kernels_triu_i64_strided_run(
244 input_ptr, output_ptr, shape.as_ptr(), rank,
245 stride_x.as_ptr(), stride_y.as_ptr(), diagonal, stream_ptr,
246 )
247 },
248 ElementKind::Bool => unsafe {
249 baracuda_kernels_sys::baracuda_kernels_triu_bool_strided_run(
250 input_ptr, output_ptr, shape.as_ptr(), rank,
251 stride_x.as_ptr(), stride_y.as_ptr(), diagonal, stream_ptr,
252 )
253 },
254 _ => {
255 return Err(Error::Unsupported(
256 "baracuda-kernels::TriuPlan::run: dtype not wired (strided) \
257 (should have been rejected at select())",
258 ));
259 }
260 };
261 return map_status(status);
262 }
263
264 let status = match T::KIND {
265 ElementKind::F16 => unsafe {
266 baracuda_kernels_sys::baracuda_kernels_triu_f16_run(
267 input_ptr, output_ptr, shape.as_ptr(), rank, diagonal, stream_ptr,
268 )
269 },
270 ElementKind::Bf16 => unsafe {
271 baracuda_kernels_sys::baracuda_kernels_triu_bf16_run(
272 input_ptr, output_ptr, shape.as_ptr(), rank, diagonal, stream_ptr,
273 )
274 },
275 ElementKind::F32 => unsafe {
276 baracuda_kernels_sys::baracuda_kernels_triu_f32_run(
277 input_ptr, output_ptr, shape.as_ptr(), rank, diagonal, stream_ptr,
278 )
279 },
280 ElementKind::F64 => unsafe {
281 baracuda_kernels_sys::baracuda_kernels_triu_f64_run(
282 input_ptr, output_ptr, shape.as_ptr(), rank, diagonal, stream_ptr,
283 )
284 },
285 ElementKind::I32 => unsafe {
286 baracuda_kernels_sys::baracuda_kernels_triu_i32_run(
287 input_ptr, output_ptr, shape.as_ptr(), rank, diagonal, stream_ptr,
288 )
289 },
290 ElementKind::I64 => unsafe {
291 baracuda_kernels_sys::baracuda_kernels_triu_i64_run(
292 input_ptr, output_ptr, shape.as_ptr(), rank, diagonal, stream_ptr,
293 )
294 },
295 ElementKind::Bool => unsafe {
296 baracuda_kernels_sys::baracuda_kernels_triu_bool_run(
297 input_ptr, output_ptr, shape.as_ptr(), rank, diagonal, stream_ptr,
298 )
299 },
300 _ => {
301 return Err(Error::Unsupported(
302 "baracuda-kernels::TriuPlan::run: dtype not wired \
303 (should have been rejected at select())",
304 ));
305 }
306 };
307 map_status(status)
308 }
309}
310
311fn dtype_in_scope(k: ElementKind) -> bool {
312 matches!(
313 k,
314 ElementKind::F16
315 | ElementKind::Bf16
316 | ElementKind::F32
317 | ElementKind::F64
318 | ElementKind::I32
319 | ElementKind::I64
320 | ElementKind::Bool
321 )
322}
323
324fn map_status(code: i32) -> Result<()> {
325 match code {
326 0 => Ok(()),
327 1 => Err(Error::MisalignedOperand),
328 2 => Err(Error::InvalidProblem(
329 "baracuda-kernels-sys reported invalid problem",
330 )),
331 3 => Err(Error::Unsupported(
332 "baracuda-kernels-sys reported unsupported configuration",
333 )),
334 4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
335 n => Err(Error::CutlassInternal(n)),
336 }
337}