Skip to main content

axonml_core/backends/
gpu_tests.rs

1//! GPU Backend Testing Infrastructure
2//!
3//! # File
4//! `crates/axonml-core/src/backends/gpu_tests.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use crate::device::DeviceCapabilities;
18
19// =============================================================================
20// Test Configuration
21// =============================================================================
22
23/// Configuration for GPU tests.
24#[derive(Debug, Clone)]
25pub struct GpuTestConfig {
26    /// Tolerance for floating point comparisons
27    pub atol: f32,
28    /// Relative tolerance
29    pub rtol: f32,
30    /// Test sizes for correctness tests
31    pub test_sizes: Vec<usize>,
32    /// Benchmark sizes
33    pub benchmark_sizes: Vec<usize>,
34    /// Number of warmup iterations for benchmarks
35    pub warmup_iters: usize,
36    /// Number of benchmark iterations
37    pub bench_iters: usize,
38}
39
40impl Default for GpuTestConfig {
41    fn default() -> Self {
42        Self {
43            atol: 1e-5,
44            rtol: 1e-4,
45            test_sizes: vec![1, 7, 16, 64, 256, 1024, 4096],
46            benchmark_sizes: vec![1024, 4096, 16384, 65536, 262144, 1048576],
47            warmup_iters: 5,
48            bench_iters: 100,
49        }
50    }
51}
52
53// =============================================================================
54// Test Results
55// =============================================================================
56
57/// Result of a GPU test.
58#[derive(Debug, Clone)]
59pub struct GpuTestResult {
60    /// Test name
61    pub name: String,
62    /// Whether the test passed
63    pub passed: bool,
64    /// Error message if failed
65    pub error: Option<String>,
66    /// Maximum absolute error (for correctness tests)
67    pub max_abs_error: Option<f32>,
68    /// Throughput in GB/s (for benchmarks)
69    pub throughput_gbps: Option<f64>,
70    /// Latency in microseconds
71    pub latency_us: Option<f64>,
72}
73
74impl GpuTestResult {
75    /// Create a passed result.
76    pub fn pass(name: &str) -> Self {
77        Self {
78            name: name.to_string(),
79            passed: true,
80            error: None,
81            max_abs_error: None,
82            throughput_gbps: None,
83            latency_us: None,
84        }
85    }
86
87    /// Create a failed result.
88    pub fn fail(name: &str, error: &str) -> Self {
89        Self {
90            name: name.to_string(),
91            passed: false,
92            error: Some(error.to_string()),
93            max_abs_error: None,
94            throughput_gbps: None,
95            latency_us: None,
96        }
97    }
98
99    /// Add correctness metrics.
100    pub fn with_error(mut self, max_abs_error: f32) -> Self {
101        self.max_abs_error = Some(max_abs_error);
102        self
103    }
104
105    /// Add performance metrics.
106    pub fn with_perf(mut self, throughput_gbps: f64, latency_us: f64) -> Self {
107        self.throughput_gbps = Some(throughput_gbps);
108        self.latency_us = Some(latency_us);
109        self
110    }
111}
112
113/// Collection of test results.
114#[derive(Debug, Default)]
115pub struct GpuTestReport {
116    /// Backend name
117    pub backend: String,
118    /// Device capabilities
119    pub capabilities: Option<DeviceCapabilities>,
120    /// Individual test results
121    pub results: Vec<GpuTestResult>,
122}
123
124impl GpuTestReport {
125    /// Create a new report for a backend.
126    pub fn new(backend: &str) -> Self {
127        Self {
128            backend: backend.to_string(),
129            capabilities: None,
130            results: Vec::new(),
131        }
132    }
133
134    /// Set device capabilities.
135    pub fn with_capabilities(mut self, caps: DeviceCapabilities) -> Self {
136        self.capabilities = Some(caps);
137        self
138    }
139
140    /// Add a test result.
141    pub fn add_result(&mut self, result: GpuTestResult) {
142        self.results.push(result);
143    }
144
145    /// Get number of passed tests.
146    pub fn passed_count(&self) -> usize {
147        self.results.iter().filter(|r| r.passed).count()
148    }
149
150    /// Get number of failed tests.
151    pub fn failed_count(&self) -> usize {
152        self.results.iter().filter(|r| !r.passed).count()
153    }
154
155    /// Print a summary of the report.
156    pub fn print_summary(&self) {
157        println!("\n========================================");
158        println!("GPU Test Report: {}", self.backend);
159        println!("========================================");
160
161        if let Some(caps) = &self.capabilities {
162            println!("Device: {}", caps.name);
163            println!(
164                "Memory: {:.1} GB total, {:.1} GB available",
165                caps.total_memory as f64 / 1e9,
166                caps.available_memory as f64 / 1e9
167            );
168            if let Some(cc) = &caps.compute_capability {
169                println!("Compute Capability: {}.{}", cc.0, cc.1);
170            }
171            println!();
172        }
173
174        println!(
175            "Results: {} passed, {} failed",
176            self.passed_count(),
177            self.failed_count()
178        );
179        println!();
180
181        for result in &self.results {
182            let status = if result.passed { "PASS" } else { "FAIL" };
183            print!("[{}] {}", status, result.name);
184
185            if let Some(err) = &result.error {
186                print!(" - {}", err);
187            }
188            if let Some(mae) = result.max_abs_error {
189                print!(" (max_err: {:.2e})", mae);
190            }
191            if let Some(tp) = result.throughput_gbps {
192                print!(" [{:.2} GB/s]", tp);
193            }
194            if let Some(lat) = result.latency_us {
195                print!(" [{:.1} us]", lat);
196            }
197            println!();
198        }
199
200        if self.failed_count() > 0 {
201            println!("\nFailed tests:");
202            for result in self.results.iter().filter(|r| !r.passed) {
203                println!(
204                    "  - {}: {}",
205                    result.name,
206                    result.error.as_deref().unwrap_or("Unknown")
207                );
208            }
209        }
210    }
211}
212
213// =============================================================================
214// Test Utilities
215// =============================================================================
216
217/// Compare two float slices for approximate equality.
218pub fn assert_close(expected: &[f32], actual: &[f32], atol: f32, rtol: f32) -> Result<f32, String> {
219    if expected.len() != actual.len() {
220        return Err(format!(
221            "Length mismatch: expected {}, got {}",
222            expected.len(),
223            actual.len()
224        ));
225    }
226
227    let mut max_abs_error = 0.0f32;
228    for (i, (e, a)) in expected.iter().zip(actual.iter()).enumerate() {
229        let abs_err = (e - a).abs();
230        let rel_tol = rtol * e.abs().max(a.abs());
231        max_abs_error = max_abs_error.max(abs_err);
232
233        if abs_err > atol + rel_tol {
234            return Err(format!(
235                "Mismatch at index {}: expected {}, got {} (abs_err: {:.2e}, tol: {:.2e})",
236                i,
237                e,
238                a,
239                abs_err,
240                atol + rel_tol
241            ));
242        }
243    }
244
245    Ok(max_abs_error)
246}
247
248/// Generate random test data.
249pub fn random_vec(len: usize, seed: u64) -> Vec<f32> {
250    // Simple LCG for reproducibility
251    let mut state = seed;
252    (0..len)
253        .map(|_| {
254            state = state.wrapping_mul(1103515245).wrapping_add(12345);
255            // Map to [-1, 1]
256            ((state >> 16) & 0x7FFF) as f32 / 16384.0 - 1.0
257        })
258        .collect()
259}
260
261/// CPU reference implementation for element-wise addition.
262pub fn cpu_add(a: &[f32], b: &[f32]) -> Vec<f32> {
263    a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
264}
265
266/// CPU reference implementation for element-wise multiplication.
267pub fn cpu_mul(a: &[f32], b: &[f32]) -> Vec<f32> {
268    a.iter().zip(b.iter()).map(|(x, y)| x * y).collect()
269}
270
271/// CPU reference implementation for scalar multiplication.
272pub fn cpu_scale(a: &[f32], alpha: f32) -> Vec<f32> {
273    a.iter().map(|x| x * alpha).collect()
274}
275
276/// CPU reference implementation for ReLU.
277pub fn cpu_relu(a: &[f32]) -> Vec<f32> {
278    a.iter().map(|x| x.max(0.0)).collect()
279}
280
281/// CPU reference implementation for sigmoid.
282pub fn cpu_sigmoid(a: &[f32]) -> Vec<f32> {
283    a.iter().map(|x| 1.0 / (1.0 + (-x).exp())).collect()
284}
285
286/// CPU reference implementation for tanh.
287pub fn cpu_tanh(a: &[f32]) -> Vec<f32> {
288    a.iter().map(|x| x.tanh()).collect()
289}
290
291/// CPU reference implementation for matrix multiplication.
292/// A is m x k, B is k x n, C is m x n (row-major).
293pub fn cpu_gemm(a: &[f32], b: &[f32], m: usize, n: usize, k: usize) -> Vec<f32> {
294    let mut c = vec![0.0; m * n];
295    for i in 0..m {
296        for j in 0..n {
297            let mut sum = 0.0;
298            for p in 0..k {
299                sum += a[i * k + p] * b[p * n + j];
300            }
301            c[i * n + j] = sum;
302        }
303    }
304    c
305}
306
307// =============================================================================
308// CUDA Tests
309// =============================================================================
310
311#[cfg(feature = "cuda")]
312pub mod cuda_tests {
313    use super::*;
314    use crate::backends::Backend;
315    use crate::backends::cuda::{CudaBackend, device_count, is_available};
316
317    /// Run all CUDA tests.
318    pub fn run_all_tests(config: &GpuTestConfig) -> GpuTestReport {
319        let mut report = GpuTestReport::new("CUDA");
320
321        if !is_available() {
322            report.add_result(GpuTestResult::fail(
323                "cuda_availability",
324                "CUDA not available on this system",
325            ));
326            return report;
327        }
328
329        let backend = match CudaBackend::new(0) {
330            Some(b) => b,
331            None => {
332                report.add_result(GpuTestResult::fail(
333                    "backend_creation",
334                    "Failed to create CUDA backend",
335                ));
336                return report;
337            }
338        };
339
340        report = report.with_capabilities(backend.capabilities());
341
342        // Memory operations
343        report.add_result(test_memory_roundtrip(&backend, config));
344
345        // Element-wise operations
346        for &size in &config.test_sizes {
347            report.add_result(test_add(&backend, size, config));
348            report.add_result(test_mul(&backend, size, config));
349            report.add_result(test_scale(&backend, size, config));
350        }
351
352        // Activation functions
353        for &size in &config.test_sizes {
354            report.add_result(test_relu(&backend, size, config));
355            report.add_result(test_sigmoid(&backend, size, config));
356            report.add_result(test_tanh(&backend, size, config));
357        }
358
359        // Matrix multiplication
360        report.add_result(test_gemm_square(&backend, 64, config));
361        report.add_result(test_gemm_square(&backend, 256, config));
362        report.add_result(test_gemm_rectangular(&backend, 128, 64, 96, config));
363
364        report
365    }
366
367    fn test_memory_roundtrip(backend: &CudaBackend, _config: &GpuTestConfig) -> GpuTestResult {
368        let name = "memory_roundtrip";
369        let data: Vec<f32> = (0..1024).map(|i| i as f32).collect();
370
371        match backend.htod_copy(&data) {
372            Ok(gpu_data) => match backend.dtoh_copy(&gpu_data) {
373                Ok(result) => {
374                    if result == data {
375                        GpuTestResult::pass(name)
376                    } else {
377                        GpuTestResult::fail(name, "Data mismatch after roundtrip")
378                    }
379                }
380                Err(e) => GpuTestResult::fail(name, &format!("dtoh_copy failed: {}", e)),
381            },
382            Err(e) => GpuTestResult::fail(name, &format!("htod_copy failed: {}", e)),
383        }
384    }
385
386    fn test_add(backend: &CudaBackend, size: usize, config: &GpuTestConfig) -> GpuTestResult {
387        let name = format!("add_f32_{}", size);
388        let a = random_vec(size, 42);
389        let b = random_vec(size, 123);
390        let expected = cpu_add(&a, &b);
391
392        let gpu_a = match backend.htod_copy(&a) {
393            Ok(d) => d,
394            Err(e) => return GpuTestResult::fail(&name, &format!("htod_copy(a): {}", e)),
395        };
396        let gpu_b = match backend.htod_copy(&b) {
397            Ok(d) => d,
398            Err(e) => return GpuTestResult::fail(&name, &format!("htod_copy(b): {}", e)),
399        };
400        let mut gpu_c = match backend.alloc::<f32>(size) {
401            Ok(d) => d,
402            Err(e) => return GpuTestResult::fail(&name, &format!("alloc: {}", e)),
403        };
404
405        if let Err(e) = backend.add_f32(&mut gpu_c, &gpu_a, &gpu_b, size) {
406            return GpuTestResult::fail(&name, &format!("add_f32: {}", e));
407        }
408
409        backend.synchronize();
410
411        match backend.dtoh_copy(&gpu_c) {
412            Ok(result) => match assert_close(&expected, &result, config.atol, config.rtol) {
413                Ok(max_err) => GpuTestResult::pass(&name).with_error(max_err),
414                Err(e) => GpuTestResult::fail(&name, &e),
415            },
416            Err(e) => GpuTestResult::fail(&name, &format!("dtoh_copy: {}", e)),
417        }
418    }
419
420    fn test_mul(backend: &CudaBackend, size: usize, config: &GpuTestConfig) -> GpuTestResult {
421        let name = format!("mul_f32_{}", size);
422        let a = random_vec(size, 42);
423        let b = random_vec(size, 123);
424        let expected = cpu_mul(&a, &b);
425
426        let gpu_a = match backend.htod_copy(&a) {
427            Ok(d) => d,
428            Err(e) => return GpuTestResult::fail(&name, &format!("htod_copy(a): {}", e)),
429        };
430        let gpu_b = match backend.htod_copy(&b) {
431            Ok(d) => d,
432            Err(e) => return GpuTestResult::fail(&name, &format!("htod_copy(b): {}", e)),
433        };
434        let mut gpu_c = match backend.alloc::<f32>(size) {
435            Ok(d) => d,
436            Err(e) => return GpuTestResult::fail(&name, &format!("alloc: {}", e)),
437        };
438
439        if let Err(e) = backend.mul_f32(&mut gpu_c, &gpu_a, &gpu_b, size) {
440            return GpuTestResult::fail(&name, &format!("mul_f32: {}", e));
441        }
442
443        backend.synchronize();
444
445        match backend.dtoh_copy(&gpu_c) {
446            Ok(result) => match assert_close(&expected, &result, config.atol, config.rtol) {
447                Ok(max_err) => GpuTestResult::pass(&name).with_error(max_err),
448                Err(e) => GpuTestResult::fail(&name, &e),
449            },
450            Err(e) => GpuTestResult::fail(&name, &format!("dtoh_copy: {}", e)),
451        }
452    }
453
454    fn test_scale(backend: &CudaBackend, size: usize, config: &GpuTestConfig) -> GpuTestResult {
455        let name = format!("scale_f32_{}", size);
456        let a = random_vec(size, 42);
457        let alpha = 2.5f32;
458        let expected = cpu_scale(&a, alpha);
459
460        let mut gpu_a = match backend.htod_copy(&a) {
461            Ok(d) => d,
462            Err(e) => return GpuTestResult::fail(&name, &format!("htod_copy: {}", e)),
463        };
464
465        if let Err(e) = backend.scale_f32(&mut gpu_a, alpha, size) {
466            return GpuTestResult::fail(&name, &format!("scale_f32: {}", e));
467        }
468
469        backend.synchronize();
470
471        match backend.dtoh_copy(&gpu_a) {
472            Ok(result) => match assert_close(&expected, &result, config.atol, config.rtol) {
473                Ok(max_err) => GpuTestResult::pass(&name).with_error(max_err),
474                Err(e) => GpuTestResult::fail(&name, &e),
475            },
476            Err(e) => GpuTestResult::fail(&name, &format!("dtoh_copy: {}", e)),
477        }
478    }
479
480    fn test_relu(backend: &CudaBackend, size: usize, config: &GpuTestConfig) -> GpuTestResult {
481        let name = format!("relu_f32_{}", size);
482        let a = random_vec(size, 42);
483        let expected = cpu_relu(&a);
484
485        let gpu_a = match backend.htod_copy(&a) {
486            Ok(d) => d,
487            Err(e) => return GpuTestResult::fail(&name, &format!("htod_copy: {}", e)),
488        };
489        let mut gpu_b = match backend.alloc::<f32>(size) {
490            Ok(d) => d,
491            Err(e) => return GpuTestResult::fail(&name, &format!("alloc: {}", e)),
492        };
493
494        if let Err(e) = backend.relu_f32(&mut gpu_b, &gpu_a, size) {
495            return GpuTestResult::fail(&name, &format!("relu_f32: {}", e));
496        }
497
498        backend.synchronize();
499
500        match backend.dtoh_copy(&gpu_b) {
501            Ok(result) => match assert_close(&expected, &result, config.atol, config.rtol) {
502                Ok(max_err) => GpuTestResult::pass(&name).with_error(max_err),
503                Err(e) => GpuTestResult::fail(&name, &e),
504            },
505            Err(e) => GpuTestResult::fail(&name, &format!("dtoh_copy: {}", e)),
506        }
507    }
508
509    fn test_sigmoid(backend: &CudaBackend, size: usize, config: &GpuTestConfig) -> GpuTestResult {
510        let name = format!("sigmoid_f32_{}", size);
511        let a = random_vec(size, 42);
512        let expected = cpu_sigmoid(&a);
513
514        let gpu_a = match backend.htod_copy(&a) {
515            Ok(d) => d,
516            Err(e) => return GpuTestResult::fail(&name, &format!("htod_copy: {}", e)),
517        };
518        let mut gpu_b = match backend.alloc::<f32>(size) {
519            Ok(d) => d,
520            Err(e) => return GpuTestResult::fail(&name, &format!("alloc: {}", e)),
521        };
522
523        if let Err(e) = backend.sigmoid_f32(&mut gpu_b, &gpu_a, size) {
524            return GpuTestResult::fail(&name, &format!("sigmoid_f32: {}", e));
525        }
526
527        backend.synchronize();
528
529        // Sigmoid uses fast approximations, so allow higher tolerance
530        let sigmoid_atol = 1e-3;
531        let sigmoid_rtol = 1e-2;
532
533        match backend.dtoh_copy(&gpu_b) {
534            Ok(result) => match assert_close(&expected, &result, sigmoid_atol, sigmoid_rtol) {
535                Ok(max_err) => GpuTestResult::pass(&name).with_error(max_err),
536                Err(e) => GpuTestResult::fail(&name, &e),
537            },
538            Err(e) => GpuTestResult::fail(&name, &format!("dtoh_copy: {}", e)),
539        }
540    }
541
542    fn test_tanh(backend: &CudaBackend, size: usize, config: &GpuTestConfig) -> GpuTestResult {
543        let name = format!("tanh_f32_{}", size);
544        let a = random_vec(size, 42);
545        let expected = cpu_tanh(&a);
546
547        let gpu_a = match backend.htod_copy(&a) {
548            Ok(d) => d,
549            Err(e) => return GpuTestResult::fail(&name, &format!("htod_copy: {}", e)),
550        };
551        let mut gpu_b = match backend.alloc::<f32>(size) {
552            Ok(d) => d,
553            Err(e) => return GpuTestResult::fail(&name, &format!("alloc: {}", e)),
554        };
555
556        if let Err(e) = backend.tanh_f32(&mut gpu_b, &gpu_a, size) {
557            return GpuTestResult::fail(&name, &format!("tanh_f32: {}", e));
558        }
559
560        backend.synchronize();
561
562        // Tanh uses fast approximations
563        let tanh_atol = 1e-3;
564        let tanh_rtol = 1e-2;
565
566        match backend.dtoh_copy(&gpu_b) {
567            Ok(result) => match assert_close(&expected, &result, tanh_atol, tanh_rtol) {
568                Ok(max_err) => GpuTestResult::pass(&name).with_error(max_err),
569                Err(e) => GpuTestResult::fail(&name, &e),
570            },
571            Err(e) => GpuTestResult::fail(&name, &format!("dtoh_copy: {}", e)),
572        }
573    }
574
575    fn test_gemm_square(backend: &CudaBackend, n: usize, config: &GpuTestConfig) -> GpuTestResult {
576        test_gemm_rectangular(backend, n, n, n, config)
577    }
578
579    fn test_gemm_rectangular(
580        backend: &CudaBackend,
581        m: usize,
582        n: usize,
583        k: usize,
584        config: &GpuTestConfig,
585    ) -> GpuTestResult {
586        let name = format!("gemm_f32_{}x{}x{}", m, n, k);
587
588        // Generate test data
589        let a = random_vec(m * k, 42);
590        let b = random_vec(k * n, 123);
591        let expected = cpu_gemm(&a, &b, m, n, k);
592
593        // Convert to column-major for cuBLAS
594        let a_col = row_to_col_major(&a, m, k);
595        let b_col = row_to_col_major(&b, k, n);
596
597        let gpu_a = match backend.htod_copy(&a_col) {
598            Ok(d) => d,
599            Err(e) => return GpuTestResult::fail(&name, &format!("htod_copy(a): {}", e)),
600        };
601        let gpu_b = match backend.htod_copy(&b_col) {
602            Ok(d) => d,
603            Err(e) => return GpuTestResult::fail(&name, &format!("htod_copy(b): {}", e)),
604        };
605        let mut gpu_c = match backend.alloc::<f32>(m * n) {
606            Ok(d) => d,
607            Err(e) => return GpuTestResult::fail(&name, &format!("alloc: {}", e)),
608        };
609
610        // cuBLAS GEMM: C = alpha * A @ B + beta * C
611        if let Err(e) = backend.gemm_f32(
612            false, false, // no transpose
613            m, n, k, 1.0, // alpha
614            &gpu_a, m, // A, lda
615            &gpu_b, k,   // B, ldb
616            0.0, // beta
617            &mut gpu_c, m, // C, ldc
618        ) {
619            return GpuTestResult::fail(&name, &format!("gemm_f32: {}", e));
620        }
621
622        backend.synchronize();
623
624        match backend.dtoh_copy(&gpu_c) {
625            Ok(result_col) => {
626                // Convert back from column-major
627                let result = col_to_row_major(&result_col, m, n);
628
629                // GEMM can have larger numerical errors
630                let gemm_atol = 1e-3;
631                let gemm_rtol = 1e-2;
632
633                match assert_close(&expected, &result, gemm_atol, gemm_rtol) {
634                    Ok(max_err) => GpuTestResult::pass(&name).with_error(max_err),
635                    Err(e) => GpuTestResult::fail(&name, &e),
636                }
637            }
638            Err(e) => GpuTestResult::fail(&name, &format!("dtoh_copy: {}", e)),
639        }
640    }
641
642    // Helper: row-major to column-major conversion
643    fn row_to_col_major(data: &[f32], rows: usize, cols: usize) -> Vec<f32> {
644        let mut result = vec![0.0; rows * cols];
645        for i in 0..rows {
646            for j in 0..cols {
647                result[j * rows + i] = data[i * cols + j];
648            }
649        }
650        result
651    }
652
653    // Helper: column-major to row-major conversion
654    fn col_to_row_major(data: &[f32], rows: usize, cols: usize) -> Vec<f32> {
655        let mut result = vec![0.0; rows * cols];
656        for i in 0..rows {
657            for j in 0..cols {
658                result[i * cols + j] = data[j * rows + i];
659            }
660        }
661        result
662    }
663}
664
665// =============================================================================
666// Hardware Detection
667// =============================================================================
668
669/// Detect available GPU backends.
670pub fn detect_gpu_backends() -> Vec<String> {
671    let mut backends = Vec::new();
672
673    #[cfg(feature = "cuda")]
674    {
675        if crate::backends::cuda::is_available() {
676            backends.push(format!(
677                "CUDA ({} device(s))",
678                crate::backends::cuda::device_count()
679            ));
680        }
681    }
682
683    #[cfg(feature = "vulkan")]
684    {
685        backends.push("Vulkan".to_string());
686    }
687
688    #[cfg(feature = "metal")]
689    {
690        #[cfg(target_os = "macos")]
691        backends.push("Metal".to_string());
692    }
693
694    #[cfg(feature = "wgpu")]
695    {
696        backends.push("WebGPU".to_string());
697    }
698
699    if backends.is_empty() {
700        backends.push("None (CPU only)".to_string());
701    }
702
703    backends
704}
705
706/// Print GPU detection information.
707pub fn print_gpu_info() {
708    println!("GPU Backend Detection");
709    println!("=====================");
710
711    let backends = detect_gpu_backends();
712    for backend in &backends {
713        println!("  - {}", backend);
714    }
715
716    #[cfg(feature = "cuda")]
717    {
718        if crate::backends::cuda::is_available() {
719            println!("\nCUDA Devices:");
720            for i in 0..crate::backends::cuda::device_count() {
721                let caps = crate::backends::cuda::get_capabilities(i);
722                println!("  [{}] {}", i, caps.name);
723                println!("      Memory: {:.1} GB", caps.total_memory as f64 / 1e9);
724                if let Some(cc) = caps.compute_capability {
725                    println!("      Compute: {}.{}", cc.0, cc.1);
726                }
727            }
728        }
729    }
730}
731
732// =============================================================================
733// Tests
734// =============================================================================
735
736#[cfg(test)]
737mod tests {
738    use super::*;
739
740    #[test]
741    fn test_random_vec_reproducibility() {
742        let a = random_vec(100, 42);
743        let b = random_vec(100, 42);
744        assert_eq!(a, b, "Same seed should produce same output");
745    }
746
747    #[test]
748    fn test_cpu_add() {
749        let a = vec![1.0, 2.0, 3.0];
750        let b = vec![4.0, 5.0, 6.0];
751        let c = cpu_add(&a, &b);
752        assert_eq!(c, vec![5.0, 7.0, 9.0]);
753    }
754
755    #[test]
756    fn test_cpu_mul() {
757        let a = vec![1.0, 2.0, 3.0];
758        let b = vec![4.0, 5.0, 6.0];
759        let c = cpu_mul(&a, &b);
760        assert_eq!(c, vec![4.0, 10.0, 18.0]);
761    }
762
763    #[test]
764    fn test_cpu_relu() {
765        let a = vec![-1.0, 0.0, 1.0, 2.0];
766        let b = cpu_relu(&a);
767        assert_eq!(b, vec![0.0, 0.0, 1.0, 2.0]);
768    }
769
770    #[test]
771    fn test_cpu_gemm() {
772        // 2x3 @ 3x2 = 2x2
773        let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
774        let b = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
775        let c = cpu_gemm(&a, &b, 2, 2, 3);
776        // Expected:
777        // [1*1+2*3+3*5, 1*2+2*4+3*6] = [22, 28]
778        // [4*1+5*3+6*5, 4*2+5*4+6*6] = [49, 64]
779        assert_eq!(c, vec![22.0, 28.0, 49.0, 64.0]);
780    }
781
782    #[test]
783    fn test_assert_close_pass() {
784        let a = vec![1.0, 2.0, 3.0];
785        let b = vec![1.00001, 2.00001, 3.00001];
786        assert!(assert_close(&a, &b, 1e-4, 1e-4).is_ok());
787    }
788
789    #[test]
790    fn test_assert_close_fail() {
791        let a = vec![1.0, 2.0, 3.0];
792        let b = vec![1.1, 2.0, 3.0];
793        assert!(assert_close(&a, &b, 1e-4, 1e-4).is_err());
794    }
795
796    #[test]
797    fn test_detect_backends() {
798        let backends = detect_gpu_backends();
799        assert!(!backends.is_empty());
800    }
801
802    #[test]
803    fn test_gpu_test_result() {
804        let pass = GpuTestResult::pass("test").with_error(0.0001);
805        assert!(pass.passed);
806        assert!(pass.max_abs_error.is_some());
807
808        let fail = GpuTestResult::fail("test", "error");
809        assert!(!fail.passed);
810        assert_eq!(fail.error, Some("error".to_string()));
811    }
812
813    #[cfg(feature = "cuda")]
814    #[test]
815    fn test_cuda_all() {
816        let config = GpuTestConfig::default();
817        let report = cuda_tests::run_all_tests(&config);
818        report.print_summary();
819
820        // If CUDA is available, all tests should pass
821        if crate::backends::cuda::is_available() {
822            assert_eq!(report.failed_count(), 0, "Some CUDA tests failed");
823        }
824    }
825}