1use core::cell::Cell;
26use core::ffi::c_void;
27use core::marker::PhantomData;
28
29use baracuda_cutlass::{Error, Result};
30use baracuda_driver::Stream;
31use baracuda_kernels_sys::{
32 curandCreateGenerator, curandDestroyGenerator, curandGenerateUniform, curandGenerator_t,
33 curandSetPseudoRandomGeneratorSeed, curandSetStream,
34};
35use baracuda_kernels_types::{
36 ArchSku, BackendKind, Bool, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
37 PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, Workspace,
38};
39
40#[derive(Copy, Clone, Debug)]
42pub struct DropoutDescriptor<const N: usize> {
43 pub shape: [i32; N],
45 pub element: ElementKind,
47 pub p: f32,
50 pub seed: u64,
52}
53
54pub struct DropoutArgs<'a, T: Element, const N: usize> {
56 pub x: TensorRef<'a, T, N>,
58 pub y: TensorMut<'a, T, N>,
60 pub mask: TensorMut<'a, Bool, N>,
63}
64
65pub struct DropoutPlan<T: Element, const N: usize> {
70 desc: DropoutDescriptor<N>,
71 sku: KernelSku,
72 generator: Cell<curandGenerator_t>,
73 _marker: PhantomData<T>,
74}
75
76impl<T: Element, const N: usize> DropoutPlan<T, N> {
77 pub fn select(
79 _stream: &Stream,
80 desc: &DropoutDescriptor<N>,
81 _pref: PlanPreference,
82 ) -> Result<Self> {
83 if desc.element != T::KIND {
84 return Err(Error::Unsupported(
85 "baracuda-kernels::DropoutPlan: descriptor.element != T::KIND",
86 ));
87 }
88 if !matches!(T::KIND, ElementKind::F32 | ElementKind::F64) {
89 return Err(Error::Unsupported(
90 "baracuda-kernels::DropoutPlan: wired today: f32 + f64",
91 ));
92 }
93 for &d in desc.shape.iter() {
94 if d < 0 {
95 return Err(Error::InvalidProblem(
96 "baracuda-kernels::DropoutPlan: shape dims must be non-negative",
97 ));
98 }
99 }
100 if N > 8 {
101 return Err(Error::Unsupported(
102 "baracuda-kernels::DropoutPlan: tensor rank > 8 not supported",
103 ));
104 }
105 if !(desc.p >= 0.0 && desc.p <= 1.0) {
106 return Err(Error::InvalidProblem(
107 "baracuda-kernels::DropoutPlan: p must be in [0, 1]",
108 ));
109 }
110
111 let math_precision = match T::KIND {
112 ElementKind::F64 => MathPrecision::F64,
113 _ => MathPrecision::F32,
114 };
115 let precision_guarantee = PrecisionGuarantee {
116 math_precision,
117 accumulator: T::KIND,
118 bit_stable_on_same_hardware: true,
119 deterministic: true,
120 };
121 let sku = KernelSku {
122 category: OpCategory::Random,
123 op: 100, element: T::KIND,
125 aux_element: Some(ElementKind::Bool),
126 layout: None,
127 epilogue: None,
128 arch: ArchSku::Sm80,
129 backend: BackendKind::Bespoke,
130 precision_guarantee,
131 };
132 Ok(Self {
133 desc: *desc,
134 sku,
135 generator: Cell::new(core::ptr::null_mut()),
136 _marker: PhantomData,
137 })
138 }
139
140 #[inline]
143 pub fn workspace_size(&self) -> usize {
144 let numel: i64 = self.desc.shape.iter().map(|&d| d as i64).product();
145 (numel.max(0) as usize) * core::mem::size_of::<f32>()
146 }
147
148 #[inline]
150 pub fn sku(&self) -> KernelSku {
151 self.sku
152 }
153
154 #[inline]
156 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
157 self.sku.precision_guarantee
158 }
159
160 fn ensure_generator(&self) -> Result<curandGenerator_t> {
161 let g = self.generator.get();
162 if !g.is_null() {
163 return Ok(g);
164 }
165 let mut handle: curandGenerator_t = core::ptr::null_mut();
166 let status =
167 unsafe { curandCreateGenerator(&mut handle as *mut _, 100) };
168 if status != 0 {
169 return Err(Error::CutlassInternal(-status));
170 }
171 let status = unsafe { curandSetPseudoRandomGeneratorSeed(handle, self.desc.seed) };
172 if status != 0 {
173 unsafe {
174 let _ = curandDestroyGenerator(handle);
175 }
176 return Err(Error::CutlassInternal(-status));
177 }
178 self.generator.set(handle);
179 Ok(handle)
180 }
181
182 fn check_args(&self, args: &DropoutArgs<'_, T, N>) -> Result<i64> {
183 if args.x.shape != self.desc.shape
184 || args.y.shape != self.desc.shape
185 || args.mask.shape != self.desc.shape
186 {
187 return Err(Error::InvalidProblem(
188 "baracuda-kernels::DropoutPlan: shape mismatch (x / y / mask)",
189 ));
190 }
191 let numel = args.y.numel();
192 let xlen = args.x.data.len() as i64;
193 let ylen = args.y.data.len() as i64;
194 let mlen = args.mask.data.len() as i64;
195 if xlen < numel || ylen < numel || mlen < numel {
196 return Err(Error::BufferTooSmall {
197 needed: numel as usize,
198 got: xlen.min(ylen).min(mlen) as usize,
199 });
200 }
201 Ok(numel)
202 }
203}
204
205impl<const N: usize> DropoutPlan<f32, N> {
206 pub fn run(
208 &self,
209 stream: &Stream,
210 workspace: Workspace<'_>,
211 args: DropoutArgs<'_, f32, N>,
212 ) -> Result<()> {
213 let numel = self.check_args(&args)?;
214 if numel == 0 {
215 return Ok(());
216 }
217 let needed = self.workspace_size();
218 let (ws_ptr, ws_bytes): (*mut c_void, usize) = match workspace {
219 Workspace::None => {
220 return Err(Error::WorkspaceTooSmall {
221 needed,
222 got: 0,
223 })
224 }
225 Workspace::Borrowed(slice) => {
226 if slice.len() < needed {
227 return Err(Error::WorkspaceTooSmall {
228 needed,
229 got: slice.len(),
230 });
231 }
232 (slice.as_raw().0 as *mut c_void, slice.len())
233 }
234 };
235
236 let stream_ptr = stream.as_raw() as *mut c_void;
237 let x_ptr = args.x.data.as_raw().0 as *const c_void;
238 let y_ptr = args.y.data.as_raw().0 as *mut c_void;
239 let mask_ptr = args.mask.data.as_raw().0 as *mut c_void;
240 let rand_ptr = ws_ptr as *mut f32;
241
242 let gen_handle = self.ensure_generator()?;
243 let status = unsafe { curandSetStream(gen_handle, stream_ptr) };
244 if status != 0 {
245 return Err(Error::CutlassInternal(-status));
246 }
247 let status = unsafe { curandGenerateUniform(gen_handle, rand_ptr, numel as usize) };
249 if status != 0 {
250 return Err(Error::CutlassInternal(-status));
251 }
252
253 let p = self.desc.p;
258 let scale = if p < 1.0 { 1.0_f32 / (1.0 - p) } else { 0.0_f32 };
259 let status = unsafe {
260 baracuda_kernels_sys::baracuda_kernels_dropout_f32_run(
261 numel,
262 p,
263 scale,
264 x_ptr,
265 rand_ptr as *const c_void,
266 y_ptr,
267 mask_ptr,
268 core::ptr::null_mut(),
269 ws_bytes,
270 stream_ptr,
271 )
272 };
273 map_status(status)
277 }
278}
279
280impl<const N: usize> DropoutPlan<f64, N> {
281 pub fn run(
283 &self,
284 stream: &Stream,
285 workspace: Workspace<'_>,
286 args: DropoutArgs<'_, f64, N>,
287 ) -> Result<()> {
288 let numel = self.check_args(&args)?;
289 if numel == 0 {
290 return Ok(());
291 }
292 let needed = self.workspace_size();
293 let (ws_ptr, ws_bytes): (*mut c_void, usize) = match workspace {
294 Workspace::None => {
295 return Err(Error::WorkspaceTooSmall {
296 needed,
297 got: 0,
298 })
299 }
300 Workspace::Borrowed(slice) => {
301 if slice.len() < needed {
302 return Err(Error::WorkspaceTooSmall {
303 needed,
304 got: slice.len(),
305 });
306 }
307 (slice.as_raw().0 as *mut c_void, slice.len())
308 }
309 };
310
311 let stream_ptr = stream.as_raw() as *mut c_void;
312 let x_ptr = args.x.data.as_raw().0 as *const c_void;
313 let y_ptr = args.y.data.as_raw().0 as *mut c_void;
314 let mask_ptr = args.mask.data.as_raw().0 as *mut c_void;
315 let rand_ptr = ws_ptr as *mut f32;
316
317 let gen_handle = self.ensure_generator()?;
318 let status = unsafe { curandSetStream(gen_handle, stream_ptr) };
319 if status != 0 {
320 return Err(Error::CutlassInternal(-status));
321 }
322 let status = unsafe { curandGenerateUniform(gen_handle, rand_ptr, numel as usize) };
323 if status != 0 {
324 return Err(Error::CutlassInternal(-status));
325 }
326
327 let p = self.desc.p;
328 let scale = if p < 1.0 { 1.0_f64 / (1.0 - p as f64) } else { 0.0_f64 };
329 let status = unsafe {
330 baracuda_kernels_sys::baracuda_kernels_dropout_f64_run(
331 numel,
332 p,
333 scale,
334 x_ptr,
335 rand_ptr as *const c_void,
336 y_ptr,
337 mask_ptr,
338 core::ptr::null_mut(),
339 ws_bytes,
340 stream_ptr,
341 )
342 };
343 map_status(status)
344 }
345}
346
347impl<T: Element, const N: usize> Drop for DropoutPlan<T, N> {
348 fn drop(&mut self) {
349 let g = self.generator.get();
350 if !g.is_null() {
351 unsafe {
352 let _ = curandDestroyGenerator(g);
353 }
354 self.generator.set(core::ptr::null_mut());
355 }
356 }
357}
358
359#[derive(Copy, Clone, Debug)]
369pub struct DropoutBackwardDescriptor<const N: usize> {
370 pub shape: [i32; N],
372 pub element: ElementKind,
374 pub p: f32,
376}
377
378pub struct DropoutBackwardArgs<'a, T: Element, const N: usize> {
380 pub dy: TensorRef<'a, T, N>,
382 pub mask: TensorRef<'a, Bool, N>,
384 pub dx: TensorMut<'a, T, N>,
386}
387
388pub struct DropoutBackwardPlan<T: Element, const N: usize> {
390 desc: DropoutBackwardDescriptor<N>,
391 sku: KernelSku,
392 _marker: PhantomData<T>,
393}
394
395impl<T: Element, const N: usize> DropoutBackwardPlan<T, N> {
396 pub fn select(
398 _stream: &Stream,
399 desc: &DropoutBackwardDescriptor<N>,
400 _pref: PlanPreference,
401 ) -> Result<Self> {
402 if desc.element != T::KIND {
403 return Err(Error::Unsupported(
404 "baracuda-kernels::DropoutBackwardPlan: descriptor.element != T::KIND",
405 ));
406 }
407 if !matches!(T::KIND, ElementKind::F32 | ElementKind::F64) {
408 return Err(Error::Unsupported(
409 "baracuda-kernels::DropoutBackwardPlan: wired today: f32 + f64",
410 ));
411 }
412 for &d in desc.shape.iter() {
413 if d < 0 {
414 return Err(Error::InvalidProblem(
415 "baracuda-kernels::DropoutBackwardPlan: shape dims must be non-negative",
416 ));
417 }
418 }
419 if N > 8 {
420 return Err(Error::Unsupported(
421 "baracuda-kernels::DropoutBackwardPlan: tensor rank > 8 not supported",
422 ));
423 }
424 if !(desc.p >= 0.0 && desc.p <= 1.0) {
425 return Err(Error::InvalidProblem(
426 "baracuda-kernels::DropoutBackwardPlan: p must be in [0, 1]",
427 ));
428 }
429
430 let math_precision = match T::KIND {
431 ElementKind::F64 => MathPrecision::F64,
432 _ => MathPrecision::F32,
433 };
434 let precision_guarantee = PrecisionGuarantee {
435 math_precision,
436 accumulator: T::KIND,
437 bit_stable_on_same_hardware: true,
438 deterministic: true,
439 };
440 let sku = KernelSku {
441 category: OpCategory::Random,
442 op: 101, element: T::KIND,
444 aux_element: Some(ElementKind::Bool),
445 layout: None,
446 epilogue: None,
447 arch: ArchSku::Sm80,
448 backend: BackendKind::Bespoke,
449 precision_guarantee,
450 };
451 Ok(Self {
452 desc: *desc,
453 sku,
454 _marker: PhantomData,
455 })
456 }
457
458 #[inline]
460 pub fn workspace_size(&self) -> usize {
461 0
462 }
463
464 #[inline]
466 pub fn sku(&self) -> KernelSku {
467 self.sku
468 }
469
470 #[inline]
472 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
473 self.sku.precision_guarantee
474 }
475
476 fn check_args(&self, args: &DropoutBackwardArgs<'_, T, N>) -> Result<i64> {
477 if args.dy.shape != self.desc.shape
478 || args.mask.shape != self.desc.shape
479 || args.dx.shape != self.desc.shape
480 {
481 return Err(Error::InvalidProblem(
482 "baracuda-kernels::DropoutBackwardPlan: shape mismatch",
483 ));
484 }
485 let numel = args.dy.numel();
486 let dylen = args.dy.data.len() as i64;
487 let mlen = args.mask.data.len() as i64;
488 let dxlen = args.dx.data.len() as i64;
489 if dylen < numel || mlen < numel || dxlen < numel {
490 return Err(Error::BufferTooSmall {
491 needed: numel as usize,
492 got: dylen.min(mlen).min(dxlen) as usize,
493 });
494 }
495 Ok(numel)
496 }
497}
498
499impl<const N: usize> DropoutBackwardPlan<f32, N> {
500 pub fn run(
502 &self,
503 stream: &Stream,
504 _workspace: Workspace<'_>,
505 args: DropoutBackwardArgs<'_, f32, N>,
506 ) -> Result<()> {
507 let numel = self.check_args(&args)?;
508 if numel == 0 {
509 return Ok(());
510 }
511 let stream_ptr = stream.as_raw() as *mut c_void;
512 let dy_ptr = args.dy.data.as_raw().0 as *const c_void;
513 let mask_ptr = args.mask.data.as_raw().0 as *const c_void;
514 let dx_ptr = args.dx.data.as_raw().0 as *mut c_void;
515
516 let p = self.desc.p;
517 let scale = if p < 1.0 { 1.0_f32 / (1.0 - p) } else { 0.0_f32 };
518 let status = unsafe {
519 baracuda_kernels_sys::baracuda_kernels_dropout_backward_f32_run(
520 numel,
521 scale,
522 dy_ptr,
523 mask_ptr,
524 dx_ptr,
525 core::ptr::null_mut(),
526 0,
527 stream_ptr,
528 )
529 };
530 map_status(status)
531 }
532}
533
534impl<const N: usize> DropoutBackwardPlan<f64, N> {
535 pub fn run(
537 &self,
538 stream: &Stream,
539 _workspace: Workspace<'_>,
540 args: DropoutBackwardArgs<'_, f64, N>,
541 ) -> Result<()> {
542 let numel = self.check_args(&args)?;
543 if numel == 0 {
544 return Ok(());
545 }
546 let stream_ptr = stream.as_raw() as *mut c_void;
547 let dy_ptr = args.dy.data.as_raw().0 as *const c_void;
548 let mask_ptr = args.mask.data.as_raw().0 as *const c_void;
549 let dx_ptr = args.dx.data.as_raw().0 as *mut c_void;
550
551 let p = self.desc.p;
552 let scale = if p < 1.0 { 1.0_f64 / (1.0 - p as f64) } else { 0.0_f64 };
553 let status = unsafe {
554 baracuda_kernels_sys::baracuda_kernels_dropout_backward_f64_run(
555 numel,
556 scale,
557 dy_ptr,
558 mask_ptr,
559 dx_ptr,
560 core::ptr::null_mut(),
561 0,
562 stream_ptr,
563 )
564 };
565 map_status(status)
566 }
567}
568
569fn map_status(code: i32) -> Result<()> {
570 match code {
571 0 => Ok(()),
572 1 => Err(Error::MisalignedOperand),
573 2 => Err(Error::InvalidProblem(
574 "baracuda-kernels-sys reported invalid problem",
575 )),
576 3 => Err(Error::Unsupported(
577 "baracuda-kernels-sys reported unsupported configuration",
578 )),
579 4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
580 n => Err(Error::CutlassInternal(n)),
581 }
582}