use super::*;
use crate::cuda::memory::{SizeClass, TransferMode};
use crate::cuda::pipeline::{
presets, AsyncPipeline, BankConflictStrategy, MemoryPattern, PtxOptimizationHints,
PtxOptimizer, RegisterTiling,
};
use proptest::prelude::*;
use serial_test::serial;
fn has_cuda() -> bool {
CudaExecutor::is_available() && CudaExecutor::num_devices() > 0
}
proptest! {
#[test]
#[serial]
fn prop_lifecycle_cycles_always_succeed(cycles in 1..5usize) {
if !has_cuda() {
return Ok(());
}
for i in 0..cycles {
let executor = CudaExecutor::new(0)
.map_err(|e| TestCaseError::fail(format!("Cycle {}: {}", i, e)))?;
prop_assert!(executor.device_name().is_ok());
}
}
#[test]
#[serial]
fn prop_gemm_valid_dims_succeed(size in 4..16u32) {
if !has_cuda() {
return Ok(());
}
let mut executor = CudaExecutor::new(0)
.map_err(|e| TestCaseError::fail(format!("{}", e)))?;
let n = size * size;
let a = vec![1.0f32; n as usize];
let b = vec![1.0f32; n as usize];
let mut c = vec![0.0f32; n as usize];
let result = executor.gemm(&a, &b, &mut c, size, size, size);
prop_assert!(result.is_ok(), "GEMM should succeed for {}x{}", size, size);
let expected = size as f32;
for (i, &val) in c.iter().enumerate() {
prop_assert!(
(val - expected).abs() < 1e-3,
"c[{}] = {}, expected {}",
i,
val,
expected
);
}
}
#[test]
#[serial]
fn prop_sequential_executors_independent(count in 1..3usize) {
if !has_cuda() {
return Ok(());
}
for i in 0..count {
let mut executor = CudaExecutor::new(0)
.map_err(|e| TestCaseError::fail(format!("Executor {}: {}", i, e)))?;
let a = vec![1.0f32; 16];
let b = vec![1.0f32; 16];
let mut c = vec![0.0f32; 16];
let result = executor.gemm(&a, &b, &mut c, 4, 4, 4);
prop_assert!(result.is_ok(), "Executor {} GEMM failed", i);
}
}
}
#[test]
#[serial]
fn test_gemm_invalid_size_always_rejected() {
if !has_cuda() {
return;
}
let mut executor = CudaExecutor::new(0).expect("test");
let a = vec![1.0f32; 10]; let b = vec![1.0f32; 16];
let mut c = vec![0.0f32; 16];
assert!(executor.gemm(&a, &b, &mut c, 4, 4, 4).is_err());
let a = vec![1.0f32; 16];
let b = vec![1.0f32; 10]; let mut c = vec![0.0f32; 16];
assert!(executor.gemm(&a, &b, &mut c, 4, 4, 4).is_err());
let a = vec![1.0f32; 16];
let b = vec![1.0f32; 16];
let mut c = vec![0.0f32; 10]; assert!(executor.gemm(&a, &b, &mut c, 4, 4, 4).is_err());
}
#[test]
fn test_imp_1000a_fp16_tensor_core_ptx_generation() {
let kernels = CudaKernels::new();
let kernel_type = KernelType::GemmFp16TensorCore {
m: 64,
n: 64,
k: 64,
};
let ptx = kernels.generate_ptx(&kernel_type);
assert!(ptx.contains(".visible .entry gemm_wmma_fp16"));
assert!(ptx.contains(".param .u64 a_ptr"));
assert!(ptx.contains(".param .u64 b_ptr"));
assert!(ptx.contains(".param .u64 c_ptr"));
assert!(ptx.contains(".param .u32 m") || ptx.contains("m_param"));
assert!(ptx.contains(".shared"));
assert_eq!(kernels.kernel_name(&kernel_type), "gemm_wmma_fp16");
}
#[test]
fn test_imp_1000a_fp16_dimension_requirements() {
let kernel_type = KernelType::GemmFp16TensorCore {
m: 16, n: 32, k: 48, };
let kernels = CudaKernels::new();
let ptx = kernels.generate_ptx(&kernel_type);
assert!(!ptx.is_empty());
assert!(ptx.contains("gemm_wmma_fp16")); }
#[test]
#[serial]
fn test_imp_1000a_fp16_gemm_alignment_validation() {
if !has_cuda() {
return;
}
let mut executor = CudaExecutor::new(0).expect("test");
let a = vec![1.0f32; 16 * 32];
let b = vec![1.0f32; 32 * 16];
let mut c = vec![0.0f32; 16 * 16];
assert!(executor.gemm_fp16(&a, &b, &mut c, 16, 16, 32).is_ok());
let a = vec![1.0f32; 15 * 32];
let b = vec![1.0f32; 32 * 16];
let mut c = vec![0.0f32; 15 * 16];
assert!(executor.gemm_fp16(&a, &b, &mut c, 15, 16, 32).is_err());
let a = vec![1.0f32; 16 * 32];
let b = vec![1.0f32; 32 * 17];
let mut c = vec![0.0f32; 16 * 17];
assert!(executor.gemm_fp16(&a, &b, &mut c, 16, 17, 32).is_err());
let a = vec![1.0f32; 16 * 33];
let b = vec![1.0f32; 33 * 16];
let mut c = vec![0.0f32; 16 * 16];
assert!(executor.gemm_fp16(&a, &b, &mut c, 16, 16, 33).is_err());
}
#[test]
#[serial]
fn test_imp_1000a_fp16_gemm_correctness() {
if !has_cuda() {
return;
}
let mut executor = CudaExecutor::new(0).expect("test");
let m = 16u32;
let n = 16u32;
let k = 16u32;
let a = vec![1.0f32; (m * k) as usize];
let mut b = vec![0.0f32; (k * n) as usize];
for i in 0..k.min(n) {
b[(i * n + i) as usize] = 1.0;
}
let mut c = vec![0.0f32; (m * n) as usize];
executor.gemm_fp16(&a, &b, &mut c, m, n, k).expect("test");
for row in 0..m {
let row_sum: f32 = (0..n).map(|col| c[(row * n + col) as usize]).sum();
assert!(
(row_sum - n as f32).abs() < 1.0,
"Row {} sum {} != {}",
row,
row_sum,
n
);
}
}
#[test]
fn test_imp_1000b_q4k_fused_ptx_generation() {
let kernels = CudaKernels::new();
let kernel_type = KernelType::QuantizedGemm {
m: 1,
n: 4096,
k: 4096,
};
let ptx = kernels.generate_ptx(&kernel_type);
assert!(ptx.contains(".visible .entry q4k_gemm_fused"));
assert!(ptx.contains(".param .u64 a_ptr"));
assert!(ptx.contains(".param .u64 b_quant_ptr"));
assert!(ptx.contains(".param .u64 c_ptr"));
assert!(ptx.contains("mul.f32"), "Missing mul.f32 for dequant");
assert!(ptx.contains("add.f32"), "Missing add.f32 for accumulate");
assert!(
ptx.contains("shfl") || ptx.contains("shfl.down"),
"Missing warp shuffle for reduction"
);
}
#[test]
fn test_imp_1000b_q4k_block_layout() {
let kernel_type = KernelType::QuantizedGemm {
m: 1,
n: 128, k: 4096, };
let kernels = CudaKernels::new();
let ptx = kernels.generate_ptx(&kernel_type);
assert_eq!(4096 % 32, 0);
assert!(!ptx.is_empty());
assert!(ptx.contains("q4k_gemm_fused"));
}
#[test]
#[serial]
fn test_imp_1000b_q4k_gemm_integration() {
if !has_cuda() {
return;
}
let mut executor = CudaExecutor::new(0).expect("test");
let m = 32u32;
let n = 32u32;
let k = 128u32;
let a = vec![1.0f32; (m * k) as usize];
let b = vec![1.0f32; (k * n) as usize];
let mut c = vec![0.0f32; (m * n) as usize];
let result = executor.gemm(&a, &b, &mut c, m, n, k);
assert!(result.is_ok(), "GEMM failed: {:?}", result);
}
#[test]
fn test_imp_1000b_q4k_preset() {
let kernel = presets::q4k_inference(1, 4096, 4096);
match kernel {
KernelType::QuantizedGemm { m, n, k } => {
assert_eq!(m, 1, "Batch size should be 1");
assert_eq!(n, 4096, "Hidden dim should be 4096");
assert_eq!(k, 4096, "K dim should be 4096");
},
_ => panic!("Expected QuantizedGemm kernel type"),
}
}
#[test]
#[serial]
fn test_imp_1000c_async_pipeline_creation() {
if !has_cuda() {
return;
}
let context = CudaContext::new(0).expect("test");
let pipeline = AsyncPipeline::new(&context);
assert!(pipeline.is_ok(), "AsyncPipeline creation failed");
let pipeline = pipeline.expect("test");
assert!(!pipeline.is_active());
assert_eq!(pipeline.layers_queued(), 0);
}
#[test]
#[serial]
fn test_imp_1000c_async_pipeline_lifecycle() {
if !has_cuda() {
return;
}
let context = CudaContext::new(0).expect("test");
let mut pipeline = AsyncPipeline::new(&context).expect("test");
pipeline.begin();
assert!(pipeline.is_active());
let l0 = pipeline.enqueue_layer();
let l1 = pipeline.enqueue_layer();
let l2 = pipeline.enqueue_layer();
assert_eq!(l0, 0);
assert_eq!(l1, 1);
assert_eq!(l2, 2);
assert_eq!(pipeline.layers_queued(), 3);
let result = pipeline.end();
assert!(result.is_ok());
assert!(!pipeline.is_active());
}
#[test]
#[serial]
fn test_imp_1000c_async_dual_stream_sync() {
if !has_cuda() {
return;
}
let context = CudaContext::new(0).expect("test");
let pipeline = AsyncPipeline::new(&context).expect("test");
let sync_result = pipeline.sync();
assert!(sync_result.is_ok(), "Dual-stream sync failed");
}
#[test]
#[serial]
fn test_imp_1000c_async_stream_accessors() {
if !has_cuda() {
return;
}
let context = CudaContext::new(0).expect("test");
let pipeline = AsyncPipeline::new(&context).expect("test");
let _compute = pipeline.compute_stream();
let _transfer = pipeline.transfer_stream();
assert!(pipeline.compute_stream().synchronize().is_ok());
assert!(pipeline.transfer_stream().synchronize().is_ok());
}
#[test]
fn test_imp_1000d_optimization_hints_default() {
let hints = PtxOptimizationHints::default();
assert_eq!(hints.memory_pattern, MemoryPattern::Scalar);
assert_eq!(hints.register_tiling.width, 4);
assert_eq!(hints.register_tiling.height, 4);
assert_eq!(hints.bank_conflict_strategy, BankConflictStrategy::None);
assert!(!hints.enable_ilp);
assert!(!hints.uses_vectorized_loads());
assert_eq!(hints.vector_width(), 1);
}
include!("proptests_imp_1000d.rs");
include!("proptests_kernels.rs");
include!("proptests_tqa012d_forward.rs");
include!("proptests_tqa013e_batched.rs");
include!("proptests_tcov001s_tcov001t_tcov001u.rs");
include!("proptests_tcov002_kernel.rs");
include!("proptests_tcov010_more.rs");