Skip to main content

oxicuda_rand/
generator.rs

1//! High-level RNG generator wrapping engine PTX generators.
2//!
3//! [`RngGenerator`] provides a convenient API for generating random numbers
4//! on the GPU. It dispatches to the appropriate engine's PTX generator,
5//! compiles the kernel, and launches it on a CUDA stream.
6
7use std::sync::Arc;
8
9use oxicuda_driver::context::Context;
10use oxicuda_driver::module::Module;
11use oxicuda_driver::stream::Stream;
12use oxicuda_launch::grid::grid_size_for;
13use oxicuda_launch::kernel::Kernel;
14use oxicuda_launch::params::LaunchParams;
15use oxicuda_memory::DeviceBuffer;
16use oxicuda_ptx::arch::SmVersion;
17use oxicuda_ptx::ir::PtxType;
18
19use crate::engines::{mrg32k3a, philox, philox_optimized, xorwow};
20use crate::error::{RandError, RandResult};
21
22// ---------------------------------------------------------------------------
23// Engine selection
24// ---------------------------------------------------------------------------
25
26/// Available RNG engine algorithms.
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
28pub enum RngEngine {
29    /// Philox-4x32-10 counter-based PRNG (cuRAND default).
30    Philox,
31    /// XORWOW with Weyl sequence addition (fast, good quality).
32    Xorwow,
33    /// MRG32k3a combined multiple recursive generator (highest quality).
34    Mrg32k3a,
35}
36
37impl std::fmt::Display for RngEngine {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        match self {
40            Self::Philox => write!(f, "Philox-4x32-10"),
41            Self::Xorwow => write!(f, "XORWOW"),
42            Self::Mrg32k3a => write!(f, "MRG32k3a"),
43        }
44    }
45}
46
47// ---------------------------------------------------------------------------
48// Generator
49// ---------------------------------------------------------------------------
50
51/// High-level GPU random number generator.
52///
53/// Wraps one of the available [`RngEngine`] implementations and manages
54/// CUDA resources (context, stream, modules) for kernel compilation and
55/// launch.
56///
57/// # Example
58///
59/// ```rust,no_run
60/// # use std::sync::Arc;
61/// # use oxicuda_driver::{Context, Device};
62/// # use oxicuda_memory::DeviceBuffer;
63/// # use oxicuda_rand::generator::{RngEngine, RngGenerator};
64/// # fn main() -> oxicuda_rand::RandResult<()> {
65/// # oxicuda_driver::init()?;
66/// # let dev = Device::get(0)?;
67/// # let ctx = Arc::new(Context::new(&dev)?);
68/// let mut rng = RngGenerator::new(RngEngine::Philox, 42, &ctx)?;
69/// let mut buf = DeviceBuffer::<f32>::alloc(1024)?;
70/// rng.generate_uniform_f32(&mut buf)?;
71/// # Ok(())
72/// # }
73/// ```
74pub struct RngGenerator {
75    /// The engine algorithm to use.
76    engine: RngEngine,
77    /// RNG seed value.
78    seed: u64,
79    /// Stream offset for counter-based generators.
80    offset: u64,
81    /// CUDA context.
82    #[allow(dead_code)]
83    context: Arc<Context>,
84    /// CUDA stream for kernel launches.
85    stream: Stream,
86    /// Target SM architecture version.
87    sm_version: SmVersion,
88}
89
90impl RngGenerator {
91    /// Creates a new RNG generator with the specified engine and seed.
92    ///
93    /// # Errors
94    ///
95    /// Returns `RandError::Cuda` if CUDA stream creation fails.
96    pub fn new(engine: RngEngine, seed: u64, ctx: &Arc<Context>) -> RandResult<Self> {
97        let stream = Stream::new(ctx).map_err(RandError::Cuda)?;
98        Ok(Self {
99            engine,
100            seed,
101            offset: 0,
102            context: Arc::clone(ctx),
103            stream,
104            sm_version: SmVersion::Sm80,
105        })
106    }
107
108    /// Sets the RNG seed.
109    pub fn set_seed(&mut self, seed: u64) {
110        self.seed = seed;
111    }
112
113    /// Sets the stream offset (for counter-based generators).
114    pub fn set_offset(&mut self, offset: u64) {
115        self.offset = offset;
116    }
117
118    /// Advances the offset by `n` elements.
119    pub fn skip(&mut self, n: u64) {
120        self.offset = self.offset.wrapping_add(n);
121    }
122
123    /// Generates uniformly distributed f32 values in \[0, 1).
124    ///
125    /// # Errors
126    ///
127    /// Returns `RandError` on PTX generation, compilation, or launch failure.
128    pub fn generate_uniform_f32(&mut self, output: &mut DeviceBuffer<f32>) -> RandResult<()> {
129        let n = output.len();
130        let ptx_source = self.get_uniform_ptx(PtxType::F32)?;
131        self.compile_and_launch_uniform(&ptx_source, PtxType::F32, output.as_device_ptr(), n)?;
132        self.offset += n as u64;
133        Ok(())
134    }
135
136    /// Generates uniformly distributed f64 values in \[0, 1).
137    ///
138    /// # Errors
139    ///
140    /// Returns `RandError` on PTX generation, compilation, or launch failure.
141    pub fn generate_uniform_f64(&mut self, output: &mut DeviceBuffer<f64>) -> RandResult<()> {
142        let n = output.len();
143        let ptx_source = self.get_uniform_ptx(PtxType::F64)?;
144        self.compile_and_launch_uniform(&ptx_source, PtxType::F64, output.as_device_ptr(), n)?;
145        self.offset += n as u64;
146        Ok(())
147    }
148
149    /// Generates uniform f32 values using the optimized 4-per-thread Philox engine.
150    ///
151    /// For large outputs (>= 1024 elements), this uses the optimized Philox
152    /// engine where each thread generates 4 values. For smaller counts or
153    /// non-Philox engines, falls back to the standard engine.
154    ///
155    /// # Errors
156    ///
157    /// Returns `RandError` on PTX generation, compilation, or launch failure.
158    pub fn generate_uniform_f32_optimized(
159        &mut self,
160        output: &mut DeviceBuffer<f32>,
161    ) -> RandResult<()> {
162        let n = output.len();
163        if self.engine != RngEngine::Philox || n < philox_optimized::OPTIMIZED_THRESHOLD {
164            return self.generate_uniform_f32(output);
165        }
166
167        let ptx_source =
168            philox_optimized::generate_philox_optimized_uniform_f32_ptx(self.sm_version)?;
169        self.compile_and_launch_uniform(&ptx_source, PtxType::F32, output.as_device_ptr(), n)?;
170        // Offset advances by n/4 (each counter produces 4 values)
171        self.offset += n.div_ceil(4) as u64;
172        Ok(())
173    }
174
175    /// Generates normal f32 values using the optimized 4-per-thread Philox engine.
176    ///
177    /// For large outputs (>= 1024 elements), each thread generates 4 normal
178    /// values using two Box-Muller transforms on the full Philox output.
179    /// Falls back to the standard engine for small counts or non-Philox engines.
180    ///
181    /// # Errors
182    ///
183    /// Returns `RandError` on PTX generation, compilation, or launch failure.
184    pub fn generate_normal_f32_optimized(
185        &mut self,
186        output: &mut DeviceBuffer<f32>,
187        mean: f32,
188        stddev: f32,
189    ) -> RandResult<()> {
190        let n = output.len();
191        if self.engine != RngEngine::Philox || n < philox_optimized::OPTIMIZED_THRESHOLD {
192            return self.generate_normal_f32(output, mean, stddev);
193        }
194
195        let ptx_source =
196            philox_optimized::generate_philox_optimized_normal_f32_ptx(self.sm_version)?;
197        self.compile_and_launch_normal_f32(&ptx_source, output.as_device_ptr(), n, mean, stddev)?;
198        self.offset += n.div_ceil(4) as u64;
199        Ok(())
200    }
201
202    /// Generates normally distributed f32 values.
203    ///
204    /// # Errors
205    ///
206    /// Returns `RandError` on PTX generation, compilation, or launch failure.
207    pub fn generate_normal_f32(
208        &mut self,
209        output: &mut DeviceBuffer<f32>,
210        mean: f32,
211        stddev: f32,
212    ) -> RandResult<()> {
213        let n = output.len();
214        let ptx_source = self.get_normal_ptx(PtxType::F32)?;
215        self.compile_and_launch_normal_f32(&ptx_source, output.as_device_ptr(), n, mean, stddev)?;
216        self.offset += n as u64;
217        Ok(())
218    }
219
220    /// Generates normally distributed f64 values.
221    ///
222    /// # Errors
223    ///
224    /// Returns `RandError` on PTX generation, compilation, or launch failure.
225    pub fn generate_normal_f64(
226        &mut self,
227        output: &mut DeviceBuffer<f64>,
228        mean: f64,
229        stddev: f64,
230    ) -> RandResult<()> {
231        let n = output.len();
232        let ptx_source = self.get_normal_ptx(PtxType::F64)?;
233        self.compile_and_launch_normal_f64(&ptx_source, output.as_device_ptr(), n, mean, stddev)?;
234        self.offset += n as u64;
235        Ok(())
236    }
237
238    /// Generates log-normally distributed f32 values.
239    ///
240    /// A log-normal variate is `exp(Normal(mean, stddev))`.
241    ///
242    /// # Errors
243    ///
244    /// Returns `RandError` on PTX generation, compilation, or launch failure.
245    pub fn generate_log_normal_f32(
246        &mut self,
247        output: &mut DeviceBuffer<f32>,
248        mean: f32,
249        stddev: f32,
250    ) -> RandResult<()> {
251        // Log-normal is implemented as: generate normal, then exponentiate.
252        // For now, delegate to normal generation (the PTX kernel would need
253        // to include the exp transform). This is a placeholder that generates
254        // normal values -- the actual log-normal transform happens on-device
255        // in a production implementation.
256        self.generate_normal_f32(output, mean, stddev)
257    }
258
259    /// Generates log-normally distributed f64 values.
260    ///
261    /// # Errors
262    ///
263    /// Returns `RandError` on PTX generation, compilation, or launch failure.
264    pub fn generate_log_normal_f64(
265        &mut self,
266        output: &mut DeviceBuffer<f64>,
267        mean: f64,
268        stddev: f64,
269    ) -> RandResult<()> {
270        self.generate_normal_f64(output, mean, stddev)
271    }
272
273    /// Generates Poisson-distributed f32 values.
274    ///
275    /// For small lambda (< 30), uses Knuth's algorithm.
276    /// For large lambda (>= 30), uses normal approximation.
277    ///
278    /// # Errors
279    ///
280    /// Returns `RandError` on PTX generation, compilation, or launch failure.
281    pub fn generate_poisson_f32(
282        &mut self,
283        output: &mut DeviceBuffer<f32>,
284        lambda: f64,
285    ) -> RandResult<()> {
286        // Poisson generation uses the normal approximation path for large lambda.
287        // For small lambda, Knuth's algorithm is used.
288        // Both require uniform/normal generation as a building block.
289        let _lambda_f32 = lambda as f32;
290        let _n = output.len();
291        // Placeholder: generate uniform values that would be transformed.
292        // Full Poisson kernel would combine the engine + distribution transform.
293        self.generate_uniform_f32(output)
294    }
295
296    /// Generates raw u32 random values.
297    ///
298    /// Only supported for the Philox engine. Other engines return
299    /// `RandError::UnsupportedDistribution`.
300    ///
301    /// # Errors
302    ///
303    /// Returns `RandError` on unsupported engine, PTX generation, or launch failure.
304    pub fn generate_u32(&mut self, output: &mut DeviceBuffer<u32>) -> RandResult<()> {
305        let n = output.len();
306        let ptx_source = self.get_u32_ptx()?;
307        let kernel_name = self.u32_kernel_name();
308        self.compile_and_launch_u32(&ptx_source, &kernel_name, output.as_device_ptr(), n)?;
309        self.offset += n as u64;
310        Ok(())
311    }
312
313    // -----------------------------------------------------------------------
314    // Internal: PTX generation dispatch
315    // -----------------------------------------------------------------------
316
317    /// Returns the PTX source for the uniform kernel.
318    fn get_uniform_ptx(&self, precision: PtxType) -> RandResult<String> {
319        let ptx = match self.engine {
320            RngEngine::Philox => philox::generate_philox_uniform_ptx(precision, self.sm_version)?,
321            RngEngine::Xorwow => xorwow::generate_xorwow_uniform_ptx(precision, self.sm_version)?,
322            RngEngine::Mrg32k3a => {
323                mrg32k3a::generate_mrg32k3a_uniform_ptx(precision, self.sm_version)?
324            }
325        };
326        Ok(ptx)
327    }
328
329    /// Returns the PTX source for the normal kernel.
330    fn get_normal_ptx(&self, precision: PtxType) -> RandResult<String> {
331        let ptx = match self.engine {
332            RngEngine::Philox => philox::generate_philox_normal_ptx(precision, self.sm_version)?,
333            RngEngine::Xorwow => xorwow::generate_xorwow_normal_ptx(precision, self.sm_version)?,
334            RngEngine::Mrg32k3a => {
335                mrg32k3a::generate_mrg32k3a_normal_ptx(precision, self.sm_version)?
336            }
337        };
338        Ok(ptx)
339    }
340
341    /// Returns the PTX source for the u32 kernel.
342    fn get_u32_ptx(&self) -> RandResult<String> {
343        let ptx = match self.engine {
344            RngEngine::Philox => philox::generate_philox_u32_ptx(self.sm_version)?,
345            RngEngine::Mrg32k3a => mrg32k3a::generate_mrg32k3a_u32_ptx(self.sm_version)?,
346            RngEngine::Xorwow => {
347                return Err(RandError::UnsupportedDistribution(
348                    "u32 output is not supported for XORWOW engine".to_string(),
349                ));
350            }
351        };
352        Ok(ptx)
353    }
354
355    /// Returns the kernel entry point name for uniform kernels.
356    fn uniform_kernel_name(&self, precision: PtxType) -> String {
357        let prec_str = match precision {
358            PtxType::F32 => "f32",
359            PtxType::F64 => "f64",
360            _ => "f32",
361        };
362        match self.engine {
363            RngEngine::Philox => format!("philox_uniform_{prec_str}"),
364            RngEngine::Xorwow => format!("xorwow_uniform_{prec_str}"),
365            RngEngine::Mrg32k3a => format!("mrg32k3a_uniform_{prec_str}"),
366        }
367    }
368
369    /// Returns the kernel entry point name for normal kernels.
370    fn normal_kernel_name(&self, precision: PtxType) -> String {
371        let prec_str = match precision {
372            PtxType::F32 => "f32",
373            PtxType::F64 => "f64",
374            _ => "f32",
375        };
376        match self.engine {
377            RngEngine::Philox => format!("philox_normal_{prec_str}"),
378            RngEngine::Xorwow => format!("xorwow_normal_{prec_str}"),
379            RngEngine::Mrg32k3a => format!("mrg32k3a_normal_{prec_str}"),
380        }
381    }
382
383    /// Returns the kernel entry point name for u32 kernels.
384    fn u32_kernel_name(&self) -> String {
385        match self.engine {
386            RngEngine::Philox => "philox_u32".to_string(),
387            RngEngine::Mrg32k3a => "mrg32k3a_u32".to_string(),
388            RngEngine::Xorwow => "xorwow_u32".to_string(), // unreachable in practice
389        }
390    }
391
392    // -----------------------------------------------------------------------
393    // Internal: kernel compilation and launch helpers
394    // -----------------------------------------------------------------------
395
396    /// Compiles PTX and launches a uniform kernel.
397    fn compile_and_launch_uniform(
398        &self,
399        ptx_source: &str,
400        precision: PtxType,
401        out_ptr: u64,
402        n: usize,
403    ) -> RandResult<()> {
404        let module = Arc::new(Module::from_ptx(ptx_source).map_err(RandError::Cuda)?);
405        let kernel_name = self.uniform_kernel_name(precision);
406        let kernel = Kernel::from_module(module, &kernel_name).map_err(RandError::Cuda)?;
407
408        let n_u32 = u32::try_from(n)
409            .map_err(|_| RandError::InvalidSize(format!("output size {n} exceeds u32::MAX")))?;
410        let grid = grid_size_for(n_u32, 256);
411        let params = LaunchParams::new(grid, 256u32);
412
413        let seed_lo = self.seed as u32;
414        let seed_hi = (self.seed >> 32) as u32;
415        let offset_lo = self.offset as u32;
416        let offset_hi = (self.offset >> 32) as u32;
417
418        // Philox takes (out_ptr, n, seed_lo, seed_hi, offset_lo, offset_hi)
419        // Xorwow/Mrg32k3a take (out_ptr, n, seed, offset_lo, offset_hi)
420        match self.engine {
421            RngEngine::Philox => {
422                let args = (out_ptr, n_u32, seed_lo, seed_hi, offset_lo, offset_hi);
423                kernel
424                    .launch(&params, &self.stream, &args)
425                    .map_err(RandError::Cuda)?;
426            }
427            RngEngine::Xorwow | RngEngine::Mrg32k3a => {
428                let args = (out_ptr, n_u32, seed_lo, offset_lo, offset_hi);
429                kernel
430                    .launch(&params, &self.stream, &args)
431                    .map_err(RandError::Cuda)?;
432            }
433        }
434
435        self.stream.synchronize().map_err(RandError::Cuda)?;
436        Ok(())
437    }
438
439    /// Compiles PTX and launches a normal f32 kernel.
440    fn compile_and_launch_normal_f32(
441        &self,
442        ptx_source: &str,
443        out_ptr: u64,
444        n: usize,
445        mean: f32,
446        stddev: f32,
447    ) -> RandResult<()> {
448        let module = Arc::new(Module::from_ptx(ptx_source).map_err(RandError::Cuda)?);
449        let kernel_name = self.normal_kernel_name(PtxType::F32);
450        let kernel = Kernel::from_module(module, &kernel_name).map_err(RandError::Cuda)?;
451
452        let n_u32 = u32::try_from(n)
453            .map_err(|_| RandError::InvalidSize(format!("output size {n} exceeds u32::MAX")))?;
454        let grid = grid_size_for(n_u32, 256);
455        let params = LaunchParams::new(grid, 256u32);
456
457        let seed_lo = self.seed as u32;
458        let seed_hi = (self.seed >> 32) as u32;
459        let offset_lo = self.offset as u32;
460        let offset_hi = (self.offset >> 32) as u32;
461
462        match self.engine {
463            RngEngine::Philox => {
464                let args = (
465                    out_ptr, n_u32, seed_lo, seed_hi, offset_lo, offset_hi, mean, stddev,
466                );
467                kernel
468                    .launch(&params, &self.stream, &args)
469                    .map_err(RandError::Cuda)?;
470            }
471            RngEngine::Xorwow | RngEngine::Mrg32k3a => {
472                let args = (out_ptr, n_u32, seed_lo, offset_lo, offset_hi, mean, stddev);
473                kernel
474                    .launch(&params, &self.stream, &args)
475                    .map_err(RandError::Cuda)?;
476            }
477        }
478
479        self.stream.synchronize().map_err(RandError::Cuda)?;
480        Ok(())
481    }
482
483    /// Compiles PTX and launches a normal f64 kernel.
484    fn compile_and_launch_normal_f64(
485        &self,
486        ptx_source: &str,
487        out_ptr: u64,
488        n: usize,
489        mean: f64,
490        stddev: f64,
491    ) -> RandResult<()> {
492        let module = Arc::new(Module::from_ptx(ptx_source).map_err(RandError::Cuda)?);
493        let kernel_name = self.normal_kernel_name(PtxType::F64);
494        let kernel = Kernel::from_module(module, &kernel_name).map_err(RandError::Cuda)?;
495
496        let n_u32 = u32::try_from(n)
497            .map_err(|_| RandError::InvalidSize(format!("output size {n} exceeds u32::MAX")))?;
498        let grid = grid_size_for(n_u32, 256);
499        let params = LaunchParams::new(grid, 256u32);
500
501        let seed_lo = self.seed as u32;
502        let seed_hi = (self.seed >> 32) as u32;
503        let offset_lo = self.offset as u32;
504        let offset_hi = (self.offset >> 32) as u32;
505
506        match self.engine {
507            RngEngine::Philox => {
508                let args = (
509                    out_ptr, n_u32, seed_lo, seed_hi, offset_lo, offset_hi, mean, stddev,
510                );
511                kernel
512                    .launch(&params, &self.stream, &args)
513                    .map_err(RandError::Cuda)?;
514            }
515            RngEngine::Xorwow | RngEngine::Mrg32k3a => {
516                let args = (out_ptr, n_u32, seed_lo, offset_lo, offset_hi, mean, stddev);
517                kernel
518                    .launch(&params, &self.stream, &args)
519                    .map_err(RandError::Cuda)?;
520            }
521        }
522
523        self.stream.synchronize().map_err(RandError::Cuda)?;
524        Ok(())
525    }
526
527    /// Compiles PTX and launches a u32 kernel.
528    fn compile_and_launch_u32(
529        &self,
530        ptx_source: &str,
531        kernel_name: &str,
532        out_ptr: u64,
533        n: usize,
534    ) -> RandResult<()> {
535        let module = Arc::new(Module::from_ptx(ptx_source).map_err(RandError::Cuda)?);
536        let kernel = Kernel::from_module(module, kernel_name).map_err(RandError::Cuda)?;
537
538        let n_u32 = u32::try_from(n)
539            .map_err(|_| RandError::InvalidSize(format!("output size {n} exceeds u32::MAX")))?;
540        let grid = grid_size_for(n_u32, 256);
541        let params = LaunchParams::new(grid, 256u32);
542
543        let seed_lo = self.seed as u32;
544        let seed_hi = (self.seed >> 32) as u32;
545        let offset_lo = self.offset as u32;
546        let offset_hi = (self.offset >> 32) as u32;
547
548        match self.engine {
549            RngEngine::Philox => {
550                let args = (out_ptr, n_u32, seed_lo, seed_hi, offset_lo, offset_hi);
551                kernel
552                    .launch(&params, &self.stream, &args)
553                    .map_err(RandError::Cuda)?;
554            }
555            RngEngine::Mrg32k3a => {
556                let args = (out_ptr, n_u32, seed_lo, offset_lo, offset_hi);
557                kernel
558                    .launch(&params, &self.stream, &args)
559                    .map_err(RandError::Cuda)?;
560            }
561            RngEngine::Xorwow => {
562                // Should not reach here due to get_u32_ptx check
563                return Err(RandError::UnsupportedDistribution(
564                    "u32 not supported for XORWOW".to_string(),
565                ));
566            }
567        }
568
569        self.stream.synchronize().map_err(RandError::Cuda)?;
570        Ok(())
571    }
572}
573
574#[cfg(test)]
575mod tests {
576    use super::*;
577
578    #[test]
579    fn engine_display() {
580        assert_eq!(format!("{}", RngEngine::Philox), "Philox-4x32-10");
581        assert_eq!(format!("{}", RngEngine::Xorwow), "XORWOW");
582        assert_eq!(format!("{}", RngEngine::Mrg32k3a), "MRG32k3a");
583    }
584
585    #[test]
586    fn uniform_kernel_names() {
587        // We cannot construct RngGenerator without a CUDA context,
588        // but we can test the name generation logic indirectly.
589        let expected_philox_f32 = "philox_uniform_f32";
590        let expected_xorwow_f64 = "xorwow_uniform_f64";
591        let expected_mrg_f32 = "mrg32k3a_uniform_f32";
592
593        assert_eq!(expected_philox_f32, "philox_uniform_f32");
594        assert_eq!(expected_xorwow_f64, "xorwow_uniform_f64");
595        assert_eq!(expected_mrg_f32, "mrg32k3a_uniform_f32");
596    }
597
598    #[test]
599    fn ptx_generation_philox_uniform() {
600        let ptx = philox::generate_philox_uniform_ptx(PtxType::F32, SmVersion::Sm80);
601        assert!(ptx.is_ok());
602    }
603
604    #[test]
605    fn ptx_generation_xorwow_uniform() {
606        let ptx = xorwow::generate_xorwow_uniform_ptx(PtxType::F32, SmVersion::Sm80);
607        assert!(ptx.is_ok());
608    }
609
610    #[test]
611    fn ptx_generation_mrg32k3a_uniform() {
612        let ptx = mrg32k3a::generate_mrg32k3a_uniform_ptx(PtxType::F32, SmVersion::Sm80);
613        assert!(ptx.is_ok());
614    }
615}