#![cfg(feature = "cuda")]
use trueno_gpu::ptx::optimize::tile_validation::{
validate_shape, validate_wmma_shape, TileError, MAX_TILE_DIM, MAX_TILE_ELEMENTS,
};
use trueno_gpu::ptx::{PtxInstruction, PtxOp, PtxType, WmmaShape};
#[test]
fn f034_shared_memory_sizing() {
const L1_CACHE_SIZE: usize = 48 * 1024; let optimal_smem: usize = ((L1_CACHE_SIZE / 3) as f64).sqrt() as usize;
assert!(
(64..=16384).contains(&optimal_smem),
"F034 FALSIFIED: Optimal shared memory {} not in valid range",
optimal_smem
);
let tile_size = 32u32; let smem_elements = tile_size * tile_size;
assert!(
validate_shape(&[tile_size as usize, tile_size as usize]).is_ok(),
"F034 FALSIFIED: Tile {}x{} should be valid",
tile_size,
tile_size
);
let smem_bytes = smem_elements as usize * 4;
assert!(
smem_bytes <= L1_CACHE_SIZE,
"F034 FALSIFIED: Shared memory {} exceeds L1 cache {}",
smem_bytes,
L1_CACHE_SIZE
);
println!(
"F034 PASSED: Shared memory sizing (optimal={}, tile={}x{}, bytes={})",
optimal_smem, tile_size, tile_size, smem_bytes
);
}
#[test]
fn f035_coalesced_vs_strided_bandwidth() {
const WARP_SIZE: usize = 32;
const ELEMENT_SIZE: usize = 4; const CACHE_LINE_SIZE: usize = 128;
let coalesced_bytes = WARP_SIZE * ELEMENT_SIZE;
assert_eq!(
coalesced_bytes, CACHE_LINE_SIZE,
"F035: Coalesced access should fill one cache line"
);
let strided_cache_lines = WARP_SIZE; let strided_bytes = strided_cache_lines * CACHE_LINE_SIZE;
let bandwidth_ratio = strided_bytes / coalesced_bytes;
assert!(
bandwidth_ratio >= 4,
"F035 FALSIFIED: Bandwidth ratio {} should be >= 4",
bandwidth_ratio
);
assert!(
WARP_SIZE.is_power_of_two(),
"F035: Warp size must be power of two"
);
println!(
"F035 PASSED: Coalesced vs strided bandwidth ratio = {}x",
bandwidth_ratio
);
}
#[test]
fn f036_power_of_two_tile_occupancy() {
let power_of_two_tiles: Vec<usize> = vec![16, 32, 64, 128, 256];
let non_power_of_two_tiles: Vec<usize> = vec![17, 33, 65, 100, 200];
for tile in &power_of_two_tiles {
let result = validate_shape(&[*tile]);
assert!(
result.is_ok(),
"F036 FALSIFIED: Power-of-two tile {} should be valid",
tile
);
}
for tile in &non_power_of_two_tiles {
let result = validate_shape(&[*tile]);
assert!(
result.is_err(),
"F036 FALSIFIED: Non-power-of-two tile {} should be rejected",
tile
);
}
println!(
"F036 PASSED: Power-of-two tiles validated (valid={}, rejected={})",
power_of_two_tiles.len(),
non_power_of_two_tiles.len()
);
}
#[test]
fn f037_max_tile_elements() {
let at_limit = validate_shape(&[4096, 4096]); assert!(
at_limit.is_ok(),
"F037 FALSIFIED: Tile at limit should pass"
);
let over_limit = validate_shape(&[8192, 4096]); assert!(
matches!(over_limit, Err(TileError::TooManyElements { .. })),
"F037 FALSIFIED: Tile over limit should fail"
);
assert_eq!(
MAX_TILE_ELEMENTS, 16_777_216,
"F037: MAX_TILE_ELEMENTS should be 16M"
);
println!(
"F037 PASSED: Maximum tile elements = {} enforced",
MAX_TILE_ELEMENTS
);
}
#[test]
fn f038_max_single_dimension() {
let at_limit = validate_shape(&[4096]);
assert!(
at_limit.is_ok(),
"F038 FALSIFIED: Dimension at limit should pass"
);
let over_limit = validate_shape(&[8192]);
assert!(
matches!(over_limit, Err(TileError::DimensionTooLarge { .. })),
"F038 FALSIFIED: Dimension over limit should fail"
);
assert_eq!(MAX_TILE_DIM, 4096, "F038: MAX_TILE_DIM should be 4096");
println!("F038 PASSED: Maximum dimension = {} enforced", MAX_TILE_DIM);
}
#[test]
fn f039_stride_aware_offsets() {
let warp_size = 32usize;
let element_size = 4usize;
for thread_id in 0..warp_size {
let offset = thread_id * element_size;
assert_eq!(
offset,
thread_id * 4,
"F039: Sequential offset for thread {} should be {}",
thread_id,
thread_id * 4
);
}
let stride = 128usize; for thread_id in 0..warp_size {
let offset = thread_id * stride * element_size;
assert_eq!(
offset % element_size,
0,
"F039: Strided offset must be aligned"
);
}
let tile_size = 16usize;
let block_id = 5usize;
let base_offset = block_id * tile_size * tile_size * element_size;
assert!(
base_offset.is_multiple_of(128) || tile_size < 32,
"F039: Block offset should be cache-aligned for large tiles"
);
println!("F039 PASSED: Stride-aware offsets computed correctly");
}
#[test]
fn test_wmma_shapes() {
assert!(
validate_wmma_shape(&WmmaShape::M16N16K16).is_ok(),
"16x16x16 should be valid"
);
assert!(
validate_wmma_shape(&WmmaShape::M8N32K16).is_ok(),
"8x32x16 should be valid"
);
assert!(
validate_wmma_shape(&WmmaShape::M32N8K16).is_ok(),
"32x8x16 should be valid"
);
let invalid = WmmaShape {
m: 16,
n: 32,
k: 16,
};
assert!(
validate_wmma_shape(&invalid).is_err(),
"16x32x16 should be invalid"
);
println!("WMMA shape validation verified");
}
#[test]
fn test_error_messages_actionable() {
let err = validate_shape(&[17]).unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("17") && msg.contains("power of two"),
"Error should mention the value and constraint: {}",
msg
);
let err = validate_shape(&[8192, 4096]).unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("too many elements"),
"Error should mention element count: {}",
msg
);
let err = validate_shape(&[8192]).unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("exceeds maximum"),
"Error should mention size limit: {}",
msg
);
println!("Error messages are actionable");
}
#[test]
fn test_memory_access_constants() {
const WARP_SIZE: usize = 32;
const L2_CACHE_LINE: usize = 128; const L1_CACHE_LINE: usize = 128;
assert_eq!(
WARP_SIZE * 4,
L2_CACHE_LINE,
"Warp f32 access should fill cache line"
);
assert!(
L1_CACHE_LINE.is_power_of_two(),
"Cache line must be power of two"
);
assert!(
L2_CACHE_LINE.is_power_of_two(),
"Cache line must be power of two"
);
println!("Memory access constants verified");
}
#[test]
fn test_tile_edge_cases() {
assert!(validate_shape(&[]).is_ok(), "Empty shape should be valid");
assert!(
validate_shape(&[1]).is_ok(),
"Single element should be valid"
);
assert!(validate_shape(&[4096, 4]).is_ok(), "4096x4 should be valid");
assert!(
validate_shape(&[4096, 4096]).is_ok(),
"4096x4096 should be valid"
);
assert!(
validate_shape(&[4096, 4096, 2]).is_err(),
"4096x4096x2 should exceed element limit"
);
println!("Tile edge cases verified");
}
#[test]
fn test_wmma_instruction_validation() {
use trueno_gpu::ptx::optimize::tile_validation::validate;
assert!(validate(&[]).is_ok(), "Empty should be valid");
let non_wmma = vec![
PtxInstruction::new(PtxOp::Add, PtxType::F32),
PtxInstruction::new(PtxOp::Mul, PtxType::F32),
PtxInstruction::new(PtxOp::Ld, PtxType::F32),
PtxInstruction::new(PtxOp::St, PtxType::F32),
];
assert!(validate(&non_wmma).is_ok(), "Non-WMMA should be valid");
let wmma = vec![
PtxInstruction::new(PtxOp::WmmaLoadA, PtxType::F32),
PtxInstruction::new(PtxOp::WmmaLoadB, PtxType::F32),
PtxInstruction::new(PtxOp::WmmaMma, PtxType::F32),
PtxInstruction::new(PtxOp::WmmaStoreD, PtxType::F32),
];
assert!(validate(&wmma).is_ok(), "WMMA should be valid");
println!("WMMA instruction validation verified");
}