Skip to main content

oxicuda_rand/
generator.rs

1//! High-level RNG generator wrapping engine PTX generators.
2//!
3//! [`RngGenerator`] provides a convenient API for generating random numbers
4//! on the GPU. It dispatches to the appropriate engine's PTX generator,
5//! compiles the kernel, and launches it on a CUDA stream.
6
7use std::sync::Arc;
8
9use oxicuda_driver::context::Context;
10use oxicuda_driver::module::Module;
11use oxicuda_driver::stream::Stream;
12use oxicuda_launch::grid::grid_size_for;
13use oxicuda_launch::kernel::Kernel;
14use oxicuda_launch::params::LaunchParams;
15use oxicuda_memory::DeviceBuffer;
16use oxicuda_ptx::arch::SmVersion;
17use oxicuda_ptx::builder::KernelBuilder;
18use oxicuda_ptx::error::PtxGenError;
19use oxicuda_ptx::ir::PtxType;
20
21use crate::engines::{mrg32k3a, philox, philox_optimized, xorwow};
22use crate::error::{RandError, RandResult};
23
24const LOG_NORMAL_EXP_KERNEL_F32: &str = "log_normal_exp_f32";
25const LOG_NORMAL_EXP_KERNEL_F64: &str = "log_normal_exp_f64";
26const POISSON_POSTPROCESS_KERNEL_F32: &str = "poisson_postprocess_f32";
27
28fn log_normal_exp_kernel_name(precision: PtxType) -> &'static str {
29    match precision {
30        PtxType::F32 => LOG_NORMAL_EXP_KERNEL_F32,
31        PtxType::F64 => LOG_NORMAL_EXP_KERNEL_F64,
32        _ => LOG_NORMAL_EXP_KERNEL_F32,
33    }
34}
35
36fn poisson_postprocess_kernel_name() -> &'static str {
37    POISSON_POSTPROCESS_KERNEL_F32
38}
39
40fn generate_log_normal_exp_ptx(precision: PtxType, sm: SmVersion) -> Result<String, PtxGenError> {
41    let kernel_name = log_normal_exp_kernel_name(precision);
42    let stride_bytes = precision.size_bytes() as u32;
43
44    KernelBuilder::new(kernel_name)
45        .target(sm)
46        .param("out_ptr", PtxType::U64)
47        .param("n", PtxType::U32)
48        .max_threads_per_block(256)
49        .body(move |b| {
50            let gid = b.global_thread_id_x();
51            let n_reg = b.load_param_u32("n");
52
53            b.if_lt_u32(gid.clone(), n_reg, move |b| {
54                let out_ptr = b.load_param_u64("out_ptr");
55                let addr = b.byte_offset_addr(out_ptr, gid.clone(), stride_bytes);
56
57                match precision {
58                    PtxType::F32 => {
59                        let normal_val = b.load_global_f32(addr.clone());
60                        let log2e = b.alloc_reg(PtxType::F32);
61                        b.raw_ptx(&format!("mov.f32 {log2e}, 0f3FB8AA3B;"));
62                        let scaled = b.alloc_reg(PtxType::F32);
63                        b.raw_ptx(&format!("mul.rn.f32 {scaled}, {normal_val}, {log2e};"));
64                        let result = b.alloc_reg(PtxType::F32);
65                        b.raw_ptx(&format!("ex2.approx.f32 {result}, {scaled};"));
66                        b.store_global_f32(addr, result);
67                    }
68                    PtxType::F64 => {
69                        let normal_val = b.load_global_f64(addr.clone());
70                        let narrow = b.alloc_reg(PtxType::F32);
71                        b.raw_ptx(&format!("cvt.rn.f32.f64 {narrow}, {normal_val};"));
72
73                        let log2e = b.alloc_reg(PtxType::F32);
74                        b.raw_ptx(&format!("mov.f32 {log2e}, 0f3FB8AA3B;"));
75                        let scaled = b.alloc_reg(PtxType::F32);
76                        b.raw_ptx(&format!("mul.rn.f32 {scaled}, {narrow}, {log2e};"));
77                        let exp_f32 = b.alloc_reg(PtxType::F32);
78                        b.raw_ptx(&format!("ex2.approx.f32 {exp_f32}, {scaled};"));
79
80                        let result = b.alloc_reg(PtxType::F64);
81                        b.raw_ptx(&format!("cvt.f64.f32 {result}, {exp_f32};"));
82                        b.store_global_f64(addr, result);
83                    }
84                    _ => {}
85                }
86            });
87
88            b.ret();
89        })
90        .build()
91}
92
93fn generate_poisson_postprocess_f32_ptx(sm: SmVersion) -> Result<String, PtxGenError> {
94    let kernel_name = poisson_postprocess_kernel_name();
95
96    KernelBuilder::new(kernel_name)
97        .target(sm)
98        .param("out_ptr", PtxType::U64)
99        .param("n", PtxType::U32)
100        .max_threads_per_block(256)
101        .body(move |b| {
102            let gid = b.global_thread_id_x();
103            let n_reg = b.load_param_u32("n");
104
105            b.if_lt_u32(gid.clone(), n_reg, move |b| {
106                let out_ptr = b.load_param_u64("out_ptr");
107                let addr = b.byte_offset_addr(out_ptr, gid, 4);
108                let value = b.load_global_f32(addr.clone());
109
110                let rounded_i32 = b.alloc_reg(PtxType::S32);
111                b.raw_ptx(&format!("cvt.rni.s32.f32 {rounded_i32}, {value};"));
112
113                let zero_i32 = b.alloc_reg(PtxType::S32);
114                b.raw_ptx(&format!("mov.s32 {zero_i32}, 0;"));
115
116                let clamped_i32 = b.alloc_reg(PtxType::S32);
117                b.raw_ptx(&format!(
118                    "max.s32 {clamped_i32}, {rounded_i32}, {zero_i32};"
119                ));
120
121                let clamped_f32 = b.alloc_reg(PtxType::F32);
122                b.raw_ptx(&format!("cvt.rn.f32.s32 {clamped_f32}, {clamped_i32};"));
123                b.store_global_f32(addr, clamped_f32);
124            });
125
126            b.ret();
127        })
128        .build()
129}
130
131fn validate_poisson_lambda(lambda: f64) -> RandResult<f32> {
132    if !lambda.is_finite() || lambda < 0.0 {
133        return Err(RandError::InvalidParameter(format!(
134            "lambda must be finite and >= 0, got {lambda}"
135        )));
136    }
137    Ok(lambda as f32)
138}
139
140// ---------------------------------------------------------------------------
141// Engine selection
142// ---------------------------------------------------------------------------
143
144/// Available RNG engine algorithms.
145#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
146pub enum RngEngine {
147    /// Philox-4x32-10 counter-based PRNG (cuRAND default).
148    Philox,
149    /// XORWOW with Weyl sequence addition (fast, good quality).
150    Xorwow,
151    /// MRG32k3a combined multiple recursive generator (highest quality).
152    Mrg32k3a,
153}
154
155impl std::fmt::Display for RngEngine {
156    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
157        match self {
158            Self::Philox => write!(f, "Philox-4x32-10"),
159            Self::Xorwow => write!(f, "XORWOW"),
160            Self::Mrg32k3a => write!(f, "MRG32k3a"),
161        }
162    }
163}
164
165// ---------------------------------------------------------------------------
166// Generator
167// ---------------------------------------------------------------------------
168
169/// High-level GPU random number generator.
170///
171/// Wraps one of the available [`RngEngine`] implementations and manages
172/// CUDA resources (context, stream, modules) for kernel compilation and
173/// launch.
174///
175/// # Example
176///
177/// ```rust,no_run
178/// # use std::sync::Arc;
179/// # use oxicuda_driver::{Context, Device};
180/// # use oxicuda_memory::DeviceBuffer;
181/// # use oxicuda_rand::generator::{RngEngine, RngGenerator};
182/// # fn main() -> oxicuda_rand::RandResult<()> {
183/// # oxicuda_driver::init()?;
184/// # let dev = Device::get(0)?;
185/// # let ctx = Arc::new(Context::new(&dev)?);
186/// let mut rng = RngGenerator::new(RngEngine::Philox, 42, &ctx)?;
187/// let mut buf = DeviceBuffer::<f32>::alloc(1024)?;
188/// rng.generate_uniform_f32(&mut buf)?;
189/// # Ok(())
190/// # }
191/// ```
192pub struct RngGenerator {
193    /// The engine algorithm to use.
194    engine: RngEngine,
195    /// RNG seed value.
196    seed: u64,
197    /// Stream offset for counter-based generators.
198    offset: u64,
199    /// CUDA context.
200    #[allow(dead_code)]
201    context: Arc<Context>,
202    /// CUDA stream for kernel launches.
203    stream: Stream,
204    /// Target SM architecture version.
205    sm_version: SmVersion,
206}
207
208impl RngGenerator {
209    /// Creates a new RNG generator with the specified engine and seed.
210    ///
211    /// # Errors
212    ///
213    /// Returns `RandError::Cuda` if CUDA stream creation fails.
214    pub fn new(engine: RngEngine, seed: u64, ctx: &Arc<Context>) -> RandResult<Self> {
215        let stream = Stream::new(ctx).map_err(RandError::Cuda)?;
216        Ok(Self {
217            engine,
218            seed,
219            offset: 0,
220            context: Arc::clone(ctx),
221            stream,
222            sm_version: SmVersion::Sm80,
223        })
224    }
225
226    /// Sets the RNG seed.
227    pub fn set_seed(&mut self, seed: u64) {
228        self.seed = seed;
229    }
230
231    /// Sets the stream offset (for counter-based generators).
232    pub fn set_offset(&mut self, offset: u64) {
233        self.offset = offset;
234    }
235
236    /// Advances the offset by `n` elements.
237    pub fn skip(&mut self, n: u64) {
238        self.offset = self.offset.wrapping_add(n);
239    }
240
241    /// Generates uniformly distributed f32 values in \[0, 1).
242    ///
243    /// # Errors
244    ///
245    /// Returns `RandError` on PTX generation, compilation, or launch failure.
246    pub fn generate_uniform_f32(&mut self, output: &mut DeviceBuffer<f32>) -> RandResult<()> {
247        let n = output.len();
248        let ptx_source = self.get_uniform_ptx(PtxType::F32)?;
249        self.compile_and_launch_uniform(&ptx_source, PtxType::F32, output.as_device_ptr(), n)?;
250        self.offset += n as u64;
251        Ok(())
252    }
253
254    /// Generates uniformly distributed f64 values in \[0, 1).
255    ///
256    /// # Errors
257    ///
258    /// Returns `RandError` on PTX generation, compilation, or launch failure.
259    pub fn generate_uniform_f64(&mut self, output: &mut DeviceBuffer<f64>) -> RandResult<()> {
260        let n = output.len();
261        let ptx_source = self.get_uniform_ptx(PtxType::F64)?;
262        self.compile_and_launch_uniform(&ptx_source, PtxType::F64, output.as_device_ptr(), n)?;
263        self.offset += n as u64;
264        Ok(())
265    }
266
267    /// Generates uniform f32 values using the optimized 4-per-thread Philox engine.
268    ///
269    /// For large outputs (>= 1024 elements), this uses the optimized Philox
270    /// engine where each thread generates 4 values. For smaller counts or
271    /// non-Philox engines, falls back to the standard engine.
272    ///
273    /// # Errors
274    ///
275    /// Returns `RandError` on PTX generation, compilation, or launch failure.
276    pub fn generate_uniform_f32_optimized(
277        &mut self,
278        output: &mut DeviceBuffer<f32>,
279    ) -> RandResult<()> {
280        let n = output.len();
281        if self.engine != RngEngine::Philox || n < philox_optimized::OPTIMIZED_THRESHOLD {
282            return self.generate_uniform_f32(output);
283        }
284
285        let ptx_source =
286            philox_optimized::generate_philox_optimized_uniform_f32_ptx(self.sm_version)?;
287        self.compile_and_launch_uniform(&ptx_source, PtxType::F32, output.as_device_ptr(), n)?;
288        // Offset advances by n/4 (each counter produces 4 values)
289        self.offset += n.div_ceil(4) as u64;
290        Ok(())
291    }
292
293    /// Generates normal f32 values using the optimized 4-per-thread Philox engine.
294    ///
295    /// For large outputs (>= 1024 elements), each thread generates 4 normal
296    /// values using two Box-Muller transforms on the full Philox output.
297    /// Falls back to the standard engine for small counts or non-Philox engines.
298    ///
299    /// # Errors
300    ///
301    /// Returns `RandError` on PTX generation, compilation, or launch failure.
302    pub fn generate_normal_f32_optimized(
303        &mut self,
304        output: &mut DeviceBuffer<f32>,
305        mean: f32,
306        stddev: f32,
307    ) -> RandResult<()> {
308        let n = output.len();
309        if self.engine != RngEngine::Philox || n < philox_optimized::OPTIMIZED_THRESHOLD {
310            return self.generate_normal_f32(output, mean, stddev);
311        }
312
313        let ptx_source =
314            philox_optimized::generate_philox_optimized_normal_f32_ptx(self.sm_version)?;
315        self.compile_and_launch_normal_f32(&ptx_source, output.as_device_ptr(), n, mean, stddev)?;
316        self.offset += n.div_ceil(4) as u64;
317        Ok(())
318    }
319
320    /// Generates normally distributed f32 values.
321    ///
322    /// # Errors
323    ///
324    /// Returns `RandError` on PTX generation, compilation, or launch failure.
325    pub fn generate_normal_f32(
326        &mut self,
327        output: &mut DeviceBuffer<f32>,
328        mean: f32,
329        stddev: f32,
330    ) -> RandResult<()> {
331        let n = output.len();
332        let ptx_source = self.get_normal_ptx(PtxType::F32)?;
333        self.compile_and_launch_normal_f32(&ptx_source, output.as_device_ptr(), n, mean, stddev)?;
334        self.offset += n as u64;
335        Ok(())
336    }
337
338    /// Generates normally distributed f64 values.
339    ///
340    /// # Errors
341    ///
342    /// Returns `RandError` on PTX generation, compilation, or launch failure.
343    pub fn generate_normal_f64(
344        &mut self,
345        output: &mut DeviceBuffer<f64>,
346        mean: f64,
347        stddev: f64,
348    ) -> RandResult<()> {
349        let n = output.len();
350        let ptx_source = self.get_normal_ptx(PtxType::F64)?;
351        self.compile_and_launch_normal_f64(&ptx_source, output.as_device_ptr(), n, mean, stddev)?;
352        self.offset += n as u64;
353        Ok(())
354    }
355
356    /// Generates log-normally distributed f32 values.
357    ///
358    /// A log-normal variate is `exp(Normal(mean, stddev))`.
359    ///
360    /// # Errors
361    ///
362    /// Returns `RandError` on PTX generation, compilation, or launch failure.
363    pub fn generate_log_normal_f32(
364        &mut self,
365        output: &mut DeviceBuffer<f32>,
366        mean: f32,
367        stddev: f32,
368    ) -> RandResult<()> {
369        let n = output.len();
370        self.generate_normal_f32(output, mean, stddev)?;
371        let ptx_source = self.get_log_normal_exp_ptx(PtxType::F32)?;
372        self.compile_and_launch_log_normal_exp(
373            &ptx_source,
374            PtxType::F32,
375            output.as_device_ptr(),
376            n,
377        )?;
378        Ok(())
379    }
380
381    /// Generates log-normally distributed f64 values.
382    ///
383    /// # Errors
384    ///
385    /// Returns `RandError` on PTX generation, compilation, or launch failure.
386    pub fn generate_log_normal_f64(
387        &mut self,
388        output: &mut DeviceBuffer<f64>,
389        mean: f64,
390        stddev: f64,
391    ) -> RandResult<()> {
392        let n = output.len();
393        self.generate_normal_f64(output, mean, stddev)?;
394        let ptx_source = self.get_log_normal_exp_ptx(PtxType::F64)?;
395        self.compile_and_launch_log_normal_exp(
396            &ptx_source,
397            PtxType::F64,
398            output.as_device_ptr(),
399            n,
400        )?;
401        Ok(())
402    }
403
404    /// Generates Poisson-distributed f32 values.
405    ///
406    /// Uses a normal approximation: `Normal(lambda, sqrt(lambda))` followed by
407    /// in-place rounding to nearest integer and clamping to `>= 0`.
408    ///
409    /// # Errors
410    ///
411    /// Returns `RandError` on PTX generation, compilation, or launch failure.
412    pub fn generate_poisson_f32(
413        &mut self,
414        output: &mut DeviceBuffer<f32>,
415        lambda: f64,
416    ) -> RandResult<()> {
417        let lambda_f32 = validate_poisson_lambda(lambda)?;
418        let stddev = lambda.sqrt() as f32;
419        let n = output.len();
420
421        // Consume RNG state using normal generation; postprocessing is deterministic.
422        self.generate_normal_f32(output, lambda_f32, stddev)?;
423
424        let ptx_source = self.get_poisson_postprocess_f32_ptx()?;
425        self.compile_and_launch_poisson_postprocess_f32(&ptx_source, output.as_device_ptr(), n)?;
426        Ok(())
427    }
428
429    /// Generates raw u32 random values.
430    ///
431    /// Only supported for the Philox engine. Other engines return
432    /// `RandError::UnsupportedDistribution`.
433    ///
434    /// # Errors
435    ///
436    /// Returns `RandError` on unsupported engine, PTX generation, or launch failure.
437    pub fn generate_u32(&mut self, output: &mut DeviceBuffer<u32>) -> RandResult<()> {
438        let n = output.len();
439        let ptx_source = self.get_u32_ptx()?;
440        let kernel_name = self.u32_kernel_name();
441        self.compile_and_launch_u32(&ptx_source, &kernel_name, output.as_device_ptr(), n)?;
442        self.offset += n as u64;
443        Ok(())
444    }
445
446    // -----------------------------------------------------------------------
447    // Internal: PTX generation dispatch
448    // -----------------------------------------------------------------------
449
450    /// Returns the PTX source for the uniform kernel.
451    fn get_uniform_ptx(&self, precision: PtxType) -> RandResult<String> {
452        let ptx = match self.engine {
453            RngEngine::Philox => philox::generate_philox_uniform_ptx(precision, self.sm_version)?,
454            RngEngine::Xorwow => xorwow::generate_xorwow_uniform_ptx(precision, self.sm_version)?,
455            RngEngine::Mrg32k3a => {
456                mrg32k3a::generate_mrg32k3a_uniform_ptx(precision, self.sm_version)?
457            }
458        };
459        Ok(ptx)
460    }
461
462    /// Returns the PTX source for the normal kernel.
463    fn get_normal_ptx(&self, precision: PtxType) -> RandResult<String> {
464        let ptx = match self.engine {
465            RngEngine::Philox => philox::generate_philox_normal_ptx(precision, self.sm_version)?,
466            RngEngine::Xorwow => xorwow::generate_xorwow_normal_ptx(precision, self.sm_version)?,
467            RngEngine::Mrg32k3a => {
468                mrg32k3a::generate_mrg32k3a_normal_ptx(precision, self.sm_version)?
469            }
470        };
471        Ok(ptx)
472    }
473
474    /// Returns the PTX source for the u32 kernel.
475    fn get_u32_ptx(&self) -> RandResult<String> {
476        let ptx = match self.engine {
477            RngEngine::Philox => philox::generate_philox_u32_ptx(self.sm_version)?,
478            RngEngine::Mrg32k3a => mrg32k3a::generate_mrg32k3a_u32_ptx(self.sm_version)?,
479            RngEngine::Xorwow => {
480                return Err(RandError::UnsupportedDistribution(
481                    "u32 output is not supported for XORWOW engine".to_string(),
482                ));
483            }
484        };
485        Ok(ptx)
486    }
487
488    /// Returns PTX for the in-place exp transform used by log-normal generation.
489    fn get_log_normal_exp_ptx(&self, precision: PtxType) -> RandResult<String> {
490        generate_log_normal_exp_ptx(precision, self.sm_version).map_err(RandError::from)
491    }
492
493    /// Returns PTX for in-place Poisson approximation postprocessing.
494    fn get_poisson_postprocess_f32_ptx(&self) -> RandResult<String> {
495        generate_poisson_postprocess_f32_ptx(self.sm_version).map_err(RandError::from)
496    }
497
498    /// Returns the kernel entry point name for uniform kernels.
499    fn uniform_kernel_name(&self, precision: PtxType) -> String {
500        let prec_str = match precision {
501            PtxType::F32 => "f32",
502            PtxType::F64 => "f64",
503            _ => "f32",
504        };
505        match self.engine {
506            RngEngine::Philox => format!("philox_uniform_{prec_str}"),
507            RngEngine::Xorwow => format!("xorwow_uniform_{prec_str}"),
508            RngEngine::Mrg32k3a => format!("mrg32k3a_uniform_{prec_str}"),
509        }
510    }
511
512    /// Returns the kernel entry point name for normal kernels.
513    fn normal_kernel_name(&self, precision: PtxType) -> String {
514        let prec_str = match precision {
515            PtxType::F32 => "f32",
516            PtxType::F64 => "f64",
517            _ => "f32",
518        };
519        match self.engine {
520            RngEngine::Philox => format!("philox_normal_{prec_str}"),
521            RngEngine::Xorwow => format!("xorwow_normal_{prec_str}"),
522            RngEngine::Mrg32k3a => format!("mrg32k3a_normal_{prec_str}"),
523        }
524    }
525
526    /// Returns the kernel entry point name for u32 kernels.
527    fn u32_kernel_name(&self) -> String {
528        match self.engine {
529            RngEngine::Philox => "philox_u32".to_string(),
530            RngEngine::Mrg32k3a => "mrg32k3a_u32".to_string(),
531            RngEngine::Xorwow => "xorwow_u32".to_string(), // unreachable in practice
532        }
533    }
534
535    // -----------------------------------------------------------------------
536    // Internal: kernel compilation and launch helpers
537    // -----------------------------------------------------------------------
538
539    /// Compiles PTX and launches a uniform kernel.
540    fn compile_and_launch_uniform(
541        &self,
542        ptx_source: &str,
543        precision: PtxType,
544        out_ptr: u64,
545        n: usize,
546    ) -> RandResult<()> {
547        let module = Arc::new(Module::from_ptx(ptx_source).map_err(RandError::Cuda)?);
548        let kernel_name = self.uniform_kernel_name(precision);
549        let kernel = Kernel::from_module(module, &kernel_name).map_err(RandError::Cuda)?;
550
551        let n_u32 = u32::try_from(n)
552            .map_err(|_| RandError::InvalidSize(format!("output size {n} exceeds u32::MAX")))?;
553        let grid = grid_size_for(n_u32, 256);
554        let params = LaunchParams::new(grid, 256u32);
555
556        let seed_lo = self.seed as u32;
557        let seed_hi = (self.seed >> 32) as u32;
558        let offset_lo = self.offset as u32;
559        let offset_hi = (self.offset >> 32) as u32;
560
561        // Philox takes (out_ptr, n, seed_lo, seed_hi, offset_lo, offset_hi)
562        // Xorwow/Mrg32k3a take (out_ptr, n, seed, offset_lo, offset_hi)
563        match self.engine {
564            RngEngine::Philox => {
565                let args = (out_ptr, n_u32, seed_lo, seed_hi, offset_lo, offset_hi);
566                kernel
567                    .launch(&params, &self.stream, &args)
568                    .map_err(RandError::Cuda)?;
569            }
570            RngEngine::Xorwow | RngEngine::Mrg32k3a => {
571                let args = (out_ptr, n_u32, seed_lo, offset_lo, offset_hi);
572                kernel
573                    .launch(&params, &self.stream, &args)
574                    .map_err(RandError::Cuda)?;
575            }
576        }
577
578        self.stream.synchronize().map_err(RandError::Cuda)?;
579        Ok(())
580    }
581
582    /// Compiles PTX and launches a normal f32 kernel.
583    fn compile_and_launch_normal_f32(
584        &self,
585        ptx_source: &str,
586        out_ptr: u64,
587        n: usize,
588        mean: f32,
589        stddev: f32,
590    ) -> RandResult<()> {
591        let module = Arc::new(Module::from_ptx(ptx_source).map_err(RandError::Cuda)?);
592        let kernel_name = self.normal_kernel_name(PtxType::F32);
593        let kernel = Kernel::from_module(module, &kernel_name).map_err(RandError::Cuda)?;
594
595        let n_u32 = u32::try_from(n)
596            .map_err(|_| RandError::InvalidSize(format!("output size {n} exceeds u32::MAX")))?;
597        let grid = grid_size_for(n_u32, 256);
598        let params = LaunchParams::new(grid, 256u32);
599
600        let seed_lo = self.seed as u32;
601        let seed_hi = (self.seed >> 32) as u32;
602        let offset_lo = self.offset as u32;
603        let offset_hi = (self.offset >> 32) as u32;
604
605        match self.engine {
606            RngEngine::Philox => {
607                let args = (
608                    out_ptr, n_u32, seed_lo, seed_hi, offset_lo, offset_hi, mean, stddev,
609                );
610                kernel
611                    .launch(&params, &self.stream, &args)
612                    .map_err(RandError::Cuda)?;
613            }
614            RngEngine::Xorwow | RngEngine::Mrg32k3a => {
615                let args = (out_ptr, n_u32, seed_lo, offset_lo, offset_hi, mean, stddev);
616                kernel
617                    .launch(&params, &self.stream, &args)
618                    .map_err(RandError::Cuda)?;
619            }
620        }
621
622        self.stream.synchronize().map_err(RandError::Cuda)?;
623        Ok(())
624    }
625
626    /// Compiles PTX and launches a normal f64 kernel.
627    fn compile_and_launch_normal_f64(
628        &self,
629        ptx_source: &str,
630        out_ptr: u64,
631        n: usize,
632        mean: f64,
633        stddev: f64,
634    ) -> RandResult<()> {
635        let module = Arc::new(Module::from_ptx(ptx_source).map_err(RandError::Cuda)?);
636        let kernel_name = self.normal_kernel_name(PtxType::F64);
637        let kernel = Kernel::from_module(module, &kernel_name).map_err(RandError::Cuda)?;
638
639        let n_u32 = u32::try_from(n)
640            .map_err(|_| RandError::InvalidSize(format!("output size {n} exceeds u32::MAX")))?;
641        let grid = grid_size_for(n_u32, 256);
642        let params = LaunchParams::new(grid, 256u32);
643
644        let seed_lo = self.seed as u32;
645        let seed_hi = (self.seed >> 32) as u32;
646        let offset_lo = self.offset as u32;
647        let offset_hi = (self.offset >> 32) as u32;
648
649        match self.engine {
650            RngEngine::Philox => {
651                let args = (
652                    out_ptr, n_u32, seed_lo, seed_hi, offset_lo, offset_hi, mean, stddev,
653                );
654                kernel
655                    .launch(&params, &self.stream, &args)
656                    .map_err(RandError::Cuda)?;
657            }
658            RngEngine::Xorwow | RngEngine::Mrg32k3a => {
659                let args = (out_ptr, n_u32, seed_lo, offset_lo, offset_hi, mean, stddev);
660                kernel
661                    .launch(&params, &self.stream, &args)
662                    .map_err(RandError::Cuda)?;
663            }
664        }
665
666        self.stream.synchronize().map_err(RandError::Cuda)?;
667        Ok(())
668    }
669
670    /// Compiles PTX and launches a u32 kernel.
671    fn compile_and_launch_u32(
672        &self,
673        ptx_source: &str,
674        kernel_name: &str,
675        out_ptr: u64,
676        n: usize,
677    ) -> RandResult<()> {
678        let module = Arc::new(Module::from_ptx(ptx_source).map_err(RandError::Cuda)?);
679        let kernel = Kernel::from_module(module, kernel_name).map_err(RandError::Cuda)?;
680
681        let n_u32 = u32::try_from(n)
682            .map_err(|_| RandError::InvalidSize(format!("output size {n} exceeds u32::MAX")))?;
683        let grid = grid_size_for(n_u32, 256);
684        let params = LaunchParams::new(grid, 256u32);
685
686        let seed_lo = self.seed as u32;
687        let seed_hi = (self.seed >> 32) as u32;
688        let offset_lo = self.offset as u32;
689        let offset_hi = (self.offset >> 32) as u32;
690
691        match self.engine {
692            RngEngine::Philox => {
693                let args = (out_ptr, n_u32, seed_lo, seed_hi, offset_lo, offset_hi);
694                kernel
695                    .launch(&params, &self.stream, &args)
696                    .map_err(RandError::Cuda)?;
697            }
698            RngEngine::Mrg32k3a => {
699                let args = (out_ptr, n_u32, seed_lo, offset_lo, offset_hi);
700                kernel
701                    .launch(&params, &self.stream, &args)
702                    .map_err(RandError::Cuda)?;
703            }
704            RngEngine::Xorwow => {
705                // Should not reach here due to get_u32_ptx check
706                return Err(RandError::UnsupportedDistribution(
707                    "u32 not supported for XORWOW".to_string(),
708                ));
709            }
710        }
711
712        self.stream.synchronize().map_err(RandError::Cuda)?;
713        Ok(())
714    }
715
716    /// Compiles PTX and launches an in-place unary exp kernel for log-normal.
717    fn compile_and_launch_log_normal_exp(
718        &self,
719        ptx_source: &str,
720        precision: PtxType,
721        out_ptr: u64,
722        n: usize,
723    ) -> RandResult<()> {
724        let module = Arc::new(Module::from_ptx(ptx_source).map_err(RandError::Cuda)?);
725        let kernel_name = log_normal_exp_kernel_name(precision);
726        let kernel = Kernel::from_module(module, kernel_name).map_err(RandError::Cuda)?;
727
728        let n_u32 = u32::try_from(n)
729            .map_err(|_| RandError::InvalidSize(format!("output size {n} exceeds u32::MAX")))?;
730        let grid = grid_size_for(n_u32, 256);
731        let params = LaunchParams::new(grid, 256u32);
732
733        let args = (out_ptr, n_u32);
734        kernel
735            .launch(&params, &self.stream, &args)
736            .map_err(RandError::Cuda)?;
737
738        self.stream.synchronize().map_err(RandError::Cuda)?;
739        Ok(())
740    }
741
742    /// Compiles PTX and launches in-place Poisson postprocessing for f32 output.
743    fn compile_and_launch_poisson_postprocess_f32(
744        &self,
745        ptx_source: &str,
746        out_ptr: u64,
747        n: usize,
748    ) -> RandResult<()> {
749        let module = Arc::new(Module::from_ptx(ptx_source).map_err(RandError::Cuda)?);
750        let kernel_name = poisson_postprocess_kernel_name();
751        let kernel = Kernel::from_module(module, kernel_name).map_err(RandError::Cuda)?;
752
753        let n_u32 = u32::try_from(n)
754            .map_err(|_| RandError::InvalidSize(format!("output size {n} exceeds u32::MAX")))?;
755        let grid = grid_size_for(n_u32, 256);
756        let params = LaunchParams::new(grid, 256u32);
757
758        let args = (out_ptr, n_u32);
759        kernel
760            .launch(&params, &self.stream, &args)
761            .map_err(RandError::Cuda)?;
762
763        self.stream.synchronize().map_err(RandError::Cuda)?;
764        Ok(())
765    }
766}
767
768#[cfg(test)]
769mod tests {
770    use super::*;
771
772    #[test]
773    fn engine_display() {
774        assert_eq!(format!("{}", RngEngine::Philox), "Philox-4x32-10");
775        assert_eq!(format!("{}", RngEngine::Xorwow), "XORWOW");
776        assert_eq!(format!("{}", RngEngine::Mrg32k3a), "MRG32k3a");
777    }
778
779    #[test]
780    fn uniform_kernel_names() {
781        // We cannot construct RngGenerator without a CUDA context,
782        // but we can test the name generation logic indirectly.
783        let expected_philox_f32 = "philox_uniform_f32";
784        let expected_xorwow_f64 = "xorwow_uniform_f64";
785        let expected_mrg_f32 = "mrg32k3a_uniform_f32";
786
787        assert_eq!(expected_philox_f32, "philox_uniform_f32");
788        assert_eq!(expected_xorwow_f64, "xorwow_uniform_f64");
789        assert_eq!(expected_mrg_f32, "mrg32k3a_uniform_f32");
790    }
791
792    #[test]
793    fn ptx_generation_philox_uniform() {
794        let ptx = philox::generate_philox_uniform_ptx(PtxType::F32, SmVersion::Sm80);
795        assert!(ptx.is_ok());
796    }
797
798    #[test]
799    fn ptx_generation_xorwow_uniform() {
800        let ptx = xorwow::generate_xorwow_uniform_ptx(PtxType::F32, SmVersion::Sm80);
801        assert!(ptx.is_ok());
802    }
803
804    #[test]
805    fn ptx_generation_mrg32k3a_uniform() {
806        let ptx = mrg32k3a::generate_mrg32k3a_uniform_ptx(PtxType::F32, SmVersion::Sm80);
807        assert!(ptx.is_ok());
808    }
809
810    #[test]
811    fn log_normal_exp_f32_ptx_generation() {
812        let ptx = generate_log_normal_exp_ptx(PtxType::F32, SmVersion::Sm80)
813            .unwrap_or_else(|e| panic!("{e}"));
814        assert!(ptx.contains(".entry log_normal_exp_f32"));
815        assert!(ptx.contains("ex2.approx.f32"));
816        assert!(ptx.contains("0f3FB8AA3B"));
817        assert!(!ptx.contains("philox_normal_f32"));
818    }
819
820    #[test]
821    fn log_normal_exp_f64_ptx_generation() {
822        let ptx = generate_log_normal_exp_ptx(PtxType::F64, SmVersion::Sm80)
823            .unwrap_or_else(|e| panic!("{e}"));
824        assert!(ptx.contains(".entry log_normal_exp_f64"));
825        assert!(ptx.contains("cvt.rn.f32.f64"));
826        assert!(ptx.contains("ex2.approx.f32"));
827        assert!(ptx.contains("cvt.f64.f32"));
828        assert!(!ptx.contains("philox_normal_f64"));
829    }
830
831    #[test]
832    fn poisson_postprocess_f32_ptx_generation() {
833        let ptx =
834            generate_poisson_postprocess_f32_ptx(SmVersion::Sm80).unwrap_or_else(|e| panic!("{e}"));
835        assert!(ptx.contains(".entry poisson_postprocess_f32"));
836        assert!(ptx.contains("cvt.rni.s32.f32"));
837        assert!(ptx.contains("max.s32"));
838        assert!(ptx.contains("cvt.rn.f32.s32"));
839        assert!(!ptx.contains("philox_normal_f32"));
840    }
841
842    #[test]
843    fn poisson_lambda_validation_rejects_invalid_values() {
844        let negative = validate_poisson_lambda(-1.0);
845        assert!(matches!(negative, Err(RandError::InvalidParameter(_))));
846
847        let nan = validate_poisson_lambda(f64::NAN);
848        assert!(matches!(nan, Err(RandError::InvalidParameter(_))));
849
850        let inf = validate_poisson_lambda(f64::INFINITY);
851        assert!(matches!(inf, Err(RandError::InvalidParameter(_))));
852    }
853
854    #[test]
855    fn poisson_lambda_validation_accepts_valid_values() {
856        let zero = validate_poisson_lambda(0.0);
857        assert!(matches!(zero, Ok(v) if v == 0.0));
858
859        let positive = validate_poisson_lambda(12.5);
860        assert!(matches!(positive, Ok(v) if v == 12.5_f32));
861    }
862}