use super::super::instructions::{PtxInstruction, PtxOp, WmmaShape};
use crate::error::{GpuError, Result};
pub const MAX_TILE_ELEMENTS: usize = 16_777_216;
pub const MAX_TILE_DIM: usize = 4096;
#[derive(Debug, Clone, PartialEq)]
pub enum TileError {
TooManyElements {
actual: usize,
max: usize,
},
NonPowerOfTwo {
dim: usize,
},
DimensionTooLarge {
actual: usize,
max: usize,
},
InvalidWmmaShape {
shape: String,
},
}
impl std::fmt::Display for TileError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::TooManyElements { actual, max } => {
write!(f, "Tile has too many elements: {} > {}", actual, max)
}
Self::NonPowerOfTwo { dim } => {
write!(f, "Tile dimension {} is not a power of two", dim)
}
Self::DimensionTooLarge { actual, max } => {
write!(f, "Tile dimension {} exceeds maximum {}", actual, max)
}
Self::InvalidWmmaShape { shape } => {
write!(f, "Invalid WMMA shape: {}", shape)
}
}
}
}
impl std::error::Error for TileError {}
impl From<TileError> for GpuError {
fn from(err: TileError) -> Self {
GpuError::InvalidParameter(err.to_string())
}
}
pub fn validate_shape(shape: &[usize]) -> std::result::Result<(), TileError> {
let total_elements: usize = shape.iter().product();
if total_elements > MAX_TILE_ELEMENTS {
return Err(TileError::TooManyElements { actual: total_elements, max: MAX_TILE_ELEMENTS });
}
for &dim in shape {
if dim != 0 && !dim.is_power_of_two() {
return Err(TileError::NonPowerOfTwo { dim });
}
}
for &dim in shape {
if dim > MAX_TILE_DIM {
return Err(TileError::DimensionTooLarge { actual: dim, max: MAX_TILE_DIM });
}
}
Ok(())
}
pub fn validate_wmma_shape(shape: &WmmaShape) -> std::result::Result<(), TileError> {
let valid_shapes = [
(16, 16, 16), (8, 32, 16), (32, 8, 16), ];
let is_valid =
valid_shapes.iter().any(|&(m, n, k)| shape.m == m && shape.n == n && shape.k == k);
if !is_valid {
return Err(TileError::InvalidWmmaShape {
shape: format!("m{}n{}k{}", shape.m, shape.n, shape.k),
});
}
Ok(())
}
pub fn validate(instructions: &[PtxInstruction]) -> Result<()> {
for instr in instructions {
if matches!(
instr.op,
PtxOp::WmmaLoadA
| PtxOp::WmmaLoadB
| PtxOp::WmmaLoadC
| PtxOp::WmmaMma
| PtxOp::WmmaStoreD
) {
validate_wmma_shape(&WmmaShape::M16N16K16)?;
}
}
Ok(())
}
#[cfg(test)]
mod tests;
#[cfg(test)]
mod property_tests;