wasmtime_wizer/
snapshot.rs

1use crate::InstanceState;
2use crate::info::ModuleContext;
3use rayon::iter::{IntoParallelIterator, ParallelExtend, ParallelIterator};
4use std::convert::TryFrom;
5use std::ops::Range;
6
7/// The maximum number of data segments that we will emit. Most
8/// engines support more than this, but we want to leave some
9/// headroom.
10const MAX_DATA_SEGMENTS: usize = 10_000;
11
12/// A "snapshot" of Wasm state from its default value after having been initialized.
13pub struct Snapshot {
14    /// Maps global index to its initialized value.
15    ///
16    /// Note that this only tracks defined mutable globals, not all globals.
17    pub globals: Vec<(u32, SnapshotVal)>,
18
19    /// A new minimum size for each memory (in units of pages).
20    pub memory_mins: Vec<u64>,
21
22    /// Segments of non-zero memory.
23    pub data_segments: Vec<DataSegment>,
24}
25
26/// A value from a snapshot, currently a subset of wasm types that aren't
27/// reference types.
28#[expect(missing_docs, reason = "self-describing variants")]
29pub enum SnapshotVal {
30    I32(i32),
31    I64(i64),
32    F32(u32),
33    F64(u64),
34    V128(u128),
35}
36
37/// A data segment initializer for a memory.
38#[derive(Clone)]
39pub struct DataSegment {
40    /// The index of this data segment's memory.
41    pub memory_index: u32,
42
43    /// This data segment's initialized memory that it originated from.
44    pub data: Vec<u8>,
45
46    /// The offset within the memory that `data` should be copied to.
47    pub offset: u64,
48
49    /// Whether or not `memory_index` is a 64-bit memory.
50    pub is64: bool,
51}
52
53/// Snapshot the given instance's globals, memories, and instances from the Wasm
54/// defaults.
55//
56// TODO: when we support reference types, we will have to snapshot tables.
57pub async fn snapshot(module: &ModuleContext<'_>, ctx: &mut impl InstanceState) -> Snapshot {
58    log::debug!("Snapshotting the initialized state");
59
60    let globals = snapshot_globals(module, ctx).await;
61    let (memory_mins, data_segments) = snapshot_memories(module, ctx).await;
62
63    Snapshot {
64        globals,
65        memory_mins,
66        data_segments,
67    }
68}
69
70/// Get the initialized values of all globals.
71async fn snapshot_globals(
72    module: &ModuleContext<'_>,
73    ctx: &mut impl InstanceState,
74) -> Vec<(u32, SnapshotVal)> {
75    log::debug!("Snapshotting global values");
76
77    let mut ret = Vec::new();
78    for (i, name) in module.defined_global_exports.as_ref().unwrap().iter() {
79        let val = ctx.global_get(&name).await;
80        ret.push((*i, val));
81    }
82    ret
83}
84
85#[derive(Clone)]
86struct DataSegmentRange {
87    memory_index: u32,
88    range: Range<usize>,
89}
90
91impl DataSegmentRange {
92    /// What is the gap between two consecutive data segments?
93    ///
94    /// `self` must be in front of `other` and they must not overlap with each
95    /// other.
96    fn gap(&self, other: &Self) -> usize {
97        debug_assert_eq!(self.memory_index, other.memory_index);
98        debug_assert!(self.range.end <= other.range.start);
99        other.range.start - self.range.end
100    }
101
102    /// Merge two consecutive data segments.
103    ///
104    /// `self` must be in front of `other` and they must not overlap with each
105    /// other.
106    fn merge(&mut self, other: &Self) {
107        debug_assert_eq!(self.memory_index, other.memory_index);
108        debug_assert!(self.range.end <= other.range.start);
109        self.range.end = other.range.end;
110    }
111}
112
113/// Find the initialized minimum page size of each memory, as well as all
114/// regions of non-zero memory.
115async fn snapshot_memories(
116    module: &ModuleContext<'_>,
117    instance: &mut impl InstanceState,
118) -> (Vec<u64>, Vec<DataSegment>) {
119    log::debug!("Snapshotting memories");
120
121    // Find and record non-zero regions of memory (in parallel).
122    let mut memory_mins = vec![];
123    let mut data_segments = vec![];
124    let iter = module
125        .defined_memories()
126        .zip(module.defined_memory_exports.as_ref().unwrap());
127    for ((memory_index, ty), name) in iter {
128        instance
129            .memory_contents(&name, |memory| {
130                let page_size = 1 << ty.page_size_log2.unwrap_or(16);
131                let num_wasm_pages = memory.len() / page_size;
132                memory_mins.push(num_wasm_pages as u64);
133
134                let memory_data = &memory[..];
135
136                // Consider each Wasm page in parallel. Create data segments for each
137                // region of non-zero memory.
138                data_segments.par_extend((0..num_wasm_pages).into_par_iter().flat_map(|i| {
139                    let page_end = (i + 1) * page_size;
140                    let mut start = i * page_size;
141                    let mut segments = vec![];
142                    while start < page_end {
143                        let nonzero = match memory_data[start..page_end]
144                            .iter()
145                            .position(|byte| *byte != 0)
146                        {
147                            None => break,
148                            Some(i) => i,
149                        };
150                        start += nonzero;
151                        let end = memory_data[start..page_end]
152                            .iter()
153                            .position(|byte| *byte == 0)
154                            .map_or(page_end, |zero| start + zero);
155                        segments.push(DataSegmentRange {
156                            memory_index,
157                            range: start..end,
158                        });
159                        start = end;
160                    }
161                    segments
162                }));
163            })
164            .await;
165    }
166
167    if data_segments.is_empty() {
168        return (memory_mins, Vec::new());
169    }
170
171    // Sort data segments to enforce determinism in the face of the
172    // parallelism above.
173    data_segments.sort_by_key(|s| (s.memory_index, s.range.start));
174
175    // Merge any contiguous segments (caused by spanning a Wasm page boundary,
176    // and therefore created in separate logical threads above) or pages that
177    // are within four bytes of each other. Four because this is the minimum
178    // overhead of defining a new active data segment: one for the memory index
179    // LEB, two for the memory offset init expression (one for the `i32.const`
180    // opcode and another for the constant immediate LEB), and finally one for
181    // the data length LEB).
182    const MIN_ACTIVE_SEGMENT_OVERHEAD: usize = 4;
183    let mut merged_data_segments = Vec::with_capacity(data_segments.len());
184    merged_data_segments.push(data_segments[0].clone());
185    for b in &data_segments[1..] {
186        let a = merged_data_segments.last_mut().unwrap();
187
188        // Only merge segments for the same memory.
189        if a.memory_index != b.memory_index {
190            merged_data_segments.push(b.clone());
191            continue;
192        }
193
194        // Only merge segments if they are contiguous or if it is definitely
195        // more size efficient than leaving them apart.
196        let gap = a.gap(b);
197        if gap > MIN_ACTIVE_SEGMENT_OVERHEAD {
198            merged_data_segments.push(b.clone());
199            continue;
200        }
201
202        // Okay, merge them together into `a` (so that the next iteration can
203        // merge it with its predecessor) and then omit `b`!
204        a.merge(b);
205    }
206
207    remove_excess_segments(&mut merged_data_segments);
208
209    // With the final set of data segments now extract the actual data of each
210    // memory, copying it into a `DataSegment`, to return the final list of
211    // segments.
212    //
213    // Here the memories are iterated over again and, in tandem, the
214    // `merged_data_segments` list is traversed to extract a `DataSegment` for
215    // each range that `merged_data_segments` indicates. This relies on
216    // `merged_data_segments` being a sorted list by `memory_index` at least.
217    let mut final_data_segments = Vec::with_capacity(merged_data_segments.len());
218    let mut merged = merged_data_segments.iter().peekable();
219    let iter = module
220        .defined_memories()
221        .zip(module.defined_memory_exports.as_ref().unwrap());
222    for ((memory_index, ty), name) in iter {
223        instance
224            .memory_contents(&name, |memory| {
225                while let Some(segment) = merged.next_if(|s| s.memory_index == memory_index) {
226                    final_data_segments.push(DataSegment {
227                        memory_index,
228                        data: memory[segment.range.clone()].to_vec(),
229                        offset: segment.range.start.try_into().unwrap(),
230                        is64: ty.memory64,
231                    });
232                }
233            })
234            .await;
235    }
236    assert!(merged.next().is_none());
237
238    (memory_mins, final_data_segments)
239}
240
241/// Engines apply a limit on how many segments a module may contain, and Wizer
242/// can run afoul of it. When that happens, we need to merge data segments
243/// together until our number of data segments fits within the limit.
244fn remove_excess_segments(merged_data_segments: &mut Vec<DataSegmentRange>) {
245    if merged_data_segments.len() < MAX_DATA_SEGMENTS {
246        return;
247    }
248
249    // We need to remove `excess` number of data segments.
250    let excess = merged_data_segments.len() - MAX_DATA_SEGMENTS;
251
252    #[derive(Clone, Copy, PartialEq, Eq)]
253    struct GapIndex {
254        gap: u32,
255        // Use a `u32` instead of `usize` to fit `GapIndex` within a word on
256        // 64-bit systems, using less memory.
257        index: u32,
258    }
259
260    // Find the gaps between the start of one segment and the next (if they are
261    // both in the same memory). We will merge the `excess` segments with the
262    // smallest gaps together. Because they are the smallest gaps, this will
263    // bloat the size of our data segment the least.
264    let mut smallest_gaps = Vec::with_capacity(merged_data_segments.len() - 1);
265    for (index, w) in merged_data_segments.windows(2).enumerate() {
266        if w[0].memory_index != w[1].memory_index {
267            continue;
268        }
269        let gap = match u32::try_from(w[0].gap(&w[1])) {
270            Ok(gap) => gap,
271            // If the gap is larger than 4G then don't consider these two data
272            // segments for merging and assume there's enough other data
273            // segments close enough together to still consider for merging to
274            // get under the limit.
275            Err(_) => continue,
276        };
277        let index = u32::try_from(index).unwrap();
278        smallest_gaps.push(GapIndex { gap, index });
279    }
280    smallest_gaps.sort_unstable_by_key(|g| g.gap);
281    smallest_gaps.truncate(excess);
282
283    // Now merge the chosen segments together in reverse index order so that
284    // merging two segments doesn't mess up the index of the next segments we
285    // will to merge.
286    smallest_gaps.sort_unstable_by(|a, b| a.index.cmp(&b.index).reverse());
287    for GapIndex { index, .. } in smallest_gaps {
288        let index = usize::try_from(index).unwrap();
289        let [a, b] = merged_data_segments
290            .get_disjoint_mut([index, index + 1])
291            .unwrap();
292        a.merge(b);
293
294        // Okay to use `swap_remove` here because, even though it makes
295        // `merged_data_segments` unsorted, the segments are still sorted within
296        // the range `0..index` and future iterations will only operate within
297        // that subregion because we are iterating over largest to smallest
298        // indices.
299        merged_data_segments.swap_remove(index + 1);
300    }
301
302    // Finally, sort the data segments again so that our output is
303    // deterministic.
304    merged_data_segments.sort_by_key(|s| (s.memory_index, s.range.start));
305}