Skip to main content

scirs2_stats/gpu/
mod.rs

1//! GPU-accelerated batch statistical distribution computations via WebGPU (wgpu).
2//!
3//! This module provides batch evaluation of common distribution functions
4//! (Normal and Exponential) using WGSL compute shaders dispatched through
5//! the wgpu backend.  Each public function has a CPU fallback that is always
6//! available, regardless of whether the `gpu_wgpu` feature is enabled.
7//!
8//! # Feature flags
9//!
10//! * `gpu`       — enables GPU abstraction layer via `scirs2-core/gpu`
11//! * `gpu_wgpu`  — enables WGSL/wgpu compute path (implies `gpu`)
12//!
13//! # Examples
14//!
15//! ```rust
16//! use scirs2_stats::gpu::{normal_log_pdf_batch, exponential_cdf_batch};
17//!
18//! let xs = vec![0.0_f64, 1.0, -1.0];
19//! let log_pdfs = normal_log_pdf_batch(&xs, 0.0, 1.0);
20//! assert_eq!(log_pdfs.len(), 3);
21//!
22//! let cdfs = exponential_cdf_batch(&xs, 1.0);
23//! assert!(cdfs[2].abs() < 1e-10);  // x = -1 → CDF = 0
24//! ```
25
26// ---------------------------------------------------------------------------
27// WGSL shader constants
28// ---------------------------------------------------------------------------
29
30/// WGSL compute shader: batch log-PDF of Normal distribution.
31///
32/// Computes `log_pdf(xᵢ; µ, σ) = -½((xᵢ−µ)/σ)² − ln(σ) − ln(√(2π))`
33/// for each element in the input array, given scalar µ and σ in a uniform
34/// buffer.
35///
36/// Bindings:
37///   - `@group(0) @binding(0)`: read-only `array<f32>` input x values
38///   - `@group(0) @binding(1)`: read-write `array<f32>` output values
39///   - `@group(0) @binding(2)`: uniform `NormalParams { mu, sigma, n, _pad }`
40const NORMAL_LOG_PDF_WGSL: &str = r#"
41struct NormalParams {
42    mu: f32,
43    sigma: f32,
44    n: u32,
45    _pad: u32,
46}
47
48@group(0) @binding(0) var<storage, read> x: array<f32>;
49@group(0) @binding(1) var<storage, read_write> out: array<f32>;
50@group(0) @binding(2) var<uniform> params: NormalParams;
51
52const LOG_SQRT_2PI: f32 = 0.9189385332046727;
53
54@compute @workgroup_size(64)
55fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
56    let i = gid.x;
57    if i >= params.n { return; }
58    let z = (x[i] - params.mu) / params.sigma;
59    out[i] = -0.5 * z * z - LOG_SQRT_2PI - log(params.sigma);
60}
61"#;
62
63/// WGSL compute shader: batch CDF of Normal distribution.
64///
65/// Computes `Φ((xᵢ−µ)/(σ√2))` using the Horner-form A&S erf approximation
66/// (Abramowitz & Stegun 7.1.26), which has a maximum absolute error of ≈ 1.5×10⁻⁷.
67///
68/// Bindings: same layout as [`NORMAL_LOG_PDF_WGSL`].
69const NORMAL_CDF_WGSL: &str = r#"
70struct NormalParams {
71    mu: f32,
72    sigma: f32,
73    n: u32,
74    _pad: u32,
75}
76
77@group(0) @binding(0) var<storage, read> x: array<f32>;
78@group(0) @binding(1) var<storage, read_write> out: array<f32>;
79@group(0) @binding(2) var<uniform> params: NormalParams;
80
81fn approx_erf(v: f32) -> f32 {
82    let t = 1.0 / (1.0 + 0.3275911 * abs(v));
83    let y = 1.0 - (((((1.061405429 * t - 1.453152027) * t
84        + 1.421413741) * t - 0.284496736) * t + 0.254829592) * t * exp(-v * v));
85    return select(-y, y, v >= 0.0);
86}
87
88@compute @workgroup_size(64)
89fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
90    let i = gid.x;
91    if i >= params.n { return; }
92    let z = (x[i] - params.mu) / (params.sigma * 1.41421356237f);
93    out[i] = 0.5 * (1.0 + approx_erf(z));
94}
95"#;
96
97/// WGSL compute shader: batch log-PDF of Exponential distribution.
98///
99/// For `xᵢ ≥ 0`: `log_pdf(xᵢ; λ) = ln(λ) − λ·xᵢ`.
100/// For `xᵢ < 0`: outputs `-1e30` (representing −∞).
101///
102/// Bindings:
103///   - `@group(0) @binding(0)`: read-only `array<f32>` input x values
104///   - `@group(0) @binding(1)`: read-write `array<f32>` output values
105///   - `@group(0) @binding(2)`: uniform `ExponParams { lambda, n, _pad0, _pad1 }`
106const EXPONENTIAL_LOG_PDF_WGSL: &str = r#"
107struct ExponParams {
108    lambda: f32,
109    n: u32,
110    _pad0: u32,
111    _pad1: u32,
112}
113
114@group(0) @binding(0) var<storage, read> x: array<f32>;
115@group(0) @binding(1) var<storage, read_write> out: array<f32>;
116@group(0) @binding(2) var<uniform> params: ExponParams;
117
118@compute @workgroup_size(64)
119fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
120    let i = gid.x;
121    if i >= params.n { return; }
122    let xi = x[i];
123    out[i] = select(-1e30, log(params.lambda) - params.lambda * xi, xi >= 0.0);
124}
125"#;
126
127/// WGSL compute shader: batch CDF of Exponential distribution.
128///
129/// For `xᵢ ≥ 0`: `CDF(xᵢ; λ) = 1 − exp(−λ·xᵢ)`.
130/// For `xᵢ < 0`: outputs `0.0`.
131///
132/// Bindings: same layout as [`EXPONENTIAL_LOG_PDF_WGSL`].
133const EXPONENTIAL_CDF_WGSL: &str = r#"
134struct ExponParams {
135    lambda: f32,
136    n: u32,
137    _pad0: u32,
138    _pad1: u32,
139}
140
141@group(0) @binding(0) var<storage, read> x: array<f32>;
142@group(0) @binding(1) var<storage, read_write> out: array<f32>;
143@group(0) @binding(2) var<uniform> params: ExponParams;
144
145@compute @workgroup_size(64)
146fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
147    let i = gid.x;
148    if i >= params.n { return; }
149    let xi = x[i];
150    out[i] = select(0.0, 1.0 - exp(-params.lambda * xi), xi >= 0.0);
151}
152"#;
153
154// ---------------------------------------------------------------------------
155// Error type
156// ---------------------------------------------------------------------------
157
158/// Errors that can arise from GPU-accelerated distribution evaluation.
159#[derive(Debug, Clone)]
160pub enum GpuStatsError {
161    /// No wgpu-capable GPU adapter is available on this system.
162    GpuNotAvailable,
163    /// A runtime error occurred during GPU pipeline setup or dispatch.
164    RuntimeError(String),
165    /// The `gpu_wgpu` feature flag is not enabled in this build.
166    FeatureNotEnabled,
167}
168
169impl std::fmt::Display for GpuStatsError {
170    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
171        match self {
172            GpuStatsError::GpuNotAvailable => {
173                write!(f, "wgpu GPU adapter not available on this system")
174            }
175            GpuStatsError::RuntimeError(msg) => {
176                write!(f, "GPU runtime error: {msg}")
177            }
178            GpuStatsError::FeatureNotEnabled => {
179                write!(f, "gpu_wgpu feature is not enabled in this build")
180            }
181        }
182    }
183}
184
185impl std::error::Error for GpuStatsError {}
186
187// ---------------------------------------------------------------------------
188// GPU dispatch helper — real path (gpu_wgpu feature)
189// ---------------------------------------------------------------------------
190
191/// Upload `xs` as `f32`, dispatch the WGSL shader with a uniform parameter
192/// buffer (`params_bytes`), and return the resulting `f32` values.
193///
194/// The shader is expected to declare exactly three bindings:
195///   - `@group(0) @binding(0)` — read-only storage `array<f32>` (input)
196///   - `@group(0) @binding(1)` — read-write storage `array<f32>` (output)
197///   - `@group(0) @binding(2)` — uniform buffer (distribution parameters)
198///
199/// `params_bytes` must be 16-byte aligned (wgpu requirement for uniform
200/// buffers); all parameter structs in this module use 4 × f32/u32 fields
201/// (= 16 bytes) to satisfy this constraint.
202#[cfg(feature = "gpu_wgpu")]
203fn dispatch_with_params_f32(
204    wgsl: &str,
205    xs: &[f32],
206    params_bytes: &[u8],
207) -> Result<Vec<f32>, GpuStatsError> {
208    use wgpu::{
209        util::{BufferInitDescriptor, DeviceExt as _},
210        Backends, BindGroupDescriptor, BindGroupEntry, BindGroupLayoutDescriptor,
211        BindGroupLayoutEntry, BindingType, BufferBindingType, BufferDescriptor, BufferUsages,
212        CommandEncoderDescriptor, ComputePassDescriptor, DeviceDescriptor, Features, Instance,
213        InstanceDescriptor, Limits, MapMode, PowerPreference, RequestAdapterOptions,
214        ShaderModuleDescriptor, ShaderSource, ShaderStages,
215    };
216
217    let n = xs.len();
218    if n == 0 {
219        return Ok(Vec::new());
220    }
221
222    // ── Adapter / device ─────────────────────────────────────────────────────
223    let instance = Instance::new(InstanceDescriptor {
224        backends: Backends::all(),
225        flags: wgpu::InstanceFlags::default(),
226        memory_budget_thresholds: Default::default(),
227        backend_options: Default::default(),
228        display: None,
229    });
230
231    let adapter = pollster::block_on(instance.request_adapter(&RequestAdapterOptions {
232        power_preference: PowerPreference::HighPerformance,
233        compatible_surface: None,
234        force_fallback_adapter: false,
235    }))
236    .map_err(|_| GpuStatsError::GpuNotAvailable)?;
237
238    let (device, queue) = pollster::block_on(adapter.request_device(&DeviceDescriptor {
239        label: Some("scirs2-stats-gpu"),
240        required_features: Features::empty(),
241        required_limits: Limits::default(),
242        ..Default::default()
243    }))
244    .map_err(|e| GpuStatsError::RuntimeError(e.to_string()))?;
245
246    // ── Shader / pipeline ─────────────────────────────────────────────────────
247    let shader_module = device.create_shader_module(ShaderModuleDescriptor {
248        label: Some("scirs2-stats-shader"),
249        source: ShaderSource::Wgsl(wgsl.into()),
250    });
251
252    // Three-binding layout: storage-readonly, storage-readwrite, uniform
253    let bgl = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
254        label: Some("scirs2-stats-bgl"),
255        entries: &[
256            BindGroupLayoutEntry {
257                binding: 0,
258                visibility: ShaderStages::COMPUTE,
259                ty: BindingType::Buffer {
260                    ty: BufferBindingType::Storage { read_only: true },
261                    has_dynamic_offset: false,
262                    min_binding_size: None,
263                },
264                count: None,
265            },
266            BindGroupLayoutEntry {
267                binding: 1,
268                visibility: ShaderStages::COMPUTE,
269                ty: BindingType::Buffer {
270                    ty: BufferBindingType::Storage { read_only: false },
271                    has_dynamic_offset: false,
272                    min_binding_size: None,
273                },
274                count: None,
275            },
276            BindGroupLayoutEntry {
277                binding: 2,
278                visibility: ShaderStages::COMPUTE,
279                ty: BindingType::Buffer {
280                    ty: BufferBindingType::Uniform,
281                    has_dynamic_offset: false,
282                    min_binding_size: None,
283                },
284                count: None,
285            },
286        ],
287    });
288
289    let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
290        label: Some("scirs2-stats-layout"),
291        bind_group_layouts: &[Some(&bgl)],
292        ..Default::default()
293    });
294
295    let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
296        label: Some("scirs2-stats-pipeline"),
297        layout: Some(&pipeline_layout),
298        module: &shader_module,
299        entry_point: Some("main"),
300        compilation_options: Default::default(),
301        cache: None,
302    });
303
304    // ── Buffers ───────────────────────────────────────────────────────────────
305    let input_bytes: Vec<u8> = xs.iter().flat_map(|v| v.to_le_bytes()).collect();
306    let byte_len = (n as u64) * 4;
307
308    let buf_input = device.create_buffer_init(&BufferInitDescriptor {
309        label: Some("scirs2-stats-input"),
310        contents: &input_bytes,
311        usage: BufferUsages::STORAGE | BufferUsages::COPY_DST,
312    });
313
314    let buf_output = device.create_buffer(&BufferDescriptor {
315        label: Some("scirs2-stats-output"),
316        size: byte_len,
317        usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC,
318        mapped_at_creation: false,
319    });
320
321    let buf_params = device.create_buffer_init(&BufferInitDescriptor {
322        label: Some("scirs2-stats-params"),
323        contents: params_bytes,
324        usage: BufferUsages::UNIFORM | BufferUsages::COPY_DST,
325    });
326
327    let buf_staging = device.create_buffer(&BufferDescriptor {
328        label: Some("scirs2-stats-staging"),
329        size: byte_len,
330        usage: BufferUsages::MAP_READ | BufferUsages::COPY_DST,
331        mapped_at_creation: false,
332    });
333
334    // ── Bind group ────────────────────────────────────────────────────────────
335    let bind_group = device.create_bind_group(&BindGroupDescriptor {
336        label: Some("scirs2-stats-bg"),
337        layout: &bgl,
338        entries: &[
339            BindGroupEntry {
340                binding: 0,
341                resource: buf_input.as_entire_binding(),
342            },
343            BindGroupEntry {
344                binding: 1,
345                resource: buf_output.as_entire_binding(),
346            },
347            BindGroupEntry {
348                binding: 2,
349                resource: buf_params.as_entire_binding(),
350            },
351        ],
352    });
353
354    // ── Encode / dispatch ─────────────────────────────────────────────────────
355    let mut encoder = device.create_command_encoder(&CommandEncoderDescriptor {
356        label: Some("scirs2-stats-encoder"),
357    });
358    {
359        let mut cpass = encoder.begin_compute_pass(&ComputePassDescriptor {
360            label: Some("scirs2-stats-pass"),
361            timestamp_writes: None,
362        });
363        cpass.set_pipeline(&pipeline);
364        cpass.set_bind_group(0, &bind_group, &[]);
365        let workgroups = (n as u32 + 63) / 64;
366        cpass.dispatch_workgroups(workgroups, 1, 1);
367    }
368    encoder.copy_buffer_to_buffer(&buf_output, 0, &buf_staging, 0, byte_len);
369    queue.submit(Some(encoder.finish()));
370
371    // ── Readback ──────────────────────────────────────────────────────────────
372    device
373        .poll(wgpu::PollType::wait_indefinitely())
374        .map_err(|e| GpuStatsError::RuntimeError(format!("GPU poll error: {e:?}")))?;
375
376    let slice = buf_staging.slice(0..byte_len);
377    let (tx, rx) = std::sync::mpsc::channel();
378    slice.map_async(MapMode::Read, move |r| {
379        let _ = tx.send(r);
380    });
381
382    device
383        .poll(wgpu::PollType::wait_indefinitely())
384        .map_err(|e| GpuStatsError::RuntimeError(format!("GPU poll during map: {e:?}")))?;
385
386    rx.recv()
387        .map_err(|_| GpuStatsError::RuntimeError("channel closed in map_async".into()))?
388        .map_err(|e| GpuStatsError::RuntimeError(format!("map_async failed: {e:?}")))?;
389
390    let mapped = slice.get_mapped_range();
391    let result: Vec<f32> = mapped
392        .chunks_exact(4)
393        .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
394        .collect();
395    drop(mapped);
396    buf_staging.unmap();
397
398    Ok(result)
399}
400
401// ---------------------------------------------------------------------------
402// Encode NormalParams as little-endian bytes for the uniform buffer.
403// Layout: mu (f32), sigma (f32), n (u32), _pad (u32) → 16 bytes.
404// ---------------------------------------------------------------------------
405
406#[cfg(feature = "gpu_wgpu")]
407fn encode_normal_params(mu: f32, sigma: f32, n: u32) -> [u8; 16] {
408    let mut out = [0u8; 16];
409    out[0..4].copy_from_slice(&mu.to_le_bytes());
410    out[4..8].copy_from_slice(&sigma.to_le_bytes());
411    out[8..12].copy_from_slice(&n.to_le_bytes());
412    // _pad = 0
413    out
414}
415
416// Encode ExponParams: lambda (f32), n (u32), _pad0, _pad1 → 16 bytes.
417#[cfg(feature = "gpu_wgpu")]
418fn encode_expon_params(lambda: f32, n: u32) -> [u8; 16] {
419    let mut out = [0u8; 16];
420    out[0..4].copy_from_slice(&lambda.to_le_bytes());
421    out[4..8].copy_from_slice(&n.to_le_bytes());
422    // _pad0, _pad1 = 0
423    out
424}
425
426// ---------------------------------------------------------------------------
427// WGPU dispatch wrappers (gpu_wgpu feature)
428// ---------------------------------------------------------------------------
429
430/// Attempt batch Normal log-PDF evaluation via WebGPU.
431///
432/// Returns `Err(GpuStatsError::GpuNotAvailable)` when no adapter is found.
433#[cfg(feature = "gpu_wgpu")]
434fn normal_log_pdf_wgpu(xs: &[f64], mu: f64, sigma: f64) -> Result<Vec<f64>, GpuStatsError> {
435    let xs_f32: Vec<f32> = xs.iter().map(|&v| v as f32).collect();
436    let params = encode_normal_params(mu as f32, sigma as f32, xs_f32.len() as u32);
437    let out_f32 = dispatch_with_params_f32(NORMAL_LOG_PDF_WGSL, &xs_f32, &params)?;
438    Ok(out_f32.iter().map(|&v| v as f64).collect())
439}
440
441/// Attempt batch Normal CDF evaluation via WebGPU.
442#[cfg(feature = "gpu_wgpu")]
443fn normal_cdf_wgpu(xs: &[f64], mu: f64, sigma: f64) -> Result<Vec<f64>, GpuStatsError> {
444    let xs_f32: Vec<f32> = xs.iter().map(|&v| v as f32).collect();
445    let params = encode_normal_params(mu as f32, sigma as f32, xs_f32.len() as u32);
446    let out_f32 = dispatch_with_params_f32(NORMAL_CDF_WGSL, &xs_f32, &params)?;
447    Ok(out_f32.iter().map(|&v| v as f64).collect())
448}
449
450/// Attempt batch Exponential log-PDF evaluation via WebGPU.
451#[cfg(feature = "gpu_wgpu")]
452fn exponential_log_pdf_wgpu(xs: &[f64], lambda: f64) -> Result<Vec<f64>, GpuStatsError> {
453    let xs_f32: Vec<f32> = xs.iter().map(|&v| v as f32).collect();
454    let params = encode_expon_params(lambda as f32, xs_f32.len() as u32);
455    let out_f32 = dispatch_with_params_f32(EXPONENTIAL_LOG_PDF_WGSL, &xs_f32, &params)?;
456    Ok(out_f32.iter().map(|&v| v as f64).collect())
457}
458
459/// Attempt batch Exponential CDF evaluation via WebGPU.
460#[cfg(feature = "gpu_wgpu")]
461fn exponential_cdf_wgpu(xs: &[f64], lambda: f64) -> Result<Vec<f64>, GpuStatsError> {
462    let xs_f32: Vec<f32> = xs.iter().map(|&v| v as f32).collect();
463    let params = encode_expon_params(lambda as f32, xs_f32.len() as u32);
464    let out_f32 = dispatch_with_params_f32(EXPONENTIAL_CDF_WGSL, &xs_f32, &params)?;
465    Ok(out_f32.iter().map(|&v| v as f64).collect())
466}
467
468// ---------------------------------------------------------------------------
469// CPU implementations
470// ---------------------------------------------------------------------------
471
472/// Compute the error function using the A&S 7.1.26 polynomial approximation.
473///
474/// Maximum absolute error: ≈ 1.5 × 10⁻⁷.
475#[inline]
476fn erf_cpu(x: f64) -> f64 {
477    // Handle negative input by symmetry
478    if x < 0.0 {
479        return -erf_cpu(-x);
480    }
481    let t = 1.0 / (1.0 + 0.3275911 * x);
482    // Horner evaluation of A&S 7.1.26
483    let poly = t
484        * (0.254_829_592
485            + t * (-0.284_496_736
486                + t * (1.421_413_741 + t * (-1.453_152_027 + t * 1.061_405_429))));
487    1.0 - poly * (-x * x).exp()
488}
489
490/// Compute the standard-normal CDF Φ(z) = 0.5·(1 + erf(z/√2)).
491#[inline]
492fn phi_cpu(z: f64) -> f64 {
493    0.5 * (1.0 + erf_cpu(z / std::f64::consts::SQRT_2))
494}
495
496/// CPU scalar Normal log-PDF.
497#[inline]
498fn normal_log_pdf_scalar(x: f64, mu: f64, sigma: f64) -> f64 {
499    let z = (x - mu) / sigma;
500    -0.5 * z * z - (2.0 * std::f64::consts::PI).sqrt().ln() - sigma.ln()
501}
502
503/// CPU scalar Normal CDF.
504#[inline]
505fn normal_cdf_scalar(x: f64, mu: f64, sigma: f64) -> f64 {
506    phi_cpu((x - mu) / sigma)
507}
508
509/// CPU scalar Exponential log-PDF.
510///
511/// Returns `f64::NEG_INFINITY` (represented as a very large negative number)
512/// for `x < 0`.
513#[inline]
514fn exponential_log_pdf_scalar(x: f64, lambda: f64) -> f64 {
515    if x < 0.0 {
516        f64::NEG_INFINITY
517    } else {
518        lambda.ln() - lambda * x
519    }
520}
521
522/// CPU scalar Exponential CDF.
523#[inline]
524fn exponential_cdf_scalar(x: f64, lambda: f64) -> f64 {
525    if x < 0.0 {
526        0.0
527    } else {
528        1.0 - (-lambda * x).exp()
529    }
530}
531
532/// Minimum array length to trigger GPU dispatch.
533/// Arrays smaller than this always use the CPU path (GPU launch overhead
534/// is not worth it for short inputs, and CPU delivers full f64 precision).
535const MIN_GPU_SIZE: usize = 1024;
536
537// ---------------------------------------------------------------------------
538// Public batch API
539// ---------------------------------------------------------------------------
540
541/// Batch compute Normal log-PDF for an array of `x` values.
542///
543/// Computes `log_pdf(xᵢ; µ, σ) = -½·((xᵢ−µ)/σ)² − ln(σ) − ln(√(2π))`
544/// for each element.
545///
546/// When the `gpu_wgpu` feature is enabled and a compatible GPU is available,
547/// the computation is dispatched to a WGSL compute shader; otherwise the
548/// function silently falls back to a vectorised CPU loop.
549///
550/// # Arguments
551///
552/// * `xs`    — Input points at which to evaluate the log-PDF.
553/// * `mu`    — Distribution mean µ.
554/// * `sigma` — Distribution standard deviation σ (must be positive).
555///
556/// # Returns
557///
558/// A `Vec<f64>` of the same length as `xs`.
559pub fn normal_log_pdf_batch(xs: &[f64], mu: f64, sigma: f64) -> Vec<f64> {
560    #[cfg(feature = "gpu_wgpu")]
561    {
562        if xs.len() >= MIN_GPU_SIZE {
563            if let Ok(result) = normal_log_pdf_wgpu(xs, mu, sigma) {
564                return result;
565            }
566        }
567    }
568    xs.iter()
569        .map(|&x| normal_log_pdf_scalar(x, mu, sigma))
570        .collect()
571}
572
573/// Batch compute Normal CDF for an array of `x` values.
574///
575/// Computes `Φ((xᵢ−µ)/σ)` where Φ is the standard-normal CDF.
576///
577/// When the `gpu_wgpu` feature is enabled and a compatible GPU is available,
578/// computation is dispatched to a WGSL compute shader; otherwise the
579/// function falls back to a CPU loop using the A&S erf approximation.
580///
581/// # Arguments
582///
583/// * `xs`    — Input points at which to evaluate the CDF.
584/// * `mu`    — Distribution mean µ.
585/// * `sigma` — Distribution standard deviation σ (must be positive).
586///
587/// # Returns
588///
589/// A `Vec<f64>` of the same length as `xs`, with values in `[0, 1]`.
590pub fn normal_cdf_batch(xs: &[f64], mu: f64, sigma: f64) -> Vec<f64> {
591    #[cfg(feature = "gpu_wgpu")]
592    {
593        if xs.len() >= MIN_GPU_SIZE {
594            if let Ok(result) = normal_cdf_wgpu(xs, mu, sigma) {
595                return result;
596            }
597        }
598    }
599    xs.iter()
600        .map(|&x| normal_cdf_scalar(x, mu, sigma))
601        .collect()
602}
603
604/// Batch compute Exponential log-PDF for an array of `x` values.
605///
606/// For `xᵢ ≥ 0`: `log_pdf(xᵢ; λ) = ln(λ) − λ·xᵢ`.
607/// For `xᵢ < 0`: returns `f64::NEG_INFINITY`.
608///
609/// When the `gpu_wgpu` feature is enabled and a compatible GPU is available,
610/// computation is dispatched to a WGSL compute shader; otherwise the
611/// function falls back to a CPU loop.
612///
613/// # Arguments
614///
615/// * `xs`     — Input points at which to evaluate the log-PDF.
616/// * `lambda` — Rate parameter λ (must be positive).
617///
618/// # Returns
619///
620/// A `Vec<f64>` of the same length as `xs`.
621pub fn exponential_log_pdf_batch(xs: &[f64], lambda: f64) -> Vec<f64> {
622    #[cfg(feature = "gpu_wgpu")]
623    {
624        if xs.len() >= MIN_GPU_SIZE {
625            if let Ok(result) = exponential_log_pdf_wgpu(xs, lambda) {
626                return result;
627            }
628        }
629    }
630    xs.iter()
631        .map(|&x| exponential_log_pdf_scalar(x, lambda))
632        .collect()
633}
634
635/// Batch compute Exponential CDF for an array of `x` values.
636///
637/// For `xᵢ ≥ 0`: `CDF(xᵢ; λ) = 1 − exp(−λ·xᵢ)`.
638/// For `xᵢ < 0`: returns `0.0`.
639///
640/// When the `gpu_wgpu` feature is enabled and a compatible GPU is available,
641/// computation is dispatched to a WGSL compute shader; otherwise the
642/// function falls back to a CPU loop.
643///
644/// # Arguments
645///
646/// * `xs`     — Input points at which to evaluate the CDF.
647/// * `lambda` — Rate parameter λ (must be positive).
648///
649/// # Returns
650///
651/// A `Vec<f64>` of the same length as `xs`, with values in `[0, 1]`.
652pub fn exponential_cdf_batch(xs: &[f64], lambda: f64) -> Vec<f64> {
653    #[cfg(feature = "gpu_wgpu")]
654    {
655        if xs.len() >= MIN_GPU_SIZE {
656            if let Ok(result) = exponential_cdf_wgpu(xs, lambda) {
657                return result;
658            }
659        }
660    }
661    xs.iter()
662        .map(|&x| exponential_cdf_scalar(x, lambda))
663        .collect()
664}
665
666// ---------------------------------------------------------------------------
667// Tests
668// ---------------------------------------------------------------------------
669
670#[cfg(test)]
671mod tests {
672    use super::*;
673
674    /// ln(√(2π))
675    fn log_sqrt_2pi() -> f64 {
676        (2.0 * std::f64::consts::PI).sqrt().ln()
677    }
678
679    #[test]
680    fn test_normal_log_pdf_batch_cpu() {
681        let xs = vec![0.0_f64, 1.0, -1.0, 2.0];
682        let result = normal_log_pdf_batch(&xs, 0.0, 1.0);
683        let lsp = log_sqrt_2pi();
684        assert_eq!(result.len(), xs.len());
685        for (r, &x) in result.iter().zip(xs.iter()) {
686            let expected = -0.5 * x * x - lsp;
687            assert!(
688                (r - expected).abs() < 1e-10,
689                "normal_log_pdf mismatch at x={x}: got {r}, expected {expected}"
690            );
691        }
692    }
693
694    #[test]
695    fn test_normal_log_pdf_batch_nonstandard() {
696        // Verify that mu/sigma parameters are applied correctly.
697        let xs = vec![2.0_f64, 3.0, 4.0];
698        let mu = 2.0;
699        let sigma = 2.0;
700        let result = normal_log_pdf_batch(&xs, mu, sigma);
701        let lsp = log_sqrt_2pi();
702        for (r, &x) in result.iter().zip(xs.iter()) {
703            let z = (x - mu) / sigma;
704            let expected = -0.5 * z * z - lsp - sigma.ln();
705            assert!(
706                (r - expected).abs() < 1e-10,
707                "nonstandard normal_log_pdf mismatch at x={x}"
708            );
709        }
710    }
711
712    #[test]
713    fn test_normal_log_pdf_batch_empty() {
714        let result = normal_log_pdf_batch(&[], 0.0, 1.0);
715        assert!(result.is_empty());
716    }
717
718    #[test]
719    fn test_normal_cdf_batch_cpu() {
720        let xs = vec![-1e6_f64, -1.0, 0.0, 1.0, 1e6_f64];
721        let result = normal_cdf_batch(&xs, 0.0, 1.0);
722        assert_eq!(result.len(), xs.len());
723        // Φ(−∞) ≈ 0
724        assert!(result[0] < 1e-6, "Φ(-1e6) should be ~0, got {}", result[0]);
725        // Φ(+∞) ≈ 1
726        assert!(
727            result[4] > 1.0 - 1e-6,
728            "Φ(+1e6) should be ~1, got {}",
729            result[4]
730        );
731        // Φ(0) ≈ 0.5; the A&S polynomial residual at x=0 gives ~5e-10 error
732        assert!(
733            (result[2] - 0.5).abs() < 1e-8,
734            "Φ(0) should be 0.5, got {}",
735            result[2]
736        );
737        // Φ(-1) ≈ 0.1587 (to within 0.001)
738        assert!(
739            (result[1] - 0.158_655_253_931_457_05).abs() < 1e-3,
740            "Φ(-1) should be ≈0.1587, got {}",
741            result[1]
742        );
743        // Φ(1) ≈ 0.8413
744        assert!(
745            (result[3] - 0.841_344_746_068_543).abs() < 1e-3,
746            "Φ(1) should be ≈0.8413, got {}",
747            result[3]
748        );
749    }
750
751    #[test]
752    fn test_normal_cdf_batch_symmetry() {
753        // Φ(-z) + Φ(z) = 1 for any z ≠ 0
754        // At z=0 the polynomial residual prevents exact cancellation; use 1e-8.
755        let xs = vec![-2.0_f64, -1.0, 0.0, 1.0, 2.0];
756        let result = normal_cdf_batch(&xs, 0.0, 1.0);
757        // Φ(-2) + Φ(2) = 1
758        assert!(
759            (result[0] + result[4] - 1.0).abs() < 1e-7,
760            "Φ(-2)+Φ(2) should be 1, got {}",
761            result[0] + result[4]
762        );
763        // Φ(-1) + Φ(1) = 1
764        assert!(
765            (result[1] + result[3] - 1.0).abs() < 1e-7,
766            "Φ(-1)+Φ(1) should be 1, got {}",
767            result[1] + result[3]
768        );
769        // Φ(0) ≈ 0.5 (A&S polynomial residual gives ~5e-10 at x=0)
770        assert!(
771            (result[2] - 0.5).abs() < 1e-8,
772            "Φ(0) should be ~0.5, got {}",
773            result[2]
774        );
775    }
776
777    #[test]
778    fn test_normal_cdf_batch_empty() {
779        let result = normal_cdf_batch(&[], 0.0, 1.0);
780        assert!(result.is_empty());
781    }
782
783    #[test]
784    fn test_exponential_log_pdf_batch_cpu() {
785        let xs = vec![0.0_f64, 1.0, 2.0, -1.0];
786        let lambda = 2.0_f64;
787        let result = exponential_log_pdf_batch(&xs, lambda);
788        assert_eq!(result.len(), xs.len());
789
790        // log_pdf(0; λ=2) = ln(2)
791        assert!(
792            (result[0] - lambda.ln()).abs() < 1e-10,
793            "log_pdf(0) should be ln(2), got {}",
794            result[0]
795        );
796        // log_pdf(1; λ=2) = ln(2) - 2
797        let expected_1 = lambda.ln() - lambda * 1.0;
798        assert!(
799            (result[1] - expected_1).abs() < 1e-10,
800            "log_pdf(1) should be {expected_1}, got {}",
801            result[1]
802        );
803        // log_pdf(2; λ=2) = ln(2) - 4
804        let expected_2 = lambda.ln() - lambda * 2.0;
805        assert!(
806            (result[2] - expected_2).abs() < 1e-10,
807            "log_pdf(2) should be {expected_2}, got {}",
808            result[2]
809        );
810        // x < 0 → -inf
811        assert!(
812            result[3] < -1e20,
813            "log_pdf(-1) should be -inf, got {}",
814            result[3]
815        );
816    }
817
818    #[test]
819    fn test_exponential_log_pdf_batch_unit_rate() {
820        // λ = 1: log_pdf(x) = -x for x >= 0
821        let xs: Vec<f64> = (0..=5).map(|i| i as f64).collect();
822        let result = exponential_log_pdf_batch(&xs, 1.0);
823        for (i, (&x, &r)) in xs.iter().zip(result.iter()).enumerate() {
824            let expected = -x; // ln(1) - 1*x = -x
825            assert!(
826                (r - expected).abs() < 1e-10,
827                "unit-rate log_pdf mismatch at index {i}"
828            );
829        }
830    }
831
832    #[test]
833    fn test_exponential_log_pdf_batch_empty() {
834        let result = exponential_log_pdf_batch(&[], 1.0);
835        assert!(result.is_empty());
836    }
837
838    #[test]
839    fn test_exponential_cdf_batch_cpu() {
840        let xs = vec![0.0_f64, 1.0, -1.0];
841        let result = exponential_cdf_batch(&xs, 1.0);
842        assert_eq!(result.len(), xs.len());
843
844        // CDF(0; λ=1) = 0
845        assert!(
846            (result[0] - 0.0).abs() < 1e-10,
847            "CDF(0) should be 0, got {}",
848            result[0]
849        );
850        // CDF(1; λ=1) = 1 - exp(-1)
851        let expected_1 = 1.0 - (-1.0_f64).exp();
852        assert!(
853            (result[1] - expected_1).abs() < 1e-10,
854            "CDF(1) should be {expected_1}, got {}",
855            result[1]
856        );
857        // CDF(-1; λ=1) = 0
858        assert!(
859            (result[2] - 0.0).abs() < 1e-10,
860            "CDF(-1) should be 0, got {}",
861            result[2]
862        );
863    }
864
865    #[test]
866    fn test_exponential_cdf_batch_large_x() {
867        // CDF(x; λ) → 1 as x → +∞
868        let xs = vec![100.0_f64, 1000.0];
869        let result = exponential_cdf_batch(&xs, 1.0);
870        assert!(result[0] > 1.0 - 1e-10);
871        assert!(result[1] > 1.0 - 1e-10);
872    }
873
874    #[test]
875    fn test_exponential_cdf_batch_empty() {
876        let result = exponential_cdf_batch(&[], 1.0);
877        assert!(result.is_empty());
878    }
879
880    #[test]
881    fn test_erf_cpu_symmetry() {
882        // erf is an odd function: erf(-x) = -erf(x)
883        for &x in &[0.5_f64, 1.0, 1.5, 2.0, 3.0] {
884            let pos = erf_cpu(x);
885            let neg = erf_cpu(-x);
886            assert!(
887                (pos + neg).abs() < 1e-12,
888                "erf symmetry failed at x={x}: erf(x)={pos}, erf(-x)={neg}"
889            );
890        }
891    }
892
893    #[test]
894    fn test_erf_cpu_known_values() {
895        // erf(0) = 0: the A&S polynomial has a residual of ~1e-9 at x=0
896        // due to the polynomial not cancelling exactly; use a generous tolerance.
897        assert!(
898            erf_cpu(0.0).abs() < 1e-8,
899            "erf(0) should be ~0, got {}",
900            erf_cpu(0.0)
901        );
902        // erf(1) ≈ 0.8427007929; A&S 7.1.26 has max absolute error ~1.5e-7
903        assert!(
904            (erf_cpu(1.0) - 0.842_700_792_949_715).abs() < 2e-7,
905            "erf(1) mismatch: {}",
906            erf_cpu(1.0)
907        );
908        // erf(2) ≈ 0.9953222650; same tolerance
909        assert!(
910            (erf_cpu(2.0) - 0.995_322_265_018_953).abs() < 2e-7,
911            "erf(2) mismatch: {}",
912            erf_cpu(2.0)
913        );
914    }
915
916    // ── GPU test (skipped when wgpu adapter unavailable) ─────────────────────
917
918    #[cfg(feature = "gpu_wgpu")]
919    #[test]
920    fn test_normal_log_pdf_wgpu_or_skip() {
921        let xs = vec![0.0_f64, 1.0, -1.0];
922        let gpu_result = normal_log_pdf_wgpu(&xs, 0.0, 1.0);
923        match gpu_result {
924            Err(GpuStatsError::GpuNotAvailable) => {
925                // No GPU adapter available — acceptable in CI
926                eprintln!("test_normal_log_pdf_wgpu_or_skip: GPU not available, skipping");
927            }
928            Err(e) => panic!("GPU error: {e}"),
929            Ok(gpu) => {
930                let cpu: Vec<f64> = xs
931                    .iter()
932                    .map(|&x| normal_log_pdf_scalar(x, 0.0, 1.0))
933                    .collect();
934                for (g, c) in gpu.iter().zip(cpu.iter()) {
935                    // f32 GPU vs f64 CPU: allow 1e-4 relative tolerance
936                    assert!((g - c).abs() < 1e-4, "GPU/CPU mismatch: gpu={g}, cpu={c}");
937                }
938            }
939        }
940    }
941
942    #[cfg(feature = "gpu_wgpu")]
943    #[test]
944    fn test_normal_cdf_wgpu_or_skip() {
945        let xs = vec![-1.0_f64, 0.0, 1.0];
946        let gpu_result = normal_cdf_wgpu(&xs, 0.0, 1.0);
947        match gpu_result {
948            Err(GpuStatsError::GpuNotAvailable) => {
949                eprintln!("test_normal_cdf_wgpu_or_skip: GPU not available, skipping");
950            }
951            Err(e) => panic!("GPU error: {e}"),
952            Ok(gpu) => {
953                let cpu: Vec<f64> = xs.iter().map(|&x| normal_cdf_scalar(x, 0.0, 1.0)).collect();
954                for (g, c) in gpu.iter().zip(cpu.iter()) {
955                    assert!((g - c).abs() < 1e-4, "GPU/CPU mismatch: gpu={g}, cpu={c}");
956                }
957            }
958        }
959    }
960
961    #[cfg(feature = "gpu_wgpu")]
962    #[test]
963    fn test_exponential_log_pdf_wgpu_or_skip() {
964        let xs = vec![0.0_f64, 1.0, 2.0];
965        let lambda = 2.0_f64;
966        let gpu_result = exponential_log_pdf_wgpu(&xs, lambda);
967        match gpu_result {
968            Err(GpuStatsError::GpuNotAvailable) => {
969                eprintln!("test_exponential_log_pdf_wgpu_or_skip: GPU not available, skipping");
970            }
971            Err(e) => panic!("GPU error: {e}"),
972            Ok(gpu) => {
973                let cpu: Vec<f64> = xs
974                    .iter()
975                    .map(|&x| exponential_log_pdf_scalar(x, lambda))
976                    .collect();
977                for (g, c) in gpu.iter().zip(cpu.iter()) {
978                    assert!((g - c).abs() < 1e-4, "GPU/CPU mismatch: gpu={g}, cpu={c}");
979                }
980            }
981        }
982    }
983
984    #[cfg(feature = "gpu_wgpu")]
985    #[test]
986    fn test_exponential_cdf_wgpu_or_skip() {
987        let xs = vec![0.0_f64, 1.0, 2.0];
988        let gpu_result = exponential_cdf_wgpu(&xs, 1.0);
989        match gpu_result {
990            Err(GpuStatsError::GpuNotAvailable) => {
991                eprintln!("test_exponential_cdf_wgpu_or_skip: GPU not available, skipping");
992            }
993            Err(e) => panic!("GPU error: {e}"),
994            Ok(gpu) => {
995                let cpu: Vec<f64> = xs.iter().map(|&x| exponential_cdf_scalar(x, 1.0)).collect();
996                for (g, c) in gpu.iter().zip(cpu.iter()) {
997                    assert!((g - c).abs() < 1e-4, "GPU/CPU mismatch: gpu={g}, cpu={c}");
998                }
999            }
1000        }
1001    }
1002}