1use crate::ir::instruction::downcast_ref;
18use crate::{
19 analysis::{
20 dominator_tree::DominatorTree,
21 effect_analysis::{Effect, EffectAnalysis},
22 loop_tools::{self, LoopForest, LoopPtr},
23 },
24 ir::{
25 instruction::{
26 downcast_mut,
27 misc_inst::{ICmp, ICmpOp, Phi},
28 InstType,
29 },
30 BBPtr, Constant, InstPtr, Operand, ValueType,
31 },
32 Program,
33};
34use anyhow::Result;
35use duskphantom_utils::cprintln;
36use std::collections::{HashMap, HashSet};
37
38use super::{loop_simplify, Transform};
39
40pub fn optimize_program<const N_THREAD: i32>(program: &mut Program) -> Result<bool> {
41 let mut changed = false;
42 let effect_analysis = EffectAnalysis::new(program);
43 for func in program.module.functions.clone() {
44 let Some(mut forest) = loop_tools::LoopForest::make_forest(func) else {
45 continue;
46 };
47 loop_simplify::LoopSimplifier::new(&mut program.mem_pool).run(&mut forest)?;
48 let mut dom_tree = DominatorTree::new(func);
49 changed |=
50 MakeParallel::<N_THREAD>::new(program, &mut forest, &mut dom_tree, &effect_analysis)
51 .run_and_log()?;
52 }
53 Ok(changed)
54}
55
56pub struct MakeParallel<'a, const N_THREAD: i32> {
57 program: &'a mut Program,
58 loop_forest: &'a mut LoopForest,
59 dom_tree: &'a mut DominatorTree,
60 effect_analysis: &'a EffectAnalysis,
61 stack_ref: HashMap<LoopPtr, HashSet<InstPtr>>,
62}
63
64impl<'a, const N_THREAD: i32> Transform for MakeParallel<'a, N_THREAD> {
65 fn get_program_mut(&mut self) -> &mut Program {
66 self.program
67 }
68
69 fn name() -> String {
70 "make_parallel".to_string()
71 }
72
73 fn run(&mut self) -> Result<bool> {
74 let mut changed = false;
75 for lo in self.loop_forest.forest.clone() {
76 self.check_stack_reference(lo)?;
77 }
78 for lo in self.loop_forest.forest.clone() {
79 let mut candidate = Vec::new();
80 self.make_candidate(&mut candidate, lo)?;
81 for c in candidate {
82 changed |= self.make_parallel(c)?;
83 }
84 }
85 Ok(changed)
86 }
87}
88
89impl<'a, const N_THREAD: i32> MakeParallel<'a, N_THREAD> {
90 pub fn new(
91 program: &'a mut Program,
92 loop_forest: &'a mut LoopForest,
93 dom_tree: &'a mut DominatorTree,
94 effect_analysis: &'a EffectAnalysis,
95 ) -> Self {
96 Self {
97 program,
98 loop_forest,
99 dom_tree,
100 effect_analysis,
101 stack_ref: HashMap::new(),
102 }
103 }
104
105 fn check_stack_reference(&mut self, lo: LoopPtr) -> Result<()> {
106 for bb in &lo.blocks {
107 for inst in bb.iter() {
108 if let Some(inst) = get_base_alloc(inst) {
109 self.stack_ref.entry(lo).or_default().insert(inst);
110 }
111 }
112 }
113 for sub_loop in lo.sub_loops.iter() {
114 self.check_stack_reference(*sub_loop)?;
115 let Some(sub_ref) = self.stack_ref.get(sub_loop).cloned() else {
116 continue;
117 };
118 self.stack_ref.entry(lo).or_default().extend(sub_ref);
119 }
120 Ok(())
121 }
122
123 fn get_loop_effect(&mut self, lo: LoopPtr, indvar: &Operand) -> Result<Option<Effect>> {
124 let mut effect = Effect::new();
125 for bb in &lo.blocks {
126 let Some(bb_effect) = self.get_block_effect(*bb, indvar)? else {
127 return Ok(None);
128 };
129 if !merge_effect(&mut effect, &bb_effect, indvar)? {
130 return Ok(None);
131 }
132 }
133
134 for sub_loop in lo.sub_loops.iter() {
136 let Some(sub_effect) = self.get_loop_effect(*sub_loop, indvar)? else {
137 return Ok(None);
138 };
139 if !merge_effect(&mut effect, &sub_effect, indvar)? {
140 return Ok(None);
141 }
142 }
143 Ok(Some(effect))
144 }
145
146 fn get_block_effect(&mut self, bb: BBPtr, indvar: &Operand) -> Result<Option<Effect>> {
147 let mut effect = Effect::new();
148 for inst in bb.iter() {
149 if self.effect_analysis.has_io(inst) {
151 return Ok(None);
152 }
153
154 if let Some(inst_effect) = &self.effect_analysis.inst_effect.get(&inst) {
156 if !merge_effect(&mut effect, inst_effect, indvar)? {
157 return Ok(None);
158 }
159 }
160 }
161 Ok(Some(effect))
162 }
163
164 fn make_candidate(&mut self, result: &mut Vec<Candidate>, lo: LoopPtr) -> Result<()> {
165 #[allow(unused)]
166 let pre_header = lo.pre_header.unwrap();
167
168 let mut exit = Vec::new();
172 get_exit_inst(lo, lo, &mut exit);
173
174 if exit.len() != 1 {
176 cprintln!("[INFO] loop {} has multiple exit edges", pre_header.name);
177 return Ok(());
178 }
179
180 let exit = exit[0];
183 if pre_header.get_succ_bb() != &vec![exit.get_parent_bb().unwrap()] {
184 cprintln!(
185 "[INFO] loop {}'s pred is not {}",
186 pre_header.name,
187 pre_header.name
188 );
189 return Ok(());
190 }
191
192 let Some(candidate) = Candidate::from_exit(exit, lo, self.dom_tree) else {
194 cprintln!("[INFO] loop {} does not have indvar", pre_header.name);
195 return Ok(());
196 };
197
198 if self
200 .get_loop_effect(lo, &candidate.indvar.into())?
201 .is_none()
202 {
203 cprintln!("[INFO] loop {} has conflict effect", pre_header.name);
204 return Ok(());
205 }
206
207 cprintln!(
209 "[INFO] loop {} is made candidate {}!",
210 pre_header.name,
211 candidate.dump()
212 );
213 result.push(candidate);
214 Ok(())
215 }
216
217 fn make_parallel(&mut self, mut candidate: Candidate) -> Result<bool> {
218 let mut map = HashMap::new();
220 if let Some(stack_ref) = self.stack_ref.get(&candidate.lo) {
221 let mut vec = stack_ref.iter().cloned().collect::<Vec<_>>();
222 vec.sort_by_key(|inst| inst.get_id());
223 for inst in vec {
224 let gep_zero = self.program.mem_pool.get_getelementptr(
225 inst.get_value_type().get_sub_type().cloned().unwrap(),
226 inst.into(),
227 vec![Constant::Int(0).into()],
228 );
229 candidate.init_bb.get_last_inst().insert_before(gep_zero);
230 map.insert(inst, gep_zero);
231 }
232 }
233 replace_stack_reference(candidate.lo, &map)?;
234
235 let func_create = self
237 .program
238 .module
239 .functions
240 .iter()
241 .find(|f| f.name == "thrd_create")
242 .unwrap();
243 let inst_create = self
244 .program
245 .mem_pool
246 .get_call(*func_create, vec![Constant::Int(N_THREAD - 1).into()]);
247 candidate.init_bb.get_last_inst().insert_before(inst_create);
248
249 let i = candidate.init_val;
260 let e = candidate.exit_val;
261
262 let inst_sub = self.program.mem_pool.get_sub(e.clone(), i.clone());
264 candidate.init_bb.get_last_inst().insert_before(inst_sub);
265
266 let inst_mul = self
268 .program
269 .mem_pool
270 .get_mul(inst_create.into(), inst_sub.into());
271 candidate.init_bb.get_last_inst().insert_before(inst_mul);
272
273 let inst_div = self
275 .program
276 .mem_pool
277 .get_sdiv(inst_mul.into(), Constant::Int(N_THREAD).into());
278 candidate.init_bb.get_last_inst().insert_before(inst_div);
279
280 let inst_lb = self.program.mem_pool.get_add(inst_div.into(), i.clone());
282 candidate.init_bb.get_last_inst().insert_before(inst_lb);
283
284 let inst_add = self
286 .program
287 .mem_pool
288 .get_add(inst_mul.into(), inst_sub.into());
289 candidate.init_bb.get_last_inst().insert_before(inst_add);
290
291 let inst_div = self
293 .program
294 .mem_pool
295 .get_sdiv(inst_add.into(), Constant::Int(N_THREAD).into());
296 candidate.init_bb.get_last_inst().insert_before(inst_div);
297
298 let inst_ub = self.program.mem_pool.get_add(inst_div.into(), i.clone());
300 candidate.init_bb.get_last_inst().insert_before(inst_ub);
301
302 let phi = downcast_mut::<Phi>(candidate.indvar.as_mut());
304 phi.replace_incoming_value_at(candidate.init_bb, inst_lb.into());
305
306 let inst_cond = self.program.mem_pool.get_icmp(
308 ICmpOp::Slt,
309 ValueType::Int,
310 candidate.indvar.into(),
311 inst_ub.into(),
312 );
313 candidate.exit.insert_before(inst_cond);
314 candidate.exit.set_operand(0, inst_cond.into());
315
316 let func_join = self
318 .program
319 .module
320 .functions
321 .iter()
322 .find(|f| f.name == "thrd_join")
323 .unwrap();
324 let mut inst_join = self.program.mem_pool.get_call(*func_join, vec![]);
325 candidate.exit_bb.push_front(inst_join);
326
327 let mut inst_sub = self
330 .program
331 .mem_pool
332 .get_sub(inst_sub.into(), Constant::Int(1).into());
333 inst_join.insert_after(inst_sub);
334
335 let mut inst_div = self
337 .program
338 .mem_pool
339 .get_sdiv(inst_sub.into(), Constant::Int(candidate.delta).into());
340 inst_sub.insert_after(inst_div);
341
342 let mut inst_add = self
344 .program
345 .mem_pool
346 .get_add(inst_div.into(), Constant::Int(1).into());
347 inst_div.insert_after(inst_add);
348
349 let mut inst_mul = self
351 .program
352 .mem_pool
353 .get_mul(inst_add.into(), Constant::Int(candidate.delta).into());
354 inst_add.insert_after(inst_mul);
355
356 let inst_pred = self.program.mem_pool.get_add(inst_mul.into(), i.clone());
358 inst_mul.insert_after(inst_pred);
359
360 for mut user in candidate.indvar.get_user().iter().cloned() {
362 if !candidate.lo.is_in_loop(&user.get_parent_bb().unwrap()) {
363 user.replace_operand(&candidate.indvar.into(), &inst_pred.into());
364 }
365 }
366 Ok(true)
367 }
368}
369
370fn get_exit_inst(lo: LoopPtr, parent: LoopPtr, result: &mut Vec<InstPtr>) {
372 for bb in &lo.blocks {
373 if bb.get_succ_bb().iter().any(|bb| !parent.is_in_loop(bb)) {
374 result.push(bb.get_last_inst());
375 }
376 }
377 for sub_loop in lo.sub_loops.iter() {
378 get_exit_inst(*sub_loop, parent, result);
379 }
380}
381
382fn get_base_alloc(inst: InstPtr) -> Option<InstPtr> {
384 if inst.get_type() == InstType::Alloca {
385 return Some(inst);
386 }
387 match inst.get_type() {
388 InstType::Load | InstType::GetElementPtr => {
389 let ptr = inst.get_operand().first()?;
390 if let Operand::Instruction(ptr) = ptr {
391 get_base_alloc(*ptr)
392 } else {
393 None
394 }
395 }
396 InstType::Store => {
397 let base = inst.get_operand().get(1)?;
398 if let Operand::Instruction(inst) = base {
399 get_base_alloc(*inst)
400 } else {
401 None
402 }
403 }
404 _ => None,
405 }
406}
407
408fn replace_stack_reference(lo: LoopPtr, map: &HashMap<InstPtr, InstPtr>) -> Result<()> {
410 for bb in &lo.blocks {
411 for inst in bb.iter() {
412 replace_base_alloc(inst, map);
413 }
414 }
415 for sub_loop in lo.sub_loops.iter() {
416 replace_stack_reference(*sub_loop, map)?;
417 }
418 Ok(())
419}
420
421fn replace_base_alloc(mut inst: InstPtr, map: &HashMap<InstPtr, InstPtr>) {
423 match inst.get_type() {
424 InstType::Load => {
425 let ptr = inst.get_operand().first().unwrap();
426 if let Operand::Instruction(ptr) = ptr.clone() {
427 if ptr.get_type() == InstType::Alloca {
428 inst.set_operand(0, map[&ptr].into());
429 }
430 }
431 }
432 InstType::GetElementPtr => {
433 if inst.get_operand().len() == 2 {
434 return;
436 }
437 let ptr = inst.get_operand().first().unwrap();
438 if let Operand::Instruction(ptr) = ptr.clone() {
439 if ptr.get_type() == InstType::Alloca {
440 inst.set_operand(0, map[&ptr].into());
441 }
442 }
443 }
444 InstType::Store => {
445 let base = inst.get_operand().get(1).unwrap();
446 if let Operand::Instruction(ptr) = base.clone() {
447 if ptr.get_type() == InstType::Alloca {
448 inst.set_operand(1, map[&ptr].into());
449 }
450 }
451 }
452 _ => (),
453 }
454}
455
456fn merge_effect(a: &mut Effect, b: &Effect, indvar: &Operand) -> Result<bool> {
459 if a.def_range.can_conflict(&b.def_range, indvar)
460 || a.use_range.can_conflict(&b.def_range, indvar)
461 || a.def_range.can_conflict(&b.use_range, indvar)
462 || b.def_range.can_conflict(&b.use_range, indvar)
463 || b.def_range.can_conflict(&b.def_range, indvar)
464 {
465 cprintln!(
466 "[INFO] failed to merge {} with {} (indvar = {})",
467 a.dump(),
468 b.dump(),
469 indvar
470 );
471 return Ok(false);
472 }
473 a.def_range.merge(&b.def_range);
474 a.use_range.merge(&b.use_range);
475 Ok(true)
476}
477
478struct Candidate {
492 lo: LoopPtr,
493 indvar: InstPtr,
494 exit: InstPtr,
495 delta: i32,
496 init_val: Operand,
497 init_bb: BBPtr,
498 exit_val: Operand,
499 exit_bb: BBPtr,
500}
501
502impl Candidate {
503 #[allow(clippy::too_many_arguments)]
504 fn new(
505 lo: LoopPtr,
506 indvar: InstPtr,
507 exit: InstPtr,
508 delta: i32,
509 init_val: Operand,
510 init_bb: BBPtr,
511 exit_val: Operand,
512 exit_bb: BBPtr,
513 ) -> Self {
514 Self {
515 lo,
516 indvar,
517 exit,
518 delta,
519 init_val,
520 init_bb,
521 exit_val,
522 exit_bb,
523 }
524 }
525
526 #[allow(unused)]
528 fn dump(&self) -> String {
529 format!(
530 "Candidate {{\n indvar: {},\n exit: {},\n init_val: {},\n init_bb: {},\n exit_val: {},\n exit_bb: {},\n}}",
531 self.indvar.gen_llvm_ir(),
532 self.exit.gen_llvm_ir(),
533 self.init_val,
534 self.init_bb.name,
535 self.exit_val,
536 self.exit_bb.name,
537 )
538 }
539
540 fn from_exit(exit: InstPtr, lo: LoopPtr, dom_tree: &mut DominatorTree) -> Option<Self> {
544 let pre_header = lo.pre_header.unwrap();
545 if exit.get_type() != InstType::Br {
546 cprintln!(
547 "[INFO] loop {} fails because {} is not br",
548 pre_header.name,
549 exit.gen_llvm_ir()
550 );
551 return None;
552 }
553
554 let parent_bb = exit.get_parent_bb()?;
556 let exit_bb = parent_bb
557 .get_succ_bb()
558 .iter()
559 .find(|bb| !lo.is_in_loop(bb))?;
560
561 if exit_bb.get_pred_bb().len() != 1 {
564 cprintln!(
565 "[INFO] loop {} fails because {} has multiple preds",
566 pre_header.name,
567 exit_bb.name
568 );
569 return None;
570 }
571
572 let Operand::Instruction(cond) = exit.get_operand().first()? else {
575 cprintln!(
576 "[INFO] loop {} fails because {}'s first operand is not inst",
577 pre_header.name,
578 exit.gen_llvm_ir()
579 );
580 return None;
581 };
582 let InstType::ICmp = cond.get_type() else {
583 cprintln!(
584 "[INFO] loop {} fails because {} is not condition",
585 pre_header.name,
586 cond.gen_llvm_ir()
587 );
588 return None;
589 };
590 let icmp = downcast_ref::<ICmp>(cond.as_ref().as_ref());
591 if icmp.op != ICmpOp::Slt {
592 cprintln!(
593 "[INFO] loop {} fails because {} is not slt",
594 pre_header.name,
595 cond.gen_llvm_ir()
596 );
597 return None;
598 }
599 let Operand::Instruction(indvar) = icmp.get_lhs() else {
600 cprintln!(
601 "[INFO] loop {} fails because {}'s lhs is not inst",
602 pre_header.name,
603 cond.gen_llvm_ir()
604 );
605 return None;
606 };
607 let exit_val = icmp.get_rhs().clone();
608
609 if let Operand::Instruction(inst) = exit_val {
611 if !dom_tree.is_dominate(inst.get_parent_bb().unwrap(), pre_header) {
612 cprintln!(
613 "[INFO] loop {} fails because {} is not calculated before loop",
614 pre_header.name,
615 inst.gen_llvm_ir()
616 );
617 return None;
618 }
619 }
620
621 if indvar.get_type() != InstType::Phi {
626 cprintln!(
627 "[INFO] loop {} fails because {} is not phi",
628 pre_header.name,
629 indvar.gen_llvm_ir()
630 );
631 return None;
632 }
633 let phi = downcast_ref::<Phi>(indvar.as_ref().as_ref());
634 let inc = phi.get_incoming_values();
635 if inc.len() != 2 {
636 cprintln!(
637 "[INFO] loop {} fails because {}'s incoming value length is not 2",
638 pre_header.name,
639 indvar.gen_llvm_ir()
640 );
641 return None;
642 }
643 let init_val = inc[0].0.clone();
644 let init_bb = lo.pre_header?;
645 if init_bb != inc[0].1 {
646 cprintln!(
647 "[INFO] loop {} fails because {} is not pre_header",
648 pre_header.name,
649 inc[0].1.name.clone()
650 );
651 return None;
652 }
653 let Operand::Instruction(next_val) = inc[1].0 else {
654 cprintln!(
655 "[INFO] loop {} fails because {}'s second incoming value is not inst",
656 pre_header.name,
657 indvar.gen_llvm_ir()
658 );
659 return None;
660 };
661 if next_val.get_type() != InstType::Add {
662 cprintln!(
663 "[INFO] loop {} fails because {} is not add",
664 pre_header.name,
665 next_val.gen_llvm_ir()
666 );
667 return None;
668 }
669 let Operand::Constant(Constant::Int(delta)) = next_val.get_operand()[1] else {
670 cprintln!(
671 "[INFO] loop {} fails because {}'s second operand is not constant",
672 pre_header.name,
673 next_val.gen_llvm_ir()
674 );
675 return None;
676 };
677 let next_bb = inc[1].1;
678 if !lo.is_in_loop(&next_bb) {
679 cprintln!(
680 "[INFO] loop {} fails because {} is not in loop",
681 pre_header.name,
682 next_bb.name
683 );
684 return None;
685 }
686 for inst in indvar.get_parent_bb().unwrap().iter() {
687 if inst.get_type() == InstType::Phi && inst != *indvar {
688 cprintln!(
689 "[INFO] loop {} fails because {} has multiple phi",
690 pre_header.name,
691 inst.gen_llvm_ir()
692 );
693 return None;
694 }
695 }
696
697 Some(Self::new(
699 lo, *indvar, exit, delta, init_val, init_bb, exit_val, *exit_bb,
700 ))
701 }
702}