#[cfg(all(test, feature = "cuda"))]
mod tests {
use super::*;
#[test]
fn test_cuda_backend_new() {
let backend = CudaBackend::new(1, 4096, 4096, 64);
assert_eq!(backend.m, 1);
assert_eq!(backend.n, 4096);
assert_eq!(backend.k, 4096);
assert_eq!(backend.head_dim, 64);
assert_eq!(backend.num_heads, 32); assert_eq!(backend.max_seq_len, 2048); }
#[test]
fn test_cuda_backend_with_num_heads() {
let backend = CudaBackend::new(1, 4096, 4096, 64).with_num_heads(16);
assert_eq!(backend.num_heads, 16);
}
#[test]
fn test_cuda_backend_with_max_seq_len() {
let backend = CudaBackend::new(1, 4096, 4096, 64).with_max_seq_len(4096);
assert_eq!(backend.max_seq_len, 4096);
}
#[test]
fn test_cuda_backend_builder_chain() {
let backend = CudaBackend::new(1, 4096, 4096, 128)
.with_num_heads(8)
.with_max_seq_len(1024);
assert_eq!(backend.num_heads, 8);
assert_eq!(backend.max_seq_len, 1024);
assert_eq!(backend.head_dim, 128);
}
#[test]
fn test_cuda_backend_clone() {
let backend = CudaBackend::new(2, 1024, 2048, 64);
let cloned = backend.clone();
assert_eq!(cloned.m, 2);
assert_eq!(cloned.n, 1024);
assert_eq!(cloned.k, 2048);
}
#[test]
fn test_cuda_backend_debug() {
let backend = CudaBackend::new(1, 4096, 4096, 64);
let debug_str = format!("{:?}", backend);
assert!(debug_str.contains("CudaBackend"));
assert!(debug_str.contains("4096"));
}
#[test]
fn test_q4k_gemm_kernel_name() {
let backend = CudaBackend::new(1, 4096, 4096, 64);
assert_eq!(backend.q4k_gemm_kernel_name(), "q4k_gemm_fused");
}
#[test]
fn test_q4k_blocks_per_row() {
let backend = CudaBackend::new(1, 4096, 4096, 64);
assert_eq!(backend.q4k_blocks_per_row(), 128); }
#[test]
fn test_q4k_blocks_per_row_small() {
let backend = CudaBackend::new(1, 1024, 256, 64);
assert_eq!(backend.q4k_blocks_per_row(), 8); }
#[test]
fn test_q4k_weight_bytes() {
let backend = CudaBackend::new(1, 1024, 1024, 64);
assert_eq!(backend.q4k_weight_bytes(), 589_824);
}
#[test]
fn test_q4k_weight_bytes_large() {
let backend = CudaBackend::new(1, 4096, 4096, 64);
assert_eq!(backend.q4k_weight_bytes(), 9_437_184);
}
#[test]
fn test_q4k_gemm_ptx_generation() {
let backend = CudaBackend::new(1, 1024, 256, 64);
let ptx = backend.q4k_gemm_ptx();
assert!(!ptx.is_empty());
assert!(ptx.contains(".version"));
}
#[test]
fn test_q4k_gemm_ptx_caching() {
let backend = CudaBackend::new(1, 1024, 256, 64);
let ptx1 = backend.q4k_gemm_ptx();
let ptx2 = backend.q4k_gemm_ptx(); assert_eq!(ptx1, ptx2);
}
#[test]
fn test_flash_attention_kernel_name_causal() {
let backend = CudaBackend::new(1, 4096, 4096, 64);
assert_eq!(
backend.flash_attention_kernel_name(true),
"flash_attention_causal"
);
}
#[test]
fn test_flash_attention_kernel_name_non_causal() {
let backend = CudaBackend::new(1, 4096, 4096, 64);
assert_eq!(
backend.flash_attention_kernel_name(false),
"flash_attention"
);
}
#[test]
fn test_flash_attention_smem_bytes() {
let backend = CudaBackend::new(1, 4096, 4096, 64);
assert_eq!(backend.flash_attention_smem_bytes(), 49152);
}
#[test]
fn test_flash_attention_smem_bytes_large_head_dim() {
let backend = CudaBackend::new(1, 4096, 4096, 128);
assert_eq!(backend.flash_attention_smem_bytes(), 98304);
}
#[test]
fn test_flash_attention_ptx_generation() {
let backend = CudaBackend::new(1, 4096, 4096, 64);
let ptx = backend.flash_attention_ptx(512, 64, true);
assert!(!ptx.is_empty());
assert!(ptx.contains(".version"));
}
#[test]
fn test_flash_attention_causal_ptx_caching() {
let backend = CudaBackend::new(1, 4096, 4096, 64).with_max_seq_len(512);
let ptx1 = backend.flash_attention_causal_ptx();
let ptx2 = backend.flash_attention_causal_ptx(); assert_eq!(ptx1, ptx2);
}
#[test]
fn test_kv_cache_bytes_per_layer() {
let backend = CudaBackend::new(1, 4096, 4096, 64)
.with_num_heads(32)
.with_max_seq_len(2048);
assert_eq!(backend.kv_cache_bytes_per_layer(), 33_554_432);
}
#[test]
fn test_kv_cache_total_bytes() {
let backend = CudaBackend::new(1, 4096, 4096, 64)
.with_num_heads(32)
.with_max_seq_len(2048);
let per_layer = backend.kv_cache_bytes_per_layer();
assert_eq!(backend.kv_cache_total_bytes(22), per_layer * 22);
}
#[test]
fn test_kv_cache_page_tokens() {
let backend = CudaBackend::new(1, 4096, 4096, 64);
assert_eq!(backend.kv_cache_page_tokens(), 64);
}
#[test]
fn test_kv_cache_pages_needed_exact() {
let backend = CudaBackend::new(1, 4096, 4096, 64);
assert_eq!(backend.kv_cache_pages_needed(64), 1);
assert_eq!(backend.kv_cache_pages_needed(128), 2);
}
#[test]
fn test_kv_cache_pages_needed_round_up() {
let backend = CudaBackend::new(1, 4096, 4096, 64);
assert_eq!(backend.kv_cache_pages_needed(65), 2);
assert_eq!(backend.kv_cache_pages_needed(100), 2);
assert_eq!(backend.kv_cache_pages_needed(129), 3);
}
#[test]
fn test_kv_cache_pages_needed_large() {
let backend = CudaBackend::new(1, 4096, 4096, 64);
assert_eq!(backend.kv_cache_pages_needed(2048), 32);
}
#[test]
fn test_q4k_gemm_launch_config() {
let backend = CudaBackend::new(1, 1024, 4096, 64);
let (grid, block) = backend.q4k_gemm_launch_config();
assert_eq!(grid, (32, 1, 1));
assert_eq!(block, (1024, 1, 1));
}
#[test]
fn test_flash_attention_launch_config() {
let backend = CudaBackend::new(1, 4096, 4096, 64).with_num_heads(8);
let (grid, block) = backend.flash_attention_launch_config(256);
assert_eq!(grid, (4, 8, 1));
assert_eq!(block, (4096, 1, 1));
}
#[test]
fn test_validate_dimensions_valid() {
let backend = CudaBackend::new(1, 4096, 4096, 64);
assert!(backend.validate_dimensions());
}
#[test]
fn test_validate_dimensions_k_not_divisible() {
let backend = CudaBackend::new(1, 4096, 4095, 64); assert!(!backend.validate_dimensions());
}
#[test]
fn test_validate_dimensions_head_dim_not_power_of_two() {
let backend = CudaBackend::new(1, 4096, 4096, 80); assert!(!backend.validate_dimensions());
}
#[test]
fn test_validate_dimensions_zero_m() {
let backend = CudaBackend::new(0, 4096, 4096, 64);
assert!(!backend.validate_dimensions());
}
#[test]
fn test_validate_dimensions_zero_n() {
let backend = CudaBackend::new(1, 0, 4096, 64);
assert!(!backend.validate_dimensions());
}
#[test]
fn test_validate_dimensions_zero_k() {
let backend = CudaBackend::new(1, 4096, 0, 64);
assert!(!backend.validate_dimensions());
}
#[test]
fn test_validate_dimensions_zero_head_dim() {
let backend = CudaBackend::new(1, 4096, 4096, 0);
assert!(!backend.validate_dimensions());
}
#[test]
fn test_ptx_target() {
let backend = CudaBackend::new(1, 4096, 4096, 64);
assert_eq!(backend.ptx_target(), "sm_89");
}
#[test]
fn test_ptx_version() {
let backend = CudaBackend::new(1, 4096, 4096, 64);
assert_eq!(backend.ptx_version(), (8, 0));
}
}