#![cfg(all(unix, target_pointer_width = "64"))]
#![deny(missing_docs)]
#![deny(clippy::pedantic)]
use std::{
collections::{BTreeMap, btree_map::Entry},
fs::OpenOptions,
mem::{self, MaybeUninit},
ops::{Deref, DerefMut},
os::fd::AsRawFd,
path::PathBuf,
sync::{Once, OnceLock, RwLock},
{io, process, ptr, slice},
};
use libc::{
MAP_ANONYMOUS, MAP_FAILED, MAP_FIXED, MAP_NORESERVE, MAP_PRIVATE,
PROT_NONE, PROT_READ, PROT_WRITE, SA_SIGINFO, c_int, sigaction,
sigemptyset, siginfo_t, sigset_t, ucontext_t,
};
pub struct Mmap(&'static mut MmapInner);
impl Mmap {
pub fn new(page_number: usize, page_size: usize) -> io::Result<Self> {
unsafe { Self::with_files(page_number, page_size, |_| None) }
}
pub unsafe fn with_files<FL>(
page_number: usize,
page_size: usize,
file_locator: FL,
) -> io::Result<Self>
where
FL: 'static + LocateFile,
{
unsafe {
let inner = MmapInner::new(page_number, page_size, file_locator)?;
with_global_map_mut(|global_map| {
let inner = Box::leak(Box::new(inner));
let start_addr = inner.bytes.as_mut_ptr() as usize;
let end_addr = start_addr + inner.bytes.len();
let inner_ptr = inner as *mut _;
global_map.insert(start_addr..end_addr, inner_ptr as _);
Ok(Self(inner))
})
}
}
pub fn snap(&mut self) -> io::Result<()> {
unsafe { self.0.snap() }
}
pub fn revert(&mut self) -> io::Result<()> {
unsafe { self.0.revert() }
}
pub fn apply(&mut self) -> io::Result<()> {
unsafe { self.0.apply() }
}
pub fn dirty_pages(&self) -> impl Iterator<Item = (&[u8], &[u8], &usize)> {
self.0.last_snapshot().clean_pages.iter().map(
move |(page_index, clean_page)| {
let page_size = self.0.page_size;
let offset = page_index * page_size;
(
&self.0.bytes[offset..][..page_size],
&clean_page[..],
page_index,
)
},
)
}
}
impl AsRef<[u8]> for Mmap {
fn as_ref(&self) -> &[u8] {
self.0.bytes
}
}
impl AsMut<[u8]> for Mmap {
fn as_mut(&mut self) -> &mut [u8] {
self.0.bytes
}
}
impl Deref for Mmap {
type Target = [u8];
fn deref(&self) -> &Self::Target {
self.0.bytes
}
}
impl DerefMut for Mmap {
fn deref_mut(&mut self) -> &mut Self::Target {
self.0.bytes
}
}
impl Drop for Mmap {
fn drop(&mut self) {
with_global_map_mut(|global_map| {
unsafe {
let inner_ptr = self.0 as *mut MmapInner;
let inner = Box::from_raw(inner_ptr);
let start_addr = inner.bytes.as_mut_ptr() as usize;
let len = inner.bytes.len();
let end_addr = start_addr + len;
global_map.remove(start_addr..end_addr);
};
});
}
}
type InnerMap = rangemap::RangeMap<usize, usize>;
static INNER_MAP: OnceLock<RwLock<InnerMap>> = OnceLock::new();
fn with_global_map<T, F>(closure: F) -> T
where
F: FnOnce(&InnerMap) -> T,
{
let global_map = INNER_MAP
.get_or_init(|| RwLock::new(InnerMap::new()))
.read()
.unwrap();
closure(&global_map)
}
fn with_global_map_mut<T, F>(closure: F) -> T
where
F: FnOnce(&mut InnerMap) -> T,
{
let mut global_map = INNER_MAP
.get_or_init(|| RwLock::new(InnerMap::new()))
.write()
.unwrap();
closure(&mut global_map)
}
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
fn system_page_size() -> usize {
static PAGE_SIZE: OnceLock<usize> = OnceLock::new();
unsafe {
*PAGE_SIZE.get_or_init(|| libc::sysconf(libc::_SC_PAGESIZE) as usize)
}
}
struct Snapshot {
clean_pages: BTreeMap<usize, Vec<u8>>,
hit_pages: PageBits,
}
impl Snapshot {
fn new(page_number: usize) -> io::Result<Self> {
Ok(Self {
clean_pages: BTreeMap::new(),
hit_pages: PageBits::new(page_number)?,
})
}
}
struct PageBits(&'static mut [u8]);
impl PageBits {
fn new(page_number: usize) -> io::Result<Self> {
let page_bits = unsafe {
let len = page_number / 8 + usize::from(page_number % 8 != 0);
let ptr = libc::mmap(
ptr::null_mut(),
len,
PROT_READ | PROT_WRITE,
MAP_PRIVATE | MAP_ANONYMOUS | MAP_NORESERVE,
-1,
0,
);
if ptr == MAP_FAILED {
return Err(io::Error::last_os_error());
}
slice::from_raw_parts_mut(ptr.cast(), len)
};
Ok(Self(page_bits))
}
fn set_and_exec<T, E, F>(
&mut self,
page_index: usize,
closure: F,
) -> Result<T, E>
where
F: FnOnce(bool) -> Result<T, E>,
{
let byte_index = page_index / 8;
let bit_index = page_index % 8;
let byte = &mut self.0[byte_index];
let mask = 1u8 << bit_index;
match *byte & mask {
0 => {
let r = closure(false);
if r.is_ok() {
*byte |= mask;
}
r
}
_ => closure(true),
}
}
pub fn is_page_hit(&self, page_index: usize) -> bool {
let byte_index = page_index / 8;
let bit_index = page_index % 8;
let byte = &self.0[byte_index];
let mask = 1u8 << bit_index;
!matches!(*byte & mask, 0)
}
}
impl Drop for PageBits {
fn drop(&mut self) {
unsafe {
let ptr = self.0.as_mut_ptr();
let len = self.0.len();
libc::munmap(ptr.cast(), len);
}
}
}
pub trait LocateFile: Send + Sync {
fn locate_file(&mut self, page_index: usize) -> Option<PathBuf>;
}
impl<F> LocateFile for F
where
F: FnMut(usize) -> Option<PathBuf>,
F: Send + Sync,
{
fn locate_file(&mut self, page_index: usize) -> Option<PathBuf> {
self(page_index)
}
}
struct MmapInner {
bytes: &'static mut [u8],
page_size: usize,
page_number: usize,
mapped_pages: PageBits,
snapshots: Vec<Snapshot>,
file_locator: Box<dyn LocateFile>,
}
impl MmapInner {
unsafe fn new<FL>(
page_number: usize,
page_size: usize,
file_locator: FL,
) -> io::Result<Self>
where
FL: 'static + LocateFile,
{
unsafe {
setup_action();
let system_page_size = system_page_size();
if page_size % system_page_size != 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!(
"Page size {page_size} must be a multiple \
of the system page size {system_page_size}"
),
));
}
let mapped_pages = PageBits::new(page_number)?;
let snapshot = Snapshot::new(page_number)?;
let bytes = {
let len = page_number * page_size;
let ptr = libc::mmap(
ptr::null_mut(),
len,
PROT_NONE,
MAP_PRIVATE | MAP_ANONYMOUS | MAP_NORESERVE,
-1,
0,
);
if ptr == MAP_FAILED {
return Err(io::Error::last_os_error());
}
slice::from_raw_parts_mut(ptr.cast(), len)
};
Ok(Self {
bytes,
page_size,
page_number,
mapped_pages,
snapshots: vec![snapshot],
file_locator: Box::new(file_locator),
})
}
}
unsafe fn process_segv(&mut self, si_addr: usize) -> io::Result<()> {
let start_addr = self.bytes.as_mut_ptr() as usize;
let page_size = self.page_size;
let page_index = (si_addr - start_addr) / page_size;
let start_addr = self.bytes.as_ptr() as usize;
let page_offset = page_index * self.page_size;
let page_addr = start_addr + page_offset;
let page_size = self.page_size;
self.mapped_pages.set_and_exec(
page_index,
|is_bit_set| -> io::Result<()> {
if is_bit_set {
return Ok(());
}
if let Some(path) = self.file_locator.locate_file(page_index) {
let file =
OpenOptions::new().read(true).write(true).open(path)?;
let ptr = unsafe {
libc::mmap(
page_addr as _,
page_size,
PROT_NONE,
MAP_PRIVATE | MAP_FIXED | MAP_NORESERVE,
file.as_raw_fd(),
0,
)
};
if ptr == MAP_FAILED {
return Err(io::Error::last_os_error());
}
}
Ok(())
},
)?;
let snapshot = self
.snapshots
.last_mut()
.expect("There should always be at least one snapshot");
snapshot.hit_pages.set_and_exec(page_index, |is_bit_set| {
let mut prot = PROT_READ;
if is_bit_set {
prot |= PROT_WRITE;
if let Entry::Vacant(e) = snapshot.clean_pages.entry(page_index)
{
if unsafe {
libc::mprotect(page_addr as _, page_size, PROT_READ)
} != 0
{
return Err(io::Error::last_os_error());
}
let mut clean_page = vec![0; page_size];
clean_page.copy_from_slice(
&self.bytes[page_offset..][..page_size],
);
e.insert(clean_page);
}
}
if unsafe { libc::mprotect(page_addr as _, page_size, prot) } != 0 {
return Err(io::Error::last_os_error());
}
Ok(())
})?;
Ok(())
}
unsafe fn snap(&mut self) -> io::Result<()> {
unsafe {
let len = self.bytes.len();
if libc::mprotect(self.bytes.as_mut_ptr().cast(), len, PROT_NONE)
!= 0
{
return Err(io::Error::last_os_error());
}
self.snapshots.push(Snapshot::new(self.page_number)?);
Ok(())
}
}
unsafe fn apply(&mut self) -> io::Result<()> {
unsafe {
let len = self.bytes.len();
if libc::mprotect(self.bytes.as_mut_ptr().cast(), len, PROT_NONE)
!= 0
{
return Err(io::Error::last_os_error());
}
let popped_snapshot = self
.snapshots
.pop()
.expect("There should always be at least one snapshot");
if self.snapshots.is_empty() {
self.snapshots.push(Snapshot::new(self.page_number)?);
}
let snapshot = self.last_snapshot_mut();
for (page_index, clean_page) in popped_snapshot.clean_pages {
snapshot.clean_pages.entry(page_index).or_insert(clean_page);
}
Ok(())
}
}
unsafe fn revert(&mut self) -> io::Result<()> {
unsafe {
let popped_snapshot = self
.snapshots
.pop()
.expect("There should always be at least one snapshot");
if self.snapshots.is_empty() {
self.snapshots.push(Snapshot::new(self.page_number)?);
} else {
}
let page_size = self.page_size;
for (page_index, clean_page) in popped_snapshot.clean_pages {
let page_offset = page_index * page_size;
let start_addr = self.bytes.as_mut_ptr() as usize;
let page_addr = start_addr + page_offset;
if libc::mprotect(page_addr as _, page_size, PROT_WRITE) != 0 {
return Err(io::Error::last_os_error());
}
self.bytes[page_offset..][..page_size]
.copy_from_slice(&clean_page[..]);
let prot = match &self
.last_snapshot()
.hit_pages
.is_page_hit(page_index)
{
false => PROT_NONE,
true => PROT_READ,
};
if libc::mprotect(page_addr as _, page_size, prot) != 0 {
return Err(io::Error::last_os_error());
}
}
Ok(())
}
}
fn last_snapshot(&self) -> &Snapshot {
self.snapshots
.last()
.expect("There should always be at least one snapshot")
}
fn last_snapshot_mut(&mut self) -> &mut Snapshot {
self.snapshots
.last_mut()
.expect("There should always be at least one snapshot")
}
}
impl Drop for MmapInner {
fn drop(&mut self) {
unsafe {
let ptr = self.bytes.as_mut_ptr();
let len = self.bytes.len();
libc::munmap(ptr.cast(), len);
}
}
}
static SIGNAL_HANDLER: Once = Once::new();
unsafe fn setup_action() -> sigaction {
static OLD_ACTION: OnceLock<sigaction> = OnceLock::new();
SIGNAL_HANDLER.call_once(|| {
unsafe {
let mut sa_mask = MaybeUninit::<sigset_t>::uninit();
sigemptyset(sa_mask.as_mut_ptr());
let act = sigaction {
sa_sigaction: segfault_handler as _,
sa_mask: sa_mask.assume_init(),
sa_flags: SA_SIGINFO,
#[cfg(target_os = "linux")]
sa_restorer: None,
};
let mut old_act = MaybeUninit::<sigaction>::uninit();
if libc::sigaction(libc::SIGSEGV, &act, old_act.as_mut_ptr()) != 0 {
process::exit(1);
}
#[cfg(target_os = "macos")]
if libc::sigaction(libc::SIGBUS, &act, old_act.as_mut_ptr()) != 0 {
process::exit(2);
}
OLD_ACTION.get_or_init(move || old_act.assume_init());
}
});
*OLD_ACTION.get().unwrap()
}
unsafe fn call_old_action(
sig: c_int,
info: *mut siginfo_t,
ctx: *mut ucontext_t,
) {
unsafe {
let old_act = setup_action();
if old_act.sa_flags & SA_SIGINFO == 0 {
let act: fn(c_int) = mem::transmute(old_act.sa_sigaction);
act(sig);
} else {
let act: fn(c_int, *mut siginfo_t, *mut ucontext_t) =
mem::transmute(old_act.sa_sigaction);
act(sig, info, ctx);
}
}
}
unsafe fn segfault_handler(
sig: c_int,
info: *mut siginfo_t,
ctx: *mut ucontext_t,
) {
with_global_map(move |global_map| {
unsafe {
let si_addr = (*info).si_addr() as usize;
if let Some(inner_ptr) = global_map.get(&si_addr) {
let inner = &mut *(*inner_ptr as *mut MmapInner);
if inner.process_segv(si_addr).is_err() {
call_old_action(sig, info, ctx);
}
return;
}
call_old_action(sig, info, ctx);
}
});
}
#[cfg(test)]
mod tests {
use super::*;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use std::thread;
const N_PAGES: usize = 65536;
const PAGE_SIZE: usize = 65536;
const DIRT: [u8; 2 * PAGE_SIZE] = [42; 2 * PAGE_SIZE];
const DIRT2: [u8; 2 * PAGE_SIZE] = [43; 2 * PAGE_SIZE];
const OFFSET: usize = PAGE_SIZE / 2 + PAGE_SIZE;
#[test]
fn write() {
let mut mem = Mmap::new(N_PAGES, PAGE_SIZE)
.expect("Instantiating new memory should succeed");
let slice = &mut mem[OFFSET..][..DIRT.len()];
slice.copy_from_slice(&DIRT);
assert_eq!(slice, DIRT, "Slice should be dirt just written");
assert_eq!(mem.dirty_pages().count(), 3);
}
#[test]
fn write_multi_thread() {
const NUM_THREADS: usize = 8;
let mut threads = Vec::with_capacity(NUM_THREADS);
for _ in 0..NUM_THREADS {
threads.push(thread::spawn(|| {
let mut mem = Mmap::new(N_PAGES, PAGE_SIZE)
.expect("Instantiating new memory should succeed");
let slice = &mut mem[OFFSET..][..DIRT.len()];
slice.copy_from_slice(&DIRT);
assert_eq!(slice, DIRT, "Slice should be dirt just written");
assert_eq!(mem.dirty_pages().count(), 3);
}));
}
threads
.drain(..)
.for_each(|t| t.join().expect("Thread should exit cleanly"));
}
#[test]
fn revert() {
let mut mem = Mmap::new(N_PAGES, PAGE_SIZE)
.expect("Instantiating new memory should succeed");
let slice = &mut mem[OFFSET..][..DIRT.len()];
slice.copy_from_slice(&DIRT);
mem.snap().expect("Snapshotting should succeed");
assert_eq!(mem.dirty_pages().count(), 0);
let slice = &mem[OFFSET..][..DIRT.len()];
assert_eq!(slice, DIRT, "Slice should be dirt just written");
let slice = &mut mem[OFFSET..][..DIRT.len()];
slice.copy_from_slice(&[0; 2 * PAGE_SIZE]);
mem.revert().expect("Reverting should succeed");
assert_eq!(mem.dirty_pages().count(), 3);
let slice = &mut mem[OFFSET..][..DIRT.len()];
assert_eq!(
slice, DIRT,
"Slice should be the dirt that was written before"
);
}
#[test]
fn multi_revert() {
let mut mem = Mmap::new(N_PAGES, PAGE_SIZE)
.expect("Instantiating new memory should succeed");
let slice = &mut mem[OFFSET..][..DIRT.len()];
slice.copy_from_slice(&DIRT);
mem.snap().expect("Snapshotting should succeed");
assert_eq!(mem.dirty_pages().count(), 0);
let slice = &mem[OFFSET..][..DIRT.len()];
assert_eq!(slice, DIRT, "Slice should be dirt just written");
let slice = &mut mem[OFFSET..][..DIRT2.len()];
slice.copy_from_slice(&DIRT2);
mem.snap().expect("Snapshotting should succeed");
assert_eq!(mem.dirty_pages().count(), 0);
let slice = &mem[OFFSET..][..DIRT2.len()];
assert_eq!(slice, DIRT2, "Slice should be dirt just written");
mem.revert().expect("Reverting should succeed");
assert_eq!(mem.dirty_pages().count(), 3);
let slice = &mem[OFFSET..][..DIRT2.len()];
assert_eq!(slice, DIRT2, "Slice should be dirt written second");
mem.revert().expect("Reverting should succeed");
assert_eq!(mem.dirty_pages().count(), 3);
let slice = &mem[OFFSET..][..DIRT.len()];
assert_eq!(slice, DIRT, "Slice should be dirt written first");
}
#[test]
fn apply() {
let mut mem = Mmap::new(N_PAGES, PAGE_SIZE)
.expect("Instantiating new memory should succeed");
let slice = &mut mem[OFFSET..][..DIRT.len()];
slice.copy_from_slice(&DIRT);
mem.snap().expect("Snapshotting should succeed");
assert_eq!(mem.dirty_pages().count(), 0);
let slice = &mem[OFFSET..][..DIRT.len()];
assert_eq!(slice, DIRT, "Slice should be dirt just written");
let slice = &mut mem[OFFSET..][..DIRT.len()];
slice.copy_from_slice(&[0; 2 * PAGE_SIZE]);
mem.apply().expect("Applying should succeed");
assert_eq!(mem.dirty_pages().count(), 3);
let slice = &mut mem[OFFSET..][..DIRT.len()];
assert_eq!(
slice,
&[0; 2 * PAGE_SIZE],
"Slice should be the zeros written afterwards"
);
}
#[test]
fn apply_revert_apply() {
const N_WRITES: usize = 64;
let mut rng = StdRng::seed_from_u64(0xDEAD_BEEF);
let mut mem = Mmap::new(N_PAGES, PAGE_SIZE)
.expect("Instantiating new memory should succeed");
let mut mem_alt = Mmap::new(N_PAGES, PAGE_SIZE)
.expect("Instantiating new memory should succeed");
mem.snap().expect("Snapshotting should succeed");
mem_alt.snap().expect("Snapshotting should succeed");
for _ in 0..N_WRITES {
let i = rng.gen_range(0..N_PAGES);
let byte = rng.r#gen();
mem[i] = byte;
mem_alt[i] = byte;
}
mem.apply().expect("Applying should succeed");
mem_alt.apply().expect("Applying should succeed");
mem.snap().expect("Snapshotting should succeed");
for _ in 0..N_WRITES {
let i = rng.gen_range(0..N_PAGES);
let byte = rng.r#gen();
mem[i] = byte;
}
mem.revert().expect("Reverting should succeed");
mem.snap().expect("Snapshotting should succeed");
mem_alt.snap().expect("Snapshotting should succeed");
for _ in 0..N_WRITES {
let i = rng.gen_range(0..N_PAGES);
let byte = rng.r#gen();
mem[i] = byte;
mem_alt[i] = byte;
}
mem.apply().expect("Applying should succeed");
mem_alt.apply().expect("Applying should succeed");
mem.dirty_pages().zip(mem_alt.dirty_pages()).for_each(
|((dirty, clean, index), (alt_dirty, alt_clean, alt_index))| {
let hash_dirty = blake3::hash(dirty);
let hash_alt_dirty = blake3::hash(alt_dirty);
let hash_dirty = hex::encode(hash_dirty.as_bytes());
let hash_alt_dirty = hex::encode(hash_alt_dirty.as_bytes());
assert_eq!(
hash_dirty, hash_alt_dirty,
"Dirty state should be the same"
);
let hash_clean = blake3::hash(clean);
let hash_alt_clean = blake3::hash(alt_clean);
let hash_clean = hex::encode(hash_clean.as_bytes());
let hash_alt_clean = hex::encode(hash_alt_clean.as_bytes());
assert_eq!(
hash_clean, hash_alt_clean,
"Clean state should be the same"
);
assert_eq!(index, alt_index, "Index should be the same");
},
);
}
#[test]
fn tc_snaps() {
let mut mem = Mmap::new(N_PAGES, PAGE_SIZE)
.expect("Instantiating new memory should succeed");
for i in 0..((50) * PAGE_SIZE) {
mem[i] = (i % 256) as u8;
}
println!("Initial memory state after filling:");
mem.snap().expect("call_inner: Snap 1 should succeed");
println!("Memory state after snap 1:");
mem.snap().expect("fn c query Snap 2 should succeed");
println!("Memory state after snap 2:");
mem.snap()
.expect("fn c contract_to_contract Snap 3 should succeed");
mem[1 * PAGE_SIZE] = 0xAB;
mem[2 * PAGE_SIZE] = 0xCD;
mem[3 * PAGE_SIZE] = 0xEF;
mem[27 * PAGE_SIZE] = 0x12;
mem[28 * PAGE_SIZE] = 0x34;
mem[30 * PAGE_SIZE] = 0x56;
println!("Memory state after snap 3 & writing:");
mem.snap()
.expect("fn c contract_to_contract Snap 4 should succeed");
assert_eq!(mem[1 * PAGE_SIZE], 0xAB);
assert_eq!(mem[2 * PAGE_SIZE], 0xCD);
assert_eq!(mem[3 * PAGE_SIZE], 0xEF);
assert_eq!(mem[27 * PAGE_SIZE], 0x12);
assert_eq!(mem[28 * PAGE_SIZE], 0x34);
assert_eq!(mem[30 * PAGE_SIZE], 0x56);
mem[1 * PAGE_SIZE] = 0x11;
mem[2 * PAGE_SIZE] = 0x22;
mem[3 * PAGE_SIZE] = 0x33;
println!("Memory state after snap 4 & writing:");
assert_eq!(mem[1 * PAGE_SIZE], 0x11);
assert_eq!(mem[2 * PAGE_SIZE], 0x22);
assert_eq!(mem[3 * PAGE_SIZE], 0x33);
mem.revert()
.expect("fn c contract_to_contract Revert 1 should succeed");
println!("Memory state after Revert 1:");
assert_eq!(mem[1 * PAGE_SIZE], 0xAB);
assert_eq!(mem[2 * PAGE_SIZE], 0xCD);
assert_eq!(mem[3 * PAGE_SIZE], 0xEF);
assert_eq!(mem[27 * PAGE_SIZE], 0x12);
assert_eq!(mem[28 * PAGE_SIZE], 0x34);
assert_eq!(mem[30 * PAGE_SIZE], 0x56);
mem.revert()
.expect("fn c contract_to_contract Revert 2 should succeed");
println!("Memory state after Revert 2:");
for i in 0..((50) * PAGE_SIZE) {
assert_eq!(
mem[i],
(i % 256) as u8,
"Memory should match initial state on page num {}",
i / PAGE_SIZE + 1
);
}
mem.revert().expect("fn c query Revert 3 should succeed");
println!("Memory state after Revert 3:");
for i in 0..((50) * PAGE_SIZE) {
assert_eq!(
mem[i],
(i % 256) as u8,
"Memory should match initial state on page num {}",
i / PAGE_SIZE + 1
);
}
}
#[test]
fn snap_revert_revert_apply_scenario() {
use blake3::Hasher;
const N_PAGES: usize = 65536;
const PAGE_SIZE: usize = 65536;
const OFFSET: usize = 0;
fn fill_region(mem: &mut Mmap, offset: usize, len: usize, value: u8) {
let slice = &mut mem[offset..][..len];
for b in slice {
*b = value;
}
}
fn assert_region_eq(
mem: &Mmap,
offset: usize,
len: usize,
value: u8,
msg: &str,
) {
let slice = &mem[offset..][..len];
assert!(
slice.iter().all(|&b| b == value),
"{msg}: expected all {:#x}, found: first few bytes = {:?}",
value,
&slice[..std::cmp::min(16, slice.len())]
);
}
fn print_region(mem: &Mmap, offset: usize, len: usize, msg: &str) {
let slice = &mem[offset..][..len];
println!(
"memory region at {msg}: {:?}",
&slice[..std::cmp::min(16, slice.len())]
);
}
let mut mem = Mmap::new(N_PAGES, PAGE_SIZE)
.expect("Instantiating new memory should succeed");
let len = 2 * PAGE_SIZE;
print_region(&mem, OFFSET, len, "beginning");
mem.snap().expect("Snapshot 1 should succeed");
fill_region(&mut mem, OFFSET, len, 0x11);
print_region(&mem, OFFSET, len, "After modify #1");
mem.snap().expect("Snapshot 2 should succeed");
fill_region(&mut mem, OFFSET, len, 0x22);
print_region(&mem, OFFSET, len, "After modify #2");
mem.snap().expect("Snapshot 3 should succeed"); fill_region(&mut mem, len, PAGE_SIZE, 0x33); print_region(&mem, len, PAGE_SIZE, "After modify #3");
mem.apply().expect("Apply snapshot 3 should succeed");
assert_region_eq(&mem, len, PAGE_SIZE, 0x33, "After apply #3");
fill_region(&mut mem, OFFSET, len, 0x44);
print_region(&mem, OFFSET, len, "After modify #4");
mem.snap().expect("Snapshot 4 should succeed"); fill_region(&mut mem, OFFSET, len, 0x55);
print_region(&mem, OFFSET, len, "After modify #5");
mem.revert().expect("Revert from snapshot 4 should succeed");
print_region(&mem, OFFSET, len, "After revert #5");
mem.revert().expect("Revert from snapshot 2 should succeed");
mem.apply().expect("Apply snapshot 1 should succeed"); mem.dirty_pages().for_each(|(dirty, clean, page_index)| {
println!(
"Dirty page index: {page_index} - dirty {} - clean {}",
hex::encode(Hasher::new().update(dirty).finalize().as_bytes()),
hex::encode(Hasher::new().update(clean).finalize().as_bytes())
);
});
print_region(&mem, OFFSET, len, "After apply #1");
assert_region_eq(&mem, OFFSET, len, 0x11, "After apply #1");
mem.snap().expect("Snapshot 5 should succeed");
print_region(&mem, OFFSET, len, "new call");
print_region(&mem, len, PAGE_SIZE, "new call");
assert_region_eq(
&mem,
OFFSET,
len,
0x11,
"After apply #1 (final state must be memory #1)",
);
}
#[test]
fn apply_preserves_earliest_clean_state() {
let mut mem = Mmap::new(N_PAGES, PAGE_SIZE)
.expect("Instantiating new memory should succeed");
mem[2 * PAGE_SIZE] = 0x11;
mem.snap().expect("Snap S1 should succeed");
mem[2 * PAGE_SIZE] = 0x22;
mem.snap().expect("Snap S2 should succeed");
mem[2 * PAGE_SIZE] = 0x33;
mem.apply().expect("Apply should succeed");
let dirty: Vec<_> = mem.dirty_pages().collect();
let (dirty_page, clean_page, &page_index) = dirty[0];
assert_eq!(page_index, 2);
assert_eq!(dirty_page[0], 0x33);
assert_eq!(
clean_page[0], 0x11,
"or_insert keeps earliest state (0x11), not immediate pre-mod (0x22)"
);
mem.revert().expect("Revert should succeed");
assert_eq!(mem[2 * PAGE_SIZE], 0x11);
}
#[test]
fn more_reverts_than_snaps() {
let mut mem = Mmap::new(N_PAGES, PAGE_SIZE).unwrap();
println!("dirty pages at start: {}", mem.dirty_pages().count());
let a = mem[0];
println!("initial value at mem[0]: {}", a);
println!(
"dirty pages after reading mem[0]: {}",
mem.dirty_pages().count()
);
mem[0] = 1;
println!(
"dirty pages after writing mem[0]: {}",
mem.dirty_pages().count()
);
mem[0] = 1;
mem.snap().unwrap(); println!("-----------");
mem[0] = 2;
mem.snap().unwrap(); println!("-----------");
mem[0] = 3;
mem.snap().unwrap(); println!("-----------");
mem.snap().unwrap(); println!("-----------");
mem[0] = 5;
mem.snap().unwrap(); println!("-----------");
mem[0] = 6; println!("-----------");
mem.apply().unwrap();
mem.apply().unwrap();
mem.apply().unwrap();
mem.apply().unwrap();
println!("----------- REVERT -----------");
mem.revert().unwrap();
println!("----------- REVERT -----------");
mem.revert().unwrap();
println!("----------- REVERT -----------");
mem.revert().unwrap();
println!("----------- REVERT -----------");
mem.revert().unwrap();
println!("----------- REVERT -----------");
mem.revert().unwrap();
println!("----------- REVERT -----------");
mem.revert().unwrap();
println!("CHANGING VALUE BEFORE EXTRA REVERTS");
mem.snap().unwrap();
println!("SNAPPING");
mem[0] = 10;
println!("CHANGED VALUE TO 10");
mem.apply().unwrap(); mem.apply().unwrap(); mem.apply().unwrap(); mem.apply().unwrap(); mem.apply().unwrap(); mem.apply().unwrap(); mem.apply().unwrap(); println!("--- APPLIED AFTER CHANGING VALUE ---");
assert_eq!(mem[0], 10, "should be 10");
println!("----------- ADDITIONAL REVERTS -----------");
println!("----------- REVERT -----------");
mem.revert().unwrap();
println!("----------- AFTER REVERT -----------");
assert_eq!(mem[0], 0, "should be 0 after extra revert");
println!("----------- REVERT -----------");
mem.revert().unwrap();
assert_eq!(mem[0], 0, "should be 0 after extra revert");
println!("----------- REVERT -----------");
mem.revert().unwrap();
assert_eq!(mem[0], 0, "should be 0 after extra revert");
println!("----------- REVERT -----------");
mem.revert().unwrap();
println!("----------- REVERT -----------");
mem.revert().unwrap();
println!("----------- REVERT -----------");
mem.revert().unwrap();
println!("----------- REVERT -----------");
mem.revert().unwrap();
println!("----------- REVERT -----------");
mem.revert().unwrap();
println!("----------- REVERT -----------");
mem.revert().unwrap();
println!("----------- REVERT -----------");
mem.revert().unwrap();
println!("----------- REVERT -----------");
mem.revert().unwrap();
println!("----------- REVERT -----------");
mem.revert().unwrap();
println!("----------- ASSERT -----------");
assert_eq!(mem[0], 0, "Further reverts stay at initial state");
}
fn print_mem(mem: &[u8]) {
if mem.len() <= PAGE_SIZE {
println!("Memory: {:?}", &mem[0..6]);
} else {
let hash = blake3::hash(mem);
println!("Memory hash: {}", hex::encode(hash.as_bytes()));
}
}
#[test]
fn tc_snaps_min() {
let mut mem = Mmap::new(1, PAGE_SIZE).unwrap();
mem[0] = 1;
println!("Initial memory state after filling:");
print_mem(mem.as_ref());
mem.snap().unwrap();
assert_eq!(mem[0], 1);
println!("After snap 1:");
print_mem(mem.as_ref());
mem.snap().unwrap();
println!("After snap 2:");
assert_eq!(mem[0], 1);
print_mem(mem.as_ref());
mem[0] = 2;
mem.snap().unwrap();
println!("After snap 3:");
print_mem(mem.as_ref());
assert_eq!(mem[0], 2);
assert_eq!(mem[1..PAGE_SIZE], [0; PAGE_SIZE - 1]);
mem.revert().unwrap();
assert_eq!(mem[0], 2);
println!("After revert 1:");
print_mem(mem.as_ref());
mem.revert().unwrap();
assert_eq!(mem[0], 1);
println!("After revert 2:");
print_mem(mem.as_ref());
mem.revert().unwrap();
println!("After reverts 3:");
print_mem(mem.as_ref());
assert_eq!(mem[0], 1);
}
#[test]
fn tc_apply2() {
let mut mem = Mmap::new(1, PAGE_SIZE).unwrap();
println!("Initial memory state after filling:");
print_mem(mem.as_ref());
mem.snap().unwrap();
println!("After snap 1:");
print_mem(mem.as_ref());
mem[0] = 1;
mem.snap().unwrap();
assert_eq!(mem[0], 1);
println!("After snap 1:");
print_mem(mem.as_ref());
mem[0] = 2;
mem.snap().unwrap();
println!("After snap 2:");
print_mem(mem.as_ref());
assert_eq!(mem[0], 2);
mem[0] = 3;
mem.snap().unwrap();
println!("After snap 3:");
print_mem(mem.as_ref());
assert_eq!(mem[0], 3);
mem.apply().unwrap();
assert_eq!(mem[0], 3);
println!("After apply:");
print_mem(mem.as_ref());
mem.revert().unwrap();
assert_eq!(mem[0], 2);
println!("After revert:");
print_mem(mem.as_ref());
mem.apply().unwrap();
assert_eq!(mem[0], 2);
println!("After apply 2:");
print_mem(mem.as_ref());
mem.apply().unwrap();
assert_eq!(mem[0], 2);
println!("After apply 3:");
print_mem(mem.as_ref());
mem.apply().unwrap();
assert_eq!(mem[0], 2);
println!("After apply 4:");
print_mem(mem.as_ref());
mem.revert().unwrap();
assert_eq!(mem[0], 0);
println!("After revert it goes to 0");
print_mem(mem.as_ref());
}
#[test]
fn tc_apply3() {
let mut mem = Mmap::new(1, PAGE_SIZE).unwrap();
println!("Initial memory state after filling:");
print_mem(mem.as_ref());
mem.snap().unwrap();
println!("After snap 1:");
print_mem(mem.as_ref());
mem[0] = 1;
mem.snap().unwrap();
assert_eq!(mem[0], 1);
println!("After snap 1:");
print_mem(mem.as_ref());
mem[0] = 2;
mem.snap().unwrap();
println!("After snap 2:");
print_mem(mem.as_ref());
assert_eq!(mem[0], 2);
mem[0] = 3;
mem.snap().unwrap();
println!("After snap 3:");
print_mem(mem.as_ref());
assert_eq!(mem[0], 3);
mem.apply().unwrap();
assert_eq!(mem[0], 3);
println!("After apply:");
print_mem(mem.as_ref());
mem.apply().unwrap();
assert_eq!(mem[0], 3); println!("After revert:");
print_mem(mem.as_ref());
mem.revert().unwrap();
assert_eq!(mem[0], 1); println!("After apply 2:");
print_mem(mem.as_ref());
mem.apply().unwrap();
assert_eq!(mem[0], 1); println!("After apply 3:");
print_mem(mem.as_ref());
mem.apply().unwrap();
assert_eq!(mem[0], 1); println!("After apply 4:");
print_mem(mem.as_ref());
}
#[test]
fn tc_snaps_min_neo() {
let mut mem = Mmap::new(1, PAGE_SIZE).unwrap();
mem.snap().unwrap();
assert_eq!(mem[0], 0);
println!("After snap 1:");
print_mem(mem.as_ref());
mem.snap().unwrap();
println!("After snap 2:");
assert_eq!(mem[0], 0);
print_mem(mem.as_ref());
mem[0] = 1;
mem.snap().unwrap();
println!("After snap 3:");
print_mem(mem.as_ref());
assert_eq!(mem[0], 1);
assert_eq!(mem[1..PAGE_SIZE], [0; PAGE_SIZE - 1]);
mem.revert().unwrap();
assert_eq!(mem[0], 1);
println!("After revert 1:");
print_mem(mem.as_ref());
mem.revert().unwrap();
assert_eq!(mem[0], 0);
println!("After revert 2:");
print_mem(mem.as_ref());
mem.revert().unwrap();
println!("After reverts 3:");
print_mem(mem.as_ref());
assert_eq!(mem[0], 0);
}
#[test]
fn tc_apply_neo() {
let mut mem = Mmap::new(1, PAGE_SIZE).unwrap();
println!("Initial memory state after filling:");
print_mem(mem.as_ref());
mem.snap().unwrap();
println!("After snap 1:");
print_mem(mem.as_ref());
mem[0] = 1;
mem.snap().unwrap();
assert_eq!(mem[0], 1);
println!("After snap 1:");
print_mem(mem.as_ref());
mem[0] = 2;
mem.snap().unwrap();
println!("After snap 2:");
print_mem(mem.as_ref());
assert_eq!(mem[0], 2);
mem[0] = 3;
mem.snap().unwrap();
println!("After snap 3:");
print_mem(mem.as_ref());
assert_eq!(mem[0], 3);
mem.apply().unwrap();
assert_eq!(mem[0], 3);
println!("After apply:");
print_mem(mem.as_ref());
mem.revert().unwrap();
assert_eq!(mem[0], 2);
println!("After revert:");
print_mem(mem.as_ref());
for i in 2..=4 {
mem.apply().unwrap();
assert_eq!(mem[0], 2);
println!("After apply {}:", i);
print_mem(mem.as_ref());
}
mem.revert().unwrap();
assert_eq!(mem[0], 0);
println!("After revert it goes to 0");
print_mem(mem.as_ref());
}
#[test]
fn tc_apply2_neo() {
let mut mem = Mmap::new(1, PAGE_SIZE).unwrap();
println!("Initial memory state after filling:");
print_mem(mem.as_ref());
for i in 0..99 {
mem[0] = i + 1;
mem.snap().unwrap();
assert_eq!(mem[0], i + 1);
println!("After snap {}:", i + 1);
print_mem(mem.as_ref());
}
assert_eq!(mem.0.snapshots.len(), 100);
for i in 0usize..10 {
assert_eq!(mem[0], 99);
assert_eq!(mem.0.snapshots.len(), 100 - i);
mem.apply().unwrap();
println!("After apply {}:", i);
print_mem(mem.as_ref());
}
assert_eq!(mem[0], 99); mem.revert().unwrap(); assert_eq!(mem[0], 89);
mem.revert().unwrap(); assert_eq!(mem[0], 88); println!("After apply revert:");
print_mem(mem.as_ref());
for i in 0..88 {
if i % 2 == 0 {
mem.apply().unwrap();
assert_eq!(mem[0], 88 - i);
} else {
mem.revert().unwrap();
assert_eq!(mem[0], 88 - (i + 1));
}
}
assert_eq!(mem[0], 0); }
}