use core::fmt;
use crate::{Flat, list::Segment};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ValidateError {
OutOfBounds {
addr: usize,
need: usize,
buf_len: usize,
},
Misaligned {
addr: usize,
align: usize,
},
NullNear {
addr: usize,
},
InvalidListHeader {
addr: usize,
},
ListLenMismatch {
addr: usize,
expected: u32,
actual: u32,
},
InvalidDiscriminant {
addr: usize,
value: u8,
max: u8,
},
InvalidBool {
addr: usize,
value: u8,
},
Uninhabited,
}
impl fmt::Display for ValidateError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::OutOfBounds { addr, need, buf_len } => {
write!(f, "out of bounds: addr {addr}, need {need} bytes, buf_len {buf_len}")
}
Self::Misaligned { addr, align } => {
write!(f, "misaligned: addr {addr} not aligned to {align}")
}
Self::NullNear { addr } => write!(f, "null Near offset at addr {addr}"),
Self::InvalidListHeader { addr } => {
write!(f, "invalid NearList header at addr {addr}")
}
Self::ListLenMismatch { addr, expected, actual } => {
write!(f, "list len mismatch at addr {addr}: expected {expected}, got {actual}")
}
Self::InvalidDiscriminant { addr, value, max } => {
write!(f, "invalid discriminant at addr {addr}: value {value}, max {max}")
}
Self::InvalidBool { addr, value } => {
write!(f, "invalid bool at addr {addr}: value {value}")
}
Self::Uninhabited => write!(f, "uninhabited type can never be valid"),
}
}
}
impl core::error::Error for ValidateError {}
impl ValidateError {
#[inline]
pub fn check_bounds<T>(addr: usize, buf: &[u8]) -> Result<(), Self> {
let need = size_of::<T>();
if addr.checked_add(need).is_none_or(|end| end > buf.len()) {
return Err(Self::OutOfBounds { addr, need, buf_len: buf.len() });
}
Ok(())
}
#[inline]
pub const fn check_align<T>(addr: usize) -> Result<(), Self> {
let align = align_of::<T>();
if align > 1 && !addr.is_multiple_of(align) {
return Err(Self::Misaligned { addr, align });
}
Ok(())
}
#[inline]
pub fn check<T>(addr: usize, buf: &[u8]) -> Result<(), Self> {
Self::check_bounds::<T>(addr, buf)?;
Self::check_align::<T>(addr)
}
}
pub fn validate_list_impl<T: Flat>(hdr_addr: usize, buf: &[u8]) -> Result<(), ValidateError> {
ValidateError::check_bounds::<i32>(hdr_addr, buf)?;
ValidateError::check_align::<i32>(hdr_addr)?;
let head = i32::from_ne_bytes(buf[hdr_addr..hdr_addr + 4].try_into().unwrap());
let total_len = u32::from_ne_bytes(buf[hdr_addr + 4..hdr_addr + 8].try_into().unwrap());
if (total_len == 0) != (head == 0) {
return Err(ValidateError::InvalidListHeader { addr: hdr_addr });
}
if total_len == 0 {
return Ok(());
}
let seg_hdr_size = size_of::<Segment<T>>();
let elem_size = size_of::<T>();
let max_iters = buf.len().checked_div(seg_hdr_size).unwrap_or(total_len as usize) + 1;
let mut counted: u32 = 0;
let mut seg_addr = hdr_addr.cast_signed().wrapping_add(head as isize).cast_unsigned();
let mut iters = 0usize;
loop {
if iters >= max_iters {
return Err(ValidateError::InvalidListHeader { addr: hdr_addr });
}
iters += 1;
ValidateError::check_bounds::<Segment<T>>(seg_addr, buf)?;
ValidateError::check_align::<Segment<T>>(seg_addr)?;
let seg_next = i32::from_ne_bytes(buf[seg_addr..seg_addr + 4].try_into().unwrap());
let seg_len = u32::from_ne_bytes(buf[seg_addr + 4..seg_addr + 8].try_into().unwrap());
let values_start = seg_addr + seg_hdr_size;
for i in 0..seg_len as usize {
let elem_addr = values_start + i * elem_size;
T::validate(elem_addr, buf)?;
}
counted = counted.checked_add(seg_len).ok_or(ValidateError::ListLenMismatch {
addr: hdr_addr,
expected: total_len,
actual: u32::MAX,
})?;
if counted > total_len {
return Err(ValidateError::ListLenMismatch {
addr: hdr_addr,
expected: total_len,
actual: counted,
});
}
if seg_next == 0 {
break;
}
seg_addr = seg_addr.cast_signed().wrapping_add(seg_next as isize).cast_unsigned();
}
if counted != total_len {
return Err(ValidateError::ListLenMismatch {
addr: hdr_addr,
expected: total_len,
actual: counted,
});
}
Ok(())
}