baracuda_kernels/shape_layout/
fill.rs1use core::ffi::c_void;
16use core::marker::PhantomData;
17
18use baracuda_cutlass::{Error, Result};
19use baracuda_driver::Stream;
20use baracuda_kernels_types::{
21 ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
22 PlanPreference, PrecisionGuarantee, ShapeLayoutKind, TensorMut, Workspace,
23};
24use half::{bf16, f16};
25
26#[derive(Copy, Clone, Debug)]
31pub struct FillDescriptor<T: Element> {
32 pub numel: i32,
34 pub value: T,
37 pub element: ElementKind,
39}
40
41pub struct FillArgs<'a, T: Element> {
43 pub output: TensorMut<'a, T, 1>,
45}
46
47pub struct FillPlan<T: Element> {
67 desc: FillDescriptor<T>,
68 sku: KernelSku,
69 _marker: PhantomData<T>,
70}
71
72impl<T: Element> FillPlan<T> {
73 pub fn select(
75 _stream: &Stream,
76 desc: &FillDescriptor<T>,
77 _pref: PlanPreference,
78 ) -> Result<Self> {
79 if desc.element != T::KIND {
80 return Err(Error::Unsupported(
81 "baracuda-kernels::FillPlan: descriptor element != type parameter T",
82 ));
83 }
84 if desc.numel < 0 {
85 return Err(Error::InvalidProblem(
86 "baracuda-kernels::FillPlan: numel must be non-negative",
87 ));
88 }
89 if !dtype_in_scope(T::KIND) {
90 return Err(Error::Unsupported(
91 "baracuda-kernels::FillPlan: dtype not wired today; supported set is \
92 {f32, f64, f16, bf16, i32, i64}",
93 ));
94 }
95
96 let precision_guarantee = PrecisionGuarantee {
98 math_precision: MathPrecision::F32,
99 accumulator: ElementKind::F32,
100 bit_stable_on_same_hardware: true,
101 deterministic: true,
102 };
103 let sku = KernelSku {
104 category: OpCategory::ShapeLayout,
105 op: ShapeLayoutKind::Fill as u16,
106 element: T::KIND,
107 aux_element: None,
108 layout: None,
109 epilogue: None,
110 arch: ArchSku::Sm80,
111 backend: BackendKind::Bespoke,
112 precision_guarantee,
113 };
114 Ok(Self {
115 desc: *desc,
116 sku,
117 _marker: PhantomData,
118 })
119 }
120
121 pub fn can_implement(&self, args: &FillArgs<'_, T>) -> Result<()> {
123 let expected = self.desc.numel as i64;
124 if args.output.numel() != expected {
125 return Err(Error::InvalidProblem(
126 "baracuda-kernels::FillPlan: output numel mismatch with descriptor",
127 ));
128 }
129 if (args.output.data.len() as i64) < expected {
130 return Err(Error::BufferTooSmall {
131 needed: expected as usize,
132 got: args.output.data.len(),
133 });
134 }
135 Ok(())
136 }
137
138 #[inline]
140 pub fn workspace_size(&self) -> usize {
141 0
142 }
143
144 #[inline]
146 pub fn sku(&self) -> KernelSku {
147 self.sku
148 }
149
150 #[inline]
152 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
153 self.sku.precision_guarantee
154 }
155
156 pub fn run(
158 &self,
159 stream: &Stream,
160 _workspace: Workspace<'_>,
161 args: FillArgs<'_, T>,
162 ) -> Result<()> {
163 self.can_implement(&args)?;
164 let numel = self.desc.numel as i64;
165 if numel == 0 {
166 return Ok(());
167 }
168 let y_ptr = args.output.data.as_raw().0 as *mut c_void;
169 let stream_ptr = stream.as_raw() as *mut c_void;
170
171 let status = unsafe {
181 match T::KIND {
182 ElementKind::F32 => {
183 let v: f32 = core::mem::transmute_copy(&self.desc.value);
184 baracuda_kernels_sys::baracuda_kernels_fill_f32_run(
185 numel, y_ptr, v, core::ptr::null_mut(), 0, stream_ptr,
186 )
187 }
188 ElementKind::F64 => {
189 let v: f64 = core::mem::transmute_copy(&self.desc.value);
190 baracuda_kernels_sys::baracuda_kernels_fill_f64_run(
191 numel, y_ptr, v, core::ptr::null_mut(), 0, stream_ptr,
192 )
193 }
194 ElementKind::I32 => {
195 let v: i32 = core::mem::transmute_copy(&self.desc.value);
196 baracuda_kernels_sys::baracuda_kernels_fill_i32_run(
197 numel, y_ptr, v, core::ptr::null_mut(), 0, stream_ptr,
198 )
199 }
200 ElementKind::I64 => {
201 let v: i64 = core::mem::transmute_copy(&self.desc.value);
202 baracuda_kernels_sys::baracuda_kernels_fill_i64_run(
203 numel, y_ptr, v, core::ptr::null_mut(), 0, stream_ptr,
204 )
205 }
206 ElementKind::F16 => {
207 let v: f16 = core::mem::transmute_copy(&self.desc.value);
208 baracuda_kernels_sys::baracuda_kernels_fill_f16_run(
209 numel, y_ptr, v.to_bits(), core::ptr::null_mut(), 0, stream_ptr,
210 )
211 }
212 ElementKind::Bf16 => {
213 let v: bf16 = core::mem::transmute_copy(&self.desc.value);
214 baracuda_kernels_sys::baracuda_kernels_fill_bf16_run(
215 numel, y_ptr, v.to_bits(), core::ptr::null_mut(), 0, stream_ptr,
216 )
217 }
218 _ => {
219 return Err(Error::Unsupported(
220 "baracuda-kernels::FillPlan::run reached an unimplemented dtype \
221 — select() should have caught this",
222 ));
223 }
224 }
225 };
226 map_status(status)
227 }
228}
229
230fn dtype_in_scope(k: ElementKind) -> bool {
231 matches!(
232 k,
233 ElementKind::F32
234 | ElementKind::F64
235 | ElementKind::F16
236 | ElementKind::Bf16
237 | ElementKind::I32
238 | ElementKind::I64
239 )
240}
241
242fn map_status(code: i32) -> Result<()> {
243 match code {
244 0 => Ok(()),
245 1 => Err(Error::MisalignedOperand),
246 2 => Err(Error::InvalidProblem(
247 "baracuda-kernels-sys reported invalid problem",
248 )),
249 3 => Err(Error::Unsupported(
250 "baracuda-kernels-sys reported unsupported configuration",
251 )),
252 4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
253 n => Err(Error::CutlassInternal(n)),
254 }
255}