makepad_stitch/
validate.rs

1use {
2    crate::{
3        code,
4        code::{
5            BinOpInfo, BlockType, InstrVisitor, LoadInfo, MemArg, StoreInfo, UnOpInfo,
6            UncompiledCode,
7        },
8        decode::DecodeError,
9        func::FuncType,
10        global::Mut,
11        module::ModuleBuilder,
12        ref_::RefType,
13        val::ValType,
14    },
15    std::{mem, ops::Deref},
16};
17
18#[derive(Clone, Debug)]
19pub(crate) struct Validator {
20    label_idxs: Vec<u32>,
21    locals: Vec<ValType>,
22    blocks: Vec<Block>,
23    opds: Vec<OpdType>,
24    aux_opds: Vec<OpdType>,
25}
26
27impl Validator {
28    pub(crate) fn new() -> Validator {
29        Validator {
30            label_idxs: Vec::new(),
31            locals: Vec::new(),
32            blocks: Vec::new(),
33            opds: Vec::new(),
34            aux_opds: Vec::new(),
35        }
36    }
37
38    pub(crate) fn validate(
39        &mut self,
40        type_: &FuncType,
41        module: &ModuleBuilder,
42        code: &UncompiledCode,
43    ) -> Result<(), DecodeError> {
44        use crate::decode::Decoder;
45
46        self.label_idxs.clear();
47        self.locals.clear();
48        self.blocks.clear();
49        self.opds.clear();
50        let mut validation = Validation {
51            module,
52            locals: &mut self.locals,
53            blocks: &mut self.blocks,
54            opds: &mut self.opds,
55            aux_opds: &mut self.aux_opds,
56        };
57        validation.locals.extend(type_.params().iter().copied());
58        validation.locals.extend(code.locals.iter().copied());
59        validation.push_block(
60            BlockKind::Block,
61            FuncType::new([], type_.results().iter().copied()),
62        );
63        let mut decoder = Decoder::new(&code.expr);
64        while !validation.blocks.is_empty() {
65            code::decode_instr(&mut decoder, &mut self.label_idxs, &mut validation)?;
66        }
67        Ok(())
68    }
69}
70
71impl Default for Validator {
72    fn default() -> Self {
73        Self::new()
74    }
75}
76
77#[derive(Debug)]
78struct Validation<'a> {
79    module: &'a ModuleBuilder,
80    locals: &'a mut Vec<ValType>,
81    blocks: &'a mut Vec<Block>,
82    opds: &'a mut Vec<OpdType>,
83    aux_opds: &'a mut Vec<OpdType>,
84}
85
86impl<'a> Validation<'a> {
87    fn resolve_block_type(&self, type_: BlockType) -> Result<FuncType, DecodeError> {
88        match type_ {
89            BlockType::TypeIdx(idx) => self.module.type_(idx).cloned(),
90            BlockType::ValType(val_type) => Ok(FuncType::from_val_type(val_type)),
91        }
92    }
93
94    fn local(&self, idx: u32) -> Result<ValType, DecodeError> {
95        self.locals
96            .get(idx as usize)
97            .copied()
98            .ok_or_else(|| DecodeError::new("unknown local"))
99    }
100
101    fn label(&self, idx: u32) -> Result<(), DecodeError> {
102        let idx = usize::try_from(idx).unwrap();
103        if idx >= self.blocks.len() {
104            return Err(DecodeError::new("unknown label"));
105        }
106        Ok(())
107    }
108
109    fn block(&self, idx: u32) -> &Block {
110        let idx = usize::try_from(idx).unwrap();
111        &self.blocks[self.blocks.len() - 1 - idx]
112    }
113
114    fn push_block(&mut self, kind: BlockKind, type_: FuncType) {
115        self.blocks.push(Block {
116            kind,
117            type_,
118            is_unreachable: false,
119            height: self.opds.len(),
120        });
121        for start_type in self.block(0).type_.clone().params().iter().copied() {
122            self.push_opd(start_type);
123        }
124    }
125
126    fn pop_block(&mut self) -> Result<Block, DecodeError> {
127        for end_type in self.block(0).type_.clone().results().iter().rev().copied() {
128            self.pop_opd()?.check(end_type)?;
129        }
130        if self.opds.len() != self.block(0).height {
131            return Err(DecodeError::new("type mismatch"));
132        }
133        Ok(self.blocks.pop().unwrap())
134    }
135
136    fn set_unreachable(&mut self) {
137        self.opds.truncate(self.block(0).height);
138        self.blocks.last_mut().unwrap().is_unreachable = true;
139    }
140
141    fn push_opd(&mut self, type_: impl Into<OpdType>) {
142        let type_ = type_.into();
143        self.opds.push(type_);
144    }
145
146    fn pop_opd(&mut self) -> Result<OpdType, DecodeError> {
147        if self.opds.len() == self.block(0).height {
148            if !self.block(0).is_unreachable {
149                return Err(DecodeError::new("type mismatch"));
150            }
151            Ok(OpdType::Unknown)
152        } else {
153            Ok(self.opds.pop().unwrap())
154        }
155    }
156}
157
158impl<'a> InstrVisitor for Validation<'a> {
159    type Error = DecodeError;
160
161    // Control instructions
162    fn visit_nop(&mut self) -> Result<(), Self::Error> {
163        Ok(())
164    }
165
166    fn visit_unreachable(&mut self) -> Result<(), Self::Error> {
167        self.set_unreachable();
168        Ok(())
169    }
170
171    fn visit_block(&mut self, type_: BlockType) -> Result<(), Self::Error> {
172        let type_ = self.resolve_block_type(type_)?;
173        for start_type in type_.params().iter().rev().copied() {
174            self.pop_opd()?.check(start_type)?;
175        }
176        self.push_block(BlockKind::Block, type_);
177        Ok(())
178    }
179
180    fn visit_loop(&mut self, type_: BlockType) -> Result<(), Self::Error> {
181        let type_ = self.resolve_block_type(type_)?;
182        for start_type in type_.params().iter().rev().copied() {
183            self.pop_opd()?.check(start_type)?;
184        }
185        self.push_block(BlockKind::Loop, type_);
186        Ok(())
187    }
188
189    fn visit_if(&mut self, type_: BlockType) -> Result<(), Self::Error> {
190        let type_ = self.resolve_block_type(type_)?;
191        self.pop_opd()?.check(ValType::I32)?;
192        for start_type in type_.params().iter().rev().copied() {
193            self.pop_opd()?.check(start_type)?;
194        }
195        self.push_block(BlockKind::If, type_);
196        Ok(())
197    }
198
199    fn visit_else(&mut self) -> Result<(), Self::Error> {
200        let block = self.pop_block()?;
201        if block.kind != BlockKind::If {
202            return Err(DecodeError::new("unexpected else opcode"));
203        }
204        self.push_block(BlockKind::Else, block.type_);
205        Ok(())
206    }
207
208    fn visit_end(&mut self) -> Result<(), Self::Error> {
209        let block = self.pop_block()?;
210        let block = if block.kind == BlockKind::If {
211            self.push_block(BlockKind::Else, block.type_);
212            self.pop_block()?
213        } else {
214            block
215        };
216        for end_type in block.type_.results().iter().copied() {
217            self.push_opd(end_type);
218        }
219        Ok(())
220    }
221
222    fn visit_br(&mut self, label_idx: u32) -> Result<(), Self::Error> {
223        self.label(label_idx)?;
224        for label_type in self.block(label_idx).label_types().iter().rev().copied() {
225            self.pop_opd()?.check(label_type)?;
226        }
227        self.set_unreachable();
228        Ok(())
229    }
230
231    fn visit_br_if(&mut self, label_idx: u32) -> Result<(), Self::Error> {
232        self.pop_opd()?.check(ValType::I32)?;
233        self.label(label_idx)?;
234        for &label_type in self.block(label_idx).label_types().iter().rev() {
235            self.pop_opd()?.check(label_type)?;
236        }
237        for &label_type in self.block(label_idx).label_types().iter() {
238            self.push_opd(label_type);
239        }
240        Ok(())
241    }
242
243    fn visit_br_table(
244        &mut self,
245        label_idxs: &[u32],
246        default_label_idx: u32,
247    ) -> Result<(), Self::Error> {
248        self.pop_opd()?.check(ValType::I32)?;
249        self.label(default_label_idx)?;
250        let arity = self.block(default_label_idx).label_types().len();
251        for label_idx in label_idxs.iter().copied() {
252            self.label(label_idx)?;
253            if self.block(label_idx).label_types().len() != arity {
254                return Err(DecodeError::new("arity mismatch"));
255            }
256            let mut aux_opds = mem::take(self.aux_opds);
257            for label_type in self.block(label_idx).label_types().iter().rev().copied() {
258                let opd = self.pop_opd()?;
259                opd.check(label_type)?;
260                aux_opds.push(opd);
261            }
262            while let Some(opd) = aux_opds.pop() {
263                self.push_opd(opd);
264            }
265            *self.aux_opds = aux_opds;
266        }
267        for label_type in self
268            .block(default_label_idx)
269            .label_types()
270            .iter()
271            .rev()
272            .copied()
273        {
274            self.pop_opd()?.check(label_type)?;
275        }
276        self.set_unreachable();
277        Ok(())
278    }
279
280    fn visit_return(&mut self) -> Result<(), Self::Error> {
281        self.visit_br(self.blocks.len() as u32 - 1)
282    }
283
284    fn visit_call(&mut self, func_idx: u32) -> Result<(), Self::Error> {
285        let type_ = self.module.func(func_idx)?;
286        for param_type in type_.params().iter().rev().copied() {
287            self.pop_opd()?.check(param_type)?;
288        }
289        for result_type in type_.results().iter().copied() {
290            self.push_opd(result_type);
291        }
292        Ok(())
293    }
294
295    fn visit_call_indirect(&mut self, table_idx: u32, type_idx: u32) -> Result<(), Self::Error> {
296        let table_type = self.module.table(table_idx)?;
297        if table_type.elem != RefType::FuncRef {
298            return Err(DecodeError::new("type mismatch"));
299        }
300        let type_ = self.module.type_(type_idx)?;
301        self.pop_opd()?.check(ValType::I32)?;
302        for param_type in type_.params().iter().rev().copied() {
303            self.pop_opd()?.check(param_type)?;
304        }
305        for result_type in type_.results().iter().copied() {
306            self.push_opd(result_type);
307        }
308        Ok(())
309    }
310
311    // Reference instructions
312    fn visit_ref_null(&mut self, type_: RefType) -> Result<(), Self::Error> {
313        self.push_opd(type_);
314        Ok(())
315    }
316
317    fn visit_ref_is_null(&mut self) -> Result<(), Self::Error> {
318        if !self.pop_opd()?.is_ref() {
319            return Err(DecodeError::new("type mismatch"));
320        };
321        self.push_opd(ValType::I32);
322        Ok(())
323    }
324
325    fn visit_ref_func(&mut self, func_idx: u32) -> Result<(), Self::Error> {
326        self.module.ref_(func_idx)?;
327        self.push_opd(ValType::FuncRef);
328        Ok(())
329    }
330
331    // Parametric instructions
332    fn visit_drop(&mut self) -> Result<(), Self::Error> {
333        self.pop_opd()?;
334        Ok(())
335    }
336
337    fn visit_select(&mut self, type_: Option<ValType>) -> Result<(), Self::Error> {
338        if let Some(type_) = type_ {
339            self.pop_opd()?.check(ValType::I32)?;
340            self.pop_opd()?.check(type_)?;
341            self.pop_opd()?.check(type_)?;
342            self.push_opd(type_);
343        } else {
344            self.pop_opd()?.check(ValType::I32)?;
345            let input_type_1 = self.pop_opd()?;
346            let input_type_0 = self.pop_opd()?;
347            if !(input_type_0.is_num() && input_type_1.is_num()) {
348                return Err(DecodeError::new("type mismatch"));
349            }
350            if let OpdType::ValType(input_type_1) = input_type_1 {
351                input_type_0.check(input_type_1)?;
352            }
353            self.push_opd(if input_type_0.is_unknown() {
354                input_type_1
355            } else {
356                input_type_0
357            });
358        }
359        Ok(())
360    }
361
362    // Variable instructions
363    fn visit_local_get(&mut self, local_idx: u32) -> Result<(), Self::Error> {
364        let type_ = self.local(local_idx)?;
365        self.push_opd(type_);
366        Ok(())
367    }
368
369    fn visit_local_set(&mut self, local_idx: u32) -> Result<(), Self::Error> {
370        let type_ = self.local(local_idx)?;
371        self.pop_opd()?.check(type_)?;
372        Ok(())
373    }
374
375    fn visit_local_tee(&mut self, local_idx: u32) -> Result<(), Self::Error> {
376        let type_ = self.local(local_idx)?;
377        self.pop_opd()?.check(type_)?;
378        self.push_opd(type_);
379        Ok(())
380    }
381
382    fn visit_global_get(&mut self, global_idx: u32) -> Result<(), Self::Error> {
383        let type_ = self.module.global(global_idx)?;
384        self.push_opd(type_.val);
385        Ok(())
386    }
387
388    fn visit_global_set(&mut self, global_idx: u32) -> Result<(), Self::Error> {
389        let type_ = self.module.global(global_idx)?;
390        if type_.mut_ != Mut::Var {
391            return Err(DecodeError::new("type mismatch"));
392        }
393        self.pop_opd()?.check(type_.val)?;
394        Ok(())
395    }
396
397    // Table instructions
398    fn visit_table_get(&mut self, table_idx: u32) -> Result<(), Self::Error> {
399        let type_ = self.module.table(table_idx)?;
400        self.pop_opd()?.check(ValType::I32)?;
401        self.push_opd(type_.elem);
402        Ok(())
403    }
404
405    fn visit_table_set(&mut self, table_idx: u32) -> Result<(), Self::Error> {
406        let type_ = self.module.table(table_idx)?;
407        self.pop_opd()?.check(type_.elem)?;
408        self.pop_opd()?.check(ValType::I32)?;
409        Ok(())
410    }
411
412    fn visit_table_size(&mut self, table_idx: u32) -> Result<(), Self::Error> {
413        self.module.table(table_idx)?;
414        self.push_opd(ValType::I32);
415        Ok(())
416    }
417
418    fn visit_table_grow(&mut self, table_idx: u32) -> Result<(), Self::Error> {
419        let type_ = self.module.table(table_idx)?;
420        self.pop_opd()?.check(ValType::I32)?;
421        self.pop_opd()?.check(type_.elem)?;
422        self.push_opd(ValType::I32);
423        Ok(())
424    }
425
426    fn visit_table_fill(&mut self, table_idx: u32) -> Result<(), Self::Error> {
427        let type_ = self.module.table(table_idx)?;
428        self.pop_opd()?.check(ValType::I32)?;
429        self.pop_opd()?.check(type_.elem)?;
430        self.pop_opd()?.check(ValType::I32)?;
431        Ok(())
432    }
433
434    fn visit_table_copy(
435        &mut self,
436        dst_table_idx: u32,
437        src_table_idx: u32,
438    ) -> Result<(), Self::Error> {
439        let dst_type = self.module.table(dst_table_idx)?;
440        let src_type = self.module.table(src_table_idx)?;
441        if dst_type.elem != src_type.elem {
442            return Err(DecodeError::new("type mismatch"));
443        }
444        self.pop_opd()?.check(ValType::I32)?;
445        self.pop_opd()?.check(ValType::I32)?;
446        self.pop_opd()?.check(ValType::I32)?;
447        Ok(())
448    }
449
450    fn visit_table_init(&mut self, table_idx: u32, elem_idx: u32) -> Result<(), Self::Error> {
451        let dst_type = self.module.table(table_idx)?;
452        let src_type = self.module.elem(elem_idx)?;
453        if dst_type.elem != src_type {
454            return Err(DecodeError::new("type mismatch"));
455        }
456        self.pop_opd()?.check(ValType::I32)?;
457        self.pop_opd()?.check(ValType::I32)?;
458        self.pop_opd()?.check(ValType::I32)?;
459        Ok(())
460    }
461
462    fn visit_elem_drop(&mut self, elem_idx: u32) -> Result<(), Self::Error> {
463        self.module.elem(elem_idx)?;
464        Ok(())
465    }
466
467    // Memory instructions
468    fn visit_load(&mut self, arg: MemArg, info: LoadInfo) -> Result<(), Self::Error> {
469        if arg.align > info.max_align {
470            return Err(DecodeError::new("alignment too large"));
471        }
472        self.module.memory(0)?;
473        self.visit_un_op(info.op)
474    }
475
476    fn visit_store(&mut self, arg: MemArg, info: StoreInfo) -> Result<(), Self::Error> {
477        if arg.align > info.max_align {
478            return Err(DecodeError::new("alignment too large"));
479        }
480        self.module.memory(0)?;
481        self.visit_bin_op(info.op)
482    }
483
484    fn visit_memory_size(&mut self) -> Result<(), Self::Error> {
485        self.module.memory(0)?;
486        self.push_opd(ValType::I32);
487        Ok(())
488    }
489
490    fn visit_memory_grow(&mut self) -> Result<(), Self::Error> {
491        self.module.memory(0)?;
492        self.pop_opd()?.check(ValType::I32)?;
493        self.push_opd(ValType::I32);
494        Ok(())
495    }
496
497    fn visit_memory_fill(&mut self) -> Result<(), Self::Error> {
498        self.module.memory(0)?;
499        self.pop_opd()?.check(ValType::I32)?;
500        self.pop_opd()?.check(ValType::I32)?;
501        self.pop_opd()?.check(ValType::I32)?;
502        Ok(())
503    }
504
505    fn visit_memory_copy(&mut self) -> Result<(), Self::Error> {
506        self.module.memory(0)?;
507        self.pop_opd()?.check(ValType::I32)?;
508        self.pop_opd()?.check(ValType::I32)?;
509        self.pop_opd()?.check(ValType::I32)?;
510        Ok(())
511    }
512
513    fn visit_memory_init(&mut self, data_idx: u32) -> Result<(), Self::Error> {
514        self.module.memory(0)?;
515        self.module.data(data_idx)?;
516        self.pop_opd()?.check(ValType::I32)?;
517        self.pop_opd()?.check(ValType::I32)?;
518        self.pop_opd()?.check(ValType::I32)?;
519        Ok(())
520    }
521
522    fn visit_data_drop(&mut self, data_idx: u32) -> Result<(), Self::Error> {
523        self.module.data(data_idx)?;
524        Ok(())
525    }
526
527    // Numeric instructions
528    fn visit_i32_const(&mut self, _val: i32) -> Result<(), Self::Error> {
529        self.push_opd(ValType::I32);
530        Ok(())
531    }
532
533    fn visit_i64_const(&mut self, _val: i64) -> Result<(), Self::Error> {
534        self.push_opd(ValType::I64);
535        Ok(())
536    }
537
538    fn visit_f32_const(&mut self, _val: f32) -> Result<(), Self::Error> {
539        self.push_opd(ValType::F32);
540        Ok(())
541    }
542
543    fn visit_f64_const(&mut self, _val: f64) -> Result<(), Self::Error> {
544        self.push_opd(ValType::F64);
545        Ok(())
546    }
547
548    fn visit_un_op(&mut self, info: UnOpInfo) -> Result<(), Self::Error> {
549        self.pop_opd()?.check(info.input_type)?;
550        if let Some(output_type) = info.output_type {
551            self.push_opd(output_type);
552        }
553        Ok(())
554    }
555
556    fn visit_bin_op(&mut self, info: BinOpInfo) -> Result<(), Self::Error> {
557        self.pop_opd()?.check(info.input_type_1)?;
558        self.pop_opd()?.check(info.input_type_0)?;
559        if let Some(output_type) = info.output_type {
560            self.push_opd(output_type);
561        }
562        Ok(())
563    }
564}
565
566#[derive(Clone, Debug)]
567struct Block {
568    kind: BlockKind,
569    type_: FuncType,
570    is_unreachable: bool,
571    height: usize,
572}
573
574impl Block {
575    fn label_types(&self) -> LabelTypes {
576        LabelTypes {
577            kind: self.kind,
578            type_: self.type_.clone(),
579        }
580    }
581}
582
583#[derive(Clone, Copy, Debug, Eq, PartialEq)]
584enum BlockKind {
585    Block,
586    Loop,
587    If,
588    Else,
589}
590
591#[derive(Clone, Debug)]
592struct LabelTypes {
593    kind: BlockKind,
594    type_: FuncType,
595}
596
597impl Deref for LabelTypes {
598    type Target = [ValType];
599
600    fn deref(&self) -> &Self::Target {
601        match self.kind {
602            BlockKind::Block | BlockKind::If | BlockKind::Else => self.type_.results(),
603            BlockKind::Loop => self.type_.params(),
604        }
605    }
606}
607
608#[derive(Clone, Copy, Debug)]
609enum OpdType {
610    ValType(ValType),
611    Unknown,
612}
613
614impl OpdType {
615    fn is_num(self) -> bool {
616        match self {
617            OpdType::ValType(type_) => type_.is_num(),
618            _ => true,
619        }
620    }
621
622    fn is_ref(self) -> bool {
623        match self {
624            OpdType::ValType(type_) => type_.is_ref(),
625            _ => true,
626        }
627    }
628
629    fn is_unknown(self) -> bool {
630        match self {
631            OpdType::Unknown => true,
632            _ => false,
633        }
634    }
635
636    fn check(self, expected_type: impl Into<ValType>) -> Result<(), DecodeError> {
637        let expected_type = expected_type.into();
638        match self {
639            OpdType::ValType(actual_type) if actual_type != expected_type => {
640                Err(DecodeError::new("type mismatch"))
641            }
642            _ => Ok(()),
643        }
644    }
645}
646
647impl From<RefType> for OpdType {
648    fn from(type_: RefType) -> Self {
649        OpdType::ValType(type_.into())
650    }
651}
652
653impl From<ValType> for OpdType {
654    fn from(type_: ValType) -> Self {
655        OpdType::ValType(type_)
656    }
657}
658
659impl From<Unknown> for OpdType {
660    fn from(_: Unknown) -> Self {
661        OpdType::Unknown
662    }
663}
664
665#[derive(Debug)]
666struct Unknown;