gear_wasm_instrument/
module.rs

1// This file is part of Gear.
2//
3// Copyright (C) 2025 Gear Technologies Inc.
4// SPDX-License-Identifier: GPL-3.0-or-later WITH Classpath-exception-2.0
5//
6// This program is free software: you can redistribute it and/or modify
7// it under the terms of the GNU General Public License as published by
8// the Free Software Foundation, either version 3 of the License, or
9// (at your option) any later version.
10//
11// This program is distributed in the hope that it will be useful,
12// but WITHOUT ANY WARRANTY; without even the implied warranty of
13// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14// GNU General Public License for more details.
15//
16// You should have received a copy of the GNU General Public License
17// along with this program. If not, see <https://www.gnu.org/licenses/>.
18
19use alloc::{
20    borrow::Cow,
21    format,
22    string::{String, ToString},
23    vec,
24    vec::Vec,
25};
26use core::convert::Infallible;
27use wasm_encoder::{
28    reencode,
29    reencode::{Reencode, RoundtripReencoder},
30};
31use wasmparser::{
32    BinaryReaderError, Encoding, ExternalKind, FuncType, FunctionBody, GlobalType, KnownCustom,
33    MemoryType, Payload, RefType, TableType, TypeRef, ValType, WasmFeatures,
34};
35
36pub const GEAR_SUPPORTED_FEATURES: WasmFeatures = WasmFeatures::WASM1
37    .union(WasmFeatures::SIGN_EXTENSION)
38    .difference(WasmFeatures::FLOATS);
39
40// based on `wasmparser::_for_each_operator_group` and
41// it's recommended to read its documentation to understand the logic
42//
43// float instructions are removed
44macro_rules! for_each_instruction_group {
45    ($mac:ident) => {
46        $mac! {
47            @mvp {
48                Unreachable
49                Nop
50                Block { blockty: $crate::BlockType }
51                Loop { blockty: $crate::BlockType }
52                If { blockty: $crate::BlockType }
53                Else
54                End
55                Br { relative_depth: u32 }
56                BrIf { relative_depth: u32 }
57                BrTable { targets: $crate::BrTable }
58                Return
59                Call { function_index: u32 }
60                CallIndirect { type_index: u32, table_index: u32 }
61                Drop
62                Select
63                LocalGet { local_index: u32 }
64                LocalSet { local_index: u32 }
65                LocalTee { local_index: u32 }
66                GlobalGet { global_index: u32 }
67                GlobalSet { global_index: u32 }
68                I32Load { memarg: $crate::MemArg }
69                I64Load { memarg: $crate::MemArg }
70                I32Load8S { memarg: $crate::MemArg }
71                I32Load8U { memarg: $crate::MemArg }
72                I32Load16S { memarg: $crate::MemArg }
73                I32Load16U { memarg: $crate::MemArg }
74                I64Load8S { memarg: $crate::MemArg }
75                I64Load8U { memarg: $crate::MemArg }
76                I64Load16S { memarg: $crate::MemArg }
77                I64Load16U { memarg: $crate::MemArg }
78                I64Load32S { memarg: $crate::MemArg }
79                I64Load32U { memarg: $crate::MemArg }
80                I32Store { memarg: $crate::MemArg }
81                I64Store { memarg: $crate::MemArg }
82                I32Store8 { memarg: $crate::MemArg }
83                I32Store16 { memarg: $crate::MemArg }
84                I64Store8 { memarg: $crate::MemArg }
85                I64Store16 { memarg: $crate::MemArg }
86                I64Store32 { memarg: $crate::MemArg }
87                MemorySize { mem: u32 }
88                MemoryGrow { mem: u32 }
89                I32Const { value: i32 }
90                I64Const { value: i64 }
91                I32Eqz
92                I32Eq
93                I32Ne
94                I32LtS
95                I32LtU
96                I32GtS
97                I32GtU
98                I32LeS
99                I32LeU
100                I32GeS
101                I32GeU
102                I64Eqz
103                I64Eq
104                I64Ne
105                I64LtS
106                I64LtU
107                I64GtS
108                I64GtU
109                I64LeS
110                I64LeU
111                I64GeS
112                I64GeU
113                I32Clz
114                I32Ctz
115                I32Popcnt
116                I32Add
117                I32Sub
118                I32Mul
119                I32DivS
120                I32DivU
121                I32RemS
122                I32RemU
123                I32And
124                I32Or
125                I32Xor
126                I32Shl
127                I32ShrS
128                I32ShrU
129                I32Rotl
130                I32Rotr
131                I64Clz
132                I64Ctz
133                I64Popcnt
134                I64Add
135                I64Sub
136                I64Mul
137                I64DivS
138                I64DivU
139                I64RemS
140                I64RemU
141                I64And
142                I64Or
143                I64Xor
144                I64Shl
145                I64ShrS
146                I64ShrU
147                I64Rotl
148                I64Rotr
149                I32WrapI64
150                I64ExtendI32S
151                I64ExtendI32U
152            }
153
154            @sign_extension {
155                I32Extend8S
156                I32Extend16S
157                I64Extend8S
158                I64Extend16S
159                I64Extend32S
160            }
161        }
162    };
163}
164
165// exactly the same as `for_each_instruction_group` but without proposals info
166macro_rules! define_for_each_instruction {
167    ($(
168        @$proposal:ident {
169            $($op:ident $( { $( $arg:ident: $argty:ty ),+ } )?)+
170        }
171    )+) => {
172        macro_rules! for_each_instruction {
173            ($mac:ident) => {
174                $mac! {
175                    $(
176                        $(
177                            $op $( { $( $arg: $argty ),+ } )?
178                        )+
179                    )+
180                }
181            };
182        }
183    };
184}
185
186// `for_each_instruction` is now defined
187for_each_instruction_group!(define_for_each_instruction);
188
189macro_rules! define_instruction {
190    ($( $op:ident $( { $( $arg:ident: $argty:ty ),+ } )? )+) => {
191        define_instruction!(@convert $( $op $( { $( $arg: $argty ),+ } )? )+ @accum);
192    };
193    // omit `table_index` field of `call_indirect` instruction because it's always zero
194    // but we still save original fields to use them for `wasmparser` and `wasm-encoder` types
195    // during parsing and reencoding
196    (
197        @convert
198        CallIndirect { $type_index_arg:ident: $type_index_argty:ty, $table_index_arg:ident: $table_index_argty:ty }
199        $( $ops:ident $( { $($args:ident: $argtys:ty),+ } )? )*
200        @accum
201        $( $accum_op:ident $( { $($original_arg:ident: $original_argty:ty),+ } => { $($accum_arg:ident: $accum_argty:ty),+ } )? )*
202    ) => {
203        define_instruction!(
204            @convert
205            $( $ops $( { $($args: $argtys),+ } )? )*
206            @accum
207            CallIndirect { $type_index_arg: $type_index_argty, $table_index_arg: $table_index_argty } => { $type_index_arg: $type_index_argty }
208            $( $accum_op $( { $($original_arg: $original_argty),+ } => { $($accum_arg: $accum_argty),+ } )? )*
209        );
210    };
211    // do nothing to the rest instructions and collect them
212    (
213        @convert
214        $op:ident $( { $($arg:ident: $argty:ty),+ } )?
215        $( $ops:ident $( { $($args:ident: $argtys:ty),+ } )? )*
216        @accum
217        $( $accum_op:ident $( { $($original_arg:ident: $original_argty:ty),+ } => { $($accum_arg:ident: $accum_argty:ty),+ } )? )*
218    ) => {
219        define_instruction!(
220            @convert
221            $( $ops $( { $($args: $argtys),+ } )? )*
222            @accum
223            $op $( { $($arg: $argty),+ } => { $($arg: $argty),+ } )?
224            $( $accum_op $( { $($original_arg: $original_argty),+ } => { $($accum_arg: $accum_argty),+ } )? )*
225        );
226    };
227    // collection is done so we define `Instruction` itself now
228    (
229        @convert
230        @accum
231        $( $op:ident $( { $( $original_arg:ident: $original_argty:ty ),+ } => { $( $arg:ident: $argty:ty ),+ } )? )+
232    ) => {
233        #[derive(Debug, Clone, Eq, PartialEq)]
234        pub enum Instruction {
235            $(
236                $op $(( $( $argty ),+ ))?,
237            )+
238        }
239
240        impl Instruction {
241            fn parse(op: wasmparser::Operator) -> Result<Self> {
242                match op {
243                    $(
244                        wasmparser::Operator::$op $({ $($original_arg),* })? => {
245                            define_instruction!(@parse $op $(( $($original_arg $original_arg),* ))?)
246                        }
247                    )*
248                    op => Err(ModuleError::UnsupportedInstruction(format!("{op:?}"))),
249                }
250            }
251
252            fn reencode(&self) -> Result<wasm_encoder::Instruction> {
253                Ok(match self {
254                    $(
255                        Self::$op $(( $($arg),+ ))? => {
256                            $(
257                                $(let $arg = define_instruction!(@arg $arg $arg);)*
258                            )?
259                            define_instruction!(@build $op $($($arg)*)?)
260                        }
261                    )*
262                })
263            }
264        }
265    };
266
267    // further macro branches are based on `wasm_encoder::reencode` module
268
269    (@parse CallIndirect($type_index:ident type_index, $table_index:ident table_index)) => {{
270        // already verified by wasmparser
271        debug_assert_eq!($table_index, 0);
272
273        Ok(Self::CallIndirect(<_>::try_from($type_index)?))
274    }};
275    (@parse $op:ident $(( $($arg:ident $_arg:ident),* ))?) => {
276        Ok(Self::$op $(( $(<_>::try_from($arg)?),* ))?)
277    };
278
279    (@arg $arg:ident blockty) => (RoundtripReencoder.block_type(*$arg)?);
280    (@arg $arg:ident targets) => ((
281        ($arg).targets.clone().into(),
282        ($arg).default,
283    ));
284    (@arg $arg:ident memarg) => ((*$arg).reencode());
285    (@arg $arg:ident $_arg:ident) => (*$arg);
286
287    (@build $op:ident) => (wasm_encoder::Instruction::$op);
288    (@build BrTable $arg:ident) => (wasm_encoder::Instruction::BrTable($arg.0, $arg.1));
289    (@build I32Const $arg:ident) => (wasm_encoder::Instruction::I32Const($arg));
290    (@build I64Const $arg:ident) => (wasm_encoder::Instruction::I64Const($arg));
291    (@build F32Const $arg:ident) => (wasm_encoder::Instruction::F32Const(f32::from_bits($arg.bits())));
292    (@build F64Const $arg:ident) => (wasm_encoder::Instruction::F64Const(f64::from_bits($arg.bits())));
293    (@build CallIndirect $arg:ident) => (wasm_encoder::Instruction::CallIndirect { type_index: $arg, table_index: 0 });
294    (@build $op:ident $arg:ident) => (wasm_encoder::Instruction::$op($arg));
295    (@build $op:ident $($arg:ident)*) => (wasm_encoder::Instruction::$op { $($arg),* });
296}
297
298for_each_instruction!(define_instruction);
299
300impl Instruction {
301    /// Returns `true` if instruction is forbidden to be used by user
302    /// but allowed to be used by instrumentation stage.
303    pub fn is_user_forbidden(&self) -> bool {
304        matches!(self, Self::MemoryGrow { .. })
305    }
306}
307
308pub type Result<T, E = ModuleError> = core::result::Result<T, E>;
309
310#[derive(Debug, derive_more::Display, derive_more::From)]
311pub enum ModuleError {
312    #[display("Binary reader error: {_0}")]
313    BinaryReader(BinaryReaderError),
314    #[display("Reencode error: {_0}")]
315    Reencode(reencode::Error),
316    #[display("Int conversion error: {_0}")]
317    TryFromInt(core::num::TryFromIntError),
318    #[display("Unsupported instruction: {_0}")]
319    UnsupportedInstruction(String),
320    #[display("Multiple tables")]
321    MultipleTables,
322    #[display("Multiple memories")]
323    MultipleMemories,
324    #[from(skip)]
325    #[display("Memory index must be zero (actual: {_0})")]
326    NonZeroMemoryIdx(u32),
327    #[from(skip)]
328    #[display("Optional table index of element segment is not supported (index: {_0})")]
329    ElementTableIdx(u32),
330    #[display("Passive data is not supported")]
331    PassiveDataKind,
332    #[display("Element expressions are not supported")]
333    ElementExpressions,
334    #[display("Only active element is supported")]
335    NonActiveElementKind,
336    #[display("Table init expression is not supported")]
337    TableInitExpr,
338}
339
340impl core::error::Error for ModuleError {
341    fn source(&self) -> Option<&(dyn core::error::Error + 'static)> {
342        match self {
343            ModuleError::BinaryReader(e) => Some(e),
344            ModuleError::Reencode(e) => Some(e),
345            ModuleError::TryFromInt(e) => Some(e),
346            ModuleError::UnsupportedInstruction(_) => None,
347            ModuleError::MultipleTables => None,
348            ModuleError::MultipleMemories => None,
349            ModuleError::NonZeroMemoryIdx(_) => None,
350            ModuleError::ElementTableIdx(_) => None,
351            ModuleError::PassiveDataKind => None,
352            ModuleError::ElementExpressions => None,
353            ModuleError::NonActiveElementKind => None,
354            ModuleError::TableInitExpr => None,
355        }
356    }
357}
358
359impl From<Infallible> for ModuleError {
360    fn from(value: Infallible) -> Self {
361        match value {}
362    }
363}
364
365#[derive(Debug, Clone, Copy, Eq, PartialEq)]
366pub struct MemArg {
367    /// The expected alignment of the instruction's dynamic address operand
368    /// (expressed the exponent of a power of two).
369    pub align: u8,
370    /// A static offset to add to the instruction's dynamic address operand.
371    pub offset: u32,
372}
373
374impl TryFrom<wasmparser::MemArg> for MemArg {
375    type Error = ModuleError;
376
377    fn try_from(
378        wasmparser::MemArg {
379            align,
380            max_align: _,
381            offset,
382            memory,
383        }: wasmparser::MemArg,
384    ) -> Result<Self, Self::Error> {
385        // always zero if multi-memory is not enabled
386        debug_assert_eq!(memory, 0);
387        Ok(Self {
388            align,
389            offset: offset.try_into()?,
390        })
391    }
392}
393
394impl MemArg {
395    pub fn zero() -> Self {
396        Self {
397            align: 0,
398            offset: 0,
399        }
400    }
401
402    pub fn i32() -> Self {
403        Self::i32_offset(0)
404    }
405
406    pub fn i64() -> Self {
407        Self::i64_offset(0)
408    }
409
410    pub fn i32_offset(offset: u32) -> Self {
411        Self { align: 2, offset }
412    }
413
414    pub fn i64_offset(offset: u32) -> Self {
415        Self { align: 3, offset }
416    }
417
418    fn reencode(self) -> wasm_encoder::MemArg {
419        wasm_encoder::MemArg {
420            offset: self.offset as u64,
421            align: self.align as u32,
422            memory_index: 0,
423        }
424    }
425}
426
427#[derive(Debug, Clone, Eq, PartialEq)]
428pub struct BrTable {
429    pub default: u32,
430    pub targets: Vec<u32>,
431}
432
433impl BrTable {
434    /// Returns the number of `br_table` entries, not including the default label.
435    pub fn len(&self) -> u32 {
436        self.targets.len() as u32
437    }
438
439    /// Returns whether `BrTable` doesn’t have any labels apart from the default one.
440    pub fn is_empty(&self) -> bool {
441        self.targets.is_empty()
442    }
443}
444
445impl TryFrom<wasmparser::BrTable<'_>> for BrTable {
446    type Error = ModuleError;
447
448    fn try_from(targets: wasmparser::BrTable) -> Result<Self> {
449        Ok(Self {
450            default: targets.default(),
451            targets: targets
452                .targets()
453                .collect::<Result<Vec<_>, BinaryReaderError>>()?,
454        })
455    }
456}
457
458#[derive(Default, Clone, derive_more::Debug, Eq, PartialEq)]
459#[debug("ConstExpr {{ .. }}")]
460pub struct ConstExpr {
461    pub instructions: Vec<Instruction>,
462}
463
464impl ConstExpr {
465    pub fn empty() -> Self {
466        Self {
467            instructions: Vec::new(),
468        }
469    }
470
471    pub fn i32_value(value: i32) -> Self {
472        Self {
473            instructions: vec![Instruction::I32Const(value)],
474        }
475    }
476
477    pub fn i64_value(value: i64) -> Self {
478        Self {
479            instructions: vec![Instruction::I64Const(value)],
480        }
481    }
482
483    fn parse(expr: wasmparser::ConstExpr) -> Result<Self> {
484        let mut instructions = Vec::new();
485        let mut ops = expr.get_operators_reader();
486        while !ops.is_end_then_eof() {
487            instructions.push(Instruction::parse(ops.read()?)?);
488        }
489
490        Ok(Self { instructions })
491    }
492
493    fn reencode(&self) -> Result<wasm_encoder::ConstExpr> {
494        Ok(wasm_encoder::ConstExpr::extended(
495            self.instructions
496                .iter()
497                .map(Instruction::reencode)
498                .collect::<Result<Vec<_>>>()?,
499        ))
500    }
501}
502
503#[derive(Debug, Clone, Eq, PartialEq)]
504pub struct Import {
505    pub module: Cow<'static, str>,
506    pub name: Cow<'static, str>,
507    pub ty: TypeRef,
508}
509
510impl Import {
511    pub fn func(
512        module: impl Into<Cow<'static, str>>,
513        name: impl Into<Cow<'static, str>>,
514        index: u32,
515    ) -> Self {
516        Self {
517            module: module.into(),
518            name: name.into(),
519            ty: TypeRef::Func(index),
520        }
521    }
522
523    pub fn memory(initial: u32, maximum: Option<u32>) -> Self {
524        Self {
525            module: "env".into(),
526            name: "memory".into(),
527            ty: TypeRef::Memory(MemoryType {
528                memory64: false,
529                shared: false,
530                initial: initial as u64,
531                maximum: maximum.map(|v| v as u64),
532                page_size_log2: None,
533            }),
534        }
535    }
536
537    fn parse(import: wasmparser::Import) -> Self {
538        Self {
539            module: import.module.to_string().into(),
540            name: import.name.to_string().into(),
541            ty: import.ty,
542        }
543    }
544
545    pub fn reencode(&self, imports: &mut wasm_encoder::ImportSection) -> Result<()> {
546        imports.import(
547            &self.module,
548            &self.name,
549            RoundtripReencoder.entity_type(self.ty)?,
550        );
551        Ok(())
552    }
553}
554
555#[derive(Clone, Debug, Eq, PartialEq)]
556pub enum TableInit {
557    RefNull,
558}
559
560#[derive(Clone, Debug, Eq, PartialEq)]
561pub struct Table {
562    pub ty: TableType,
563    pub init: TableInit,
564}
565
566impl Table {
567    pub fn funcref(initial: u32, maximum: Option<u32>) -> Self {
568        Table {
569            ty: TableType {
570                element_type: RefType::FUNCREF,
571                table64: false,
572                initial: initial as u64,
573                maximum: maximum.map(|v| v as u64),
574                shared: false,
575            },
576            init: TableInit::RefNull,
577        }
578    }
579
580    fn parse(table: wasmparser::Table) -> Result<Self> {
581        Ok(Self {
582            ty: table.ty,
583            init: match table.init {
584                wasmparser::TableInit::RefNull => TableInit::RefNull,
585                wasmparser::TableInit::Expr(_expr) => return Err(ModuleError::TableInitExpr),
586            },
587        })
588    }
589
590    fn reencode(&self, tables: &mut wasm_encoder::TableSection) -> Result<()> {
591        let ty = RoundtripReencoder.table_type(self.ty)?;
592        match &self.init {
593            TableInit::RefNull => {
594                tables.table(ty);
595            }
596        }
597        Ok(())
598    }
599}
600
601#[derive(Debug, Clone)]
602pub struct Global {
603    pub ty: GlobalType,
604    pub init_expr: ConstExpr,
605}
606
607impl Global {
608    pub fn i32_value(value: i32) -> Self {
609        Self {
610            ty: GlobalType {
611                content_type: ValType::I32,
612                mutable: false,
613                shared: false,
614            },
615            init_expr: ConstExpr::i32_value(value),
616        }
617    }
618
619    pub fn i64_value(value: i64) -> Self {
620        Self {
621            ty: GlobalType {
622                content_type: ValType::I64,
623                mutable: false,
624                shared: false,
625            },
626            init_expr: ConstExpr::i64_value(value),
627        }
628    }
629
630    pub fn i64_value_mut(value: i64) -> Self {
631        Self {
632            ty: GlobalType {
633                content_type: ValType::I64,
634                mutable: true,
635                shared: false,
636            },
637            init_expr: ConstExpr::i64_value(value),
638        }
639    }
640
641    fn parse(global: wasmparser::Global) -> Result<Self> {
642        Ok(Self {
643            ty: global.ty,
644            init_expr: ConstExpr::parse(global.init_expr)?,
645        })
646    }
647}
648
649#[derive(Debug, Clone)]
650pub struct Export {
651    pub name: Cow<'static, str>,
652    pub kind: ExternalKind,
653    pub index: u32,
654}
655
656impl Export {
657    pub fn func(name: impl Into<Cow<'static, str>>, index: u32) -> Self {
658        Self {
659            name: name.into(),
660            kind: ExternalKind::Func,
661            index,
662        }
663    }
664
665    pub fn global(name: impl Into<Cow<'static, str>>, index: u32) -> Self {
666        Self {
667            name: name.into(),
668            kind: ExternalKind::Global,
669            index,
670        }
671    }
672
673    fn parse(export: wasmparser::Export) -> Self {
674        Self {
675            name: export.name.to_string().into(),
676            kind: export.kind,
677            index: export.index,
678        }
679    }
680}
681
682#[derive(Clone)]
683pub enum ElementKind {
684    Active { offset_expr: ConstExpr },
685}
686
687impl ElementKind {
688    fn parse(kind: wasmparser::ElementKind) -> Result<Self> {
689        match kind {
690            wasmparser::ElementKind::Passive => Err(ModuleError::NonActiveElementKind),
691            wasmparser::ElementKind::Active {
692                table_index,
693                offset_expr,
694            } => {
695                if let Some(table_index) = table_index {
696                    return Err(ModuleError::ElementTableIdx(table_index));
697                }
698
699                Ok(Self::Active {
700                    offset_expr: ConstExpr::parse(offset_expr)?,
701                })
702            }
703            wasmparser::ElementKind::Declared => Err(ModuleError::NonActiveElementKind),
704        }
705    }
706}
707
708#[derive(Clone)]
709pub enum ElementItems {
710    Functions(Vec<u32>),
711}
712
713impl ElementItems {
714    fn parse(elements: wasmparser::ElementItems) -> Result<Self> {
715        match elements {
716            wasmparser::ElementItems::Functions(f) => {
717                let mut funcs = Vec::new();
718                for func in f {
719                    funcs.push(func?);
720                }
721                Ok(Self::Functions(funcs))
722            }
723            wasmparser::ElementItems::Expressions(_ty, _e) => Err(ModuleError::ElementExpressions),
724        }
725    }
726}
727
728#[derive(Clone)]
729pub struct Element {
730    pub kind: ElementKind,
731    pub items: ElementItems,
732}
733
734impl Element {
735    pub fn functions(funcs: Vec<u32>) -> Self {
736        Self {
737            kind: ElementKind::Active {
738                offset_expr: ConstExpr::i32_value(0),
739            },
740            items: ElementItems::Functions(funcs),
741        }
742    }
743
744    fn parse(element: wasmparser::Element) -> Result<Self> {
745        Ok(Self {
746            kind: ElementKind::parse(element.kind)?,
747            items: ElementItems::parse(element.items)?,
748        })
749    }
750
751    fn reencode(&self, encoder_section: &mut wasm_encoder::ElementSection) -> Result<()> {
752        let items = match &self.items {
753            ElementItems::Functions(funcs) => {
754                wasm_encoder::Elements::Functions(funcs.clone().into())
755            }
756        };
757
758        match &self.kind {
759            ElementKind::Active { offset_expr } => {
760                encoder_section.active(None, &offset_expr.reencode()?, items);
761            }
762        }
763
764        Ok(())
765    }
766}
767
768#[derive(Debug, Clone)]
769pub struct Data {
770    pub offset_expr: ConstExpr,
771    pub data: Cow<'static, [u8]>,
772}
773
774impl Data {
775    pub fn with_offset(data: impl Into<Cow<'static, [u8]>>, offset: u32) -> Self {
776        Self {
777            offset_expr: ConstExpr::i32_value(offset as i32),
778            data: data.into(),
779        }
780    }
781
782    fn parse(data: wasmparser::Data) -> Result<Self> {
783        Ok(Self {
784            offset_expr: match data.kind {
785                wasmparser::DataKind::Passive => return Err(ModuleError::PassiveDataKind),
786                wasmparser::DataKind::Active {
787                    memory_index,
788                    offset_expr,
789                } => {
790                    if memory_index != 0 {
791                        return Err(ModuleError::NonZeroMemoryIdx(memory_index));
792                    }
793
794                    ConstExpr::parse(offset_expr)?
795                }
796            },
797            data: data.data.to_vec().into(),
798        })
799    }
800}
801
802#[derive(Debug, Clone, Default)]
803pub struct Function {
804    pub locals: Vec<(u32, ValType)>,
805    pub instructions: Vec<Instruction>,
806}
807
808impl Function {
809    pub fn from_instructions(instructions: impl Into<Vec<Instruction>>) -> Self {
810        Self {
811            locals: Vec::new(),
812            instructions: instructions.into(),
813        }
814    }
815
816    fn from_entry(func: FunctionBody) -> Result<Self> {
817        let mut locals = Vec::new();
818        for pair in func.get_locals_reader()? {
819            let (cnt, ty) = pair?;
820            locals.push((cnt, ty));
821        }
822
823        let mut instructions = Vec::new();
824        let mut reader = func.get_operators_reader()?;
825        while !reader.eof() {
826            instructions.push(Instruction::parse(reader.read()?)?);
827        }
828
829        Ok(Self {
830            locals,
831            instructions,
832        })
833    }
834
835    fn reencode(&self) -> Result<wasm_encoder::Function> {
836        let mut encoder_func = wasm_encoder::Function::new(
837            self.locals
838                .iter()
839                .map(|&(cnt, ty)| Ok((cnt, RoundtripReencoder.val_type(ty)?)))
840                .collect::<Result<Vec<_>, reencode::Error>>()?,
841        );
842
843        for op in &self.instructions {
844            encoder_func.instruction(&op.reencode()?);
845        }
846
847        if self.instructions.is_empty() {
848            encoder_func.instruction(&wasm_encoder::Instruction::End);
849        }
850
851        Ok(encoder_func)
852    }
853}
854
855pub type NameMap = Vec<Naming>;
856
857/// Represents a name for an index from the names section.
858#[derive(Debug, Clone)]
859pub struct Naming {
860    /// The index being named.
861    pub index: u32,
862    /// The name for the index.
863    pub name: Cow<'static, str>,
864}
865
866pub type IndirectNameMap = Vec<IndirectNaming>;
867
868/// Represents an indirect name in the names custom section.
869#[derive(Debug, Clone)]
870pub struct IndirectNaming {
871    /// The indirect index of the name.
872    pub index: u32,
873    /// The map of names within the `index` prior.
874    pub names: NameMap,
875}
876
877#[derive(Debug, Clone)]
878pub enum Name {
879    /// The name is for the module.
880    Module(Cow<'static, str>),
881    /// The name is for the functions.
882    Function(NameMap),
883    /// The name is for the function locals.
884    Local(IndirectNameMap),
885    /// The name is for the function labels.
886    Label(IndirectNameMap),
887    /// The name is for the types.
888    Type(NameMap),
889    /// The name is for the tables.
890    Table(NameMap),
891    /// The name is for the memories.
892    Memory(NameMap),
893    /// The name is for the globals.
894    Global(NameMap),
895    /// The name is for the element segments.
896    Element(NameMap),
897    /// The name is for the data segments.
898    Data(NameMap),
899    /// The name is for fields.
900    Field(IndirectNameMap),
901    /// The name is for tags.
902    Tag(NameMap),
903    /// An unknown [name subsection](https://webassembly.github.io/spec/core/appendix/custom.html#subsections).
904    Unknown {
905        /// The identifier for this subsection.
906        ty: u8,
907        /// The contents of this subsection.
908        data: Cow<'static, [u8]>,
909    },
910}
911
912impl Name {
913    fn parse(name: wasmparser::Name) -> Result<Self> {
914        let name_map = |map: wasmparser::NameMap| {
915            map.into_iter()
916                .map(|n| {
917                    n.map(|n| Naming {
918                        index: n.index,
919                        name: n.name.to_string().into(),
920                    })
921                })
922                .collect::<Result<Vec<_>, BinaryReaderError>>()
923        };
924
925        let indirect_name_map = |map: wasmparser::IndirectNameMap| {
926            map.into_iter()
927                .map(|n| {
928                    n.and_then(|n| {
929                        Ok(IndirectNaming {
930                            index: n.index,
931                            names: name_map(n.names)?,
932                        })
933                    })
934                })
935                .collect::<Result<Vec<_>, BinaryReaderError>>()
936        };
937
938        Ok(match name {
939            wasmparser::Name::Module {
940                name,
941                name_range: _,
942            } => Self::Module(name.to_string().into()),
943            wasmparser::Name::Function(map) => Self::Function(name_map(map)?),
944            wasmparser::Name::Local(map) => Self::Local(indirect_name_map(map)?),
945            wasmparser::Name::Label(map) => Self::Label(indirect_name_map(map)?),
946            wasmparser::Name::Type(map) => Self::Type(name_map(map)?),
947            wasmparser::Name::Table(map) => Self::Table(name_map(map)?),
948            wasmparser::Name::Memory(map) => Self::Memory(name_map(map)?),
949            wasmparser::Name::Global(map) => Self::Global(name_map(map)?),
950            wasmparser::Name::Element(map) => Self::Element(name_map(map)?),
951            wasmparser::Name::Data(map) => Self::Data(name_map(map)?),
952            wasmparser::Name::Field(map) => Self::Field(indirect_name_map(map)?),
953            wasmparser::Name::Tag(map) => Self::Tag(name_map(map)?),
954            wasmparser::Name::Unknown { ty, data, range: _ } => Self::Unknown {
955                ty,
956                data: data.to_vec().into(),
957            },
958        })
959    }
960
961    fn reencode(&self, section: &mut wasm_encoder::NameSection) {
962        let name_map = |map: &NameMap| {
963            map.iter()
964                .fold(wasm_encoder::NameMap::new(), |mut map, naming| {
965                    map.append(naming.index, &naming.name);
966                    map
967                })
968        };
969
970        let indirect_name_map = |map: &IndirectNameMap| {
971            map.iter()
972                .fold(wasm_encoder::IndirectNameMap::new(), |mut map, naming| {
973                    map.append(naming.index, &name_map(&naming.names));
974                    map
975                })
976        };
977
978        match self {
979            Name::Module(name) => {
980                section.module(name);
981            }
982            Name::Function(map) => section.functions(&name_map(map)),
983            Name::Local(map) => section.locals(&indirect_name_map(map)),
984            Name::Label(map) => section.labels(&indirect_name_map(map)),
985            Name::Type(map) => section.types(&name_map(map)),
986            Name::Table(map) => section.tables(&name_map(map)),
987            Name::Memory(map) => section.memories(&name_map(map)),
988            Name::Global(map) => section.globals(&name_map(map)),
989            Name::Element(map) => section.elements(&name_map(map)),
990            Name::Data(map) => section.data(&name_map(map)),
991            Name::Field(map) => section.fields(&indirect_name_map(map)),
992            Name::Tag(map) => section.tags(&name_map(map)),
993            Name::Unknown { ty, data } => section.raw(*ty, data),
994        }
995    }
996}
997
998pub struct ModuleFuncIndexShifter {
999    builder: ModuleBuilder,
1000    inserted_at: u32,
1001    code_section: bool,
1002    export_section: bool,
1003    element_section: bool,
1004    start_section: bool,
1005    name_section: bool,
1006}
1007
1008impl ModuleFuncIndexShifter {
1009    pub fn with_code_section(mut self) -> Self {
1010        self.code_section = true;
1011        self
1012    }
1013
1014    pub fn with_export_section(mut self) -> Self {
1015        self.export_section = true;
1016        self
1017    }
1018
1019    pub fn with_element_section(mut self) -> Self {
1020        self.element_section = true;
1021        self
1022    }
1023
1024    pub fn with_start_section(mut self) -> Self {
1025        self.start_section = true;
1026        self
1027    }
1028
1029    pub fn with_name_section(mut self) -> Self {
1030        self.name_section = true;
1031        self
1032    }
1033
1034    /// Shift function indices in every section
1035    pub fn with_all(self) -> Self {
1036        self.with_code_section()
1037            .with_export_section()
1038            .with_element_section()
1039            .with_start_section()
1040            .with_name_section()
1041    }
1042
1043    pub fn shift_all(self) -> ModuleBuilder {
1044        self.with_all().shift()
1045    }
1046
1047    /// Do actual shifting
1048    pub fn shift(mut self) -> ModuleBuilder {
1049        if let Some(section) = self
1050            .builder
1051            .module
1052            .code_section
1053            .as_mut()
1054            .filter(|_| self.code_section)
1055        {
1056            for func in section {
1057                for instruction in &mut func.instructions {
1058                    if let Instruction::Call(function_index) = instruction {
1059                        if *function_index >= self.inserted_at {
1060                            *function_index += 1
1061                        }
1062                    }
1063                }
1064            }
1065        }
1066
1067        if let Some(section) = self
1068            .builder
1069            .module
1070            .export_section
1071            .as_mut()
1072            .filter(|_| self.export_section)
1073        {
1074            for export in section {
1075                if let ExternalKind::Func = export.kind {
1076                    if export.index >= self.inserted_at {
1077                        export.index += 1
1078                    }
1079                }
1080            }
1081        }
1082
1083        if let Some(section) = self
1084            .builder
1085            .module
1086            .element_section
1087            .as_mut()
1088            .filter(|_| self.element_section)
1089        {
1090            for segment in section {
1091                // update all indirect call addresses initial values
1092                match &mut segment.items {
1093                    ElementItems::Functions(funcs) => {
1094                        for func_index in funcs.iter_mut() {
1095                            if *func_index >= self.inserted_at {
1096                                *func_index += 1
1097                            }
1098                        }
1099                    }
1100                }
1101            }
1102        }
1103
1104        if let Some(start_idx) = self
1105            .builder
1106            .module
1107            .start_section
1108            .as_mut()
1109            .filter(|_| self.start_section)
1110        {
1111            if *start_idx >= self.inserted_at {
1112                *start_idx += 1
1113            }
1114        }
1115
1116        if let Some(section) = self
1117            .builder
1118            .module
1119            .name_section
1120            .as_mut()
1121            .filter(|_| self.name_section)
1122        {
1123            for name in section {
1124                if let Name::Function(map) = name {
1125                    for naming in map {
1126                        if naming.index >= self.inserted_at {
1127                            naming.index += 1;
1128                        }
1129                    }
1130                }
1131            }
1132        }
1133
1134        self.builder
1135    }
1136}
1137
1138#[derive(Debug, Default)]
1139pub struct ModuleBuilder {
1140    module: Module,
1141}
1142
1143impl ModuleBuilder {
1144    pub fn from_module(module: Module) -> Self {
1145        Self { module }
1146    }
1147
1148    pub fn shift_func_index(self, inserted_at: u32) -> ModuleFuncIndexShifter {
1149        ModuleFuncIndexShifter {
1150            builder: self,
1151            inserted_at,
1152            code_section: false,
1153            export_section: false,
1154            element_section: false,
1155            start_section: false,
1156            name_section: false,
1157        }
1158    }
1159
1160    pub fn build(self) -> Module {
1161        self.module
1162    }
1163
1164    fn type_section(&mut self) -> &mut TypeSection {
1165        self.module
1166            .type_section
1167            .get_or_insert_with(Default::default)
1168    }
1169
1170    fn import_section(&mut self) -> &mut Vec<Import> {
1171        self.module.import_section.get_or_insert_with(Vec::new)
1172    }
1173
1174    fn func_section(&mut self) -> &mut Vec<u32> {
1175        self.module.function_section.get_or_insert_with(Vec::new)
1176    }
1177
1178    fn global_section(&mut self) -> &mut Vec<Global> {
1179        self.module.global_section.get_or_insert_with(Vec::new)
1180    }
1181
1182    fn export_section(&mut self) -> &mut Vec<Export> {
1183        self.module.export_section.get_or_insert_with(Vec::new)
1184    }
1185
1186    fn element_section(&mut self) -> &mut Vec<Element> {
1187        self.module.element_section.get_or_insert_with(Vec::new)
1188    }
1189
1190    fn code_section(&mut self) -> &mut CodeSection {
1191        self.module.code_section.get_or_insert_with(Vec::new)
1192    }
1193
1194    fn data_section(&mut self) -> &mut DataSection {
1195        self.module.data_section.get_or_insert_with(Vec::new)
1196    }
1197
1198    fn custom_sections(&mut self) -> &mut Vec<CustomSection> {
1199        self.module.custom_sections.get_or_insert_with(Vec::new)
1200    }
1201
1202    pub fn push_custom_section(
1203        &mut self,
1204        name: impl Into<Cow<'static, str>>,
1205        data: impl Into<Vec<u8>>,
1206    ) {
1207        self.custom_sections().push((name.into(), data.into()));
1208    }
1209
1210    /// Adds a new function to the module.
1211    ///
1212    /// Returns index from function section
1213    pub fn add_func(&mut self, ty: FuncType, function: Function) -> u32 {
1214        let type_idx = self.push_type(ty);
1215        self.func_section().push(type_idx);
1216        let func_idx = self.func_section().len() as u32 - 1;
1217        self.code_section().push(function);
1218        func_idx
1219    }
1220
1221    pub fn push_type(&mut self, ty: FuncType) -> u32 {
1222        let idx = self.type_section().iter().position(|vec_ty| *vec_ty == ty);
1223        idx.map(|pos| pos as u32).unwrap_or_else(|| {
1224            self.type_section().push(ty);
1225            self.type_section().len() as u32 - 1
1226        })
1227    }
1228
1229    pub fn push_import(&mut self, import: Import) -> u32 {
1230        self.import_section().push(import);
1231        self.import_section().len() as u32 - 1
1232    }
1233
1234    pub fn set_table(&mut self, table: Table) {
1235        debug_assert_eq!(self.module.table_section, None);
1236        self.module.table_section = Some(table);
1237    }
1238
1239    pub fn push_global(&mut self, global: Global) -> u32 {
1240        self.global_section().push(global);
1241        self.global_section().len() as u32 - 1
1242    }
1243
1244    pub fn push_export(&mut self, export: Export) {
1245        self.export_section().push(export);
1246    }
1247
1248    pub fn push_element(&mut self, element: Element) {
1249        self.element_section().push(element);
1250    }
1251
1252    pub fn push_data(&mut self, data: Data) {
1253        self.data_section().push(data);
1254    }
1255}
1256
1257pub type TypeSection = Vec<FuncType>;
1258pub type FuncSection = Vec<u32>;
1259pub type CodeSection = Vec<Function>;
1260pub type DataSection = Vec<Data>;
1261pub type CustomSection = (Cow<'static, str>, Vec<u8>);
1262
1263#[derive(derive_more::Debug, Clone, Default)]
1264#[debug("Module {{ .. }}")]
1265pub struct Module {
1266    pub type_section: Option<TypeSection>,
1267    pub import_section: Option<Vec<Import>>,
1268    pub function_section: Option<FuncSection>,
1269    pub table_section: Option<Table>,
1270    pub memory_section: Option<MemoryType>,
1271    pub global_section: Option<Vec<Global>>,
1272    pub export_section: Option<Vec<Export>>,
1273    pub start_section: Option<u32>,
1274    pub element_section: Option<Vec<Element>>,
1275    pub code_section: Option<CodeSection>,
1276    pub data_section: Option<DataSection>,
1277    pub name_section: Option<Vec<Name>>,
1278    pub custom_sections: Option<Vec<CustomSection>>,
1279}
1280
1281impl Module {
1282    pub fn new(code: &[u8]) -> Result<Self> {
1283        let mut type_section = None;
1284        let mut import_section = None;
1285        let mut function_section = None;
1286        let mut table_section = None;
1287        let mut memory_section = None;
1288        let mut global_section = None;
1289        let mut export_section = None;
1290        let mut start_section = None;
1291        let mut element_section = None;
1292        let mut code_section = None;
1293        let mut data_section = None;
1294        let mut name_section = None;
1295        let mut custom_sections = None;
1296
1297        let mut parser = wasmparser::Parser::new(0);
1298        parser.set_features(GEAR_SUPPORTED_FEATURES);
1299        for payload in parser.parse_all(code) {
1300            match payload? {
1301                Payload::Version {
1302                    num: _,
1303                    encoding,
1304                    range: _,
1305                } => {
1306                    debug_assert_eq!(encoding, Encoding::Module);
1307                }
1308                Payload::TypeSection(section) => {
1309                    debug_assert!(type_section.is_none());
1310                    type_section = Some(
1311                        section
1312                            .into_iter_err_on_gc_types()
1313                            .collect::<Result<_, _>>()?,
1314                    );
1315                }
1316                Payload::ImportSection(section) => {
1317                    debug_assert!(import_section.is_none());
1318                    import_section = Some(
1319                        section
1320                            .into_iter()
1321                            .map(|import| import.map(Import::parse))
1322                            .collect::<Result<_, _>>()?,
1323                    );
1324                }
1325                Payload::FunctionSection(section) => {
1326                    debug_assert!(function_section.is_none());
1327                    function_section = Some(section.into_iter().collect::<Result<_, _>>()?);
1328                }
1329                Payload::TableSection(section) => {
1330                    debug_assert!(table_section.is_none());
1331                    let mut section = section.into_iter();
1332
1333                    table_section = section
1334                        .next()
1335                        .map(|table| table.map_err(Into::into).and_then(Table::parse))
1336                        .transpose()?;
1337
1338                    if section.next().is_some() {
1339                        return Err(ModuleError::MultipleTables);
1340                    }
1341                }
1342                Payload::MemorySection(section) => {
1343                    debug_assert!(memory_section.is_none());
1344                    let mut section = section.into_iter();
1345
1346                    memory_section = section.next().transpose()?;
1347
1348                    if section.next().is_some() {
1349                        return Err(ModuleError::MultipleMemories);
1350                    }
1351                }
1352                Payload::TagSection(_) => {}
1353                Payload::GlobalSection(section) => {
1354                    debug_assert!(global_section.is_none());
1355                    global_section = Some(
1356                        section
1357                            .into_iter()
1358                            .map(|element| element.map_err(Into::into).and_then(Global::parse))
1359                            .collect::<Result<_, _>>()?,
1360                    );
1361                }
1362                Payload::ExportSection(section) => {
1363                    debug_assert!(export_section.is_none());
1364                    export_section = Some(
1365                        section
1366                            .into_iter()
1367                            .map(|e| e.map(Export::parse))
1368                            .collect::<Result<_, _>>()?,
1369                    );
1370                }
1371                Payload::StartSection { func, range: _ } => {
1372                    start_section = Some(func);
1373                }
1374                Payload::ElementSection(section) => {
1375                    debug_assert!(element_section.is_none());
1376                    element_section = Some(
1377                        section
1378                            .into_iter()
1379                            .map(|element| element.map_err(Into::into).and_then(Element::parse))
1380                            .collect::<Result<Vec<_>>>()?,
1381                    );
1382                }
1383                // note: the section is not present in WASM 1.0
1384                Payload::DataCountSection { count, range: _ } => {
1385                    data_section = Some(Vec::with_capacity(count as usize));
1386                }
1387                Payload::DataSection(section) => {
1388                    let data_section = data_section.get_or_insert_with(Vec::new);
1389                    for data in section {
1390                        let data = data?;
1391                        data_section.push(Data::parse(data)?);
1392                    }
1393                }
1394                Payload::CodeSectionStart {
1395                    count,
1396                    range: _,
1397                    size: _,
1398                } => {
1399                    code_section = Some(Vec::with_capacity(count as usize));
1400                }
1401                Payload::CodeSectionEntry(entry) => {
1402                    code_section
1403                        .as_mut()
1404                        .expect("code section start missing")
1405                        .push(Function::from_entry(entry)?);
1406                }
1407                Payload::CustomSection(section) => match section.as_known() {
1408                    KnownCustom::Name(name_section_reader) => {
1409                        name_section = Some(
1410                            name_section_reader
1411                                .into_iter()
1412                                .map(|name| name.map_err(Into::into).and_then(Name::parse))
1413                                .collect::<Result<Vec<_>>>()?,
1414                        );
1415                    }
1416                    _ => {
1417                        let custom_sections = custom_sections.get_or_insert_with(Vec::new);
1418                        let name = section.name().to_string().into();
1419                        let data = section.data().to_vec();
1420                        custom_sections.push((name, data));
1421                    }
1422                },
1423                Payload::UnknownSection { .. } => {}
1424                _ => {}
1425            }
1426        }
1427
1428        Ok(Self {
1429            type_section,
1430            import_section,
1431            function_section,
1432            table_section,
1433            memory_section,
1434            global_section,
1435            export_section,
1436            start_section,
1437            element_section,
1438            code_section,
1439            data_section,
1440            name_section,
1441            custom_sections,
1442        })
1443    }
1444
1445    pub fn serialize(&self) -> Result<Vec<u8>> {
1446        let mut module = wasm_encoder::Module::new();
1447
1448        if let Some(crate_section) = &self.type_section {
1449            let mut encoder_section = wasm_encoder::TypeSection::new();
1450            for func_type in crate_section.clone() {
1451                encoder_section
1452                    .ty()
1453                    .func_type(&RoundtripReencoder.func_type(func_type)?);
1454            }
1455            module.section(&encoder_section);
1456        }
1457
1458        if let Some(crate_section) = &self.import_section {
1459            let mut encoder_section = wasm_encoder::ImportSection::new();
1460            for import in crate_section.clone() {
1461                import.reencode(&mut encoder_section)?;
1462            }
1463            module.section(&encoder_section);
1464        }
1465
1466        if let Some(crate_section) = &self.function_section {
1467            let mut encoder_section = wasm_encoder::FunctionSection::new();
1468            for &function in crate_section {
1469                encoder_section.function(function);
1470            }
1471            module.section(&encoder_section);
1472        }
1473
1474        if let Some(table) = &self.table_section {
1475            let mut encoder_section = wasm_encoder::TableSection::new();
1476            table.reencode(&mut encoder_section)?;
1477            module.section(&encoder_section);
1478        }
1479
1480        if let Some(memory) = &self.memory_section {
1481            let mut encoder_section = wasm_encoder::MemorySection::new();
1482            encoder_section.memory(RoundtripReencoder.memory_type(*memory));
1483            module.section(&encoder_section);
1484        }
1485
1486        if let Some(crate_section) = &self.global_section {
1487            let mut encoder_section = wasm_encoder::GlobalSection::new();
1488            for global in crate_section {
1489                encoder_section.global(
1490                    RoundtripReencoder.global_type(global.ty)?,
1491                    &global.init_expr.reencode()?,
1492                );
1493            }
1494            module.section(&encoder_section);
1495        }
1496
1497        if let Some(crate_section) = &self.export_section {
1498            let mut encoder_section = wasm_encoder::ExportSection::new();
1499            for export in crate_section {
1500                encoder_section.export(
1501                    &export.name,
1502                    RoundtripReencoder.export_kind(export.kind),
1503                    export.index,
1504                );
1505            }
1506            module.section(&encoder_section);
1507        }
1508
1509        if let Some(function_index) = self.start_section {
1510            module.section(&wasm_encoder::StartSection { function_index });
1511        }
1512
1513        if let Some(crate_section) = &self.element_section {
1514            let mut encoder_section = wasm_encoder::ElementSection::new();
1515            for element in crate_section {
1516                element.reencode(&mut encoder_section)?;
1517            }
1518            module.section(&encoder_section);
1519        }
1520
1521        if let Some(crate_section) = &self.code_section {
1522            let mut encoder_section = wasm_encoder::CodeSection::new();
1523            for function in crate_section {
1524                encoder_section.function(&function.reencode()?);
1525            }
1526            module.section(&encoder_section);
1527        }
1528
1529        if let Some(crate_section) = &self.data_section {
1530            let mut encoder_section = wasm_encoder::DataSection::new();
1531            for data in crate_section {
1532                encoder_section.active(0, &data.offset_expr.reencode()?, data.data.iter().copied());
1533            }
1534            module.section(&encoder_section);
1535        }
1536
1537        if let Some(name_section) = &self.name_section {
1538            let mut encoder_section = wasm_encoder::NameSection::new();
1539            for name in name_section {
1540                name.reencode(&mut encoder_section);
1541            }
1542            module.section(&encoder_section);
1543        }
1544
1545        if let Some(custom_sections) = &self.custom_sections {
1546            for (name, data) in custom_sections {
1547                let encoder_section = wasm_encoder::CustomSection {
1548                    name: Cow::Borrowed(name),
1549                    data: Cow::Borrowed(data),
1550                };
1551                module.section(&encoder_section);
1552            }
1553        }
1554
1555        Ok(module.finish())
1556    }
1557
1558    pub fn import_count(&self, pred: impl Fn(&TypeRef) -> bool) -> usize {
1559        self.import_section
1560            .as_ref()
1561            .map(|imports| imports.iter().filter(|import| pred(&import.ty)).count())
1562            .unwrap_or(0)
1563    }
1564
1565    pub fn functions_space(&self) -> usize {
1566        self.import_count(|ty| matches!(ty, TypeRef::Func(_)))
1567            + self
1568                .function_section
1569                .as_ref()
1570                .map(|section| section.len())
1571                .unwrap_or(0)
1572    }
1573
1574    pub fn globals_space(&self) -> usize {
1575        self.import_count(|ty| matches!(ty, TypeRef::Global(_)))
1576            + self
1577                .global_section
1578                .as_ref()
1579                .map(|section| section.len())
1580                .unwrap_or(0)
1581    }
1582}
1583
1584#[cfg(test)]
1585mod tests {
1586    use super::*;
1587
1588    macro_rules! test_parsing_failed {
1589        (
1590            $( $test_name:ident: $wat:literal => $err:expr; )*
1591        ) => {
1592            $(
1593                #[test]
1594                fn $test_name() {
1595                    let wasm = wat::parse_str($wat).unwrap();
1596                    let lhs = Module::new(&wasm).unwrap_err();
1597                    let rhs: ModuleError = $err;
1598                    // we cannot compare errors directly because `BinaryReaderError` does not implement `PartialEq`
1599                    assert_eq!(format!("{lhs:?}"), format!("{rhs:?}"));
1600                }
1601            )*
1602        };
1603    }
1604
1605    test_parsing_failed! {
1606        multiple_tables_denied: r#"
1607        (module
1608            (table 10 10 funcref)
1609            (table 20 20 funcref)
1610        )"# => ModuleError::MultipleTables;
1611
1612        multiple_memories_denied: r#"
1613        (module
1614            (memory (export "memory") 1)
1615            (memory (export "memory2") 2)
1616        )"# => ModuleError::MultipleMemories;
1617
1618        data_non_zero_memory_idx_denied: r#"
1619        (module
1620            (data (memory 123) (offset i32.const 0) "")
1621        )
1622        "# => ModuleError::NonZeroMemoryIdx(123);
1623
1624        element_table_idx_denied: r#"
1625        (module
1626            (elem 123 (offset i32.const 0) 0 0 0 0)
1627        )"# => ModuleError::ElementTableIdx(123);
1628
1629        passive_data_kind_denied: r#"
1630        (module
1631            (data "")
1632        )
1633        "# => ModuleError::PassiveDataKind;
1634
1635        passive_element_denied: r#"
1636        (module
1637            (elem funcref (item i32.const 0))
1638        )
1639        "# => ModuleError::NonActiveElementKind;
1640
1641        declared_element_denied: r#"
1642        (module
1643            (func $a)
1644            (elem declare func $a)
1645        )
1646        "# => ModuleError::NonActiveElementKind;
1647
1648        element_expressions_denied: r#"
1649        (module
1650            (elem (i32.const 1) funcref)
1651        )
1652        "# => ModuleError::ElementExpressions;
1653
1654        table_init_expr_denied: r#"
1655        (module
1656            (table 0 0 funcref (i32.const 0))
1657        )"# => ModuleError::TableInitExpr;
1658    }
1659
1660    #[test]
1661    fn call_indirect_non_zero_table_idx_denied() {
1662        let wasm = wat::parse_str(
1663            r#"
1664            (module
1665                (func
1666                    call_indirect 123 (type 333)
1667                )
1668            )
1669            "#,
1670        )
1671        .unwrap();
1672        let err = Module::new(&wasm).unwrap_err();
1673        if let ModuleError::BinaryReader(err) = err {
1674            assert_eq!(err.offset(), 26);
1675            assert_eq!(err.message(), "zero byte expected");
1676        } else {
1677            panic!("{err}");
1678        }
1679    }
1680
1681    #[test]
1682    fn custom_section_kept() {
1683        let mut builder = ModuleBuilder::default();
1684        builder.push_custom_section("dummy", [1, 2, 3]);
1685        let module = builder.build();
1686        let module_bytes = module.serialize().unwrap();
1687        let wat = wasmprinter::print_bytes(&module_bytes).unwrap();
1688
1689        let parsed_module_bytes = Module::new(&module_bytes).unwrap().serialize().unwrap();
1690        let parsed_wat = wasmprinter::print_bytes(&parsed_module_bytes).unwrap();
1691        assert_eq!(wat, parsed_wat);
1692    }
1693}