Skip to main content

baracuda_kernels/elementwise/
affine.rs

1//! Affine plan — fused `y[i] = a * x[i] + b` with scalar `a`, `b`.
2//!
3//! Phase 3 fanout from `fuel-cuda-kernels/affine.cu`. Same-dtype-input
4//! / same-dtype-output but carries two scalar parameters (`a`, `b`),
5//! so it gets its own plan shape instead of routing through the unified
6//! [`crate::UnaryPlan`].
7//!
8//! Today wired across `{f32, f64, f16, bf16, i32, i64}` — every
9//! [`Element`]-implementing numeric scalar in the unified Plan layer.
10//! `u8` / `i8` kernels also ship in `baracuda-kernels-sys` but those
11//! types live on the `IntElement` family with its own (deferred) plan
12//! shape. f16 / bf16 compute through f32 internally; `a` / `b` cross
13//! the FFI as `f32` for those dtypes (matching the rest of the
14//! elementwise family's f32-accumulator precision-guarantee contract).
15//! The kernel is contig-only — baracuda's plan layer materializes
16//! strided views upstream.
17
18use core::ffi::c_void;
19use core::marker::PhantomData;
20
21use baracuda_cutlass::{Error, Result};
22use baracuda_driver::Stream;
23use baracuda_kernels_types::{
24    ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
25    PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, UnaryKind, Workspace,
26};
27use half::{bf16, f16};
28
29/// Descriptor for an affine op.
30///
31/// `a` and `b` are the scalar multiplier / bias. Both share the
32/// output element type (no cross-dtype scalar parameters). `element`
33/// must match `T::KIND` at `select` time.
34#[derive(Copy, Clone, Debug)]
35pub struct AffineDescriptor<T: Element> {
36    /// Number of elements in input and output.
37    pub numel: i32,
38    /// Multiplier — same dtype as input / output.
39    pub a: T,
40    /// Additive bias — same dtype as input / output.
41    pub b: T,
42    /// Input / output element type. Must equal `T::KIND`.
43    pub element: ElementKind,
44}
45
46/// Args bundle for an affine launch.
47pub struct AffineArgs<'a, T: Element> {
48    /// Input tensor — rank-1 contiguous view.
49    pub input: TensorRef<'a, T, 1>,
50    /// Output tensor — rank-1 contiguous view, same numel as input.
51    pub output: TensorMut<'a, T, 1>,
52}
53
54/// Affine plan.
55pub struct AffinePlan<T: Element> {
56    desc: AffineDescriptor<T>,
57    sku: KernelSku,
58    _marker: PhantomData<T>,
59}
60
61impl<T: Element> AffinePlan<T> {
62    /// Pick a kernel for `desc`.
63    pub fn select(
64        _stream: &Stream,
65        desc: &AffineDescriptor<T>,
66        _pref: PlanPreference,
67    ) -> Result<Self> {
68        if desc.element != T::KIND {
69            return Err(Error::Unsupported(
70                "baracuda-kernels::AffinePlan: descriptor element != type parameter T",
71            ));
72        }
73        if desc.numel < 0 {
74            return Err(Error::InvalidProblem(
75                "baracuda-kernels::AffinePlan: numel must be non-negative",
76            ));
77        }
78        if !dtype_in_scope(T::KIND) {
79            return Err(Error::Unsupported(
80                "baracuda-kernels::AffinePlan: dtype not wired today; supported set is \
81                 {f32, f64, f16, bf16, i32, i64}",
82            ));
83        }
84
85        let precision_guarantee = PrecisionGuarantee {
86            math_precision: MathPrecision::F32,
87            accumulator: ElementKind::F32,
88            bit_stable_on_same_hardware: true,
89            deterministic: true,
90        };
91        let sku = KernelSku {
92            category: OpCategory::UnaryElementwise,
93            op: UnaryKind::Affine as u16,
94            element: T::KIND,
95            aux_element: None,
96            layout: None,
97            epilogue: None,
98            arch: ArchSku::Sm80,
99            backend: BackendKind::Bespoke,
100            precision_guarantee,
101        };
102        Ok(Self {
103            desc: *desc,
104            sku,
105            _marker: PhantomData,
106        })
107    }
108
109    /// Validate args.
110    pub fn can_implement(&self, args: &AffineArgs<'_, T>) -> Result<()> {
111        let expected = self.desc.numel as i64;
112        if args.input.numel() != expected {
113            return Err(Error::InvalidProblem(
114                "baracuda-kernels::AffinePlan: input numel mismatch with descriptor",
115            ));
116        }
117        if args.output.numel() != expected {
118            return Err(Error::InvalidProblem(
119                "baracuda-kernels::AffinePlan: output numel mismatch with descriptor",
120            ));
121        }
122        if (args.input.data.len() as i64) < expected {
123            return Err(Error::BufferTooSmall {
124                needed: expected as usize,
125                got: args.input.data.len(),
126            });
127        }
128        if (args.output.data.len() as i64) < expected {
129            return Err(Error::BufferTooSmall {
130                needed: expected as usize,
131                got: args.output.data.len(),
132            });
133        }
134        Ok(())
135    }
136
137    /// Workspace size in bytes. Always `0`.
138    #[inline]
139    pub fn workspace_size(&self) -> usize {
140        0
141    }
142
143    /// Identity of the kernel this plan picked.
144    #[inline]
145    pub fn sku(&self) -> KernelSku {
146        self.sku
147    }
148
149    /// Numerical guarantees for this plan's kernel.
150    #[inline]
151    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
152        self.sku.precision_guarantee
153    }
154
155    /// Launch.
156    ///
157    /// Dispatch policy:
158    /// * Both `input` and `output` canonical contiguous (stride matches
159    ///   `[1]` for rank-1) → contig FFI fast path (`affine_<dtype>_run`).
160    /// * Either side has non-trivial strides (broadcast, flipped, or
161    ///   strided view) → strided FFI sibling (`affine_<dtype>_strided_run`).
162    pub fn run(
163        &self,
164        stream: &Stream,
165        _workspace: Workspace<'_>,
166        args: AffineArgs<'_, T>,
167    ) -> Result<()> {
168        self.can_implement(&args)?;
169        let numel = self.desc.numel as i64;
170        if numel == 0 {
171            return Ok(());
172        }
173        let x_ptr = args.input.data.as_raw().0 as *const c_void;
174        let y_ptr = args.output.data.as_raw().0 as *mut c_void;
175        let stream_ptr = stream.as_raw() as *mut c_void;
176
177        // Contig fast path: rank-1, input.stride == [1], output.stride == [1].
178        let contig =
179            is_canonical_contig(&args.input.shape, &args.input.stride)
180            && is_canonical_contig(&args.output.shape, &args.output.stride);
181
182        // SAFETY: each match arm only fires when `T::KIND` equals the
183        // matched ElementKind. The `transmute_copy` of `desc.a` /
184        // `desc.b` preserves the bit pattern across monomorphized
185        // layouts of the same logical type. f16 / bf16 are upcast to
186        // f32 before crossing the FFI.
187        let status = unsafe {
188            if contig {
189                match T::KIND {
190                    ElementKind::F32 => {
191                        let a: f32 = core::mem::transmute_copy(&self.desc.a);
192                        let b: f32 = core::mem::transmute_copy(&self.desc.b);
193                        baracuda_kernels_sys::baracuda_kernels_affine_f32_run(
194                            numel, x_ptr, y_ptr, a, b, core::ptr::null_mut(), 0, stream_ptr,
195                        )
196                    }
197                    ElementKind::F64 => {
198                        let a: f64 = core::mem::transmute_copy(&self.desc.a);
199                        let b: f64 = core::mem::transmute_copy(&self.desc.b);
200                        baracuda_kernels_sys::baracuda_kernels_affine_f64_run(
201                            numel, x_ptr, y_ptr, a, b, core::ptr::null_mut(), 0, stream_ptr,
202                        )
203                    }
204                    ElementKind::I32 => {
205                        let a: i32 = core::mem::transmute_copy(&self.desc.a);
206                        let b: i32 = core::mem::transmute_copy(&self.desc.b);
207                        baracuda_kernels_sys::baracuda_kernels_affine_i32_run(
208                            numel, x_ptr, y_ptr, a, b, core::ptr::null_mut(), 0, stream_ptr,
209                        )
210                    }
211                    ElementKind::I64 => {
212                        let a: i64 = core::mem::transmute_copy(&self.desc.a);
213                        let b: i64 = core::mem::transmute_copy(&self.desc.b);
214                        baracuda_kernels_sys::baracuda_kernels_affine_i64_run(
215                            numel, x_ptr, y_ptr, a, b, core::ptr::null_mut(), 0, stream_ptr,
216                        )
217                    }
218                    ElementKind::F16 => {
219                        let a: f16 = core::mem::transmute_copy(&self.desc.a);
220                        let b: f16 = core::mem::transmute_copy(&self.desc.b);
221                        baracuda_kernels_sys::baracuda_kernels_affine_f16_run(
222                            numel, x_ptr, y_ptr, a.to_f32(), b.to_f32(),
223                            core::ptr::null_mut(), 0, stream_ptr,
224                        )
225                    }
226                    ElementKind::Bf16 => {
227                        let a: bf16 = core::mem::transmute_copy(&self.desc.a);
228                        let b: bf16 = core::mem::transmute_copy(&self.desc.b);
229                        baracuda_kernels_sys::baracuda_kernels_affine_bf16_run(
230                            numel, x_ptr, y_ptr, a.to_f32(), b.to_f32(),
231                            core::ptr::null_mut(), 0, stream_ptr,
232                        )
233                    }
234                    _ => {
235                        return Err(Error::Unsupported(
236                            "baracuda-kernels::AffinePlan::run reached an unimplemented dtype \
237                             — select() should have caught this",
238                        ));
239                    }
240                }
241            } else {
242                // Strided slow path. Pass `shape` (logical, equal on
243                // both sides) plus signed-i64 strides for x and y.
244                let shape_ptr = args.input.shape.as_ptr();
245                let stride_x_ptr = args.input.stride.as_ptr();
246                let stride_y_ptr = args.output.stride.as_ptr();
247                let rank: i32 = 1;
248                match T::KIND {
249                    ElementKind::F32 => {
250                        let a: f32 = core::mem::transmute_copy(&self.desc.a);
251                        let b: f32 = core::mem::transmute_copy(&self.desc.b);
252                        baracuda_kernels_sys::baracuda_kernels_affine_f32_strided_run(
253                            numel, rank, shape_ptr, stride_x_ptr, stride_y_ptr,
254                            x_ptr, y_ptr, a, b, core::ptr::null_mut(), 0, stream_ptr,
255                        )
256                    }
257                    ElementKind::F64 => {
258                        let a: f64 = core::mem::transmute_copy(&self.desc.a);
259                        let b: f64 = core::mem::transmute_copy(&self.desc.b);
260                        baracuda_kernels_sys::baracuda_kernels_affine_f64_strided_run(
261                            numel, rank, shape_ptr, stride_x_ptr, stride_y_ptr,
262                            x_ptr, y_ptr, a, b, core::ptr::null_mut(), 0, stream_ptr,
263                        )
264                    }
265                    ElementKind::I32 => {
266                        let a: i32 = core::mem::transmute_copy(&self.desc.a);
267                        let b: i32 = core::mem::transmute_copy(&self.desc.b);
268                        baracuda_kernels_sys::baracuda_kernels_affine_i32_strided_run(
269                            numel, rank, shape_ptr, stride_x_ptr, stride_y_ptr,
270                            x_ptr, y_ptr, a, b, core::ptr::null_mut(), 0, stream_ptr,
271                        )
272                    }
273                    ElementKind::I64 => {
274                        let a: i64 = core::mem::transmute_copy(&self.desc.a);
275                        let b: i64 = core::mem::transmute_copy(&self.desc.b);
276                        baracuda_kernels_sys::baracuda_kernels_affine_i64_strided_run(
277                            numel, rank, shape_ptr, stride_x_ptr, stride_y_ptr,
278                            x_ptr, y_ptr, a, b, core::ptr::null_mut(), 0, stream_ptr,
279                        )
280                    }
281                    ElementKind::F16 => {
282                        let a: f16 = core::mem::transmute_copy(&self.desc.a);
283                        let b: f16 = core::mem::transmute_copy(&self.desc.b);
284                        baracuda_kernels_sys::baracuda_kernels_affine_f16_strided_run(
285                            numel, rank, shape_ptr, stride_x_ptr, stride_y_ptr,
286                            x_ptr, y_ptr, a.to_f32(), b.to_f32(),
287                            core::ptr::null_mut(), 0, stream_ptr,
288                        )
289                    }
290                    ElementKind::Bf16 => {
291                        let a: bf16 = core::mem::transmute_copy(&self.desc.a);
292                        let b: bf16 = core::mem::transmute_copy(&self.desc.b);
293                        baracuda_kernels_sys::baracuda_kernels_affine_bf16_strided_run(
294                            numel, rank, shape_ptr, stride_x_ptr, stride_y_ptr,
295                            x_ptr, y_ptr, a.to_f32(), b.to_f32(),
296                            core::ptr::null_mut(), 0, stream_ptr,
297                        )
298                    }
299                    _ => {
300                        return Err(Error::Unsupported(
301                            "baracuda-kernels::AffinePlan::run reached an unimplemented dtype \
302                             — select() should have caught this",
303                        ));
304                    }
305                }
306            }
307        };
308        map_status(status)
309    }
310}
311
312/// Returns `true` iff `stride` matches the canonical row-major contiguous
313/// layout for `shape` (rightmost axis stride 1, each prior axis multiplies
314/// by the extent to its right). Used by [`AffinePlan::run`] to pick
315/// between the contig fast path and the strided slow path.
316///
317/// A broadcast axis (stride 0) is **not** canonical contig.
318#[inline]
319fn is_canonical_contig<const N: usize>(shape: &[i32; N], stride: &[i64; N]) -> bool {
320    if N == 0 {
321        return true;
322    }
323    let mut expected: i64 = 1;
324    let mut i = N;
325    while i > 0 {
326        i -= 1;
327        if stride[i] != expected {
328            return false;
329        }
330        expected = expected.saturating_mul(shape[i] as i64);
331    }
332    true
333}
334
335fn dtype_in_scope(k: ElementKind) -> bool {
336    matches!(
337        k,
338        ElementKind::F32
339            | ElementKind::F64
340            | ElementKind::F16
341            | ElementKind::Bf16
342            | ElementKind::I32
343            | ElementKind::I64
344    )
345}
346
347fn map_status(code: i32) -> Result<()> {
348    match code {
349        0 => Ok(()),
350        1 => Err(Error::MisalignedOperand),
351        2 => Err(Error::InvalidProblem(
352            "baracuda-kernels-sys reported invalid problem",
353        )),
354        3 => Err(Error::Unsupported(
355            "baracuda-kernels-sys reported unsupported configuration",
356        )),
357        4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
358        n => Err(Error::CutlassInternal(n)),
359    }
360}