Skip to main content

scirs2_special/
gpu_dispatch.rs

1//! GPU auto-dispatch for batch evaluation of special functions.
2//! Falls back to CPU when GPU is unavailable or array is small.
3//!
4//! The dispatch logic is intentionally simple: a minimum array size threshold
5//! controls whether to attempt GPU execution. When `allow_gpu` is false (the
6//! default), all evaluation is performed on CPU regardless of array size.
7//!
8//! # Example
9//!
10//! ```rust
11//! use scirs2_special::gpu_dispatch::{GpuDispatchConfig, batch_gamma, batch_erf};
12//!
13//! let xs = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0];
14//! let config = GpuDispatchConfig::default();
15//! let results = batch_gamma(&xs, &config);
16//! // Γ(1)=1, Γ(2)=1, Γ(3)=2, Γ(4)=6, Γ(5)=24
17//! assert!((results[4] - 24.0).abs() < 1e-10);
18//! ```
19
20/// Configuration for GPU dispatch.
21#[derive(Debug, Clone)]
22pub struct GpuDispatchConfig {
23    /// Minimum array size to trigger GPU execution.
24    pub min_gpu_size: usize,
25    /// Use GPU if available; always use CPU if false.
26    pub allow_gpu: bool,
27}
28
29impl Default for GpuDispatchConfig {
30    fn default() -> Self {
31        Self {
32            min_gpu_size: 1024,
33            allow_gpu: false,
34        }
35    }
36}
37
38impl GpuDispatchConfig {
39    /// Create a config that always uses CPU regardless of array size.
40    pub fn cpu_only() -> Self {
41        Self {
42            min_gpu_size: usize::MAX,
43            allow_gpu: false,
44        }
45    }
46
47    /// Create a config that allows GPU dispatch at the given threshold.
48    pub fn gpu_at(min_size: usize) -> Self {
49        Self {
50            min_gpu_size: min_size,
51            allow_gpu: true,
52        }
53    }
54}
55
56/// Result of dispatch decision.
57#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub enum DispatchTarget {
59    Cpu,
60    Gpu,
61}
62
63/// Decide whether to dispatch to GPU based on array size.
64///
65/// Returns `DispatchTarget::Gpu` only when `config.allow_gpu` is true
66/// and `n >= config.min_gpu_size`.
67pub fn select_dispatch(n: usize, config: &GpuDispatchConfig) -> DispatchTarget {
68    if config.allow_gpu && n >= config.min_gpu_size {
69        DispatchTarget::Gpu
70    } else {
71        DispatchTarget::Cpu
72    }
73}
74
75// ─────────────────────────────────────────────────────────────────────────────
76// CPU implementations delegate to the existing crate functions
77// ─────────────────────────────────────────────────────────────────────────────
78
79#[inline]
80fn gamma_cpu(x: f64) -> f64 {
81    crate::gamma::gamma(x)
82}
83
84#[inline]
85fn erf_cpu(x: f64) -> f64 {
86    crate::erf::erf(x)
87}
88
89#[inline]
90fn bessel_j0_cpu(x: f64) -> f64 {
91    crate::bessel::j0(x)
92}
93
94#[inline]
95fn lgamma_cpu(x: f64) -> f64 {
96    crate::gamma::gammaln(x)
97}
98
99#[inline]
100fn erfc_cpu(x: f64) -> f64 {
101    crate::erf::erfc(x)
102}
103
104#[inline]
105fn erfinv_cpu(x: f64) -> f64 {
106    crate::erf::erfinv(x)
107}
108
109// ─────────────────────────────────────────────────────────────────────────────
110// Public batch APIs
111// ─────────────────────────────────────────────────────────────────────────────
112
113/// Batch evaluate gamma function with auto-dispatch.
114///
115/// When `config.allow_gpu` is false (the default), all computation is on CPU.
116/// When `allow_gpu` is true and the array exceeds `min_gpu_size`, the GPU
117/// path is attempted via the WGSL WebGPU backend
118/// ([`crate::gpu_kernels::wgsl::gamma_batch_wgpu`]) and, if that returns
119/// `GpuNotAvailable`, via the CUDA backend
120/// ([`crate::gpu_kernels::cuda::gamma_batch_cuda`]).  If neither backend is
121/// available the function falls back to the CPU path silently.
122pub fn batch_gamma(xs: &[f64], config: &GpuDispatchConfig) -> Vec<f64> {
123    match select_dispatch(xs.len(), config) {
124        DispatchTarget::Cpu => xs.iter().map(|&x| gamma_cpu(x)).collect(),
125        DispatchTarget::Gpu => {
126            // Try WGSL (WebGPU) first, then CUDA, then fall back to CPU.
127            if let Ok(result) = crate::gpu_kernels::wgsl::gamma_batch_wgpu(xs) {
128                return result;
129            }
130            if let Ok(result) = crate::gpu_kernels::cuda::gamma_batch_cuda(xs) {
131                return result;
132            }
133            xs.iter().map(|&x| gamma_cpu(x)).collect()
134        }
135    }
136}
137
138/// Batch evaluate erf function with auto-dispatch.
139///
140/// GPU path attempts WGSL then CUDA before falling back to CPU.
141pub fn batch_erf(xs: &[f64], config: &GpuDispatchConfig) -> Vec<f64> {
142    match select_dispatch(xs.len(), config) {
143        DispatchTarget::Cpu => xs.iter().map(|&x| erf_cpu(x)).collect(),
144        DispatchTarget::Gpu => {
145            if let Ok(result) = crate::gpu_kernels::wgsl::erf_batch_wgpu(xs) {
146                return result;
147            }
148            if let Ok(result) = crate::gpu_kernels::cuda::erf_batch_cuda(xs) {
149                return result;
150            }
151            xs.iter().map(|&x| erf_cpu(x)).collect()
152        }
153    }
154}
155
156/// Batch evaluate Bessel J₀ with auto-dispatch.
157///
158/// GPU path attempts WGSL then CUDA before falling back to CPU.
159pub fn batch_bessel_j0(xs: &[f64], config: &GpuDispatchConfig) -> Vec<f64> {
160    match select_dispatch(xs.len(), config) {
161        DispatchTarget::Cpu => xs.iter().map(|&x| bessel_j0_cpu(x)).collect(),
162        DispatchTarget::Gpu => {
163            if let Ok(result) = crate::gpu_kernels::wgsl::bessel_j0_batch_wgpu(xs) {
164                return result;
165            }
166            if let Ok(result) = crate::gpu_kernels::cuda::bessel_j0_batch_cuda(xs) {
167                return result;
168            }
169            xs.iter().map(|&x| bessel_j0_cpu(x)).collect()
170        }
171    }
172}
173
174/// Batch evaluate log-gamma with auto-dispatch.
175///
176/// GPU path attempts WGSL WebGPU backend ([`crate::gpu_kernels::wgsl::lgamma_batch_wgpu`])
177/// before falling back to the scalar CPU path.  CUDA is not yet available for lgamma.
178pub fn batch_lgamma(xs: &[f64], config: &GpuDispatchConfig) -> Vec<f64> {
179    match select_dispatch(xs.len(), config) {
180        DispatchTarget::Cpu => xs.iter().map(|&x| lgamma_cpu(x)).collect(),
181        DispatchTarget::Gpu => {
182            if let Ok(result) = crate::gpu_kernels::wgsl::lgamma_batch_wgpu(xs) {
183                return result;
184            }
185            xs.iter().map(|&x| lgamma_cpu(x)).collect()
186        }
187    }
188}
189
190/// Batch evaluate erfc function with auto-dispatch.
191///
192/// Computes `erfc(x) = 1 - erf(x)` for each element.  GPU path attempts the
193/// WGSL WebGPU backend ([`crate::gpu_kernels::wgsl::erfc_batch_wgpu`]) before
194/// falling back to the scalar CPU path using [`crate::erf::erfc`].
195pub fn batch_erfc(xs: &[f64], config: &GpuDispatchConfig) -> Vec<f64> {
196    match select_dispatch(xs.len(), config) {
197        DispatchTarget::Cpu => xs.iter().map(|&x| erfc_cpu(x)).collect(),
198        DispatchTarget::Gpu => {
199            if let Ok(result) = crate::gpu_kernels::wgsl::erfc_batch_wgpu(xs) {
200                return result;
201            }
202            xs.iter().map(|&x| erfc_cpu(x)).collect()
203        }
204    }
205}
206
207/// Batch evaluate the inverse error function with auto-dispatch.
208///
209/// Computes `erfinv(p)` such that `erf(erfinv(p)) == p` for |p| < 1.
210/// GPU path attempts the WGSL WebGPU backend
211/// ([`crate::gpu_kernels::wgsl::erfinv_batch_wgpu`]) before falling back to
212/// the scalar CPU path using [`crate::erf::erfinv`].
213pub fn batch_erfinv(xs: &[f64], config: &GpuDispatchConfig) -> Vec<f64> {
214    match select_dispatch(xs.len(), config) {
215        DispatchTarget::Cpu => xs.iter().map(|&x| erfinv_cpu(x)).collect(),
216        DispatchTarget::Gpu => {
217            if let Ok(result) = crate::gpu_kernels::wgsl::erfinv_batch_wgpu(xs) {
218                return result;
219            }
220            xs.iter().map(|&x| erfinv_cpu(x)).collect()
221        }
222    }
223}
224
225/// Batch evaluate with a custom function and auto-dispatch.
226///
227/// The function `f` is always called on CPU; the `config` controls whether
228/// a GPU-accelerated path would be preferred for built-in functions.  This
229/// generic variant always runs on CPU because user functions cannot be
230/// dispatched to GPU without additional codegen infrastructure.
231pub fn batch_eval<F>(xs: &[f64], f: F, config: &GpuDispatchConfig) -> Vec<f64>
232where
233    F: Fn(f64) -> f64,
234{
235    // User-provided functions always run on CPU; dispatch info is recorded but unused.
236    let _target = select_dispatch(xs.len(), config);
237    xs.iter().map(|&x| f(x)).collect()
238}
239
240// ─────────────────────────────────────────────────────────────────────────────
241// Tests
242// ─────────────────────────────────────────────────────────────────────────────
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247
248    #[test]
249    fn test_batch_gamma_cpu() {
250        let xs = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0];
251        let config = GpuDispatchConfig::default();
252        let results = batch_gamma(&xs, &config);
253        // Γ(n) = (n-1)!
254        let expected = [1.0, 1.0, 2.0, 6.0, 24.0];
255        assert_eq!(results.len(), expected.len());
256        for (r, e) in results.iter().zip(expected.iter()) {
257            assert!(
258                (r - e).abs() < 1e-10,
259                "batch_gamma mismatch: got {r}, expected {e}"
260            );
261        }
262    }
263
264    #[test]
265    fn test_dispatch_small_array() {
266        // Array size 10 with default config (allow_gpu=false) → always CPU
267        let config = GpuDispatchConfig::default();
268        assert_eq!(select_dispatch(10, &config), DispatchTarget::Cpu);
269    }
270
271    #[test]
272    fn test_dispatch_large_array_cpu() {
273        // allow_gpu=false, size 10000 → still CPU
274        let config = GpuDispatchConfig {
275            min_gpu_size: 1024,
276            allow_gpu: false,
277        };
278        assert_eq!(select_dispatch(10_000, &config), DispatchTarget::Cpu);
279    }
280
281    #[test]
282    fn test_dispatch_large_array_gpu_enabled() {
283        // allow_gpu=true, size 10000 → GPU (when threshold is 1024)
284        let config = GpuDispatchConfig {
285            min_gpu_size: 1024,
286            allow_gpu: true,
287        };
288        assert_eq!(select_dispatch(10_000, &config), DispatchTarget::Gpu);
289    }
290
291    #[test]
292    fn test_dispatch_exactly_at_threshold() {
293        let config = GpuDispatchConfig {
294            min_gpu_size: 1024,
295            allow_gpu: true,
296        };
297        assert_eq!(select_dispatch(1024, &config), DispatchTarget::Gpu);
298        assert_eq!(select_dispatch(1023, &config), DispatchTarget::Cpu);
299    }
300
301    #[test]
302    fn test_batch_erf() {
303        let xs = vec![0.0_f64, 1.0, -1.0, 2.0];
304        let config = GpuDispatchConfig::default();
305        let results = batch_erf(&xs, &config);
306        assert_eq!(results.len(), 4);
307        // erf(0) = 0
308        assert!(results[0].abs() < 1e-15);
309        // erf(1) ≈ 0.8427007929497148
310        // The crate implementation uses A&S 7.1.26 with max error 1.5e-7.
311        assert!(
312            (results[1] - 0.842_700_792_949_715).abs() < 2e-7,
313            "erf(1.0) got {:.10}, expected ~0.842700793",
314            results[1]
315        );
316        // erf is odd
317        assert!(
318            (results[2] + results[1]).abs() < 1e-12,
319            "erf should be odd: erf(-1)+erf(1)={}",
320            results[2] + results[1]
321        );
322        // erf(2) ≈ 0.9953222650189527
323        assert!(
324            (results[3] - 0.995_322_265_019).abs() < 2e-7,
325            "erf(2.0) got {:.10}, expected ~0.995322265",
326            results[3]
327        );
328    }
329
330    #[test]
331    fn test_batch_eval_custom() {
332        // Custom f(x) = x^2
333        let xs: Vec<f64> = (1..=5).map(|i| i as f64).collect();
334        let config = GpuDispatchConfig::default();
335        let results = batch_eval(&xs, |x| x * x, &config);
336        let expected: Vec<f64> = xs.iter().map(|&x| x * x).collect();
337        assert_eq!(results, expected);
338    }
339
340    #[test]
341    fn test_batch_bessel_j0() {
342        let xs = vec![0.0_f64, 1.0, 2.0];
343        let config = GpuDispatchConfig::default();
344        let results = batch_bessel_j0(&xs, &config);
345        assert_eq!(results.len(), 3);
346        // J₀(0) = 1
347        assert!((results[0] - 1.0).abs() < 1e-12);
348        // J₀(1) ≈ 0.7651976866
349        assert!((results[1] - 0.765_197_686_6).abs() < 1e-8);
350    }
351
352    #[test]
353    fn test_batch_gamma_empty() {
354        let xs: Vec<f64> = vec![];
355        let config = GpuDispatchConfig::default();
356        let results = batch_gamma(&xs, &config);
357        assert!(results.is_empty());
358    }
359
360    #[test]
361    fn test_batch_erfc() {
362        let xs = vec![0.0_f64, 1.0, -1.0];
363        let config = GpuDispatchConfig::default();
364        let results = batch_erfc(&xs, &config);
365        assert_eq!(results.len(), 3);
366        // erfc(0) = 1
367        assert!((results[0] - 1.0).abs() < 1e-14);
368        // erfc(1) ≈ 0.15729920705028516
369        // The crate erfc uses A&S 7.1.26 with max error ~1.5e-7
370        assert!(
371            (results[1] - 0.157_299_207_05).abs() < 2e-7,
372            "erfc(1.0) got {:.12}, expected ~0.15729920705",
373            results[1]
374        );
375        // erfc(-1) = 2 - erfc(1) ≈ 1.84270079295
376        assert!(
377            (results[2] - 1.842_700_792_95).abs() < 2e-7,
378            "erfc(-1.0) got {:.12}, expected ~1.842700793",
379            results[2]
380        );
381    }
382
383    #[test]
384    fn test_batch_erfinv() {
385        let xs = vec![0.0_f64, 0.5, -0.5];
386        let config = GpuDispatchConfig::default();
387        let results = batch_erfinv(&xs, &config);
388        assert_eq!(results.len(), 3);
389        // erfinv(0) = 0
390        assert!(results[0].abs() < 1e-14);
391        // erfinv(0.5) ≈ 0.47693627620448
392        // Tolerance is relaxed because erfinv uses a rough approximation.
393        assert!(
394            (results[1] - 0.476_936_276_2).abs() < 0.01,
395            "erfinv(0.5) got {:.12}, expected ~0.4769362762",
396            results[1]
397        );
398        // erfinv is odd
399        assert!(
400            (results[2] + results[1]).abs() < 1e-12,
401            "erfinv should be odd: erfinv(-0.5)+erfinv(0.5)={}",
402            results[2] + results[1]
403        );
404    }
405}