use std::io::{Seek, SeekFrom, Write};
use crate::{atomic::AtomicWriter, dirty::DirtyBitmap, error::SnapshotError};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MemorySnapshotKind {
Full,
Diff,
}
pub trait PageReader {
fn read_at(&self, offset_from_ram_start: u64, buf: &mut [u8]) -> std::io::Result<()>;
}
#[derive(Debug)]
pub struct VecPageReader {
bytes: Vec<u8>,
}
impl VecPageReader {
#[must_use]
pub fn new(bytes: Vec<u8>) -> Self {
Self { bytes }
}
pub fn bytes_mut(&mut self) -> &mut [u8] {
&mut self.bytes
}
}
impl PageReader for VecPageReader {
fn read_at(&self, offset: u64, buf: &mut [u8]) -> std::io::Result<()> {
let start = usize::try_from(offset).map_err(|_| {
std::io::Error::new(std::io::ErrorKind::InvalidInput, "offset > usize::MAX")
})?;
let end = start.checked_add(buf.len()).ok_or_else(|| {
std::io::Error::new(std::io::ErrorKind::InvalidInput, "offset overflow")
})?;
if end > self.bytes.len() {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"read past end of memory",
));
}
buf.copy_from_slice(&self.bytes[start..end]);
Ok(())
}
}
#[derive(Debug)]
pub struct MemoryWriter {
inner: AtomicWriter,
ram_size: u64,
page_size: u64,
}
impl MemoryWriter {
pub fn open(
dest: &std::path::Path,
ram_size: u64,
page_size: u64,
) -> Result<Self, SnapshotError> {
Ok(Self {
inner: AtomicWriter::open(dest)?,
ram_size,
page_size,
})
}
#[must_use]
pub fn ram_size(&self) -> u64 {
self.ram_size
}
#[must_use]
pub fn page_size(&self) -> u64 {
self.page_size
}
pub fn write_full<R: PageReader>(&mut self, reader: &R) -> Result<(), SnapshotError> {
let mut buf = vec![
0u8;
usize::try_from(self.page_size).map_err(|_| {
SnapshotError::MemoryIo(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"page_size > usize::MAX",
))
})?
];
let mut offset = 0u64;
while offset < self.ram_size {
let chunk = (self.ram_size - offset).min(self.page_size);
let chunk_usize = usize::try_from(chunk).map_err(|_| {
SnapshotError::MemoryIo(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"chunk > usize::MAX",
))
})?;
let buf_slice = &mut buf[..chunk_usize];
reader
.read_at(offset, buf_slice)
.map_err(SnapshotError::MemoryIo)?;
self.inner
.file_mut()
.write_all(buf_slice)
.map_err(SnapshotError::MemoryIo)?;
offset += chunk;
}
Ok(())
}
pub fn write_diff<R: PageReader>(
&mut self,
reader: &R,
dirty: &DirtyBitmap,
) -> Result<u64, SnapshotError> {
if dirty.ram_size() != self.ram_size {
return Err(SnapshotError::InvalidPath(format!(
"diff bitmap covers {} bytes, memory file expects {}",
dirty.ram_size(),
self.ram_size
)));
}
let bitmap_page = dirty.page_size();
let target_page = self.page_size;
if bitmap_page < target_page || !bitmap_page.is_multiple_of(target_page) {
return Err(SnapshotError::InvalidPath(format!(
"diff bitmap page ({bitmap_page}) must be a multiple of memory-file page \
({target_page})",
)));
}
self.inner
.file_mut()
.set_len(self.ram_size)
.map_err(SnapshotError::MemoryIo)?;
let pages_per_block = bitmap_page / target_page;
let buf_len = usize::try_from(target_page).map_err(|_| {
SnapshotError::MemoryIo(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"page_size > usize::MAX",
))
})?;
let mut buf = vec![0u8; buf_len];
let mut pages_written: u64 = 0;
for bitmap_page_idx in 0..dirty.page_count() {
if !dirty.is_dirty_by_index(bitmap_page_idx) {
continue;
}
let block_byte_offset = bitmap_page_idx * bitmap_page;
for sub in 0..pages_per_block {
let target_offset = block_byte_offset + sub * target_page;
if target_offset >= self.ram_size {
break;
}
let chunk = (self.ram_size - target_offset).min(target_page);
let chunk_usize = usize::try_from(chunk).map_err(|_| {
SnapshotError::MemoryIo(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"chunk > usize::MAX",
))
})?;
let slice = &mut buf[..chunk_usize];
reader
.read_at(target_offset, slice)
.map_err(SnapshotError::MemoryIo)?;
self.inner
.file_mut()
.seek(SeekFrom::Start(target_offset))
.map_err(SnapshotError::MemoryIo)?;
self.inner
.file_mut()
.write_all(slice)
.map_err(SnapshotError::MemoryIo)?;
pages_written += 1;
}
}
Ok(pages_written)
}
pub fn commit(self) -> Result<(), SnapshotError> {
self.inner.commit()
}
}
#[cfg(test)]
mod tests {
use std::path::Path;
use tempfile::TempDir;
use super::*;
fn build_reader(size: usize) -> VecPageReader {
let mut r = VecPageReader::new(vec![0u8; size]);
let bytes = r.bytes_mut();
for (i, byte) in bytes.iter_mut().enumerate() {
*byte = (i % 256) as u8;
}
r
}
fn dest_in(dir: &Path, name: &str) -> std::path::PathBuf {
dir.join(name)
}
#[test]
fn test_should_write_full_dump_dense() {
let dir = TempDir::new().unwrap();
let dest = dest_in(dir.path(), "x.mem");
let ram_size = 64 * 1024;
let reader = build_reader(ram_size);
let mut w = MemoryWriter::open(&dest, ram_size as u64, 16 * 1024).unwrap();
w.write_full(&reader).unwrap();
w.commit().unwrap();
let written = std::fs::read(&dest).unwrap();
assert_eq!(written.len(), ram_size);
assert_eq!(written, reader.bytes);
}
#[test]
fn test_should_write_diff_only_dirty_pages() {
let dir = TempDir::new().unwrap();
let dest = dest_in(dir.path(), "x.mem");
let ram_size: u64 = 64 * 1024;
let page_size: u64 = 16 * 1024;
let reader = build_reader(usize::try_from(ram_size).unwrap());
let bm = DirtyBitmap::new(0, ram_size, page_size).unwrap();
bm.set_dirty_by_index(1); bm.set_dirty_by_index(3);
let mut w = MemoryWriter::open(&dest, ram_size, page_size).unwrap();
let pages = w.write_diff(&reader, &bm).unwrap();
w.commit().unwrap();
assert_eq!(pages, 2);
let written = std::fs::read(&dest).unwrap();
assert_eq!(written.len() as u64, ram_size);
assert!(written[0..page_size as usize].iter().all(|&b| b == 0));
assert!(
written[page_size as usize..(2 * page_size) as usize]
.iter()
.enumerate()
.all(|(i, &b)| b == ((i + page_size as usize) % 256) as u8),
"diff did not preserve dirty page 1's markers",
);
assert!(
written[(2 * page_size) as usize..(3 * page_size) as usize]
.iter()
.all(|&b| b == 0)
);
}
#[test]
fn test_should_unwrite_diff_when_bitmap_covers_finer_units() {
let dir = TempDir::new().unwrap();
let dest = dest_in(dir.path(), "x.mem");
let ram_size: u64 = 4 * 1024 * 1024;
let target_page = 16 * 1024u64;
let bitmap_page = 2 * 1024 * 1024u64;
let reader = build_reader(usize::try_from(ram_size).unwrap());
let bm = DirtyBitmap::new(0, ram_size, bitmap_page).unwrap();
bm.set_dirty_by_index(1);
let mut w = MemoryWriter::open(&dest, ram_size, target_page).unwrap();
let pages = w.write_diff(&reader, &bm).unwrap();
w.commit().unwrap();
assert_eq!(pages, bitmap_page / target_page);
let written = std::fs::read(&dest).unwrap();
assert!(written[..bitmap_page as usize].iter().all(|&b| b == 0));
let expected_byte = |i: usize| (i % 256) as u8;
let block = bitmap_page as usize;
for (i, byte) in written[block..block * 2].iter().enumerate() {
assert_eq!(*byte, expected_byte(block + i));
}
}
}