1use std::io::{Seek, SeekFrom, Write};
14
15use crate::{atomic::AtomicWriter, dirty::DirtyBitmap, error::SnapshotError};
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum MemorySnapshotKind {
21 Full,
23 Diff,
25}
26
27pub trait PageReader {
32 fn read_at(&self, offset_from_ram_start: u64, buf: &mut [u8]) -> std::io::Result<()>;
40}
41
42#[derive(Debug)]
44pub struct VecPageReader {
45 bytes: Vec<u8>,
46}
47
48impl VecPageReader {
49 #[must_use]
52 pub fn new(bytes: Vec<u8>) -> Self {
53 Self { bytes }
54 }
55
56 pub fn bytes_mut(&mut self) -> &mut [u8] {
58 &mut self.bytes
59 }
60}
61
62impl PageReader for VecPageReader {
63 fn read_at(&self, offset: u64, buf: &mut [u8]) -> std::io::Result<()> {
64 let start = usize::try_from(offset).map_err(|_| {
65 std::io::Error::new(std::io::ErrorKind::InvalidInput, "offset > usize::MAX")
66 })?;
67 let end = start.checked_add(buf.len()).ok_or_else(|| {
68 std::io::Error::new(std::io::ErrorKind::InvalidInput, "offset overflow")
69 })?;
70 if end > self.bytes.len() {
71 return Err(std::io::Error::new(
72 std::io::ErrorKind::UnexpectedEof,
73 "read past end of memory",
74 ));
75 }
76 buf.copy_from_slice(&self.bytes[start..end]);
77 Ok(())
78 }
79}
80
81#[derive(Debug)]
84pub struct MemoryWriter {
85 inner: AtomicWriter,
86 ram_size: u64,
87 page_size: u64,
88}
89
90impl MemoryWriter {
91 pub fn open(
101 dest: &std::path::Path,
102 ram_size: u64,
103 page_size: u64,
104 ) -> Result<Self, SnapshotError> {
105 Ok(Self {
106 inner: AtomicWriter::open(dest)?,
107 ram_size,
108 page_size,
109 })
110 }
111
112 #[must_use]
114 pub fn ram_size(&self) -> u64 {
115 self.ram_size
116 }
117
118 #[must_use]
120 pub fn page_size(&self) -> u64 {
121 self.page_size
122 }
123
124 pub fn write_full<R: PageReader>(&mut self, reader: &R) -> Result<(), SnapshotError> {
133 let mut buf = vec![
134 0u8;
135 usize::try_from(self.page_size).map_err(|_| {
136 SnapshotError::MemoryIo(std::io::Error::new(
137 std::io::ErrorKind::InvalidInput,
138 "page_size > usize::MAX",
139 ))
140 })?
141 ];
142 let mut offset = 0u64;
143 while offset < self.ram_size {
144 let chunk = (self.ram_size - offset).min(self.page_size);
145 let chunk_usize = usize::try_from(chunk).map_err(|_| {
146 SnapshotError::MemoryIo(std::io::Error::new(
147 std::io::ErrorKind::InvalidInput,
148 "chunk > usize::MAX",
149 ))
150 })?;
151 let buf_slice = &mut buf[..chunk_usize];
152 reader
153 .read_at(offset, buf_slice)
154 .map_err(SnapshotError::MemoryIo)?;
155 self.inner
156 .file_mut()
157 .write_all(buf_slice)
158 .map_err(SnapshotError::MemoryIo)?;
159 offset += chunk;
160 }
161 Ok(())
162 }
163
164 pub fn write_diff<R: PageReader>(
178 &mut self,
179 reader: &R,
180 dirty: &DirtyBitmap,
181 ) -> Result<u64, SnapshotError> {
182 if dirty.ram_size() != self.ram_size {
183 return Err(SnapshotError::InvalidPath(format!(
184 "diff bitmap covers {} bytes, memory file expects {}",
185 dirty.ram_size(),
186 self.ram_size
187 )));
188 }
189 let bitmap_page = dirty.page_size();
190 let target_page = self.page_size;
191 if bitmap_page < target_page || !bitmap_page.is_multiple_of(target_page) {
192 return Err(SnapshotError::InvalidPath(format!(
193 "diff bitmap page ({bitmap_page}) must be a multiple of memory-file page \
194 ({target_page})",
195 )));
196 }
197 self.inner
201 .file_mut()
202 .set_len(self.ram_size)
203 .map_err(SnapshotError::MemoryIo)?;
204 let pages_per_block = bitmap_page / target_page;
205 let buf_len = usize::try_from(target_page).map_err(|_| {
206 SnapshotError::MemoryIo(std::io::Error::new(
207 std::io::ErrorKind::InvalidInput,
208 "page_size > usize::MAX",
209 ))
210 })?;
211 let mut buf = vec![0u8; buf_len];
212 let mut pages_written: u64 = 0;
213 for bitmap_page_idx in 0..dirty.page_count() {
216 if !dirty.is_dirty_by_index(bitmap_page_idx) {
217 continue;
218 }
219 let block_byte_offset = bitmap_page_idx * bitmap_page;
220 for sub in 0..pages_per_block {
221 let target_offset = block_byte_offset + sub * target_page;
222 if target_offset >= self.ram_size {
223 break;
224 }
225 let chunk = (self.ram_size - target_offset).min(target_page);
226 let chunk_usize = usize::try_from(chunk).map_err(|_| {
227 SnapshotError::MemoryIo(std::io::Error::new(
228 std::io::ErrorKind::InvalidInput,
229 "chunk > usize::MAX",
230 ))
231 })?;
232 let slice = &mut buf[..chunk_usize];
233 reader
234 .read_at(target_offset, slice)
235 .map_err(SnapshotError::MemoryIo)?;
236 self.inner
237 .file_mut()
238 .seek(SeekFrom::Start(target_offset))
239 .map_err(SnapshotError::MemoryIo)?;
240 self.inner
241 .file_mut()
242 .write_all(slice)
243 .map_err(SnapshotError::MemoryIo)?;
244 pages_written += 1;
245 }
246 }
247 Ok(pages_written)
248 }
249
250 pub fn commit(self) -> Result<(), SnapshotError> {
255 self.inner.commit()
256 }
257}
258
259#[cfg(test)]
260mod tests {
261 use std::path::Path;
262
263 use tempfile::TempDir;
264
265 use super::*;
266
267 fn build_reader(size: usize) -> VecPageReader {
268 let mut r = VecPageReader::new(vec![0u8; size]);
269 let bytes = r.bytes_mut();
271 for (i, byte) in bytes.iter_mut().enumerate() {
272 *byte = (i % 256) as u8;
273 }
274 r
275 }
276
277 fn dest_in(dir: &Path, name: &str) -> std::path::PathBuf {
278 dir.join(name)
279 }
280
281 #[test]
282 fn test_should_write_full_dump_dense() {
283 let dir = TempDir::new().unwrap();
284 let dest = dest_in(dir.path(), "x.mem");
285 let ram_size = 64 * 1024;
286 let reader = build_reader(ram_size);
287 let mut w = MemoryWriter::open(&dest, ram_size as u64, 16 * 1024).unwrap();
288 w.write_full(&reader).unwrap();
289 w.commit().unwrap();
290 let written = std::fs::read(&dest).unwrap();
291 assert_eq!(written.len(), ram_size);
292 assert_eq!(written, reader.bytes);
293 }
294
295 #[test]
296 fn test_should_write_diff_only_dirty_pages() {
297 let dir = TempDir::new().unwrap();
298 let dest = dest_in(dir.path(), "x.mem");
299 let ram_size: u64 = 64 * 1024;
300 let page_size: u64 = 16 * 1024;
301 let reader = build_reader(usize::try_from(ram_size).unwrap());
302 let bm = DirtyBitmap::new(0, ram_size, page_size).unwrap();
303 bm.set_dirty_by_index(1); bm.set_dirty_by_index(3); let mut w = MemoryWriter::open(&dest, ram_size, page_size).unwrap();
307 let pages = w.write_diff(&reader, &bm).unwrap();
308 w.commit().unwrap();
309 assert_eq!(pages, 2);
310
311 let written = std::fs::read(&dest).unwrap();
312 assert_eq!(written.len() as u64, ram_size);
313 assert!(written[0..page_size as usize].iter().all(|&b| b == 0));
316 assert!(
317 written[page_size as usize..(2 * page_size) as usize]
318 .iter()
319 .enumerate()
320 .all(|(i, &b)| b == ((i + page_size as usize) % 256) as u8),
321 "diff did not preserve dirty page 1's markers",
322 );
323 assert!(
324 written[(2 * page_size) as usize..(3 * page_size) as usize]
325 .iter()
326 .all(|&b| b == 0)
327 );
328 }
329
330 #[test]
331 fn test_should_unwrite_diff_when_bitmap_covers_finer_units() {
332 let dir = TempDir::new().unwrap();
334 let dest = dest_in(dir.path(), "x.mem");
335 let ram_size: u64 = 4 * 1024 * 1024;
336 let target_page = 16 * 1024u64;
337 let bitmap_page = 2 * 1024 * 1024u64;
338 let reader = build_reader(usize::try_from(ram_size).unwrap());
339 let bm = DirtyBitmap::new(0, ram_size, bitmap_page).unwrap();
340 bm.set_dirty_by_index(1); let mut w = MemoryWriter::open(&dest, ram_size, target_page).unwrap();
343 let pages = w.write_diff(&reader, &bm).unwrap();
344 w.commit().unwrap();
345 assert_eq!(pages, bitmap_page / target_page);
346
347 let written = std::fs::read(&dest).unwrap();
348 assert!(written[..bitmap_page as usize].iter().all(|&b| b == 0));
350 let expected_byte = |i: usize| (i % 256) as u8;
352 let block = bitmap_page as usize;
353 for (i, byte) in written[block..block * 2].iter().enumerate() {
354 assert_eq!(*byte, expected_byte(block + i));
355 }
356 }
357}