use crate::{
offset_of, Archive, ArchivePointee, Archived, ArchivedIsize, Fallible, RawRelPtr, RelPtr,
};
use bytecheck::{CheckBytes, Unreachable};
use core::{
alloc::Layout,
any::TypeId,
fmt,
marker::{PhantomData, PhantomPinned},
};
use ptr_meta::{DynMetadata, Pointee};
use std::{collections::HashMap, error::Error};
impl RawRelPtr {
#[inline]
pub unsafe fn manual_check_bytes<'a, C: Fallible + ?Sized>(
value: *const RawRelPtr,
context: &mut C,
) -> Result<&'a Self, Unreachable> {
let bytes = value.cast::<u8>();
ArchivedIsize::check_bytes(bytes.add(offset_of!(Self, offset)).cast(), context).unwrap();
PhantomPinned::check_bytes(bytes.add(offset_of!(Self, _phantom)).cast(), context).unwrap();
Ok(&*value)
}
}
impl<T: ArchivePointee + ?Sized> RelPtr<T> {
#[inline]
pub unsafe fn manual_check_bytes<'a, C: Fallible + ?Sized>(
value: *const RelPtr<T>,
context: &mut C,
) -> Result<&'a Self, <T::ArchivedMetadata as CheckBytes<C>>::Error>
where
T: CheckBytes<C>,
T::ArchivedMetadata: CheckBytes<C>,
{
let bytes = value.cast::<u8>();
RawRelPtr::manual_check_bytes(bytes.add(offset_of!(Self, raw_ptr)).cast(), context)
.unwrap();
T::ArchivedMetadata::check_bytes(bytes.add(offset_of!(Self, metadata)).cast(), context)?;
PhantomData::<T>::check_bytes(bytes.add(offset_of!(Self, _phantom)).cast(), context)
.unwrap();
Ok(&*value)
}
}
pub trait LayoutMetadata<T: ?Sized> {
fn layout(self) -> Layout;
}
impl<T> LayoutMetadata<T> for () {
#[inline]
fn layout(self) -> Layout {
Layout::new::<T>()
}
}
impl<T> LayoutMetadata<[T]> for usize {
#[inline]
fn layout(self) -> Layout {
Layout::array::<T>(self).unwrap()
}
}
impl LayoutMetadata<str> for usize {
#[inline]
fn layout(self) -> Layout {
Layout::array::<u8>(self).unwrap()
}
}
impl<T: ?Sized> LayoutMetadata<T> for DynMetadata<T> {
#[inline]
fn layout(self) -> Layout {
self.layout()
}
}
#[derive(Debug)]
pub enum ArchiveBoundsError {
Underaligned {
expected_align: usize,
actual_align: usize,
},
OutOfBounds {
base: usize,
offset: isize,
archive_len: usize,
},
Overrun {
pos: usize,
size: usize,
archive_len: usize,
},
Unaligned {
pos: usize,
align: usize,
},
}
impl fmt::Display for ArchiveBoundsError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ArchiveBoundsError::Underaligned {
expected_align,
actual_align,
} => write!(
f,
"archive underaligned: need alignment {} but have alignment {}",
expected_align, actual_align
),
ArchiveBoundsError::OutOfBounds {
base,
offset,
archive_len,
} => write!(
f,
"out of bounds pointer: base {} offset {} in archive len {}",
base, offset, archive_len
),
ArchiveBoundsError::Overrun {
pos,
size,
archive_len,
} => write!(
f,
"archive overrun: pos {} size {} in archive len {}",
pos, size, archive_len
),
ArchiveBoundsError::Unaligned { pos, align } => write!(
f,
"unaligned pointer: pos {} unaligned for alignment {}",
pos, align
),
}
}
}
impl Error for ArchiveBoundsError {}
pub trait ArchiveBoundsContext: Fallible {
unsafe fn check_rel_ptr(
&mut self,
base: *const u8,
offset: isize,
) -> Result<*const u8, Self::Error>;
unsafe fn bounds_check_ptr(
&mut self,
ptr: *const u8,
layout: &Layout,
) -> Result<(), Self::Error>;
}
pub struct ArchiveBoundsValidator {
begin: *const u8,
len: usize,
}
impl ArchiveBoundsValidator {
#[inline]
pub fn new(bytes: &[u8]) -> Self {
Self {
begin: bytes.as_ptr(),
len: bytes.len(),
}
}
#[inline]
pub fn begin(&self) -> *const u8 {
self.begin
}
#[inline]
pub fn len(&self) -> usize {
self.len
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
}
impl Fallible for ArchiveBoundsValidator {
type Error = ArchiveBoundsError;
}
impl ArchiveBoundsContext for ArchiveBoundsValidator {
unsafe fn check_rel_ptr(
&mut self,
base: *const u8,
offset: isize,
) -> Result<*const u8, Self::Error> {
let base_pos = base.offset_from(self.begin);
if offset < -base_pos || offset > self.len as isize - base_pos {
Err(ArchiveBoundsError::OutOfBounds {
base: base_pos as usize,
offset,
archive_len: self.len,
})
} else {
Ok(base.offset(offset))
}
}
unsafe fn bounds_check_ptr(
&mut self,
ptr: *const u8,
layout: &Layout,
) -> Result<(), Self::Error> {
if (self.begin as usize) & (layout.align() - 1) != 0 {
Err(ArchiveBoundsError::Underaligned {
expected_align: layout.align(),
actual_align: 1 << (self.begin as usize).trailing_zeros(),
})
} else {
let target_pos = ptr.offset_from(self.begin) as usize;
if target_pos & (layout.align() - 1) != 0 {
Err(ArchiveBoundsError::Unaligned {
pos: target_pos,
align: layout.align(),
})
} else if self.len - target_pos < layout.size() {
Err(ArchiveBoundsError::Overrun {
pos: target_pos,
size: layout.size(),
archive_len: self.len,
})
} else {
Ok(())
}
}
}
}
#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)]
pub struct Interval {
pub start: *const u8,
pub end: *const u8,
}
impl Interval {
#[inline]
pub fn overlaps(&self, other: &Self) -> bool {
self.start < other.end && other.start < self.end
}
}
#[derive(Debug)]
pub enum ArchiveMemoryError<E> {
Inner(E),
ClaimOverlap {
previous: Interval,
current: Interval,
},
}
impl<E: fmt::Display> fmt::Display for ArchiveMemoryError<E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ArchiveMemoryError::Inner(e) => e.fmt(f),
ArchiveMemoryError::ClaimOverlap { previous, current } => write!(
f,
"memory claim overlap: current [{:#?}..{:#?}] overlaps previous [{:#?}..{:#?}]",
current.start, current.end, previous.start, previous.end
),
}
}
}
impl<E: Error + 'static> Error for ArchiveMemoryError<E> {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
ArchiveMemoryError::Inner(e) => Some(e as &dyn Error),
ArchiveMemoryError::ClaimOverlap { .. } => None,
}
}
}
pub trait ArchiveMemoryContext: Fallible {
unsafe fn claim_bytes(&mut self, start: *const u8, len: usize) -> Result<(), Self::Error>;
unsafe fn claim_owned_ptr<T: ArchivePointee + ?Sized>(
&mut self,
ptr: *const T,
) -> Result<(), Self::Error>
where
Self: ArchiveBoundsContext,
<T as Pointee>::Metadata: LayoutMetadata<T>,
{
let metadata = ptr_meta::metadata(ptr);
let layout = LayoutMetadata::<T>::layout(metadata);
self.bounds_check_ptr(ptr.cast(), &layout)?;
self.claim_bytes(ptr.cast(), layout.size())?;
Ok(())
}
fn claim_owned_rel_ptr<T: ArchivePointee + ?Sized>(
&mut self,
rel_ptr: &RelPtr<T>,
) -> Result<*const T, Self::Error>
where
Self: ArchiveBoundsContext,
<T as Pointee>::Metadata: LayoutMetadata<T>,
{
unsafe {
let data = self.check_rel_ptr(rel_ptr.base(), rel_ptr.offset())?;
let ptr =
ptr_meta::from_raw_parts::<T>(data.cast(), T::pointer_metadata(rel_ptr.metadata()));
self.claim_owned_ptr(ptr)?;
Ok(ptr)
}
}
}
pub struct ArchiveValidator<C> {
inner: C,
intervals: Vec<Interval>,
}
impl<C> ArchiveValidator<C> {
#[inline]
pub fn new(inner: C) -> Self {
Self {
inner,
intervals: Vec::new(),
}
}
#[inline]
pub fn into_inner(self) -> C {
self.inner
}
}
impl<C: Fallible> Fallible for ArchiveValidator<C> {
type Error = ArchiveMemoryError<C::Error>;
}
impl<C: ArchiveBoundsContext> ArchiveBoundsContext for ArchiveValidator<C> {
#[inline]
unsafe fn check_rel_ptr(
&mut self,
base: *const u8,
offset: isize,
) -> Result<*const u8, Self::Error> {
self.inner
.check_rel_ptr(base, offset)
.map_err(ArchiveMemoryError::Inner)
}
#[inline]
unsafe fn bounds_check_ptr(
&mut self,
ptr: *const u8,
layout: &Layout,
) -> Result<(), Self::Error> {
self.inner
.bounds_check_ptr(ptr, layout)
.map_err(ArchiveMemoryError::Inner)
}
}
impl<C: ArchiveBoundsContext> ArchiveMemoryContext for ArchiveValidator<C> {
unsafe fn claim_bytes(&mut self, start: *const u8, len: usize) -> Result<(), Self::Error> {
let interval = Interval {
start,
end: start.add(len),
};
match self.intervals.binary_search(&interval) {
Ok(index) => Err(ArchiveMemoryError::ClaimOverlap {
previous: self.intervals[index],
current: interval,
}),
Err(index) => {
if index < self.intervals.len() {
if self.intervals[index].overlaps(&interval) {
return Err(ArchiveMemoryError::ClaimOverlap {
previous: self.intervals[index],
current: interval,
});
} else if self.intervals[index].start == interval.end {
self.intervals[index].start = interval.start;
return Ok(());
}
}
if index > 0 {
if self.intervals[index - 1].overlaps(&interval) {
return Err(ArchiveMemoryError::ClaimOverlap {
previous: self.intervals[index - 1],
current: interval,
});
} else if self.intervals[index - 1].end == interval.start {
self.intervals[index - 1].end = interval.end;
return Ok(());
}
}
self.intervals.insert(index, interval);
Ok(())
}
}
}
}
#[derive(Debug)]
pub enum SharedArchiveError<E> {
Inner(E),
TypeMismatch {
previous: TypeId,
current: TypeId,
},
}
impl<E: fmt::Display> fmt::Display for SharedArchiveError<E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SharedArchiveError::Inner(e) => e.fmt(f),
SharedArchiveError::TypeMismatch { previous, current } => write!(
f,
"the same memory region has been claimed as two different types ({:?} and {:?})",
previous, current
),
}
}
}
impl<E: Error + 'static> Error for SharedArchiveError<E> {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
SharedArchiveError::Inner(e) => Some(e as &dyn Error),
SharedArchiveError::TypeMismatch { .. } => None,
}
}
}
pub trait SharedArchiveContext: Fallible {
unsafe fn claim_shared_bytes(
&mut self,
start: *const u8,
len: usize,
type_id: TypeId,
) -> Result<bool, Self::Error>;
fn claim_shared_ptr<T: ArchivePointee + CheckBytes<Self> + ?Sized>(
&mut self,
rel_ptr: &RelPtr<T>,
type_id: TypeId,
) -> Result<Option<*const T>, Self::Error>
where
Self: ArchiveBoundsContext,
<T as Pointee>::Metadata: LayoutMetadata<T>,
{
unsafe {
let data = self.check_rel_ptr(rel_ptr.base(), rel_ptr.offset())?;
let metadata = T::pointer_metadata(rel_ptr.metadata());
let ptr = ptr_meta::from_raw_parts::<T>(data.cast(), metadata);
let layout = LayoutMetadata::<T>::layout(metadata);
self.bounds_check_ptr(ptr.cast(), &layout)?;
if self.claim_shared_bytes(ptr.cast(), layout.size(), type_id)? {
Ok(Some(ptr))
} else {
Ok(None)
}
}
}
}
pub struct SharedArchiveValidator<C> {
inner: C,
shared_blocks: HashMap<*const u8, TypeId>,
}
impl<C> SharedArchiveValidator<C> {
#[inline]
pub fn new(inner: C) -> Self {
Self {
inner,
shared_blocks: HashMap::new(),
}
}
#[inline]
pub fn into_inner(self) -> C {
self.inner
}
}
impl<C: Fallible> Fallible for SharedArchiveValidator<C> {
type Error = SharedArchiveError<C::Error>;
}
impl<C: ArchiveBoundsContext> ArchiveBoundsContext for SharedArchiveValidator<C> {
#[inline]
unsafe fn check_rel_ptr(
&mut self,
base: *const u8,
offset: isize,
) -> Result<*const u8, Self::Error> {
self.inner
.check_rel_ptr(base, offset)
.map_err(SharedArchiveError::Inner)
}
#[inline]
unsafe fn bounds_check_ptr(
&mut self,
ptr: *const u8,
layout: &Layout,
) -> Result<(), Self::Error> {
self.inner
.bounds_check_ptr(ptr, layout)
.map_err(SharedArchiveError::Inner)
}
}
impl<C: ArchiveMemoryContext> ArchiveMemoryContext for SharedArchiveValidator<C> {
#[inline]
unsafe fn claim_bytes(&mut self, start: *const u8, len: usize) -> Result<(), Self::Error> {
self.inner
.claim_bytes(start, len)
.map_err(SharedArchiveError::Inner)
}
}
impl<C: ArchiveMemoryContext> SharedArchiveContext for SharedArchiveValidator<C> {
unsafe fn claim_shared_bytes(
&mut self,
start: *const u8,
len: usize,
type_id: TypeId,
) -> Result<bool, Self::Error> {
if let Some(previous_type_id) = self.shared_blocks.get(&start) {
if previous_type_id != &type_id {
Err(SharedArchiveError::TypeMismatch {
previous: *previous_type_id,
current: type_id,
})
} else {
Ok(false)
}
} else {
self.shared_blocks.insert(start, type_id);
self.inner
.claim_bytes(start, len)
.map_err(SharedArchiveError::Inner)?;
Ok(true)
}
}
}
#[derive(Debug)]
pub enum CheckArchiveError<T, C> {
CheckBytesError(T),
ContextError(C),
}
impl<T: fmt::Display, C: fmt::Display> fmt::Display for CheckArchiveError<T, C> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
CheckArchiveError::CheckBytesError(e) => write!(f, "check bytes error: {}", e),
CheckArchiveError::ContextError(e) => write!(f, "context error: {}", e),
}
}
}
impl<T: Error + 'static, C: Error + 'static> Error for CheckArchiveError<T, C> {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
CheckArchiveError::CheckBytesError(e) => Some(e as &dyn Error),
CheckArchiveError::ContextError(e) => Some(e as &dyn Error),
}
}
}
pub type DefaultArchiveValidator = SharedArchiveValidator<ArchiveValidator<ArchiveBoundsValidator>>;
pub type CheckTypeError<T, C> =
CheckArchiveError<<T as CheckBytes<C>>::Error, <C as Fallible>::Error>;
#[inline]
pub fn check_archived_value<T: Archive>(
buf: &[u8],
pos: usize,
) -> Result<&T::Archived, CheckTypeError<T::Archived, DefaultArchiveValidator>>
where
T::Archived: CheckBytes<DefaultArchiveValidator>,
{
let mut validator =
SharedArchiveValidator::new(ArchiveValidator::new(ArchiveBoundsValidator::new(buf)));
check_archived_value_with_context::<T, DefaultArchiveValidator>(buf, pos, &mut validator)
}
#[inline]
pub fn check_archived_root<T: Archive>(
buf: &[u8],
) -> Result<&T::Archived, CheckTypeError<T::Archived, DefaultArchiveValidator>>
where
T::Archived: CheckBytes<DefaultArchiveValidator>,
{
check_archived_value::<T>(buf, buf.len() - core::mem::size_of::<T::Archived>())
}
#[inline]
pub fn check_archived_value_with_context<
'a,
T: Archive,
C: ArchiveBoundsContext + ArchiveMemoryContext + ?Sized,
>(
buf: &'a [u8],
pos: usize,
context: &mut C,
) -> Result<&'a T::Archived, CheckTypeError<T::Archived, C>>
where
T::Archived: CheckBytes<C> + Pointee<Metadata = ()>,
{
unsafe {
let data = context
.check_rel_ptr(buf.as_ptr(), pos as isize)
.map_err(CheckArchiveError::ContextError)?;
let ptr = ptr_meta::from_raw_parts::<<T as Archive>::Archived>(data.cast(), ());
let layout = LayoutMetadata::<T::Archived>::layout(());
context
.bounds_check_ptr(ptr.cast(), &layout)
.map_err(CheckArchiveError::ContextError)?;
context
.claim_bytes(ptr.cast(), layout.size())
.map_err(CheckArchiveError::ContextError)?;
Ok(Archived::<T>::check_bytes(ptr, context).map_err(CheckArchiveError::CheckBytesError)?)
}
}
#[inline]
pub fn check_archived_root_with_context<
'a,
T: Archive,
C: ArchiveBoundsContext + ArchiveMemoryContext + ?Sized,
>(
buf: &'a [u8],
context: &mut C,
) -> Result<&'a T::Archived, CheckTypeError<T::Archived, C>>
where
T::Archived: CheckBytes<C> + Pointee<Metadata = ()>,
{
check_archived_value_with_context::<T, C>(
buf,
buf.len() - core::mem::size_of::<T::Archived>(),
context,
)
}