Skip to main content

baracuda_kernels/elementwise/
cast.rs

1//! Cast plan — heterogeneous dtype conversion (`y = (TOut) x`).
2//!
3//! Distinct from the same-dtype [`crate::UnaryPlan`] family because the
4//! input and output element types differ. The plan is generic over
5//! both `TIn: Element` and `TOut: Element`, with [`select`] dispatching
6//! on the runtime `(input_element, output_element)` pair.
7//!
8//! Today wired (at the Plan-layer dispatch): every cross-dtype pair in
9//! `{f32, f64, f16, bf16, i32, i64} × {same}`. The kernels in
10//! `baracuda-kernels-sys` cover a broader set (also `u8` / `i8`
11//! endpoints) — those would route via the [`IntElement`] family with a
12//! distinct plan shape, deferred. Bool is not wired today either —
13//! its truthiness convention (`x != 0 → 1`) would need a dedicated
14//! kernel rather than a pure `static_cast`. FP8 endpoints land in a
15//! future Phase-2 FP8 fanout. The kernel itself is contig-only —
16//! baracuda's plan layer materializes strided views upstream.
17//!
18//! [`IntElement`]: baracuda_kernels_types::IntElement
19
20use core::ffi::c_void;
21use core::marker::PhantomData;
22
23use baracuda_cutlass::{Error, Result};
24use baracuda_driver::Stream;
25use baracuda_kernels_types::{
26    ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
27    PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, UnaryKind, Workspace,
28};
29
30/// Descriptor for a dtype cast.
31///
32/// `numel` is the total number of elements (both input and output have
33/// the same number — cast doesn't change shape). `input_element` must
34/// match `TIn::KIND` and `output_element` must match `TOut::KIND` at
35/// `select` time.
36#[derive(Copy, Clone, Debug)]
37pub struct CastDescriptor {
38    /// Number of elements in both input and output.
39    pub numel: i32,
40    /// Input element type.
41    pub input_element: ElementKind,
42    /// Output element type.
43    pub output_element: ElementKind,
44}
45
46/// Args bundle for a cast launch. Both `input` and `output` are
47/// rank-1 contiguous views over `numel` elements.
48pub struct CastArgs<'a, TIn: Element, TOut: Element> {
49    /// Input — `TIn` element type.
50    pub input: TensorRef<'a, TIn, 1>,
51    /// Output — `TOut` element type.
52    pub output: TensorMut<'a, TOut, 1>,
53}
54
55/// Cast plan.
56///
57/// `TIn` is the input element type. `TOut` is the output element type.
58pub struct CastPlan<TIn: Element, TOut: Element> {
59    desc: CastDescriptor,
60    sku: KernelSku,
61    _marker_in: PhantomData<TIn>,
62    _marker_out: PhantomData<TOut>,
63}
64
65impl<TIn: Element, TOut: Element> CastPlan<TIn, TOut> {
66    /// Pick a kernel for `desc`.
67    pub fn select(
68        _stream: &Stream,
69        desc: &CastDescriptor,
70        _pref: PlanPreference,
71    ) -> Result<Self> {
72        if desc.input_element != TIn::KIND {
73            return Err(Error::Unsupported(
74                "baracuda-kernels::CastPlan: descriptor input_element != type parameter TIn",
75            ));
76        }
77        if desc.output_element != TOut::KIND {
78            return Err(Error::Unsupported(
79                "baracuda-kernels::CastPlan: descriptor output_element != type parameter TOut",
80            ));
81        }
82        if desc.numel < 0 {
83            return Err(Error::InvalidProblem(
84                "baracuda-kernels::CastPlan: numel must be non-negative",
85            ));
86        }
87        if !pair_in_scope(TIn::KIND, TOut::KIND) {
88            return Err(Error::Unsupported(
89                "baracuda-kernels::CastPlan: this (TIn, TOut) pair is not wired today; \
90                 supported set is {f32, f64, f16, bf16, i32, i64} × {same}",
91            ));
92        }
93
94        // Cast is a pure copy + numeric conversion — no fused math.
95        // Precision guarantee reflects the f32 detour for half-precision
96        // endpoints but is a no-op (bit-stable copy) for same-dtype
97        // pairs and for purely integer cross-casts.
98        let precision_guarantee = PrecisionGuarantee {
99            math_precision: MathPrecision::F32,
100            accumulator: ElementKind::F32,
101            bit_stable_on_same_hardware: true,
102            deterministic: true,
103        };
104        let sku = KernelSku {
105            category: OpCategory::UnaryElementwise,
106            op: UnaryKind::Cast as u16,
107            element: TIn::KIND,
108            aux_element: Some(TOut::KIND),
109            layout: None,
110            epilogue: None,
111            arch: ArchSku::Sm80,
112            backend: BackendKind::Bespoke,
113            precision_guarantee,
114        };
115        Ok(Self {
116            desc: *desc,
117            sku,
118            _marker_in: PhantomData,
119            _marker_out: PhantomData,
120        })
121    }
122
123    /// Validate args.
124    pub fn can_implement(&self, args: &CastArgs<'_, TIn, TOut>) -> Result<()> {
125        let expected = self.desc.numel as i64;
126        if args.input.numel() != expected {
127            return Err(Error::InvalidProblem(
128                "baracuda-kernels::CastPlan: input numel mismatch with descriptor",
129            ));
130        }
131        if args.output.numel() != expected {
132            return Err(Error::InvalidProblem(
133                "baracuda-kernels::CastPlan: output numel mismatch with descriptor",
134            ));
135        }
136        if (args.input.data.len() as i64) < expected {
137            return Err(Error::BufferTooSmall {
138                needed: expected as usize,
139                got: args.input.data.len(),
140            });
141        }
142        if (args.output.data.len() as i64) < expected {
143            return Err(Error::BufferTooSmall {
144                needed: expected as usize,
145                got: args.output.data.len(),
146            });
147        }
148        Ok(())
149    }
150
151    /// Workspace size in bytes. Always `0`.
152    #[inline]
153    pub fn workspace_size(&self) -> usize {
154        0
155    }
156
157    /// Identity of the kernel this plan picked.
158    #[inline]
159    pub fn sku(&self) -> KernelSku {
160        self.sku
161    }
162
163    /// Numerical guarantees for this plan's kernel.
164    #[inline]
165    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
166        self.sku.precision_guarantee
167    }
168
169    /// Launch.
170    pub fn run(
171        &self,
172        stream: &Stream,
173        _workspace: Workspace<'_>,
174        args: CastArgs<'_, TIn, TOut>,
175    ) -> Result<()> {
176        self.can_implement(&args)?;
177        let numel = self.desc.numel as i64;
178        if numel == 0 {
179            return Ok(());
180        }
181        let x_ptr = args.input.data.as_raw().0 as *const c_void;
182        let y_ptr = args.output.data.as_raw().0 as *mut c_void;
183        let stream_ptr = stream.as_raw() as *mut c_void;
184
185        // Dispatch table — each cell calls the matching
186        // `cast_<sin>_<sout>` FFI. The `match` is on the runtime kinds
187        // (which `select` already proved consistent with TIn / TOut);
188        // unreachable arm is the "select bug" guard.
189        let status = match (TIn::KIND, TOut::KIND) {
190            // f32 -> *
191            (ElementKind::F32, ElementKind::F32) => unsafe {
192                baracuda_kernels_sys::baracuda_kernels_cast_f32_f32_run(
193                    numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
194                )
195            },
196            (ElementKind::F32, ElementKind::F64) => unsafe {
197                baracuda_kernels_sys::baracuda_kernels_cast_f32_f64_run(
198                    numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
199                )
200            },
201            (ElementKind::F32, ElementKind::F16) => unsafe {
202                baracuda_kernels_sys::baracuda_kernels_cast_f32_f16_run(
203                    numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
204                )
205            },
206            (ElementKind::F32, ElementKind::Bf16) => unsafe {
207                baracuda_kernels_sys::baracuda_kernels_cast_f32_bf16_run(
208                    numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
209                )
210            },
211            (ElementKind::F32, ElementKind::I32) => unsafe {
212                baracuda_kernels_sys::baracuda_kernels_cast_f32_i32_run(
213                    numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
214                )
215            },
216            (ElementKind::F32, ElementKind::I64) => unsafe {
217                baracuda_kernels_sys::baracuda_kernels_cast_f32_i64_run(
218                    numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
219                )
220            },
221            // f64 -> *
222            (ElementKind::F64, ElementKind::F32) => unsafe {
223                baracuda_kernels_sys::baracuda_kernels_cast_f64_f32_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
224            },
225            (ElementKind::F64, ElementKind::F64) => unsafe {
226                baracuda_kernels_sys::baracuda_kernels_cast_f64_f64_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
227            },
228            (ElementKind::F64, ElementKind::F16) => unsafe {
229                baracuda_kernels_sys::baracuda_kernels_cast_f64_f16_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
230            },
231            (ElementKind::F64, ElementKind::Bf16) => unsafe {
232                baracuda_kernels_sys::baracuda_kernels_cast_f64_bf16_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
233            },
234            (ElementKind::F64, ElementKind::I32) => unsafe {
235                baracuda_kernels_sys::baracuda_kernels_cast_f64_i32_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
236            },
237            (ElementKind::F64, ElementKind::I64) => unsafe {
238                baracuda_kernels_sys::baracuda_kernels_cast_f64_i64_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
239            },
240            // f16 -> *
241            (ElementKind::F16, ElementKind::F32) => unsafe {
242                baracuda_kernels_sys::baracuda_kernels_cast_f16_f32_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
243            },
244            (ElementKind::F16, ElementKind::F64) => unsafe {
245                baracuda_kernels_sys::baracuda_kernels_cast_f16_f64_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
246            },
247            (ElementKind::F16, ElementKind::F16) => unsafe {
248                baracuda_kernels_sys::baracuda_kernels_cast_f16_f16_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
249            },
250            (ElementKind::F16, ElementKind::Bf16) => unsafe {
251                baracuda_kernels_sys::baracuda_kernels_cast_f16_bf16_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
252            },
253            (ElementKind::F16, ElementKind::I32) => unsafe {
254                baracuda_kernels_sys::baracuda_kernels_cast_f16_i32_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
255            },
256            (ElementKind::F16, ElementKind::I64) => unsafe {
257                baracuda_kernels_sys::baracuda_kernels_cast_f16_i64_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
258            },
259            // bf16 -> *
260            (ElementKind::Bf16, ElementKind::F32) => unsafe {
261                baracuda_kernels_sys::baracuda_kernels_cast_bf16_f32_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
262            },
263            (ElementKind::Bf16, ElementKind::F64) => unsafe {
264                baracuda_kernels_sys::baracuda_kernels_cast_bf16_f64_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
265            },
266            (ElementKind::Bf16, ElementKind::F16) => unsafe {
267                baracuda_kernels_sys::baracuda_kernels_cast_bf16_f16_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
268            },
269            (ElementKind::Bf16, ElementKind::Bf16) => unsafe {
270                baracuda_kernels_sys::baracuda_kernels_cast_bf16_bf16_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
271            },
272            (ElementKind::Bf16, ElementKind::I32) => unsafe {
273                baracuda_kernels_sys::baracuda_kernels_cast_bf16_i32_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
274            },
275            (ElementKind::Bf16, ElementKind::I64) => unsafe {
276                baracuda_kernels_sys::baracuda_kernels_cast_bf16_i64_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
277            },
278            // i32 -> *
279            (ElementKind::I32, ElementKind::F32) => unsafe {
280                baracuda_kernels_sys::baracuda_kernels_cast_i32_f32_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
281            },
282            (ElementKind::I32, ElementKind::F64) => unsafe {
283                baracuda_kernels_sys::baracuda_kernels_cast_i32_f64_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
284            },
285            (ElementKind::I32, ElementKind::F16) => unsafe {
286                baracuda_kernels_sys::baracuda_kernels_cast_i32_f16_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
287            },
288            (ElementKind::I32, ElementKind::Bf16) => unsafe {
289                baracuda_kernels_sys::baracuda_kernels_cast_i32_bf16_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
290            },
291            (ElementKind::I32, ElementKind::I32) => unsafe {
292                baracuda_kernels_sys::baracuda_kernels_cast_i32_i32_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
293            },
294            (ElementKind::I32, ElementKind::I64) => unsafe {
295                baracuda_kernels_sys::baracuda_kernels_cast_i32_i64_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
296            },
297            // i64 -> *
298            (ElementKind::I64, ElementKind::F32) => unsafe {
299                baracuda_kernels_sys::baracuda_kernels_cast_i64_f32_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
300            },
301            (ElementKind::I64, ElementKind::F64) => unsafe {
302                baracuda_kernels_sys::baracuda_kernels_cast_i64_f64_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
303            },
304            (ElementKind::I64, ElementKind::F16) => unsafe {
305                baracuda_kernels_sys::baracuda_kernels_cast_i64_f16_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
306            },
307            (ElementKind::I64, ElementKind::Bf16) => unsafe {
308                baracuda_kernels_sys::baracuda_kernels_cast_i64_bf16_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
309            },
310            (ElementKind::I64, ElementKind::I32) => unsafe {
311                baracuda_kernels_sys::baracuda_kernels_cast_i64_i32_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
312            },
313            (ElementKind::I64, ElementKind::I64) => unsafe {
314                baracuda_kernels_sys::baracuda_kernels_cast_i64_i64_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
315            },
316            _ => {
317                return Err(Error::Unsupported(
318                    "baracuda-kernels::CastPlan::run reached an unimplemented \
319                     (TIn, TOut) pair — select() should have caught this",
320                ));
321            }
322        };
323        map_status(status)
324    }
325}
326
327/// Supported (TIn, TOut) pairs — same set the dispatch table covers.
328fn pair_in_scope(input: ElementKind, output: ElementKind) -> bool {
329    fn allowed(k: ElementKind) -> bool {
330        matches!(
331            k,
332            ElementKind::F32
333                | ElementKind::F64
334                | ElementKind::F16
335                | ElementKind::Bf16
336                | ElementKind::I32
337                | ElementKind::I64
338        )
339    }
340    allowed(input) && allowed(output)
341}
342
343fn map_status(code: i32) -> Result<()> {
344    match code {
345        0 => Ok(()),
346        1 => Err(Error::MisalignedOperand),
347        2 => Err(Error::InvalidProblem(
348            "baracuda-kernels-sys reported invalid problem",
349        )),
350        3 => Err(Error::Unsupported(
351            "baracuda-kernels-sys reported unsupported configuration",
352        )),
353        4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
354        n => Err(Error::CutlassInternal(n)),
355    }
356}