1const NORMAL_LOG_PDF_WGSL: &str = r#"
41struct NormalParams {
42 mu: f32,
43 sigma: f32,
44 n: u32,
45 _pad: u32,
46}
47
48@group(0) @binding(0) var<storage, read> x: array<f32>;
49@group(0) @binding(1) var<storage, read_write> out: array<f32>;
50@group(0) @binding(2) var<uniform> params: NormalParams;
51
52const LOG_SQRT_2PI: f32 = 0.9189385332046727;
53
54@compute @workgroup_size(64)
55fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
56 let i = gid.x;
57 if i >= params.n { return; }
58 let z = (x[i] - params.mu) / params.sigma;
59 out[i] = -0.5 * z * z - LOG_SQRT_2PI - log(params.sigma);
60}
61"#;
62
63const NORMAL_CDF_WGSL: &str = r#"
70struct NormalParams {
71 mu: f32,
72 sigma: f32,
73 n: u32,
74 _pad: u32,
75}
76
77@group(0) @binding(0) var<storage, read> x: array<f32>;
78@group(0) @binding(1) var<storage, read_write> out: array<f32>;
79@group(0) @binding(2) var<uniform> params: NormalParams;
80
81fn approx_erf(v: f32) -> f32 {
82 let t = 1.0 / (1.0 + 0.3275911 * abs(v));
83 let y = 1.0 - (((((1.061405429 * t - 1.453152027) * t
84 + 1.421413741) * t - 0.284496736) * t + 0.254829592) * t * exp(-v * v));
85 return select(-y, y, v >= 0.0);
86}
87
88@compute @workgroup_size(64)
89fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
90 let i = gid.x;
91 if i >= params.n { return; }
92 let z = (x[i] - params.mu) / (params.sigma * 1.41421356237f);
93 out[i] = 0.5 * (1.0 + approx_erf(z));
94}
95"#;
96
97const EXPONENTIAL_LOG_PDF_WGSL: &str = r#"
107struct ExponParams {
108 lambda: f32,
109 n: u32,
110 _pad0: u32,
111 _pad1: u32,
112}
113
114@group(0) @binding(0) var<storage, read> x: array<f32>;
115@group(0) @binding(1) var<storage, read_write> out: array<f32>;
116@group(0) @binding(2) var<uniform> params: ExponParams;
117
118@compute @workgroup_size(64)
119fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
120 let i = gid.x;
121 if i >= params.n { return; }
122 let xi = x[i];
123 out[i] = select(-1e30, log(params.lambda) - params.lambda * xi, xi >= 0.0);
124}
125"#;
126
127const EXPONENTIAL_CDF_WGSL: &str = r#"
134struct ExponParams {
135 lambda: f32,
136 n: u32,
137 _pad0: u32,
138 _pad1: u32,
139}
140
141@group(0) @binding(0) var<storage, read> x: array<f32>;
142@group(0) @binding(1) var<storage, read_write> out: array<f32>;
143@group(0) @binding(2) var<uniform> params: ExponParams;
144
145@compute @workgroup_size(64)
146fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
147 let i = gid.x;
148 if i >= params.n { return; }
149 let xi = x[i];
150 out[i] = select(0.0, 1.0 - exp(-params.lambda * xi), xi >= 0.0);
151}
152"#;
153
154#[derive(Debug, Clone)]
160pub enum GpuStatsError {
161 GpuNotAvailable,
163 RuntimeError(String),
165 FeatureNotEnabled,
167}
168
169impl std::fmt::Display for GpuStatsError {
170 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
171 match self {
172 GpuStatsError::GpuNotAvailable => {
173 write!(f, "wgpu GPU adapter not available on this system")
174 }
175 GpuStatsError::RuntimeError(msg) => {
176 write!(f, "GPU runtime error: {msg}")
177 }
178 GpuStatsError::FeatureNotEnabled => {
179 write!(f, "gpu_wgpu feature is not enabled in this build")
180 }
181 }
182 }
183}
184
185impl std::error::Error for GpuStatsError {}
186
187#[cfg(feature = "gpu_wgpu")]
203fn dispatch_with_params_f32(
204 wgsl: &str,
205 xs: &[f32],
206 params_bytes: &[u8],
207) -> Result<Vec<f32>, GpuStatsError> {
208 use wgpu::{
209 util::{BufferInitDescriptor, DeviceExt as _},
210 Backends, BindGroupDescriptor, BindGroupEntry, BindGroupLayoutDescriptor,
211 BindGroupLayoutEntry, BindingType, BufferBindingType, BufferDescriptor, BufferUsages,
212 CommandEncoderDescriptor, ComputePassDescriptor, DeviceDescriptor, Features, Instance,
213 InstanceDescriptor, Limits, MapMode, PowerPreference, RequestAdapterOptions,
214 ShaderModuleDescriptor, ShaderSource, ShaderStages,
215 };
216
217 let n = xs.len();
218 if n == 0 {
219 return Ok(Vec::new());
220 }
221
222 let instance = Instance::new(InstanceDescriptor {
224 backends: Backends::all(),
225 flags: wgpu::InstanceFlags::default(),
226 memory_budget_thresholds: Default::default(),
227 backend_options: Default::default(),
228 display: None,
229 });
230
231 let adapter = pollster::block_on(instance.request_adapter(&RequestAdapterOptions {
232 power_preference: PowerPreference::HighPerformance,
233 compatible_surface: None,
234 force_fallback_adapter: false,
235 }))
236 .map_err(|_| GpuStatsError::GpuNotAvailable)?;
237
238 let (device, queue) = pollster::block_on(adapter.request_device(&DeviceDescriptor {
239 label: Some("scirs2-stats-gpu"),
240 required_features: Features::empty(),
241 required_limits: Limits::default(),
242 ..Default::default()
243 }))
244 .map_err(|e| GpuStatsError::RuntimeError(e.to_string()))?;
245
246 let shader_module = device.create_shader_module(ShaderModuleDescriptor {
248 label: Some("scirs2-stats-shader"),
249 source: ShaderSource::Wgsl(wgsl.into()),
250 });
251
252 let bgl = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
254 label: Some("scirs2-stats-bgl"),
255 entries: &[
256 BindGroupLayoutEntry {
257 binding: 0,
258 visibility: ShaderStages::COMPUTE,
259 ty: BindingType::Buffer {
260 ty: BufferBindingType::Storage { read_only: true },
261 has_dynamic_offset: false,
262 min_binding_size: None,
263 },
264 count: None,
265 },
266 BindGroupLayoutEntry {
267 binding: 1,
268 visibility: ShaderStages::COMPUTE,
269 ty: BindingType::Buffer {
270 ty: BufferBindingType::Storage { read_only: false },
271 has_dynamic_offset: false,
272 min_binding_size: None,
273 },
274 count: None,
275 },
276 BindGroupLayoutEntry {
277 binding: 2,
278 visibility: ShaderStages::COMPUTE,
279 ty: BindingType::Buffer {
280 ty: BufferBindingType::Uniform,
281 has_dynamic_offset: false,
282 min_binding_size: None,
283 },
284 count: None,
285 },
286 ],
287 });
288
289 let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
290 label: Some("scirs2-stats-layout"),
291 bind_group_layouts: &[Some(&bgl)],
292 ..Default::default()
293 });
294
295 let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
296 label: Some("scirs2-stats-pipeline"),
297 layout: Some(&pipeline_layout),
298 module: &shader_module,
299 entry_point: Some("main"),
300 compilation_options: Default::default(),
301 cache: None,
302 });
303
304 let input_bytes: Vec<u8> = xs.iter().flat_map(|v| v.to_le_bytes()).collect();
306 let byte_len = (n as u64) * 4;
307
308 let buf_input = device.create_buffer_init(&BufferInitDescriptor {
309 label: Some("scirs2-stats-input"),
310 contents: &input_bytes,
311 usage: BufferUsages::STORAGE | BufferUsages::COPY_DST,
312 });
313
314 let buf_output = device.create_buffer(&BufferDescriptor {
315 label: Some("scirs2-stats-output"),
316 size: byte_len,
317 usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC,
318 mapped_at_creation: false,
319 });
320
321 let buf_params = device.create_buffer_init(&BufferInitDescriptor {
322 label: Some("scirs2-stats-params"),
323 contents: params_bytes,
324 usage: BufferUsages::UNIFORM | BufferUsages::COPY_DST,
325 });
326
327 let buf_staging = device.create_buffer(&BufferDescriptor {
328 label: Some("scirs2-stats-staging"),
329 size: byte_len,
330 usage: BufferUsages::MAP_READ | BufferUsages::COPY_DST,
331 mapped_at_creation: false,
332 });
333
334 let bind_group = device.create_bind_group(&BindGroupDescriptor {
336 label: Some("scirs2-stats-bg"),
337 layout: &bgl,
338 entries: &[
339 BindGroupEntry {
340 binding: 0,
341 resource: buf_input.as_entire_binding(),
342 },
343 BindGroupEntry {
344 binding: 1,
345 resource: buf_output.as_entire_binding(),
346 },
347 BindGroupEntry {
348 binding: 2,
349 resource: buf_params.as_entire_binding(),
350 },
351 ],
352 });
353
354 let mut encoder = device.create_command_encoder(&CommandEncoderDescriptor {
356 label: Some("scirs2-stats-encoder"),
357 });
358 {
359 let mut cpass = encoder.begin_compute_pass(&ComputePassDescriptor {
360 label: Some("scirs2-stats-pass"),
361 timestamp_writes: None,
362 });
363 cpass.set_pipeline(&pipeline);
364 cpass.set_bind_group(0, &bind_group, &[]);
365 let workgroups = (n as u32 + 63) / 64;
366 cpass.dispatch_workgroups(workgroups, 1, 1);
367 }
368 encoder.copy_buffer_to_buffer(&buf_output, 0, &buf_staging, 0, byte_len);
369 queue.submit(Some(encoder.finish()));
370
371 device
373 .poll(wgpu::PollType::wait_indefinitely())
374 .map_err(|e| GpuStatsError::RuntimeError(format!("GPU poll error: {e:?}")))?;
375
376 let slice = buf_staging.slice(0..byte_len);
377 let (tx, rx) = std::sync::mpsc::channel();
378 slice.map_async(MapMode::Read, move |r| {
379 let _ = tx.send(r);
380 });
381
382 device
383 .poll(wgpu::PollType::wait_indefinitely())
384 .map_err(|e| GpuStatsError::RuntimeError(format!("GPU poll during map: {e:?}")))?;
385
386 rx.recv()
387 .map_err(|_| GpuStatsError::RuntimeError("channel closed in map_async".into()))?
388 .map_err(|e| GpuStatsError::RuntimeError(format!("map_async failed: {e:?}")))?;
389
390 let mapped = slice.get_mapped_range();
391 let result: Vec<f32> = mapped
392 .chunks_exact(4)
393 .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
394 .collect();
395 drop(mapped);
396 buf_staging.unmap();
397
398 Ok(result)
399}
400
401#[cfg(feature = "gpu_wgpu")]
407fn encode_normal_params(mu: f32, sigma: f32, n: u32) -> [u8; 16] {
408 let mut out = [0u8; 16];
409 out[0..4].copy_from_slice(&mu.to_le_bytes());
410 out[4..8].copy_from_slice(&sigma.to_le_bytes());
411 out[8..12].copy_from_slice(&n.to_le_bytes());
412 out
414}
415
416#[cfg(feature = "gpu_wgpu")]
418fn encode_expon_params(lambda: f32, n: u32) -> [u8; 16] {
419 let mut out = [0u8; 16];
420 out[0..4].copy_from_slice(&lambda.to_le_bytes());
421 out[4..8].copy_from_slice(&n.to_le_bytes());
422 out
424}
425
426#[cfg(feature = "gpu_wgpu")]
434fn normal_log_pdf_wgpu(xs: &[f64], mu: f64, sigma: f64) -> Result<Vec<f64>, GpuStatsError> {
435 let xs_f32: Vec<f32> = xs.iter().map(|&v| v as f32).collect();
436 let params = encode_normal_params(mu as f32, sigma as f32, xs_f32.len() as u32);
437 let out_f32 = dispatch_with_params_f32(NORMAL_LOG_PDF_WGSL, &xs_f32, ¶ms)?;
438 Ok(out_f32.iter().map(|&v| v as f64).collect())
439}
440
441#[cfg(feature = "gpu_wgpu")]
443fn normal_cdf_wgpu(xs: &[f64], mu: f64, sigma: f64) -> Result<Vec<f64>, GpuStatsError> {
444 let xs_f32: Vec<f32> = xs.iter().map(|&v| v as f32).collect();
445 let params = encode_normal_params(mu as f32, sigma as f32, xs_f32.len() as u32);
446 let out_f32 = dispatch_with_params_f32(NORMAL_CDF_WGSL, &xs_f32, ¶ms)?;
447 Ok(out_f32.iter().map(|&v| v as f64).collect())
448}
449
450#[cfg(feature = "gpu_wgpu")]
452fn exponential_log_pdf_wgpu(xs: &[f64], lambda: f64) -> Result<Vec<f64>, GpuStatsError> {
453 let xs_f32: Vec<f32> = xs.iter().map(|&v| v as f32).collect();
454 let params = encode_expon_params(lambda as f32, xs_f32.len() as u32);
455 let out_f32 = dispatch_with_params_f32(EXPONENTIAL_LOG_PDF_WGSL, &xs_f32, ¶ms)?;
456 Ok(out_f32.iter().map(|&v| v as f64).collect())
457}
458
459#[cfg(feature = "gpu_wgpu")]
461fn exponential_cdf_wgpu(xs: &[f64], lambda: f64) -> Result<Vec<f64>, GpuStatsError> {
462 let xs_f32: Vec<f32> = xs.iter().map(|&v| v as f32).collect();
463 let params = encode_expon_params(lambda as f32, xs_f32.len() as u32);
464 let out_f32 = dispatch_with_params_f32(EXPONENTIAL_CDF_WGSL, &xs_f32, ¶ms)?;
465 Ok(out_f32.iter().map(|&v| v as f64).collect())
466}
467
468#[inline]
476fn erf_cpu(x: f64) -> f64 {
477 if x < 0.0 {
479 return -erf_cpu(-x);
480 }
481 let t = 1.0 / (1.0 + 0.3275911 * x);
482 let poly = t
484 * (0.254_829_592
485 + t * (-0.284_496_736
486 + t * (1.421_413_741 + t * (-1.453_152_027 + t * 1.061_405_429))));
487 1.0 - poly * (-x * x).exp()
488}
489
490#[inline]
492fn phi_cpu(z: f64) -> f64 {
493 0.5 * (1.0 + erf_cpu(z / std::f64::consts::SQRT_2))
494}
495
496#[inline]
498fn normal_log_pdf_scalar(x: f64, mu: f64, sigma: f64) -> f64 {
499 let z = (x - mu) / sigma;
500 -0.5 * z * z - (2.0 * std::f64::consts::PI).sqrt().ln() - sigma.ln()
501}
502
503#[inline]
505fn normal_cdf_scalar(x: f64, mu: f64, sigma: f64) -> f64 {
506 phi_cpu((x - mu) / sigma)
507}
508
509#[inline]
514fn exponential_log_pdf_scalar(x: f64, lambda: f64) -> f64 {
515 if x < 0.0 {
516 f64::NEG_INFINITY
517 } else {
518 lambda.ln() - lambda * x
519 }
520}
521
522#[inline]
524fn exponential_cdf_scalar(x: f64, lambda: f64) -> f64 {
525 if x < 0.0 {
526 0.0
527 } else {
528 1.0 - (-lambda * x).exp()
529 }
530}
531
532const MIN_GPU_SIZE: usize = 1024;
536
537pub fn normal_log_pdf_batch(xs: &[f64], mu: f64, sigma: f64) -> Vec<f64> {
560 #[cfg(feature = "gpu_wgpu")]
561 {
562 if xs.len() >= MIN_GPU_SIZE {
563 if let Ok(result) = normal_log_pdf_wgpu(xs, mu, sigma) {
564 return result;
565 }
566 }
567 }
568 xs.iter()
569 .map(|&x| normal_log_pdf_scalar(x, mu, sigma))
570 .collect()
571}
572
573pub fn normal_cdf_batch(xs: &[f64], mu: f64, sigma: f64) -> Vec<f64> {
591 #[cfg(feature = "gpu_wgpu")]
592 {
593 if xs.len() >= MIN_GPU_SIZE {
594 if let Ok(result) = normal_cdf_wgpu(xs, mu, sigma) {
595 return result;
596 }
597 }
598 }
599 xs.iter()
600 .map(|&x| normal_cdf_scalar(x, mu, sigma))
601 .collect()
602}
603
604pub fn exponential_log_pdf_batch(xs: &[f64], lambda: f64) -> Vec<f64> {
622 #[cfg(feature = "gpu_wgpu")]
623 {
624 if xs.len() >= MIN_GPU_SIZE {
625 if let Ok(result) = exponential_log_pdf_wgpu(xs, lambda) {
626 return result;
627 }
628 }
629 }
630 xs.iter()
631 .map(|&x| exponential_log_pdf_scalar(x, lambda))
632 .collect()
633}
634
635pub fn exponential_cdf_batch(xs: &[f64], lambda: f64) -> Vec<f64> {
653 #[cfg(feature = "gpu_wgpu")]
654 {
655 if xs.len() >= MIN_GPU_SIZE {
656 if let Ok(result) = exponential_cdf_wgpu(xs, lambda) {
657 return result;
658 }
659 }
660 }
661 xs.iter()
662 .map(|&x| exponential_cdf_scalar(x, lambda))
663 .collect()
664}
665
666#[cfg(test)]
671mod tests {
672 use super::*;
673
674 fn log_sqrt_2pi() -> f64 {
676 (2.0 * std::f64::consts::PI).sqrt().ln()
677 }
678
679 #[test]
680 fn test_normal_log_pdf_batch_cpu() {
681 let xs = vec![0.0_f64, 1.0, -1.0, 2.0];
682 let result = normal_log_pdf_batch(&xs, 0.0, 1.0);
683 let lsp = log_sqrt_2pi();
684 assert_eq!(result.len(), xs.len());
685 for (r, &x) in result.iter().zip(xs.iter()) {
686 let expected = -0.5 * x * x - lsp;
687 assert!(
688 (r - expected).abs() < 1e-10,
689 "normal_log_pdf mismatch at x={x}: got {r}, expected {expected}"
690 );
691 }
692 }
693
694 #[test]
695 fn test_normal_log_pdf_batch_nonstandard() {
696 let xs = vec![2.0_f64, 3.0, 4.0];
698 let mu = 2.0;
699 let sigma = 2.0;
700 let result = normal_log_pdf_batch(&xs, mu, sigma);
701 let lsp = log_sqrt_2pi();
702 for (r, &x) in result.iter().zip(xs.iter()) {
703 let z = (x - mu) / sigma;
704 let expected = -0.5 * z * z - lsp - sigma.ln();
705 assert!(
706 (r - expected).abs() < 1e-10,
707 "nonstandard normal_log_pdf mismatch at x={x}"
708 );
709 }
710 }
711
712 #[test]
713 fn test_normal_log_pdf_batch_empty() {
714 let result = normal_log_pdf_batch(&[], 0.0, 1.0);
715 assert!(result.is_empty());
716 }
717
718 #[test]
719 fn test_normal_cdf_batch_cpu() {
720 let xs = vec![-1e6_f64, -1.0, 0.0, 1.0, 1e6_f64];
721 let result = normal_cdf_batch(&xs, 0.0, 1.0);
722 assert_eq!(result.len(), xs.len());
723 assert!(result[0] < 1e-6, "Φ(-1e6) should be ~0, got {}", result[0]);
725 assert!(
727 result[4] > 1.0 - 1e-6,
728 "Φ(+1e6) should be ~1, got {}",
729 result[4]
730 );
731 assert!(
733 (result[2] - 0.5).abs() < 1e-8,
734 "Φ(0) should be 0.5, got {}",
735 result[2]
736 );
737 assert!(
739 (result[1] - 0.158_655_253_931_457_05).abs() < 1e-3,
740 "Φ(-1) should be ≈0.1587, got {}",
741 result[1]
742 );
743 assert!(
745 (result[3] - 0.841_344_746_068_543).abs() < 1e-3,
746 "Φ(1) should be ≈0.8413, got {}",
747 result[3]
748 );
749 }
750
751 #[test]
752 fn test_normal_cdf_batch_symmetry() {
753 let xs = vec![-2.0_f64, -1.0, 0.0, 1.0, 2.0];
756 let result = normal_cdf_batch(&xs, 0.0, 1.0);
757 assert!(
759 (result[0] + result[4] - 1.0).abs() < 1e-7,
760 "Φ(-2)+Φ(2) should be 1, got {}",
761 result[0] + result[4]
762 );
763 assert!(
765 (result[1] + result[3] - 1.0).abs() < 1e-7,
766 "Φ(-1)+Φ(1) should be 1, got {}",
767 result[1] + result[3]
768 );
769 assert!(
771 (result[2] - 0.5).abs() < 1e-8,
772 "Φ(0) should be ~0.5, got {}",
773 result[2]
774 );
775 }
776
777 #[test]
778 fn test_normal_cdf_batch_empty() {
779 let result = normal_cdf_batch(&[], 0.0, 1.0);
780 assert!(result.is_empty());
781 }
782
783 #[test]
784 fn test_exponential_log_pdf_batch_cpu() {
785 let xs = vec![0.0_f64, 1.0, 2.0, -1.0];
786 let lambda = 2.0_f64;
787 let result = exponential_log_pdf_batch(&xs, lambda);
788 assert_eq!(result.len(), xs.len());
789
790 assert!(
792 (result[0] - lambda.ln()).abs() < 1e-10,
793 "log_pdf(0) should be ln(2), got {}",
794 result[0]
795 );
796 let expected_1 = lambda.ln() - lambda * 1.0;
798 assert!(
799 (result[1] - expected_1).abs() < 1e-10,
800 "log_pdf(1) should be {expected_1}, got {}",
801 result[1]
802 );
803 let expected_2 = lambda.ln() - lambda * 2.0;
805 assert!(
806 (result[2] - expected_2).abs() < 1e-10,
807 "log_pdf(2) should be {expected_2}, got {}",
808 result[2]
809 );
810 assert!(
812 result[3] < -1e20,
813 "log_pdf(-1) should be -inf, got {}",
814 result[3]
815 );
816 }
817
818 #[test]
819 fn test_exponential_log_pdf_batch_unit_rate() {
820 let xs: Vec<f64> = (0..=5).map(|i| i as f64).collect();
822 let result = exponential_log_pdf_batch(&xs, 1.0);
823 for (i, (&x, &r)) in xs.iter().zip(result.iter()).enumerate() {
824 let expected = -x; assert!(
826 (r - expected).abs() < 1e-10,
827 "unit-rate log_pdf mismatch at index {i}"
828 );
829 }
830 }
831
832 #[test]
833 fn test_exponential_log_pdf_batch_empty() {
834 let result = exponential_log_pdf_batch(&[], 1.0);
835 assert!(result.is_empty());
836 }
837
838 #[test]
839 fn test_exponential_cdf_batch_cpu() {
840 let xs = vec![0.0_f64, 1.0, -1.0];
841 let result = exponential_cdf_batch(&xs, 1.0);
842 assert_eq!(result.len(), xs.len());
843
844 assert!(
846 (result[0] - 0.0).abs() < 1e-10,
847 "CDF(0) should be 0, got {}",
848 result[0]
849 );
850 let expected_1 = 1.0 - (-1.0_f64).exp();
852 assert!(
853 (result[1] - expected_1).abs() < 1e-10,
854 "CDF(1) should be {expected_1}, got {}",
855 result[1]
856 );
857 assert!(
859 (result[2] - 0.0).abs() < 1e-10,
860 "CDF(-1) should be 0, got {}",
861 result[2]
862 );
863 }
864
865 #[test]
866 fn test_exponential_cdf_batch_large_x() {
867 let xs = vec![100.0_f64, 1000.0];
869 let result = exponential_cdf_batch(&xs, 1.0);
870 assert!(result[0] > 1.0 - 1e-10);
871 assert!(result[1] > 1.0 - 1e-10);
872 }
873
874 #[test]
875 fn test_exponential_cdf_batch_empty() {
876 let result = exponential_cdf_batch(&[], 1.0);
877 assert!(result.is_empty());
878 }
879
880 #[test]
881 fn test_erf_cpu_symmetry() {
882 for &x in &[0.5_f64, 1.0, 1.5, 2.0, 3.0] {
884 let pos = erf_cpu(x);
885 let neg = erf_cpu(-x);
886 assert!(
887 (pos + neg).abs() < 1e-12,
888 "erf symmetry failed at x={x}: erf(x)={pos}, erf(-x)={neg}"
889 );
890 }
891 }
892
893 #[test]
894 fn test_erf_cpu_known_values() {
895 assert!(
898 erf_cpu(0.0).abs() < 1e-8,
899 "erf(0) should be ~0, got {}",
900 erf_cpu(0.0)
901 );
902 assert!(
904 (erf_cpu(1.0) - 0.842_700_792_949_715).abs() < 2e-7,
905 "erf(1) mismatch: {}",
906 erf_cpu(1.0)
907 );
908 assert!(
910 (erf_cpu(2.0) - 0.995_322_265_018_953).abs() < 2e-7,
911 "erf(2) mismatch: {}",
912 erf_cpu(2.0)
913 );
914 }
915
916 #[cfg(feature = "gpu_wgpu")]
919 #[test]
920 fn test_normal_log_pdf_wgpu_or_skip() {
921 let xs = vec![0.0_f64, 1.0, -1.0];
922 let gpu_result = normal_log_pdf_wgpu(&xs, 0.0, 1.0);
923 match gpu_result {
924 Err(GpuStatsError::GpuNotAvailable) => {
925 eprintln!("test_normal_log_pdf_wgpu_or_skip: GPU not available, skipping");
927 }
928 Err(e) => panic!("GPU error: {e}"),
929 Ok(gpu) => {
930 let cpu: Vec<f64> = xs
931 .iter()
932 .map(|&x| normal_log_pdf_scalar(x, 0.0, 1.0))
933 .collect();
934 for (g, c) in gpu.iter().zip(cpu.iter()) {
935 assert!((g - c).abs() < 1e-4, "GPU/CPU mismatch: gpu={g}, cpu={c}");
937 }
938 }
939 }
940 }
941
942 #[cfg(feature = "gpu_wgpu")]
943 #[test]
944 fn test_normal_cdf_wgpu_or_skip() {
945 let xs = vec![-1.0_f64, 0.0, 1.0];
946 let gpu_result = normal_cdf_wgpu(&xs, 0.0, 1.0);
947 match gpu_result {
948 Err(GpuStatsError::GpuNotAvailable) => {
949 eprintln!("test_normal_cdf_wgpu_or_skip: GPU not available, skipping");
950 }
951 Err(e) => panic!("GPU error: {e}"),
952 Ok(gpu) => {
953 let cpu: Vec<f64> = xs.iter().map(|&x| normal_cdf_scalar(x, 0.0, 1.0)).collect();
954 for (g, c) in gpu.iter().zip(cpu.iter()) {
955 assert!((g - c).abs() < 1e-4, "GPU/CPU mismatch: gpu={g}, cpu={c}");
956 }
957 }
958 }
959 }
960
961 #[cfg(feature = "gpu_wgpu")]
962 #[test]
963 fn test_exponential_log_pdf_wgpu_or_skip() {
964 let xs = vec![0.0_f64, 1.0, 2.0];
965 let lambda = 2.0_f64;
966 let gpu_result = exponential_log_pdf_wgpu(&xs, lambda);
967 match gpu_result {
968 Err(GpuStatsError::GpuNotAvailable) => {
969 eprintln!("test_exponential_log_pdf_wgpu_or_skip: GPU not available, skipping");
970 }
971 Err(e) => panic!("GPU error: {e}"),
972 Ok(gpu) => {
973 let cpu: Vec<f64> = xs
974 .iter()
975 .map(|&x| exponential_log_pdf_scalar(x, lambda))
976 .collect();
977 for (g, c) in gpu.iter().zip(cpu.iter()) {
978 assert!((g - c).abs() < 1e-4, "GPU/CPU mismatch: gpu={g}, cpu={c}");
979 }
980 }
981 }
982 }
983
984 #[cfg(feature = "gpu_wgpu")]
985 #[test]
986 fn test_exponential_cdf_wgpu_or_skip() {
987 let xs = vec![0.0_f64, 1.0, 2.0];
988 let gpu_result = exponential_cdf_wgpu(&xs, 1.0);
989 match gpu_result {
990 Err(GpuStatsError::GpuNotAvailable) => {
991 eprintln!("test_exponential_cdf_wgpu_or_skip: GPU not available, skipping");
992 }
993 Err(e) => panic!("GPU error: {e}"),
994 Ok(gpu) => {
995 let cpu: Vec<f64> = xs.iter().map(|&x| exponential_cdf_scalar(x, 1.0)).collect();
996 for (g, c) in gpu.iter().zip(cpu.iter()) {
997 assert!((g - c).abs() < 1e-4, "GPU/CPU mismatch: gpu={g}, cpu={c}");
998 }
999 }
1000 }
1001 }
1002}