use super::super::instructions::{PtxInstruction, PtxOp, WmmaShape};
use crate::error::{GpuError, Result};
pub const MAX_TILE_ELEMENTS: usize = 16_777_216;
pub const MAX_TILE_DIM: usize = 4096;
#[derive(Debug, Clone, PartialEq)]
pub enum TileError {
TooManyElements {
actual: usize,
max: usize,
},
NonPowerOfTwo {
dim: usize,
},
DimensionTooLarge {
actual: usize,
max: usize,
},
InvalidWmmaShape {
shape: String,
},
}
impl std::fmt::Display for TileError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::TooManyElements { actual, max } => {
write!(f, "Tile has too many elements: {} > {}", actual, max)
}
Self::NonPowerOfTwo { dim } => {
write!(f, "Tile dimension {} is not a power of two", dim)
}
Self::DimensionTooLarge { actual, max } => {
write!(f, "Tile dimension {} exceeds maximum {}", actual, max)
}
Self::InvalidWmmaShape { shape } => {
write!(f, "Invalid WMMA shape: {}", shape)
}
}
}
}
impl std::error::Error for TileError {}
impl From<TileError> for GpuError {
fn from(err: TileError) -> Self {
GpuError::InvalidParameter(err.to_string())
}
}
pub fn validate_shape(shape: &[usize]) -> std::result::Result<(), TileError> {
let total_elements: usize = shape.iter().product();
if total_elements > MAX_TILE_ELEMENTS {
return Err(TileError::TooManyElements {
actual: total_elements,
max: MAX_TILE_ELEMENTS,
});
}
for &dim in shape {
if dim != 0 && !dim.is_power_of_two() {
return Err(TileError::NonPowerOfTwo { dim });
}
}
for &dim in shape {
if dim > MAX_TILE_DIM {
return Err(TileError::DimensionTooLarge {
actual: dim,
max: MAX_TILE_DIM,
});
}
}
Ok(())
}
pub fn validate_wmma_shape(shape: &WmmaShape) -> std::result::Result<(), TileError> {
let valid_shapes = [
(16, 16, 16), (8, 32, 16), (32, 8, 16), ];
let is_valid = valid_shapes
.iter()
.any(|&(m, n, k)| shape.m == m && shape.n == n && shape.k == k);
if !is_valid {
return Err(TileError::InvalidWmmaShape {
shape: format!("m{}n{}k{}", shape.m, shape.n, shape.k),
});
}
Ok(())
}
pub fn validate(instructions: &[PtxInstruction]) -> Result<()> {
for instr in instructions {
if matches!(
instr.op,
PtxOp::WmmaLoadA
| PtxOp::WmmaLoadB
| PtxOp::WmmaLoadC
| PtxOp::WmmaMma
| PtxOp::WmmaStoreD
) {
validate_wmma_shape(&WmmaShape::M16N16K16)?;
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::GpuError;
#[test]
fn test_tile_error_display_too_many_elements() {
let err = TileError::TooManyElements {
actual: 20_000_000,
max: 16_777_216,
};
let msg = err.to_string();
assert!(
msg.contains("20000000"),
"Should contain actual count: {}",
msg
);
assert!(
msg.contains("16777216"),
"Should contain max count: {}",
msg
);
assert!(
msg.contains("too many elements"),
"Should describe the error: {}",
msg
);
}
#[test]
fn test_tile_error_display_dimension_too_large() {
let err = TileError::DimensionTooLarge {
actual: 8192,
max: 4096,
};
let msg = err.to_string();
assert!(
msg.contains("8192"),
"Should contain actual dimension: {}",
msg
);
assert!(
msg.contains("4096"),
"Should contain max dimension: {}",
msg
);
assert!(
msg.contains("exceeds"),
"Should describe the error: {}",
msg
);
}
#[test]
fn test_tile_error_display_invalid_wmma_shape() {
let err = TileError::InvalidWmmaShape {
shape: "m64n64k64".to_string(),
};
let msg = err.to_string();
assert!(msg.contains("m64n64k64"), "Should contain shape: {}", msg);
assert!(
msg.contains("Invalid WMMA"),
"Should describe the error: {}",
msg
);
}
#[test]
fn test_tile_error_display_non_power_of_two() {
let err = TileError::NonPowerOfTwo { dim: 123 };
let msg = err.to_string();
assert!(msg.contains("123"), "Should contain dimension: {}", msg);
assert!(
msg.contains("power of two"),
"Should describe the error: {}",
msg
);
}
#[test]
fn test_tile_error_implements_std_error() {
let err = TileError::NonPowerOfTwo { dim: 42 };
let std_err: &dyn std::error::Error = &err;
assert!(std_err.source().is_none());
}
#[test]
fn test_tile_error_to_gpu_error_conversion() {
let tile_err = TileError::TooManyElements {
actual: 100,
max: 50,
};
let gpu_err: GpuError = tile_err.clone().into();
match gpu_err {
GpuError::InvalidParameter(msg) => {
assert!(msg.contains("100"), "Should contain actual: {}", msg);
assert!(msg.contains("50"), "Should contain max: {}", msg);
}
_ => panic!("Expected InvalidParameter variant"),
}
}
#[test]
fn test_tile_error_conversion_non_power_of_two() {
let tile_err = TileError::NonPowerOfTwo { dim: 37 };
let gpu_err: GpuError = tile_err.into();
match gpu_err {
GpuError::InvalidParameter(msg) => {
assert!(msg.contains("37"));
assert!(msg.contains("power of two"));
}
_ => panic!("Expected InvalidParameter"),
}
}
#[test]
fn test_tile_error_conversion_dimension_too_large() {
let tile_err = TileError::DimensionTooLarge {
actual: 10000,
max: 4096,
};
let gpu_err: GpuError = tile_err.into();
match gpu_err {
GpuError::InvalidParameter(msg) => {
assert!(msg.contains("10000"));
assert!(msg.contains("4096"));
}
_ => panic!("Expected InvalidParameter"),
}
}
#[test]
fn test_tile_error_conversion_invalid_wmma() {
let tile_err = TileError::InvalidWmmaShape {
shape: "m99n99k99".to_string(),
};
let gpu_err: GpuError = tile_err.into();
match gpu_err {
GpuError::InvalidParameter(msg) => {
assert!(msg.contains("m99n99k99"));
}
_ => panic!("Expected InvalidParameter"),
}
}
#[test]
fn test_validate_shape_multiple_dims_with_non_power_of_two() {
let result = validate_shape(&[32, 100]);
assert!(matches!(result, Err(TileError::NonPowerOfTwo { dim: 100 })));
}
#[test]
fn test_validate_shape_first_dim_non_power_of_two() {
let result = validate_shape(&[13, 16]);
assert!(matches!(result, Err(TileError::NonPowerOfTwo { dim: 13 })));
}
#[test]
fn test_validate_shape_too_many_elements_exact_boundary() {
assert!(validate_shape(&[4096, 4096]).is_ok());
assert!(matches!(
validate_shape(&[4096, 4096, 2]),
Err(TileError::TooManyElements { .. })
));
}
#[test]
fn test_validate_shape_single_element() {
assert!(validate_shape(&[1]).is_ok());
assert!(validate_shape(&[1, 1]).is_ok());
assert!(validate_shape(&[1, 1, 1]).is_ok());
}
#[test]
fn test_validate_shape_large_valid_multidimensional() {
assert!(validate_shape(&[64, 64, 64]).is_ok()); assert!(validate_shape(&[256, 256, 256]).is_ok()); }
#[test]
fn test_validate_shape_dimension_boundary() {
assert!(validate_shape(&[4096]).is_ok());
assert!(matches!(
validate_shape(&[8192]),
Err(TileError::DimensionTooLarge {
actual: 8192,
max: 4096
})
));
}
#[test]
fn test_validate_shape_zero_dimension_is_ok() {
let result = validate_shape(&[0, 16]);
assert!(result.is_ok());
}
#[test]
fn test_wmma_shape_validation_all_valid() {
assert!(validate_wmma_shape(&WmmaShape::M16N16K16).is_ok());
assert!(validate_wmma_shape(&WmmaShape::M8N32K16).is_ok());
assert!(validate_wmma_shape(&WmmaShape::M32N8K16).is_ok());
}
#[test]
fn test_wmma_shape_validation_invalid_combinations() {
let cases = [
WmmaShape { m: 8, n: 8, k: 8 }, WmmaShape {
m: 32,
n: 32,
k: 32,
}, WmmaShape {
m: 16,
n: 32,
k: 16,
}, WmmaShape { m: 8, n: 16, k: 16 }, WmmaShape { m: 1, n: 1, k: 1 }, WmmaShape {
m: 64,
n: 64,
k: 64,
}, ];
for shape in cases {
assert!(
validate_wmma_shape(&shape).is_err(),
"Shape m{}n{}k{} should be invalid",
shape.m,
shape.n,
shape.k
);
}
}
#[test]
fn test_wmma_invalid_error_message_format() {
let shape = WmmaShape { m: 24, n: 24, k: 8 };
let result = validate_wmma_shape(&shape);
match result {
Err(TileError::InvalidWmmaShape { shape: s }) => {
assert_eq!(s, "m24n24k8");
}
_ => panic!("Expected InvalidWmmaShape error"),
}
}
#[test]
fn test_validate_wmma_load_a() {
let instructions = vec![PtxInstruction::new(
PtxOp::WmmaLoadA,
crate::ptx::types::PtxType::F16,
)];
assert!(validate(&instructions).is_ok());
}
#[test]
fn test_validate_wmma_load_b() {
let instructions = vec![PtxInstruction::new(
PtxOp::WmmaLoadB,
crate::ptx::types::PtxType::F16,
)];
assert!(validate(&instructions).is_ok());
}
#[test]
fn test_validate_wmma_load_c() {
let instructions = vec![PtxInstruction::new(
PtxOp::WmmaLoadC,
crate::ptx::types::PtxType::F32,
)];
assert!(validate(&instructions).is_ok());
}
#[test]
fn test_validate_wmma_store_d() {
let instructions = vec![PtxInstruction::new(
PtxOp::WmmaStoreD,
crate::ptx::types::PtxType::F32,
)];
assert!(validate(&instructions).is_ok());
}
#[test]
fn test_validate_mixed_instructions_with_wmma() {
let instructions = vec![
PtxInstruction::new(PtxOp::Add, crate::ptx::types::PtxType::F32),
PtxInstruction::new(PtxOp::WmmaLoadA, crate::ptx::types::PtxType::F16),
PtxInstruction::new(PtxOp::Mul, crate::ptx::types::PtxType::F32),
PtxInstruction::new(PtxOp::WmmaLoadB, crate::ptx::types::PtxType::F16),
PtxInstruction::new(PtxOp::WmmaLoadC, crate::ptx::types::PtxType::F32),
PtxInstruction::new(PtxOp::WmmaMma, crate::ptx::types::PtxType::F32),
PtxInstruction::new(PtxOp::WmmaStoreD, crate::ptx::types::PtxType::F32),
PtxInstruction::new(PtxOp::Sub, crate::ptx::types::PtxType::F32),
];
assert!(validate(&instructions).is_ok());
}
#[test]
fn test_validate_all_wmma_ops_in_sequence() {
let instructions = vec![
PtxInstruction::new(PtxOp::WmmaLoadA, crate::ptx::types::PtxType::F16),
PtxInstruction::new(PtxOp::WmmaLoadB, crate::ptx::types::PtxType::F16),
PtxInstruction::new(PtxOp::WmmaLoadC, crate::ptx::types::PtxType::F32),
PtxInstruction::new(PtxOp::WmmaMma, crate::ptx::types::PtxType::F32),
PtxInstruction::new(PtxOp::WmmaStoreD, crate::ptx::types::PtxType::F32),
];
assert!(validate(&instructions).is_ok());
}
#[test]
fn test_tile_error_clone() {
let err1 = TileError::TooManyElements {
actual: 1000,
max: 500,
};
let err2 = err1.clone();
assert_eq!(err1, err2);
}
#[test]
fn test_tile_error_partial_eq() {
let err1 = TileError::NonPowerOfTwo { dim: 17 };
let err2 = TileError::NonPowerOfTwo { dim: 17 };
let err3 = TileError::NonPowerOfTwo { dim: 19 };
assert_eq!(err1, err2);
assert_ne!(err1, err3);
}
#[test]
fn test_tile_error_debug() {
let err = TileError::DimensionTooLarge {
actual: 5000,
max: 4096,
};
let debug_str = format!("{:?}", err);
assert!(debug_str.contains("DimensionTooLarge"));
assert!(debug_str.contains("5000"));
assert!(debug_str.contains("4096"));
}
#[test]
fn test_constants_values() {
assert_eq!(MAX_TILE_ELEMENTS, 16_777_216);
assert_eq!(MAX_TILE_DIM, 4096);
}
#[test]
fn test_power_of_two_tiles_valid() {
assert!(validate_shape(&[8, 16, 32, 64]).is_ok());
assert!(validate_shape(&[128, 128]).is_ok());
assert!(validate_shape(&[1024, 1024]).is_ok());
assert!(validate_shape(&[4096]).is_ok());
}
#[test]
fn test_non_power_of_two_rejected() {
assert!(matches!(
validate_shape(&[7]),
Err(TileError::NonPowerOfTwo { dim: 7 })
));
assert!(matches!(
validate_shape(&[100]),
Err(TileError::NonPowerOfTwo { dim: 100 })
));
assert!(validate_shape(&[17]).is_err());
assert!(validate_shape(&[1000]).is_err());
}
#[test]
fn test_max_tile_elements_enforced() {
assert!(validate_shape(&[4096, 4096]).is_ok());
assert!(matches!(
validate_shape(&[8192, 4096]),
Err(TileError::TooManyElements { .. })
));
}
#[test]
fn test_max_dimension_enforced() {
assert!(validate_shape(&[4096]).is_ok());
assert!(matches!(
validate_shape(&[8192]),
Err(TileError::DimensionTooLarge { .. })
));
}
#[test]
fn test_validation_catches_invalid_at_build_time() {
let result = validate_shape(&[12345]);
assert!(result.is_err());
}
#[test]
fn test_constraints_backend_agnostic() {
let shape = [32, 32];
assert!(validate_shape(&shape).is_ok());
}
#[test]
fn test_small_tiles_valid() {
assert!(validate_shape(&[4]).is_ok());
assert!(validate_shape(&[8]).is_ok());
assert!(validate_shape(&[2, 2]).is_ok());
}
#[test]
fn test_empty_shape_valid() {
assert!(validate_shape(&[]).is_ok());
}
#[test]
fn test_zero_dimension() {
let result = validate_shape(&[0, 16]);
assert!(result.is_ok() || result.is_err());
}
#[test]
fn test_wmma_valid_shapes() {
assert!(validate_wmma_shape(&WmmaShape::M16N16K16).is_ok());
assert!(validate_wmma_shape(&WmmaShape::M8N32K16).is_ok());
assert!(validate_wmma_shape(&WmmaShape::M32N8K16).is_ok());
}
#[test]
fn test_wmma_invalid_shapes() {
let invalid = WmmaShape {
m: 32,
n: 32,
k: 16,
};
assert!(validate_wmma_shape(&invalid).is_err());
}
#[test]
fn test_error_messages_actionable() {
let err = validate_shape(&[17]).unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("17") && msg.contains("power of two"),
"Error message should be actionable: {}",
msg
);
}
#[test]
fn test_validate_instructions_empty() {
assert!(validate(&[]).is_ok());
}
#[test]
fn test_validate_instructions_no_wmma() {
let instructions = vec![
PtxInstruction::new(PtxOp::Add, crate::ptx::types::PtxType::F32),
PtxInstruction::new(PtxOp::Mul, crate::ptx::types::PtxType::F32),
];
assert!(validate(&instructions).is_ok());
}
#[test]
fn test_validate_instructions_with_wmma() {
let instructions = vec![PtxInstruction::new(
PtxOp::WmmaMma,
crate::ptx::types::PtxType::F32,
)];
assert!(validate(&instructions).is_ok());
}
}
#[cfg(test)]
mod property_tests {
use super::*;
use proptest::prelude::*;
fn power_of_two() -> impl Strategy<Value = usize> {
(0u32..12).prop_map(|exp| 1usize << exp) }
fn non_power_of_two() -> impl Strategy<Value = usize> {
(3usize..1000).prop_filter("not power of two", |&n| !n.is_power_of_two())
}
proptest! {
#[test]
fn power_of_two_single_dim_valid(dim in power_of_two()) {
if dim > 0 && dim <= MAX_TILE_DIM {
prop_assert!(validate_shape(&[dim]).is_ok(),
"Power of two {} should be valid", dim);
}
}
#[test]
fn non_power_of_two_rejected(dim in non_power_of_two()) {
let result = validate_shape(&[dim]);
prop_assert!(result.is_err(), "Non-power-of-two {} should be rejected", dim);
if let Err(TileError::NonPowerOfTwo { dim: d }) = result {
prop_assert_eq!(d, dim);
}
}
#[test]
fn total_elements_within_limit(exp1 in 0u32..10, exp2 in 0u32..10) {
let d1 = 1usize << exp1;
let d2 = 1usize << exp2;
let total = d1.saturating_mul(d2);
if d1 <= MAX_TILE_DIM && d2 <= MAX_TILE_DIM && total <= MAX_TILE_ELEMENTS {
prop_assert!(validate_shape(&[d1, d2]).is_ok(),
"{}x{} = {} should be valid", d1, d2, total);
}
}
#[test]
fn tile_error_display_contains_values(dim in non_power_of_two()) {
let err = TileError::NonPowerOfTwo { dim };
let msg = err.to_string();
prop_assert!(msg.contains(&dim.to_string()),
"Error message should contain dimension: {}", msg);
}
#[test]
fn validate_never_panics(dims in prop::collection::vec(0usize..10000, 0..5)) {
let _ = validate_shape(&dims);
}
#[test]
fn wmma_valid_shapes_pass(_dummy in 0u8..3) {
let shapes = [
WmmaShape::M16N16K16,
WmmaShape::M8N32K16,
WmmaShape::M32N8K16,
];
for shape in shapes {
prop_assert!(validate_wmma_shape(&shape).is_ok(),
"Valid WMMA shape {:?} should pass", shape);
}
}
#[test]
fn wmma_invalid_shapes_fail(m in 1u32..100, n in 1u32..100, k in 1u32..100) {
let shape = WmmaShape { m, n, k };
let is_valid = matches!(
(m, n, k),
(16, 16, 16) | (8, 32, 16) | (32, 8, 16)
);
let result = validate_wmma_shape(&shape);
prop_assert_eq!(result.is_ok(), is_valid,
"WMMA shape m{}n{}k{} validity mismatch", m, n, k);
}
}
}