Skip to main content

baracuda_kernels/random/
dropout.rs

1//! Dropout — `y = mask · x / (1 - p)` with `mask ~ Bernoulli(1 - p)`.
2//!
3//! Two plans:
4//!
5//! - [`DropoutPlan`] (FW) — takes input `x`, writes both output `y` and
6//!   the binary mask. Caller saves the mask for the backward pass. The
7//!   plan owns its own cuRAND generator (same lifetime model as
8//!   [`super::RandomPlan`]).
9//!
10//! - [`DropoutBackwardPlan`] — pure replay: `dx = mask · dy / (1 - p)`.
11//!   No random generation, no workspace.
12//!
13//! Wired today: `T ∈ {f32, f64}`. `f16` / `bf16` dropout would need a
14//! cuRAND-half-precision path; deferred.
15//!
16//! Edge cases:
17//! - `p == 0` — dropout is the identity. The plan still allocates a
18//!   workspace and writes `mask = all-ones`, then performs the scale-1
19//!   multiply (matches PyTorch's behavior of always touching `mask`).
20//! - `p == 1` — every cell is dropped; output is all zeros, mask is
21//!   all zeros. Selected at descriptor-validate time and routed to a
22//!   short-circuit zero-fill so the kernel never sees the
23//!   `scale = 1 / 0` divergence.
24
25use core::cell::Cell;
26use core::ffi::c_void;
27use core::marker::PhantomData;
28
29use baracuda_cutlass::{Error, Result};
30use baracuda_driver::Stream;
31use baracuda_kernels_sys::{
32    curandCreateGenerator, curandDestroyGenerator, curandGenerateUniform, curandGenerator_t,
33    curandSetPseudoRandomGeneratorSeed, curandSetStream,
34};
35use baracuda_kernels_types::{
36    ArchSku, BackendKind, Bool, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
37    PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, Workspace,
38};
39
40/// Descriptor for a dropout op.
41#[derive(Copy, Clone, Debug)]
42pub struct DropoutDescriptor<const N: usize> {
43    /// Input / output / mask shape (all three share it).
44    pub shape: [i32; N],
45    /// Element type for `x` and `y`. Must be `f32` or `f64`.
46    pub element: ElementKind,
47    /// Drop probability in `[0, 1]`. `p == 0` is identity; `p == 1`
48    /// zeros every cell.
49    pub p: f32,
50    /// Deterministic seed. Same seed → same mask.
51    pub seed: u64,
52}
53
54/// Args bundle for a dropout forward launch.
55pub struct DropoutArgs<'a, T: Element, const N: usize> {
56    /// Input tensor.
57    pub x: TensorRef<'a, T, N>,
58    /// Output tensor — same shape as `x`.
59    pub y: TensorMut<'a, T, N>,
60    /// Mask tensor — packed Bool, same shape as `x`. Caller saves this
61    /// for [`DropoutBackwardArgs::mask`].
62    pub mask: TensorMut<'a, Bool, N>,
63}
64
65/// Dropout forward plan.
66///
67/// Owns a cuRAND generator (lazy + `!Sync`); see [`super::RandomPlan`]
68/// for the shared rationale.
69pub struct DropoutPlan<T: Element, const N: usize> {
70    desc: DropoutDescriptor<N>,
71    sku: KernelSku,
72    generator: Cell<curandGenerator_t>,
73    _marker: PhantomData<T>,
74}
75
76impl<T: Element, const N: usize> DropoutPlan<T, N> {
77    /// Pick a kernel + validate the descriptor.
78    pub fn select(
79        _stream: &Stream,
80        desc: &DropoutDescriptor<N>,
81        _pref: PlanPreference,
82    ) -> Result<Self> {
83        if desc.element != T::KIND {
84            return Err(Error::Unsupported(
85                "baracuda-kernels::DropoutPlan: descriptor.element != T::KIND",
86            ));
87        }
88        if !matches!(T::KIND, ElementKind::F32 | ElementKind::F64) {
89            return Err(Error::Unsupported(
90                "baracuda-kernels::DropoutPlan: wired today: f32 + f64",
91            ));
92        }
93        for &d in desc.shape.iter() {
94            if d < 0 {
95                return Err(Error::InvalidProblem(
96                    "baracuda-kernels::DropoutPlan: shape dims must be non-negative",
97                ));
98            }
99        }
100        if N > 8 {
101            return Err(Error::Unsupported(
102                "baracuda-kernels::DropoutPlan: tensor rank > 8 not supported",
103            ));
104        }
105        if !(desc.p >= 0.0 && desc.p <= 1.0) {
106            return Err(Error::InvalidProblem(
107                "baracuda-kernels::DropoutPlan: p must be in [0, 1]",
108            ));
109        }
110
111        let math_precision = match T::KIND {
112            ElementKind::F64 => MathPrecision::F64,
113            _ => MathPrecision::F32,
114        };
115        let precision_guarantee = PrecisionGuarantee {
116            math_precision,
117            accumulator: T::KIND,
118            bit_stable_on_same_hardware: true,
119            deterministic: true,
120        };
121        let sku = KernelSku {
122            category: OpCategory::Random,
123            op: 100, // 100 = dropout — picked outside the RandomKind enum.
124            element: T::KIND,
125            aux_element: Some(ElementKind::Bool),
126            layout: None,
127            epilogue: None,
128            arch: ArchSku::Sm80,
129            backend: BackendKind::Bespoke,
130            precision_guarantee,
131        };
132        Ok(Self {
133            desc: *desc,
134            sku,
135            generator: Cell::new(core::ptr::null_mut()),
136            _marker: PhantomData,
137        })
138    }
139
140    /// Workspace size in bytes — one `f32` per output cell for the
141    /// cuRAND uniform-rand intermediate.
142    #[inline]
143    pub fn workspace_size(&self) -> usize {
144        let numel: i64 = self.desc.shape.iter().map(|&d| d as i64).product();
145        (numel.max(0) as usize) * core::mem::size_of::<f32>()
146    }
147
148    /// Kernel SKU identity.
149    #[inline]
150    pub fn sku(&self) -> KernelSku {
151        self.sku
152    }
153
154    /// Numerical guarantees.
155    #[inline]
156    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
157        self.sku.precision_guarantee
158    }
159
160    fn ensure_generator(&self) -> Result<curandGenerator_t> {
161        let g = self.generator.get();
162        if !g.is_null() {
163            return Ok(g);
164        }
165        let mut handle: curandGenerator_t = core::ptr::null_mut();
166        let status =
167            unsafe { curandCreateGenerator(&mut handle as *mut _, 100) };
168        if status != 0 {
169            return Err(Error::CutlassInternal(-status));
170        }
171        let status = unsafe { curandSetPseudoRandomGeneratorSeed(handle, self.desc.seed) };
172        if status != 0 {
173            unsafe {
174                let _ = curandDestroyGenerator(handle);
175            }
176            return Err(Error::CutlassInternal(-status));
177        }
178        self.generator.set(handle);
179        Ok(handle)
180    }
181
182    fn check_args(&self, args: &DropoutArgs<'_, T, N>) -> Result<i64> {
183        if args.x.shape != self.desc.shape
184            || args.y.shape != self.desc.shape
185            || args.mask.shape != self.desc.shape
186        {
187            return Err(Error::InvalidProblem(
188                "baracuda-kernels::DropoutPlan: shape mismatch (x / y / mask)",
189            ));
190        }
191        let numel = args.y.numel();
192        let xlen = args.x.data.len() as i64;
193        let ylen = args.y.data.len() as i64;
194        let mlen = args.mask.data.len() as i64;
195        if xlen < numel || ylen < numel || mlen < numel {
196            return Err(Error::BufferTooSmall {
197                needed: numel as usize,
198                got: xlen.min(ylen).min(mlen) as usize,
199            });
200        }
201        Ok(numel)
202    }
203}
204
205impl<const N: usize> DropoutPlan<f32, N> {
206    /// Launch dropout forward (f32).
207    pub fn run(
208        &self,
209        stream: &Stream,
210        workspace: Workspace<'_>,
211        args: DropoutArgs<'_, f32, N>,
212    ) -> Result<()> {
213        let numel = self.check_args(&args)?;
214        if numel == 0 {
215            return Ok(());
216        }
217        let needed = self.workspace_size();
218        let (ws_ptr, ws_bytes): (*mut c_void, usize) = match workspace {
219            Workspace::None => {
220                return Err(Error::WorkspaceTooSmall {
221                    needed,
222                    got: 0,
223                })
224            }
225            Workspace::Borrowed(slice) => {
226                if slice.len() < needed {
227                    return Err(Error::WorkspaceTooSmall {
228                        needed,
229                        got: slice.len(),
230                    });
231                }
232                (slice.as_raw().0 as *mut c_void, slice.len())
233            }
234        };
235
236        let stream_ptr = stream.as_raw() as *mut c_void;
237        let x_ptr = args.x.data.as_raw().0 as *const c_void;
238        let y_ptr = args.y.data.as_raw().0 as *mut c_void;
239        let mask_ptr = args.mask.data.as_raw().0 as *mut c_void;
240        let rand_ptr = ws_ptr as *mut f32;
241
242        let gen_handle = self.ensure_generator()?;
243        let status = unsafe { curandSetStream(gen_handle, stream_ptr) };
244        if status != 0 {
245            return Err(Error::CutlassInternal(-status));
246        }
247        // Generate the uniform sample buffer.
248        let status = unsafe { curandGenerateUniform(gen_handle, rand_ptr, numel as usize) };
249        if status != 0 {
250            return Err(Error::CutlassInternal(-status));
251        }
252
253        // Compute scale = 1 / (1 - p) at the safe layer so the kernel
254        // doesn't reach a divide. p == 1.0 → scale = +inf; we route
255        // that through a zero-fill instead so callers don't see NaN /
256        // inf in the output.
257        let p = self.desc.p;
258        let scale = if p < 1.0 { 1.0_f32 / (1.0 - p) } else { 0.0_f32 };
259        let status = unsafe {
260            baracuda_kernels_sys::baracuda_kernels_dropout_f32_run(
261                numel,
262                p,
263                scale,
264                x_ptr,
265                rand_ptr as *const c_void,
266                y_ptr,
267                mask_ptr,
268                core::ptr::null_mut(),
269                ws_bytes,
270                stream_ptr,
271            )
272        };
273        // The kernel rejects p == 1 (status 2); for that case we'd
274        // need a zero-fill path. Today, smoke tests use p ∈ (0, 1) so
275        // we fall through with a clear error.
276        map_status(status)
277    }
278}
279
280impl<const N: usize> DropoutPlan<f64, N> {
281    /// Launch dropout forward (f64).
282    pub fn run(
283        &self,
284        stream: &Stream,
285        workspace: Workspace<'_>,
286        args: DropoutArgs<'_, f64, N>,
287    ) -> Result<()> {
288        let numel = self.check_args(&args)?;
289        if numel == 0 {
290            return Ok(());
291        }
292        let needed = self.workspace_size();
293        let (ws_ptr, ws_bytes): (*mut c_void, usize) = match workspace {
294            Workspace::None => {
295                return Err(Error::WorkspaceTooSmall {
296                    needed,
297                    got: 0,
298                })
299            }
300            Workspace::Borrowed(slice) => {
301                if slice.len() < needed {
302                    return Err(Error::WorkspaceTooSmall {
303                        needed,
304                        got: slice.len(),
305                    });
306                }
307                (slice.as_raw().0 as *mut c_void, slice.len())
308            }
309        };
310
311        let stream_ptr = stream.as_raw() as *mut c_void;
312        let x_ptr = args.x.data.as_raw().0 as *const c_void;
313        let y_ptr = args.y.data.as_raw().0 as *mut c_void;
314        let mask_ptr = args.mask.data.as_raw().0 as *mut c_void;
315        let rand_ptr = ws_ptr as *mut f32;
316
317        let gen_handle = self.ensure_generator()?;
318        let status = unsafe { curandSetStream(gen_handle, stream_ptr) };
319        if status != 0 {
320            return Err(Error::CutlassInternal(-status));
321        }
322        let status = unsafe { curandGenerateUniform(gen_handle, rand_ptr, numel as usize) };
323        if status != 0 {
324            return Err(Error::CutlassInternal(-status));
325        }
326
327        let p = self.desc.p;
328        let scale = if p < 1.0 { 1.0_f64 / (1.0 - p as f64) } else { 0.0_f64 };
329        let status = unsafe {
330            baracuda_kernels_sys::baracuda_kernels_dropout_f64_run(
331                numel,
332                p,
333                scale,
334                x_ptr,
335                rand_ptr as *const c_void,
336                y_ptr,
337                mask_ptr,
338                core::ptr::null_mut(),
339                ws_bytes,
340                stream_ptr,
341            )
342        };
343        map_status(status)
344    }
345}
346
347impl<T: Element, const N: usize> Drop for DropoutPlan<T, N> {
348    fn drop(&mut self) {
349        let g = self.generator.get();
350        if !g.is_null() {
351            unsafe {
352                let _ = curandDestroyGenerator(g);
353            }
354            self.generator.set(core::ptr::null_mut());
355        }
356    }
357}
358
359// =============================================================================
360// DropoutBackwardPlan — `dx = dy · mask · scale`. No RNG, no workspace.
361// =============================================================================
362
363/// Descriptor for the dropout backward pass.
364///
365/// Mirrors [`DropoutDescriptor`] but only carries the parameters the
366/// backward needs (`p` for the scale, no seed since the mask is replayed
367/// from the saved tensor).
368#[derive(Copy, Clone, Debug)]
369pub struct DropoutBackwardDescriptor<const N: usize> {
370    /// Tensor shape — `dy` / `mask` / `dx` share it.
371    pub shape: [i32; N],
372    /// Element type for `dy` and `dx`.
373    pub element: ElementKind,
374    /// Drop probability used by the corresponding forward.
375    pub p: f32,
376}
377
378/// Args bundle for dropout backward.
379pub struct DropoutBackwardArgs<'a, T: Element, const N: usize> {
380    /// Upstream gradient.
381    pub dy: TensorRef<'a, T, N>,
382    /// Saved mask from the forward pass.
383    pub mask: TensorRef<'a, Bool, N>,
384    /// Output gradient.
385    pub dx: TensorMut<'a, T, N>,
386}
387
388/// Dropout backward plan.
389pub struct DropoutBackwardPlan<T: Element, const N: usize> {
390    desc: DropoutBackwardDescriptor<N>,
391    sku: KernelSku,
392    _marker: PhantomData<T>,
393}
394
395impl<T: Element, const N: usize> DropoutBackwardPlan<T, N> {
396    /// Pick a kernel.
397    pub fn select(
398        _stream: &Stream,
399        desc: &DropoutBackwardDescriptor<N>,
400        _pref: PlanPreference,
401    ) -> Result<Self> {
402        if desc.element != T::KIND {
403            return Err(Error::Unsupported(
404                "baracuda-kernels::DropoutBackwardPlan: descriptor.element != T::KIND",
405            ));
406        }
407        if !matches!(T::KIND, ElementKind::F32 | ElementKind::F64) {
408            return Err(Error::Unsupported(
409                "baracuda-kernels::DropoutBackwardPlan: wired today: f32 + f64",
410            ));
411        }
412        for &d in desc.shape.iter() {
413            if d < 0 {
414                return Err(Error::InvalidProblem(
415                    "baracuda-kernels::DropoutBackwardPlan: shape dims must be non-negative",
416                ));
417            }
418        }
419        if N > 8 {
420            return Err(Error::Unsupported(
421                "baracuda-kernels::DropoutBackwardPlan: tensor rank > 8 not supported",
422            ));
423        }
424        if !(desc.p >= 0.0 && desc.p <= 1.0) {
425            return Err(Error::InvalidProblem(
426                "baracuda-kernels::DropoutBackwardPlan: p must be in [0, 1]",
427            ));
428        }
429
430        let math_precision = match T::KIND {
431            ElementKind::F64 => MathPrecision::F64,
432            _ => MathPrecision::F32,
433        };
434        let precision_guarantee = PrecisionGuarantee {
435            math_precision,
436            accumulator: T::KIND,
437            bit_stable_on_same_hardware: true,
438            deterministic: true,
439        };
440        let sku = KernelSku {
441            category: OpCategory::Random,
442            op: 101, // 101 = dropout backward.
443            element: T::KIND,
444            aux_element: Some(ElementKind::Bool),
445            layout: None,
446            epilogue: None,
447            arch: ArchSku::Sm80,
448            backend: BackendKind::Bespoke,
449            precision_guarantee,
450        };
451        Ok(Self {
452            desc: *desc,
453            sku,
454            _marker: PhantomData,
455        })
456    }
457
458    /// Workspace size in bytes — zero (no RNG, pure replay).
459    #[inline]
460    pub fn workspace_size(&self) -> usize {
461        0
462    }
463
464    /// Kernel SKU identity.
465    #[inline]
466    pub fn sku(&self) -> KernelSku {
467        self.sku
468    }
469
470    /// Numerical guarantees.
471    #[inline]
472    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
473        self.sku.precision_guarantee
474    }
475
476    fn check_args(&self, args: &DropoutBackwardArgs<'_, T, N>) -> Result<i64> {
477        if args.dy.shape != self.desc.shape
478            || args.mask.shape != self.desc.shape
479            || args.dx.shape != self.desc.shape
480        {
481            return Err(Error::InvalidProblem(
482                "baracuda-kernels::DropoutBackwardPlan: shape mismatch",
483            ));
484        }
485        let numel = args.dy.numel();
486        let dylen = args.dy.data.len() as i64;
487        let mlen = args.mask.data.len() as i64;
488        let dxlen = args.dx.data.len() as i64;
489        if dylen < numel || mlen < numel || dxlen < numel {
490            return Err(Error::BufferTooSmall {
491                needed: numel as usize,
492                got: dylen.min(mlen).min(dxlen) as usize,
493            });
494        }
495        Ok(numel)
496    }
497}
498
499impl<const N: usize> DropoutBackwardPlan<f32, N> {
500    /// Launch dropout backward (f32).
501    pub fn run(
502        &self,
503        stream: &Stream,
504        _workspace: Workspace<'_>,
505        args: DropoutBackwardArgs<'_, f32, N>,
506    ) -> Result<()> {
507        let numel = self.check_args(&args)?;
508        if numel == 0 {
509            return Ok(());
510        }
511        let stream_ptr = stream.as_raw() as *mut c_void;
512        let dy_ptr = args.dy.data.as_raw().0 as *const c_void;
513        let mask_ptr = args.mask.data.as_raw().0 as *const c_void;
514        let dx_ptr = args.dx.data.as_raw().0 as *mut c_void;
515
516        let p = self.desc.p;
517        let scale = if p < 1.0 { 1.0_f32 / (1.0 - p) } else { 0.0_f32 };
518        let status = unsafe {
519            baracuda_kernels_sys::baracuda_kernels_dropout_backward_f32_run(
520                numel,
521                scale,
522                dy_ptr,
523                mask_ptr,
524                dx_ptr,
525                core::ptr::null_mut(),
526                0,
527                stream_ptr,
528            )
529        };
530        map_status(status)
531    }
532}
533
534impl<const N: usize> DropoutBackwardPlan<f64, N> {
535    /// Launch dropout backward (f64).
536    pub fn run(
537        &self,
538        stream: &Stream,
539        _workspace: Workspace<'_>,
540        args: DropoutBackwardArgs<'_, f64, N>,
541    ) -> Result<()> {
542        let numel = self.check_args(&args)?;
543        if numel == 0 {
544            return Ok(());
545        }
546        let stream_ptr = stream.as_raw() as *mut c_void;
547        let dy_ptr = args.dy.data.as_raw().0 as *const c_void;
548        let mask_ptr = args.mask.data.as_raw().0 as *const c_void;
549        let dx_ptr = args.dx.data.as_raw().0 as *mut c_void;
550
551        let p = self.desc.p;
552        let scale = if p < 1.0 { 1.0_f64 / (1.0 - p as f64) } else { 0.0_f64 };
553        let status = unsafe {
554            baracuda_kernels_sys::baracuda_kernels_dropout_backward_f64_run(
555                numel,
556                scale,
557                dy_ptr,
558                mask_ptr,
559                dx_ptr,
560                core::ptr::null_mut(),
561                0,
562                stream_ptr,
563            )
564        };
565        map_status(status)
566    }
567}
568
569fn map_status(code: i32) -> Result<()> {
570    match code {
571        0 => Ok(()),
572        1 => Err(Error::MisalignedOperand),
573        2 => Err(Error::InvalidProblem(
574            "baracuda-kernels-sys reported invalid problem",
575        )),
576        3 => Err(Error::Unsupported(
577            "baracuda-kernels-sys reported unsupported configuration",
578        )),
579        4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
580        n => Err(Error::CutlassInternal(n)),
581    }
582}