Skip to main content

gear_wasm_instrument/
module.rs

1// Copyright (C) Gear Technologies Inc.
2// SPDX-License-Identifier: GPL-3.0-or-later WITH Classpath-exception-2.0
3
4use alloc::{
5    borrow::Cow,
6    format,
7    string::{String, ToString},
8    vec,
9    vec::Vec,
10};
11use core::convert::Infallible;
12use wasm_encoder::{
13    reencode,
14    reencode::{Reencode, RoundtripReencoder},
15};
16use wasmparser::{
17    BinaryReaderError, Encoding, ExternalKind, FuncType, FunctionBody, GlobalType, KnownCustom,
18    MemoryType, Payload, RefType, TableType, TypeRef, ValType, WasmFeatures,
19};
20
21pub const GEAR_SUPPORTED_FEATURES: WasmFeatures = WasmFeatures::WASM1
22    .union(WasmFeatures::SIGN_EXTENSION)
23    .difference(WasmFeatures::FLOATS);
24
25// based on `wasmparser::_for_each_operator_group` and
26// it's recommended to read its documentation to understand the logic
27//
28// float instructions are removed
29macro_rules! for_each_instruction_group {
30    ($mac:ident) => {
31        $mac! {
32            @mvp {
33                Unreachable
34                Nop
35                Block { blockty: $crate::BlockType }
36                Loop { blockty: $crate::BlockType }
37                If { blockty: $crate::BlockType }
38                Else
39                End
40                Br { relative_depth: u32 }
41                BrIf { relative_depth: u32 }
42                BrTable { targets: $crate::BrTable }
43                Return
44                Call { function_index: u32 }
45                CallIndirect { type_index: u32, table_index: u32 }
46                Drop
47                Select
48                LocalGet { local_index: u32 }
49                LocalSet { local_index: u32 }
50                LocalTee { local_index: u32 }
51                GlobalGet { global_index: u32 }
52                GlobalSet { global_index: u32 }
53                I32Load { memarg: $crate::MemArg }
54                I64Load { memarg: $crate::MemArg }
55                I32Load8S { memarg: $crate::MemArg }
56                I32Load8U { memarg: $crate::MemArg }
57                I32Load16S { memarg: $crate::MemArg }
58                I32Load16U { memarg: $crate::MemArg }
59                I64Load8S { memarg: $crate::MemArg }
60                I64Load8U { memarg: $crate::MemArg }
61                I64Load16S { memarg: $crate::MemArg }
62                I64Load16U { memarg: $crate::MemArg }
63                I64Load32S { memarg: $crate::MemArg }
64                I64Load32U { memarg: $crate::MemArg }
65                I32Store { memarg: $crate::MemArg }
66                I64Store { memarg: $crate::MemArg }
67                I32Store8 { memarg: $crate::MemArg }
68                I32Store16 { memarg: $crate::MemArg }
69                I64Store8 { memarg: $crate::MemArg }
70                I64Store16 { memarg: $crate::MemArg }
71                I64Store32 { memarg: $crate::MemArg }
72                MemorySize { mem: u32 }
73                MemoryGrow { mem: u32 }
74                I32Const { value: i32 }
75                I64Const { value: i64 }
76                I32Eqz
77                I32Eq
78                I32Ne
79                I32LtS
80                I32LtU
81                I32GtS
82                I32GtU
83                I32LeS
84                I32LeU
85                I32GeS
86                I32GeU
87                I64Eqz
88                I64Eq
89                I64Ne
90                I64LtS
91                I64LtU
92                I64GtS
93                I64GtU
94                I64LeS
95                I64LeU
96                I64GeS
97                I64GeU
98                I32Clz
99                I32Ctz
100                I32Popcnt
101                I32Add
102                I32Sub
103                I32Mul
104                I32DivS
105                I32DivU
106                I32RemS
107                I32RemU
108                I32And
109                I32Or
110                I32Xor
111                I32Shl
112                I32ShrS
113                I32ShrU
114                I32Rotl
115                I32Rotr
116                I64Clz
117                I64Ctz
118                I64Popcnt
119                I64Add
120                I64Sub
121                I64Mul
122                I64DivS
123                I64DivU
124                I64RemS
125                I64RemU
126                I64And
127                I64Or
128                I64Xor
129                I64Shl
130                I64ShrS
131                I64ShrU
132                I64Rotl
133                I64Rotr
134                I32WrapI64
135                I64ExtendI32S
136                I64ExtendI32U
137            }
138
139            @sign_extension {
140                I32Extend8S
141                I32Extend16S
142                I64Extend8S
143                I64Extend16S
144                I64Extend32S
145            }
146        }
147    };
148}
149
150// exactly the same as `for_each_instruction_group` but without proposals info
151macro_rules! define_for_each_instruction {
152    ($(
153        @$proposal:ident {
154            $($op:ident $( { $( $arg:ident: $argty:ty ),+ } )?)+
155        }
156    )+) => {
157        macro_rules! for_each_instruction {
158            ($mac:ident) => {
159                $mac! {
160                    $(
161                        $(
162                            $op $( { $( $arg: $argty ),+ } )?
163                        )+
164                    )+
165                }
166            };
167        }
168    };
169}
170
171// `for_each_instruction` is now defined
172for_each_instruction_group!(define_for_each_instruction);
173
174macro_rules! define_instruction {
175    ($( $op:ident $( { $( $arg:ident: $argty:ty ),+ } )? )+) => {
176        define_instruction!(@convert $( $op $( { $( $arg: $argty ),+ } )? )+ @accum);
177    };
178    // omit `table_index` field of `call_indirect` instruction because it's always zero
179    // but we still save original fields to use them for `wasmparser` and `wasm-encoder` types
180    // during parsing and reencoding
181    (
182        @convert
183        CallIndirect { $type_index_arg:ident: $type_index_argty:ty, $table_index_arg:ident: $table_index_argty:ty }
184        $( $ops:ident $( { $($args:ident: $argtys:ty),+ } )? )*
185        @accum
186        $( $accum_op:ident $( { $($original_arg:ident: $original_argty:ty),+ } => { $($accum_arg:ident: $accum_argty:ty),+ } )? )*
187    ) => {
188        define_instruction!(
189            @convert
190            $( $ops $( { $($args: $argtys),+ } )? )*
191            @accum
192            CallIndirect { $type_index_arg: $type_index_argty, $table_index_arg: $table_index_argty } => { $type_index_arg: $type_index_argty }
193            $( $accum_op $( { $($original_arg: $original_argty),+ } => { $($accum_arg: $accum_argty),+ } )? )*
194        );
195    };
196    // do nothing to the rest instructions and collect them
197    (
198        @convert
199        $op:ident $( { $($arg:ident: $argty:ty),+ } )?
200        $( $ops:ident $( { $($args:ident: $argtys:ty),+ } )? )*
201        @accum
202        $( $accum_op:ident $( { $($original_arg:ident: $original_argty:ty),+ } => { $($accum_arg:ident: $accum_argty:ty),+ } )? )*
203    ) => {
204        define_instruction!(
205            @convert
206            $( $ops $( { $($args: $argtys),+ } )? )*
207            @accum
208            $op $( { $($arg: $argty),+ } => { $($arg: $argty),+ } )?
209            $( $accum_op $( { $($original_arg: $original_argty),+ } => { $($accum_arg: $accum_argty),+ } )? )*
210        );
211    };
212    // collection is done so we define `Instruction` itself now
213    (
214        @convert
215        @accum
216        $( $op:ident $( { $( $original_arg:ident: $original_argty:ty ),+ } => { $( $arg:ident: $argty:ty ),+ } )? )+
217    ) => {
218        #[derive(Debug, Clone, Eq, PartialEq)]
219        pub enum Instruction {
220            $(
221                $op $(( $( $argty ),+ ))?,
222            )+
223        }
224
225        impl Instruction {
226            fn parse(op: wasmparser::Operator) -> Result<Self> {
227                match op {
228                    $(
229                        wasmparser::Operator::$op $({ $($original_arg),* })? => {
230                            define_instruction!(@parse $op $(( $($original_arg $original_arg),* ))?)
231                        }
232                    )*
233                    op => Err(ModuleError::UnsupportedInstruction(format!("{op:?}"))),
234                }
235            }
236
237            fn reencode(&self) -> Result<wasm_encoder::Instruction<'_>> {
238                Ok(match self {
239                    $(
240                        Self::$op $(( $($arg),+ ))? => {
241                            $(
242                                $(let $arg = define_instruction!(@arg $arg $arg);)*
243                            )?
244                            define_instruction!(@build $op $($($arg)*)?)
245                        }
246                    )*
247                })
248            }
249        }
250    };
251
252    // further macro branches are based on `wasm_encoder::reencode` module
253
254    (@parse CallIndirect($type_index:ident type_index, $table_index:ident table_index)) => {{
255        // already verified by wasmparser
256        debug_assert_eq!($table_index, 0);
257
258        Ok(Self::CallIndirect(<_>::try_from($type_index)?))
259    }};
260    (@parse $op:ident $(( $($arg:ident $_arg:ident),* ))?) => {
261        Ok(Self::$op $(( $(<_>::try_from($arg)?),* ))?)
262    };
263
264    (@arg $arg:ident blockty) => (RoundtripReencoder.block_type(*$arg)?);
265    (@arg $arg:ident targets) => ((
266        ($arg).targets.clone().into(),
267        ($arg).default,
268    ));
269    (@arg $arg:ident memarg) => ((*$arg).reencode());
270    (@arg $arg:ident $_arg:ident) => (*$arg);
271
272    (@build $op:ident) => (wasm_encoder::Instruction::$op);
273    (@build BrTable $arg:ident) => (wasm_encoder::Instruction::BrTable($arg.0, $arg.1));
274    (@build I32Const $arg:ident) => (wasm_encoder::Instruction::I32Const($arg));
275    (@build I64Const $arg:ident) => (wasm_encoder::Instruction::I64Const($arg));
276    (@build F32Const $arg:ident) => (wasm_encoder::Instruction::F32Const(f32::from_bits($arg.bits())));
277    (@build F64Const $arg:ident) => (wasm_encoder::Instruction::F64Const(f64::from_bits($arg.bits())));
278    (@build CallIndirect $arg:ident) => (wasm_encoder::Instruction::CallIndirect { type_index: $arg, table_index: 0 });
279    (@build $op:ident $arg:ident) => (wasm_encoder::Instruction::$op($arg));
280    (@build $op:ident $($arg:ident)*) => (wasm_encoder::Instruction::$op { $($arg),* });
281}
282
283for_each_instruction!(define_instruction);
284
285impl Instruction {
286    /// Returns `true` if instruction is forbidden to be used by user
287    /// but allowed to be used by instrumentation stage.
288    pub fn is_user_forbidden(&self) -> bool {
289        matches!(self, Self::MemoryGrow { .. })
290    }
291}
292
293pub type Result<T, E = ModuleError> = core::result::Result<T, E>;
294
295#[derive(Debug, derive_more::Display, derive_more::From)]
296pub enum ModuleError {
297    #[display("Binary reader error: {_0}")]
298    BinaryReader(BinaryReaderError),
299    #[display("Reencode error: {_0}")]
300    Reencode(reencode::Error),
301    #[display("Int conversion error: {_0}")]
302    TryFromInt(core::num::TryFromIntError),
303    #[display("Unsupported instruction: {_0}")]
304    UnsupportedInstruction(String),
305    #[display("Multiple tables")]
306    MultipleTables,
307    #[display("Multiple memories")]
308    MultipleMemories,
309    #[from(skip)]
310    #[display("Memory index must be zero (actual: {_0})")]
311    NonZeroMemoryIdx(u32),
312    #[from(skip)]
313    #[display("Optional table index of element segment is not supported (index: {_0})")]
314    ElementTableIdx(u32),
315    #[display("Passive data is not supported")]
316    PassiveDataKind,
317    #[display("Element expressions are not supported")]
318    ElementExpressions,
319    #[display("Only active element is supported")]
320    NonActiveElementKind,
321    #[display("Table init expression is not supported")]
322    TableInitExpr,
323}
324
325impl core::error::Error for ModuleError {
326    fn source(&self) -> Option<&(dyn core::error::Error + 'static)> {
327        match self {
328            ModuleError::BinaryReader(e) => Some(e),
329            ModuleError::Reencode(e) => Some(e),
330            ModuleError::TryFromInt(e) => Some(e),
331            ModuleError::UnsupportedInstruction(_) => None,
332            ModuleError::MultipleTables => None,
333            ModuleError::MultipleMemories => None,
334            ModuleError::NonZeroMemoryIdx(_) => None,
335            ModuleError::ElementTableIdx(_) => None,
336            ModuleError::PassiveDataKind => None,
337            ModuleError::ElementExpressions => None,
338            ModuleError::NonActiveElementKind => None,
339            ModuleError::TableInitExpr => None,
340        }
341    }
342}
343
344impl From<Infallible> for ModuleError {
345    fn from(value: Infallible) -> Self {
346        match value {}
347    }
348}
349
350#[derive(Debug, Clone, Copy, Eq, PartialEq)]
351pub struct MemArg {
352    /// The expected alignment of the instruction's dynamic address operand
353    /// (expressed the exponent of a power of two).
354    pub align: u8,
355    /// A static offset to add to the instruction's dynamic address operand.
356    pub offset: u32,
357}
358
359impl TryFrom<wasmparser::MemArg> for MemArg {
360    type Error = ModuleError;
361
362    fn try_from(
363        wasmparser::MemArg {
364            align,
365            max_align: _,
366            offset,
367            memory,
368        }: wasmparser::MemArg,
369    ) -> Result<Self, Self::Error> {
370        // always zero if multi-memory is not enabled
371        debug_assert_eq!(memory, 0);
372        Ok(Self {
373            align,
374            offset: offset.try_into()?,
375        })
376    }
377}
378
379impl MemArg {
380    pub fn zero() -> Self {
381        Self {
382            align: 0,
383            offset: 0,
384        }
385    }
386
387    pub fn i32() -> Self {
388        Self::i32_offset(0)
389    }
390
391    pub fn i64() -> Self {
392        Self::i64_offset(0)
393    }
394
395    pub fn i32_offset(offset: u32) -> Self {
396        Self { align: 2, offset }
397    }
398
399    pub fn i64_offset(offset: u32) -> Self {
400        Self { align: 3, offset }
401    }
402
403    fn reencode(self) -> wasm_encoder::MemArg {
404        wasm_encoder::MemArg {
405            offset: self.offset as u64,
406            align: self.align as u32,
407            memory_index: 0,
408        }
409    }
410}
411
412#[derive(Debug, Clone, Eq, PartialEq)]
413pub struct BrTable {
414    pub default: u32,
415    pub targets: Vec<u32>,
416}
417
418impl BrTable {
419    /// Returns the number of `br_table` entries, not including the default label.
420    pub fn len(&self) -> u32 {
421        self.targets.len() as u32
422    }
423
424    /// Returns whether `BrTable` doesn’t have any labels apart from the default one.
425    pub fn is_empty(&self) -> bool {
426        self.targets.is_empty()
427    }
428}
429
430impl TryFrom<wasmparser::BrTable<'_>> for BrTable {
431    type Error = ModuleError;
432
433    fn try_from(targets: wasmparser::BrTable) -> Result<Self> {
434        Ok(Self {
435            default: targets.default(),
436            targets: targets
437                .targets()
438                .collect::<Result<Vec<_>, BinaryReaderError>>()?,
439        })
440    }
441}
442
443#[derive(Default, Clone, derive_more::Debug, Eq, PartialEq)]
444#[debug("ConstExpr {{ .. }}")]
445pub struct ConstExpr {
446    pub instructions: Vec<Instruction>,
447}
448
449impl ConstExpr {
450    pub fn empty() -> Self {
451        Self {
452            instructions: Vec::new(),
453        }
454    }
455
456    pub fn i32_value(value: i32) -> Self {
457        Self {
458            instructions: vec![Instruction::I32Const(value)],
459        }
460    }
461
462    pub fn i64_value(value: i64) -> Self {
463        Self {
464            instructions: vec![Instruction::I64Const(value)],
465        }
466    }
467
468    fn parse(expr: wasmparser::ConstExpr) -> Result<Self> {
469        let mut instructions = Vec::new();
470        let mut ops = expr.get_operators_reader();
471        while !ops.is_end_then_eof() {
472            instructions.push(Instruction::parse(ops.read()?)?);
473        }
474
475        Ok(Self { instructions })
476    }
477
478    fn reencode(&self) -> Result<wasm_encoder::ConstExpr> {
479        Ok(wasm_encoder::ConstExpr::extended(
480            self.instructions
481                .iter()
482                .map(Instruction::reencode)
483                .collect::<Result<Vec<_>>>()?,
484        ))
485    }
486}
487
488#[derive(Debug, Clone, Eq, PartialEq)]
489pub struct Import {
490    pub module: Cow<'static, str>,
491    pub name: Cow<'static, str>,
492    pub ty: TypeRef,
493}
494
495impl Import {
496    pub fn func(
497        module: impl Into<Cow<'static, str>>,
498        name: impl Into<Cow<'static, str>>,
499        index: u32,
500    ) -> Self {
501        Self {
502            module: module.into(),
503            name: name.into(),
504            ty: TypeRef::Func(index),
505        }
506    }
507
508    pub fn memory(initial: u32, maximum: Option<u32>) -> Self {
509        Self {
510            module: "env".into(),
511            name: "memory".into(),
512            ty: TypeRef::Memory(MemoryType {
513                memory64: false,
514                shared: false,
515                initial: initial as u64,
516                maximum: maximum.map(|v| v as u64),
517                page_size_log2: None,
518            }),
519        }
520    }
521
522    fn parse(import: wasmparser::Import) -> Self {
523        Self {
524            module: import.module.to_string().into(),
525            name: import.name.to_string().into(),
526            ty: import.ty,
527        }
528    }
529
530    pub fn reencode(&self, imports: &mut wasm_encoder::ImportSection) -> Result<()> {
531        imports.import(
532            &self.module,
533            &self.name,
534            RoundtripReencoder.entity_type(self.ty)?,
535        );
536        Ok(())
537    }
538}
539
540#[derive(Clone, Debug, Eq, PartialEq)]
541pub enum TableInit {
542    RefNull,
543}
544
545#[derive(Clone, Debug, Eq, PartialEq)]
546pub struct Table {
547    pub ty: TableType,
548    pub init: TableInit,
549}
550
551impl Table {
552    pub fn funcref(initial: u32, maximum: Option<u32>) -> Self {
553        Table {
554            ty: TableType {
555                element_type: RefType::FUNCREF,
556                table64: false,
557                initial: initial as u64,
558                maximum: maximum.map(|v| v as u64),
559                shared: false,
560            },
561            init: TableInit::RefNull,
562        }
563    }
564
565    fn parse(table: wasmparser::Table) -> Result<Self> {
566        Ok(Self {
567            ty: table.ty,
568            init: match table.init {
569                wasmparser::TableInit::RefNull => TableInit::RefNull,
570                wasmparser::TableInit::Expr(_expr) => return Err(ModuleError::TableInitExpr),
571            },
572        })
573    }
574
575    fn reencode(&self, tables: &mut wasm_encoder::TableSection) -> Result<()> {
576        let ty = RoundtripReencoder.table_type(self.ty)?;
577        match &self.init {
578            TableInit::RefNull => {
579                tables.table(ty);
580            }
581        }
582        Ok(())
583    }
584}
585
586#[derive(Debug, Clone)]
587pub struct Global {
588    pub ty: GlobalType,
589    pub init_expr: ConstExpr,
590}
591
592impl Global {
593    pub fn i32_value(value: i32) -> Self {
594        Self {
595            ty: GlobalType {
596                content_type: ValType::I32,
597                mutable: false,
598                shared: false,
599            },
600            init_expr: ConstExpr::i32_value(value),
601        }
602    }
603
604    pub fn i64_value(value: i64) -> Self {
605        Self {
606            ty: GlobalType {
607                content_type: ValType::I64,
608                mutable: false,
609                shared: false,
610            },
611            init_expr: ConstExpr::i64_value(value),
612        }
613    }
614
615    pub fn i64_value_mut(value: i64) -> Self {
616        Self {
617            ty: GlobalType {
618                content_type: ValType::I64,
619                mutable: true,
620                shared: false,
621            },
622            init_expr: ConstExpr::i64_value(value),
623        }
624    }
625
626    fn parse(global: wasmparser::Global) -> Result<Self> {
627        Ok(Self {
628            ty: global.ty,
629            init_expr: ConstExpr::parse(global.init_expr)?,
630        })
631    }
632}
633
634#[derive(Debug, Clone)]
635pub struct Export {
636    pub name: Cow<'static, str>,
637    pub kind: ExternalKind,
638    pub index: u32,
639}
640
641impl Export {
642    pub fn func(name: impl Into<Cow<'static, str>>, index: u32) -> Self {
643        Self {
644            name: name.into(),
645            kind: ExternalKind::Func,
646            index,
647        }
648    }
649
650    pub fn global(name: impl Into<Cow<'static, str>>, index: u32) -> Self {
651        Self {
652            name: name.into(),
653            kind: ExternalKind::Global,
654            index,
655        }
656    }
657
658    fn parse(export: wasmparser::Export) -> Self {
659        Self {
660            name: export.name.to_string().into(),
661            kind: export.kind,
662            index: export.index,
663        }
664    }
665}
666
667#[derive(Clone)]
668pub enum ElementKind {
669    Active { offset_expr: ConstExpr },
670}
671
672impl ElementKind {
673    fn parse(kind: wasmparser::ElementKind) -> Result<Self> {
674        match kind {
675            wasmparser::ElementKind::Passive => Err(ModuleError::NonActiveElementKind),
676            wasmparser::ElementKind::Active {
677                table_index,
678                offset_expr,
679            } => {
680                if let Some(table_index) = table_index {
681                    return Err(ModuleError::ElementTableIdx(table_index));
682                }
683
684                Ok(Self::Active {
685                    offset_expr: ConstExpr::parse(offset_expr)?,
686                })
687            }
688            wasmparser::ElementKind::Declared => Err(ModuleError::NonActiveElementKind),
689        }
690    }
691}
692
693#[derive(Clone)]
694pub enum ElementItems {
695    Functions(Vec<u32>),
696}
697
698impl ElementItems {
699    fn parse(elements: wasmparser::ElementItems) -> Result<Self> {
700        match elements {
701            wasmparser::ElementItems::Functions(f) => {
702                let mut funcs = Vec::new();
703                for func in f {
704                    funcs.push(func?);
705                }
706                Ok(Self::Functions(funcs))
707            }
708            wasmparser::ElementItems::Expressions(_ty, _e) => Err(ModuleError::ElementExpressions),
709        }
710    }
711}
712
713#[derive(Clone)]
714pub struct Element {
715    pub kind: ElementKind,
716    pub items: ElementItems,
717}
718
719impl Element {
720    pub fn functions(funcs: Vec<u32>) -> Self {
721        Self {
722            kind: ElementKind::Active {
723                offset_expr: ConstExpr::i32_value(0),
724            },
725            items: ElementItems::Functions(funcs),
726        }
727    }
728
729    fn parse(element: wasmparser::Element) -> Result<Self> {
730        Ok(Self {
731            kind: ElementKind::parse(element.kind)?,
732            items: ElementItems::parse(element.items)?,
733        })
734    }
735
736    fn reencode(&self, encoder_section: &mut wasm_encoder::ElementSection) -> Result<()> {
737        let items = match &self.items {
738            ElementItems::Functions(funcs) => {
739                wasm_encoder::Elements::Functions(funcs.clone().into())
740            }
741        };
742
743        match &self.kind {
744            ElementKind::Active { offset_expr } => {
745                encoder_section.active(None, &offset_expr.reencode()?, items);
746            }
747        }
748
749        Ok(())
750    }
751}
752
753#[derive(Debug, Clone)]
754pub struct Data {
755    pub offset_expr: ConstExpr,
756    pub data: Cow<'static, [u8]>,
757}
758
759impl Data {
760    pub fn with_offset(data: impl Into<Cow<'static, [u8]>>, offset: u32) -> Self {
761        Self {
762            offset_expr: ConstExpr::i32_value(offset as i32),
763            data: data.into(),
764        }
765    }
766
767    fn parse(data: wasmparser::Data) -> Result<Self> {
768        Ok(Self {
769            offset_expr: match data.kind {
770                wasmparser::DataKind::Passive => return Err(ModuleError::PassiveDataKind),
771                wasmparser::DataKind::Active {
772                    memory_index,
773                    offset_expr,
774                } => {
775                    if memory_index != 0 {
776                        return Err(ModuleError::NonZeroMemoryIdx(memory_index));
777                    }
778
779                    ConstExpr::parse(offset_expr)?
780                }
781            },
782            data: data.data.to_vec().into(),
783        })
784    }
785}
786
787#[derive(Debug, Clone, Default)]
788pub struct Function {
789    pub locals: Vec<(u32, ValType)>,
790    pub instructions: Vec<Instruction>,
791}
792
793impl Function {
794    pub fn from_instructions(instructions: impl Into<Vec<Instruction>>) -> Self {
795        Self {
796            locals: Vec::new(),
797            instructions: instructions.into(),
798        }
799    }
800
801    fn from_entry(func: FunctionBody) -> Result<Self> {
802        let mut locals = Vec::new();
803        for pair in func.get_locals_reader()? {
804            let (cnt, ty) = pair?;
805            locals.push((cnt, ty));
806        }
807
808        let mut instructions = Vec::new();
809        let mut reader = func.get_operators_reader()?;
810        while !reader.eof() {
811            instructions.push(Instruction::parse(reader.read()?)?);
812        }
813
814        Ok(Self {
815            locals,
816            instructions,
817        })
818    }
819
820    fn reencode(&self) -> Result<wasm_encoder::Function> {
821        let mut encoder_func = wasm_encoder::Function::new(
822            self.locals
823                .iter()
824                .map(|&(cnt, ty)| Ok((cnt, RoundtripReencoder.val_type(ty)?)))
825                .collect::<Result<Vec<_>, reencode::Error>>()?,
826        );
827
828        for op in &self.instructions {
829            encoder_func.instruction(&op.reencode()?);
830        }
831
832        if self.instructions.is_empty() {
833            encoder_func.instruction(&wasm_encoder::Instruction::End);
834        }
835
836        Ok(encoder_func)
837    }
838}
839
840pub type NameMap = Vec<Naming>;
841
842/// Represents a name for an index from the names section.
843#[derive(Debug, Clone)]
844pub struct Naming {
845    /// The index being named.
846    pub index: u32,
847    /// The name for the index.
848    pub name: Cow<'static, str>,
849}
850
851pub type IndirectNameMap = Vec<IndirectNaming>;
852
853/// Represents an indirect name in the names custom section.
854#[derive(Debug, Clone)]
855pub struct IndirectNaming {
856    /// The indirect index of the name.
857    pub index: u32,
858    /// The map of names within the `index` prior.
859    pub names: NameMap,
860}
861
862#[derive(Debug, Clone)]
863pub enum Name {
864    /// The name is for the module.
865    Module(Cow<'static, str>),
866    /// The name is for the functions.
867    Function(NameMap),
868    /// The name is for the function locals.
869    Local(IndirectNameMap),
870    /// The name is for the function labels.
871    Label(IndirectNameMap),
872    /// The name is for the types.
873    Type(NameMap),
874    /// The name is for the tables.
875    Table(NameMap),
876    /// The name is for the memories.
877    Memory(NameMap),
878    /// The name is for the globals.
879    Global(NameMap),
880    /// The name is for the element segments.
881    Element(NameMap),
882    /// The name is for the data segments.
883    Data(NameMap),
884    /// The name is for fields.
885    Field(IndirectNameMap),
886    /// The name is for tags.
887    Tag(NameMap),
888    /// An unknown [name subsection](https://webassembly.github.io/spec/core/appendix/custom.html#subsections).
889    Unknown {
890        /// The identifier for this subsection.
891        ty: u8,
892        /// The contents of this subsection.
893        data: Cow<'static, [u8]>,
894    },
895}
896
897impl Name {
898    fn parse(name: wasmparser::Name) -> Result<Self> {
899        let name_map = |map: wasmparser::NameMap| {
900            map.into_iter()
901                .map(|n| {
902                    n.map(|n| Naming {
903                        index: n.index,
904                        name: n.name.to_string().into(),
905                    })
906                })
907                .collect::<Result<Vec<_>, BinaryReaderError>>()
908        };
909
910        let indirect_name_map = |map: wasmparser::IndirectNameMap| {
911            map.into_iter()
912                .map(|n| {
913                    n.and_then(|n| {
914                        Ok(IndirectNaming {
915                            index: n.index,
916                            names: name_map(n.names)?,
917                        })
918                    })
919                })
920                .collect::<Result<Vec<_>, BinaryReaderError>>()
921        };
922
923        Ok(match name {
924            wasmparser::Name::Module {
925                name,
926                name_range: _,
927            } => Self::Module(name.to_string().into()),
928            wasmparser::Name::Function(map) => Self::Function(name_map(map)?),
929            wasmparser::Name::Local(map) => Self::Local(indirect_name_map(map)?),
930            wasmparser::Name::Label(map) => Self::Label(indirect_name_map(map)?),
931            wasmparser::Name::Type(map) => Self::Type(name_map(map)?),
932            wasmparser::Name::Table(map) => Self::Table(name_map(map)?),
933            wasmparser::Name::Memory(map) => Self::Memory(name_map(map)?),
934            wasmparser::Name::Global(map) => Self::Global(name_map(map)?),
935            wasmparser::Name::Element(map) => Self::Element(name_map(map)?),
936            wasmparser::Name::Data(map) => Self::Data(name_map(map)?),
937            wasmparser::Name::Field(map) => Self::Field(indirect_name_map(map)?),
938            wasmparser::Name::Tag(map) => Self::Tag(name_map(map)?),
939            wasmparser::Name::Unknown { ty, data, range: _ } => Self::Unknown {
940                ty,
941                data: data.to_vec().into(),
942            },
943        })
944    }
945
946    fn reencode(&self, section: &mut wasm_encoder::NameSection) {
947        let name_map = |map: &NameMap| {
948            map.iter()
949                .fold(wasm_encoder::NameMap::new(), |mut map, naming| {
950                    map.append(naming.index, &naming.name);
951                    map
952                })
953        };
954
955        let indirect_name_map = |map: &IndirectNameMap| {
956            map.iter()
957                .fold(wasm_encoder::IndirectNameMap::new(), |mut map, naming| {
958                    map.append(naming.index, &name_map(&naming.names));
959                    map
960                })
961        };
962
963        match self {
964            Name::Module(name) => {
965                section.module(name);
966            }
967            Name::Function(map) => section.functions(&name_map(map)),
968            Name::Local(map) => section.locals(&indirect_name_map(map)),
969            Name::Label(map) => section.labels(&indirect_name_map(map)),
970            Name::Type(map) => section.types(&name_map(map)),
971            Name::Table(map) => section.tables(&name_map(map)),
972            Name::Memory(map) => section.memories(&name_map(map)),
973            Name::Global(map) => section.globals(&name_map(map)),
974            Name::Element(map) => section.elements(&name_map(map)),
975            Name::Data(map) => section.data(&name_map(map)),
976            Name::Field(map) => section.fields(&indirect_name_map(map)),
977            Name::Tag(map) => section.tags(&name_map(map)),
978            Name::Unknown { ty, data } => section.raw(*ty, data),
979        }
980    }
981}
982
983pub struct ModuleFuncIndexShifter {
984    builder: ModuleBuilder,
985    inserted_at: u32,
986    code_section: bool,
987    export_section: bool,
988    element_section: bool,
989    start_section: bool,
990    name_section: bool,
991}
992
993impl ModuleFuncIndexShifter {
994    pub fn with_code_section(mut self) -> Self {
995        self.code_section = true;
996        self
997    }
998
999    pub fn with_export_section(mut self) -> Self {
1000        self.export_section = true;
1001        self
1002    }
1003
1004    pub fn with_element_section(mut self) -> Self {
1005        self.element_section = true;
1006        self
1007    }
1008
1009    pub fn with_start_section(mut self) -> Self {
1010        self.start_section = true;
1011        self
1012    }
1013
1014    pub fn with_name_section(mut self) -> Self {
1015        self.name_section = true;
1016        self
1017    }
1018
1019    /// Shift function indices in every section
1020    pub fn with_all(self) -> Self {
1021        self.with_code_section()
1022            .with_export_section()
1023            .with_element_section()
1024            .with_start_section()
1025            .with_name_section()
1026    }
1027
1028    pub fn shift_all(self) -> ModuleBuilder {
1029        self.with_all().shift()
1030    }
1031
1032    /// Do actual shifting
1033    pub fn shift(mut self) -> ModuleBuilder {
1034        if let Some(section) = self
1035            .builder
1036            .module
1037            .code_section
1038            .as_mut()
1039            .filter(|_| self.code_section)
1040        {
1041            for func in section {
1042                for instruction in &mut func.instructions {
1043                    if let Instruction::Call(function_index) = instruction
1044                        && *function_index >= self.inserted_at
1045                    {
1046                        *function_index += 1
1047                    }
1048                }
1049            }
1050        }
1051
1052        if let Some(section) = self
1053            .builder
1054            .module
1055            .export_section
1056            .as_mut()
1057            .filter(|_| self.export_section)
1058        {
1059            for export in section {
1060                if let ExternalKind::Func = export.kind
1061                    && export.index >= self.inserted_at
1062                {
1063                    export.index += 1
1064                }
1065            }
1066        }
1067
1068        if let Some(section) = self
1069            .builder
1070            .module
1071            .element_section
1072            .as_mut()
1073            .filter(|_| self.element_section)
1074        {
1075            for segment in section {
1076                // update all indirect call addresses initial values
1077                match &mut segment.items {
1078                    ElementItems::Functions(funcs) => {
1079                        for func_index in funcs.iter_mut() {
1080                            if *func_index >= self.inserted_at {
1081                                *func_index += 1
1082                            }
1083                        }
1084                    }
1085                }
1086            }
1087        }
1088
1089        if let Some(start_idx) = self
1090            .builder
1091            .module
1092            .start_section
1093            .as_mut()
1094            .filter(|_| self.start_section)
1095            && *start_idx >= self.inserted_at
1096        {
1097            *start_idx += 1
1098        }
1099
1100        if let Some(section) = self
1101            .builder
1102            .module
1103            .name_section
1104            .as_mut()
1105            .filter(|_| self.name_section)
1106        {
1107            for name in section {
1108                if let Name::Function(map) = name {
1109                    for naming in map {
1110                        if naming.index >= self.inserted_at {
1111                            naming.index += 1;
1112                        }
1113                    }
1114                }
1115            }
1116        }
1117
1118        self.builder
1119    }
1120}
1121
1122#[derive(Debug, Default)]
1123pub struct ModuleBuilder {
1124    module: Module,
1125}
1126
1127impl ModuleBuilder {
1128    pub fn from_module(module: Module) -> Self {
1129        Self { module }
1130    }
1131
1132    pub fn shift_func_index(self, inserted_at: u32) -> ModuleFuncIndexShifter {
1133        ModuleFuncIndexShifter {
1134            builder: self,
1135            inserted_at,
1136            code_section: false,
1137            export_section: false,
1138            element_section: false,
1139            start_section: false,
1140            name_section: false,
1141        }
1142    }
1143
1144    pub fn build(self) -> Module {
1145        self.module
1146    }
1147
1148    fn type_section(&mut self) -> &mut TypeSection {
1149        self.module
1150            .type_section
1151            .get_or_insert_with(Default::default)
1152    }
1153
1154    fn import_section(&mut self) -> &mut Vec<Import> {
1155        self.module.import_section.get_or_insert_with(Vec::new)
1156    }
1157
1158    fn func_section(&mut self) -> &mut Vec<u32> {
1159        self.module.function_section.get_or_insert_with(Vec::new)
1160    }
1161
1162    fn global_section(&mut self) -> &mut Vec<Global> {
1163        self.module.global_section.get_or_insert_with(Vec::new)
1164    }
1165
1166    fn export_section(&mut self) -> &mut Vec<Export> {
1167        self.module.export_section.get_or_insert_with(Vec::new)
1168    }
1169
1170    fn element_section(&mut self) -> &mut Vec<Element> {
1171        self.module.element_section.get_or_insert_with(Vec::new)
1172    }
1173
1174    fn code_section(&mut self) -> &mut CodeSection {
1175        self.module.code_section.get_or_insert_with(Vec::new)
1176    }
1177
1178    fn data_section(&mut self) -> &mut DataSection {
1179        self.module.data_section.get_or_insert_with(Vec::new)
1180    }
1181
1182    fn custom_sections(&mut self) -> &mut Vec<CustomSection> {
1183        self.module.custom_sections.get_or_insert_with(Vec::new)
1184    }
1185
1186    pub fn push_custom_section(
1187        &mut self,
1188        name: impl Into<Cow<'static, str>>,
1189        data: impl Into<Vec<u8>>,
1190    ) {
1191        self.custom_sections().push((name.into(), data.into()));
1192    }
1193
1194    /// Adds a new function to the module.
1195    ///
1196    /// Returns index from function section
1197    pub fn add_func(&mut self, ty: FuncType, function: Function) -> u32 {
1198        let type_idx = self.push_type(ty);
1199        self.func_section().push(type_idx);
1200        let func_idx = self.func_section().len() as u32 - 1;
1201        self.code_section().push(function);
1202        func_idx
1203    }
1204
1205    pub fn push_type(&mut self, ty: FuncType) -> u32 {
1206        let idx = self.type_section().iter().position(|vec_ty| *vec_ty == ty);
1207        idx.map(|pos| pos as u32).unwrap_or_else(|| {
1208            self.type_section().push(ty);
1209            self.type_section().len() as u32 - 1
1210        })
1211    }
1212
1213    pub fn push_import(&mut self, import: Import) -> u32 {
1214        self.import_section().push(import);
1215        self.import_section().len() as u32 - 1
1216    }
1217
1218    pub fn set_table(&mut self, table: Table) {
1219        debug_assert_eq!(self.module.table_section, None);
1220        self.module.table_section = Some(table);
1221    }
1222
1223    pub fn push_global(&mut self, global: Global) -> u32 {
1224        self.global_section().push(global);
1225        self.global_section().len() as u32 - 1
1226    }
1227
1228    pub fn push_export(&mut self, export: Export) {
1229        self.export_section().push(export);
1230    }
1231
1232    pub fn push_element(&mut self, element: Element) {
1233        self.element_section().push(element);
1234    }
1235
1236    pub fn push_data(&mut self, data: Data) {
1237        self.data_section().push(data);
1238    }
1239}
1240
1241pub type TypeSection = Vec<FuncType>;
1242pub type FuncSection = Vec<u32>;
1243pub type CodeSection = Vec<Function>;
1244pub type DataSection = Vec<Data>;
1245pub type CustomSection = (Cow<'static, str>, Vec<u8>);
1246
1247#[derive(derive_more::Debug, Clone, Default)]
1248#[debug("Module {{ .. }}")]
1249pub struct Module {
1250    pub type_section: Option<TypeSection>,
1251    pub import_section: Option<Vec<Import>>,
1252    pub function_section: Option<FuncSection>,
1253    pub table_section: Option<Table>,
1254    pub memory_section: Option<MemoryType>,
1255    pub global_section: Option<Vec<Global>>,
1256    pub export_section: Option<Vec<Export>>,
1257    pub start_section: Option<u32>,
1258    pub element_section: Option<Vec<Element>>,
1259    pub code_section: Option<CodeSection>,
1260    pub data_section: Option<DataSection>,
1261    pub name_section: Option<Vec<Name>>,
1262    pub custom_sections: Option<Vec<CustomSection>>,
1263}
1264
1265impl Module {
1266    pub fn new(code: &[u8]) -> Result<Self> {
1267        let mut type_section = None;
1268        let mut import_section = None;
1269        let mut function_section = None;
1270        let mut table_section = None;
1271        let mut memory_section = None;
1272        let mut global_section = None;
1273        let mut export_section = None;
1274        let mut start_section = None;
1275        let mut element_section = None;
1276        let mut code_section = None;
1277        let mut data_section = None;
1278        let mut name_section = None;
1279        let mut custom_sections = None;
1280
1281        let mut parser = wasmparser::Parser::new(0);
1282        parser.set_features(GEAR_SUPPORTED_FEATURES);
1283        for payload in parser.parse_all(code) {
1284            match payload? {
1285                Payload::Version {
1286                    num: _,
1287                    encoding,
1288                    range: _,
1289                } => {
1290                    debug_assert_eq!(encoding, Encoding::Module);
1291                }
1292                Payload::TypeSection(section) => {
1293                    debug_assert!(type_section.is_none());
1294                    type_section = Some(
1295                        section
1296                            .into_iter_err_on_gc_types()
1297                            .collect::<Result<_, _>>()?,
1298                    );
1299                }
1300                Payload::ImportSection(section) => {
1301                    debug_assert!(import_section.is_none());
1302                    import_section = Some(
1303                        section
1304                            .into_iter()
1305                            .map(|import| import.map(Import::parse))
1306                            .collect::<Result<_, _>>()?,
1307                    );
1308                }
1309                Payload::FunctionSection(section) => {
1310                    debug_assert!(function_section.is_none());
1311                    function_section = Some(section.into_iter().collect::<Result<_, _>>()?);
1312                }
1313                Payload::TableSection(section) => {
1314                    debug_assert!(table_section.is_none());
1315                    let mut section = section.into_iter();
1316
1317                    table_section = section
1318                        .next()
1319                        .map(|table| table.map_err(Into::into).and_then(Table::parse))
1320                        .transpose()?;
1321
1322                    if section.next().is_some() {
1323                        return Err(ModuleError::MultipleTables);
1324                    }
1325                }
1326                Payload::MemorySection(section) => {
1327                    debug_assert!(memory_section.is_none());
1328                    let mut section = section.into_iter();
1329
1330                    memory_section = section.next().transpose()?;
1331
1332                    if section.next().is_some() {
1333                        return Err(ModuleError::MultipleMemories);
1334                    }
1335                }
1336                Payload::TagSection(_) => {}
1337                Payload::GlobalSection(section) => {
1338                    debug_assert!(global_section.is_none());
1339                    global_section = Some(
1340                        section
1341                            .into_iter()
1342                            .map(|element| element.map_err(Into::into).and_then(Global::parse))
1343                            .collect::<Result<_, _>>()?,
1344                    );
1345                }
1346                Payload::ExportSection(section) => {
1347                    debug_assert!(export_section.is_none());
1348                    export_section = Some(
1349                        section
1350                            .into_iter()
1351                            .map(|e| e.map(Export::parse))
1352                            .collect::<Result<_, _>>()?,
1353                    );
1354                }
1355                Payload::StartSection { func, range: _ } => {
1356                    start_section = Some(func);
1357                }
1358                Payload::ElementSection(section) => {
1359                    debug_assert!(element_section.is_none());
1360                    element_section = Some(
1361                        section
1362                            .into_iter()
1363                            .map(|element| element.map_err(Into::into).and_then(Element::parse))
1364                            .collect::<Result<Vec<_>>>()?,
1365                    );
1366                }
1367                // note: the section is not present in WASM 1.0
1368                Payload::DataCountSection { count, range: _ } => {
1369                    data_section = Some(Vec::with_capacity(count as usize));
1370                }
1371                Payload::DataSection(section) => {
1372                    let data_section = data_section.get_or_insert_with(Vec::new);
1373                    for data in section {
1374                        let data = data?;
1375                        data_section.push(Data::parse(data)?);
1376                    }
1377                }
1378                Payload::CodeSectionStart {
1379                    count,
1380                    range: _,
1381                    size: _,
1382                } => {
1383                    code_section = Some(Vec::with_capacity(count as usize));
1384                }
1385                Payload::CodeSectionEntry(entry) => {
1386                    code_section
1387                        .as_mut()
1388                        .expect("code section start missing")
1389                        .push(Function::from_entry(entry)?);
1390                }
1391                Payload::CustomSection(section) => match section.as_known() {
1392                    KnownCustom::Name(name_section_reader) => {
1393                        name_section = Some(
1394                            name_section_reader
1395                                .into_iter()
1396                                .map(|name| name.map_err(Into::into).and_then(Name::parse))
1397                                .collect::<Result<Vec<_>>>()?,
1398                        );
1399                    }
1400                    _ => {
1401                        let custom_sections = custom_sections.get_or_insert_with(Vec::new);
1402                        let name = section.name().to_string().into();
1403                        let data = section.data().to_vec();
1404                        custom_sections.push((name, data));
1405                    }
1406                },
1407                Payload::UnknownSection { .. } => {}
1408                _ => {}
1409            }
1410        }
1411
1412        Ok(Self {
1413            type_section,
1414            import_section,
1415            function_section,
1416            table_section,
1417            memory_section,
1418            global_section,
1419            export_section,
1420            start_section,
1421            element_section,
1422            code_section,
1423            data_section,
1424            name_section,
1425            custom_sections,
1426        })
1427    }
1428
1429    /// Strips all non-name WASM custom sections from the module.
1430    ///
1431    /// The `name` section is **preserved** to keep Wasmer/Wasmtime trap
1432    /// backtraces readable in production logs. This differs from
1433    /// `wasm_optimizer::Optimizer::strip_custom_sections`, which clears both.
1434    ///
1435    /// Non-name custom sections (`sails:idl`, `producers`, etc.) are not consumed
1436    /// at sandbox execution time; IDL readers pull from `OriginalCode`.
1437    pub fn strip_custom_sections(&mut self) {
1438        self.custom_sections = None;
1439    }
1440
1441    pub fn serialize(&self) -> Result<Vec<u8>> {
1442        let mut module = wasm_encoder::Module::new();
1443
1444        if let Some(crate_section) = &self.type_section {
1445            let mut encoder_section = wasm_encoder::TypeSection::new();
1446            for func_type in crate_section {
1447                encoder_section
1448                    .ty()
1449                    .func_type(&RoundtripReencoder.func_type(func_type.clone())?);
1450            }
1451            module.section(&encoder_section);
1452        }
1453
1454        if let Some(crate_section) = &self.import_section {
1455            let mut encoder_section = wasm_encoder::ImportSection::new();
1456            for import in crate_section {
1457                import.reencode(&mut encoder_section)?;
1458            }
1459            module.section(&encoder_section);
1460        }
1461
1462        if let Some(crate_section) = &self.function_section {
1463            let mut encoder_section = wasm_encoder::FunctionSection::new();
1464            for &function in crate_section {
1465                encoder_section.function(function);
1466            }
1467            module.section(&encoder_section);
1468        }
1469
1470        if let Some(table) = &self.table_section {
1471            let mut encoder_section = wasm_encoder::TableSection::new();
1472            table.reencode(&mut encoder_section)?;
1473            module.section(&encoder_section);
1474        }
1475
1476        if let Some(memory) = &self.memory_section {
1477            let mut encoder_section = wasm_encoder::MemorySection::new();
1478            encoder_section.memory(RoundtripReencoder.memory_type(*memory));
1479            module.section(&encoder_section);
1480        }
1481
1482        if let Some(crate_section) = &self.global_section {
1483            let mut encoder_section = wasm_encoder::GlobalSection::new();
1484            for global in crate_section {
1485                encoder_section.global(
1486                    RoundtripReencoder.global_type(global.ty)?,
1487                    &global.init_expr.reencode()?,
1488                );
1489            }
1490            module.section(&encoder_section);
1491        }
1492
1493        if let Some(crate_section) = &self.export_section {
1494            let mut encoder_section = wasm_encoder::ExportSection::new();
1495            for export in crate_section {
1496                encoder_section.export(
1497                    &export.name,
1498                    RoundtripReencoder.export_kind(export.kind),
1499                    export.index,
1500                );
1501            }
1502            module.section(&encoder_section);
1503        }
1504
1505        if let Some(function_index) = self.start_section {
1506            module.section(&wasm_encoder::StartSection { function_index });
1507        }
1508
1509        if let Some(crate_section) = &self.element_section {
1510            let mut encoder_section = wasm_encoder::ElementSection::new();
1511            for element in crate_section {
1512                element.reencode(&mut encoder_section)?;
1513            }
1514            module.section(&encoder_section);
1515        }
1516
1517        if let Some(crate_section) = &self.code_section {
1518            let mut encoder_section = wasm_encoder::CodeSection::new();
1519            for function in crate_section {
1520                encoder_section.function(&function.reencode()?);
1521            }
1522            module.section(&encoder_section);
1523        }
1524
1525        if let Some(crate_section) = &self.data_section {
1526            let mut encoder_section = wasm_encoder::DataSection::new();
1527            for data in crate_section {
1528                encoder_section.active(0, &data.offset_expr.reencode()?, data.data.iter().copied());
1529            }
1530            module.section(&encoder_section);
1531        }
1532
1533        if let Some(name_section) = &self.name_section {
1534            let mut encoder_section = wasm_encoder::NameSection::new();
1535            for name in name_section {
1536                name.reencode(&mut encoder_section);
1537            }
1538            module.section(&encoder_section);
1539        }
1540
1541        if let Some(custom_sections) = &self.custom_sections {
1542            for (name, data) in custom_sections {
1543                let encoder_section = wasm_encoder::CustomSection {
1544                    name: Cow::Borrowed(name),
1545                    data: Cow::Borrowed(data),
1546                };
1547                module.section(&encoder_section);
1548            }
1549        }
1550
1551        Ok(module.finish())
1552    }
1553
1554    pub fn import_count(&self, pred: impl Fn(&TypeRef) -> bool) -> usize {
1555        self.import_section
1556            .as_ref()
1557            .map(|imports| imports.iter().filter(|import| pred(&import.ty)).count())
1558            .unwrap_or(0)
1559    }
1560
1561    pub fn functions_space(&self) -> usize {
1562        self.import_count(|ty| matches!(ty, TypeRef::Func(_)))
1563            + self
1564                .function_section
1565                .as_ref()
1566                .map(|section| section.len())
1567                .unwrap_or(0)
1568    }
1569
1570    pub fn globals_space(&self) -> usize {
1571        self.import_count(|ty| matches!(ty, TypeRef::Global(_)))
1572            + self
1573                .global_section
1574                .as_ref()
1575                .map(|section| section.len())
1576                .unwrap_or(0)
1577    }
1578}
1579
1580#[cfg(test)]
1581mod tests {
1582    use super::*;
1583
1584    macro_rules! test_parsing_failed {
1585        (
1586            $( $test_name:ident: $wat:literal => $err:expr; )*
1587        ) => {
1588            $(
1589                #[test]
1590                fn $test_name() {
1591                    let wasm = wat::parse_str($wat).unwrap();
1592                    let lhs = Module::new(&wasm).unwrap_err();
1593                    let rhs: ModuleError = $err;
1594                    // we cannot compare errors directly because `BinaryReaderError` does not implement `PartialEq`
1595                    assert_eq!(format!("{lhs:?}"), format!("{rhs:?}"));
1596                }
1597            )*
1598        };
1599    }
1600
1601    test_parsing_failed! {
1602        multiple_tables_denied: r#"
1603        (module
1604            (table 10 10 funcref)
1605            (table 20 20 funcref)
1606        )"# => ModuleError::MultipleTables;
1607
1608        multiple_memories_denied: r#"
1609        (module
1610            (memory (export "memory") 1)
1611            (memory (export "memory2") 2)
1612        )"# => ModuleError::MultipleMemories;
1613
1614        data_non_zero_memory_idx_denied: r#"
1615        (module
1616            (data (memory 123) (offset i32.const 0) "")
1617        )
1618        "# => ModuleError::NonZeroMemoryIdx(123);
1619
1620        element_table_idx_denied: r#"
1621        (module
1622            (elem 123 (offset i32.const 0) 0 0 0 0)
1623        )"# => ModuleError::ElementTableIdx(123);
1624
1625        passive_data_kind_denied: r#"
1626        (module
1627            (data "")
1628        )
1629        "# => ModuleError::PassiveDataKind;
1630
1631        passive_element_denied: r#"
1632        (module
1633            (elem funcref (item i32.const 0))
1634        )
1635        "# => ModuleError::NonActiveElementKind;
1636
1637        declared_element_denied: r#"
1638        (module
1639            (func $a)
1640            (elem declare func $a)
1641        )
1642        "# => ModuleError::NonActiveElementKind;
1643
1644        element_expressions_denied: r#"
1645        (module
1646            (elem (i32.const 1) funcref)
1647        )
1648        "# => ModuleError::ElementExpressions;
1649
1650        table_init_expr_denied: r#"
1651        (module
1652            (table 0 0 funcref (i32.const 0))
1653        )"# => ModuleError::TableInitExpr;
1654    }
1655
1656    #[test]
1657    fn call_indirect_non_zero_table_idx_denied() {
1658        let wasm = wat::parse_str(
1659            r#"
1660            (module
1661                (func
1662                    call_indirect 123 (type 333)
1663                )
1664            )
1665            "#,
1666        )
1667        .unwrap();
1668        let err = Module::new(&wasm).unwrap_err();
1669        if let ModuleError::BinaryReader(err) = err {
1670            assert_eq!(err.offset(), 26);
1671            assert_eq!(err.message(), "zero byte expected");
1672        } else {
1673            panic!("{err}");
1674        }
1675    }
1676
1677    #[test]
1678    fn custom_section_kept() {
1679        let mut builder = ModuleBuilder::default();
1680        builder.push_custom_section("dummy", [1, 2, 3]);
1681        let module = builder.build();
1682        let module_bytes = module.serialize().unwrap();
1683        let wat = wasmprinter::print_bytes(&module_bytes).unwrap();
1684
1685        let parsed_module_bytes = Module::new(&module_bytes).unwrap().serialize().unwrap();
1686        let parsed_wat = wasmprinter::print_bytes(&parsed_module_bytes).unwrap();
1687        assert_eq!(wat, parsed_wat);
1688    }
1689
1690    #[test]
1691    fn strip_custom_sections_clears_custom_but_keeps_name() {
1692        let mut builder = ModuleBuilder::default();
1693        builder.push_custom_section("sails:idl", [0xAA, 0xBB, 0xCC]);
1694        builder.push_custom_section("producers", [0xDE, 0xAD]);
1695        let mut module = builder.build();
1696        // Simulate a preserved name section.
1697        module.name_section = Some(Vec::new());
1698
1699        module.strip_custom_sections();
1700
1701        assert!(
1702            module.custom_sections.is_none(),
1703            "custom_sections must be cleared"
1704        );
1705        assert!(
1706            module.name_section.is_some(),
1707            "name_section must be preserved across strip"
1708        );
1709
1710        // Round-trip through serialize/parse: non-name custom sections must not
1711        // reappear in the serialized bytes.
1712        let bytes = module.serialize().unwrap();
1713        let reparsed = Module::new(&bytes).unwrap();
1714        assert!(
1715            reparsed
1716                .custom_sections
1717                .as_ref()
1718                .is_none_or(|cs| cs.is_empty()),
1719            "serialized module must not contain non-name custom sections after strip"
1720        );
1721    }
1722
1723    #[test]
1724    fn strip_custom_sections_on_empty_module_is_noop() {
1725        let mut module = ModuleBuilder::default().build();
1726        // No custom sections, no name section: must not panic, stays None.
1727        assert!(module.custom_sections.is_none());
1728        module.strip_custom_sections();
1729        assert!(module.custom_sections.is_none());
1730    }
1731}