1use core::ffi::c_void;
15use core::marker::PhantomData;
16
17use baracuda_cutlass::{Error, Result};
18use baracuda_driver::Stream;
19use baracuda_kernels_types::{
20 ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
21 PlanPreference, PrecisionGuarantee, ReduceKind, TensorMut, TensorRef, Workspace,
22};
23
24#[derive(Copy, Clone, Debug)]
30pub struct TraceDescriptor {
31 pub n: i32,
33 pub element: ElementKind,
35}
36
37pub struct TraceArgs<'a, T: Element> {
42 pub x: TensorRef<'a, T, 2>,
44 pub y: TensorMut<'a, T, 0>,
46}
47
48pub struct TracePlan<T: Element> {
53 desc: TraceDescriptor,
54 sku: KernelSku,
55 _marker: PhantomData<T>,
56}
57
58impl<T: Element> TracePlan<T> {
59 pub fn select(
61 _stream: &Stream,
62 desc: &TraceDescriptor,
63 _pref: PlanPreference,
64 ) -> Result<Self> {
65 if desc.element != T::KIND {
66 return Err(Error::Unsupported(
67 "baracuda-kernels::TracePlan: descriptor element != type parameter T",
68 ));
69 }
70 if desc.n < 0 {
71 return Err(Error::InvalidProblem(
72 "baracuda-kernels::TracePlan: n must be non-negative",
73 ));
74 }
75 let dtype_in_scope = matches!(
76 T::KIND,
77 ElementKind::F32 | ElementKind::F16 | ElementKind::Bf16 | ElementKind::F64
78 );
79 if !dtype_in_scope {
80 return Err(Error::Unsupported(
81 "baracuda-kernels::TracePlan: supported dtypes are \
82 {f32, f16, bf16, f64}; other dtypes land in later fanout",
83 ));
84 }
85 let precision_guarantee = PrecisionGuarantee {
89 math_precision: MathPrecision::F32,
90 accumulator: ElementKind::F32,
91 bit_stable_on_same_hardware: true,
92 deterministic: true,
93 };
94 let sku = KernelSku {
95 category: OpCategory::Reduction,
96 op: ReduceKind::Trace as u16,
97 element: T::KIND,
98 aux_element: None,
99 layout: None,
100 epilogue: None,
101 arch: ArchSku::Sm80,
102 backend: BackendKind::Bespoke,
103 precision_guarantee,
104 };
105 Ok(Self {
106 desc: *desc,
107 sku,
108 _marker: PhantomData,
109 })
110 }
111
112 pub fn can_implement(&self, args: &TraceArgs<'_, T>) -> Result<()> {
114 if args.x.shape != [self.desc.n, self.desc.n] {
115 return Err(Error::InvalidProblem(
116 "baracuda-kernels::TracePlan: X shape must be [n, n] (square)",
117 ));
118 }
119 let y_shape: [i32; 0] = args.y.shape;
123 let _expected: [i32; 0] = [];
124 if y_shape != _expected {
125 return Err(Error::InvalidProblem(
126 "baracuda-kernels::TracePlan: Y must be a rank-0 scalar (empty shape)",
127 ));
128 }
129 let n = self.desc.n as i64;
130 let x_needed = n.saturating_mul(n);
131 let x_len = args.x.data.len() as i64;
132 if x_len < x_needed {
133 return Err(Error::BufferTooSmall {
134 needed: x_needed as usize,
135 got: x_len as usize,
136 });
137 }
138 if (args.y.data.len() as i64) < 1 {
139 return Err(Error::BufferTooSmall {
140 needed: 1,
141 got: args.y.data.len(),
142 });
143 }
144 Ok(())
145 }
146
147 #[inline]
149 pub fn workspace_size(&self) -> usize {
150 0
151 }
152
153 #[inline]
155 pub fn sku(&self) -> KernelSku {
156 self.sku
157 }
158
159 #[inline]
161 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
162 self.sku.precision_guarantee
163 }
164
165 pub fn run(
167 &self,
168 stream: &Stream,
169 _workspace: Workspace<'_>,
170 args: TraceArgs<'_, T>,
171 ) -> Result<()> {
172 self.can_implement(&args)?;
173 if self.desc.n == 0 {
174 }
177 let x_ptr = args.x.data.as_raw().0 as *const c_void;
178 let y_ptr = args.y.data.as_raw().0 as *mut c_void;
179 let stream_ptr = stream.as_raw() as *mut c_void;
180 let n = self.desc.n;
181 let stride_row = args.x.stride[0];
182 let stride_col = args.x.stride[1];
183
184 macro_rules! dispatch {
185 ($sym:ident) => {{
186 unsafe {
187 baracuda_kernels_sys::$sym(
188 n,
189 stride_row,
190 stride_col,
191 x_ptr,
192 y_ptr,
193 core::ptr::null_mut(),
194 0,
195 stream_ptr,
196 )
197 }
198 }};
199 }
200
201 let status = match T::KIND {
202 ElementKind::F32 => dispatch!(baracuda_kernels_trace_f32_run),
203 ElementKind::F16 => dispatch!(baracuda_kernels_trace_f16_run),
204 ElementKind::Bf16 => dispatch!(baracuda_kernels_trace_bf16_run),
205 ElementKind::F64 => dispatch!(baracuda_kernels_trace_f64_run),
206 _ => {
207 return Err(Error::Unsupported(
208 "baracuda-kernels::TracePlan::run: dtype not wired",
209 ));
210 }
211 };
212 map_status(status)
213 }
214}
215
216fn map_status(code: i32) -> Result<()> {
217 match code {
218 0 => Ok(()),
219 1 => Err(Error::MisalignedOperand),
220 2 => Err(Error::InvalidProblem(
221 "baracuda-kernels-sys reported invalid problem",
222 )),
223 3 => Err(Error::Unsupported(
224 "baracuda-kernels-sys reported unsupported configuration",
225 )),
226 4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
227 n => Err(Error::CutlassInternal(n)),
228 }
229}