1use core::ffi::c_void;
21use core::marker::PhantomData;
22
23use baracuda_cutlass::{Error, Result};
24use baracuda_driver::Stream;
25use baracuda_kernels_types::{
26 ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
27 PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, UnaryKind, Workspace,
28};
29
30#[derive(Copy, Clone, Debug)]
37pub struct CastDescriptor {
38 pub numel: i32,
40 pub input_element: ElementKind,
42 pub output_element: ElementKind,
44}
45
46pub struct CastArgs<'a, TIn: Element, TOut: Element> {
49 pub input: TensorRef<'a, TIn, 1>,
51 pub output: TensorMut<'a, TOut, 1>,
53}
54
55pub struct CastPlan<TIn: Element, TOut: Element> {
59 desc: CastDescriptor,
60 sku: KernelSku,
61 _marker_in: PhantomData<TIn>,
62 _marker_out: PhantomData<TOut>,
63}
64
65impl<TIn: Element, TOut: Element> CastPlan<TIn, TOut> {
66 pub fn select(
68 _stream: &Stream,
69 desc: &CastDescriptor,
70 _pref: PlanPreference,
71 ) -> Result<Self> {
72 if desc.input_element != TIn::KIND {
73 return Err(Error::Unsupported(
74 "baracuda-kernels::CastPlan: descriptor input_element != type parameter TIn",
75 ));
76 }
77 if desc.output_element != TOut::KIND {
78 return Err(Error::Unsupported(
79 "baracuda-kernels::CastPlan: descriptor output_element != type parameter TOut",
80 ));
81 }
82 if desc.numel < 0 {
83 return Err(Error::InvalidProblem(
84 "baracuda-kernels::CastPlan: numel must be non-negative",
85 ));
86 }
87 if !pair_in_scope(TIn::KIND, TOut::KIND) {
88 return Err(Error::Unsupported(
89 "baracuda-kernels::CastPlan: this (TIn, TOut) pair is not wired today; \
90 supported set is {f32, f64, f16, bf16, i32, i64} × {same}",
91 ));
92 }
93
94 let precision_guarantee = PrecisionGuarantee {
99 math_precision: MathPrecision::F32,
100 accumulator: ElementKind::F32,
101 bit_stable_on_same_hardware: true,
102 deterministic: true,
103 };
104 let sku = KernelSku {
105 category: OpCategory::UnaryElementwise,
106 op: UnaryKind::Cast as u16,
107 element: TIn::KIND,
108 aux_element: Some(TOut::KIND),
109 layout: None,
110 epilogue: None,
111 arch: ArchSku::Sm80,
112 backend: BackendKind::Bespoke,
113 precision_guarantee,
114 };
115 Ok(Self {
116 desc: *desc,
117 sku,
118 _marker_in: PhantomData,
119 _marker_out: PhantomData,
120 })
121 }
122
123 pub fn can_implement(&self, args: &CastArgs<'_, TIn, TOut>) -> Result<()> {
125 let expected = self.desc.numel as i64;
126 if args.input.numel() != expected {
127 return Err(Error::InvalidProblem(
128 "baracuda-kernels::CastPlan: input numel mismatch with descriptor",
129 ));
130 }
131 if args.output.numel() != expected {
132 return Err(Error::InvalidProblem(
133 "baracuda-kernels::CastPlan: output numel mismatch with descriptor",
134 ));
135 }
136 if (args.input.data.len() as i64) < expected {
137 return Err(Error::BufferTooSmall {
138 needed: expected as usize,
139 got: args.input.data.len(),
140 });
141 }
142 if (args.output.data.len() as i64) < expected {
143 return Err(Error::BufferTooSmall {
144 needed: expected as usize,
145 got: args.output.data.len(),
146 });
147 }
148 Ok(())
149 }
150
151 #[inline]
153 pub fn workspace_size(&self) -> usize {
154 0
155 }
156
157 #[inline]
159 pub fn sku(&self) -> KernelSku {
160 self.sku
161 }
162
163 #[inline]
165 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
166 self.sku.precision_guarantee
167 }
168
169 pub fn run(
171 &self,
172 stream: &Stream,
173 _workspace: Workspace<'_>,
174 args: CastArgs<'_, TIn, TOut>,
175 ) -> Result<()> {
176 self.can_implement(&args)?;
177 let numel = self.desc.numel as i64;
178 if numel == 0 {
179 return Ok(());
180 }
181 let x_ptr = args.input.data.as_raw().0 as *const c_void;
182 let y_ptr = args.output.data.as_raw().0 as *mut c_void;
183 let stream_ptr = stream.as_raw() as *mut c_void;
184
185 let status = match (TIn::KIND, TOut::KIND) {
190 (ElementKind::F32, ElementKind::F32) => unsafe {
192 baracuda_kernels_sys::baracuda_kernels_cast_f32_f32_run(
193 numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
194 )
195 },
196 (ElementKind::F32, ElementKind::F64) => unsafe {
197 baracuda_kernels_sys::baracuda_kernels_cast_f32_f64_run(
198 numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
199 )
200 },
201 (ElementKind::F32, ElementKind::F16) => unsafe {
202 baracuda_kernels_sys::baracuda_kernels_cast_f32_f16_run(
203 numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
204 )
205 },
206 (ElementKind::F32, ElementKind::Bf16) => unsafe {
207 baracuda_kernels_sys::baracuda_kernels_cast_f32_bf16_run(
208 numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
209 )
210 },
211 (ElementKind::F32, ElementKind::I32) => unsafe {
212 baracuda_kernels_sys::baracuda_kernels_cast_f32_i32_run(
213 numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
214 )
215 },
216 (ElementKind::F32, ElementKind::I64) => unsafe {
217 baracuda_kernels_sys::baracuda_kernels_cast_f32_i64_run(
218 numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
219 )
220 },
221 (ElementKind::F64, ElementKind::F32) => unsafe {
223 baracuda_kernels_sys::baracuda_kernels_cast_f64_f32_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
224 },
225 (ElementKind::F64, ElementKind::F64) => unsafe {
226 baracuda_kernels_sys::baracuda_kernels_cast_f64_f64_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
227 },
228 (ElementKind::F64, ElementKind::F16) => unsafe {
229 baracuda_kernels_sys::baracuda_kernels_cast_f64_f16_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
230 },
231 (ElementKind::F64, ElementKind::Bf16) => unsafe {
232 baracuda_kernels_sys::baracuda_kernels_cast_f64_bf16_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
233 },
234 (ElementKind::F64, ElementKind::I32) => unsafe {
235 baracuda_kernels_sys::baracuda_kernels_cast_f64_i32_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
236 },
237 (ElementKind::F64, ElementKind::I64) => unsafe {
238 baracuda_kernels_sys::baracuda_kernels_cast_f64_i64_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
239 },
240 (ElementKind::F16, ElementKind::F32) => unsafe {
242 baracuda_kernels_sys::baracuda_kernels_cast_f16_f32_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
243 },
244 (ElementKind::F16, ElementKind::F64) => unsafe {
245 baracuda_kernels_sys::baracuda_kernels_cast_f16_f64_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
246 },
247 (ElementKind::F16, ElementKind::F16) => unsafe {
248 baracuda_kernels_sys::baracuda_kernels_cast_f16_f16_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
249 },
250 (ElementKind::F16, ElementKind::Bf16) => unsafe {
251 baracuda_kernels_sys::baracuda_kernels_cast_f16_bf16_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
252 },
253 (ElementKind::F16, ElementKind::I32) => unsafe {
254 baracuda_kernels_sys::baracuda_kernels_cast_f16_i32_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
255 },
256 (ElementKind::F16, ElementKind::I64) => unsafe {
257 baracuda_kernels_sys::baracuda_kernels_cast_f16_i64_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
258 },
259 (ElementKind::Bf16, ElementKind::F32) => unsafe {
261 baracuda_kernels_sys::baracuda_kernels_cast_bf16_f32_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
262 },
263 (ElementKind::Bf16, ElementKind::F64) => unsafe {
264 baracuda_kernels_sys::baracuda_kernels_cast_bf16_f64_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
265 },
266 (ElementKind::Bf16, ElementKind::F16) => unsafe {
267 baracuda_kernels_sys::baracuda_kernels_cast_bf16_f16_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
268 },
269 (ElementKind::Bf16, ElementKind::Bf16) => unsafe {
270 baracuda_kernels_sys::baracuda_kernels_cast_bf16_bf16_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
271 },
272 (ElementKind::Bf16, ElementKind::I32) => unsafe {
273 baracuda_kernels_sys::baracuda_kernels_cast_bf16_i32_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
274 },
275 (ElementKind::Bf16, ElementKind::I64) => unsafe {
276 baracuda_kernels_sys::baracuda_kernels_cast_bf16_i64_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
277 },
278 (ElementKind::I32, ElementKind::F32) => unsafe {
280 baracuda_kernels_sys::baracuda_kernels_cast_i32_f32_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
281 },
282 (ElementKind::I32, ElementKind::F64) => unsafe {
283 baracuda_kernels_sys::baracuda_kernels_cast_i32_f64_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
284 },
285 (ElementKind::I32, ElementKind::F16) => unsafe {
286 baracuda_kernels_sys::baracuda_kernels_cast_i32_f16_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
287 },
288 (ElementKind::I32, ElementKind::Bf16) => unsafe {
289 baracuda_kernels_sys::baracuda_kernels_cast_i32_bf16_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
290 },
291 (ElementKind::I32, ElementKind::I32) => unsafe {
292 baracuda_kernels_sys::baracuda_kernels_cast_i32_i32_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
293 },
294 (ElementKind::I32, ElementKind::I64) => unsafe {
295 baracuda_kernels_sys::baracuda_kernels_cast_i32_i64_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
296 },
297 (ElementKind::I64, ElementKind::F32) => unsafe {
299 baracuda_kernels_sys::baracuda_kernels_cast_i64_f32_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
300 },
301 (ElementKind::I64, ElementKind::F64) => unsafe {
302 baracuda_kernels_sys::baracuda_kernels_cast_i64_f64_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
303 },
304 (ElementKind::I64, ElementKind::F16) => unsafe {
305 baracuda_kernels_sys::baracuda_kernels_cast_i64_f16_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
306 },
307 (ElementKind::I64, ElementKind::Bf16) => unsafe {
308 baracuda_kernels_sys::baracuda_kernels_cast_i64_bf16_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
309 },
310 (ElementKind::I64, ElementKind::I32) => unsafe {
311 baracuda_kernels_sys::baracuda_kernels_cast_i64_i32_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
312 },
313 (ElementKind::I64, ElementKind::I64) => unsafe {
314 baracuda_kernels_sys::baracuda_kernels_cast_i64_i64_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
315 },
316 _ => {
317 return Err(Error::Unsupported(
318 "baracuda-kernels::CastPlan::run reached an unimplemented \
319 (TIn, TOut) pair — select() should have caught this",
320 ));
321 }
322 };
323 map_status(status)
324 }
325}
326
327fn pair_in_scope(input: ElementKind, output: ElementKind) -> bool {
329 fn allowed(k: ElementKind) -> bool {
330 matches!(
331 k,
332 ElementKind::F32
333 | ElementKind::F64
334 | ElementKind::F16
335 | ElementKind::Bf16
336 | ElementKind::I32
337 | ElementKind::I64
338 )
339 }
340 allowed(input) && allowed(output)
341}
342
343fn map_status(code: i32) -> Result<()> {
344 match code {
345 0 => Ok(()),
346 1 => Err(Error::MisalignedOperand),
347 2 => Err(Error::InvalidProblem(
348 "baracuda-kernels-sys reported invalid problem",
349 )),
350 3 => Err(Error::Unsupported(
351 "baracuda-kernels-sys reported unsupported configuration",
352 )),
353 4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
354 n => Err(Error::CutlassInternal(n)),
355 }
356}