use std::{
collections::{btree_map as map, BTreeMap},
fmt,
};
use crate::{AllocZeroed, Event, Pointer, Request, Violation};
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[non_exhaustive]
pub struct Region {
pub ptr: Pointer,
pub size: usize,
pub align: usize,
}
impl Region {
pub fn new(ptr: Pointer, size: usize, align: usize) -> Self {
Self { ptr, size, align }
}
pub fn overlaps(self, other: Self) -> bool {
self.ptr <= other.ptr && other.ptr < self.ptr.saturating_add(self.size)
}
pub fn is_same_region_as(self, other: Self) -> bool {
self.ptr == other.ptr && self.size == other.size
}
}
impl fmt::Display for Region {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
fmt,
"{}-{} (size: {}, align: {})",
self.ptr,
self.ptr.saturating_add(self.size),
self.size,
self.align,
)
}
}
#[derive(Default)]
pub struct Machine {
regions: BTreeMap<Pointer, Request>,
pub memory_used: usize,
}
impl Machine {
pub fn push(&mut self, event: &Event) -> Result<(), Violation> {
match event {
Event::Alloc(requested) => {
self.alloc(requested)?;
}
Event::Free(requested) => {
self.free(requested)?;
}
Event::AllocZeroed(AllocZeroed { is_zeroed, request }) => {
if let Some(false) = is_zeroed {
return Err(Violation::NonZeroedAlloc {
alloc: request.clone(),
});
}
self.alloc(request)?;
}
Event::Realloc(realloc) => {
if let Some(false) = realloc.is_relocated {
return Err(Violation::NonCopiedRealloc {
realloc: realloc.clone(),
});
}
self.free(&realloc.free())?;
self.alloc(&realloc.alloc())?;
}
Event::ReallocNull(realloc) => {
return Err(Violation::ReallocNull {
realloc: realloc.clone(),
});
}
Event::AllocFailed => (),
Event::AllocZeroedFailed => (),
Event::ReallocFailed => (),
}
Ok(())
}
fn alloc(&mut self, request: &Request) -> Result<(), Violation> {
if !request.region.ptr.is_aligned_with(request.region.align) {
return Err(Violation::MisalignedAlloc {
alloc: request.clone(),
});
}
if let Some(existing) = find_region_overlaps(&self.regions, request.region).next() {
return Err(Violation::ConflictingAlloc {
request: request.clone(),
existing,
});
}
self.memory_used = self.memory_used.saturating_add(request.region.size);
let existing = self.regions.insert(request.region.ptr, request.clone());
debug_assert!(existing.is_none());
Ok(())
}
fn free(&mut self, request: &Request) -> Result<(), Violation> {
let entry = if let map::Entry::Occupied(entry) = self.regions.entry(request.region.ptr) {
entry
} else {
return Err(Violation::MissingFree {
request: request.clone(),
});
};
let existing = entry.get();
if !existing.region.is_same_region_as(request.region) {
return Err(Violation::IncompleteFree {
request: request.clone(),
existing: existing.clone(),
});
}
if existing.region.align != request.region.align {
return Err(Violation::MisalignedFree {
request: request.clone(),
existing: existing.clone(),
});
}
let (_, region) = entry.remove_entry();
self.memory_used = self.memory_used.saturating_sub(region.region.size);
Ok(())
}
pub fn trailing_regions(&self) -> Vec<Request> {
self.regions.values().cloned().collect()
}
}
fn find_region_overlaps(
regions: &BTreeMap<Pointer, Request>,
needle: Region,
) -> impl Iterator<Item = Request> + '_ {
let head = regions
.range(..=needle.ptr)
.take_while(move |(_, r)| r.region.overlaps(needle));
let tail = regions
.range(needle.ptr..)
.take_while(move |(_, r)| r.region.overlaps(needle));
head.chain(tail).map(|(_, r)| r.clone())
}