use std::fmt;
const SPIRV_WORD_SIZE: usize = 4;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BytecodeError {
InvalidLength(usize),
MisalignedPointer,
}
impl fmt::Display for BytecodeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::InvalidLength(len) => {
write!(f, "SPIR-V byte length {len} is not a multiple of 4")
}
Self::MisalignedPointer => {
write!(f, "SPIR-V byte slice pointer is not 4-byte aligned")
}
}
}
}
impl std::error::Error for BytecodeError {}
pub fn cast_to_u32(bytes: &[u8]) -> Result<&[u32], BytecodeError> {
if bytes.is_empty() {
return Ok(&[]);
}
if bytes.len() % SPIRV_WORD_SIZE != 0 {
return Err(BytecodeError::InvalidLength(bytes.len()));
}
if (bytes.as_ptr() as usize) % SPIRV_WORD_SIZE != 0 {
return Err(BytecodeError::MisalignedPointer);
}
Ok(unsafe {
std::slice::from_raw_parts(bytes.as_ptr() as *const u32, bytes.len() / SPIRV_WORD_SIZE)
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn misaligned_pointer_display() {
let err = BytecodeError::MisalignedPointer;
assert_eq!(
err.to_string(),
"SPIR-V byte slice pointer is not 4-byte aligned"
);
}
#[test]
fn invalid_length_display() {
let err = BytecodeError::InvalidLength(7);
assert_eq!(
err.to_string(),
"SPIR-V byte length 7 is not a multiple of 4"
);
}
#[test]
fn misaligned_pointer_returns_error() {
#[repr(align(4))]
struct Aligned([u8; 8]);
let data = Aligned([0; 8]);
let misaligned = &data.0[1..5];
assert_eq!(
cast_to_u32(misaligned),
Err(BytecodeError::MisalignedPointer)
);
}
#[test]
fn bytecode_error_is_std_error() {
let err: &dyn std::error::Error = &BytecodeError::InvalidLength(3);
assert!(err.source().is_none());
}
}