Skip to main content

baracuda_kernels/shape_layout/
fill.rs

1//! Fill plan — `y[i] = value` for all `i`.
2//!
3//! Phase 3 fanout from `fuel-cuda-kernels/fill.cu`. Lives under the
4//! shape-layout family because its descriptor produces an output
5//! tensor with no input dependency — same family slot as `torch.full`.
6//!
7//! Today wired across `{f32, f64, f16, bf16, i32, i64}` — every
8//! [`Element`]-implementing numeric scalar baracuda exposes through
9//! the unified Plan layer. `u8` / `i8` kernels also ship in
10//! `baracuda-kernels-sys` but those types live on the `IntElement`
11//! family with its own (deferred) plan shape. f16 / bf16 transport
12//! their `value` over the FFI as a raw `u16` bit pattern; the
13//! safe-plan layer below performs the bit cast.
14
15use core::ffi::c_void;
16use core::marker::PhantomData;
17
18use baracuda_cutlass::{Error, Result};
19use baracuda_driver::Stream;
20use baracuda_kernels_types::{
21    ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
22    PlanPreference, PrecisionGuarantee, ShapeLayoutKind, TensorMut, Workspace,
23};
24use half::{bf16, f16};
25
26/// Descriptor for a fill op.
27///
28/// `value` is consumed in-place by the launcher (no descriptor-time
29/// dtype conversion). `element` must match `T::KIND` at `select` time.
30#[derive(Copy, Clone, Debug)]
31pub struct FillDescriptor<T: Element> {
32    /// Number of elements to write.
33    pub numel: i32,
34    /// Scalar to broadcast across the output. Same dtype as the output
35    /// tensor (no internal conversion).
36    pub value: T,
37    /// Output element type. Must equal `T::KIND`.
38    pub element: ElementKind,
39}
40
41/// Args bundle for a fill launch.
42pub struct FillArgs<'a, T: Element> {
43    /// Output tensor — rank-1 contiguous view over `numel` elements.
44    pub output: TensorMut<'a, T, 1>,
45}
46
47/// Fill plan.
48///
49/// `y[i] = value` for all `i` (PyTorch `torch.full`).
50///
51/// **When to use**: zero-init / constant-init of output buffers, or
52/// any broadcast-fill. No BW — a constant tensor has zero gradient.
53///
54/// **Dtypes**: `{f32, f64, f16, bf16, i32, i64}` — every numeric
55/// [`Element`] dtype baracuda exposes through the unified Plan
56/// layer. `u8` / `i8` ship in the sys crate but on the `IntElement`
57/// family (deferred plan shape).
58///
59/// **Shape limits**: flat `[numel]`; `numel ≥ 0`.
60///
61/// **Workspace**: none.
62///
63/// **Precision guarantee**: deterministic, bit-stable, bit-exact.
64/// f16 / bf16 transport `value` via raw `u16` bit pattern; the safe
65/// plan layer performs the bit cast.
66pub struct FillPlan<T: Element> {
67    desc: FillDescriptor<T>,
68    sku: KernelSku,
69    _marker: PhantomData<T>,
70}
71
72impl<T: Element> FillPlan<T> {
73    /// Pick a kernel for `desc`.
74    pub fn select(
75        _stream: &Stream,
76        desc: &FillDescriptor<T>,
77        _pref: PlanPreference,
78    ) -> Result<Self> {
79        if desc.element != T::KIND {
80            return Err(Error::Unsupported(
81                "baracuda-kernels::FillPlan: descriptor element != type parameter T",
82            ));
83        }
84        if desc.numel < 0 {
85            return Err(Error::InvalidProblem(
86                "baracuda-kernels::FillPlan: numel must be non-negative",
87            ));
88        }
89        if !dtype_in_scope(T::KIND) {
90            return Err(Error::Unsupported(
91                "baracuda-kernels::FillPlan: dtype not wired today; supported set is \
92                 {f32, f64, f16, bf16, i32, i64}",
93            ));
94        }
95
96        // Pure copy — no arithmetic.
97        let precision_guarantee = PrecisionGuarantee {
98            math_precision: MathPrecision::F32,
99            accumulator: ElementKind::F32,
100            bit_stable_on_same_hardware: true,
101            deterministic: true,
102        };
103        let sku = KernelSku {
104            category: OpCategory::ShapeLayout,
105            op: ShapeLayoutKind::Fill as u16,
106            element: T::KIND,
107            aux_element: None,
108            layout: None,
109            epilogue: None,
110            arch: ArchSku::Sm80,
111            backend: BackendKind::Bespoke,
112            precision_guarantee,
113        };
114        Ok(Self {
115            desc: *desc,
116            sku,
117            _marker: PhantomData,
118        })
119    }
120
121    /// Validate args.
122    pub fn can_implement(&self, args: &FillArgs<'_, T>) -> Result<()> {
123        let expected = self.desc.numel as i64;
124        if args.output.numel() != expected {
125            return Err(Error::InvalidProblem(
126                "baracuda-kernels::FillPlan: output numel mismatch with descriptor",
127            ));
128        }
129        if (args.output.data.len() as i64) < expected {
130            return Err(Error::BufferTooSmall {
131                needed: expected as usize,
132                got: args.output.data.len(),
133            });
134        }
135        Ok(())
136    }
137
138    /// Workspace size in bytes. Always `0`.
139    #[inline]
140    pub fn workspace_size(&self) -> usize {
141        0
142    }
143
144    /// Identity of the kernel this plan picked.
145    #[inline]
146    pub fn sku(&self) -> KernelSku {
147        self.sku
148    }
149
150    /// Numerical guarantees for this plan's kernel.
151    #[inline]
152    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
153        self.sku.precision_guarantee
154    }
155
156    /// Launch.
157    pub fn run(
158        &self,
159        stream: &Stream,
160        _workspace: Workspace<'_>,
161        args: FillArgs<'_, T>,
162    ) -> Result<()> {
163        self.can_implement(&args)?;
164        let numel = self.desc.numel as i64;
165        if numel == 0 {
166            return Ok(());
167        }
168        let y_ptr = args.output.data.as_raw().0 as *mut c_void;
169        let stream_ptr = stream.as_raw() as *mut c_void;
170
171        // Dispatch by runtime element kind. The descriptor's `value`
172        // is already typed as `T` at the Rust level — we just need to
173        // pick the right FFI per dtype. For f16 / bf16 the value
174        // crosses the FFI as a u16 bit pattern.
175        //
176        // SAFETY: each match arm only fires when `T::KIND` equals the
177        // matched ElementKind, by the construction of `T: Element`.
178        // The `transmute_copy` calls preserve the bit pattern between
179        // monomorphized layouts of the same logical type.
180        let status = unsafe {
181            match T::KIND {
182                ElementKind::F32 => {
183                    let v: f32 = core::mem::transmute_copy(&self.desc.value);
184                    baracuda_kernels_sys::baracuda_kernels_fill_f32_run(
185                        numel, y_ptr, v, core::ptr::null_mut(), 0, stream_ptr,
186                    )
187                }
188                ElementKind::F64 => {
189                    let v: f64 = core::mem::transmute_copy(&self.desc.value);
190                    baracuda_kernels_sys::baracuda_kernels_fill_f64_run(
191                        numel, y_ptr, v, core::ptr::null_mut(), 0, stream_ptr,
192                    )
193                }
194                ElementKind::I32 => {
195                    let v: i32 = core::mem::transmute_copy(&self.desc.value);
196                    baracuda_kernels_sys::baracuda_kernels_fill_i32_run(
197                        numel, y_ptr, v, core::ptr::null_mut(), 0, stream_ptr,
198                    )
199                }
200                ElementKind::I64 => {
201                    let v: i64 = core::mem::transmute_copy(&self.desc.value);
202                    baracuda_kernels_sys::baracuda_kernels_fill_i64_run(
203                        numel, y_ptr, v, core::ptr::null_mut(), 0, stream_ptr,
204                    )
205                }
206                ElementKind::F16 => {
207                    let v: f16 = core::mem::transmute_copy(&self.desc.value);
208                    baracuda_kernels_sys::baracuda_kernels_fill_f16_run(
209                        numel, y_ptr, v.to_bits(), core::ptr::null_mut(), 0, stream_ptr,
210                    )
211                }
212                ElementKind::Bf16 => {
213                    let v: bf16 = core::mem::transmute_copy(&self.desc.value);
214                    baracuda_kernels_sys::baracuda_kernels_fill_bf16_run(
215                        numel, y_ptr, v.to_bits(), core::ptr::null_mut(), 0, stream_ptr,
216                    )
217                }
218                _ => {
219                    return Err(Error::Unsupported(
220                        "baracuda-kernels::FillPlan::run reached an unimplemented dtype \
221                         — select() should have caught this",
222                    ));
223                }
224            }
225        };
226        map_status(status)
227    }
228}
229
230fn dtype_in_scope(k: ElementKind) -> bool {
231    matches!(
232        k,
233        ElementKind::F32
234            | ElementKind::F64
235            | ElementKind::F16
236            | ElementKind::Bf16
237            | ElementKind::I32
238            | ElementKind::I64
239    )
240}
241
242fn map_status(code: i32) -> Result<()> {
243    match code {
244        0 => Ok(()),
245        1 => Err(Error::MisalignedOperand),
246        2 => Err(Error::InvalidProblem(
247            "baracuda-kernels-sys reported invalid problem",
248        )),
249        3 => Err(Error::Unsupported(
250            "baracuda-kernels-sys reported unsupported configuration",
251        )),
252        4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
253        n => Err(Error::CutlassInternal(n)),
254    }
255}