1#[derive(Debug, Clone)]
22pub struct GpuDispatchConfig {
23 pub min_gpu_size: usize,
25 pub allow_gpu: bool,
27}
28
29impl Default for GpuDispatchConfig {
30 fn default() -> Self {
31 Self {
32 min_gpu_size: 1024,
33 allow_gpu: false,
34 }
35 }
36}
37
38impl GpuDispatchConfig {
39 pub fn cpu_only() -> Self {
41 Self {
42 min_gpu_size: usize::MAX,
43 allow_gpu: false,
44 }
45 }
46
47 pub fn gpu_at(min_size: usize) -> Self {
49 Self {
50 min_gpu_size: min_size,
51 allow_gpu: true,
52 }
53 }
54}
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub enum DispatchTarget {
59 Cpu,
60 Gpu,
61}
62
63pub fn select_dispatch(n: usize, config: &GpuDispatchConfig) -> DispatchTarget {
68 if config.allow_gpu && n >= config.min_gpu_size {
69 DispatchTarget::Gpu
70 } else {
71 DispatchTarget::Cpu
72 }
73}
74
75#[inline]
80fn gamma_cpu(x: f64) -> f64 {
81 crate::gamma::gamma(x)
82}
83
84#[inline]
85fn erf_cpu(x: f64) -> f64 {
86 crate::erf::erf(x)
87}
88
89#[inline]
90fn bessel_j0_cpu(x: f64) -> f64 {
91 crate::bessel::j0(x)
92}
93
94#[inline]
95fn lgamma_cpu(x: f64) -> f64 {
96 crate::gamma::gammaln(x)
97}
98
99#[inline]
100fn erfc_cpu(x: f64) -> f64 {
101 crate::erf::erfc(x)
102}
103
104#[inline]
105fn erfinv_cpu(x: f64) -> f64 {
106 crate::erf::erfinv(x)
107}
108
109pub fn batch_gamma(xs: &[f64], config: &GpuDispatchConfig) -> Vec<f64> {
123 match select_dispatch(xs.len(), config) {
124 DispatchTarget::Cpu => xs.iter().map(|&x| gamma_cpu(x)).collect(),
125 DispatchTarget::Gpu => {
126 if let Ok(result) = crate::gpu_kernels::wgsl::gamma_batch_wgpu(xs) {
128 return result;
129 }
130 if let Ok(result) = crate::gpu_kernels::cuda::gamma_batch_cuda(xs) {
131 return result;
132 }
133 xs.iter().map(|&x| gamma_cpu(x)).collect()
134 }
135 }
136}
137
138pub fn batch_erf(xs: &[f64], config: &GpuDispatchConfig) -> Vec<f64> {
142 match select_dispatch(xs.len(), config) {
143 DispatchTarget::Cpu => xs.iter().map(|&x| erf_cpu(x)).collect(),
144 DispatchTarget::Gpu => {
145 if let Ok(result) = crate::gpu_kernels::wgsl::erf_batch_wgpu(xs) {
146 return result;
147 }
148 if let Ok(result) = crate::gpu_kernels::cuda::erf_batch_cuda(xs) {
149 return result;
150 }
151 xs.iter().map(|&x| erf_cpu(x)).collect()
152 }
153 }
154}
155
156pub fn batch_bessel_j0(xs: &[f64], config: &GpuDispatchConfig) -> Vec<f64> {
160 match select_dispatch(xs.len(), config) {
161 DispatchTarget::Cpu => xs.iter().map(|&x| bessel_j0_cpu(x)).collect(),
162 DispatchTarget::Gpu => {
163 if let Ok(result) = crate::gpu_kernels::wgsl::bessel_j0_batch_wgpu(xs) {
164 return result;
165 }
166 if let Ok(result) = crate::gpu_kernels::cuda::bessel_j0_batch_cuda(xs) {
167 return result;
168 }
169 xs.iter().map(|&x| bessel_j0_cpu(x)).collect()
170 }
171 }
172}
173
174pub fn batch_lgamma(xs: &[f64], config: &GpuDispatchConfig) -> Vec<f64> {
179 match select_dispatch(xs.len(), config) {
180 DispatchTarget::Cpu => xs.iter().map(|&x| lgamma_cpu(x)).collect(),
181 DispatchTarget::Gpu => {
182 if let Ok(result) = crate::gpu_kernels::wgsl::lgamma_batch_wgpu(xs) {
183 return result;
184 }
185 xs.iter().map(|&x| lgamma_cpu(x)).collect()
186 }
187 }
188}
189
190pub fn batch_erfc(xs: &[f64], config: &GpuDispatchConfig) -> Vec<f64> {
196 match select_dispatch(xs.len(), config) {
197 DispatchTarget::Cpu => xs.iter().map(|&x| erfc_cpu(x)).collect(),
198 DispatchTarget::Gpu => {
199 if let Ok(result) = crate::gpu_kernels::wgsl::erfc_batch_wgpu(xs) {
200 return result;
201 }
202 xs.iter().map(|&x| erfc_cpu(x)).collect()
203 }
204 }
205}
206
207pub fn batch_erfinv(xs: &[f64], config: &GpuDispatchConfig) -> Vec<f64> {
214 match select_dispatch(xs.len(), config) {
215 DispatchTarget::Cpu => xs.iter().map(|&x| erfinv_cpu(x)).collect(),
216 DispatchTarget::Gpu => {
217 if let Ok(result) = crate::gpu_kernels::wgsl::erfinv_batch_wgpu(xs) {
218 return result;
219 }
220 xs.iter().map(|&x| erfinv_cpu(x)).collect()
221 }
222 }
223}
224
225pub fn batch_eval<F>(xs: &[f64], f: F, config: &GpuDispatchConfig) -> Vec<f64>
232where
233 F: Fn(f64) -> f64,
234{
235 let _target = select_dispatch(xs.len(), config);
237 xs.iter().map(|&x| f(x)).collect()
238}
239
240#[cfg(test)]
245mod tests {
246 use super::*;
247
248 #[test]
249 fn test_batch_gamma_cpu() {
250 let xs = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0];
251 let config = GpuDispatchConfig::default();
252 let results = batch_gamma(&xs, &config);
253 let expected = [1.0, 1.0, 2.0, 6.0, 24.0];
255 assert_eq!(results.len(), expected.len());
256 for (r, e) in results.iter().zip(expected.iter()) {
257 assert!(
258 (r - e).abs() < 1e-10,
259 "batch_gamma mismatch: got {r}, expected {e}"
260 );
261 }
262 }
263
264 #[test]
265 fn test_dispatch_small_array() {
266 let config = GpuDispatchConfig::default();
268 assert_eq!(select_dispatch(10, &config), DispatchTarget::Cpu);
269 }
270
271 #[test]
272 fn test_dispatch_large_array_cpu() {
273 let config = GpuDispatchConfig {
275 min_gpu_size: 1024,
276 allow_gpu: false,
277 };
278 assert_eq!(select_dispatch(10_000, &config), DispatchTarget::Cpu);
279 }
280
281 #[test]
282 fn test_dispatch_large_array_gpu_enabled() {
283 let config = GpuDispatchConfig {
285 min_gpu_size: 1024,
286 allow_gpu: true,
287 };
288 assert_eq!(select_dispatch(10_000, &config), DispatchTarget::Gpu);
289 }
290
291 #[test]
292 fn test_dispatch_exactly_at_threshold() {
293 let config = GpuDispatchConfig {
294 min_gpu_size: 1024,
295 allow_gpu: true,
296 };
297 assert_eq!(select_dispatch(1024, &config), DispatchTarget::Gpu);
298 assert_eq!(select_dispatch(1023, &config), DispatchTarget::Cpu);
299 }
300
301 #[test]
302 fn test_batch_erf() {
303 let xs = vec![0.0_f64, 1.0, -1.0, 2.0];
304 let config = GpuDispatchConfig::default();
305 let results = batch_erf(&xs, &config);
306 assert_eq!(results.len(), 4);
307 assert!(results[0].abs() < 1e-15);
309 assert!(
312 (results[1] - 0.842_700_792_949_715).abs() < 2e-7,
313 "erf(1.0) got {:.10}, expected ~0.842700793",
314 results[1]
315 );
316 assert!(
318 (results[2] + results[1]).abs() < 1e-12,
319 "erf should be odd: erf(-1)+erf(1)={}",
320 results[2] + results[1]
321 );
322 assert!(
324 (results[3] - 0.995_322_265_019).abs() < 2e-7,
325 "erf(2.0) got {:.10}, expected ~0.995322265",
326 results[3]
327 );
328 }
329
330 #[test]
331 fn test_batch_eval_custom() {
332 let xs: Vec<f64> = (1..=5).map(|i| i as f64).collect();
334 let config = GpuDispatchConfig::default();
335 let results = batch_eval(&xs, |x| x * x, &config);
336 let expected: Vec<f64> = xs.iter().map(|&x| x * x).collect();
337 assert_eq!(results, expected);
338 }
339
340 #[test]
341 fn test_batch_bessel_j0() {
342 let xs = vec![0.0_f64, 1.0, 2.0];
343 let config = GpuDispatchConfig::default();
344 let results = batch_bessel_j0(&xs, &config);
345 assert_eq!(results.len(), 3);
346 assert!((results[0] - 1.0).abs() < 1e-12);
348 assert!((results[1] - 0.765_197_686_6).abs() < 1e-8);
350 }
351
352 #[test]
353 fn test_batch_gamma_empty() {
354 let xs: Vec<f64> = vec![];
355 let config = GpuDispatchConfig::default();
356 let results = batch_gamma(&xs, &config);
357 assert!(results.is_empty());
358 }
359
360 #[test]
361 fn test_batch_erfc() {
362 let xs = vec![0.0_f64, 1.0, -1.0];
363 let config = GpuDispatchConfig::default();
364 let results = batch_erfc(&xs, &config);
365 assert_eq!(results.len(), 3);
366 assert!((results[0] - 1.0).abs() < 1e-14);
368 assert!(
371 (results[1] - 0.157_299_207_05).abs() < 2e-7,
372 "erfc(1.0) got {:.12}, expected ~0.15729920705",
373 results[1]
374 );
375 assert!(
377 (results[2] - 1.842_700_792_95).abs() < 2e-7,
378 "erfc(-1.0) got {:.12}, expected ~1.842700793",
379 results[2]
380 );
381 }
382
383 #[test]
384 fn test_batch_erfinv() {
385 let xs = vec![0.0_f64, 0.5, -0.5];
386 let config = GpuDispatchConfig::default();
387 let results = batch_erfinv(&xs, &config);
388 assert_eq!(results.len(), 3);
389 assert!(results[0].abs() < 1e-14);
391 assert!(
394 (results[1] - 0.476_936_276_2).abs() < 0.01,
395 "erfinv(0.5) got {:.12}, expected ~0.4769362762",
396 results[1]
397 );
398 assert!(
400 (results[2] + results[1]).abs() < 1e-12,
401 "erfinv should be odd: erfinv(-0.5)+erfinv(0.5)={}",
402 results[2] + results[1]
403 );
404 }
405}