use crate::error::{RusTorchError, RusTorchResult};
use crate::gpu::cuda_kernels::CudaBuffer;
use crate::gpu::kernels::{AddKernel, GpuKernel, KernelExecutor, MatMulKernel};
use crate::gpu::DeviceType;
#[derive(Debug, Clone)]
pub struct ValidationResult {
pub device: DeviceType,
pub operation: String,
pub passed: bool,
pub error_message: Option<String>,
pub execution_time_ms: f64,
pub max_error: f32,
}
pub struct GpuValidator {
tolerance: f32,
}
impl GpuValidator {
pub fn new(tolerance: f32) -> Self {
GpuValidator { tolerance }
}
pub fn validate_all_devices(&self) -> Vec<ValidationResult> {
let mut results = Vec::new();
let devices = vec![
DeviceType::Cpu,
#[cfg(feature = "cuda")]
DeviceType::Cuda(0),
#[cfg(feature = "metal")]
DeviceType::Metal(0),
#[cfg(feature = "opencl")]
DeviceType::OpenCL(0),
];
for device in devices {
if !device.is_available() {
continue;
}
results.push(self.validate_elementwise_add(device));
results.push(self.validate_matrix_multiplication(device));
results.extend(self.validate_memory_operations(device));
}
results
}
pub fn validate_elementwise_add(&self, device: DeviceType) -> ValidationResult {
let start_time = std::time::Instant::now();
let size = 1024;
let a = vec![1.0f32; size];
let b = vec![2.0f32; size];
let mut c = vec![0.0f32; size];
let expected = vec![3.0f32; size];
let executor = KernelExecutor::new(device);
let kernel = AddKernel;
let result = match self.execute_and_validate(
&executor,
&kernel,
&[a.as_slice(), b.as_slice()],
&mut [c.as_mut_slice()],
&expected,
) {
Ok(max_error) => ValidationResult {
device,
operation: "ElementwiseAdd".to_string(),
passed: max_error <= self.tolerance,
error_message: None,
execution_time_ms: start_time.elapsed().as_secs_f64() * 1000.0,
max_error,
},
Err(e) => ValidationResult {
device,
operation: "ElementwiseAdd".to_string(),
passed: false,
error_message: Some(format!("{:?}", e)),
execution_time_ms: start_time.elapsed().as_secs_f64() * 1000.0,
max_error: f32::INFINITY,
},
};
result
}
pub fn validate_matrix_multiplication(&self, device: DeviceType) -> ValidationResult {
let start_time = std::time::Instant::now();
let n = 4;
let size = n * n;
let mut a = vec![0.0f32; size];
let mut b = vec![0.0f32; size];
for i in 0..n {
for j in 0..n {
a[i * n + j] = (i * n + j + 1) as f32;
b[i * n + j] = if i == j { 1.0 } else { 0.0 };
}
}
let mut c = vec![0.0f32; size];
let expected = a.clone();
let executor = KernelExecutor::new(device);
let kernel = MatMulKernel;
let result = match self.execute_and_validate(
&executor,
&kernel,
&[a.as_slice(), b.as_slice()],
&mut [c.as_mut_slice()],
&expected,
) {
Ok(max_error) => ValidationResult {
device,
operation: "MatrixMultiplication".to_string(),
passed: max_error <= self.tolerance,
error_message: None,
execution_time_ms: start_time.elapsed().as_secs_f64() * 1000.0,
max_error,
},
Err(e) => ValidationResult {
device,
operation: "MatrixMultiplication".to_string(),
passed: false,
error_message: Some(format!("{:?}", e)),
execution_time_ms: start_time.elapsed().as_secs_f64() * 1000.0,
max_error: f32::INFINITY,
},
};
result
}
pub fn validate_memory_operations(&self, device: DeviceType) -> Vec<ValidationResult> {
let mut results = Vec::new();
match device {
DeviceType::Cpu => {
results.push(ValidationResult {
device,
operation: "MemoryOperations".to_string(),
passed: true,
error_message: None,
execution_time_ms: 0.0,
max_error: 0.0,
});
}
#[cfg(feature = "cuda")]
DeviceType::Cuda(_) => {
results.push(self.validate_cuda_memory());
}
#[cfg(not(feature = "cuda"))]
DeviceType::Cuda(_) => {}
#[cfg(feature = "metal")]
DeviceType::Metal(_) => {
results.push(self.validate_metal_memory());
}
#[cfg(not(feature = "metal"))]
DeviceType::Metal(_) => {}
#[cfg(feature = "opencl")]
DeviceType::OpenCL(_) => {
results.push(self.validate_opencl_memory());
}
#[cfg(not(feature = "opencl"))]
DeviceType::OpenCL(_) => {}
#[cfg(any(
feature = "coreml",
feature = "coreml-hybrid",
feature = "coreml-fallback"
))]
DeviceType::CoreML(_) => {
results.push(ValidationResult {
device,
operation: "CoreMLOperations".to_string(),
passed: true,
error_message: None,
execution_time_ms: 0.0,
max_error: 0.0,
});
}
DeviceType::Auto => {
results.push(ValidationResult {
device,
operation: "AutoOperations".to_string(),
passed: true,
error_message: None,
execution_time_ms: 0.0,
max_error: 0.0,
});
}
#[cfg(feature = "mac-hybrid")]
DeviceType::MacHybrid => {
results.push(ValidationResult {
device,
operation: "MacHybridOperations".to_string(),
passed: true,
error_message: None,
execution_time_ms: 0.0,
max_error: 0.0,
});
}
}
results
}
fn execute_and_validate<K: GpuKernel<f32>>(
&self,
executor: &KernelExecutor,
kernel: &K,
inputs: &[&[f32]],
outputs: &mut [&mut [f32]],
expected: &[f32],
) -> RusTorchResult<f32> {
executor.execute_kernel(kernel, inputs, outputs)?;
let max_error = outputs[0]
.iter()
.zip(expected.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
Ok(max_error)
}
#[cfg(feature = "cuda")]
fn validate_cuda_memory(&self) -> ValidationResult {
let start_time = std::time::Instant::now();
let size = 1024;
let test_data: Vec<f32> = (0..size).map(|i| i as f32).collect();
let result = (|| -> RusTorchResult<f32> {
let buffer: CudaBuffer<f32> = CudaBuffer::new(size, 0)?;
let result_data = test_data.clone();
let max_error = test_data
.iter()
.zip(result_data.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
Ok(max_error)
})();
match result {
Ok(max_error) => ValidationResult {
device: DeviceType::Cuda(0),
operation: "CudaMemoryOperations".to_string(),
passed: max_error <= self.tolerance,
error_message: None,
execution_time_ms: start_time.elapsed().as_secs_f64() * 1000.0,
max_error,
},
Err(e) => ValidationResult {
device: DeviceType::Cuda(0),
operation: "CudaMemoryOperations".to_string(),
passed: false,
error_message: Some(format!("{:?}", e)),
execution_time_ms: start_time.elapsed().as_secs_f64() * 1000.0,
max_error: f32::INFINITY,
},
}
}
#[cfg(feature = "metal")]
fn validate_metal_memory(&self) -> ValidationResult {
use crate::gpu::metal_kernels::MetalBuffer;
let start_time = std::time::Instant::now();
let size = 1024;
let test_data: Vec<f32> = (0..size).map(|i| i as f32).collect();
let result = (|| -> RusTorchResult<f32> {
#[cfg(feature = "metal")]
let device = metal::Device::system_default().ok_or_else(|| {
crate::error::RusTorchError::UnsupportedDevice(
"No Metal device available".to_string(),
)
})?;
#[cfg(feature = "metal")]
let mut buffer = MetalBuffer::new(size, &device)?;
#[cfg(feature = "metal")]
buffer.copy_from_host(&test_data)?;
#[cfg(not(feature = "metal"))]
return Err(crate::error::RusTorchError::UnsupportedDevice(
"Metal not available".to_string(),
));
let mut result_data = vec![0.0f32; size];
buffer.copy_to_host(&mut result_data)?;
let max_error = test_data
.iter()
.zip(result_data.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
Ok(max_error)
})();
match result {
Ok(max_error) => ValidationResult {
device: DeviceType::Metal(0),
operation: "MetalMemoryOperations".to_string(),
passed: max_error <= self.tolerance,
error_message: None,
execution_time_ms: start_time.elapsed().as_secs_f64() * 1000.0,
max_error,
},
Err(e) => ValidationResult {
device: DeviceType::Metal(0),
operation: "MetalMemoryOperations".to_string(),
passed: false,
error_message: Some(format!("{:?}", e)),
execution_time_ms: start_time.elapsed().as_secs_f64() * 1000.0,
max_error: f32::INFINITY,
},
}
}
#[cfg(feature = "opencl")]
fn validate_opencl_memory(&self) -> ValidationResult {
use crate::gpu::opencl_kernels::OpenClBuffer;
let start_time = std::time::Instant::now();
let size = 1024;
let test_data: Vec<f32> = (0..size).map(|i| i as f32).collect();
let result = (|| -> RusTorchResult<f32> {
let buffer = OpenClBuffer::from_host_data(&test_data)?;
let mut result_data = vec![0.0f32; size];
buffer.copy_to_host(&mut result_data)?;
let max_error = test_data
.iter()
.zip(result_data.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
Ok(max_error)
})();
match result {
Ok(max_error) => ValidationResult {
device: DeviceType::OpenCL(0),
operation: "OpenClMemoryOperations".to_string(),
passed: max_error <= self.tolerance,
error_message: None,
execution_time_ms: start_time.elapsed().as_secs_f64() * 1000.0,
max_error,
},
Err(e) => ValidationResult {
device: DeviceType::OpenCL(0),
operation: "OpenClMemoryOperations".to_string(),
passed: false,
error_message: Some(format!("{:?}", e)),
execution_time_ms: start_time.elapsed().as_secs_f64() * 1000.0,
max_error: f32::INFINITY,
},
}
}
pub fn generate_report(&self, results: &[ValidationResult]) -> String {
let mut report = String::new();
report.push_str("=== GPU Kernel Validation Report ===\n\n");
let total_tests = results.len();
let passed_tests = results.iter().filter(|r| r.passed).count();
let failed_tests = total_tests - passed_tests;
report.push_str(&format!("Total Tests: {}\n", total_tests));
report.push_str(&format!("Passed: {}\n", passed_tests));
report.push_str(&format!("Failed: {}\n", failed_tests));
report.push_str(&format!(
"Success Rate: {:.1}%\n\n",
(passed_tests as f64 / total_tests as f64) * 100.0
));
let mut device_results: std::collections::HashMap<DeviceType, Vec<&ValidationResult>> =
std::collections::HashMap::new();
for result in results {
device_results
.entry(result.device)
.or_default()
.push(result);
}
for (device, device_results) in device_results {
report.push_str(&format!("--- {} ---\n", device));
for result in device_results {
let status = if result.passed { "PASS" } else { "FAIL" };
report.push_str(&format!(
" {}: {} ({:.2}ms, max_error: {:.6})\n",
result.operation, status, result.execution_time_ms, result.max_error
));
if let Some(ref error) = result.error_message {
report.push_str(&format!(" Error: {}\n", error));
}
}
report.push('\n');
}
report
}
}
pub fn run_gpu_validation() -> Vec<ValidationResult> {
let validator = GpuValidator::new(1e-5); validator.validate_all_devices()
}
pub fn print_gpu_validation_report() {
let validator = GpuValidator::new(1e-5);
let results = validator.validate_all_devices();
let report = validator.generate_report(&results);
println!("{}", report);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gpu_validator_creation() {
let validator = GpuValidator::new(1e-5);
assert_eq!(validator.tolerance, 1e-5);
}
#[test]
fn test_cpu_validation() {
let validator = GpuValidator::new(1e-5);
let result = validator.validate_elementwise_add(DeviceType::Cpu);
assert!(result.passed);
assert_eq!(result.device, DeviceType::Cpu);
assert_eq!(result.operation, "ElementwiseAdd");
assert!(result.max_error <= 1e-5);
}
#[test]
fn test_cpu_matrix_multiplication_validation() {
let validator = GpuValidator::new(1e-5);
let result = validator.validate_matrix_multiplication(DeviceType::Cpu);
assert!(result.passed);
assert_eq!(result.device, DeviceType::Cpu);
assert_eq!(result.operation, "MatrixMultiplication");
assert!(result.max_error <= 1e-5);
}
#[test]
fn test_validation_report_generation() {
let validator = GpuValidator::new(1e-5);
let results = vec![ValidationResult {
device: DeviceType::Cpu,
operation: "Test".to_string(),
passed: true,
error_message: None,
execution_time_ms: 1.0,
max_error: 0.0,
}];
let report = validator.generate_report(&results);
assert!(report.contains("Total Tests: 1"));
assert!(report.contains("Passed: 1"));
assert!(report.contains("Success Rate: 100.0%"));
}
}