1use indexmap::IndexMap;
5use itertools::{Either, Itertools};
6use rustc_hash::{FxHashMap, FxHashSet};
7use sway_types::{FxIndexMap, FxIndexSet};
8
9use crate::{
10 get_gep_symbol, get_loaded_symbols, get_referred_symbol, get_referred_symbols,
11 get_stored_symbols, memory_utils, AnalysisResults, Block, Context, EscapedSymbols,
12 FuelVmInstruction, Function, InstOp, Instruction, InstructionInserter, IrError, LocalVar, Pass,
13 PassMutability, ReferredSymbols, ScopedPass, Symbol, Type, Value, ValueDatum,
14 ESCAPED_SYMBOLS_NAME,
15};
16
17pub const MEMCPYOPT_NAME: &str = "memcpyopt";
18
19pub fn create_memcpyopt_pass() -> Pass {
20 Pass {
21 name: MEMCPYOPT_NAME,
22 descr: "Optimizations related to MemCopy instructions",
23 deps: vec![ESCAPED_SYMBOLS_NAME],
24 runner: ScopedPass::FunctionPass(PassMutability::Transform(mem_copy_opt)),
25 }
26}
27
28pub fn mem_copy_opt(
29 context: &mut Context,
30 analyses: &AnalysisResults,
31 function: Function,
32) -> Result<bool, IrError> {
33 let mut modified = false;
34 modified |= local_copy_prop_prememcpy(context, analyses, function)?;
35 modified |= load_store_to_memcopy(context, function)?;
36 modified |= local_copy_prop(context, analyses, function)?;
37
38 Ok(modified)
39}
40
41fn local_copy_prop_prememcpy(
42 context: &mut Context,
43 analyses: &AnalysisResults,
44 function: Function,
45) -> Result<bool, IrError> {
46 struct InstInfo {
47 block: Block,
49 pos: usize,
51 }
52
53 let escaped_symbols = match analyses.get_analysis_result(function) {
56 EscapedSymbols::Complete(syms) => syms,
57 EscapedSymbols::Incomplete(_) => return Ok(false),
58 };
59
60 let mut loads_map = FxHashMap::<Symbol, Vec<Value>>::default();
62 let mut stores_map = FxHashMap::<Symbol, Vec<Value>>::default();
64 let mut instr_info_map = FxHashMap::<Value, InstInfo>::default();
66
67 for (pos, (block, inst)) in function.instruction_iter(context).enumerate() {
68 let info = || InstInfo { block, pos };
69 let inst_e = inst.get_instruction(context).unwrap();
70 match inst_e {
71 Instruction {
72 op: InstOp::Load(src_val_ptr),
73 ..
74 } => {
75 if let Some(local) = get_referred_symbol(context, *src_val_ptr) {
76 loads_map
77 .entry(local)
78 .and_modify(|loads| loads.push(inst))
79 .or_insert(vec![inst]);
80 instr_info_map.insert(inst, info());
81 }
82 }
83 Instruction {
84 op: InstOp::Store { dst_val_ptr, .. },
85 ..
86 } => {
87 if let Some(local) = get_referred_symbol(context, *dst_val_ptr) {
88 stores_map
89 .entry(local)
90 .and_modify(|stores| stores.push(inst))
91 .or_insert(vec![inst]);
92 instr_info_map.insert(inst, info());
93 }
94 }
95 _ => (),
96 }
97 }
98
99 let mut to_delete = FxHashSet::<Value>::default();
100 let candidates: FxHashMap<Symbol, Symbol> = function
106 .instruction_iter(context)
107 .enumerate()
108 .filter_map(|(pos, (block, instr_val))| {
109 instr_val
112 .get_instruction(context)
113 .and_then(|instr| {
114 if let Instruction {
116 op:
117 InstOp::Store {
118 dst_val_ptr,
119 stored_val,
120 },
121 ..
122 } = instr
123 {
124 get_gep_symbol(context, *dst_val_ptr).and_then(|dst_local| {
125 stored_val
126 .get_instruction(context)
127 .map(|src_instr| (src_instr, stored_val, dst_local))
128 })
129 } else {
130 None
131 }
132 })
133 .and_then(|(src_instr, stored_val, dst_local)| {
134 if let Instruction {
136 op: InstOp::Load(src_val_ptr),
137 ..
138 } = src_instr
139 {
140 get_gep_symbol(context, *src_val_ptr)
141 .map(|src_local| (stored_val, dst_local, src_local))
142 } else {
143 None
144 }
145 })
146 .and_then(|(src_load, dst_local, src_local)| {
147 let (temp_empty1, temp_empty2, temp_empty3) = (vec![], vec![], vec![]);
151 let dst_local_stores = stores_map.get(&dst_local).unwrap_or(&temp_empty1);
152 let src_local_stores = stores_map.get(&src_local).unwrap_or(&temp_empty2);
153 let dst_local_loads = loads_map.get(&dst_local).unwrap_or(&temp_empty3);
154 if dst_local_stores.len() != 1 || dst_local_stores[0] != instr_val
156 ||
157 !src_local_stores.iter().all(|store_val|{
159 let instr_info = instr_info_map.get(store_val).unwrap();
160 let src_load_info = instr_info_map.get(src_load).unwrap();
161 instr_info.block == block && instr_info.pos < src_load_info.pos
162 })
163 ||
164 !dst_local_loads.iter().all(|load_val| {
166 let instr_info = instr_info_map.get(load_val).unwrap();
167 instr_info.block == block && instr_info.pos > pos
168 })
169 || escaped_symbols.contains(&dst_local)
171 || escaped_symbols.contains(&src_local)
172 || dst_local.get_type(context) != src_local.get_type(context)
174 || matches!(dst_local, Symbol::Arg(_))
176 {
177 None
178 } else {
179 to_delete.insert(instr_val);
180 Some((dst_local, src_local))
181 }
182 })
183 })
184 .collect();
185
186 fn get_replace_with(candidates: &FxHashMap<Symbol, Symbol>, local: &Symbol) -> Option<Symbol> {
190 candidates
191 .get(local)
192 .map(|replace_with| get_replace_with(candidates, replace_with).unwrap_or(*replace_with))
193 }
194
195 enum ReplaceWith {
198 InPlaceLocal(LocalVar),
199 Value(Value),
200 }
201
202 let replaces: Vec<_> = function
206 .instruction_iter(context)
207 .filter_map(|(_block, value)| match value.get_instruction(context) {
208 Some(Instruction {
209 op: InstOp::GetLocal(local),
210 ..
211 }) => get_replace_with(&candidates, &Symbol::Local(*local)).map(|replace_with| {
212 (
213 value,
214 match replace_with {
215 Symbol::Local(local) => ReplaceWith::InPlaceLocal(local),
216 Symbol::Arg(ba) => {
217 ReplaceWith::Value(ba.block.get_arg(context, ba.idx).unwrap())
218 }
219 },
220 )
221 }),
222 _ => None,
223 })
224 .collect();
225
226 let mut value_replace = FxHashMap::<Value, Value>::default();
227 for (value, replace_with) in replaces.into_iter() {
228 match replace_with {
229 ReplaceWith::InPlaceLocal(replacement_var) => {
230 let Some(&Instruction {
231 op: InstOp::GetLocal(redundant_var),
232 parent,
233 }) = value.get_instruction(context)
234 else {
235 panic!("earlier match now fails");
236 };
237 if redundant_var.is_mutable(context) {
238 replacement_var.set_mutable(context, true);
239 }
240 value.replace(
241 context,
242 ValueDatum::Instruction(Instruction {
243 op: InstOp::GetLocal(replacement_var),
244 parent,
245 }),
246 )
247 }
248 ReplaceWith::Value(replace_with) => {
249 value_replace.insert(value, replace_with);
250 }
251 }
252 }
253 function.replace_values(context, &value_replace, None);
254
255 let blocks: Vec<Block> = function.block_iter(context).collect();
257 for block in blocks {
258 block.remove_instructions(context, |value| to_delete.contains(&value));
259 }
260 Ok(true)
261}
262
263fn deconstruct_memcpy(context: &Context, inst: Value) -> Option<(Value, Value, u64)> {
265 match inst.get_instruction(context).unwrap() {
266 Instruction {
267 op:
268 InstOp::MemCopyBytes {
269 dst_val_ptr,
270 src_val_ptr,
271 byte_len,
272 },
273 ..
274 } => Some((*dst_val_ptr, *src_val_ptr, *byte_len)),
275 Instruction {
276 op:
277 InstOp::MemCopyVal {
278 dst_val_ptr,
279 src_val_ptr,
280 },
281 ..
282 } => Some((
283 *dst_val_ptr,
284 *src_val_ptr,
285 memory_utils::pointee_size(context, *dst_val_ptr),
286 )),
287 _ => None,
288 }
289}
290
291fn local_copy_prop(
293 context: &mut Context,
294 analyses: &AnalysisResults,
295 function: Function,
296) -> Result<bool, IrError> {
297 let escaped_symbols = match analyses.get_analysis_result(function) {
302 EscapedSymbols::Complete(syms) => syms,
303 EscapedSymbols::Incomplete(_) => return Ok(false),
304 };
305
306 let mut available_copies: FxHashSet<Value>;
308 let mut src_to_copies: FxIndexMap<Symbol, FxIndexSet<Value>>;
310 let mut dest_to_copies: FxIndexMap<Symbol, FxIndexSet<Value>>;
314
315 fn kill_defined_symbol(
317 context: &Context,
318 value: Value,
319 len: u64,
320 available_copies: &mut FxHashSet<Value>,
321 src_to_copies: &mut FxIndexMap<Symbol, FxIndexSet<Value>>,
322 dest_to_copies: &mut FxIndexMap<Symbol, FxIndexSet<Value>>,
323 ) {
324 match get_referred_symbols(context, value) {
325 ReferredSymbols::Complete(rs) => {
326 for sym in rs {
327 if let Some(copies) = src_to_copies.get_mut(&sym) {
328 for copy in &*copies {
329 let (_, src_ptr, copy_size) = deconstruct_memcpy(context, *copy)
330 .expect("Expected copy instruction");
331 if memory_utils::may_alias(context, value, len, src_ptr, copy_size) {
332 available_copies.remove(copy);
333 }
334 }
335 }
336 if let Some(copies) = dest_to_copies.get_mut(&sym) {
337 for copy in &*copies {
338 let (dest_ptr, copy_size) = match copy.get_instruction(context).unwrap()
339 {
340 Instruction {
341 op:
342 InstOp::MemCopyBytes {
343 dst_val_ptr,
344 src_val_ptr: _,
345 byte_len,
346 },
347 ..
348 } => (*dst_val_ptr, *byte_len),
349 Instruction {
350 op:
351 InstOp::MemCopyVal {
352 dst_val_ptr,
353 src_val_ptr: _,
354 },
355 ..
356 } => (
357 *dst_val_ptr,
358 memory_utils::pointee_size(context, *dst_val_ptr),
359 ),
360 _ => panic!("Unexpected copy instruction"),
361 };
362 if memory_utils::may_alias(context, value, len, dest_ptr, copy_size) {
363 available_copies.remove(copy);
364 }
365 }
366 }
367 }
368 src_to_copies.retain(|_, copies| {
370 copies.retain(|copy| available_copies.contains(copy));
371 !copies.is_empty()
372 });
373 dest_to_copies.retain(|_, copies| {
374 copies.retain(|copy| available_copies.contains(copy));
375 !copies.is_empty()
376 });
377 }
378 ReferredSymbols::Incomplete(_) => {
379 available_copies.clear();
381 src_to_copies.clear();
382 dest_to_copies.clear();
383 }
384 }
385 }
386
387 #[allow(clippy::too_many_arguments)]
388 fn gen_new_copy(
389 context: &Context,
390 escaped_symbols: &FxHashSet<Symbol>,
391 copy_inst: Value,
392 dst_val_ptr: Value,
393 src_val_ptr: Value,
394 available_copies: &mut FxHashSet<Value>,
395 src_to_copies: &mut FxIndexMap<Symbol, FxIndexSet<Value>>,
396 dest_to_copies: &mut FxIndexMap<Symbol, FxIndexSet<Value>>,
397 ) {
398 if let (Some(dst_sym), Some(src_sym)) = (
399 get_gep_symbol(context, dst_val_ptr),
400 get_gep_symbol(context, src_val_ptr),
401 ) {
402 if escaped_symbols.contains(&dst_sym) || escaped_symbols.contains(&src_sym) {
403 return;
404 }
405 dest_to_copies
406 .entry(dst_sym)
407 .and_modify(|set| {
408 set.insert(copy_inst);
409 })
410 .or_insert([copy_inst].into_iter().collect());
411 src_to_copies
412 .entry(src_sym)
413 .and_modify(|set| {
414 set.insert(copy_inst);
415 })
416 .or_insert([copy_inst].into_iter().collect());
417 available_copies.insert(copy_inst);
418 }
419 }
420
421 struct ReplGep {
422 base: Symbol,
423 elem_ptr_ty: Type,
424 indices: Vec<Value>,
425 }
426 enum Replacement {
427 OldGep(Value),
428 NewGep(ReplGep),
429 }
430
431 fn process_load(
432 context: &Context,
433 escaped_symbols: &FxHashSet<Symbol>,
434 inst: Value,
435 src_val_ptr: Value,
436 dest_to_copies: &FxIndexMap<Symbol, FxIndexSet<Value>>,
437 replacements: &mut FxHashMap<Value, (Value, Replacement)>,
438 ) -> bool {
439 if let Some(src_sym) = get_referred_symbol(context, src_val_ptr) {
442 if escaped_symbols.contains(&src_sym) {
443 return false;
444 }
445 for memcpy in dest_to_copies
446 .get(&src_sym)
447 .iter()
448 .flat_map(|set| set.iter())
449 {
450 let (dst_ptr_memcpy, src_ptr_memcpy, copy_len) =
451 deconstruct_memcpy(context, *memcpy).expect("Expected copy instruction");
452 if memory_utils::must_alias(
459 context,
460 src_val_ptr,
461 memory_utils::pointee_size(context, src_val_ptr),
462 dst_ptr_memcpy,
463 copy_len,
464 ) {
465 if src_val_ptr.get_type(context) == src_ptr_memcpy.get_type(context) {
467 replacements
468 .insert(inst, (src_val_ptr, Replacement::OldGep(src_ptr_memcpy)));
469 return true;
470 }
471 } else {
472 if let (Some(memcpy_src_sym), Some(memcpy_dst_sym), Some(new_indices)) = (
475 get_gep_symbol(context, src_ptr_memcpy),
476 get_gep_symbol(context, dst_ptr_memcpy),
477 memory_utils::combine_indices(context, src_val_ptr),
478 ) {
479 let memcpy_src_sym_type = memcpy_src_sym
480 .get_type(context)
481 .get_pointee_type(context)
482 .unwrap();
483 let memcpy_dst_sym_type = memcpy_dst_sym
484 .get_type(context)
485 .get_pointee_type(context)
486 .unwrap();
487 if memcpy_src_sym_type == memcpy_dst_sym_type
488 && memcpy_dst_sym_type.size(context).in_bytes() == copy_len
489 {
490 replacements.insert(
491 inst,
492 (
493 src_val_ptr,
494 Replacement::NewGep(ReplGep {
495 base: memcpy_src_sym,
496 elem_ptr_ty: src_val_ptr.get_type(context).unwrap(),
497 indices: new_indices,
498 }),
499 ),
500 );
501 return true;
502 }
503 }
504 }
505 }
506 }
507
508 false
509 }
510
511 let mut modified = false;
512 for block in function.block_iter(context) {
513 loop {
521 available_copies = FxHashSet::default();
522 src_to_copies = IndexMap::default();
523 dest_to_copies = IndexMap::default();
524
525 let mut replacements = FxHashMap::default();
527
528 fn kill_escape_args(
529 context: &Context,
530 args: &Vec<Value>,
531 available_copies: &mut FxHashSet<Value>,
532 src_to_copies: &mut FxIndexMap<Symbol, FxIndexSet<Value>>,
533 dest_to_copies: &mut FxIndexMap<Symbol, FxIndexSet<Value>>,
534 ) {
535 for arg in args {
536 match get_referred_symbols(context, *arg) {
537 ReferredSymbols::Complete(rs) => {
538 let max_size = rs
539 .iter()
540 .filter_map(|sym| {
541 sym.get_type(context)
542 .get_pointee_type(context)
543 .map(|pt| pt.size(context).in_bytes())
544 })
545 .max()
546 .unwrap_or(0);
547 kill_defined_symbol(
548 context,
549 *arg,
550 max_size,
551 available_copies,
552 src_to_copies,
553 dest_to_copies,
554 );
555 }
556 ReferredSymbols::Incomplete(_) => {
557 available_copies.clear();
559 src_to_copies.clear();
560 dest_to_copies.clear();
561
562 break;
563 }
564 }
565 }
566 }
567
568 for inst in block.instruction_iter(context) {
569 match inst.get_instruction(context).unwrap() {
570 Instruction {
571 op: InstOp::Call(callee, args),
572 ..
573 } => {
574 let (immutable_args, mutable_args): (Vec<_>, Vec<_>) =
575 args.iter().enumerate().partition_map(|(arg_idx, arg)| {
576 if callee.is_arg_immutable(context, arg_idx) {
577 Either::Left(*arg)
578 } else {
579 Either::Right(*arg)
580 }
581 });
582 kill_escape_args(
584 context,
585 &mutable_args,
586 &mut available_copies,
587 &mut src_to_copies,
588 &mut dest_to_copies,
589 );
590 for arg in immutable_args {
593 process_load(
594 context,
595 escaped_symbols,
596 inst,
597 arg,
598 &dest_to_copies,
599 &mut replacements,
600 );
601 }
602 }
603 Instruction {
604 op: InstOp::AsmBlock(_, args),
605 ..
606 } => {
607 let args = args.iter().filter_map(|arg| arg.initializer).collect();
608 kill_escape_args(
609 context,
610 &args,
611 &mut available_copies,
612 &mut src_to_copies,
613 &mut dest_to_copies,
614 );
615 }
616 Instruction {
617 op: InstOp::IntToPtr(_, _),
618 ..
619 } => {
620 available_copies.clear();
622 src_to_copies.clear();
623 dest_to_copies.clear();
624 }
625 Instruction {
626 op: InstOp::Load(src_val_ptr),
627 ..
628 } => {
629 process_load(
630 context,
631 escaped_symbols,
632 inst,
633 *src_val_ptr,
634 &dest_to_copies,
635 &mut replacements,
636 );
637 }
638 Instruction {
639 op: InstOp::MemCopyBytes { .. } | InstOp::MemCopyVal { .. },
640 ..
641 } => {
642 let (dst_val_ptr, src_val_ptr, copy_len) =
643 deconstruct_memcpy(context, inst).expect("Expected copy instruction");
644 kill_defined_symbol(
645 context,
646 dst_val_ptr,
647 copy_len,
648 &mut available_copies,
649 &mut src_to_copies,
650 &mut dest_to_copies,
651 );
652 if !process_load(
654 context,
655 escaped_symbols,
656 inst,
657 src_val_ptr,
658 &dest_to_copies,
659 &mut replacements,
660 ) {
661 gen_new_copy(
662 context,
663 escaped_symbols,
664 inst,
665 dst_val_ptr,
666 src_val_ptr,
667 &mut available_copies,
668 &mut src_to_copies,
669 &mut dest_to_copies,
670 );
671 }
672 }
673 Instruction {
674 op:
675 InstOp::Store {
676 dst_val_ptr,
677 stored_val: _,
678 },
679 ..
680 } => {
681 kill_defined_symbol(
682 context,
683 *dst_val_ptr,
684 memory_utils::pointee_size(context, *dst_val_ptr),
685 &mut available_copies,
686 &mut src_to_copies,
687 &mut dest_to_copies,
688 );
689 }
690 Instruction {
691 op:
692 InstOp::FuelVm(
693 FuelVmInstruction::WideBinaryOp { result, .. }
694 | FuelVmInstruction::WideUnaryOp { result, .. }
695 | FuelVmInstruction::WideModularOp { result, .. }
696 | FuelVmInstruction::StateLoadQuadWord {
697 load_val: result, ..
698 },
699 ),
700 ..
701 } => {
702 kill_defined_symbol(
703 context,
704 *result,
705 memory_utils::pointee_size(context, *result),
706 &mut available_copies,
707 &mut src_to_copies,
708 &mut dest_to_copies,
709 );
710 }
711 _ => (),
712 }
713 }
714
715 if replacements.is_empty() {
716 break;
717 } else {
718 modified = true;
719 }
720
721 let mut new_insts = vec![];
726 for inst in block.instruction_iter(context) {
727 if let Some(replacement) = replacements.remove(&inst) {
728 let (to_replace, replacement) = match replacement {
729 (to_replace, Replacement::OldGep(v)) => (to_replace, v),
730 (
731 to_replace,
732 Replacement::NewGep(ReplGep {
733 base,
734 elem_ptr_ty,
735 indices,
736 }),
737 ) => {
738 let base = match base {
739 Symbol::Local(local) => {
740 let base = Value::new_instruction(
741 context,
742 block,
743 InstOp::GetLocal(local),
744 );
745 new_insts.push(base);
746 base
747 }
748 Symbol::Arg(block_arg) => {
749 block_arg.block.get_arg(context, block_arg.idx).unwrap()
750 }
751 };
752 let v = Value::new_instruction(
753 context,
754 block,
755 InstOp::GetElemPtr {
756 base,
757 elem_ptr_ty,
758 indices,
759 },
760 );
761 new_insts.push(v);
762 (to_replace, v)
763 }
764 };
765 match inst.get_instruction_mut(context) {
766 Some(Instruction {
767 op: InstOp::Load(ref mut src_val_ptr),
768 ..
769 })
770 | Some(Instruction {
771 op:
772 InstOp::MemCopyBytes {
773 ref mut src_val_ptr,
774 ..
775 },
776 ..
777 })
778 | Some(Instruction {
779 op:
780 InstOp::MemCopyVal {
781 ref mut src_val_ptr,
782 ..
783 },
784 ..
785 }) => {
786 assert!(to_replace == *src_val_ptr);
787 *src_val_ptr = replacement
788 }
789 Some(Instruction {
790 op: InstOp::Call(_callee, args),
791 ..
792 }) => {
793 for arg in args {
794 if *arg == to_replace {
795 *arg = replacement;
796 }
797 }
798 }
799 _ => panic!("Unexpected instruction type"),
800 }
801 }
802 new_insts.push(inst);
803 }
804
805 block.take_body(context, new_insts);
807 }
808 }
809
810 Ok(modified)
811}
812
813struct Candidate {
814 load_val: Value,
815 store_val: Value,
816 dst_ptr: Value,
817 src_ptr: Value,
818}
819
820enum CandidateKind {
821 ClobberedNoncopyType(Candidate),
826 NonClobbered(Candidate),
827}
828
829fn is_clobbered(
834 context: &Context,
835 start_inst: &Value,
836 end_inst: &Value,
837 no_overlap_ptr: &Value,
838 scrutiny_ptr: &Value,
839) -> bool {
840 let end_block = end_inst.get_instruction(context).unwrap().parent;
841 let entry_block = end_block.get_function(context).get_entry_block(context);
842
843 let mut iter = end_block
844 .instruction_iter(context)
845 .rev()
846 .skip_while(|i| i != end_inst);
847 assert!(iter.next().unwrap() == *end_inst);
848
849 let ReferredSymbols::Complete(scrutiny_symbols) = get_referred_symbols(context, *scrutiny_ptr)
850 else {
851 return true;
852 };
853
854 let ReferredSymbols::Complete(no_overlap_symbols) =
855 get_referred_symbols(context, *no_overlap_ptr)
856 else {
857 return true;
858 };
859
860 if scrutiny_symbols
863 .intersection(&no_overlap_symbols)
864 .next()
865 .is_some()
866 {
867 return true;
868 }
869
870 let mut worklist: Vec<(Block, Box<dyn Iterator<Item = Value>>)> =
873 vec![(end_block, Box::new(iter))];
874 let mut visited = FxHashSet::default();
875 'next_job: while let Some((block, iter)) = worklist.pop() {
876 visited.insert(block);
877 for inst in iter {
878 if inst == *start_inst || inst == *end_inst {
879 continue 'next_job;
881 }
882 let stored_syms = get_stored_symbols(context, inst);
883 if let ReferredSymbols::Complete(syms) = stored_syms {
884 if syms.iter().any(|sym| scrutiny_symbols.contains(sym)) {
885 return true;
886 }
887 } else {
888 return true;
889 }
890 }
891
892 if entry_block == block {
893 if scrutiny_symbols
896 .iter()
897 .any(|sym| matches!(sym, Symbol::Arg(_)))
898 {
899 return true;
900 }
901 }
902
903 for pred in block.pred_iter(context) {
904 if !visited.contains(pred) {
905 worklist.push((
906 *pred,
907 Box::new(pred.instruction_iter(context).rev().skip_while(|_| false)),
908 ));
909 }
910 }
911 }
912
913 false
914}
915
916fn is_copy_type(ty: &Type, context: &Context) -> bool {
918 ty.is_unit(context)
919 || ty.is_never(context)
920 || ty.is_bool(context)
921 || ty.is_ptr(context)
922 || ty.get_uint_width(context).map(|x| x < 256).unwrap_or(false)
923}
924
925fn load_store_to_memcopy(context: &mut Context, function: Function) -> Result<bool, IrError> {
926 let candidates = function
929 .instruction_iter(context)
930 .filter_map(|(_, store_instr_val)| {
931 store_instr_val
932 .get_instruction(context)
933 .and_then(|instr| {
934 if let Instruction {
936 op:
937 InstOp::Store {
938 dst_val_ptr,
939 stored_val,
940 },
941 ..
942 } = instr
943 {
944 stored_val
945 .get_instruction(context)
946 .map(|src_instr| (*stored_val, src_instr, dst_val_ptr))
947 } else {
948 None
949 }
950 })
951 .and_then(|(src_instr_val, src_instr, dst_val_ptr)| {
952 if let Instruction {
954 op: InstOp::Load(src_val_ptr),
955 ..
956 } = src_instr
957 {
958 Some(Candidate {
959 load_val: src_instr_val,
960 store_val: store_instr_val,
961 dst_ptr: *dst_val_ptr,
962 src_ptr: *src_val_ptr,
963 })
964 } else {
965 None
966 }
967 })
968 .and_then(|candidate @ Candidate { dst_ptr, .. }| {
969 if !is_clobbered(
971 context,
972 &candidate.load_val,
973 &candidate.store_val,
974 &candidate.dst_ptr,
975 &candidate.src_ptr,
976 ) {
977 Some(CandidateKind::NonClobbered(candidate))
978 } else if !is_copy_type(&dst_ptr.match_ptr_type(context).unwrap(), context) {
979 Some(CandidateKind::ClobberedNoncopyType(candidate))
980 } else {
981 None
982 }
983 })
984 })
985 .collect::<Vec<_>>();
986
987 if candidates.is_empty() {
988 return Ok(false);
989 }
990
991 for candidate in candidates {
992 match candidate {
993 CandidateKind::ClobberedNoncopyType(Candidate {
994 load_val,
995 store_val,
996 dst_ptr,
997 src_ptr,
998 }) => {
999 let load_block = load_val.get_instruction(context).unwrap().parent;
1000 let temp = function.new_unique_local_var(
1001 context,
1002 "__aggr_memcpy_0".into(),
1003 src_ptr.match_ptr_type(context).unwrap(),
1004 None,
1005 true,
1006 );
1007 let temp_local =
1008 Value::new_instruction(context, load_block, InstOp::GetLocal(temp));
1009 let to_temp = Value::new_instruction(
1010 context,
1011 load_block,
1012 InstOp::MemCopyVal {
1013 dst_val_ptr: temp_local,
1014 src_val_ptr: src_ptr,
1015 },
1016 );
1017 let mut inserter = InstructionInserter::new(
1018 context,
1019 load_block,
1020 crate::InsertionPosition::After(load_val),
1021 );
1022 inserter.insert_slice(&[temp_local, to_temp]);
1023
1024 let store_block = store_val.get_instruction(context).unwrap().parent;
1025 let mem_copy_val = Value::new_instruction(
1026 context,
1027 store_block,
1028 InstOp::MemCopyVal {
1029 dst_val_ptr: dst_ptr,
1030 src_val_ptr: temp_local,
1031 },
1032 );
1033 store_block.replace_instruction(context, store_val, mem_copy_val, true)?;
1034 }
1035 CandidateKind::NonClobbered(Candidate {
1036 dst_ptr: dst_val_ptr,
1037 src_ptr: src_val_ptr,
1038 store_val,
1039 ..
1040 }) => {
1041 let store_block = store_val.get_instruction(context).unwrap().parent;
1042 let mem_copy_val = Value::new_instruction(
1043 context,
1044 store_block,
1045 InstOp::MemCopyVal {
1046 dst_val_ptr,
1047 src_val_ptr,
1048 },
1049 );
1050 store_block.replace_instruction(context, store_val, mem_copy_val, true)?;
1051 }
1052 }
1053 }
1054
1055 Ok(true)
1056}
1057
1058pub const MEMCPYPROP_REVERSE_NAME: &str = "memcpyprop_reverse";
1059
1060pub fn create_memcpyprop_reverse_pass() -> Pass {
1061 Pass {
1062 name: MEMCPYPROP_REVERSE_NAME,
1063 descr: "Copy propagation of MemCpy instructions",
1064 deps: vec![],
1065 runner: ScopedPass::FunctionPass(PassMutability::Transform(copy_prop_reverse)),
1066 }
1067}
1068
1069fn copy_prop_reverse(
1071 context: &mut Context,
1072 _analyses: &AnalysisResults,
1073 function: Function,
1074) -> Result<bool, IrError> {
1075 let mut modified = false;
1076
1077 let mut stores_map: FxHashMap<Symbol, Vec<Value>> = FxHashMap::default();
1079 let mut loads_map: FxHashMap<Symbol, Vec<Value>> = FxHashMap::default();
1080 for (_block, instr_val) in function.instruction_iter(context) {
1081 let stored_syms = get_stored_symbols(context, instr_val);
1082 let stored_syms = match stored_syms {
1083 ReferredSymbols::Complete(syms) => syms,
1084 ReferredSymbols::Incomplete(_) => return Ok(false),
1085 };
1086 let loaded_syms = get_loaded_symbols(context, instr_val);
1087 let loaded_syms = match loaded_syms {
1088 ReferredSymbols::Complete(syms) => syms,
1089 ReferredSymbols::Incomplete(_) => return Ok(false),
1090 };
1091 for sym in stored_syms {
1092 stores_map.entry(sym).or_default().push(instr_val);
1093 }
1094 for sym in loaded_syms {
1095 loads_map.entry(sym).or_default().push(instr_val);
1096 }
1097 }
1098
1099 let mut candidates = vec![];
1100
1101 for (_block, inst) in function.instruction_iter(context) {
1102 let Some((dst_ptr, src_ptr, byte_len)) = deconstruct_memcpy(context, inst) else {
1103 continue;
1104 };
1105
1106 if dst_ptr.get_type(context) != src_ptr.get_type(context) {
1107 continue;
1108 }
1109
1110 let dst_sym = match get_referred_symbols(context, dst_ptr) {
1116 ReferredSymbols::Complete(syms) if syms.len() == 1 => syms.into_iter().next().unwrap(),
1117 _ => continue,
1118 };
1119 let src_sym = match get_referred_symbols(context, src_ptr) {
1120 ReferredSymbols::Complete(syms) if syms.len() == 1 => syms.into_iter().next().unwrap(),
1121 _ => continue,
1122 };
1123
1124 if dst_sym.get_type(context) != src_sym.get_type(context) {
1125 continue;
1126 }
1127
1128 if dst_sym
1130 .get_type(context)
1131 .get_pointee_type(context)
1132 .expect("All symbols must be pointer types")
1133 .size(context)
1134 .in_bytes()
1135 != byte_len
1136 {
1137 continue;
1138 }
1139
1140 let source_uses_not_clobbered = loads_map
1144 .get(&src_sym)
1145 .map(|uses| {
1146 uses.iter().all(|use_val: &Value| {
1147 *use_val == inst || !is_clobbered(context, &inst, use_val, &src_ptr, &dst_ptr)
1148 })
1149 })
1150 .unwrap_or(true);
1151
1152 let destination_uses_not_clobbered = loads_map
1156 .get(&dst_sym)
1157 .map(|uses| {
1158 uses.iter()
1159 .all(|use_val| !is_clobbered(context, &inst, use_val, &dst_ptr, &src_ptr))
1160 })
1161 .unwrap_or(true);
1162
1163 if source_uses_not_clobbered && destination_uses_not_clobbered {
1164 candidates.push((inst, dst_sym, src_sym));
1165 }
1166 }
1167
1168 if candidates.is_empty() {
1169 return Ok(false);
1170 }
1171
1172 let mut to_delete: FxHashSet<Value> = FxHashSet::default();
1173 let mut src_to_dst: FxHashMap<Symbol, Symbol> = FxHashMap::default();
1174
1175 for (inst, dst_sym, src_sym) in candidates {
1176 match src_sym {
1177 Symbol::Arg(_) => {
1178 continue;
1183 }
1184 Symbol::Local(local) => {
1185 if local.get_initializer(context).is_some() {
1186 continue;
1192 }
1193 match src_to_dst.entry(src_sym) {
1194 std::collections::hash_map::Entry::Vacant(e) => {
1195 e.insert(dst_sym);
1196 }
1197 std::collections::hash_map::Entry::Occupied(e) => {
1198 if *e.get() != dst_sym {
1199 continue;
1201 }
1202 }
1203 }
1204 to_delete.insert(inst);
1205 }
1206 }
1207 }
1208
1209 {
1211 let mut changed = true;
1212 let mut cycle_detected = false;
1213 while changed {
1214 changed = false;
1215 src_to_dst.clone().iter().for_each(|(src, dst)| {
1216 if let Some(next_dst) = src_to_dst.get(dst) {
1217 if *next_dst == *src {
1219 cycle_detected = true;
1220 return;
1221 }
1222 src_to_dst.insert(*src, *next_dst);
1223 changed = true;
1224 }
1225 });
1226 }
1227 if cycle_detected {
1228 return Ok(modified);
1230 }
1231 }
1232
1233 let mut repl_locals = vec![];
1235 for (_block, inst) in function.instruction_iter(context) {
1236 match inst.get_instruction(context).unwrap() {
1237 Instruction {
1238 op: InstOp::GetLocal(sym),
1239 ..
1240 } => {
1241 if let Some(dst) = src_to_dst.get(&Symbol::Local(*sym)) {
1242 repl_locals.push((inst, *dst));
1243 }
1244 }
1245 _ => {
1246 }
1249 }
1250 }
1251
1252 if repl_locals.is_empty() {
1253 return Ok(modified);
1254 }
1255 modified = true;
1256
1257 let mut value_replacements = FxHashMap::default();
1258 for (to_repl, repl_with) in repl_locals {
1259 let Instruction {
1260 op: InstOp::GetLocal(sym),
1261 ..
1262 } = to_repl.get_instruction_mut(context).unwrap()
1263 else {
1264 panic!("Expected GetLocal instruction");
1265 };
1266 match repl_with {
1267 Symbol::Local(dst_local) => {
1268 *sym = dst_local;
1270 }
1271 Symbol::Arg(arg) => {
1272 value_replacements.insert(to_repl, arg.as_value(context));
1274 }
1275 }
1276 }
1277
1278 function.replace_values(context, &value_replacements, None);
1280
1281 for (_, inst) in function.instruction_iter(context) {
1289 let Some((dst_ptr, src_ptr, _byte_len)) = deconstruct_memcpy(context, inst) else {
1290 continue;
1291 };
1292
1293 let dst_sym = match get_referred_symbols(context, dst_ptr) {
1294 ReferredSymbols::Complete(syms) if syms.len() == 1 => syms.into_iter().next().unwrap(),
1295 _ => continue,
1296 };
1297 let src_sym = match get_referred_symbols(context, src_ptr) {
1298 ReferredSymbols::Complete(syms) if syms.len() == 1 => syms.into_iter().next().unwrap(),
1299 _ => continue,
1300 };
1301
1302 if dst_sym == src_sym {
1303 to_delete.insert(inst);
1304 }
1305 }
1306
1307 function.remove_instructions(context, |v| to_delete.contains(&v));
1308
1309 Ok(modified)
1310}