pit_patch/
canon.rs

1use alloc::borrow::ToOwned;
2use alloc::collections::{BTreeMap, BTreeSet};
3use alloc::format;
4use alloc::string::String;
5use alloc::vec;
6use alloc::vec::Vec;
7use core::iter::once;
8use core::mem::{replace, take};
9use portal_pc_waffle::util::results_ref_2;
10use portal_pc_waffle::{HeapType, WithMutablility};
11
12use anyhow::Context;
13use pit_core::Interface;
14use portal_pc_waffle::{
15    entity::EntityRef, util::new_sig, BlockTarget, Export, ExportKind, Func, FuncDecl,
16    FunctionBody, Import, ImportKind, Module, Operator, SignatureData, TableData, Type,
17    WithNullable,
18};
19use sha3::{Digest, Sha3_256};
20
21use crate::tutils::{talloc, tfree};
22// use waffle_ast::{results_ref_2, Builder, Expr};
23
24// use crate::util::{talloc, tfree};
25
26pub fn canon(m: &mut Module, rid: &str, target: &str) -> anyhow::Result<()> {
27    let mut xs = vec![];
28    for i in m.imports.iter() {
29        if i.module == format!("pit/{rid}") {
30            if let Some(a) = i.name.strip_prefix("~") {
31                xs.push(a.to_owned())
32            }
33        }
34    }
35    xs.sort();
36    let s = new_sig(
37        m,
38        SignatureData::Func {
39            params: vec![Type::I32],
40            returns: vec![Type::Heap(WithNullable {
41                nullable: true,
42                value: portal_pc_waffle::HeapType::ExternRef,
43            })],
44            shared: true,
45        },
46    );
47    let f2 = m.funcs.push(portal_pc_waffle::FuncDecl::Import(
48        s,
49        format!("pit/{rid}.~{target}"),
50    ));
51    let mut tcache: BTreeMap<Vec<Type>, _> = BTreeMap::new();
52    let tx = m.tables.push(TableData {
53        ty: Type::Heap(WithNullable {
54            nullable: true,
55            value: portal_pc_waffle::HeapType::ExternRef,
56        }),
57        initial: 0,
58        max: None,
59        func_elements: None,
60        table64: false,
61    });
62    // let mut tc2 = BTreeMap::new();
63    let mut tc = |m: &mut Module, tys| {
64        tcache
65            .entry(tys)
66            .or_insert_with_key(|tys| {
67                let st = m.signatures.push(SignatureData::Struct {
68                    fields: tys
69                        .iter()
70                        .cloned()
71                        .map(|a| WithMutablility {
72                            mutable: false,
73                            value: portal_pc_waffle::StorageType::Val(a),
74                        })
75                        .collect(),
76                    shared: true,
77                });
78                let stt = m.tables.push(TableData {
79                    ty: Type::Heap(WithNullable {
80                        value: portal_pc_waffle::HeapType::Sig { sig_index: st },
81                        nullable: true,
82                    }),
83                    initial: 0,
84                    max: None,
85                    func_elements: None,
86                    table64: false,
87                });
88                (
89                    talloc(m, stt, &[]).unwrap(),
90                    tfree(m, stt, &[]).unwrap(),
91                    st,
92                    stt,
93                )
94            })
95            .clone()
96    };
97    let mut m2 = BTreeMap::new();
98    let is = take(&mut m.imports);
99    let stub = new_sig(
100        m,
101        SignatureData::Func {
102            params: vec![],
103            returns: vec![Type::Heap(WithNullable {
104                nullable: true,
105                value: portal_pc_waffle::HeapType::ExternRef,
106            })],
107            shared: true,
108        },
109    );
110    let stub = m.funcs.push(FuncDecl::Import(stub, format!("stub")));
111    m.imports.push(Import {
112        module: format!("system"),
113        name: format!("stub"),
114        kind: ImportKind::Func(stub),
115    });
116    for i in is {
117        if i.module == format!("pit/{rid}") {
118            if let Some(a) = i.name.strip_prefix("~") {
119                if let Ok(x) = xs.binary_search(&a.to_owned()) {
120                    if let ImportKind::Func(f) = i.kind {
121                        let fs = m.funcs[f].sig();
122                        let fname = m.funcs[f].name().to_owned();
123                        let mut b = FunctionBody::new(&m, fs);
124                        let k = b.entry;
125                        let (ta, _, ts, tts) =
126                            tc(m, b.blocks[k].params.iter().map(|a| a.0).collect());
127                        m2.insert(
128                            a.to_owned(),
129                            b.blocks[k].params.iter().map(|a| a.0).collect::<Vec<_>>(),
130                        );
131                        // let mut e = Expr::Bind(
132                        //     Operator::I32Add,
133                        //     vec![
134                        //         Expr::Bind(Operator::I32Const { value: x as u32 }, vec![]),
135                        //         Expr::Bind(
136                        //             Operator::I32Mul,
137                        //             vec![
138                        //                 Expr::Bind(
139                        //                     Operator::I32Const {
140                        //                         value: xs.len() as u32,
141                        //                     },
142                        //                     vec![],
143                        //                 ),
144                        //                 if b.blocks[k].params.iter().map(|a| a.0).collect::<Vec<_>>()
145                        //                     == vec![Type::I32]
146                        //                 {
147                        //                     Expr::Leaf(b.blocks[k].params[0].1)
148                        //                 } else {
149                        //                     Expr::Bind(
150                        //                         Operator::Call { function_index: ta },
151                        //                         once(Expr::Bind(
152                        //                             Operator::Call {
153                        //                                 function_index: stub,
154                        //                             },
155                        //                             vec![],
156                        //                         ))
157                        //                         .chain(
158                        //                             b.blocks[k]
159                        //                                 .params
160                        //                                 .iter()
161                        //                                 .map(|p| Expr::Leaf(p.1)),
162                        //                         )
163                        //                         .collect(),
164                        //                     )
165                        //                 },
166                        //             ],
167                        //         ),
168                        //     ],
169                        // );
170                        // let (a, k) = e.build(m, &mut b, k)?;
171                        let a = {
172                            let a = b.add_op(
173                                k,
174                                Operator::I32Const {
175                                    value: xs.len() as u32,
176                                },
177                                &[],
178                                &[Type::I32],
179                            );
180                            let v = if b.blocks[k].params.iter().map(|a| a.0).collect::<Vec<_>>()
181                                == vec![Type::I32]
182                            {
183                                b.blocks[k].params[0].1
184                            } else {
185                                let a = b.add_op(
186                                    k,
187                                    Operator::StructNew { sig: ts },
188                                    &b.blocks[k].params.iter().map(|a| a.1).collect::<Vec<_>>(),
189                                    &[Type::Heap(WithNullable {
190                                        value: HeapType::Sig { sig_index: ts },
191                                        nullable: true,
192                                    })],
193                                );
194                                b.add_op(k, Operator::Call { function_index: ta }, &[a], &[Type::I32])
195                            };
196                            let a = b.add_op(k, Operator::I32Mul, &[v, a], &[Type::I32]);
197                            let c = b.add_op(
198                                k,
199                                Operator::I32Const { value: x as u32 },
200                                &[],
201                                &[Type::I32],
202                            );
203                            b.add_op(k, Operator::I32Add, &[a, c], &[Type::I32])
204                        };
205                        let args = once(a)
206                            .chain(b.blocks[b.entry].params[1..].iter().map(|a| a.1))
207                            .collect();
208                        b.set_terminator(
209                            k,
210                            portal_pc_waffle::Terminator::ReturnCall { func: f2, args },
211                        );
212                        m.funcs[f] = FuncDecl::Body(fs, fname, b);
213                        continue;
214                    }
215                }
216            }
217        }
218        m.imports.push(i)
219    }
220    m.imports.push(Import {
221        module: format!("pit/{rid}"),
222        name: format!("~{target}"),
223        kind: ImportKind::Func(f2),
224    });
225    let mut b = BTreeMap::new();
226    for x in take(&mut m.exports) {
227        for x2 in xs.iter() {
228            if let Some(a) = x.name.strip_prefix(&format!("pit/{rid}/~{x2}")) {
229                let mut b = b.entry(a.to_owned()).or_insert_with(|| BTreeMap::new());
230                let (e, _) = b
231                    .entry(x2.clone())
232                    .or_insert_with(|| (Func::invalid(), m2.get(a).cloned().unwrap()));
233                if let ExportKind::Func(f) = x.kind {
234                    *e = f;
235                    continue;
236                }
237            }
238        }
239        m.exports.push(x)
240    }
241    for (method, inner) in b.into_iter() {
242        let a = inner
243            .iter()
244            .filter(|a| a.1 .0.is_valid())
245            .next()
246            .context("in getting an instance")?
247            .1;
248        let sig = a.0;
249        let funcs: Vec<_> = xs
250            .iter()
251            .map(|f| inner.get(f).cloned().unwrap_or_default())
252            .collect::<Vec<_>>();
253        let sig = m.funcs[sig].sig();
254        let mut sig = m.signatures[sig].clone();
255        if let SignatureData::Func {
256            params, returns, ..
257        } = &mut sig
258        {
259            *params = once(Type::I32)
260                .chain(params[a.1.len()..].iter().cloned())
261                .collect::<Vec<_>>();
262        }
263        let sig = new_sig(m, sig);
264        let mut b = FunctionBody::new(&m, sig);
265        let k = b.entry;
266        let fl = b.add_op(
267            k,
268            Operator::I32Const {
269                value: funcs.len() as u32,
270            },
271            &[],
272            &[Type::I32],
273        );
274        // let mut e = Expr::Bind(
275        //     Operator::I32DivU,
276        //     vec![
277        //         Expr::Leaf(b.blocks[k].params[0].1),
278        //         Expr::Bind(
279        //             ,
280        //             vec![],
281        //         ),
282        //     ],
283        // );
284        // let (a, k) = e.build(m, &mut b, k)?;
285        let a = b.add_op(
286            k,
287            Operator::I32DivU,
288            &[b.blocks[k].params[0].1, fl],
289            &[Type::I32],
290        );
291        // let mut e = Expr::Bind(
292        //     Operator::I32RemU,
293        //     vec![
294        //         Expr::Leaf(b.blocks[k].params[0].1),
295        //         Expr::Bind(
296        //             Operator::I32Const {
297        //                 value: funcs.len() as u32,
298        //             },
299        //             vec![],
300        //         ),
301        //     ],
302        // );
303        // let (c, k) = e.build(m, &mut b, k)?;
304        let c = b.add_op(
305            k,
306            Operator::I32RemU,
307            &[b.blocks[k].params[0].1, fl],
308            &[Type::I32],
309        );
310        let args = b.blocks[b.entry].params[1..]
311            .iter()
312            .map(|a| a.1)
313            .collect::<Vec<_>>();
314        let blocks = funcs
315            .iter()
316            .map(|(f, t)| {
317                let args = args.clone();
318                let k = b.add_block();
319                if f.is_invalid() {
320                    let rets = b
321                        .rets
322                        .clone()
323                        .into_iter()
324                        .map(|t| match t.clone() {
325                            Type::I32 => b.add_op(k, Operator::I32Const { value: 0 }, &[], &[t]),
326                            Type::I64 => b.add_op(k, Operator::I64Const { value: 0 }, &[], &[t]),
327                            Type::F32 => b.add_op(k, Operator::F32Const { value: 0 }, &[], &[t]),
328                            Type::F64 => b.add_op(k, Operator::F64Const { value: 0 }, &[], &[t]),
329                            Type::V128 => todo!(),
330                            Type::Heap(_) => {
331                                b.add_op(k, Operator::RefNull { ty: t.clone() }, &[], &[t])
332                            }
333                            _ => todo!(),
334                        })
335                        .collect();
336                    b.set_terminator(k, portal_pc_waffle::Terminator::Return { values: rets });
337                } else if *t == vec![Type::I32] {
338                    b.set_terminator(
339                        k,
340                        portal_pc_waffle::Terminator::ReturnCall {
341                            func: *f,
342                            args: once(a).chain(args.into_iter()).collect(),
343                        },
344                    );
345                } else {
346                    let (_, tf, ts, tts) = tc(m, t.clone());
347                    let real = if method == ".drop" {
348                        let c = b.add_op(k, Operator::Call { function_index: tf }, &[a], &t);
349                        t.iter()
350                            .cloned()
351                            // .zip(ts.into_iter())
352                            .enumerate()
353                            .map(|(w, u)| {
354                                b.add_op(k, Operator::StructGet { sig: ts, idx: w }, &[a], &[u])
355                            })
356                            .collect::<Vec<_>>()
357                    } else {
358                        let a = b.add_op(
359                            k,
360                            Operator::TableGet { table_index: tts },
361                            &[a],
362                            &[Type::Heap(WithNullable {
363                                value: portal_pc_waffle::HeapType::Sig { sig_index: ts },
364                                nullable: true,
365                            })],
366                        );
367                        t.iter()
368                            .cloned()
369                            // .zip(ts.into_iter())
370                            .enumerate()
371                            .map(|(w, u)| {
372                                b.add_op(k, Operator::StructGet { sig: ts, idx: w }, &[a], &[u])
373                            })
374                            .collect()
375                    };
376                    b.set_terminator(
377                        k,
378                        portal_pc_waffle::Terminator::ReturnCall {
379                            func: *f,
380                            args: real.into_iter().chain(args.into_iter()).collect(),
381                        },
382                    );
383                    // anyhow::bail!("invalid type");
384                }
385                Ok(BlockTarget {
386                    block: k,
387                    args: vec![],
388                })
389            })
390            .collect::<anyhow::Result<Vec<_>>>()?;
391        b.set_terminator(
392            k,
393            portal_pc_waffle::Terminator::Select {
394                value: c,
395                targets: blocks,
396                default: BlockTarget {
397                    block: b.entry,
398                    args,
399                },
400            },
401        );
402        let f = m.funcs.push(FuncDecl::Body(
403            sig,
404            format!("pit/{rid}/~{target}{method}"),
405            b,
406        ));
407        m.exports.push(Export {
408            name: format!("pit/{rid}/~{target}{method}"),
409            kind: ExportKind::Func(f),
410        });
411    }
412    Ok(())
413}
414pub fn jigger(m: &mut Module, seed: &[u8]) -> anyhow::Result<()> {
415    let mut s = Sha3_256::default();
416    s.update(&m.to_wasm_bytes()?);
417    s.update(seed);
418    let s = s.finalize();
419    for i in m.imports.iter_mut() {
420        if !i.module.starts_with("pit/") {
421            continue;
422        }
423        if let Some(a) = i.name.strip_prefix("~") {
424            let a = format!("{a}-{s:?}");
425            let mut s = Sha3_256::default();
426            s.update(a.as_bytes());
427            let s = s.finalize();
428            let s = hex::encode(s);
429            i.name = format!("~{s}");
430        }
431    }
432    for x in m.exports.iter_mut() {
433        if let Some(a) = x.name.strip_prefix("pit/") {
434            if let Some((b, a)) = a.split_once("/~") {
435                if let Some((a, c)) = a.split_once("/") {
436                    let a = format!("{a}-{s:?}");
437                    let mut s = Sha3_256::default();
438                    s.update(a.as_bytes());
439                    let s = s.finalize();
440                    let s = hex::encode(s);
441                    x.name = format!("pit/{b}/~{s}/{c}");
442                }
443            }
444        }
445    }
446    Ok(())
447}