use crate::{Archive, Archived, Offset, RelPtr};
use bytecheck::{CheckBytes, Unreachable};
use core::{fmt, mem};
use std::error;
#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)]
pub struct Interval {
start: usize,
end: usize,
}
impl Interval {
fn new(start: usize, len: usize) -> Self {
Self {
start,
end: start + len,
}
}
fn overlaps(&self, other: &Self) -> bool {
self.start < other.end && other.start < self.end
}
}
#[derive(Debug)]
pub enum ArchiveMemoryError {
OutOfBounds {
base: usize,
offset: isize,
archive_len: usize,
},
Overrun {
pos: usize,
size: usize,
archive_len: usize,
},
Unaligned {
pos: usize,
align: usize,
},
ClaimOverlap {
previous: Interval,
current: Interval,
},
}
impl fmt::Display for ArchiveMemoryError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ArchiveMemoryError::OutOfBounds {
base,
offset,
archive_len,
} => write!(
f,
"out of bounds pointer: base {} offset {} in archive len {}",
base, offset, archive_len
),
ArchiveMemoryError::Overrun {
pos,
size,
archive_len,
} => write!(
f,
"archive overrun: pos {} size {} in archive len {}",
pos, size, archive_len
),
ArchiveMemoryError::Unaligned { pos, align } => write!(
f,
"unaligned pointer: pos {} unaligned for alignment {}",
pos, align
),
ArchiveMemoryError::ClaimOverlap { previous, current } => write!(
f,
"memory claim overlap: current [{}..{}] overlaps previous [{}..{}]",
current.start, current.end, previous.start, previous.end
),
}
}
}
impl error::Error for ArchiveMemoryError {}
#[derive(Debug)]
pub enum CheckArchiveError<T> {
MemoryError(ArchiveMemoryError),
CheckBytes(T),
}
impl<T> From<ArchiveMemoryError> for CheckArchiveError<T> {
fn from(e: ArchiveMemoryError) -> Self {
Self::MemoryError(e)
}
}
impl<T: fmt::Display> fmt::Display for CheckArchiveError<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
CheckArchiveError::MemoryError(e) => write!(f, "archive memory error: {}", e),
CheckArchiveError::CheckBytes(e) => write!(f, "check bytes error: {}", e),
}
}
}
impl<T: fmt::Debug + fmt::Display> error::Error for CheckArchiveError<T> {}
pub struct ArchiveContext {
begin: *const u8,
len: usize,
intervals: Vec<Interval>,
}
impl ArchiveContext {
pub fn new(bytes: &[u8]) -> Self {
Self {
begin: bytes.as_ptr(),
len: bytes.len(),
intervals: Vec::new(),
}
}
pub unsafe fn claim<T: CheckBytes<ArchiveContext>>(
&mut self,
base: *const u8,
offset: isize,
count: usize,
) -> Result<*const u8, ArchiveMemoryError> {
self.claim_bytes(
base,
offset,
count * mem::size_of::<T>(),
mem::align_of::<T>(),
)
}
pub unsafe fn claim_bytes(
&mut self,
base: *const u8,
offset: isize,
count: usize,
align: usize,
) -> Result<*const u8, ArchiveMemoryError> {
let base_pos = base.offset_from(self.begin);
if offset < -base_pos || offset > self.len as isize - base_pos {
Err(ArchiveMemoryError::OutOfBounds {
base: base_pos as usize,
offset,
archive_len: self.len,
})
} else {
let target_pos = (base_pos + offset) as usize;
if self.len - target_pos < count {
Err(ArchiveMemoryError::Overrun {
pos: target_pos,
size: count,
archive_len: self.len,
})
} else if target_pos & (align - 1) != 0 {
Err(ArchiveMemoryError::Unaligned {
pos: target_pos,
align,
})
} else {
let interval = Interval::new(target_pos, count);
match self.intervals.binary_search(&interval) {
Ok(index) => Err(ArchiveMemoryError::ClaimOverlap {
previous: self.intervals[index],
current: interval,
}),
Err(index) => {
if index < self.intervals.len() && self.intervals[index].overlaps(&interval)
{
Err(ArchiveMemoryError::ClaimOverlap {
previous: self.intervals[index],
current: interval,
})
} else if index > 0 && self.intervals[index - 1].overlaps(&interval) {
Err(ArchiveMemoryError::ClaimOverlap {
previous: self.intervals[index - 1],
current: interval,
})
} else {
self.intervals.insert(index, interval);
Ok(base.offset(offset))
}
}
}
}
}
}
}
pub fn check_archive<'a, T: Archive>(
buf: &[u8],
pos: usize,
) -> Result<&'a Archived<T>, CheckArchiveError<<Archived<T> as CheckBytes<ArchiveContext>>::Error>>
where
T::Archived: CheckBytes<ArchiveContext>,
{
let mut context = ArchiveContext::new(buf);
unsafe {
let bytes = context.claim::<Archived<T>>(buf.as_ptr(), pos as isize, 1)?;
Archived::<T>::check_bytes(bytes, &mut context).map_err(CheckArchiveError::CheckBytes)?;
Ok(&*bytes.cast())
}
}
impl CheckBytes<ArchiveContext> for RelPtr {
type Error = Unreachable;
unsafe fn check_bytes<'a>(
bytes: *const u8,
context: &mut ArchiveContext,
) -> Result<&'a Self, Self::Error> {
Offset::check_bytes(bytes, context)?;
Ok(&*bytes.cast())
}
}