1use candid::Principal;
2use std::collections::{HashMap, HashSet};
3use walrus::ir::*;
4use walrus::*;
5
6pub struct Config {
7 pub remove_cycles_add: bool,
8 pub filter_cycles_add: bool,
9 pub limit_stable_memory_page: Option<u32>,
10 pub limit_heap_memory_page: Option<u32>,
11 pub playground_canister_id: Option<candid::Principal>,
12}
13
14struct Replacer(HashMap<FunctionId, FunctionId>);
15impl VisitorMut for Replacer {
16 fn visit_instr_mut(&mut self, instr: &mut Instr, _: &mut InstrLocId) {
17 if let Instr::Call(walrus::ir::Call { func }) = instr {
18 if let Some(new_id) = self.0.get(func) {
19 *instr = Call { func: *new_id }.into();
20 }
21 }
22 }
23}
24impl Replacer {
25 fn new() -> Self {
26 Self(HashMap::new())
27 }
28 fn add(&mut self, old: FunctionId, new: FunctionId) {
29 self.0.insert(old, new);
30 }
31}
32
33pub fn limit_resource(m: &mut Module, config: &Config) {
34 let wasm64 = match m.memories.len() {
35 0 => false, 1 => m.memories.get(m.get_memory_id().unwrap()).memory64,
37 _ => panic!("The Canister Wasm module should have at most one memory"),
38 };
39
40 if let Some(limit) = config.limit_heap_memory_page {
41 limit_heap_memory(m, limit);
42 }
43
44 let mut replacer = Replacer::new();
45
46 if config.remove_cycles_add {
47 make_cycles_add(m, &mut replacer, wasm64);
48 make_cycles_add128(m, &mut replacer);
49 make_cycles_burn128(m, &mut replacer, wasm64);
50 }
51
52 if config.filter_cycles_add {
53 let global_id = m.globals.add_local(
57 ValType::I32,
58 true, false, ConstExpr::Value(Value::I32(0)),
61 );
62 make_filter_cycles_add(m, &mut replacer, wasm64, global_id);
64 make_filter_cycles_add128(m, &mut replacer, global_id);
65 make_cycles_burn128(m, &mut replacer, wasm64);
67 make_filter_call_new(m, &mut replacer, wasm64, global_id);
69 }
70
71 if let Some(limit) = config.limit_stable_memory_page {
72 make_stable_grow(m, &mut replacer, wasm64, limit as i32);
73 make_stable64_grow(m, &mut replacer, limit as i64);
74 }
75
76 if let Some(redirect_id) = config.playground_canister_id {
77 make_redirect_call_new(m, &mut replacer, wasm64, redirect_id);
78 }
79
80 let new_ids = replacer.0.values().cloned().collect::<HashSet<_>>();
81 m.funcs.iter_local_mut().for_each(|(id, func)| {
82 if new_ids.contains(&id) {
83 return;
84 }
85 dfs_pre_order_mut(&mut replacer, func, func.entry_block());
86 });
87}
88
89fn limit_heap_memory(m: &mut Module, limit: u32) {
90 if let Ok(memory_id) = m.get_memory_id() {
91 let memory = m.memories.get_mut(memory_id);
92 let limit = limit as u64;
93 if memory.initial > limit {
94 if m.data
102 .iter()
103 .filter_map(|data| {
104 match data.kind {
105 DataKind::Passive => None,
106 DataKind::Active {
107 memory: data_memory_id,
108 offset,
109 } => {
110 if data_memory_id == memory_id {
111 match offset {
112 ConstExpr::Value(Value::I32(offset)) => Some(offset as u64),
113 ConstExpr::Value(Value::I64(offset)) => Some(offset as u64),
114 _ => {
115 None
117 }
118 }
119 } else {
120 None
121 }
122 }
123 }
124 })
125 .all(|offset| offset < limit * 65536)
126 {
127 memory.initial = limit;
128 } else {
129 panic!("Unable to restrict Wasm heap memory to {limit} pages");
130 }
131 }
132 memory.maximum = Some(limit);
133 }
134}
135
136fn make_cycles_add(m: &mut Module, replacer: &mut Replacer, wasm64: bool) {
137 if let Some(old_cycles_add) = get_ic_func_id(m, "call_cycles_add") {
138 if wasm64 {
139 panic!("Wasm64 module should not call `call_cycles_add`");
140 }
141 let mut builder = FunctionBuilder::new(&mut m.types, &[ValType::I64], &[]);
142 let amount = m.locals.add(ValType::I64);
143 builder.func_body().local_get(amount).drop();
144 let new_cycles_add = builder.finish(vec![amount], &mut m.funcs);
145 replacer.add(old_cycles_add, new_cycles_add);
146 }
147}
148
149fn make_filter_cycles_add(
150 m: &mut Module,
151 replacer: &mut Replacer,
152 wasm64: bool,
153 global_id: GlobalId,
154) {
155 if let Some(old_cycles_add) = get_ic_func_id(m, "call_cycles_add") {
156 if wasm64 {
157 panic!("Wasm64 module should not call `call_cycles_add`");
158 }
159 let mut builder = FunctionBuilder::new(&mut m.types, &[ValType::I64], &[]);
160 let amount = m.locals.add(ValType::I64);
161 let mut func = builder.func_body();
162 func.global_get(global_id);
164 func.i32_const(0);
165 func.binop(BinaryOp::I32Ne);
166 func.if_else(
168 None,
169 |then| {
170 then.local_get(amount).drop(); },
172 |otherwise| {
173 otherwise.local_get(amount).call(old_cycles_add); },
175 );
176 let new_cycles_add = builder.finish(vec![amount], &mut m.funcs);
177 replacer.add(old_cycles_add, new_cycles_add);
178 }
179}
180
181fn make_cycles_add128(m: &mut Module, replacer: &mut Replacer) {
182 if let Some(old_cycles_add128) = get_ic_func_id(m, "call_cycles_add128") {
183 let mut builder = FunctionBuilder::new(&mut m.types, &[ValType::I64, ValType::I64], &[]);
184 let high = m.locals.add(ValType::I64);
185 let low = m.locals.add(ValType::I64);
186 builder
187 .func_body()
188 .local_get(high)
189 .local_get(low)
190 .drop()
191 .drop();
192 let new_cycles_add128 = builder.finish(vec![high, low], &mut m.funcs);
193 replacer.add(old_cycles_add128, new_cycles_add128);
194 }
195}
196
197fn make_filter_cycles_add128(m: &mut Module, replacer: &mut Replacer, global_id: GlobalId) {
198 if let Some(old_cycles_add128) = get_ic_func_id(m, "call_cycles_add128") {
199 let mut builder = FunctionBuilder::new(&mut m.types, &[ValType::I64, ValType::I64], &[]);
200 let high = m.locals.add(ValType::I64);
201 let low = m.locals.add(ValType::I64);
202 let mut func = builder.func_body();
203 func.global_get(global_id);
205 func.i32_const(0);
206 func.binop(BinaryOp::I32Ne);
207 func.if_else(
209 None,
210 |then| {
211 then.local_get(high).local_get(low).drop().drop(); },
213 |otherwise| {
214 otherwise
215 .local_get(high)
216 .local_get(low)
217 .call(old_cycles_add128); },
219 );
220 let new_cycles_add128 = builder.finish(vec![high, low], &mut m.funcs);
221 replacer.add(old_cycles_add128, new_cycles_add128);
222 }
223}
224
225fn make_cycles_burn128(m: &mut Module, replacer: &mut Replacer, wasm64: bool) {
226 if let Some(older_cycles_burn128) = get_ic_func_id(m, "cycles_burn128") {
227 let dst_type = match wasm64 {
228 true => ValType::I64,
229 false => ValType::I32,
230 };
231 let mut builder =
232 FunctionBuilder::new(&mut m.types, &[ValType::I64, ValType::I64, dst_type], &[]);
233 let high = m.locals.add(ValType::I64);
234 let low = m.locals.add(ValType::I64);
235 let dst = m.locals.add(dst_type);
236 builder
237 .func_body()
238 .local_get(high)
239 .local_get(low)
240 .local_get(dst)
241 .drop()
242 .drop()
243 .drop();
244 let new_cycles_burn128 = builder.finish(vec![high, low, dst], &mut m.funcs);
245 replacer.add(older_cycles_burn128, new_cycles_burn128);
246 }
247}
248
249fn make_stable_grow(m: &mut Module, replacer: &mut Replacer, wasm64: bool, limit: i32) {
250 if let Some(old_stable_grow) = get_ic_func_id(m, "stable_grow") {
251 if wasm64 {
252 panic!("Wasm64 module should not call `stable_grow`");
253 }
254 let stable_size = get_ic_func_id(m, "stable_size").unwrap();
256 let mut builder = FunctionBuilder::new(&mut m.types, &[ValType::I32], &[ValType::I32]);
257 let requested = m.locals.add(ValType::I32);
258 builder
259 .func_body()
260 .call(stable_size)
261 .local_get(requested)
262 .binop(BinaryOp::I32Add)
263 .i32_const(limit)
264 .binop(BinaryOp::I32GtU)
265 .if_else(
266 Some(ValType::I32),
267 |then| {
268 then.i32_const(-1);
269 },
270 |else_| {
271 else_.local_get(requested).call(old_stable_grow);
272 },
273 );
274 let new_stable_grow = builder.finish(vec![requested], &mut m.funcs);
275 replacer.add(old_stable_grow, new_stable_grow);
276 }
277}
278
279fn make_stable64_grow(m: &mut Module, replacer: &mut Replacer, limit: i64) {
280 if let Some(old_stable64_grow) = get_ic_func_id(m, "stable64_grow") {
281 let stable64_size = get_ic_func_id(m, "stable64_size").unwrap();
283 let mut builder = FunctionBuilder::new(&mut m.types, &[ValType::I64], &[ValType::I64]);
284 let requested = m.locals.add(ValType::I64);
285 builder
286 .func_body()
287 .call(stable64_size)
288 .local_get(requested)
289 .binop(BinaryOp::I64Add)
290 .i64_const(limit)
291 .binop(BinaryOp::I64GtU)
292 .if_else(
293 Some(ValType::I64),
294 |then| {
295 then.i64_const(-1);
296 },
297 |else_| {
298 else_.local_get(requested).call(old_stable64_grow);
299 },
300 );
301 let new_stable64_grow = builder.finish(vec![requested], &mut m.funcs);
302 replacer.add(old_stable64_grow, new_stable64_grow);
303 }
304}
305
306#[allow(clippy::too_many_arguments)]
307fn check_list(
308 memory: MemoryId,
309 checks: &mut InstrSeqBuilder,
310 no_redirect: LocalId,
311 size: LocalId,
312 src: LocalId,
313 is_rename: Option<LocalId>,
314 list: &Vec<&[u8]>,
315 wasm64: bool,
316) {
317 let checks_id = checks.id();
318 for bytes in list {
319 checks.block(None, |list_check| {
320 let list_check_id = list_check.id();
321 list_check.local_get(size);
323 match wasm64 {
324 true => {
325 list_check
326 .i64_const(bytes.len() as i64)
327 .binop(BinaryOp::I64Ne);
328 }
329 false => {
330 list_check
331 .i32_const(bytes.len() as i32)
332 .binop(BinaryOp::I32Ne);
333 }
334 }
335 list_check.br_if(list_check_id);
336 for i in 0..bytes.len() {
338 list_check.local_get(src).load(
339 memory,
340 match wasm64 {
341 true => LoadKind::I64_8 {
342 kind: ExtendedLoad::ZeroExtend,
343 },
344 false => LoadKind::I32_8 {
345 kind: ExtendedLoad::ZeroExtend,
346 },
347 },
348 MemArg {
349 offset: i as u32,
350 align: 1,
351 },
352 );
353 }
354 for byte in bytes.iter().rev() {
355 match wasm64 {
356 true => {
357 list_check.i64_const(*byte as i64).binop(BinaryOp::I64Ne);
358 }
359 false => {
360 list_check.i32_const(*byte as i32).binop(BinaryOp::I32Ne);
361 }
362 }
363 list_check.br_if(list_check_id);
364 }
365 if let Some(is_rename) = is_rename {
367 if bytes == b"http_request" {
368 list_check.i32_const(1).local_set(is_rename);
369 } else {
370 list_check.i32_const(0).local_set(is_rename);
371 }
372 }
373 list_check.i32_const(0).local_set(no_redirect).br(checks_id);
374 });
375 }
376 checks.i32_const(1).local_set(no_redirect);
378}
379
380fn make_redirect_call_new(
381 m: &mut Module,
382 replacer: &mut Replacer,
383 wasm64: bool,
384 redirect_id: Principal,
385) {
386 if let Some(old_call_new) = get_ic_func_id(m, "call_new") {
387 let pointer_type = match wasm64 {
388 true => ValType::I64,
389 false => ValType::I32,
390 };
391 let redirect_id = redirect_id.as_slice();
392 let callee_src = m.locals.add(pointer_type);
394 let callee_size = m.locals.add(pointer_type);
395 let name_src = m.locals.add(pointer_type);
396 let name_size = m.locals.add(pointer_type);
397 let arg5 = m.locals.add(pointer_type);
398 let arg6 = m.locals.add(pointer_type);
399 let arg7 = m.locals.add(pointer_type);
400 let arg8 = m.locals.add(pointer_type);
401
402 let memory = m
403 .get_memory_id()
404 .expect("Canister Wasm module should have only one memory");
405
406 let no_redirect = m.locals.add(ValType::I32);
408 let is_rename = m.locals.add(ValType::I32);
409 let mut memory_backup = Vec::new();
410 for _ in 0..redirect_id.len() {
411 memory_backup.push(m.locals.add(pointer_type));
412 }
413 let redirect_canisters = [
414 Principal::from_slice(&[]),
415 Principal::from_text("7hfb6-caaaa-aaaar-qadga-cai").unwrap(),
416 ];
417
418 let controller_function_names = [
422 "create_canister",
423 "update_settings",
424 "install_code",
425 "uninstall_code",
426 "canister_status",
427 "stop_canister",
428 "start_canister",
429 "delete_canister",
430 "list_canister_snapshots",
431 "take_canister_snapshot",
432 "load_canister_snapshot",
433 "delete_canister_snapshot",
434 "sign_with_ecdsa",
436 "sign_with_schnorr",
437 "http_request", "_ttp_request", "eth_call",
441 "eth_feeHistory",
442 "eth_getBlockByNumber",
443 "eth_getLogs",
444 "eth_getTransactionCount",
445 "eth_getTransactionReceipt",
446 "eth_sendRawTransaction",
447 "request",
448 ];
449
450 let mut builder = FunctionBuilder::new(
451 &mut m.types,
452 &[
453 pointer_type,
454 pointer_type,
455 pointer_type,
456 pointer_type,
457 pointer_type,
458 pointer_type,
459 pointer_type,
460 pointer_type,
461 ],
462 &[],
463 );
464
465 builder
466 .func_body()
467 .block(None, |checks| {
468 let checks_id = checks.id();
469 checks
471 .block(None, |id_check| {
472 check_list(
473 memory,
474 id_check,
475 no_redirect,
476 callee_size,
477 callee_src,
478 None,
479 &redirect_canisters
480 .iter()
481 .map(|p| p.as_slice())
482 .collect::<Vec<_>>(),
483 wasm64,
484 );
485 })
486 .local_get(no_redirect)
487 .br_if(checks_id);
488 check_list(
490 memory,
491 checks,
492 no_redirect,
493 name_size,
494 name_src,
495 Some(is_rename),
496 &controller_function_names
497 .iter()
498 .map(|s| s.as_bytes())
499 .collect::<Vec<_>>(),
500 wasm64,
501 );
502 })
503 .local_get(no_redirect)
504 .if_else(
505 None,
506 |block| {
507 block
509 .local_get(callee_src)
510 .local_get(callee_size)
511 .local_get(name_src)
512 .local_get(name_size)
513 .local_get(arg5)
514 .local_get(arg6)
515 .local_get(arg7)
516 .local_get(arg8)
517 .call(old_call_new);
518 },
519 |block| {
520 for (address, backup_var) in memory_backup.iter().enumerate() {
522 match wasm64 {
523 true => {
524 block
525 .i64_const(address as i64)
526 .load(
527 memory,
528 LoadKind::I64_8 {
529 kind: ExtendedLoad::ZeroExtend,
530 },
531 MemArg {
532 offset: 0,
533 align: 1,
534 },
535 )
536 .local_set(*backup_var);
537 }
538 false => {
539 block
540 .i32_const(address as i32)
541 .load(
542 memory,
543 LoadKind::I32_8 {
544 kind: ExtendedLoad::ZeroExtend,
545 },
546 MemArg {
547 offset: 0,
548 align: 1,
549 },
550 )
551 .local_set(*backup_var);
552 }
553 }
554 }
555
556 for (address, byte) in redirect_id.iter().enumerate() {
558 match wasm64 {
559 true => {
560 block
561 .i64_const(address as i64)
562 .i64_const(*byte as i64)
563 .store(
564 memory,
565 StoreKind::I64_8 { atomic: false },
566 MemArg {
567 offset: 0,
568 align: 1,
569 },
570 );
571 }
572 false => {
573 block
574 .i32_const(address as i32)
575 .i32_const(*byte as i32)
576 .store(
577 memory,
578 StoreKind::I32_8 { atomic: false },
579 MemArg {
580 offset: 0,
581 align: 1,
582 },
583 );
584 }
585 }
586 }
587 block.local_get(is_rename).if_else(
588 None,
589 |then| match wasm64 {
590 true => {
591 then.local_get(name_src).i64_const('_' as i64).store(
592 memory,
593 StoreKind::I64_8 { atomic: false },
594 MemArg {
595 offset: 0,
596 align: 1,
597 },
598 );
599 }
600 false => {
601 then.local_get(name_src).i32_const('_' as i32).store(
602 memory,
603 StoreKind::I32_8 { atomic: false },
604 MemArg {
605 offset: 0,
606 align: 1,
607 },
608 );
609 }
610 },
611 |_| {},
612 );
613 match wasm64 {
614 true => {
615 block.i64_const(0).i64_const(redirect_id.len() as i64);
616 }
617 false => {
618 block.i32_const(0).i32_const(redirect_id.len() as i32);
619 }
620 }
621
622 block
623 .local_get(name_src)
624 .local_get(name_size)
625 .local_get(arg5)
626 .local_get(arg6)
627 .local_get(arg7)
628 .local_get(arg8)
629 .call(old_call_new);
630
631 for (address, byte) in memory_backup.iter().enumerate() {
633 match wasm64 {
634 true => {
635 block.i64_const(address as i64).local_get(*byte).store(
636 memory,
637 StoreKind::I64_8 { atomic: false },
638 MemArg {
639 offset: 0,
640 align: 1,
641 },
642 );
643 }
644 false => {
645 block.i32_const(address as i32).local_get(*byte).store(
646 memory,
647 StoreKind::I32_8 { atomic: false },
648 MemArg {
649 offset: 0,
650 align: 1,
651 },
652 );
653 }
654 }
655 }
656 },
657 );
658 let new_call_new = builder.finish(
659 vec![
660 callee_src,
661 callee_size,
662 name_src,
663 name_size,
664 arg5,
665 arg6,
666 arg7,
667 arg8,
668 ],
669 &mut m.funcs,
670 );
671 replacer.add(old_call_new, new_call_new);
672 }
673}
674
675fn make_filter_call_new(
676 m: &mut Module,
677 replacer: &mut Replacer,
678 wasm64: bool,
679 global_id: GlobalId,
680) {
681 if let Some(old_call_new) = get_ic_func_id(m, "call_new") {
682 let pointer_type = match wasm64 {
683 true => ValType::I64,
684 false => ValType::I32,
685 };
686 let callee_src = m.locals.add(pointer_type);
688 let callee_size = m.locals.add(pointer_type);
689 let name_src = m.locals.add(pointer_type);
690 let name_size = m.locals.add(pointer_type);
691 let arg5 = m.locals.add(pointer_type);
692 let arg6 = m.locals.add(pointer_type);
693 let arg7 = m.locals.add(pointer_type);
694 let arg8 = m.locals.add(pointer_type);
695
696 let memory = m
697 .get_memory_id()
698 .expect("Canister Wasm module should have only one memory");
699
700 let not_allowed_canister = m.locals.add(ValType::I32);
702 let allow_cycles = m.locals.add(ValType::I32);
703
704 let allowed_canisters = [
708 Principal::from_slice(&[]),
709 Principal::from_text("7hfb6-caaaa-aaaar-qadga-cai").unwrap(),
710 ];
711 let forbidden_function_names = ["create_canister", "deposit_cycles"];
712
713 let mut builder = FunctionBuilder::new(
714 &mut m.types,
715 &[
716 pointer_type,
717 pointer_type,
718 pointer_type,
719 pointer_type,
720 pointer_type,
721 pointer_type,
722 pointer_type,
723 pointer_type,
724 ],
725 &[],
726 );
727
728 builder
729 .func_body()
730 .block(None, |checks| {
731 let checks_id = checks.id();
732 checks
734 .block(None, |id_check| {
735 check_list(
737 memory,
738 id_check,
739 not_allowed_canister,
740 callee_size,
741 callee_src,
742 None,
743 &allowed_canisters
744 .iter()
745 .map(|p| p.as_slice())
746 .collect::<Vec<_>>(),
747 wasm64,
748 );
749 })
750 .local_get(not_allowed_canister)
751 .br_if(checks_id); check_list(
756 memory,
757 checks,
758 allow_cycles,
759 name_size,
760 name_src,
761 None,
762 &forbidden_function_names
763 .iter()
764 .map(|s| s.as_bytes())
765 .collect::<Vec<_>>(),
766 wasm64,
767 );
768 })
769 .local_get(allow_cycles)
770 .if_else(
771 None,
772 |block| {
773 block.i32_const(0).global_set(global_id);
775 },
776 |block| {
777 block.i32_const(1).global_set(global_id);
779 },
780 )
781 .local_get(callee_src)
782 .local_get(callee_size)
783 .local_get(name_src)
784 .local_get(name_size)
785 .local_get(arg5)
786 .local_get(arg6)
787 .local_get(arg7)
788 .local_get(arg8)
789 .call(old_call_new);
790 let new_call_new = builder.finish(
791 vec![
792 callee_src,
793 callee_size,
794 name_src,
795 name_size,
796 arg5,
797 arg6,
798 arg7,
799 arg8,
800 ],
801 &mut m.funcs,
802 );
803 replacer.add(old_call_new, new_call_new);
804 }
805}
806
807fn get_ic_func_id(m: &mut Module, method: &str) -> Option<FunctionId> {
811 match m.imports.find("ic0", method) {
812 Some(id) => match m.imports.get(id).kind {
813 ImportKind::Function(func_id) => Some(func_id),
814 _ => unreachable!(),
815 },
816 None => {
817 let ty = match method {
818 "stable_size" => Some(m.types.add(&[], &[ValType::I32])),
819 "stable64_size" => Some(m.types.add(&[], &[ValType::I64])),
820 _ => None,
821 };
822 match ty {
823 Some(ty) => {
824 let func_id = m.add_import_func("ic0", method, ty).0;
825 Some(func_id)
826 }
827 None => None,
828 }
829 }
830 }
831}