#![cfg_attr(not(unix), allow(unused_imports, unused_variables))]
use crate::MmapVec;
use anyhow::Result;
use libc::c_void;
use std::fs::File;
use std::sync::Arc;
use std::{convert::TryFrom, ops::Range};
use wasmtime_environ::{DefinedMemoryIndex, MemoryInitialization, MemoryStyle, Module, PrimaryMap};
pub struct ModuleMemoryImages {
memories: PrimaryMap<DefinedMemoryIndex, Option<Arc<MemoryImage>>>,
}
impl ModuleMemoryImages {
pub fn get_memory_image(&self, defined_index: DefinedMemoryIndex) -> Option<&Arc<MemoryImage>> {
self.memories[defined_index].as_ref()
}
}
#[derive(Debug, PartialEq)]
pub struct MemoryImage {
fd: FdSource,
len: usize,
fd_offset: u64,
linear_memory_offset: usize,
}
#[derive(Debug)]
enum FdSource {
#[cfg(unix)]
Mmap(Arc<File>),
#[cfg(target_os = "linux")]
Memfd(memfd::Memfd),
}
impl FdSource {
#[cfg(unix)]
fn as_file(&self) -> &File {
match self {
FdSource::Mmap(ref file) => file,
#[cfg(target_os = "linux")]
FdSource::Memfd(ref memfd) => memfd.as_file(),
}
}
}
impl PartialEq for FdSource {
fn eq(&self, other: &FdSource) -> bool {
cfg_if::cfg_if! {
if #[cfg(unix)] {
use rustix::fd::AsRawFd;
self.as_file().as_raw_fd() == other.as_file().as_raw_fd()
} else {
drop(other);
match *self {}
}
}
}
}
impl MemoryImage {
fn new(
page_size: u32,
offset: u64,
data: &[u8],
mmap: Option<&MmapVec>,
) -> Result<Option<MemoryImage>> {
let len = data.len();
assert_eq!(offset % u64::from(page_size), 0);
assert_eq!((len as u32) % page_size, 0);
let linear_memory_offset = match usize::try_from(offset) {
Ok(offset) => offset,
Err(_) => return Ok(None),
};
#[cfg(not(windows))]
if let Some(mmap) = mmap {
let start = mmap.as_ptr() as usize;
let end = start + mmap.len();
let data_start = data.as_ptr() as usize;
let data_end = data_start + data.len();
assert!(start <= data_start && data_end <= end);
assert_eq!((start as u32) % page_size, 0);
assert_eq!((data_start as u32) % page_size, 0);
assert_eq!((data_end as u32) % page_size, 0);
assert_eq!((mmap.original_offset() as u32) % page_size, 0);
if let Some(file) = mmap.original_file() {
return Ok(Some(MemoryImage {
fd: FdSource::Mmap(file.clone()),
fd_offset: u64::try_from(mmap.original_offset() + (data_start - start))
.unwrap(),
linear_memory_offset,
len,
}));
}
}
cfg_if::cfg_if! {
if #[cfg(target_os = "linux")] {
use std::io::Write;
let memfd = create_memfd()?;
memfd.as_file().write_all(data)?;
memfd.add_seals(&[
memfd::FileSeal::SealGrow,
memfd::FileSeal::SealShrink,
memfd::FileSeal::SealWrite,
memfd::FileSeal::SealSeal,
])?;
Ok(Some(MemoryImage {
fd: FdSource::Memfd(memfd),
fd_offset: 0,
linear_memory_offset,
len,
}))
} else {
Ok(None)
}
}
}
unsafe fn map_at(&self, base: usize) -> Result<()> {
cfg_if::cfg_if! {
if #[cfg(unix)] {
let ptr = rustix::mm::mmap(
(base + self.linear_memory_offset) as *mut c_void,
self.len,
rustix::mm::ProtFlags::READ | rustix::mm::ProtFlags::WRITE,
rustix::mm::MapFlags::PRIVATE | rustix::mm::MapFlags::FIXED,
self.fd.as_file(),
self.fd_offset,
)?;
assert_eq!(ptr as usize, base + self.linear_memory_offset);
Ok(())
} else {
match self.fd {}
}
}
}
unsafe fn remap_as_zeros_at(&self, base: usize) -> Result<()> {
cfg_if::cfg_if! {
if #[cfg(unix)] {
let ptr = rustix::mm::mmap_anonymous(
(base + self.linear_memory_offset) as *mut c_void,
self.len,
rustix::mm::ProtFlags::READ | rustix::mm::ProtFlags::WRITE,
rustix::mm::MapFlags::PRIVATE | rustix::mm::MapFlags::FIXED,
)?;
assert_eq!(ptr as usize, base + self.linear_memory_offset);
Ok(())
} else {
match self.fd {}
}
}
}
}
#[cfg(target_os = "linux")]
fn create_memfd() -> Result<memfd::Memfd> {
memfd::MemfdOptions::new()
.allow_sealing(true)
.create("wasm-memory-image")
.map_err(|e| e.into())
}
impl ModuleMemoryImages {
pub fn new(
module: &Module,
wasm_data: &[u8],
mmap: Option<&MmapVec>,
) -> Result<Option<ModuleMemoryImages>> {
let map = match &module.memory_initialization {
MemoryInitialization::Static { map } => map,
_ => return Ok(None),
};
let mut memories = PrimaryMap::with_capacity(map.len());
let page_size = crate::page_size() as u32;
for (memory_index, init) in map {
let defined_memory = match module.defined_memory_index(memory_index) {
Some(idx) => idx,
None => return Ok(None),
};
let init = match init {
Some(init) => init,
None => {
memories.push(None);
continue;
}
};
let data = &wasm_data[init.data.start as usize..init.data.end as usize];
let image = match MemoryImage::new(page_size, init.offset, data, mmap)? {
Some(image) => image,
None => return Ok(None),
};
let idx = memories.push(Some(Arc::new(image)));
assert_eq!(idx, defined_memory);
}
Ok(Some(ModuleMemoryImages { memories }))
}
}
#[derive(Debug)]
pub struct MemoryImageSlot {
base: usize,
static_size: usize,
image: Option<Arc<MemoryImage>>,
accessible: usize,
dirty: bool,
clear_on_drop: bool,
}
impl MemoryImageSlot {
pub(crate) fn create(base_addr: *mut c_void, accessible: usize, static_size: usize) -> Self {
let base = base_addr as usize;
MemoryImageSlot {
base,
static_size,
accessible,
image: None,
dirty: false,
clear_on_drop: true,
}
}
#[cfg(feature = "pooling-allocator")]
pub(crate) fn dummy() -> MemoryImageSlot {
MemoryImageSlot {
base: 0,
static_size: 0,
image: None,
accessible: 0,
dirty: false,
clear_on_drop: false,
}
}
pub(crate) fn no_clear_on_drop(&mut self) {
self.clear_on_drop = false;
}
pub(crate) fn set_heap_limit(&mut self, size_bytes: usize) -> Result<()> {
assert!(size_bytes <= self.static_size);
if size_bytes <= self.accessible {
return Ok(());
}
self.set_protection(self.accessible..size_bytes, true)?;
self.accessible = size_bytes;
Ok(())
}
pub(crate) fn instantiate(
&mut self,
initial_size_bytes: usize,
maybe_image: Option<&Arc<MemoryImage>>,
style: &MemoryStyle,
) -> Result<()> {
assert!(!self.dirty);
assert!(initial_size_bytes <= self.static_size);
if self.image.as_ref() != maybe_image {
self.remove_image()?;
}
if self.accessible < initial_size_bytes {
self.set_protection(self.accessible..initial_size_bytes, true)?;
self.accessible = initial_size_bytes;
}
if initial_size_bytes < self.accessible {
match style {
MemoryStyle::Static { .. } => {
self.set_protection(initial_size_bytes..self.accessible, false)?;
self.accessible = initial_size_bytes;
}
MemoryStyle::Dynamic { .. } => {}
}
}
assert!(initial_size_bytes <= self.accessible);
if self.image.as_ref() != maybe_image {
if let Some(image) = maybe_image.as_ref() {
assert!(
image.linear_memory_offset.checked_add(image.len).unwrap()
<= initial_size_bytes
);
if image.len > 0 {
unsafe {
image.map_at(self.base)?;
}
}
}
self.image = maybe_image.cloned();
}
self.dirty = true;
Ok(())
}
pub(crate) fn remove_image(&mut self) -> Result<()> {
if let Some(image) = &self.image {
unsafe {
image.remap_as_zeros_at(self.base)?;
}
self.image = None;
}
Ok(())
}
#[allow(dead_code)] pub(crate) fn clear_and_remain_ready(&mut self, keep_resident: usize) -> Result<()> {
assert!(self.dirty);
unsafe {
self.reset_all_memory_contents(keep_resident)?;
}
self.dirty = false;
Ok(())
}
#[allow(dead_code)] unsafe fn reset_all_memory_contents(&mut self, keep_resident: usize) -> Result<()> {
if !cfg!(target_os = "linux") {
return self.reset_with_anon_memory();
}
match &self.image {
Some(image) => {
assert!(self.accessible >= image.linear_memory_offset + image.len);
if image.linear_memory_offset < keep_resident {
let image_end = image.linear_memory_offset + image.len;
let mem_after_image = self.accessible - image_end;
let remaining_memset =
(keep_resident - image.linear_memory_offset).min(mem_after_image);
std::ptr::write_bytes(self.base as *mut u8, 0u8, image.linear_memory_offset);
self.madvise_reset(image.linear_memory_offset, image.len)?;
std::ptr::write_bytes(
(self.base + image_end) as *mut u8,
0u8,
remaining_memset,
);
self.madvise_reset(
image_end + remaining_memset,
mem_after_image - remaining_memset,
)?;
} else {
std::ptr::write_bytes(self.base as *mut u8, 0u8, keep_resident);
self.madvise_reset(keep_resident, self.accessible - keep_resident)?;
}
}
None => {
let size_to_memset = keep_resident.min(self.accessible);
std::ptr::write_bytes(self.base as *mut u8, 0u8, size_to_memset);
self.madvise_reset(size_to_memset, self.accessible - size_to_memset)?;
}
}
Ok(())
}
#[allow(dead_code)] unsafe fn madvise_reset(&self, base: usize, len: usize) -> Result<()> {
assert!(base + len <= self.accessible);
if len == 0 {
return Ok(());
}
cfg_if::cfg_if! {
if #[cfg(target_os = "linux")] {
rustix::mm::madvise(
(self.base + base) as *mut c_void,
len,
rustix::mm::Advice::LinuxDontNeed,
)?;
Ok(())
} else {
unreachable!();
}
}
}
fn set_protection(&self, range: Range<usize>, readwrite: bool) -> Result<()> {
assert!(range.start <= range.end);
assert!(range.end <= self.static_size);
let start = self.base.checked_add(range.start).unwrap();
if range.len() == 0 {
return Ok(());
}
unsafe {
cfg_if::cfg_if! {
if #[cfg(unix)] {
let flags = if readwrite {
rustix::mm::MprotectFlags::READ | rustix::mm::MprotectFlags::WRITE
} else {
rustix::mm::MprotectFlags::empty()
};
rustix::mm::mprotect(start as *mut _, range.len(), flags)?;
} else {
use windows_sys::Win32::System::Memory::*;
let failure = if readwrite {
VirtualAlloc(start as _, range.len(), MEM_COMMIT, PAGE_READWRITE).is_null()
} else {
VirtualFree(start as _, range.len(), MEM_DECOMMIT) == 0
};
if failure {
return Err(std::io::Error::last_os_error().into());
}
}
}
}
Ok(())
}
pub(crate) fn has_image(&self) -> bool {
self.image.is_some()
}
#[allow(dead_code)] pub(crate) fn is_dirty(&self) -> bool {
self.dirty
}
fn reset_with_anon_memory(&mut self) -> Result<()> {
if self.static_size == 0 {
assert!(self.image.is_none());
assert_eq!(self.accessible, 0);
return Ok(());
}
unsafe {
cfg_if::cfg_if! {
if #[cfg(unix)] {
let ptr = rustix::mm::mmap_anonymous(
self.base as *mut c_void,
self.static_size,
rustix::mm::ProtFlags::empty(),
rustix::mm::MapFlags::PRIVATE | rustix::mm::MapFlags::FIXED,
)?;
assert_eq!(ptr as usize, self.base);
} else {
use windows_sys::Win32::System::Memory::*;
if VirtualFree(self.base as _, self.static_size, MEM_DECOMMIT) == 0 {
return Err(std::io::Error::last_os_error().into());
}
}
}
}
self.image = None;
self.accessible = 0;
Ok(())
}
}
impl Drop for MemoryImageSlot {
fn drop(&mut self) {
if self.clear_on_drop {
self.reset_with_anon_memory().unwrap();
}
}
}
#[cfg(all(test, target_os = "linux"))]
mod test {
use std::sync::Arc;
use super::{create_memfd, FdSource, MemoryImage, MemoryImageSlot, MemoryStyle};
use crate::mmap::Mmap;
use anyhow::Result;
use std::io::Write;
fn create_memfd_with_data(offset: usize, data: &[u8]) -> Result<MemoryImage> {
let page_size = crate::page_size();
assert_eq!(offset & (page_size - 1), 0);
let memfd = create_memfd()?;
memfd.as_file().write_all(data)?;
let image_len = (data.len() + page_size - 1) & !(page_size - 1);
memfd.as_file().set_len(image_len as u64)?;
Ok(MemoryImage {
fd: FdSource::Memfd(memfd),
len: image_len,
fd_offset: 0,
linear_memory_offset: offset,
})
}
#[test]
fn instantiate_no_image() {
let style = MemoryStyle::Static { bound: 4 << 30 };
let mut mmap = Mmap::accessible_reserved(0, 4 << 20).unwrap();
let mut memfd = MemoryImageSlot::create(mmap.as_mut_ptr() as *mut _, 0, 4 << 20);
memfd.no_clear_on_drop();
assert!(!memfd.is_dirty());
memfd.instantiate(64 << 10, None, &style).unwrap();
assert!(memfd.is_dirty());
let slice = mmap.as_mut_slice();
assert_eq!(0, slice[0]);
assert_eq!(0, slice[65535]);
slice[1024] = 42;
assert_eq!(42, slice[1024]);
memfd.set_heap_limit(128 << 10).unwrap();
let slice = mmap.as_slice();
assert_eq!(42, slice[1024]);
assert_eq!(0, slice[131071]);
memfd.clear_and_remain_ready(0).unwrap();
assert!(!memfd.is_dirty());
memfd.instantiate(64 << 10, None, &style).unwrap();
let slice = mmap.as_slice();
assert_eq!(0, slice[1024]);
}
#[test]
fn instantiate_image() {
let style = MemoryStyle::Static { bound: 4 << 30 };
let mut mmap = Mmap::accessible_reserved(0, 4 << 20).unwrap();
let mut memfd = MemoryImageSlot::create(mmap.as_mut_ptr() as *mut _, 0, 4 << 20);
memfd.no_clear_on_drop();
let image = Arc::new(create_memfd_with_data(4096, &[1, 2, 3, 4]).unwrap());
memfd.instantiate(64 << 10, Some(&image), &style).unwrap();
assert!(memfd.has_image());
let slice = mmap.as_mut_slice();
assert_eq!(&[1, 2, 3, 4], &slice[4096..4100]);
slice[4096] = 5;
memfd.clear_and_remain_ready(0).unwrap();
memfd.instantiate(64 << 10, Some(&image), &style).unwrap();
let slice = mmap.as_slice();
assert_eq!(&[1, 2, 3, 4], &slice[4096..4100]);
memfd.clear_and_remain_ready(0).unwrap();
memfd.instantiate(64 << 10, None, &style).unwrap();
assert!(!memfd.has_image());
let slice = mmap.as_slice();
assert_eq!(&[0, 0, 0, 0], &slice[4096..4100]);
memfd.clear_and_remain_ready(0).unwrap();
memfd.instantiate(64 << 10, Some(&image), &style).unwrap();
let slice = mmap.as_slice();
assert_eq!(&[1, 2, 3, 4], &slice[4096..4100]);
let image2 = Arc::new(create_memfd_with_data(4096, &[10, 11, 12, 13]).unwrap());
memfd.clear_and_remain_ready(0).unwrap();
memfd.instantiate(128 << 10, Some(&image2), &style).unwrap();
let slice = mmap.as_slice();
assert_eq!(&[10, 11, 12, 13], &slice[4096..4100]);
memfd.clear_and_remain_ready(0).unwrap();
memfd.instantiate(64 << 10, Some(&image), &style).unwrap();
let slice = mmap.as_slice();
assert_eq!(&[1, 2, 3, 4], &slice[4096..4100]);
}
#[test]
#[cfg(target_os = "linux")]
fn memset_instead_of_madvise() {
let style = MemoryStyle::Static { bound: 100 };
let mut mmap = Mmap::accessible_reserved(0, 4 << 20).unwrap();
let mut memfd = MemoryImageSlot::create(mmap.as_mut_ptr() as *mut _, 0, 4 << 20);
memfd.no_clear_on_drop();
for image_off in [0, 4096, 8 << 10] {
let image = Arc::new(create_memfd_with_data(image_off, &[1, 2, 3, 4]).unwrap());
for amt_to_memset in [0, 4096, 10 << 12, 1 << 20, 10 << 20] {
memfd.instantiate(64 << 10, Some(&image), &style).unwrap();
assert!(memfd.has_image());
let slice = mmap.as_mut_slice();
if image_off > 0 {
assert_eq!(slice[image_off - 1], 0);
}
assert_eq!(slice[image_off + 5], 0);
assert_eq!(&[1, 2, 3, 4], &slice[image_off..][..4]);
slice[image_off] = 5;
assert_eq!(&[5, 2, 3, 4], &slice[image_off..][..4]);
memfd.clear_and_remain_ready(amt_to_memset).unwrap();
}
}
for amt_to_memset in [0, 4096, 10 << 12, 1 << 20, 10 << 20] {
memfd.instantiate(64 << 10, None, &style).unwrap();
for chunk in mmap.as_mut_slice()[..64 << 10].chunks_mut(1024) {
assert_eq!(chunk[0], 0);
chunk[0] = 5;
}
memfd.clear_and_remain_ready(amt_to_memset).unwrap();
}
}
#[test]
#[cfg(target_os = "linux")]
fn dynamic() {
let style = MemoryStyle::Dynamic { reserve: 200 };
let mut mmap = Mmap::accessible_reserved(0, 4 << 20).unwrap();
let mut memfd = MemoryImageSlot::create(mmap.as_mut_ptr() as *mut _, 0, 4 << 20);
memfd.no_clear_on_drop();
let image = Arc::new(create_memfd_with_data(4096, &[1, 2, 3, 4]).unwrap());
let initial = 64 << 10;
memfd.instantiate(initial, Some(&image), &style).unwrap();
assert!(memfd.has_image());
let slice = mmap.as_mut_slice();
assert_eq!(&[1, 2, 3, 4], &slice[4096..4100]);
slice[4096] = 5;
assert_eq!(&[5, 2, 3, 4], &slice[4096..4100]);
memfd.clear_and_remain_ready(0).unwrap();
assert_eq!(&[1, 2, 3, 4], &slice[4096..4100]);
memfd.instantiate(initial, Some(&image), &style).unwrap();
assert_eq!(&[1, 2, 3, 4], &slice[4096..4100]);
memfd.set_heap_limit(initial * 2).unwrap();
assert_eq!(&[0, 0], &slice[initial..initial + 2]);
slice[initial] = 100;
assert_eq!(&[100, 0], &slice[initial..initial + 2]);
memfd.clear_and_remain_ready(0).unwrap();
assert_eq!(&[0, 0], &slice[initial..initial + 2]);
memfd.instantiate(initial, Some(&image), &style).unwrap();
assert_eq!(&[0, 0], &slice[initial..initial + 2]);
memfd.set_heap_limit(initial * 2).unwrap();
assert_eq!(&[0, 0], &slice[initial..initial + 2]);
slice[initial] = 100;
assert_eq!(&[100, 0], &slice[initial..initial + 2]);
memfd.clear_and_remain_ready(0).unwrap();
memfd.instantiate(64 << 10, None, &style).unwrap();
assert!(!memfd.has_image());
assert_eq!(&[0, 0, 0, 0], &slice[4096..4100]);
assert_eq!(&[0, 0], &slice[initial..initial + 2]);
}
}