1use core::cell::Cell;
10use core::ffi::c_void;
11use core::marker::PhantomData;
12
13use baracuda_cutlass::{Error, Result};
14use baracuda_driver::Stream;
15use baracuda_kernels_sys::{
16 curandCreateGenerator, curandDestroyGenerator, curandGenerateNormal,
17 curandGenerateNormalDouble, curandGenerateUniform, curandGenerateUniformDouble,
18 curandGenerator_t, curandSetPseudoRandomGeneratorSeed, curandSetStream,
19};
20use baracuda_kernels_types::{
21 ArchSku, BackendKind, Bool, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
22 PlanPreference, PrecisionGuarantee, RandomKind, TensorMut, Workspace,
23};
24
25#[derive(Copy, Clone, Debug)]
27pub struct RandomDescriptor<const N: usize> {
28 pub kind: RandomKind,
30 pub shape: [i32; N],
32 pub element: ElementKind,
37 pub param1: f32,
39 pub param2: f32,
41 pub seed: u64,
45}
46
47pub struct RandomArgs<'a, T: Element, const N: usize> {
49 pub y: TensorMut<'a, T, N>,
51}
52
53pub struct RandomBoolArgs<'a, const N: usize> {
60 pub y: TensorMut<'a, Bool, N>,
62}
63
64pub struct RandomPlan<T: Element, const N: usize> {
76 desc: RandomDescriptor<N>,
77 sku: KernelSku,
78 generator: Cell<curandGenerator_t>,
81 _marker: PhantomData<T>,
82}
83
84impl<T: Element, const N: usize> RandomPlan<T, N> {
85 pub fn select(
87 _stream: &Stream,
88 desc: &RandomDescriptor<N>,
89 _pref: PlanPreference,
90 ) -> Result<Self> {
91 if desc.element != T::KIND {
92 return Err(Error::Unsupported(
93 "baracuda-kernels::RandomPlan: descriptor.element != T::KIND",
94 ));
95 }
96 for &d in desc.shape.iter() {
97 if d < 0 {
98 return Err(Error::InvalidProblem(
99 "baracuda-kernels::RandomPlan: shape dims must be non-negative",
100 ));
101 }
102 }
103 if N > 8 {
104 return Err(Error::Unsupported(
105 "baracuda-kernels::RandomPlan: tensor rank > 8 not supported",
106 ));
107 }
108
109 let supported = matches!(
113 (desc.kind, T::KIND),
114 (RandomKind::Uniform, ElementKind::F32)
115 | (RandomKind::Uniform, ElementKind::F64)
116 | (RandomKind::Normal, ElementKind::F32)
117 | (RandomKind::Normal, ElementKind::F64)
118 | (RandomKind::Bernoulli, ElementKind::Bool)
119 );
120 if !supported {
121 return Err(Error::Unsupported(
122 "baracuda-kernels::RandomPlan: wired today: \
123 `{Uniform, Normal} × {f32, f64}` and `Bernoulli × Bool`",
124 ));
125 }
126
127 if matches!(desc.kind, RandomKind::Bernoulli) {
129 let p = desc.param1;
130 if !(p >= 0.0 && p <= 1.0) {
131 return Err(Error::InvalidProblem(
132 "baracuda-kernels::RandomPlan(Bernoulli): p must be in [0, 1]",
133 ));
134 }
135 }
136 if matches!(desc.kind, RandomKind::Normal) && !(desc.param2 > 0.0) {
138 return Err(Error::InvalidProblem(
139 "baracuda-kernels::RandomPlan(Normal): stddev (param2) must be > 0",
140 ));
141 }
142
143 let backend = match desc.kind {
144 RandomKind::Uniform | RandomKind::Normal => BackendKind::Curand,
145 RandomKind::Bernoulli => BackendKind::Bespoke,
149 _ => BackendKind::Bespoke,
152 };
153 let math_precision = match T::KIND {
154 ElementKind::F64 => MathPrecision::F64,
155 _ => MathPrecision::F32,
156 };
157 let precision_guarantee = PrecisionGuarantee {
158 math_precision,
159 accumulator: T::KIND,
160 bit_stable_on_same_hardware: true,
164 deterministic: true,
165 };
166 let sku = KernelSku {
167 category: OpCategory::Random,
168 op: desc.kind as u16,
169 element: T::KIND,
170 aux_element: None,
171 layout: None,
172 epilogue: None,
173 arch: ArchSku::Sm80,
174 backend,
175 precision_guarantee,
176 };
177
178 Ok(Self {
179 desc: *desc,
180 sku,
181 generator: Cell::new(core::ptr::null_mut()),
182 _marker: PhantomData,
183 })
184 }
185
186 #[inline]
192 pub fn workspace_size(&self) -> usize {
193 if matches!(self.desc.kind, RandomKind::Bernoulli) {
194 let numel: i64 = self.desc.shape.iter().map(|&d| d as i64).product();
195 (numel.max(0) as usize) * core::mem::size_of::<f32>()
196 } else {
197 0
198 }
199 }
200
201 #[inline]
203 pub fn sku(&self) -> KernelSku {
204 self.sku
205 }
206
207 #[inline]
209 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
210 self.sku.precision_guarantee
211 }
212
213 fn ensure_generator(&self) -> Result<curandGenerator_t> {
215 let g = self.generator.get();
216 if !g.is_null() {
217 return Ok(g);
218 }
219 let mut handle: curandGenerator_t = core::ptr::null_mut();
220 let status =
222 unsafe { curandCreateGenerator(&mut handle as *mut _, 100) };
223 if status != 0 {
224 return Err(Error::CutlassInternal(curand_to_status(status)));
225 }
226 let status = unsafe { curandSetPseudoRandomGeneratorSeed(handle, self.desc.seed) };
227 if status != 0 {
228 unsafe {
229 let _ = curandDestroyGenerator(handle);
230 }
231 return Err(Error::CutlassInternal(curand_to_status(status)));
232 }
233 self.generator.set(handle);
234 Ok(handle)
235 }
236
237 fn bind_stream(&self, gen_handle: curandGenerator_t, stream: &Stream) -> Result<()> {
241 let stream_ptr = stream.as_raw() as *mut c_void;
242 let status = unsafe { curandSetStream(gen_handle, stream_ptr) };
243 if status != 0 {
244 return Err(Error::CutlassInternal(curand_to_status(status)));
245 }
246 Ok(())
247 }
248
249 fn check_shape<U: baracuda_types::DeviceRepr + Copy + 'static>(
251 &self,
252 y: &TensorMut<'_, U, N>,
253 ) -> Result<i64> {
254 if y.shape != self.desc.shape {
255 return Err(Error::InvalidProblem(
256 "baracuda-kernels::RandomPlan: y shape != descriptor shape",
257 ));
258 }
259 let numel = y.numel();
260 let len = y.data.len() as i64;
261 if len < numel {
262 return Err(Error::BufferTooSmall {
263 needed: numel as usize,
264 got: len as usize,
265 });
266 }
267 Ok(numel)
268 }
269}
270
271impl<const N: usize> RandomPlan<f32, N> {
281 pub fn run(
283 &self,
284 stream: &Stream,
285 _workspace: Workspace<'_>,
286 args: RandomArgs<'_, f32, N>,
287 ) -> Result<()> {
288 let numel = self.check_shape(&args.y)?;
289 if numel == 0 {
290 return Ok(());
291 }
292 let gen_handle = self.ensure_generator()?;
293 self.bind_stream(gen_handle, stream)?;
294 let ptr = args.y.data.as_raw().0 as *mut f32;
295 let n = numel as usize;
296
297 match self.desc.kind {
298 RandomKind::Uniform => {
299 let status = unsafe { curandGenerateUniform(gen_handle, ptr, n) };
305 if status != 0 {
306 return Err(Error::CutlassInternal(curand_to_status(status)));
307 }
308 let low = self.desc.param1;
309 let high = self.desc.param2;
310 if (low, high) != (0.0, 1.0) {
311 affine_transform_f32(stream, ptr, n, high - low, low)?;
312 }
313 Ok(())
314 }
315 RandomKind::Normal => {
316 let mean = self.desc.param1;
317 let stddev = self.desc.param2;
318 let status = unsafe { curandGenerateNormal(gen_handle, ptr, n, mean, stddev) };
328 if status != 0 {
329 return Err(Error::CutlassInternal(curand_to_status(status)));
330 }
331 Ok(())
332 }
333 RandomKind::Bernoulli => Err(Error::Unsupported(
334 "baracuda-kernels::RandomPlan<f32>: Bernoulli has Bool output — use RandomPlan<Bool>",
335 )),
336 _ => Err(Error::Unsupported(
338 "baracuda-kernels::RandomPlan<f32>::run reached an unimplemented RandomKind variant",
339 )),
340 }
341 }
342}
343
344impl<const N: usize> RandomPlan<f64, N> {
345 pub fn run(
347 &self,
348 stream: &Stream,
349 _workspace: Workspace<'_>,
350 args: RandomArgs<'_, f64, N>,
351 ) -> Result<()> {
352 let numel = self.check_shape(&args.y)?;
353 if numel == 0 {
354 return Ok(());
355 }
356 let gen_handle = self.ensure_generator()?;
357 self.bind_stream(gen_handle, stream)?;
358 let ptr = args.y.data.as_raw().0 as *mut f64;
359 let n = numel as usize;
360
361 match self.desc.kind {
362 RandomKind::Uniform => {
363 let status = unsafe { curandGenerateUniformDouble(gen_handle, ptr, n) };
364 if status != 0 {
365 return Err(Error::CutlassInternal(curand_to_status(status)));
366 }
367 let low = self.desc.param1 as f64;
368 let high = self.desc.param2 as f64;
369 if (low, high) != (0.0, 1.0) {
370 affine_transform_f64(stream, ptr, n, high - low, low)?;
371 }
372 Ok(())
373 }
374 RandomKind::Normal => {
375 let mean = self.desc.param1 as f64;
376 let stddev = self.desc.param2 as f64;
377 let status = unsafe { curandGenerateNormalDouble(gen_handle, ptr, n, mean, stddev) };
378 if status != 0 {
379 return Err(Error::CutlassInternal(curand_to_status(status)));
380 }
381 Ok(())
382 }
383 RandomKind::Bernoulli => Err(Error::Unsupported(
384 "baracuda-kernels::RandomPlan<f64>: Bernoulli has Bool output — use RandomPlan<Bool>",
385 )),
386 _ => Err(Error::Unsupported(
388 "baracuda-kernels::RandomPlan<f64>::run reached an unimplemented RandomKind variant",
389 )),
390 }
391 }
392}
393
394impl<const N: usize> RandomPlan<Bool, N> {
399 pub fn run(
401 &self,
402 stream: &Stream,
403 workspace: Workspace<'_>,
404 args: RandomBoolArgs<'_, N>,
405 ) -> Result<()> {
406 if !matches!(self.desc.kind, RandomKind::Bernoulli) {
407 return Err(Error::Unsupported(
408 "baracuda-kernels::RandomPlan<Bool>: only Bernoulli is wired \
409 (Uniform / Normal use the FP variants)",
410 ));
411 }
412 let numel = self.check_shape(&args.y)?;
413 if numel == 0 {
414 return Ok(());
415 }
416 let needed = self.workspace_size();
417 let (ws_ptr, ws_bytes): (*mut c_void, usize) = match workspace {
418 Workspace::None => {
419 return Err(Error::WorkspaceTooSmall {
420 needed,
421 got: 0,
422 })
423 }
424 Workspace::Borrowed(slice) => {
425 if slice.len() < needed {
426 return Err(Error::WorkspaceTooSmall {
427 needed,
428 got: slice.len(),
429 });
430 }
431 (slice.as_raw().0 as *mut c_void, slice.len())
432 }
433 };
434
435 let gen_handle = self.ensure_generator()?;
436 self.bind_stream(gen_handle, stream)?;
437
438 let rand_ptr = ws_ptr as *mut f32;
439 let n = numel as usize;
440 let status = unsafe { curandGenerateUniform(gen_handle, rand_ptr, n) };
441 if status != 0 {
442 return Err(Error::CutlassInternal(curand_to_status(status)));
443 }
444
445 let y_ptr = args.y.data.as_raw().0 as *mut c_void;
446 let stream_ptr = stream.as_raw() as *mut c_void;
447 let status = unsafe {
448 baracuda_kernels_sys::baracuda_kernels_bernoulli_run(
449 numel,
450 self.desc.param1,
451 rand_ptr as *const c_void,
452 y_ptr,
453 core::ptr::null_mut(),
454 ws_bytes, stream_ptr,
456 )
457 };
458 map_status(status)
459 }
460}
461
462impl<T: Element, const N: usize> Drop for RandomPlan<T, N> {
463 fn drop(&mut self) {
464 let g = self.generator.get();
465 if !g.is_null() {
466 unsafe {
470 let _ = curandDestroyGenerator(g);
471 }
472 self.generator.set(core::ptr::null_mut());
473 }
474 }
475}
476
477fn curand_to_status(curand_code: i32) -> i32 {
482 if curand_code == 0 {
483 0
484 } else {
485 -curand_code
486 }
487}
488
489fn map_status(code: i32) -> Result<()> {
490 match code {
491 0 => Ok(()),
492 1 => Err(Error::MisalignedOperand),
493 2 => Err(Error::InvalidProblem(
494 "baracuda-kernels-sys reported invalid problem",
495 )),
496 3 => Err(Error::Unsupported(
497 "baracuda-kernels-sys reported unsupported configuration",
498 )),
499 4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
500 n => Err(Error::CutlassInternal(n)),
501 }
502}
503
504fn affine_transform_f32(
514 stream: &Stream,
515 ptr: *mut f32,
516 n: usize,
517 scale: f32,
518 offset: f32,
519) -> Result<()> {
520 let stream_ptr = stream.as_raw() as *mut c_void;
521 let status = unsafe {
522 baracuda_kernels_sys::baracuda_kernels_affine_inplace_f32_run(
523 n as i64,
524 scale,
525 offset,
526 ptr as *mut c_void,
527 core::ptr::null_mut(),
528 0,
529 stream_ptr,
530 )
531 };
532 map_status(status)
533}
534
535fn affine_transform_f64(
536 stream: &Stream,
537 ptr: *mut f64,
538 n: usize,
539 scale: f64,
540 offset: f64,
541) -> Result<()> {
542 let stream_ptr = stream.as_raw() as *mut c_void;
543 let status = unsafe {
544 baracuda_kernels_sys::baracuda_kernels_affine_inplace_f64_run(
545 n as i64,
546 scale,
547 offset,
548 ptr as *mut c_void,
549 core::ptr::null_mut(),
550 0,
551 stream_ptr,
552 )
553 };
554 map_status(status)
555}