Skip to main content

axonml_core/backends/
gpu_tests.rs

1//! GPU Backend Testing Infrastructure
2//!
3//! Comprehensive tests for all GPU backends (CUDA, Vulkan, Metal, WebGPU).
4//! These tests verify correctness by comparing GPU results against CPU reference.
5//!
6//! # Running Tests
7//!
8//! ```bash
9//! # Run all GPU tests (requires hardware)
10//! cargo test -p axonml-core --features cuda gpu_tests
11//!
12//! # Run specific backend tests
13//! cargo test -p axonml-core --features cuda test_cuda
14//! cargo test -p axonml-core --features vulkan test_vulkan
15//! cargo test -p axonml-core --features wgpu test_wgpu
16//! ```
17//!
18//! @version 0.1.0
19
20use crate::device::DeviceCapabilities;
21
22// =============================================================================
23// Test Configuration
24// =============================================================================
25
26/// Configuration for GPU tests.
27#[derive(Debug, Clone)]
28pub struct GpuTestConfig {
29    /// Tolerance for floating point comparisons
30    pub atol: f32,
31    /// Relative tolerance
32    pub rtol: f32,
33    /// Test sizes for correctness tests
34    pub test_sizes: Vec<usize>,
35    /// Benchmark sizes
36    pub benchmark_sizes: Vec<usize>,
37    /// Number of warmup iterations for benchmarks
38    pub warmup_iters: usize,
39    /// Number of benchmark iterations
40    pub bench_iters: usize,
41}
42
43impl Default for GpuTestConfig {
44    fn default() -> Self {
45        Self {
46            atol: 1e-5,
47            rtol: 1e-4,
48            test_sizes: vec![1, 7, 16, 64, 256, 1024, 4096],
49            benchmark_sizes: vec![1024, 4096, 16384, 65536, 262144, 1048576],
50            warmup_iters: 5,
51            bench_iters: 100,
52        }
53    }
54}
55
56// =============================================================================
57// Test Results
58// =============================================================================
59
60/// Result of a GPU test.
61#[derive(Debug, Clone)]
62pub struct GpuTestResult {
63    /// Test name
64    pub name: String,
65    /// Whether the test passed
66    pub passed: bool,
67    /// Error message if failed
68    pub error: Option<String>,
69    /// Maximum absolute error (for correctness tests)
70    pub max_abs_error: Option<f32>,
71    /// Throughput in GB/s (for benchmarks)
72    pub throughput_gbps: Option<f64>,
73    /// Latency in microseconds
74    pub latency_us: Option<f64>,
75}
76
77impl GpuTestResult {
78    /// Create a passed result.
79    pub fn pass(name: &str) -> Self {
80        Self {
81            name: name.to_string(),
82            passed: true,
83            error: None,
84            max_abs_error: None,
85            throughput_gbps: None,
86            latency_us: None,
87        }
88    }
89
90    /// Create a failed result.
91    pub fn fail(name: &str, error: &str) -> Self {
92        Self {
93            name: name.to_string(),
94            passed: false,
95            error: Some(error.to_string()),
96            max_abs_error: None,
97            throughput_gbps: None,
98            latency_us: None,
99        }
100    }
101
102    /// Add correctness metrics.
103    pub fn with_error(mut self, max_abs_error: f32) -> Self {
104        self.max_abs_error = Some(max_abs_error);
105        self
106    }
107
108    /// Add performance metrics.
109    pub fn with_perf(mut self, throughput_gbps: f64, latency_us: f64) -> Self {
110        self.throughput_gbps = Some(throughput_gbps);
111        self.latency_us = Some(latency_us);
112        self
113    }
114}
115
116/// Collection of test results.
117#[derive(Debug, Default)]
118pub struct GpuTestReport {
119    /// Backend name
120    pub backend: String,
121    /// Device capabilities
122    pub capabilities: Option<DeviceCapabilities>,
123    /// Individual test results
124    pub results: Vec<GpuTestResult>,
125}
126
127impl GpuTestReport {
128    /// Create a new report for a backend.
129    pub fn new(backend: &str) -> Self {
130        Self {
131            backend: backend.to_string(),
132            capabilities: None,
133            results: Vec::new(),
134        }
135    }
136
137    /// Set device capabilities.
138    pub fn with_capabilities(mut self, caps: DeviceCapabilities) -> Self {
139        self.capabilities = Some(caps);
140        self
141    }
142
143    /// Add a test result.
144    pub fn add_result(&mut self, result: GpuTestResult) {
145        self.results.push(result);
146    }
147
148    /// Get number of passed tests.
149    pub fn passed_count(&self) -> usize {
150        self.results.iter().filter(|r| r.passed).count()
151    }
152
153    /// Get number of failed tests.
154    pub fn failed_count(&self) -> usize {
155        self.results.iter().filter(|r| !r.passed).count()
156    }
157
158    /// Print a summary of the report.
159    pub fn print_summary(&self) {
160        println!("\n========================================");
161        println!("GPU Test Report: {}", self.backend);
162        println!("========================================");
163
164        if let Some(caps) = &self.capabilities {
165            println!("Device: {}", caps.name);
166            println!(
167                "Memory: {:.1} GB total, {:.1} GB available",
168                caps.total_memory as f64 / 1e9,
169                caps.available_memory as f64 / 1e9
170            );
171            if let Some(cc) = &caps.compute_capability {
172                println!("Compute Capability: {}.{}", cc.0, cc.1);
173            }
174            println!();
175        }
176
177        println!(
178            "Results: {} passed, {} failed",
179            self.passed_count(),
180            self.failed_count()
181        );
182        println!();
183
184        for result in &self.results {
185            let status = if result.passed { "PASS" } else { "FAIL" };
186            print!("[{}] {}", status, result.name);
187
188            if let Some(err) = &result.error {
189                print!(" - {}", err);
190            }
191            if let Some(mae) = result.max_abs_error {
192                print!(" (max_err: {:.2e})", mae);
193            }
194            if let Some(tp) = result.throughput_gbps {
195                print!(" [{:.2} GB/s]", tp);
196            }
197            if let Some(lat) = result.latency_us {
198                print!(" [{:.1} us]", lat);
199            }
200            println!();
201        }
202
203        if self.failed_count() > 0 {
204            println!("\nFailed tests:");
205            for result in self.results.iter().filter(|r| !r.passed) {
206                println!(
207                    "  - {}: {}",
208                    result.name,
209                    result.error.as_deref().unwrap_or("Unknown")
210                );
211            }
212        }
213    }
214}
215
216// =============================================================================
217// Test Utilities
218// =============================================================================
219
220/// Compare two float slices for approximate equality.
221pub fn assert_close(expected: &[f32], actual: &[f32], atol: f32, rtol: f32) -> Result<f32, String> {
222    if expected.len() != actual.len() {
223        return Err(format!(
224            "Length mismatch: expected {}, got {}",
225            expected.len(),
226            actual.len()
227        ));
228    }
229
230    let mut max_abs_error = 0.0f32;
231    for (i, (e, a)) in expected.iter().zip(actual.iter()).enumerate() {
232        let abs_err = (e - a).abs();
233        let rel_tol = rtol * e.abs().max(a.abs());
234        max_abs_error = max_abs_error.max(abs_err);
235
236        if abs_err > atol + rel_tol {
237            return Err(format!(
238                "Mismatch at index {}: expected {}, got {} (abs_err: {:.2e}, tol: {:.2e})",
239                i,
240                e,
241                a,
242                abs_err,
243                atol + rel_tol
244            ));
245        }
246    }
247
248    Ok(max_abs_error)
249}
250
251/// Generate random test data.
252pub fn random_vec(len: usize, seed: u64) -> Vec<f32> {
253    // Simple LCG for reproducibility
254    let mut state = seed;
255    (0..len)
256        .map(|_| {
257            state = state.wrapping_mul(1103515245).wrapping_add(12345);
258            // Map to [-1, 1]
259            ((state >> 16) & 0x7FFF) as f32 / 16384.0 - 1.0
260        })
261        .collect()
262}
263
264/// CPU reference implementation for element-wise addition.
265pub fn cpu_add(a: &[f32], b: &[f32]) -> Vec<f32> {
266    a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
267}
268
269/// CPU reference implementation for element-wise multiplication.
270pub fn cpu_mul(a: &[f32], b: &[f32]) -> Vec<f32> {
271    a.iter().zip(b.iter()).map(|(x, y)| x * y).collect()
272}
273
274/// CPU reference implementation for scalar multiplication.
275pub fn cpu_scale(a: &[f32], alpha: f32) -> Vec<f32> {
276    a.iter().map(|x| x * alpha).collect()
277}
278
279/// CPU reference implementation for ReLU.
280pub fn cpu_relu(a: &[f32]) -> Vec<f32> {
281    a.iter().map(|x| x.max(0.0)).collect()
282}
283
284/// CPU reference implementation for sigmoid.
285pub fn cpu_sigmoid(a: &[f32]) -> Vec<f32> {
286    a.iter().map(|x| 1.0 / (1.0 + (-x).exp())).collect()
287}
288
289/// CPU reference implementation for tanh.
290pub fn cpu_tanh(a: &[f32]) -> Vec<f32> {
291    a.iter().map(|x| x.tanh()).collect()
292}
293
294/// CPU reference implementation for matrix multiplication.
295/// A is m x k, B is k x n, C is m x n (row-major).
296pub fn cpu_gemm(a: &[f32], b: &[f32], m: usize, n: usize, k: usize) -> Vec<f32> {
297    let mut c = vec![0.0; m * n];
298    for i in 0..m {
299        for j in 0..n {
300            let mut sum = 0.0;
301            for p in 0..k {
302                sum += a[i * k + p] * b[p * n + j];
303            }
304            c[i * n + j] = sum;
305        }
306    }
307    c
308}
309
310// =============================================================================
311// CUDA Tests
312// =============================================================================
313
314#[cfg(feature = "cuda")]
315pub mod cuda_tests {
316    use super::*;
317    use crate::backends::Backend;
318    use crate::backends::cuda::{device_count, is_available, CudaBackend};
319
320    /// Run all CUDA tests.
321    pub fn run_all_tests(config: &GpuTestConfig) -> GpuTestReport {
322        let mut report = GpuTestReport::new("CUDA");
323
324        if !is_available() {
325            report.add_result(GpuTestResult::fail(
326                "cuda_availability",
327                "CUDA not available on this system",
328            ));
329            return report;
330        }
331
332        let backend = match CudaBackend::new(0) {
333            Some(b) => b,
334            None => {
335                report.add_result(GpuTestResult::fail(
336                    "backend_creation",
337                    "Failed to create CUDA backend",
338                ));
339                return report;
340            }
341        };
342
343        report = report.with_capabilities(backend.capabilities());
344
345        // Memory operations
346        report.add_result(test_memory_roundtrip(&backend, config));
347
348        // Element-wise operations
349        for &size in &config.test_sizes {
350            report.add_result(test_add(&backend, size, config));
351            report.add_result(test_mul(&backend, size, config));
352            report.add_result(test_scale(&backend, size, config));
353        }
354
355        // Activation functions
356        for &size in &config.test_sizes {
357            report.add_result(test_relu(&backend, size, config));
358            report.add_result(test_sigmoid(&backend, size, config));
359            report.add_result(test_tanh(&backend, size, config));
360        }
361
362        // Matrix multiplication
363        report.add_result(test_gemm_square(&backend, 64, config));
364        report.add_result(test_gemm_square(&backend, 256, config));
365        report.add_result(test_gemm_rectangular(&backend, 128, 64, 96, config));
366
367        report
368    }
369
370    fn test_memory_roundtrip(backend: &CudaBackend, _config: &GpuTestConfig) -> GpuTestResult {
371        let name = "memory_roundtrip";
372        let data: Vec<f32> = (0..1024).map(|i| i as f32).collect();
373
374        match backend.htod_copy(&data) {
375            Ok(gpu_data) => match backend.dtoh_copy(&gpu_data) {
376                Ok(result) => {
377                    if result == data {
378                        GpuTestResult::pass(name)
379                    } else {
380                        GpuTestResult::fail(name, "Data mismatch after roundtrip")
381                    }
382                }
383                Err(e) => GpuTestResult::fail(name, &format!("dtoh_copy failed: {}", e)),
384            },
385            Err(e) => GpuTestResult::fail(name, &format!("htod_copy failed: {}", e)),
386        }
387    }
388
389    fn test_add(backend: &CudaBackend, size: usize, config: &GpuTestConfig) -> GpuTestResult {
390        let name = format!("add_f32_{}", size);
391        let a = random_vec(size, 42);
392        let b = random_vec(size, 123);
393        let expected = cpu_add(&a, &b);
394
395        let gpu_a = match backend.htod_copy(&a) {
396            Ok(d) => d,
397            Err(e) => return GpuTestResult::fail(&name, &format!("htod_copy(a): {}", e)),
398        };
399        let gpu_b = match backend.htod_copy(&b) {
400            Ok(d) => d,
401            Err(e) => return GpuTestResult::fail(&name, &format!("htod_copy(b): {}", e)),
402        };
403        let mut gpu_c = match backend.alloc::<f32>(size) {
404            Ok(d) => d,
405            Err(e) => return GpuTestResult::fail(&name, &format!("alloc: {}", e)),
406        };
407
408        if let Err(e) = backend.add_f32(&mut gpu_c, &gpu_a, &gpu_b, size) {
409            return GpuTestResult::fail(&name, &format!("add_f32: {}", e));
410        }
411
412        backend.synchronize();
413
414        match backend.dtoh_copy(&gpu_c) {
415            Ok(result) => match assert_close(&expected, &result, config.atol, config.rtol) {
416                Ok(max_err) => GpuTestResult::pass(&name).with_error(max_err),
417                Err(e) => GpuTestResult::fail(&name, &e),
418            },
419            Err(e) => GpuTestResult::fail(&name, &format!("dtoh_copy: {}", e)),
420        }
421    }
422
423    fn test_mul(backend: &CudaBackend, size: usize, config: &GpuTestConfig) -> GpuTestResult {
424        let name = format!("mul_f32_{}", size);
425        let a = random_vec(size, 42);
426        let b = random_vec(size, 123);
427        let expected = cpu_mul(&a, &b);
428
429        let gpu_a = match backend.htod_copy(&a) {
430            Ok(d) => d,
431            Err(e) => return GpuTestResult::fail(&name, &format!("htod_copy(a): {}", e)),
432        };
433        let gpu_b = match backend.htod_copy(&b) {
434            Ok(d) => d,
435            Err(e) => return GpuTestResult::fail(&name, &format!("htod_copy(b): {}", e)),
436        };
437        let mut gpu_c = match backend.alloc::<f32>(size) {
438            Ok(d) => d,
439            Err(e) => return GpuTestResult::fail(&name, &format!("alloc: {}", e)),
440        };
441
442        if let Err(e) = backend.mul_f32(&mut gpu_c, &gpu_a, &gpu_b, size) {
443            return GpuTestResult::fail(&name, &format!("mul_f32: {}", e));
444        }
445
446        backend.synchronize();
447
448        match backend.dtoh_copy(&gpu_c) {
449            Ok(result) => match assert_close(&expected, &result, config.atol, config.rtol) {
450                Ok(max_err) => GpuTestResult::pass(&name).with_error(max_err),
451                Err(e) => GpuTestResult::fail(&name, &e),
452            },
453            Err(e) => GpuTestResult::fail(&name, &format!("dtoh_copy: {}", e)),
454        }
455    }
456
457    fn test_scale(backend: &CudaBackend, size: usize, config: &GpuTestConfig) -> GpuTestResult {
458        let name = format!("scale_f32_{}", size);
459        let a = random_vec(size, 42);
460        let alpha = 2.5f32;
461        let expected = cpu_scale(&a, alpha);
462
463        let mut gpu_a = match backend.htod_copy(&a) {
464            Ok(d) => d,
465            Err(e) => return GpuTestResult::fail(&name, &format!("htod_copy: {}", e)),
466        };
467
468        if let Err(e) = backend.scale_f32(&mut gpu_a, alpha, size) {
469            return GpuTestResult::fail(&name, &format!("scale_f32: {}", e));
470        }
471
472        backend.synchronize();
473
474        match backend.dtoh_copy(&gpu_a) {
475            Ok(result) => match assert_close(&expected, &result, config.atol, config.rtol) {
476                Ok(max_err) => GpuTestResult::pass(&name).with_error(max_err),
477                Err(e) => GpuTestResult::fail(&name, &e),
478            },
479            Err(e) => GpuTestResult::fail(&name, &format!("dtoh_copy: {}", e)),
480        }
481    }
482
483    fn test_relu(backend: &CudaBackend, size: usize, config: &GpuTestConfig) -> GpuTestResult {
484        let name = format!("relu_f32_{}", size);
485        let a = random_vec(size, 42);
486        let expected = cpu_relu(&a);
487
488        let gpu_a = match backend.htod_copy(&a) {
489            Ok(d) => d,
490            Err(e) => return GpuTestResult::fail(&name, &format!("htod_copy: {}", e)),
491        };
492        let mut gpu_b = match backend.alloc::<f32>(size) {
493            Ok(d) => d,
494            Err(e) => return GpuTestResult::fail(&name, &format!("alloc: {}", e)),
495        };
496
497        if let Err(e) = backend.relu_f32(&mut gpu_b, &gpu_a, size) {
498            return GpuTestResult::fail(&name, &format!("relu_f32: {}", e));
499        }
500
501        backend.synchronize();
502
503        match backend.dtoh_copy(&gpu_b) {
504            Ok(result) => match assert_close(&expected, &result, config.atol, config.rtol) {
505                Ok(max_err) => GpuTestResult::pass(&name).with_error(max_err),
506                Err(e) => GpuTestResult::fail(&name, &e),
507            },
508            Err(e) => GpuTestResult::fail(&name, &format!("dtoh_copy: {}", e)),
509        }
510    }
511
512    fn test_sigmoid(backend: &CudaBackend, size: usize, config: &GpuTestConfig) -> GpuTestResult {
513        let name = format!("sigmoid_f32_{}", size);
514        let a = random_vec(size, 42);
515        let expected = cpu_sigmoid(&a);
516
517        let gpu_a = match backend.htod_copy(&a) {
518            Ok(d) => d,
519            Err(e) => return GpuTestResult::fail(&name, &format!("htod_copy: {}", e)),
520        };
521        let mut gpu_b = match backend.alloc::<f32>(size) {
522            Ok(d) => d,
523            Err(e) => return GpuTestResult::fail(&name, &format!("alloc: {}", e)),
524        };
525
526        if let Err(e) = backend.sigmoid_f32(&mut gpu_b, &gpu_a, size) {
527            return GpuTestResult::fail(&name, &format!("sigmoid_f32: {}", e));
528        }
529
530        backend.synchronize();
531
532        // Sigmoid uses fast approximations, so allow higher tolerance
533        let sigmoid_atol = 1e-3;
534        let sigmoid_rtol = 1e-2;
535
536        match backend.dtoh_copy(&gpu_b) {
537            Ok(result) => match assert_close(&expected, &result, sigmoid_atol, sigmoid_rtol) {
538                Ok(max_err) => GpuTestResult::pass(&name).with_error(max_err),
539                Err(e) => GpuTestResult::fail(&name, &e),
540            },
541            Err(e) => GpuTestResult::fail(&name, &format!("dtoh_copy: {}", e)),
542        }
543    }
544
545    fn test_tanh(backend: &CudaBackend, size: usize, config: &GpuTestConfig) -> GpuTestResult {
546        let name = format!("tanh_f32_{}", size);
547        let a = random_vec(size, 42);
548        let expected = cpu_tanh(&a);
549
550        let gpu_a = match backend.htod_copy(&a) {
551            Ok(d) => d,
552            Err(e) => return GpuTestResult::fail(&name, &format!("htod_copy: {}", e)),
553        };
554        let mut gpu_b = match backend.alloc::<f32>(size) {
555            Ok(d) => d,
556            Err(e) => return GpuTestResult::fail(&name, &format!("alloc: {}", e)),
557        };
558
559        if let Err(e) = backend.tanh_f32(&mut gpu_b, &gpu_a, size) {
560            return GpuTestResult::fail(&name, &format!("tanh_f32: {}", e));
561        }
562
563        backend.synchronize();
564
565        // Tanh uses fast approximations
566        let tanh_atol = 1e-3;
567        let tanh_rtol = 1e-2;
568
569        match backend.dtoh_copy(&gpu_b) {
570            Ok(result) => match assert_close(&expected, &result, tanh_atol, tanh_rtol) {
571                Ok(max_err) => GpuTestResult::pass(&name).with_error(max_err),
572                Err(e) => GpuTestResult::fail(&name, &e),
573            },
574            Err(e) => GpuTestResult::fail(&name, &format!("dtoh_copy: {}", e)),
575        }
576    }
577
578    fn test_gemm_square(backend: &CudaBackend, n: usize, config: &GpuTestConfig) -> GpuTestResult {
579        test_gemm_rectangular(backend, n, n, n, config)
580    }
581
582    fn test_gemm_rectangular(
583        backend: &CudaBackend,
584        m: usize,
585        n: usize,
586        k: usize,
587        config: &GpuTestConfig,
588    ) -> GpuTestResult {
589        let name = format!("gemm_f32_{}x{}x{}", m, n, k);
590
591        // Generate test data
592        let a = random_vec(m * k, 42);
593        let b = random_vec(k * n, 123);
594        let expected = cpu_gemm(&a, &b, m, n, k);
595
596        // Convert to column-major for cuBLAS
597        let a_col = row_to_col_major(&a, m, k);
598        let b_col = row_to_col_major(&b, k, n);
599
600        let gpu_a = match backend.htod_copy(&a_col) {
601            Ok(d) => d,
602            Err(e) => return GpuTestResult::fail(&name, &format!("htod_copy(a): {}", e)),
603        };
604        let gpu_b = match backend.htod_copy(&b_col) {
605            Ok(d) => d,
606            Err(e) => return GpuTestResult::fail(&name, &format!("htod_copy(b): {}", e)),
607        };
608        let mut gpu_c = match backend.alloc::<f32>(m * n) {
609            Ok(d) => d,
610            Err(e) => return GpuTestResult::fail(&name, &format!("alloc: {}", e)),
611        };
612
613        // cuBLAS GEMM: C = alpha * A @ B + beta * C
614        if let Err(e) = backend.gemm_f32(
615            false, false, // no transpose
616            m, n, k, 1.0, // alpha
617            &gpu_a, m, // A, lda
618            &gpu_b, k,   // B, ldb
619            0.0, // beta
620            &mut gpu_c, m, // C, ldc
621        ) {
622            return GpuTestResult::fail(&name, &format!("gemm_f32: {}", e));
623        }
624
625        backend.synchronize();
626
627        match backend.dtoh_copy(&gpu_c) {
628            Ok(result_col) => {
629                // Convert back from column-major
630                let result = col_to_row_major(&result_col, m, n);
631
632                // GEMM can have larger numerical errors
633                let gemm_atol = 1e-3;
634                let gemm_rtol = 1e-2;
635
636                match assert_close(&expected, &result, gemm_atol, gemm_rtol) {
637                    Ok(max_err) => GpuTestResult::pass(&name).with_error(max_err),
638                    Err(e) => GpuTestResult::fail(&name, &e),
639                }
640            }
641            Err(e) => GpuTestResult::fail(&name, &format!("dtoh_copy: {}", e)),
642        }
643    }
644
645    // Helper: row-major to column-major conversion
646    fn row_to_col_major(data: &[f32], rows: usize, cols: usize) -> Vec<f32> {
647        let mut result = vec![0.0; rows * cols];
648        for i in 0..rows {
649            for j in 0..cols {
650                result[j * rows + i] = data[i * cols + j];
651            }
652        }
653        result
654    }
655
656    // Helper: column-major to row-major conversion
657    fn col_to_row_major(data: &[f32], rows: usize, cols: usize) -> Vec<f32> {
658        let mut result = vec![0.0; rows * cols];
659        for i in 0..rows {
660            for j in 0..cols {
661                result[i * cols + j] = data[j * rows + i];
662            }
663        }
664        result
665    }
666}
667
668// =============================================================================
669// Hardware Detection
670// =============================================================================
671
672/// Detect available GPU backends.
673pub fn detect_gpu_backends() -> Vec<String> {
674    let mut backends = Vec::new();
675
676    #[cfg(feature = "cuda")]
677    {
678        if crate::backends::cuda::is_available() {
679            backends.push(format!(
680                "CUDA ({} device(s))",
681                crate::backends::cuda::device_count()
682            ));
683        }
684    }
685
686    #[cfg(feature = "vulkan")]
687    {
688        backends.push("Vulkan".to_string());
689    }
690
691    #[cfg(feature = "metal")]
692    {
693        #[cfg(target_os = "macos")]
694        backends.push("Metal".to_string());
695    }
696
697    #[cfg(feature = "wgpu")]
698    {
699        backends.push("WebGPU".to_string());
700    }
701
702    if backends.is_empty() {
703        backends.push("None (CPU only)".to_string());
704    }
705
706    backends
707}
708
709/// Print GPU detection information.
710pub fn print_gpu_info() {
711    println!("GPU Backend Detection");
712    println!("=====================");
713
714    let backends = detect_gpu_backends();
715    for backend in &backends {
716        println!("  - {}", backend);
717    }
718
719    #[cfg(feature = "cuda")]
720    {
721        if crate::backends::cuda::is_available() {
722            println!("\nCUDA Devices:");
723            for i in 0..crate::backends::cuda::device_count() {
724                let caps = crate::backends::cuda::get_capabilities(i);
725                println!("  [{}] {}", i, caps.name);
726                println!("      Memory: {:.1} GB", caps.total_memory as f64 / 1e9);
727                if let Some(cc) = caps.compute_capability {
728                    println!("      Compute: {}.{}", cc.0, cc.1);
729                }
730            }
731        }
732    }
733}
734
735// =============================================================================
736// Tests
737// =============================================================================
738
739#[cfg(test)]
740mod tests {
741    use super::*;
742
743    #[test]
744    fn test_random_vec_reproducibility() {
745        let a = random_vec(100, 42);
746        let b = random_vec(100, 42);
747        assert_eq!(a, b, "Same seed should produce same output");
748    }
749
750    #[test]
751    fn test_cpu_add() {
752        let a = vec![1.0, 2.0, 3.0];
753        let b = vec![4.0, 5.0, 6.0];
754        let c = cpu_add(&a, &b);
755        assert_eq!(c, vec![5.0, 7.0, 9.0]);
756    }
757
758    #[test]
759    fn test_cpu_mul() {
760        let a = vec![1.0, 2.0, 3.0];
761        let b = vec![4.0, 5.0, 6.0];
762        let c = cpu_mul(&a, &b);
763        assert_eq!(c, vec![4.0, 10.0, 18.0]);
764    }
765
766    #[test]
767    fn test_cpu_relu() {
768        let a = vec![-1.0, 0.0, 1.0, 2.0];
769        let b = cpu_relu(&a);
770        assert_eq!(b, vec![0.0, 0.0, 1.0, 2.0]);
771    }
772
773    #[test]
774    fn test_cpu_gemm() {
775        // 2x3 @ 3x2 = 2x2
776        let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
777        let b = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
778        let c = cpu_gemm(&a, &b, 2, 2, 3);
779        // Expected:
780        // [1*1+2*3+3*5, 1*2+2*4+3*6] = [22, 28]
781        // [4*1+5*3+6*5, 4*2+5*4+6*6] = [49, 64]
782        assert_eq!(c, vec![22.0, 28.0, 49.0, 64.0]);
783    }
784
785    #[test]
786    fn test_assert_close_pass() {
787        let a = vec![1.0, 2.0, 3.0];
788        let b = vec![1.00001, 2.00001, 3.00001];
789        assert!(assert_close(&a, &b, 1e-4, 1e-4).is_ok());
790    }
791
792    #[test]
793    fn test_assert_close_fail() {
794        let a = vec![1.0, 2.0, 3.0];
795        let b = vec![1.1, 2.0, 3.0];
796        assert!(assert_close(&a, &b, 1e-4, 1e-4).is_err());
797    }
798
799    #[test]
800    fn test_detect_backends() {
801        let backends = detect_gpu_backends();
802        assert!(!backends.is_empty());
803    }
804
805    #[test]
806    fn test_gpu_test_result() {
807        let pass = GpuTestResult::pass("test").with_error(0.0001);
808        assert!(pass.passed);
809        assert!(pass.max_abs_error.is_some());
810
811        let fail = GpuTestResult::fail("test", "error");
812        assert!(!fail.passed);
813        assert_eq!(fail.error, Some("error".to_string()));
814    }
815
816    #[cfg(feature = "cuda")]
817    #[test]
818    fn test_cuda_all() {
819        let config = GpuTestConfig::default();
820        let report = cuda_tests::run_all_tests(&config);
821        report.print_summary();
822
823        // If CUDA is available, all tests should pass
824        if crate::backends::cuda::is_available() {
825            assert_eq!(report.failed_count(), 0, "Some CUDA tests failed");
826        }
827    }
828}