use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_gpu_operation_dispatcher() {
let dispatcher = super::dispatcher::GpuOperationDispatcher::<f64>::new();
assert_eq!(
dispatcher.threshold(),
super::dispatcher::DEFAULT_GPU_THRESHOLD
);
let mut dispatcher = super::dispatcher::GpuOperationDispatcher::<f64>::with_threshold(1000);
assert_eq!(dispatcher.threshold(), 1000);
dispatcher.set_threshold(2000);
assert_eq!(dispatcher.threshold(), 2000);
}
#[test]
fn test_cpu_fallback_operations() {
let dispatcher = super::dispatcher::GpuOperationDispatcher::<f64>::new();
let a = array![[1.0, 2.0], [3.0, 4.0]];
let x = array![1.0, 2.0];
let result = dispatcher
.cpu_matvec(&a.view(), &x.view())
.expect("Test: operation failed");
assert_eq!(result, array![5.0, 11.0]);
let b = array![[1.0, 0.0], [0.0, 1.0]];
let result = dispatcher
.cpu_matmul(&a.view(), &b.view())
.expect("Test: operation failed");
assert_eq!(result, a);
}
#[test]
fn test_kernel_manager() {
let mut manager = super::kernels::GpuKernelManager::new();
manager
.load_kernel("test_kernel", "kernel void test() {}")
.expect("Test: operation failed");
let kernels = manager.list_kernels();
assert!(kernels.contains(&"test_kernel".to_string()));
}
#[test]
fn test_performance_profiler() {
let mut profiler = super::profiling::GpuPerformanceProfiler::new();
profiler.record("matmul", 0.1);
profiler.record("matmul", 0.2);
profiler.record("matvec", 0.05);
let avg_matmul = profiler
.average_time("matmul")
.expect("Test: operation failed");
assert!((avg_matmul - 0.15).abs() < 1e-10);
assert_eq!(profiler.best_time("matmul"), Some(0.1));
assert_eq!(profiler.average_time("matvec"), Some(0.05));
let ops = profiler.operations();
assert!(ops.contains(&"matmul"));
assert!(ops.contains(&"matvec"));
profiler.clear();
assert!(profiler.operations().is_empty());
}
#[test]
fn test_batch_size_optimizer() {
let mut optimizer = super::optimization::BatchSizeOptimizer::new(1024 * 1024);
let optimal = optimizer.optimize_batchsize("matrix_multiply", 1000);
assert!(optimal > 0 && optimal <= 1000);
let record = super::optimization::BatchPerformanceRecord {
operation: "matrix_multiply".to_string(),
batchsize: 64,
execution_time: 0.1,
memory_usage: 1000,
throughput: 640.0,
};
optimizer.record_performance(record);
let history = optimizer.get_performance_history("matrix_multiply");
assert_eq!(history.len(), 1);
}
#[test]
fn test_metrics() {
let mut metrics = super::metrics::MultiDimensionalMetrics::new();
metrics.record_execution_time("matmul", 0.1, 0.05, 0.01);
metrics.record_memory_metrics("matmul", 1000.0, 500.0, 0.8);
metrics.record_energy_metrics("matmul", 10.0, 25.0);
metrics.record_throughput_metrics("matmul", 100.0, 1000.0, 50.0);
assert!(metrics.get_time_metrics("matmul").is_some());
assert!(metrics.get_memory_metrics("matmul").is_some());
assert!(metrics.get_energy_metrics("matmul").is_some());
assert!(metrics.get_throughput_metrics("matmul").is_some());
let ops = metrics.get_operations();
assert!(ops.contains(&"matmul"));
}
#[test]
fn test_hardware_profiler() {
let mut profiler = super::hardware::HardwareCapabilityProfiler::new();
let profile = super::hardware::DeviceProfile {
peak_flops_sp: 1000.0,
peak_flops_dp: 500.0,
memory_bandwidth: 100.0,
l1_cachesize: 32 * 1024,
l2_cachesize: 1024 * 1024,
shared_memory: 32 * 1024,
register_count: 32768,
tensor_core_support: true,
mixed_precision_support: true,
};
profiler.add_device_profile("gpu_0".to_string(), profile.clone());
assert!(profiler.get_device_profile("gpu_0").is_some());
profiler.record_benchmark("gpu_0".to_string(), "matmul".to_string(), 800.0);
assert_eq!(profiler.get_benchmark("gpu_0", "matmul"), Some(800.0));
profiler.add_capability_flag("gpu_0".to_string(), "tensor_cores".to_string());
assert!(profiler.has_capability("gpu_0", "tensor_cores"));
let devices = profiler.get_available_devices();
assert!(devices.contains(&"gpu_0"));
}
#[test]
fn test_running_stats() {
let mut stats = super::metrics::RunningStats::new();
stats.update(1.0);
stats.update(2.0);
stats.update(3.0);
assert_eq!(stats.count, 3);
assert_eq!(stats.mean, 2.0);
assert_eq!(stats.min, 1.0);
assert_eq!(stats.max, 3.0);
assert!((stats.std_dev() - (1.0_f64)).abs() < 1e-10);
}