use rayon::iter::{IntoParallelIterator, ParallelExtend, ParallelIterator};
use std::convert::TryFrom;
use wasmtime::{AsContext, AsContextMut};
const WASM_PAGE_SIZE: u64 = 65_536;
const MAX_DATA_SEGMENTS: usize = 10_000;
pub struct Snapshot {
pub globals: Vec<wasmtime::Val>,
pub memory_mins: Vec<u64>,
pub data_segments: Vec<DataSegment>,
}
#[derive(Clone, Copy)]
pub struct DataSegment {
pub memory_index: u32,
pub memory: wasmtime::Memory,
pub offset: u32,
pub len: u32,
}
impl DataSegment {
pub fn data<'a>(&self, ctx: &'a impl AsContext) -> &'a [u8] {
let start = usize::try_from(self.offset).unwrap();
let end = start + usize::try_from(self.len).unwrap();
&self.memory.data(ctx)[start..end]
}
}
impl DataSegment {
fn gap(&self, other: &Self) -> u32 {
debug_assert_eq!(self.memory_index, other.memory_index);
debug_assert!(self.offset + self.len <= other.offset);
other.offset - (self.offset + self.len)
}
fn merge(&self, other: &Self) -> DataSegment {
let gap = self.gap(other);
DataSegment {
offset: self.offset,
len: self.len + gap + other.len,
..*self
}
}
}
pub fn snapshot(ctx: &mut impl AsContextMut, instance: &wasmtime::Instance) -> Snapshot {
log::debug!("Snapshotting the initialized state");
let globals = snapshot_globals(&mut *ctx, instance);
let (memory_mins, data_segments) = snapshot_memories(&mut *ctx, instance);
Snapshot {
globals,
memory_mins,
data_segments,
}
}
fn snapshot_globals(
ctx: &mut impl AsContextMut,
instance: &wasmtime::Instance,
) -> Vec<wasmtime::Val> {
log::debug!("Snapshotting global values");
let mut globals = vec![];
let mut index = 0;
loop {
let name = format!("__wizer_global_{}", index);
match instance.get_global(&mut *ctx, &name) {
None => break,
Some(global) => {
globals.push(global.get(&mut *ctx));
index += 1;
}
}
}
globals
}
fn snapshot_memories(
ctx: &mut impl AsContextMut,
instance: &wasmtime::Instance,
) -> (Vec<u64>, Vec<DataSegment>) {
log::debug!("Snapshotting memories");
let mut memory_mins = vec![];
let mut data_segments = vec![];
let mut memory_index = 0;
loop {
let name = format!("__wizer_memory_{}", memory_index);
let memory = match instance.get_memory(&mut *ctx, &name) {
None => break,
Some(memory) => memory,
};
memory_mins.push(memory.size(&*ctx));
let num_wasm_pages = memory.size(&*ctx);
let memory_data = memory.data(&*ctx);
data_segments.par_extend((0..num_wasm_pages).into_par_iter().flat_map(|i| {
let page_end = ((i + 1) * WASM_PAGE_SIZE) as usize;
let mut start = (i * WASM_PAGE_SIZE) as usize;
let mut segments = vec![];
while start < page_end {
let nonzero = match memory_data[start..page_end]
.iter()
.position(|byte| *byte != 0)
{
None => break,
Some(i) => i,
};
start += nonzero;
let end = memory_data[start..page_end]
.iter()
.position(|byte| *byte == 0)
.map_or(page_end, |zero| start + zero);
segments.push(DataSegment {
memory_index,
memory,
offset: u32::try_from(start).unwrap(),
len: u32::try_from(end - start).unwrap(),
});
start = end;
}
segments
}));
memory_index += 1;
}
if data_segments.is_empty() {
return (memory_mins, data_segments);
}
data_segments.sort_by_key(|s| (s.memory_index, s.offset));
const MIN_ACTIVE_SEGMENT_OVERHEAD: u32 = 4;
let mut merged_data_segments = Vec::with_capacity(data_segments.len());
merged_data_segments.push(data_segments[0]);
for b in &data_segments[1..] {
let a = merged_data_segments.last_mut().unwrap();
if a.memory_index != b.memory_index {
merged_data_segments.push(*b);
continue;
}
let gap = a.gap(b);
if gap > MIN_ACTIVE_SEGMENT_OVERHEAD {
merged_data_segments.push(*b);
continue;
}
let merged = a.merge(b);
*a = merged;
}
remove_excess_segments(&mut merged_data_segments);
(memory_mins, merged_data_segments)
}
fn remove_excess_segments(merged_data_segments: &mut Vec<DataSegment>) {
if merged_data_segments.len() < MAX_DATA_SEGMENTS {
return;
}
let excess = merged_data_segments.len() - MAX_DATA_SEGMENTS;
#[derive(Clone, Copy, PartialEq, Eq)]
struct GapIndex {
gap: u32,
index: u32,
}
let mut smallest_gaps = Vec::with_capacity(merged_data_segments.len() - 1);
for (index, w) in merged_data_segments.windows(2).enumerate() {
if w[0].memory_index != w[1].memory_index {
continue;
}
let gap = w[0].gap(&w[1]);
let index = u32::try_from(index).unwrap();
smallest_gaps.push(GapIndex { gap, index });
}
smallest_gaps.sort_unstable_by_key(|g| g.gap);
smallest_gaps.truncate(excess);
smallest_gaps.sort_unstable_by(|a, b| a.index.cmp(&b.index).reverse());
for GapIndex { index, .. } in smallest_gaps {
let index = usize::try_from(index).unwrap();
let merged = merged_data_segments[index].merge(&merged_data_segments[index + 1]);
merged_data_segments[index] = merged;
merged_data_segments.swap_remove(index + 1);
}
merged_data_segments.sort_by_key(|s| (s.memory_index, s.offset));
}