use std::ops::Index;
use std::os::fd::{AsRawFd, BorrowedFd};
use std::sync::atomic::{AtomicU8, Ordering};
use std::sync::{Arc, RwLock};
use std::{io, ptr};
use vm_memory::bitmap::{Bitmap, BitmapSlice, WithBitmapSlice};
use vm_memory::mmap::NewBitmap;
use vm_memory::{Address, GuestMemoryRegion};
const LOG_PAGE_SIZE: usize = 0x1000;
const LOG_WORD_SIZE: usize = u8::BITS as usize;
pub trait BitmapReplace: Bitmap {
type InnerBitmap: MemRegionBitmap;
fn replace(&self, bitmap: Self::InnerBitmap);
}
pub trait MemRegionBitmap: Sized {
fn new<R: GuestMemoryRegion>(region: &R, logmem: Arc<MmapLogReg>) -> io::Result<Self>;
}
impl BitmapReplace for () {
type InnerBitmap = ();
fn replace(&self, _bitmap: ()) {
panic!("The unit bitmap () must not be used if VHOST_USER_PROTOCOL_F_LOG_SHMFD is set");
}
}
impl MemRegionBitmap for () {
fn new<R: GuestMemoryRegion>(_region: &R, _logmem: Arc<MmapLogReg>) -> io::Result<Self> {
Err(io::Error::from(io::ErrorKind::Unsupported))
}
}
#[derive(Default, Debug, Clone)]
pub struct BitmapMmapRegion {
inner: Arc<RwLock<Option<AtomicBitmapMmap>>>,
base_address: usize, }
impl Bitmap for BitmapMmapRegion {
fn mark_dirty(&self, offset: usize, len: usize) {
let inner = self.inner.read().unwrap();
if let Some(bitmap) = inner.as_ref() {
if let Some(absolute_offset) = self.base_address.checked_add(offset) {
bitmap.mark_dirty(absolute_offset, len);
}
}
}
fn dirty_at(&self, offset: usize) -> bool {
let inner = self.inner.read().unwrap();
inner
.as_ref()
.is_some_and(|bitmap| bitmap.dirty_at(self.base_address.saturating_add(offset)))
}
fn slice_at(&self, offset: usize) -> <Self as WithBitmapSlice<'_>>::S {
Self {
inner: Arc::clone(&self.inner),
base_address: self.base_address.saturating_add(offset),
}
}
}
impl BitmapReplace for BitmapMmapRegion {
type InnerBitmap = AtomicBitmapMmap;
fn replace(&self, bitmap: AtomicBitmapMmap) {
let mut inner = self.inner.write().unwrap();
inner.replace(bitmap);
}
}
impl BitmapSlice for BitmapMmapRegion {}
impl WithBitmapSlice<'_> for BitmapMmapRegion {
type S = Self;
}
impl NewBitmap for BitmapMmapRegion {
fn with_len(_len: usize) -> Self {
Self::default()
}
}
#[derive(Debug)]
pub struct AtomicBitmapMmap {
logmem: Arc<MmapLogReg>,
pages_before_region: usize, number_of_pages: usize, }
impl MemRegionBitmap for AtomicBitmapMmap {
fn new<R: GuestMemoryRegion>(region: &R, logmem: Arc<MmapLogReg>) -> io::Result<Self> {
let region_start_addr: usize = region.start_addr().raw_value().io_try_into()?;
let region_len: usize = region.len().io_try_into()?;
if region_len == 0 {
return Err(io::Error::from(io::ErrorKind::InvalidData));
}
let region_end_addr = region_start_addr
.checked_add(region_len - 1)
.ok_or(io::Error::from(io::ErrorKind::InvalidData))?;
let region_end_log_word = page_word(page_number(region_end_addr));
if region_end_log_word >= logmem.len() {
return Err(io::Error::from(io::ErrorKind::InvalidData));
}
let offset_pages = page_number(region_start_addr);
let size_page = page_number(region_len);
Ok(Self {
logmem,
pages_before_region: offset_pages,
number_of_pages: size_page,
})
}
}
impl AtomicBitmapMmap {
fn mark_dirty(&self, offset: usize, len: usize) {
if len == 0 {
return;
}
let first_page = page_number(offset);
let last_page = page_number(offset.saturating_add(len - 1));
for page in first_page..=last_page {
if page >= self.number_of_pages {
break; }
let page = self.pages_before_region + page;
self.logmem[page_word(page)].fetch_or(1 << page_bit(page), Ordering::Relaxed);
}
}
fn dirty_at(&self, offset: usize) -> bool {
let page = page_number(offset);
if page >= self.number_of_pages {
return false; }
let page = self.pages_before_region + page;
let page_bit = self.logmem[page_word(page)].load(Ordering::Relaxed) & (1 << page_bit(page));
page_bit != 0
}
}
#[derive(Debug)]
pub struct MmapLogReg {
addr: *const AtomicU8,
len: usize,
}
unsafe impl Send for MmapLogReg {}
unsafe impl Sync for MmapLogReg {}
impl MmapLogReg {
pub(crate) fn from_file(fd: BorrowedFd, offset: u64, len: u64) -> io::Result<Self> {
let offset: isize = offset.io_try_into()?;
let len: usize = len.io_try_into()?;
if len > isize::MAX as usize {
return Err(io::Error::from(io::ErrorKind::InvalidData));
}
let addr = unsafe {
libc::mmap(
ptr::null_mut(),
len as libc::size_t,
libc::PROT_READ | libc::PROT_WRITE,
libc::MAP_SHARED,
fd.as_raw_fd(),
offset as libc::off_t,
)
};
if addr == libc::MAP_FAILED {
return Err(io::Error::last_os_error());
}
Ok(Self {
addr: addr as *const AtomicU8,
len,
})
}
fn len(&self) -> usize {
self.len
}
}
impl Index<usize> for MmapLogReg {
type Output = AtomicU8;
fn index(&self, index: usize) -> &Self::Output {
assert!(index < self.len);
unsafe { &*self.addr.add(index) }
}
}
impl Drop for MmapLogReg {
fn drop(&mut self) {
unsafe {
libc::munmap(self.addr as *mut libc::c_void, self.len as libc::size_t);
}
}
}
trait IoTryInto<T: TryFrom<Self>>: Sized {
fn io_try_into(self) -> io::Result<T>;
}
impl<TySrc, TyDst> IoTryInto<TyDst> for TySrc
where
TyDst: TryFrom<TySrc>,
<TyDst as TryFrom<TySrc>>::Error: Send + Sync + std::error::Error + 'static,
{
fn io_try_into(self) -> io::Result<TyDst> {
self.try_into()
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
}
}
#[inline]
fn page_number(addr: usize) -> usize {
addr / LOG_PAGE_SIZE
}
#[inline]
fn page_word(page: usize) -> usize {
page / LOG_WORD_SIZE
}
#[inline]
fn page_bit(page: usize) -> usize {
page % LOG_WORD_SIZE
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs::File;
use std::io::Write;
use std::os::fd::AsFd;
use vm_memory::{GuestAddress, GuestRegionMmap};
use vmm_sys_util::tempfile::TempFile;
pub fn range_is_clean<B: Bitmap>(b: &B, start: usize, len: usize) -> bool {
(start..start + len).all(|offset| !b.dirty_at(offset))
}
pub fn range_is_dirty<B: Bitmap>(b: &B, start: usize, len: usize) -> bool {
(start..start + len).all(|offset| b.dirty_at(offset))
}
fn tmp_file(len: usize) -> File {
let mut f = TempFile::new().unwrap().into_file();
let buf = vec![0; len];
f.write_all(buf.as_ref()).unwrap();
f
}
fn test_all(b: &BitmapMmapRegion, len: usize) {
assert!(range_is_clean(b, 0, len), "The bitmap should be clean");
b.mark_dirty(0, len);
assert!(range_is_dirty(b, 0, len), "The bitmap should be dirty");
}
#[test]
#[cfg(not(miri))] fn test_bitmap_region_bigger_than_log() {
let mmap_offset: u64 = 0;
let mmap_size = 1; let f = tmp_file(mmap_size);
let region_start_addr = GuestAddress(mmap_offset);
let region_len = LOG_PAGE_SIZE * 16;
let region: GuestRegionMmap<()> =
GuestRegionMmap::from_range(region_start_addr, region_len, None).unwrap();
let logmem =
Arc::new(MmapLogReg::from_file(f.as_fd(), mmap_offset, mmap_size as u64).unwrap());
let log = AtomicBitmapMmap::new(®ion, logmem);
assert!(log.is_err());
}
#[test]
#[cfg(not(miri))] fn test_bitmap_log_and_region_same_size() {
let mmap_offset: u64 = 0;
let mmap_size = 4; let f = tmp_file(mmap_size);
let region_start_addr = GuestAddress::new(mmap_offset);
let region_len = LOG_PAGE_SIZE * 32;
let region: GuestRegionMmap<()> =
GuestRegionMmap::from_range(region_start_addr, region_len, None).unwrap();
let logmem =
Arc::new(MmapLogReg::from_file(f.as_fd(), mmap_offset, mmap_size as u64).unwrap());
let log = AtomicBitmapMmap::new(®ion, logmem);
assert!(log.is_ok());
let log = log.unwrap();
let bitmap = BitmapMmapRegion::default();
bitmap.replace(log);
test_all(&bitmap, region_len);
}
#[test]
#[cfg(not(miri))] fn test_bitmap_region_smaller_than_log() {
let mmap_offset: u64 = 0;
let mmap_size = 4; let f = tmp_file(mmap_size);
let region_start_addr = GuestAddress::new(mmap_offset);
let region_len = LOG_PAGE_SIZE * 16;
let region: GuestRegionMmap<()> =
GuestRegionMmap::from_range(region_start_addr, region_len, None).unwrap();
let logmem =
Arc::new(MmapLogReg::from_file(f.as_fd(), mmap_offset, mmap_size as u64).unwrap());
let log = AtomicBitmapMmap::new(®ion, logmem);
assert!(log.is_ok());
let log = log.unwrap();
let bitmap = BitmapMmapRegion::default();
bitmap.replace(log);
test_all(&bitmap, region_len);
}
#[test]
#[cfg(not(miri))] fn test_bitmap_region_smaller_than_one_word() {
let mmap_offset: u64 = 0;
let mmap_size = 4; let f = tmp_file(mmap_size);
let region_start_addr = GuestAddress::new(mmap_offset);
let region_len = LOG_PAGE_SIZE * 6;
let region: GuestRegionMmap<()> =
GuestRegionMmap::from_range(region_start_addr, region_len, None).unwrap();
let logmem =
Arc::new(MmapLogReg::from_file(f.as_fd(), mmap_offset, mmap_size as u64).unwrap());
let log = AtomicBitmapMmap::new(®ion, logmem);
assert!(log.is_ok());
let log = log.unwrap();
let bitmap = BitmapMmapRegion::default();
bitmap.replace(log);
test_all(&bitmap, region_len);
}
#[test]
#[cfg(not(miri))] fn test_bitmap_two_regions_overlapping_word_first_dirty() {
let mmap_offset: u64 = 0;
let mmap_size = 4; let f = tmp_file(mmap_size);
let logmem =
Arc::new(MmapLogReg::from_file(f.as_fd(), mmap_offset, mmap_size as u64).unwrap());
let region0_start_addr = GuestAddress::new(mmap_offset);
let region0_len = LOG_PAGE_SIZE * 11;
let region0: GuestRegionMmap<()> =
GuestRegionMmap::from_range(region0_start_addr, region0_len, None).unwrap();
let log0 = AtomicBitmapMmap::new(®ion0, Arc::clone(&logmem));
assert!(log0.is_ok());
let log0 = log0.unwrap();
let bitmap0 = BitmapMmapRegion::default();
bitmap0.replace(log0);
let region1_start_addr = GuestAddress::new(mmap_offset + LOG_PAGE_SIZE as u64 * 14);
let region1_len = LOG_PAGE_SIZE;
let region1: GuestRegionMmap<()> =
GuestRegionMmap::from_range(region1_start_addr, region1_len, None).unwrap();
let log1 = AtomicBitmapMmap::new(®ion1, Arc::clone(&logmem));
assert!(log1.is_ok());
let log1 = log1.unwrap();
let bitmap1 = BitmapMmapRegion::default();
bitmap1.replace(log1);
assert!(
range_is_clean(&bitmap0, 0, region0_len),
"The bitmap0 should be clean"
);
assert!(
range_is_clean(&bitmap1, 0, region1_len),
"The bitmap1 should be clean"
);
bitmap0.mark_dirty(0, region0_len);
assert!(
range_is_dirty(&bitmap0, 0, region0_len),
"The bitmap0 should be dirty"
);
assert!(
range_is_clean(&bitmap1, 0, region1_len),
"The bitmap1 should be clean"
);
}
#[test]
#[cfg(not(miri))] fn test_bitmap_two_regions_overlapping_word_second_dirty() {
let mmap_offset: u64 = 0;
let mmap_size = 4; let f = tmp_file(mmap_size);
let logmem =
Arc::new(MmapLogReg::from_file(f.as_fd(), mmap_offset, mmap_size as u64).unwrap());
let region0_start_addr = GuestAddress::new(mmap_offset);
let region0_len = LOG_PAGE_SIZE * 11;
let region0: GuestRegionMmap<()> =
GuestRegionMmap::from_range(region0_start_addr, region0_len, None).unwrap();
let log0 = AtomicBitmapMmap::new(®ion0, Arc::clone(&logmem));
assert!(log0.is_ok());
let log0 = log0.unwrap();
let bitmap0 = BitmapMmapRegion::default();
bitmap0.replace(log0);
let region1_start_addr = GuestAddress::new(mmap_offset + LOG_PAGE_SIZE as u64 * 14);
let region1_len = LOG_PAGE_SIZE;
let region1: GuestRegionMmap<()> =
GuestRegionMmap::from_range(region1_start_addr, region1_len, None).unwrap();
let log1 = AtomicBitmapMmap::new(®ion1, Arc::clone(&logmem));
assert!(log1.is_ok());
let log1 = log1.unwrap();
let bitmap1 = BitmapMmapRegion::default();
bitmap1.replace(log1);
assert!(
range_is_clean(&bitmap0, 0, region0_len),
"The bitmap0 should be clean"
);
assert!(
range_is_clean(&bitmap1, 0, region1_len),
"The bitmap1 should be clean"
);
bitmap1.mark_dirty(0, region1_len);
assert!(
range_is_dirty(&bitmap1, 0, region1_len),
"The bitmap0 should be dirty"
);
assert!(
range_is_clean(&bitmap0, 0, region0_len),
"The bitmap1 should be clean"
);
}
#[test]
#[cfg(not(miri))] fn test_bitmap_region_slice() {
let mmap_offset: u64 = 0;
let mmap_size = 4; let f = tmp_file(mmap_size);
let region_start_addr = GuestAddress::new(mmap_offset);
let region_len = LOG_PAGE_SIZE * 32;
let region: GuestRegionMmap<()> =
GuestRegionMmap::from_range(region_start_addr, region_len, None).unwrap();
let logmem =
Arc::new(MmapLogReg::from_file(f.as_fd(), mmap_offset, mmap_size as u64).unwrap());
let log = AtomicBitmapMmap::new(®ion, logmem);
assert!(log.is_ok());
let log = log.unwrap();
let bitmap = BitmapMmapRegion::default();
bitmap.replace(log);
assert!(
range_is_clean(&bitmap, 0, region_len),
"The bitmap should be clean"
);
let slice_len = region_len / 2;
let slice = bitmap.slice_at(slice_len);
assert!(
range_is_clean(&slice, 0, slice_len),
"The slice should be clean"
);
slice.mark_dirty(0, slice_len);
assert!(
range_is_dirty(&slice, 0, slice_len),
"The slice should be dirty"
);
assert!(
range_is_clean(&bitmap, 0, slice_len),
"The first half of the bitmap should be clean"
);
assert!(
range_is_dirty(&bitmap, slice_len, region_len - slice_len),
"The last half of the bitmap should be dirty"
);
}
}