1use std::sync::Arc;
8
9use oxicuda_driver::context::Context;
10use oxicuda_driver::module::Module;
11use oxicuda_driver::stream::Stream;
12use oxicuda_launch::grid::grid_size_for;
13use oxicuda_launch::kernel::Kernel;
14use oxicuda_launch::params::LaunchParams;
15use oxicuda_memory::DeviceBuffer;
16use oxicuda_ptx::arch::SmVersion;
17use oxicuda_ptx::ir::PtxType;
18
19use crate::engines::{mrg32k3a, philox, philox_optimized, xorwow};
20use crate::error::{RandError, RandResult};
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
28pub enum RngEngine {
29 Philox,
31 Xorwow,
33 Mrg32k3a,
35}
36
37impl std::fmt::Display for RngEngine {
38 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39 match self {
40 Self::Philox => write!(f, "Philox-4x32-10"),
41 Self::Xorwow => write!(f, "XORWOW"),
42 Self::Mrg32k3a => write!(f, "MRG32k3a"),
43 }
44 }
45}
46
47pub struct RngGenerator {
75 engine: RngEngine,
77 seed: u64,
79 offset: u64,
81 #[allow(dead_code)]
83 context: Arc<Context>,
84 stream: Stream,
86 sm_version: SmVersion,
88}
89
90impl RngGenerator {
91 pub fn new(engine: RngEngine, seed: u64, ctx: &Arc<Context>) -> RandResult<Self> {
97 let stream = Stream::new(ctx).map_err(RandError::Cuda)?;
98 Ok(Self {
99 engine,
100 seed,
101 offset: 0,
102 context: Arc::clone(ctx),
103 stream,
104 sm_version: SmVersion::Sm80,
105 })
106 }
107
108 pub fn set_seed(&mut self, seed: u64) {
110 self.seed = seed;
111 }
112
113 pub fn set_offset(&mut self, offset: u64) {
115 self.offset = offset;
116 }
117
118 pub fn skip(&mut self, n: u64) {
120 self.offset = self.offset.wrapping_add(n);
121 }
122
123 pub fn generate_uniform_f32(&mut self, output: &mut DeviceBuffer<f32>) -> RandResult<()> {
129 let n = output.len();
130 let ptx_source = self.get_uniform_ptx(PtxType::F32)?;
131 self.compile_and_launch_uniform(&ptx_source, PtxType::F32, output.as_device_ptr(), n)?;
132 self.offset += n as u64;
133 Ok(())
134 }
135
136 pub fn generate_uniform_f64(&mut self, output: &mut DeviceBuffer<f64>) -> RandResult<()> {
142 let n = output.len();
143 let ptx_source = self.get_uniform_ptx(PtxType::F64)?;
144 self.compile_and_launch_uniform(&ptx_source, PtxType::F64, output.as_device_ptr(), n)?;
145 self.offset += n as u64;
146 Ok(())
147 }
148
149 pub fn generate_uniform_f32_optimized(
159 &mut self,
160 output: &mut DeviceBuffer<f32>,
161 ) -> RandResult<()> {
162 let n = output.len();
163 if self.engine != RngEngine::Philox || n < philox_optimized::OPTIMIZED_THRESHOLD {
164 return self.generate_uniform_f32(output);
165 }
166
167 let ptx_source =
168 philox_optimized::generate_philox_optimized_uniform_f32_ptx(self.sm_version)?;
169 self.compile_and_launch_uniform(&ptx_source, PtxType::F32, output.as_device_ptr(), n)?;
170 self.offset += n.div_ceil(4) as u64;
172 Ok(())
173 }
174
175 pub fn generate_normal_f32_optimized(
185 &mut self,
186 output: &mut DeviceBuffer<f32>,
187 mean: f32,
188 stddev: f32,
189 ) -> RandResult<()> {
190 let n = output.len();
191 if self.engine != RngEngine::Philox || n < philox_optimized::OPTIMIZED_THRESHOLD {
192 return self.generate_normal_f32(output, mean, stddev);
193 }
194
195 let ptx_source =
196 philox_optimized::generate_philox_optimized_normal_f32_ptx(self.sm_version)?;
197 self.compile_and_launch_normal_f32(&ptx_source, output.as_device_ptr(), n, mean, stddev)?;
198 self.offset += n.div_ceil(4) as u64;
199 Ok(())
200 }
201
202 pub fn generate_normal_f32(
208 &mut self,
209 output: &mut DeviceBuffer<f32>,
210 mean: f32,
211 stddev: f32,
212 ) -> RandResult<()> {
213 let n = output.len();
214 let ptx_source = self.get_normal_ptx(PtxType::F32)?;
215 self.compile_and_launch_normal_f32(&ptx_source, output.as_device_ptr(), n, mean, stddev)?;
216 self.offset += n as u64;
217 Ok(())
218 }
219
220 pub fn generate_normal_f64(
226 &mut self,
227 output: &mut DeviceBuffer<f64>,
228 mean: f64,
229 stddev: f64,
230 ) -> RandResult<()> {
231 let n = output.len();
232 let ptx_source = self.get_normal_ptx(PtxType::F64)?;
233 self.compile_and_launch_normal_f64(&ptx_source, output.as_device_ptr(), n, mean, stddev)?;
234 self.offset += n as u64;
235 Ok(())
236 }
237
238 pub fn generate_log_normal_f32(
246 &mut self,
247 output: &mut DeviceBuffer<f32>,
248 mean: f32,
249 stddev: f32,
250 ) -> RandResult<()> {
251 self.generate_normal_f32(output, mean, stddev)
257 }
258
259 pub fn generate_log_normal_f64(
265 &mut self,
266 output: &mut DeviceBuffer<f64>,
267 mean: f64,
268 stddev: f64,
269 ) -> RandResult<()> {
270 self.generate_normal_f64(output, mean, stddev)
271 }
272
273 pub fn generate_poisson_f32(
282 &mut self,
283 output: &mut DeviceBuffer<f32>,
284 lambda: f64,
285 ) -> RandResult<()> {
286 let _lambda_f32 = lambda as f32;
290 let _n = output.len();
291 self.generate_uniform_f32(output)
294 }
295
296 pub fn generate_u32(&mut self, output: &mut DeviceBuffer<u32>) -> RandResult<()> {
305 let n = output.len();
306 let ptx_source = self.get_u32_ptx()?;
307 let kernel_name = self.u32_kernel_name();
308 self.compile_and_launch_u32(&ptx_source, &kernel_name, output.as_device_ptr(), n)?;
309 self.offset += n as u64;
310 Ok(())
311 }
312
313 fn get_uniform_ptx(&self, precision: PtxType) -> RandResult<String> {
319 let ptx = match self.engine {
320 RngEngine::Philox => philox::generate_philox_uniform_ptx(precision, self.sm_version)?,
321 RngEngine::Xorwow => xorwow::generate_xorwow_uniform_ptx(precision, self.sm_version)?,
322 RngEngine::Mrg32k3a => {
323 mrg32k3a::generate_mrg32k3a_uniform_ptx(precision, self.sm_version)?
324 }
325 };
326 Ok(ptx)
327 }
328
329 fn get_normal_ptx(&self, precision: PtxType) -> RandResult<String> {
331 let ptx = match self.engine {
332 RngEngine::Philox => philox::generate_philox_normal_ptx(precision, self.sm_version)?,
333 RngEngine::Xorwow => xorwow::generate_xorwow_normal_ptx(precision, self.sm_version)?,
334 RngEngine::Mrg32k3a => {
335 mrg32k3a::generate_mrg32k3a_normal_ptx(precision, self.sm_version)?
336 }
337 };
338 Ok(ptx)
339 }
340
341 fn get_u32_ptx(&self) -> RandResult<String> {
343 let ptx = match self.engine {
344 RngEngine::Philox => philox::generate_philox_u32_ptx(self.sm_version)?,
345 RngEngine::Mrg32k3a => mrg32k3a::generate_mrg32k3a_u32_ptx(self.sm_version)?,
346 RngEngine::Xorwow => {
347 return Err(RandError::UnsupportedDistribution(
348 "u32 output is not supported for XORWOW engine".to_string(),
349 ));
350 }
351 };
352 Ok(ptx)
353 }
354
355 fn uniform_kernel_name(&self, precision: PtxType) -> String {
357 let prec_str = match precision {
358 PtxType::F32 => "f32",
359 PtxType::F64 => "f64",
360 _ => "f32",
361 };
362 match self.engine {
363 RngEngine::Philox => format!("philox_uniform_{prec_str}"),
364 RngEngine::Xorwow => format!("xorwow_uniform_{prec_str}"),
365 RngEngine::Mrg32k3a => format!("mrg32k3a_uniform_{prec_str}"),
366 }
367 }
368
369 fn normal_kernel_name(&self, precision: PtxType) -> String {
371 let prec_str = match precision {
372 PtxType::F32 => "f32",
373 PtxType::F64 => "f64",
374 _ => "f32",
375 };
376 match self.engine {
377 RngEngine::Philox => format!("philox_normal_{prec_str}"),
378 RngEngine::Xorwow => format!("xorwow_normal_{prec_str}"),
379 RngEngine::Mrg32k3a => format!("mrg32k3a_normal_{prec_str}"),
380 }
381 }
382
383 fn u32_kernel_name(&self) -> String {
385 match self.engine {
386 RngEngine::Philox => "philox_u32".to_string(),
387 RngEngine::Mrg32k3a => "mrg32k3a_u32".to_string(),
388 RngEngine::Xorwow => "xorwow_u32".to_string(), }
390 }
391
392 fn compile_and_launch_uniform(
398 &self,
399 ptx_source: &str,
400 precision: PtxType,
401 out_ptr: u64,
402 n: usize,
403 ) -> RandResult<()> {
404 let module = Arc::new(Module::from_ptx(ptx_source).map_err(RandError::Cuda)?);
405 let kernel_name = self.uniform_kernel_name(precision);
406 let kernel = Kernel::from_module(module, &kernel_name).map_err(RandError::Cuda)?;
407
408 let n_u32 = u32::try_from(n)
409 .map_err(|_| RandError::InvalidSize(format!("output size {n} exceeds u32::MAX")))?;
410 let grid = grid_size_for(n_u32, 256);
411 let params = LaunchParams::new(grid, 256u32);
412
413 let seed_lo = self.seed as u32;
414 let seed_hi = (self.seed >> 32) as u32;
415 let offset_lo = self.offset as u32;
416 let offset_hi = (self.offset >> 32) as u32;
417
418 match self.engine {
421 RngEngine::Philox => {
422 let args = (out_ptr, n_u32, seed_lo, seed_hi, offset_lo, offset_hi);
423 kernel
424 .launch(¶ms, &self.stream, &args)
425 .map_err(RandError::Cuda)?;
426 }
427 RngEngine::Xorwow | RngEngine::Mrg32k3a => {
428 let args = (out_ptr, n_u32, seed_lo, offset_lo, offset_hi);
429 kernel
430 .launch(¶ms, &self.stream, &args)
431 .map_err(RandError::Cuda)?;
432 }
433 }
434
435 self.stream.synchronize().map_err(RandError::Cuda)?;
436 Ok(())
437 }
438
439 fn compile_and_launch_normal_f32(
441 &self,
442 ptx_source: &str,
443 out_ptr: u64,
444 n: usize,
445 mean: f32,
446 stddev: f32,
447 ) -> RandResult<()> {
448 let module = Arc::new(Module::from_ptx(ptx_source).map_err(RandError::Cuda)?);
449 let kernel_name = self.normal_kernel_name(PtxType::F32);
450 let kernel = Kernel::from_module(module, &kernel_name).map_err(RandError::Cuda)?;
451
452 let n_u32 = u32::try_from(n)
453 .map_err(|_| RandError::InvalidSize(format!("output size {n} exceeds u32::MAX")))?;
454 let grid = grid_size_for(n_u32, 256);
455 let params = LaunchParams::new(grid, 256u32);
456
457 let seed_lo = self.seed as u32;
458 let seed_hi = (self.seed >> 32) as u32;
459 let offset_lo = self.offset as u32;
460 let offset_hi = (self.offset >> 32) as u32;
461
462 match self.engine {
463 RngEngine::Philox => {
464 let args = (
465 out_ptr, n_u32, seed_lo, seed_hi, offset_lo, offset_hi, mean, stddev,
466 );
467 kernel
468 .launch(¶ms, &self.stream, &args)
469 .map_err(RandError::Cuda)?;
470 }
471 RngEngine::Xorwow | RngEngine::Mrg32k3a => {
472 let args = (out_ptr, n_u32, seed_lo, offset_lo, offset_hi, mean, stddev);
473 kernel
474 .launch(¶ms, &self.stream, &args)
475 .map_err(RandError::Cuda)?;
476 }
477 }
478
479 self.stream.synchronize().map_err(RandError::Cuda)?;
480 Ok(())
481 }
482
483 fn compile_and_launch_normal_f64(
485 &self,
486 ptx_source: &str,
487 out_ptr: u64,
488 n: usize,
489 mean: f64,
490 stddev: f64,
491 ) -> RandResult<()> {
492 let module = Arc::new(Module::from_ptx(ptx_source).map_err(RandError::Cuda)?);
493 let kernel_name = self.normal_kernel_name(PtxType::F64);
494 let kernel = Kernel::from_module(module, &kernel_name).map_err(RandError::Cuda)?;
495
496 let n_u32 = u32::try_from(n)
497 .map_err(|_| RandError::InvalidSize(format!("output size {n} exceeds u32::MAX")))?;
498 let grid = grid_size_for(n_u32, 256);
499 let params = LaunchParams::new(grid, 256u32);
500
501 let seed_lo = self.seed as u32;
502 let seed_hi = (self.seed >> 32) as u32;
503 let offset_lo = self.offset as u32;
504 let offset_hi = (self.offset >> 32) as u32;
505
506 match self.engine {
507 RngEngine::Philox => {
508 let args = (
509 out_ptr, n_u32, seed_lo, seed_hi, offset_lo, offset_hi, mean, stddev,
510 );
511 kernel
512 .launch(¶ms, &self.stream, &args)
513 .map_err(RandError::Cuda)?;
514 }
515 RngEngine::Xorwow | RngEngine::Mrg32k3a => {
516 let args = (out_ptr, n_u32, seed_lo, offset_lo, offset_hi, mean, stddev);
517 kernel
518 .launch(¶ms, &self.stream, &args)
519 .map_err(RandError::Cuda)?;
520 }
521 }
522
523 self.stream.synchronize().map_err(RandError::Cuda)?;
524 Ok(())
525 }
526
527 fn compile_and_launch_u32(
529 &self,
530 ptx_source: &str,
531 kernel_name: &str,
532 out_ptr: u64,
533 n: usize,
534 ) -> RandResult<()> {
535 let module = Arc::new(Module::from_ptx(ptx_source).map_err(RandError::Cuda)?);
536 let kernel = Kernel::from_module(module, kernel_name).map_err(RandError::Cuda)?;
537
538 let n_u32 = u32::try_from(n)
539 .map_err(|_| RandError::InvalidSize(format!("output size {n} exceeds u32::MAX")))?;
540 let grid = grid_size_for(n_u32, 256);
541 let params = LaunchParams::new(grid, 256u32);
542
543 let seed_lo = self.seed as u32;
544 let seed_hi = (self.seed >> 32) as u32;
545 let offset_lo = self.offset as u32;
546 let offset_hi = (self.offset >> 32) as u32;
547
548 match self.engine {
549 RngEngine::Philox => {
550 let args = (out_ptr, n_u32, seed_lo, seed_hi, offset_lo, offset_hi);
551 kernel
552 .launch(¶ms, &self.stream, &args)
553 .map_err(RandError::Cuda)?;
554 }
555 RngEngine::Mrg32k3a => {
556 let args = (out_ptr, n_u32, seed_lo, offset_lo, offset_hi);
557 kernel
558 .launch(¶ms, &self.stream, &args)
559 .map_err(RandError::Cuda)?;
560 }
561 RngEngine::Xorwow => {
562 return Err(RandError::UnsupportedDistribution(
564 "u32 not supported for XORWOW".to_string(),
565 ));
566 }
567 }
568
569 self.stream.synchronize().map_err(RandError::Cuda)?;
570 Ok(())
571 }
572}
573
574#[cfg(test)]
575mod tests {
576 use super::*;
577
578 #[test]
579 fn engine_display() {
580 assert_eq!(format!("{}", RngEngine::Philox), "Philox-4x32-10");
581 assert_eq!(format!("{}", RngEngine::Xorwow), "XORWOW");
582 assert_eq!(format!("{}", RngEngine::Mrg32k3a), "MRG32k3a");
583 }
584
585 #[test]
586 fn uniform_kernel_names() {
587 let expected_philox_f32 = "philox_uniform_f32";
590 let expected_xorwow_f64 = "xorwow_uniform_f64";
591 let expected_mrg_f32 = "mrg32k3a_uniform_f32";
592
593 assert_eq!(expected_philox_f32, "philox_uniform_f32");
594 assert_eq!(expected_xorwow_f64, "xorwow_uniform_f64");
595 assert_eq!(expected_mrg_f32, "mrg32k3a_uniform_f32");
596 }
597
598 #[test]
599 fn ptx_generation_philox_uniform() {
600 let ptx = philox::generate_philox_uniform_ptx(PtxType::F32, SmVersion::Sm80);
601 assert!(ptx.is_ok());
602 }
603
604 #[test]
605 fn ptx_generation_xorwow_uniform() {
606 let ptx = xorwow::generate_xorwow_uniform_ptx(PtxType::F32, SmVersion::Sm80);
607 assert!(ptx.is_ok());
608 }
609
610 #[test]
611 fn ptx_generation_mrg32k3a_uniform() {
612 let ptx = mrg32k3a::generate_mrg32k3a_uniform_ptx(PtxType::F32, SmVersion::Sm80);
613 assert!(ptx.is_ok());
614 }
615}