1use std::sync::Arc;
8
9use oxicuda_driver::context::Context;
10use oxicuda_driver::module::Module;
11use oxicuda_driver::stream::Stream;
12use oxicuda_launch::grid::grid_size_for;
13use oxicuda_launch::kernel::Kernel;
14use oxicuda_launch::params::LaunchParams;
15use oxicuda_memory::DeviceBuffer;
16use oxicuda_ptx::arch::SmVersion;
17use oxicuda_ptx::builder::KernelBuilder;
18use oxicuda_ptx::error::PtxGenError;
19use oxicuda_ptx::ir::PtxType;
20
21use crate::engines::{mrg32k3a, philox, philox_optimized, xorwow};
22use crate::error::{RandError, RandResult};
23
24const LOG_NORMAL_EXP_KERNEL_F32: &str = "log_normal_exp_f32";
25const LOG_NORMAL_EXP_KERNEL_F64: &str = "log_normal_exp_f64";
26const POISSON_POSTPROCESS_KERNEL_F32: &str = "poisson_postprocess_f32";
27
28fn log_normal_exp_kernel_name(precision: PtxType) -> &'static str {
29 match precision {
30 PtxType::F32 => LOG_NORMAL_EXP_KERNEL_F32,
31 PtxType::F64 => LOG_NORMAL_EXP_KERNEL_F64,
32 _ => LOG_NORMAL_EXP_KERNEL_F32,
33 }
34}
35
36fn poisson_postprocess_kernel_name() -> &'static str {
37 POISSON_POSTPROCESS_KERNEL_F32
38}
39
40fn generate_log_normal_exp_ptx(precision: PtxType, sm: SmVersion) -> Result<String, PtxGenError> {
41 let kernel_name = log_normal_exp_kernel_name(precision);
42 let stride_bytes = precision.size_bytes() as u32;
43
44 KernelBuilder::new(kernel_name)
45 .target(sm)
46 .param("out_ptr", PtxType::U64)
47 .param("n", PtxType::U32)
48 .max_threads_per_block(256)
49 .body(move |b| {
50 let gid = b.global_thread_id_x();
51 let n_reg = b.load_param_u32("n");
52
53 b.if_lt_u32(gid.clone(), n_reg, move |b| {
54 let out_ptr = b.load_param_u64("out_ptr");
55 let addr = b.byte_offset_addr(out_ptr, gid.clone(), stride_bytes);
56
57 match precision {
58 PtxType::F32 => {
59 let normal_val = b.load_global_f32(addr.clone());
60 let log2e = b.alloc_reg(PtxType::F32);
61 b.raw_ptx(&format!("mov.f32 {log2e}, 0f3FB8AA3B;"));
62 let scaled = b.alloc_reg(PtxType::F32);
63 b.raw_ptx(&format!("mul.rn.f32 {scaled}, {normal_val}, {log2e};"));
64 let result = b.alloc_reg(PtxType::F32);
65 b.raw_ptx(&format!("ex2.approx.f32 {result}, {scaled};"));
66 b.store_global_f32(addr, result);
67 }
68 PtxType::F64 => {
69 let normal_val = b.load_global_f64(addr.clone());
70 let narrow = b.alloc_reg(PtxType::F32);
71 b.raw_ptx(&format!("cvt.rn.f32.f64 {narrow}, {normal_val};"));
72
73 let log2e = b.alloc_reg(PtxType::F32);
74 b.raw_ptx(&format!("mov.f32 {log2e}, 0f3FB8AA3B;"));
75 let scaled = b.alloc_reg(PtxType::F32);
76 b.raw_ptx(&format!("mul.rn.f32 {scaled}, {narrow}, {log2e};"));
77 let exp_f32 = b.alloc_reg(PtxType::F32);
78 b.raw_ptx(&format!("ex2.approx.f32 {exp_f32}, {scaled};"));
79
80 let result = b.alloc_reg(PtxType::F64);
81 b.raw_ptx(&format!("cvt.f64.f32 {result}, {exp_f32};"));
82 b.store_global_f64(addr, result);
83 }
84 _ => {}
85 }
86 });
87
88 b.ret();
89 })
90 .build()
91}
92
93fn generate_poisson_postprocess_f32_ptx(sm: SmVersion) -> Result<String, PtxGenError> {
94 let kernel_name = poisson_postprocess_kernel_name();
95
96 KernelBuilder::new(kernel_name)
97 .target(sm)
98 .param("out_ptr", PtxType::U64)
99 .param("n", PtxType::U32)
100 .max_threads_per_block(256)
101 .body(move |b| {
102 let gid = b.global_thread_id_x();
103 let n_reg = b.load_param_u32("n");
104
105 b.if_lt_u32(gid.clone(), n_reg, move |b| {
106 let out_ptr = b.load_param_u64("out_ptr");
107 let addr = b.byte_offset_addr(out_ptr, gid, 4);
108 let value = b.load_global_f32(addr.clone());
109
110 let rounded_i32 = b.alloc_reg(PtxType::S32);
111 b.raw_ptx(&format!("cvt.rni.s32.f32 {rounded_i32}, {value};"));
112
113 let zero_i32 = b.alloc_reg(PtxType::S32);
114 b.raw_ptx(&format!("mov.s32 {zero_i32}, 0;"));
115
116 let clamped_i32 = b.alloc_reg(PtxType::S32);
117 b.raw_ptx(&format!(
118 "max.s32 {clamped_i32}, {rounded_i32}, {zero_i32};"
119 ));
120
121 let clamped_f32 = b.alloc_reg(PtxType::F32);
122 b.raw_ptx(&format!("cvt.rn.f32.s32 {clamped_f32}, {clamped_i32};"));
123 b.store_global_f32(addr, clamped_f32);
124 });
125
126 b.ret();
127 })
128 .build()
129}
130
131fn validate_poisson_lambda(lambda: f64) -> RandResult<f32> {
132 if !lambda.is_finite() || lambda < 0.0 {
133 return Err(RandError::InvalidParameter(format!(
134 "lambda must be finite and >= 0, got {lambda}"
135 )));
136 }
137 Ok(lambda as f32)
138}
139
140#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
146pub enum RngEngine {
147 Philox,
149 Xorwow,
151 Mrg32k3a,
153}
154
155impl std::fmt::Display for RngEngine {
156 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
157 match self {
158 Self::Philox => write!(f, "Philox-4x32-10"),
159 Self::Xorwow => write!(f, "XORWOW"),
160 Self::Mrg32k3a => write!(f, "MRG32k3a"),
161 }
162 }
163}
164
165pub struct RngGenerator {
193 engine: RngEngine,
195 seed: u64,
197 offset: u64,
199 #[allow(dead_code)]
201 context: Arc<Context>,
202 stream: Stream,
204 sm_version: SmVersion,
206}
207
208impl RngGenerator {
209 pub fn new(engine: RngEngine, seed: u64, ctx: &Arc<Context>) -> RandResult<Self> {
215 let stream = Stream::new(ctx).map_err(RandError::Cuda)?;
216 Ok(Self {
217 engine,
218 seed,
219 offset: 0,
220 context: Arc::clone(ctx),
221 stream,
222 sm_version: SmVersion::Sm80,
223 })
224 }
225
226 pub fn set_seed(&mut self, seed: u64) {
228 self.seed = seed;
229 }
230
231 pub fn set_offset(&mut self, offset: u64) {
233 self.offset = offset;
234 }
235
236 pub fn skip(&mut self, n: u64) {
238 self.offset = self.offset.wrapping_add(n);
239 }
240
241 pub fn generate_uniform_f32(&mut self, output: &mut DeviceBuffer<f32>) -> RandResult<()> {
247 let n = output.len();
248 let ptx_source = self.get_uniform_ptx(PtxType::F32)?;
249 self.compile_and_launch_uniform(&ptx_source, PtxType::F32, output.as_device_ptr(), n)?;
250 self.offset += n as u64;
251 Ok(())
252 }
253
254 pub fn generate_uniform_f64(&mut self, output: &mut DeviceBuffer<f64>) -> RandResult<()> {
260 let n = output.len();
261 let ptx_source = self.get_uniform_ptx(PtxType::F64)?;
262 self.compile_and_launch_uniform(&ptx_source, PtxType::F64, output.as_device_ptr(), n)?;
263 self.offset += n as u64;
264 Ok(())
265 }
266
267 pub fn generate_uniform_f32_optimized(
277 &mut self,
278 output: &mut DeviceBuffer<f32>,
279 ) -> RandResult<()> {
280 let n = output.len();
281 if self.engine != RngEngine::Philox || n < philox_optimized::OPTIMIZED_THRESHOLD {
282 return self.generate_uniform_f32(output);
283 }
284
285 let ptx_source =
286 philox_optimized::generate_philox_optimized_uniform_f32_ptx(self.sm_version)?;
287 self.compile_and_launch_uniform(&ptx_source, PtxType::F32, output.as_device_ptr(), n)?;
288 self.offset += n.div_ceil(4) as u64;
290 Ok(())
291 }
292
293 pub fn generate_normal_f32_optimized(
303 &mut self,
304 output: &mut DeviceBuffer<f32>,
305 mean: f32,
306 stddev: f32,
307 ) -> RandResult<()> {
308 let n = output.len();
309 if self.engine != RngEngine::Philox || n < philox_optimized::OPTIMIZED_THRESHOLD {
310 return self.generate_normal_f32(output, mean, stddev);
311 }
312
313 let ptx_source =
314 philox_optimized::generate_philox_optimized_normal_f32_ptx(self.sm_version)?;
315 self.compile_and_launch_normal_f32(&ptx_source, output.as_device_ptr(), n, mean, stddev)?;
316 self.offset += n.div_ceil(4) as u64;
317 Ok(())
318 }
319
320 pub fn generate_normal_f32(
326 &mut self,
327 output: &mut DeviceBuffer<f32>,
328 mean: f32,
329 stddev: f32,
330 ) -> RandResult<()> {
331 let n = output.len();
332 let ptx_source = self.get_normal_ptx(PtxType::F32)?;
333 self.compile_and_launch_normal_f32(&ptx_source, output.as_device_ptr(), n, mean, stddev)?;
334 self.offset += n as u64;
335 Ok(())
336 }
337
338 pub fn generate_normal_f64(
344 &mut self,
345 output: &mut DeviceBuffer<f64>,
346 mean: f64,
347 stddev: f64,
348 ) -> RandResult<()> {
349 let n = output.len();
350 let ptx_source = self.get_normal_ptx(PtxType::F64)?;
351 self.compile_and_launch_normal_f64(&ptx_source, output.as_device_ptr(), n, mean, stddev)?;
352 self.offset += n as u64;
353 Ok(())
354 }
355
356 pub fn generate_log_normal_f32(
364 &mut self,
365 output: &mut DeviceBuffer<f32>,
366 mean: f32,
367 stddev: f32,
368 ) -> RandResult<()> {
369 let n = output.len();
370 self.generate_normal_f32(output, mean, stddev)?;
371 let ptx_source = self.get_log_normal_exp_ptx(PtxType::F32)?;
372 self.compile_and_launch_log_normal_exp(
373 &ptx_source,
374 PtxType::F32,
375 output.as_device_ptr(),
376 n,
377 )?;
378 Ok(())
379 }
380
381 pub fn generate_log_normal_f64(
387 &mut self,
388 output: &mut DeviceBuffer<f64>,
389 mean: f64,
390 stddev: f64,
391 ) -> RandResult<()> {
392 let n = output.len();
393 self.generate_normal_f64(output, mean, stddev)?;
394 let ptx_source = self.get_log_normal_exp_ptx(PtxType::F64)?;
395 self.compile_and_launch_log_normal_exp(
396 &ptx_source,
397 PtxType::F64,
398 output.as_device_ptr(),
399 n,
400 )?;
401 Ok(())
402 }
403
404 pub fn generate_poisson_f32(
413 &mut self,
414 output: &mut DeviceBuffer<f32>,
415 lambda: f64,
416 ) -> RandResult<()> {
417 let lambda_f32 = validate_poisson_lambda(lambda)?;
418 let stddev = lambda.sqrt() as f32;
419 let n = output.len();
420
421 self.generate_normal_f32(output, lambda_f32, stddev)?;
423
424 let ptx_source = self.get_poisson_postprocess_f32_ptx()?;
425 self.compile_and_launch_poisson_postprocess_f32(&ptx_source, output.as_device_ptr(), n)?;
426 Ok(())
427 }
428
429 pub fn generate_u32(&mut self, output: &mut DeviceBuffer<u32>) -> RandResult<()> {
438 let n = output.len();
439 let ptx_source = self.get_u32_ptx()?;
440 let kernel_name = self.u32_kernel_name();
441 self.compile_and_launch_u32(&ptx_source, &kernel_name, output.as_device_ptr(), n)?;
442 self.offset += n as u64;
443 Ok(())
444 }
445
446 fn get_uniform_ptx(&self, precision: PtxType) -> RandResult<String> {
452 let ptx = match self.engine {
453 RngEngine::Philox => philox::generate_philox_uniform_ptx(precision, self.sm_version)?,
454 RngEngine::Xorwow => xorwow::generate_xorwow_uniform_ptx(precision, self.sm_version)?,
455 RngEngine::Mrg32k3a => {
456 mrg32k3a::generate_mrg32k3a_uniform_ptx(precision, self.sm_version)?
457 }
458 };
459 Ok(ptx)
460 }
461
462 fn get_normal_ptx(&self, precision: PtxType) -> RandResult<String> {
464 let ptx = match self.engine {
465 RngEngine::Philox => philox::generate_philox_normal_ptx(precision, self.sm_version)?,
466 RngEngine::Xorwow => xorwow::generate_xorwow_normal_ptx(precision, self.sm_version)?,
467 RngEngine::Mrg32k3a => {
468 mrg32k3a::generate_mrg32k3a_normal_ptx(precision, self.sm_version)?
469 }
470 };
471 Ok(ptx)
472 }
473
474 fn get_u32_ptx(&self) -> RandResult<String> {
476 let ptx = match self.engine {
477 RngEngine::Philox => philox::generate_philox_u32_ptx(self.sm_version)?,
478 RngEngine::Mrg32k3a => mrg32k3a::generate_mrg32k3a_u32_ptx(self.sm_version)?,
479 RngEngine::Xorwow => {
480 return Err(RandError::UnsupportedDistribution(
481 "u32 output is not supported for XORWOW engine".to_string(),
482 ));
483 }
484 };
485 Ok(ptx)
486 }
487
488 fn get_log_normal_exp_ptx(&self, precision: PtxType) -> RandResult<String> {
490 generate_log_normal_exp_ptx(precision, self.sm_version).map_err(RandError::from)
491 }
492
493 fn get_poisson_postprocess_f32_ptx(&self) -> RandResult<String> {
495 generate_poisson_postprocess_f32_ptx(self.sm_version).map_err(RandError::from)
496 }
497
498 fn uniform_kernel_name(&self, precision: PtxType) -> String {
500 let prec_str = match precision {
501 PtxType::F32 => "f32",
502 PtxType::F64 => "f64",
503 _ => "f32",
504 };
505 match self.engine {
506 RngEngine::Philox => format!("philox_uniform_{prec_str}"),
507 RngEngine::Xorwow => format!("xorwow_uniform_{prec_str}"),
508 RngEngine::Mrg32k3a => format!("mrg32k3a_uniform_{prec_str}"),
509 }
510 }
511
512 fn normal_kernel_name(&self, precision: PtxType) -> String {
514 let prec_str = match precision {
515 PtxType::F32 => "f32",
516 PtxType::F64 => "f64",
517 _ => "f32",
518 };
519 match self.engine {
520 RngEngine::Philox => format!("philox_normal_{prec_str}"),
521 RngEngine::Xorwow => format!("xorwow_normal_{prec_str}"),
522 RngEngine::Mrg32k3a => format!("mrg32k3a_normal_{prec_str}"),
523 }
524 }
525
526 fn u32_kernel_name(&self) -> String {
528 match self.engine {
529 RngEngine::Philox => "philox_u32".to_string(),
530 RngEngine::Mrg32k3a => "mrg32k3a_u32".to_string(),
531 RngEngine::Xorwow => "xorwow_u32".to_string(), }
533 }
534
535 fn compile_and_launch_uniform(
541 &self,
542 ptx_source: &str,
543 precision: PtxType,
544 out_ptr: u64,
545 n: usize,
546 ) -> RandResult<()> {
547 let module = Arc::new(Module::from_ptx(ptx_source).map_err(RandError::Cuda)?);
548 let kernel_name = self.uniform_kernel_name(precision);
549 let kernel = Kernel::from_module(module, &kernel_name).map_err(RandError::Cuda)?;
550
551 let n_u32 = u32::try_from(n)
552 .map_err(|_| RandError::InvalidSize(format!("output size {n} exceeds u32::MAX")))?;
553 let grid = grid_size_for(n_u32, 256);
554 let params = LaunchParams::new(grid, 256u32);
555
556 let seed_lo = self.seed as u32;
557 let seed_hi = (self.seed >> 32) as u32;
558 let offset_lo = self.offset as u32;
559 let offset_hi = (self.offset >> 32) as u32;
560
561 match self.engine {
564 RngEngine::Philox => {
565 let args = (out_ptr, n_u32, seed_lo, seed_hi, offset_lo, offset_hi);
566 kernel
567 .launch(¶ms, &self.stream, &args)
568 .map_err(RandError::Cuda)?;
569 }
570 RngEngine::Xorwow | RngEngine::Mrg32k3a => {
571 let args = (out_ptr, n_u32, seed_lo, offset_lo, offset_hi);
572 kernel
573 .launch(¶ms, &self.stream, &args)
574 .map_err(RandError::Cuda)?;
575 }
576 }
577
578 self.stream.synchronize().map_err(RandError::Cuda)?;
579 Ok(())
580 }
581
582 fn compile_and_launch_normal_f32(
584 &self,
585 ptx_source: &str,
586 out_ptr: u64,
587 n: usize,
588 mean: f32,
589 stddev: f32,
590 ) -> RandResult<()> {
591 let module = Arc::new(Module::from_ptx(ptx_source).map_err(RandError::Cuda)?);
592 let kernel_name = self.normal_kernel_name(PtxType::F32);
593 let kernel = Kernel::from_module(module, &kernel_name).map_err(RandError::Cuda)?;
594
595 let n_u32 = u32::try_from(n)
596 .map_err(|_| RandError::InvalidSize(format!("output size {n} exceeds u32::MAX")))?;
597 let grid = grid_size_for(n_u32, 256);
598 let params = LaunchParams::new(grid, 256u32);
599
600 let seed_lo = self.seed as u32;
601 let seed_hi = (self.seed >> 32) as u32;
602 let offset_lo = self.offset as u32;
603 let offset_hi = (self.offset >> 32) as u32;
604
605 match self.engine {
606 RngEngine::Philox => {
607 let args = (
608 out_ptr, n_u32, seed_lo, seed_hi, offset_lo, offset_hi, mean, stddev,
609 );
610 kernel
611 .launch(¶ms, &self.stream, &args)
612 .map_err(RandError::Cuda)?;
613 }
614 RngEngine::Xorwow | RngEngine::Mrg32k3a => {
615 let args = (out_ptr, n_u32, seed_lo, offset_lo, offset_hi, mean, stddev);
616 kernel
617 .launch(¶ms, &self.stream, &args)
618 .map_err(RandError::Cuda)?;
619 }
620 }
621
622 self.stream.synchronize().map_err(RandError::Cuda)?;
623 Ok(())
624 }
625
626 fn compile_and_launch_normal_f64(
628 &self,
629 ptx_source: &str,
630 out_ptr: u64,
631 n: usize,
632 mean: f64,
633 stddev: f64,
634 ) -> RandResult<()> {
635 let module = Arc::new(Module::from_ptx(ptx_source).map_err(RandError::Cuda)?);
636 let kernel_name = self.normal_kernel_name(PtxType::F64);
637 let kernel = Kernel::from_module(module, &kernel_name).map_err(RandError::Cuda)?;
638
639 let n_u32 = u32::try_from(n)
640 .map_err(|_| RandError::InvalidSize(format!("output size {n} exceeds u32::MAX")))?;
641 let grid = grid_size_for(n_u32, 256);
642 let params = LaunchParams::new(grid, 256u32);
643
644 let seed_lo = self.seed as u32;
645 let seed_hi = (self.seed >> 32) as u32;
646 let offset_lo = self.offset as u32;
647 let offset_hi = (self.offset >> 32) as u32;
648
649 match self.engine {
650 RngEngine::Philox => {
651 let args = (
652 out_ptr, n_u32, seed_lo, seed_hi, offset_lo, offset_hi, mean, stddev,
653 );
654 kernel
655 .launch(¶ms, &self.stream, &args)
656 .map_err(RandError::Cuda)?;
657 }
658 RngEngine::Xorwow | RngEngine::Mrg32k3a => {
659 let args = (out_ptr, n_u32, seed_lo, offset_lo, offset_hi, mean, stddev);
660 kernel
661 .launch(¶ms, &self.stream, &args)
662 .map_err(RandError::Cuda)?;
663 }
664 }
665
666 self.stream.synchronize().map_err(RandError::Cuda)?;
667 Ok(())
668 }
669
670 fn compile_and_launch_u32(
672 &self,
673 ptx_source: &str,
674 kernel_name: &str,
675 out_ptr: u64,
676 n: usize,
677 ) -> RandResult<()> {
678 let module = Arc::new(Module::from_ptx(ptx_source).map_err(RandError::Cuda)?);
679 let kernel = Kernel::from_module(module, kernel_name).map_err(RandError::Cuda)?;
680
681 let n_u32 = u32::try_from(n)
682 .map_err(|_| RandError::InvalidSize(format!("output size {n} exceeds u32::MAX")))?;
683 let grid = grid_size_for(n_u32, 256);
684 let params = LaunchParams::new(grid, 256u32);
685
686 let seed_lo = self.seed as u32;
687 let seed_hi = (self.seed >> 32) as u32;
688 let offset_lo = self.offset as u32;
689 let offset_hi = (self.offset >> 32) as u32;
690
691 match self.engine {
692 RngEngine::Philox => {
693 let args = (out_ptr, n_u32, seed_lo, seed_hi, offset_lo, offset_hi);
694 kernel
695 .launch(¶ms, &self.stream, &args)
696 .map_err(RandError::Cuda)?;
697 }
698 RngEngine::Mrg32k3a => {
699 let args = (out_ptr, n_u32, seed_lo, offset_lo, offset_hi);
700 kernel
701 .launch(¶ms, &self.stream, &args)
702 .map_err(RandError::Cuda)?;
703 }
704 RngEngine::Xorwow => {
705 return Err(RandError::UnsupportedDistribution(
707 "u32 not supported for XORWOW".to_string(),
708 ));
709 }
710 }
711
712 self.stream.synchronize().map_err(RandError::Cuda)?;
713 Ok(())
714 }
715
716 fn compile_and_launch_log_normal_exp(
718 &self,
719 ptx_source: &str,
720 precision: PtxType,
721 out_ptr: u64,
722 n: usize,
723 ) -> RandResult<()> {
724 let module = Arc::new(Module::from_ptx(ptx_source).map_err(RandError::Cuda)?);
725 let kernel_name = log_normal_exp_kernel_name(precision);
726 let kernel = Kernel::from_module(module, kernel_name).map_err(RandError::Cuda)?;
727
728 let n_u32 = u32::try_from(n)
729 .map_err(|_| RandError::InvalidSize(format!("output size {n} exceeds u32::MAX")))?;
730 let grid = grid_size_for(n_u32, 256);
731 let params = LaunchParams::new(grid, 256u32);
732
733 let args = (out_ptr, n_u32);
734 kernel
735 .launch(¶ms, &self.stream, &args)
736 .map_err(RandError::Cuda)?;
737
738 self.stream.synchronize().map_err(RandError::Cuda)?;
739 Ok(())
740 }
741
742 fn compile_and_launch_poisson_postprocess_f32(
744 &self,
745 ptx_source: &str,
746 out_ptr: u64,
747 n: usize,
748 ) -> RandResult<()> {
749 let module = Arc::new(Module::from_ptx(ptx_source).map_err(RandError::Cuda)?);
750 let kernel_name = poisson_postprocess_kernel_name();
751 let kernel = Kernel::from_module(module, kernel_name).map_err(RandError::Cuda)?;
752
753 let n_u32 = u32::try_from(n)
754 .map_err(|_| RandError::InvalidSize(format!("output size {n} exceeds u32::MAX")))?;
755 let grid = grid_size_for(n_u32, 256);
756 let params = LaunchParams::new(grid, 256u32);
757
758 let args = (out_ptr, n_u32);
759 kernel
760 .launch(¶ms, &self.stream, &args)
761 .map_err(RandError::Cuda)?;
762
763 self.stream.synchronize().map_err(RandError::Cuda)?;
764 Ok(())
765 }
766}
767
768#[cfg(test)]
769mod tests {
770 use super::*;
771
772 #[test]
773 fn engine_display() {
774 assert_eq!(format!("{}", RngEngine::Philox), "Philox-4x32-10");
775 assert_eq!(format!("{}", RngEngine::Xorwow), "XORWOW");
776 assert_eq!(format!("{}", RngEngine::Mrg32k3a), "MRG32k3a");
777 }
778
779 #[test]
780 fn uniform_kernel_names() {
781 let expected_philox_f32 = "philox_uniform_f32";
784 let expected_xorwow_f64 = "xorwow_uniform_f64";
785 let expected_mrg_f32 = "mrg32k3a_uniform_f32";
786
787 assert_eq!(expected_philox_f32, "philox_uniform_f32");
788 assert_eq!(expected_xorwow_f64, "xorwow_uniform_f64");
789 assert_eq!(expected_mrg_f32, "mrg32k3a_uniform_f32");
790 }
791
792 #[test]
793 fn ptx_generation_philox_uniform() {
794 let ptx = philox::generate_philox_uniform_ptx(PtxType::F32, SmVersion::Sm80);
795 assert!(ptx.is_ok());
796 }
797
798 #[test]
799 fn ptx_generation_xorwow_uniform() {
800 let ptx = xorwow::generate_xorwow_uniform_ptx(PtxType::F32, SmVersion::Sm80);
801 assert!(ptx.is_ok());
802 }
803
804 #[test]
805 fn ptx_generation_mrg32k3a_uniform() {
806 let ptx = mrg32k3a::generate_mrg32k3a_uniform_ptx(PtxType::F32, SmVersion::Sm80);
807 assert!(ptx.is_ok());
808 }
809
810 #[test]
811 fn log_normal_exp_f32_ptx_generation() {
812 let ptx = generate_log_normal_exp_ptx(PtxType::F32, SmVersion::Sm80)
813 .unwrap_or_else(|e| panic!("{e}"));
814 assert!(ptx.contains(".entry log_normal_exp_f32"));
815 assert!(ptx.contains("ex2.approx.f32"));
816 assert!(ptx.contains("0f3FB8AA3B"));
817 assert!(!ptx.contains("philox_normal_f32"));
818 }
819
820 #[test]
821 fn log_normal_exp_f64_ptx_generation() {
822 let ptx = generate_log_normal_exp_ptx(PtxType::F64, SmVersion::Sm80)
823 .unwrap_or_else(|e| panic!("{e}"));
824 assert!(ptx.contains(".entry log_normal_exp_f64"));
825 assert!(ptx.contains("cvt.rn.f32.f64"));
826 assert!(ptx.contains("ex2.approx.f32"));
827 assert!(ptx.contains("cvt.f64.f32"));
828 assert!(!ptx.contains("philox_normal_f64"));
829 }
830
831 #[test]
832 fn poisson_postprocess_f32_ptx_generation() {
833 let ptx =
834 generate_poisson_postprocess_f32_ptx(SmVersion::Sm80).unwrap_or_else(|e| panic!("{e}"));
835 assert!(ptx.contains(".entry poisson_postprocess_f32"));
836 assert!(ptx.contains("cvt.rni.s32.f32"));
837 assert!(ptx.contains("max.s32"));
838 assert!(ptx.contains("cvt.rn.f32.s32"));
839 assert!(!ptx.contains("philox_normal_f32"));
840 }
841
842 #[test]
843 fn poisson_lambda_validation_rejects_invalid_values() {
844 let negative = validate_poisson_lambda(-1.0);
845 assert!(matches!(negative, Err(RandError::InvalidParameter(_))));
846
847 let nan = validate_poisson_lambda(f64::NAN);
848 assert!(matches!(nan, Err(RandError::InvalidParameter(_))));
849
850 let inf = validate_poisson_lambda(f64::INFINITY);
851 assert!(matches!(inf, Err(RandError::InvalidParameter(_))));
852 }
853
854 #[test]
855 fn poisson_lambda_validation_accepts_valid_values() {
856 let zero = validate_poisson_lambda(0.0);
857 assert!(matches!(zero, Ok(v) if v == 0.0));
858
859 let positive = validate_poisson_lambda(12.5);
860 assert!(matches!(positive, Ok(v) if v == 12.5_f32));
861 }
862}