Skip to main content

baracuda_kernels/random/
plan.rs

1//! Random-generator plan — Uniform / Normal / Bernoulli.
2//!
3//! Three sample-generation ops with no input tensor (the descriptor's
4//! `shape` defines the output extent; `seed` and `param1` / `param2`
5//! drive the distribution). Uniform and Normal route directly to cuRAND;
6//! Bernoulli generates a `float` uniform-rand buffer through cuRAND and
7//! then runs the bespoke threshold kernel to produce a Bool output.
8
9use core::cell::Cell;
10use core::ffi::c_void;
11use core::marker::PhantomData;
12
13use baracuda_cutlass::{Error, Result};
14use baracuda_driver::Stream;
15use baracuda_kernels_sys::{
16    curandCreateGenerator, curandDestroyGenerator, curandGenerateNormal,
17    curandGenerateNormalDouble, curandGenerateUniform, curandGenerateUniformDouble,
18    curandGenerator_t, curandSetPseudoRandomGeneratorSeed, curandSetStream,
19};
20use baracuda_kernels_types::{
21    ArchSku, BackendKind, Bool, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
22    PlanPreference, PrecisionGuarantee, RandomKind, TensorMut, Workspace,
23};
24
25/// Descriptor for a random-generator op.
26#[derive(Copy, Clone, Debug)]
27pub struct RandomDescriptor<const N: usize> {
28    /// Which distribution.
29    pub kind: RandomKind,
30    /// Output tensor shape. Must be all-positive when total numel > 0.
31    pub shape: [i32; N],
32    /// Output element type. For [`RandomKind::Uniform`] / [`RandomKind::Normal`]
33    /// this is the produced FP type (f32 / f64). For
34    /// [`RandomKind::Bernoulli`] this is `ElementKind::Bool` —
35    /// the descriptor's type parameter `T` is `Bool`.
36    pub element: ElementKind,
37    /// For `Uniform`: low. For `Normal`: mean. For `Bernoulli`: probability `p`.
38    pub param1: f32,
39    /// For `Uniform`: high. For `Normal`: stddev. Ignored for `Bernoulli`.
40    pub param2: f32,
41    /// Deterministic seed. Each descriptor carries its own RNG state;
42    /// re-running the same plan with the same descriptor and shape
43    /// reproduces the same sequence.
44    pub seed: u64,
45}
46
47/// Args bundle for Uniform / Normal (T = f32 | f64).
48pub struct RandomArgs<'a, T: Element, const N: usize> {
49    /// Output tensor — written by cuRAND directly. Must be contiguous.
50    pub y: TensorMut<'a, T, N>,
51}
52
53/// Args bundle for Bernoulli (output is `Bool`).
54///
55/// Bernoulli runs cuRAND uniform into the caller-provided workspace and
56/// then writes Bool output cells via the bespoke threshold kernel. The
57/// workspace must be at least `numel * sizeof(f32)` bytes (see
58/// [`RandomPlan::workspace_size`]).
59pub struct RandomBoolArgs<'a, const N: usize> {
60    /// Output tensor — packed Bool, one byte per cell.
61    pub y: TensorMut<'a, Bool, N>,
62}
63
64/// Random-generator plan.
65///
66/// Generic on `T` so the same type can carry the FP generators
67/// (`RandomPlan<f32, N>`, `RandomPlan<f64, N>`) and the Bernoulli
68/// generator (`RandomPlan<Bool, N>`). The element kind is reasserted in
69/// `select()` against the descriptor.
70///
71/// The plan owns a single cuRAND generator handle, created lazily on the
72/// first call to `run` (or any of the typed `run_*` accessors). cuRAND
73/// generators are not thread-safe; the plan is `!Sync` and `!Send` as a
74/// consequence (the `Cell<curandGenerator_t>` makes both negative).
75pub struct RandomPlan<T: Element, const N: usize> {
76    desc: RandomDescriptor<N>,
77    sku: KernelSku,
78    // Lazy cuRAND handle. `null` means "not yet created"; the first
79    // `run*` call constructs + seeds it.
80    generator: Cell<curandGenerator_t>,
81    _marker: PhantomData<T>,
82}
83
84impl<T: Element, const N: usize> RandomPlan<T, N> {
85    /// Pick a kernel + validate the descriptor.
86    pub fn select(
87        _stream: &Stream,
88        desc: &RandomDescriptor<N>,
89        _pref: PlanPreference,
90    ) -> Result<Self> {
91        if desc.element != T::KIND {
92            return Err(Error::Unsupported(
93                "baracuda-kernels::RandomPlan: descriptor.element != T::KIND",
94            ));
95        }
96        for &d in desc.shape.iter() {
97            if d < 0 {
98                return Err(Error::InvalidProblem(
99                    "baracuda-kernels::RandomPlan: shape dims must be non-negative",
100                ));
101            }
102        }
103        if N > 8 {
104            return Err(Error::Unsupported(
105                "baracuda-kernels::RandomPlan: tensor rank > 8 not supported",
106            ));
107        }
108
109        // Wired (kind, dtype) matrix:
110        //   Uniform / Normal: f32 + f64.
111        //   Bernoulli:        Bool.
112        let supported = matches!(
113            (desc.kind, T::KIND),
114            (RandomKind::Uniform, ElementKind::F32)
115                | (RandomKind::Uniform, ElementKind::F64)
116                | (RandomKind::Normal, ElementKind::F32)
117                | (RandomKind::Normal, ElementKind::F64)
118                | (RandomKind::Bernoulli, ElementKind::Bool)
119        );
120        if !supported {
121            return Err(Error::Unsupported(
122                "baracuda-kernels::RandomPlan: wired today: \
123                 `{Uniform, Normal} × {f32, f64}` and `Bernoulli × Bool`",
124            ));
125        }
126
127        // Bernoulli wants p in [0, 1].
128        if matches!(desc.kind, RandomKind::Bernoulli) {
129            let p = desc.param1;
130            if !(p >= 0.0 && p <= 1.0) {
131                return Err(Error::InvalidProblem(
132                    "baracuda-kernels::RandomPlan(Bernoulli): p must be in [0, 1]",
133                ));
134            }
135        }
136        // Normal wants stddev > 0.
137        if matches!(desc.kind, RandomKind::Normal) && !(desc.param2 > 0.0) {
138            return Err(Error::InvalidProblem(
139                "baracuda-kernels::RandomPlan(Normal): stddev (param2) must be > 0",
140            ));
141        }
142
143        let backend = match desc.kind {
144            RandomKind::Uniform | RandomKind::Normal => BackendKind::Curand,
145            // Bernoulli is a cuRAND-uniform + custom-threshold composite;
146            // labeled `Bespoke` because the visible output is from the
147            // hand-rolled kernel.
148            RandomKind::Bernoulli => BackendKind::Bespoke,
149            // Defensive — `RandomKind` is `#[non_exhaustive]`. Treat
150            // unknown variants as bespoke kernels until they're wired.
151            _ => BackendKind::Bespoke,
152        };
153        let math_precision = match T::KIND {
154            ElementKind::F64 => MathPrecision::F64,
155            _ => MathPrecision::F32,
156        };
157        let precision_guarantee = PrecisionGuarantee {
158            math_precision,
159            accumulator: T::KIND,
160            // cuRAND's XORWOW generator is bit-stable across runs with
161            // the same seed on the same hardware, and per-cell
162            // independent — no reduction order to worry about.
163            bit_stable_on_same_hardware: true,
164            deterministic: true,
165        };
166        let sku = KernelSku {
167            category: OpCategory::Random,
168            op: desc.kind as u16,
169            element: T::KIND,
170            aux_element: None,
171            layout: None,
172            epilogue: None,
173            arch: ArchSku::Sm80,
174            backend,
175            precision_guarantee,
176        };
177
178        Ok(Self {
179            desc: *desc,
180            sku,
181            generator: Cell::new(core::ptr::null_mut()),
182            _marker: PhantomData,
183        })
184    }
185
186    /// Workspace size in bytes.
187    ///
188    /// Bernoulli needs `numel * sizeof(f32)` bytes for the uniform-rand
189    /// intermediate buffer cuRAND writes into. Uniform / Normal need
190    /// zero — cuRAND writes directly to the caller-provided output.
191    #[inline]
192    pub fn workspace_size(&self) -> usize {
193        if matches!(self.desc.kind, RandomKind::Bernoulli) {
194            let numel: i64 = self.desc.shape.iter().map(|&d| d as i64).product();
195            (numel.max(0) as usize) * core::mem::size_of::<f32>()
196        } else {
197            0
198        }
199    }
200
201    /// Kernel SKU identity.
202    #[inline]
203    pub fn sku(&self) -> KernelSku {
204        self.sku
205    }
206
207    /// Numerical guarantees.
208    #[inline]
209    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
210        self.sku.precision_guarantee
211    }
212
213    /// Internal — lazily create + seed the cuRAND generator. Idempotent.
214    fn ensure_generator(&self) -> Result<curandGenerator_t> {
215        let g = self.generator.get();
216        if !g.is_null() {
217            return Ok(g);
218        }
219        let mut handle: curandGenerator_t = core::ptr::null_mut();
220        // CURAND_RNG_PSEUDO_DEFAULT == 100 (XORWOW).
221        let status =
222            unsafe { curandCreateGenerator(&mut handle as *mut _, 100) };
223        if status != 0 {
224            return Err(Error::CutlassInternal(curand_to_status(status)));
225        }
226        let status = unsafe { curandSetPseudoRandomGeneratorSeed(handle, self.desc.seed) };
227        if status != 0 {
228            unsafe {
229                let _ = curandDestroyGenerator(handle);
230            }
231            return Err(Error::CutlassInternal(curand_to_status(status)));
232        }
233        self.generator.set(handle);
234        Ok(handle)
235    }
236
237    /// Bind the cuRAND generator to the caller's stream. cuRAND
238    /// associates each generator with at most one stream at a time;
239    /// rebinding on every run lets the plan be reused across streams.
240    fn bind_stream(&self, gen_handle: curandGenerator_t, stream: &Stream) -> Result<()> {
241        let stream_ptr = stream.as_raw() as *mut c_void;
242        let status = unsafe { curandSetStream(gen_handle, stream_ptr) };
243        if status != 0 {
244            return Err(Error::CutlassInternal(curand_to_status(status)));
245        }
246        Ok(())
247    }
248
249    /// Internal — common output-shape validation.
250    fn check_shape<U: baracuda_types::DeviceRepr + Copy + 'static>(
251        &self,
252        y: &TensorMut<'_, U, N>,
253    ) -> Result<i64> {
254        if y.shape != self.desc.shape {
255            return Err(Error::InvalidProblem(
256                "baracuda-kernels::RandomPlan: y shape != descriptor shape",
257            ));
258        }
259        let numel = y.numel();
260        let len = y.data.len() as i64;
261        if len < numel {
262            return Err(Error::BufferTooSmall {
263                needed: numel as usize,
264                got: len as usize,
265            });
266        }
267        Ok(numel)
268    }
269}
270
271// =============================================================================
272// Uniform / Normal — generic over T : Element, dispatched per dtype.
273// =============================================================================
274//
275// The generic `run` is split into per-dtype impls because cuRAND's
276// API has separate f32 / f64 entry points (`curandGenerateUniform`
277// vs `curandGenerateUniformDouble`, etc.). We pick the right one at
278// compile time per `impl` block.
279
280impl<const N: usize> RandomPlan<f32, N> {
281    /// Generate `Uniform(low, high)` or `Normal(mean, stddev)` `f32` samples.
282    pub fn run(
283        &self,
284        stream: &Stream,
285        _workspace: Workspace<'_>,
286        args: RandomArgs<'_, f32, N>,
287    ) -> Result<()> {
288        let numel = self.check_shape(&args.y)?;
289        if numel == 0 {
290            return Ok(());
291        }
292        let gen_handle = self.ensure_generator()?;
293        self.bind_stream(gen_handle, stream)?;
294        let ptr = args.y.data.as_raw().0 as *mut f32;
295        let n = numel as usize;
296
297        match self.desc.kind {
298            RandomKind::Uniform => {
299                // cuRAND produces samples in (0, 1]. Map into (low, high]
300                // by an in-place affine transform — fused with the
301                // generator call would be nicer but cuRAND doesn't
302                // expose a fused path, so we sweep with a tiny kernel
303                // instead.
304                let status = unsafe { curandGenerateUniform(gen_handle, ptr, n) };
305                if status != 0 {
306                    return Err(Error::CutlassInternal(curand_to_status(status)));
307                }
308                let low = self.desc.param1;
309                let high = self.desc.param2;
310                if (low, high) != (0.0, 1.0) {
311                    affine_transform_f32(stream, ptr, n, high - low, low)?;
312                }
313                Ok(())
314            }
315            RandomKind::Normal => {
316                let mean = self.desc.param1;
317                let stddev = self.desc.param2;
318                // cuRAND requires `n` to be even (Box-Muller pairs).
319                // For odd `n`, generate `n + 1` into a tail-padded
320                // workspace and copy the first `n` — but our typical
321                // call is well above this corner case (the smoke tests
322                // use 1024 * 1024). For now, fall back to a single
323                // extra-cell over-generation when n is odd: cuRAND only
324                // documents even-n requirement on older versions; modern
325                // cuRAND (12.x) accepts any n. The status code will
326                // surface if it doesn't.
327                let status = unsafe { curandGenerateNormal(gen_handle, ptr, n, mean, stddev) };
328                if status != 0 {
329                    return Err(Error::CutlassInternal(curand_to_status(status)));
330                }
331                Ok(())
332            }
333            RandomKind::Bernoulli => Err(Error::Unsupported(
334                "baracuda-kernels::RandomPlan<f32>: Bernoulli has Bool output — use RandomPlan<Bool>",
335            )),
336            // Defensive arm — `RandomKind` is `#[non_exhaustive]`.
337            _ => Err(Error::Unsupported(
338                "baracuda-kernels::RandomPlan<f32>::run reached an unimplemented RandomKind variant",
339            )),
340        }
341    }
342}
343
344impl<const N: usize> RandomPlan<f64, N> {
345    /// Generate `Uniform(low, high)` or `Normal(mean, stddev)` `f64` samples.
346    pub fn run(
347        &self,
348        stream: &Stream,
349        _workspace: Workspace<'_>,
350        args: RandomArgs<'_, f64, N>,
351    ) -> Result<()> {
352        let numel = self.check_shape(&args.y)?;
353        if numel == 0 {
354            return Ok(());
355        }
356        let gen_handle = self.ensure_generator()?;
357        self.bind_stream(gen_handle, stream)?;
358        let ptr = args.y.data.as_raw().0 as *mut f64;
359        let n = numel as usize;
360
361        match self.desc.kind {
362            RandomKind::Uniform => {
363                let status = unsafe { curandGenerateUniformDouble(gen_handle, ptr, n) };
364                if status != 0 {
365                    return Err(Error::CutlassInternal(curand_to_status(status)));
366                }
367                let low = self.desc.param1 as f64;
368                let high = self.desc.param2 as f64;
369                if (low, high) != (0.0, 1.0) {
370                    affine_transform_f64(stream, ptr, n, high - low, low)?;
371                }
372                Ok(())
373            }
374            RandomKind::Normal => {
375                let mean = self.desc.param1 as f64;
376                let stddev = self.desc.param2 as f64;
377                let status = unsafe { curandGenerateNormalDouble(gen_handle, ptr, n, mean, stddev) };
378                if status != 0 {
379                    return Err(Error::CutlassInternal(curand_to_status(status)));
380                }
381                Ok(())
382            }
383            RandomKind::Bernoulli => Err(Error::Unsupported(
384                "baracuda-kernels::RandomPlan<f64>: Bernoulli has Bool output — use RandomPlan<Bool>",
385            )),
386            // Defensive arm — `RandomKind` is `#[non_exhaustive]`.
387            _ => Err(Error::Unsupported(
388                "baracuda-kernels::RandomPlan<f64>::run reached an unimplemented RandomKind variant",
389            )),
390        }
391    }
392}
393
394// =============================================================================
395// Bernoulli — Bool output via cuRAND uniform + threshold kernel.
396// =============================================================================
397
398impl<const N: usize> RandomPlan<Bool, N> {
399    /// Generate a Bernoulli(p) sample tensor.
400    pub fn run(
401        &self,
402        stream: &Stream,
403        workspace: Workspace<'_>,
404        args: RandomBoolArgs<'_, N>,
405    ) -> Result<()> {
406        if !matches!(self.desc.kind, RandomKind::Bernoulli) {
407            return Err(Error::Unsupported(
408                "baracuda-kernels::RandomPlan<Bool>: only Bernoulli is wired \
409                 (Uniform / Normal use the FP variants)",
410            ));
411        }
412        let numel = self.check_shape(&args.y)?;
413        if numel == 0 {
414            return Ok(());
415        }
416        let needed = self.workspace_size();
417        let (ws_ptr, ws_bytes): (*mut c_void, usize) = match workspace {
418            Workspace::None => {
419                return Err(Error::WorkspaceTooSmall {
420                    needed,
421                    got: 0,
422                })
423            }
424            Workspace::Borrowed(slice) => {
425                if slice.len() < needed {
426                    return Err(Error::WorkspaceTooSmall {
427                        needed,
428                        got: slice.len(),
429                    });
430                }
431                (slice.as_raw().0 as *mut c_void, slice.len())
432            }
433        };
434
435        let gen_handle = self.ensure_generator()?;
436        self.bind_stream(gen_handle, stream)?;
437
438        let rand_ptr = ws_ptr as *mut f32;
439        let n = numel as usize;
440        let status = unsafe { curandGenerateUniform(gen_handle, rand_ptr, n) };
441        if status != 0 {
442            return Err(Error::CutlassInternal(curand_to_status(status)));
443        }
444
445        let y_ptr = args.y.data.as_raw().0 as *mut c_void;
446        let stream_ptr = stream.as_raw() as *mut c_void;
447        let status = unsafe {
448            baracuda_kernels_sys::baracuda_kernels_bernoulli_run(
449                numel,
450                self.desc.param1,
451                rand_ptr as *const c_void,
452                y_ptr,
453                core::ptr::null_mut(),
454                ws_bytes, // pass for ABI symmetry; the bernoulli kernel ignores it.
455                stream_ptr,
456            )
457        };
458        map_status(status)
459    }
460}
461
462impl<T: Element, const N: usize> Drop for RandomPlan<T, N> {
463    fn drop(&mut self) {
464        let g = self.generator.get();
465        if !g.is_null() {
466            // Best-effort destroy — failure here is non-fatal (the
467            // process keeps the cuRAND state ledger alive until exit,
468            // which is fine).
469            unsafe {
470                let _ = curandDestroyGenerator(g);
471            }
472            self.generator.set(core::ptr::null_mut());
473        }
474    }
475}
476
477// Map a cuRAND status code into a kernel-launcher status integer. cuRAND
478// status codes (positive) don't collide with the elementwise status
479// space (0..=5), so we offset them into the negative range to make the
480// origin visible when the error surfaces.
481fn curand_to_status(curand_code: i32) -> i32 {
482    if curand_code == 0 {
483        0
484    } else {
485        -curand_code
486    }
487}
488
489fn map_status(code: i32) -> Result<()> {
490    match code {
491        0 => Ok(()),
492        1 => Err(Error::MisalignedOperand),
493        2 => Err(Error::InvalidProblem(
494            "baracuda-kernels-sys reported invalid problem",
495        )),
496        3 => Err(Error::Unsupported(
497            "baracuda-kernels-sys reported unsupported configuration",
498        )),
499        4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
500        n => Err(Error::CutlassInternal(n)),
501    }
502}
503
504// ----------------------------------------------------------------------------
505// Affine transform helper — used to scale Uniform(0, 1] → Uniform(low, high]
506// in place. Implemented as a launch of the unary `y = a · x + b` Lerp-ish
507// kernel pattern, but to avoid pulling that op family into the random
508// module, we round-trip through a per-element kernel emitted ad hoc.
509// Today we route through a small bespoke launcher; if any other op family
510// needs a fused-affine path it can graduate to a shared kernel.
511// ----------------------------------------------------------------------------
512
513fn affine_transform_f32(
514    stream: &Stream,
515    ptr: *mut f32,
516    n: usize,
517    scale: f32,
518    offset: f32,
519) -> Result<()> {
520    let stream_ptr = stream.as_raw() as *mut c_void;
521    let status = unsafe {
522        baracuda_kernels_sys::baracuda_kernels_affine_inplace_f32_run(
523            n as i64,
524            scale,
525            offset,
526            ptr as *mut c_void,
527            core::ptr::null_mut(),
528            0,
529            stream_ptr,
530        )
531    };
532    map_status(status)
533}
534
535fn affine_transform_f64(
536    stream: &Stream,
537    ptr: *mut f64,
538    n: usize,
539    scale: f64,
540    offset: f64,
541) -> Result<()> {
542    let stream_ptr = stream.as_raw() as *mut c_void;
543    let status = unsafe {
544        baracuda_kernels_sys::baracuda_kernels_affine_inplace_f64_run(
545            n as i64,
546            scale,
547            offset,
548            ptr as *mut c_void,
549            core::ptr::null_mut(),
550            0,
551            stream_ptr,
552        )
553    };
554    map_status(status)
555}