use crate::driver::{CublasHandle, CudaContext, CudaStream, GpuBuffer};
#[test]
fn test_cublas_handle_lifecycle() {
let ctx = CudaContext::new(0).expect("CUDA context required");
let handle = CublasHandle::new(&ctx).expect("cuBLAS handle creation must succeed");
let stream = CudaStream::new(&ctx).expect("stream required");
handle.set_stream(&stream).expect("set_stream must succeed");
drop(handle);
}
#[test]
fn test_cublas_gemm_f32_small() {
let ctx = CudaContext::new(0).expect("CUDA context required");
let handle = CublasHandle::new(&ctx).expect("cuBLAS handle required");
let stream = CudaStream::new(&ctx).expect("stream required");
handle.set_stream(&stream).expect("set_stream must succeed");
let a_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let b_data: Vec<f32> = vec![5.0, 6.0, 7.0, 8.0];
let c_data: Vec<f32> = vec![0.0; 4];
let a_buf = GpuBuffer::from_host(&ctx, &a_data).expect("A upload");
let b_buf = GpuBuffer::from_host(&ctx, &b_data).expect("B upload");
let mut c_buf = GpuBuffer::from_host(&ctx, &c_data).expect("C upload");
handle
.gemm_f32_row_major(
2,
2,
2, 1.0, a_buf.as_ptr(), b_buf.as_ptr(), 0.0, c_buf.as_ptr(), )
.expect("gemm_f32_row_major must succeed");
stream.synchronize().expect("sync");
let mut result = vec![0.0f32; 4];
c_buf.copy_to_host(&mut result).expect("D2H");
assert!((result[0] - 19.0).abs() < 1e-3, "C[0,0] = {} expected 19.0", result[0]);
assert!((result[1] - 22.0).abs() < 1e-3, "C[0,1] = {} expected 22.0", result[1]);
assert!((result[2] - 43.0).abs() < 1e-3, "C[1,0] = {} expected 43.0", result[2]);
assert!((result[3] - 50.0).abs() < 1e-3, "C[1,1] = {} expected 50.0", result[3]);
}
#[test]
fn test_cublas_gemm_f16_training_shape() {
let ctx = CudaContext::new(0).expect("CUDA context required");
let handle = CublasHandle::new(&ctx).expect("cuBLAS handle required");
let stream = CudaStream::new(&ctx).expect("stream required");
handle.set_stream(&stream).expect("set_stream must succeed");
let m: usize = 4096;
let k: usize = 1024;
let n: usize = 4096;
let fp16_one: u16 = 0x3C00; let a_data: Vec<u16> = vec![fp16_one; m * k];
let b_data: Vec<u16> = vec![fp16_one; k * n];
let c_data: Vec<u16> = vec![0u16; m * n];
let a_buf = GpuBuffer::from_host(&ctx, &a_data).expect("A upload");
let b_buf = GpuBuffer::from_host(&ctx, &b_data).expect("B upload");
let mut c_buf = GpuBuffer::from_host(&ctx, &c_data).expect("C upload");
for _ in 0..5 {
handle
.gemm_f16_row_major(
m as i32,
n as i32,
k as i32,
1.0,
a_buf.as_ptr(),
b_buf.as_ptr(),
0.0,
c_buf.as_ptr(),
)
.expect("warmup GEMM");
}
stream.synchronize().expect("warmup sync");
let iters = 100;
let start = std::time::Instant::now();
for _ in 0..iters {
handle
.gemm_f16_row_major(
m as i32,
n as i32,
k as i32,
1.0,
a_buf.as_ptr(),
b_buf.as_ptr(),
0.0,
c_buf.as_ptr(),
)
.expect("timed GEMM");
}
stream.synchronize().expect("timed sync");
let elapsed = start.elapsed();
let mut result = vec![0u16; m * n];
c_buf.copy_to_host(&mut result).expect("D2H");
let expected_fp16: u16 = 0x6400; assert_eq!(
result[0], expected_fp16,
"C[0,0] should be 1024.0 (0x6400), got 0x{:04X}",
result[0]
);
assert_eq!(
result[m * n - 1],
expected_fp16,
"C[last] should be 1024.0 (0x6400), got 0x{:04X}",
result[m * n - 1]
);
let flops_per_gemm = 2.0 * m as f64 * n as f64 * k as f64;
let total_flops = flops_per_gemm * iters as f64;
let tflops = total_flops / elapsed.as_secs_f64() / 1e12;
eprintln!(
"cuBLAS FP16 GEMM [{m}x{k}] x [{k}x{n}]: {tflops:.1} TFLOP/s ({} iters, {:.1}ms)",
iters,
elapsed.as_millis()
);
assert!(tflops > 50.0, "cuBLAS FP16 GEMM must exceed 50 TFLOP/s, got {tflops:.1} TFLOP/s");
}
#[test]
fn test_cublas_all_training_shapes() {
let ctx = CudaContext::new(0).expect("CUDA context required");
let handle = CublasHandle::new(&ctx).expect("cuBLAS handle required");
let stream = CudaStream::new(&ctx).expect("stream required");
handle.set_stream(&stream).expect("set_stream must succeed");
let shapes: Vec<(usize, usize, usize, &str)> = vec![
(4096, 3072, 1024, "attn_qkv"),
(1024, 3072, 4096, "attn_qkv_backward"),
(4096, 1024, 1024, "attn_output"),
(4096, 8192, 1024, "ffn_up_gate"),
(4096, 1024, 4096, "ffn_down"),
(4096, 256, 1024, "gqa_kv"),
];
let fp16_one: u16 = 0x3C00;
for (m, n, k, name) in &shapes {
let a = GpuBuffer::from_host(&ctx, &vec![fp16_one; m * k]).expect("A");
let b = GpuBuffer::from_host(&ctx, &vec![fp16_one; k * n]).expect("B");
let mut c = GpuBuffer::from_host(&ctx, &vec![0u16; m * n]).expect("C");
handle
.gemm_f16_row_major(
*m as i32,
*n as i32,
*k as i32,
1.0,
a.as_ptr(),
b.as_ptr(),
0.0,
c.as_ptr(),
)
.expect(&format!("GEMM {name} [{m}x{k}] x [{k}x{n}] must succeed"));
stream.synchronize().expect("sync");
let mut result = vec![0u16; 1];
let first_elem_buf = unsafe { GpuBuffer::<u16>::from_raw_parts(c.as_ptr(), 1) };
let mut check = vec![0u16; 1];
first_elem_buf.copy_to_host(&mut check).expect("D2H check");
std::mem::forget(first_elem_buf);
eprintln!(
"Shape {name} [{m}x{k}] x [{k}x{n}]: C[0,0] = 0x{:04X} (expected ~{k}.0)",
check[0]
);
}
}