#[cfg(feature = "gpu")]
use scirs2_core::gpu::kernels::{DataType, KernelParams};
#[cfg(feature = "gpu")]
use scirs2_core::gpu::{GpuBackend, GpuContext};
use scirs2_core::ndarray_ext::Array2;
#[allow(dead_code)]
fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("=== SciRS2 GPU Foundation Example ===\n");
#[cfg(feature = "gpu")]
run_gpu_foundation_demo()?;
#[cfg(not(feature = "gpu"))]
println!("GPU feature not enabled. Run with --features=\"gpu\" to see the GPU examples.");
Ok(())
}
#[cfg(feature = "gpu")]
#[allow(dead_code)]
fn run_gpu_foundation_demo() -> Result<(), Box<dyn std::error::Error>> {
println!("1. GPU Backend Detection");
println!("------------------------");
let backends = [
GpuBackend::Cuda,
GpuBackend::Wgpu,
GpuBackend::Metal,
GpuBackend::OpenCL,
GpuBackend::Cpu,
];
for backend in backends.iter() {
println!(
"Backend {}: Available = {}",
backend,
backend.is_available()
);
}
let preferred = GpuBackend::preferred();
println!("Preferred backend: {}", preferred);
println!("\n2. GPU Context Creation");
println!("-----------------------");
let ctx = match GpuContext::new(preferred) {
Ok(ctx) => {
println!(
"✓ Successfully created GPU context with {} backend",
ctx.backend_name()
);
ctx
}
Err(e) => {
println!("✗ Failed to create GPU context: {}", e);
println!("Falling back to CPU backend...");
GpuContext::new(GpuBackend::Cpu)?
}
};
println!("\n3. Memory Management");
println!("--------------------");
let size = 1024;
let host_data: Vec<f32> = (0..size).map(|i| (i as f32).sin()).collect();
let buffer = ctx.create_buffer_from_slice(&host_data);
println!("✓ Created GPU buffer with {} elements", buffer.len());
let retrieved_data = buffer.to_vec();
let data_matches = host_data
.iter()
.zip(retrieved_data.iter())
.all(|(a, b)| (a - b).abs() < 1e-6);
if data_matches {
println!("✓ Buffer data transfer successful");
} else {
println!("✗ Buffer data transfer failed");
}
println!("\n4. Kernel Registry");
println!("------------------");
let kernel_names = ["gemm_standard", "sum_kernel", "relu_kernel"];
for name in kernel_names.iter() {
match ctx.get_kernel(name) {
Ok(kernel) => {
println!("✓ Found kernel: {}", name);
kernel.set_f32("test_param", 1.0);
kernel.dispatch([1, 1, 1]);
println!(" - Kernel executed successfully");
}
Err(e) => {
println!("✗ Kernel '{}' not found: {}", name, e);
}
}
}
println!("\n5. Kernel Specialization");
println!("-------------------------");
let params = KernelParams::new(DataType::Float32)
.with_input_dims(vec![128, 256])
.with_output_dims(vec![128, 512])
.with_numeric_param("alpha", 1.0)
.with_numeric_param("beta", 0.0);
match ctx.get_specialized_kernel("gemm_standard", ¶ms) {
Ok(_kernel) => {
println!("✓ Created specialized GEMM kernel for 128x256 * 256x512 matrices");
println!(" - Kernel compiled successfully");
println!(" - Optimized for the specified matrix dimensions");
println!(" - Ready for execution");
}
Err(e) => {
println!("✗ Failed to create specialized kernel: {}", e);
}
}
println!("\n6. Matrix Operations Example");
println!("-----------------------------");
let a = Array2::from_shape_fn((4, 4), |(i, j)| (i * 4 + j) as f32);
let b = Array2::from_shape_fn((4, 4), |(i, j)| ((i + j) as f32).cos());
println!("Matrix A (4x4):");
for row in a.outer_iter() {
println!(" {:?}", row.to_vec());
}
println!("Matrix B (4x4):");
for row in b.outer_iter() {
println!(" {:?}", row.to_vec());
}
let c = a.dot(&b);
println!("Result C = A * B (4x4) [CPU computation]:");
for row in c.outer_iter() {
println!(
" {:?}",
row.iter().map(|x| format!("{:.2}", x)).collect::<Vec<_>>()
);
}
println!("\n7. Error Handling");
println!("-----------------");
match ctx.get_kernel("nonexistent_kernel") {
Ok(_) => println!("✗ Unexpected success for nonexistent kernel"),
Err(e) => println!("✓ Proper error handling: {}", e),
}
match GpuContext::new(GpuBackend::Metal) {
Ok(_) => println!("✓ Metal backend available"),
Err(e) => println!("✓ Proper error for unavailable backend: {}", e),
}
println!("\n=== GPU Foundation Demo Complete ===");
println!("The GPU foundation provides:");
println!("• Multi-backend support (CUDA, WebGPU, Metal, OpenCL, CPU fallback)");
println!("• Kernel library with optimized implementations");
println!("• Automatic kernel specialization");
println!("• Memory management abstraction");
println!("• Comprehensive error handling");
println!("\nNext steps for production:");
println!("• Implement actual CUDA/WebGPU/Metal backends");
println!("• Add JIT compilation for dynamic kernels");
println!("• Integrate with linear algebra operations");
println!("• Add performance monitoring and auto-tuning");
Ok(())
}