duskphantom_middle/transform/
make_parallel.rs

1// Copyright 2024 Duskphantom Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// SPDX-License-Identifier: Apache-2.0
16
17use crate::ir::instruction::downcast_ref;
18use crate::{
19    analysis::{
20        dominator_tree::DominatorTree,
21        effect_analysis::{Effect, EffectAnalysis},
22        loop_tools::{self, LoopForest, LoopPtr},
23    },
24    ir::{
25        instruction::{
26            downcast_mut,
27            misc_inst::{ICmp, ICmpOp, Phi},
28            InstType,
29        },
30        BBPtr, Constant, InstPtr, Operand, ValueType,
31    },
32    Program,
33};
34use anyhow::Result;
35use duskphantom_utils::cprintln;
36use std::collections::{HashMap, HashSet};
37
38use super::{loop_simplify, Transform};
39
40pub fn optimize_program<const N_THREAD: i32>(program: &mut Program) -> Result<bool> {
41    let mut changed = false;
42    let effect_analysis = EffectAnalysis::new(program);
43    for func in program.module.functions.clone() {
44        let Some(mut forest) = loop_tools::LoopForest::make_forest(func) else {
45            continue;
46        };
47        loop_simplify::LoopSimplifier::new(&mut program.mem_pool).run(&mut forest)?;
48        let mut dom_tree = DominatorTree::new(func);
49        changed |=
50            MakeParallel::<N_THREAD>::new(program, &mut forest, &mut dom_tree, &effect_analysis)
51                .run_and_log()?;
52    }
53    Ok(changed)
54}
55
56pub struct MakeParallel<'a, const N_THREAD: i32> {
57    program: &'a mut Program,
58    loop_forest: &'a mut LoopForest,
59    dom_tree: &'a mut DominatorTree,
60    effect_analysis: &'a EffectAnalysis,
61    stack_ref: HashMap<LoopPtr, HashSet<InstPtr>>,
62}
63
64impl<'a, const N_THREAD: i32> Transform for MakeParallel<'a, N_THREAD> {
65    fn get_program_mut(&mut self) -> &mut Program {
66        self.program
67    }
68
69    fn name() -> String {
70        "make_parallel".to_string()
71    }
72
73    fn run(&mut self) -> Result<bool> {
74        let mut changed = false;
75        for lo in self.loop_forest.forest.clone() {
76            self.check_stack_reference(lo)?;
77        }
78        for lo in self.loop_forest.forest.clone() {
79            let mut candidate = Vec::new();
80            self.make_candidate(&mut candidate, lo)?;
81            for c in candidate {
82                changed |= self.make_parallel(c)?;
83            }
84        }
85        Ok(changed)
86    }
87}
88
89impl<'a, const N_THREAD: i32> MakeParallel<'a, N_THREAD> {
90    pub fn new(
91        program: &'a mut Program,
92        loop_forest: &'a mut LoopForest,
93        dom_tree: &'a mut DominatorTree,
94        effect_analysis: &'a EffectAnalysis,
95    ) -> Self {
96        Self {
97            program,
98            loop_forest,
99            dom_tree,
100            effect_analysis,
101            stack_ref: HashMap::new(),
102        }
103    }
104
105    fn check_stack_reference(&mut self, lo: LoopPtr) -> Result<()> {
106        for bb in &lo.blocks {
107            for inst in bb.iter() {
108                if let Some(inst) = get_base_alloc(inst) {
109                    self.stack_ref.entry(lo).or_default().insert(inst);
110                }
111            }
112        }
113        for sub_loop in lo.sub_loops.iter() {
114            self.check_stack_reference(*sub_loop)?;
115            let Some(sub_ref) = self.stack_ref.get(sub_loop).cloned() else {
116                continue;
117            };
118            self.stack_ref.entry(lo).or_default().extend(sub_ref);
119        }
120        Ok(())
121    }
122
123    fn get_loop_effect(&mut self, lo: LoopPtr, indvar: &Operand) -> Result<Option<Effect>> {
124        let mut effect = Effect::new();
125        for bb in &lo.blocks {
126            let Some(bb_effect) = self.get_block_effect(*bb, indvar)? else {
127                return Ok(None);
128            };
129            if !merge_effect(&mut effect, &bb_effect, indvar)? {
130                return Ok(None);
131            }
132        }
133
134        // Additionally collect effect in sub loops
135        for sub_loop in lo.sub_loops.iter() {
136            let Some(sub_effect) = self.get_loop_effect(*sub_loop, indvar)? else {
137                return Ok(None);
138            };
139            if !merge_effect(&mut effect, &sub_effect, indvar)? {
140                return Ok(None);
141            }
142        }
143        Ok(Some(effect))
144    }
145
146    fn get_block_effect(&mut self, bb: BBPtr, indvar: &Operand) -> Result<Option<Effect>> {
147        let mut effect = Effect::new();
148        for inst in bb.iter() {
149            // Prevent instruction with IO to be parallelized
150            if self.effect_analysis.has_io(inst) {
151                return Ok(None);
152            }
153
154            // Attempt to merge effect with no conflict
155            if let Some(inst_effect) = &self.effect_analysis.inst_effect.get(&inst) {
156                if !merge_effect(&mut effect, inst_effect, indvar)? {
157                    return Ok(None);
158                }
159            }
160        }
161        Ok(Some(effect))
162    }
163
164    fn make_candidate(&mut self, result: &mut Vec<Candidate>, lo: LoopPtr) -> Result<()> {
165        #[allow(unused)]
166        let pre_header = lo.pre_header.unwrap();
167
168        // Get all exit edges
169        // TODO-TLE: ignore all bb with one succ
170        // TODO-TLE: for sub loops, only check for return
171        let mut exit = Vec::new();
172        get_exit_inst(lo, lo, &mut exit);
173
174        // If there are multiple exit edges, then it can't be parallelized
175        if exit.len() != 1 {
176            cprintln!("[INFO] loop {} has multiple exit edges", pre_header.name);
177            return Ok(());
178        }
179
180        // If succ of pre_header is not exit, then it can't be parallelized
181        // We only parallelize while loops instead of do-while loops! (no canonical form and it's hard to analysis)
182        let exit = exit[0];
183        if pre_header.get_succ_bb() != &vec![exit.get_parent_bb().unwrap()] {
184            cprintln!(
185                "[INFO] loop {}'s pred is not {}",
186                pre_header.name,
187                pre_header.name
188            );
189            return Ok(());
190        }
191
192        // Get induction var from exit. If failed, check sub loops instead
193        let Some(candidate) = Candidate::from_exit(exit, lo, self.dom_tree) else {
194            cprintln!("[INFO] loop {} does not have indvar", pre_header.name);
195            return Ok(());
196        };
197
198        // If effect range collides, then it can't be parallelized, check sub loops instead
199        if self
200            .get_loop_effect(lo, &candidate.indvar.into())?
201            .is_none()
202        {
203            cprintln!("[INFO] loop {} has conflict effect", pre_header.name);
204            return Ok(());
205        }
206
207        // Insert candidate to results
208        cprintln!(
209            "[INFO] loop {} is made candidate {}!",
210            pre_header.name,
211            candidate.dump()
212        );
213        result.push(candidate);
214        Ok(())
215    }
216
217    fn make_parallel(&mut self, mut candidate: Candidate) -> Result<bool> {
218        // Copy global array address to local stack with consistent order
219        let mut map = HashMap::new();
220        if let Some(stack_ref) = self.stack_ref.get(&candidate.lo) {
221            let mut vec = stack_ref.iter().cloned().collect::<Vec<_>>();
222            vec.sort_by_key(|inst| inst.get_id());
223            for inst in vec {
224                let gep_zero = self.program.mem_pool.get_getelementptr(
225                    inst.get_value_type().get_sub_type().cloned().unwrap(),
226                    inst.into(),
227                    vec![Constant::Int(0).into()],
228                );
229                candidate.init_bb.get_last_inst().insert_before(gep_zero);
230                map.insert(inst, gep_zero);
231            }
232        }
233        replace_stack_reference(candidate.lo, &map)?;
234
235        // Get current thread ID
236        let func_create = self
237            .program
238            .module
239            .functions
240            .iter()
241            .find(|f| f.name == "thrd_create")
242            .unwrap();
243        let inst_create = self
244            .program
245            .mem_pool
246            .get_call(*func_create, vec![Constant::Int(N_THREAD - 1).into()]);
247        candidate.init_bb.get_last_inst().insert_before(inst_create);
248
249        // Create parallelized exit and indvar
250        //
251        // i = init_val
252        // d = next_delta
253        // e = exit_value
254        // N = N_THREAD
255        // n = current_thread
256        //
257        // Before: i, i + d, i + 2d, ..., i + d(X = (e - i) ceildiv d)
258        // After: [ LB = i + (e-i)n/N, UB = i + ((e-i)n + e-i)/N )
259        let i = candidate.init_val;
260        let e = candidate.exit_val;
261
262        // e - i
263        let inst_sub = self.program.mem_pool.get_sub(e.clone(), i.clone());
264        candidate.init_bb.get_last_inst().insert_before(inst_sub);
265
266        // (e - i) * n
267        let inst_mul = self
268            .program
269            .mem_pool
270            .get_mul(inst_create.into(), inst_sub.into());
271        candidate.init_bb.get_last_inst().insert_before(inst_mul);
272
273        // (e - i) * n / N
274        let inst_div = self
275            .program
276            .mem_pool
277            .get_sdiv(inst_mul.into(), Constant::Int(N_THREAD).into());
278        candidate.init_bb.get_last_inst().insert_before(inst_div);
279
280        // Lower bound: i + (e - i) * n / N
281        let inst_lb = self.program.mem_pool.get_add(inst_div.into(), i.clone());
282        candidate.init_bb.get_last_inst().insert_before(inst_lb);
283
284        // (e - i) * n + e - i
285        let inst_add = self
286            .program
287            .mem_pool
288            .get_add(inst_mul.into(), inst_sub.into());
289        candidate.init_bb.get_last_inst().insert_before(inst_add);
290
291        // ((e - i) * n + e - i) / N
292        let inst_div = self
293            .program
294            .mem_pool
295            .get_sdiv(inst_add.into(), Constant::Int(N_THREAD).into());
296        candidate.init_bb.get_last_inst().insert_before(inst_div);
297
298        // Upper bound: i + ((e - i) * n + e - i) / N
299        let inst_ub = self.program.mem_pool.get_add(inst_div.into(), i.clone());
300        candidate.init_bb.get_last_inst().insert_before(inst_ub);
301
302        // Replace indvar to parallelized indvar
303        let phi = downcast_mut::<Phi>(candidate.indvar.as_mut());
304        phi.replace_incoming_value_at(candidate.init_bb, inst_lb.into());
305
306        // Replace exit condition to parallelized exit condition
307        let inst_cond = self.program.mem_pool.get_icmp(
308            ICmpOp::Slt,
309            ValueType::Int,
310            candidate.indvar.into(),
311            inst_ub.into(),
312        );
313        candidate.exit.insert_before(inst_cond);
314        candidate.exit.set_operand(0, inst_cond.into());
315
316        // Join threads
317        let func_join = self
318            .program
319            .module
320            .functions
321            .iter()
322            .find(|f| f.name == "thrd_join")
323            .unwrap();
324        let mut inst_join = self.program.mem_pool.get_call(*func_join, vec![]);
325        candidate.exit_bb.push_front(inst_join);
326
327        // For out-of-loop indvar, replace with predicted value:
328        // i + ((e - i - 1) / delta + 1) * delta
329        let mut inst_sub = self
330            .program
331            .mem_pool
332            .get_sub(inst_sub.into(), Constant::Int(1).into());
333        inst_join.insert_after(inst_sub);
334
335        // (e - i - 1) / delta
336        let mut inst_div = self
337            .program
338            .mem_pool
339            .get_sdiv(inst_sub.into(), Constant::Int(candidate.delta).into());
340        inst_sub.insert_after(inst_div);
341
342        // (e - i - 1) / delta + 1
343        let mut inst_add = self
344            .program
345            .mem_pool
346            .get_add(inst_div.into(), Constant::Int(1).into());
347        inst_div.insert_after(inst_add);
348
349        // ((e - i - 1) / delta + 1) * delta
350        let mut inst_mul = self
351            .program
352            .mem_pool
353            .get_mul(inst_add.into(), Constant::Int(candidate.delta).into());
354        inst_add.insert_after(inst_mul);
355
356        // i + ((e - i - 1) / delta + 1) * delta
357        let inst_pred = self.program.mem_pool.get_add(inst_mul.into(), i.clone());
358        inst_mul.insert_after(inst_pred);
359
360        // Iterate all indvar users, if not in loop, replace with predicted value
361        for mut user in candidate.indvar.get_user().iter().cloned() {
362            if !candidate.lo.is_in_loop(&user.get_parent_bb().unwrap()) {
363                user.replace_operand(&candidate.indvar.into(), &inst_pred.into());
364            }
365        }
366        Ok(true)
367    }
368}
369
370/// Get all exit `br` in loop.
371fn get_exit_inst(lo: LoopPtr, parent: LoopPtr, result: &mut Vec<InstPtr>) {
372    for bb in &lo.blocks {
373        if bb.get_succ_bb().iter().any(|bb| !parent.is_in_loop(bb)) {
374            result.push(bb.get_last_inst());
375        }
376    }
377    for sub_loop in lo.sub_loops.iter() {
378        get_exit_inst(*sub_loop, parent, result);
379    }
380}
381
382/// Get base pointer of load / store / gep instruction, return if it's alloc.
383fn get_base_alloc(inst: InstPtr) -> Option<InstPtr> {
384    if inst.get_type() == InstType::Alloca {
385        return Some(inst);
386    }
387    match inst.get_type() {
388        InstType::Load | InstType::GetElementPtr => {
389            let ptr = inst.get_operand().first()?;
390            if let Operand::Instruction(ptr) = ptr {
391                get_base_alloc(*ptr)
392            } else {
393                None
394            }
395        }
396        InstType::Store => {
397            let base = inst.get_operand().get(1)?;
398            if let Operand::Instruction(inst) = base {
399                get_base_alloc(*inst)
400            } else {
401                None
402            }
403        }
404        _ => None,
405    }
406}
407
408/// Replace stack reference to copied global array address.
409fn replace_stack_reference(lo: LoopPtr, map: &HashMap<InstPtr, InstPtr>) -> Result<()> {
410    for bb in &lo.blocks {
411        for inst in bb.iter() {
412            replace_base_alloc(inst, map);
413        }
414    }
415    for sub_loop in lo.sub_loops.iter() {
416        replace_stack_reference(*sub_loop, map)?;
417    }
418    Ok(())
419}
420
421/// Replace base pointer of load / store / gep instruction, return if it's alloc.
422fn replace_base_alloc(mut inst: InstPtr, map: &HashMap<InstPtr, InstPtr>) {
423    match inst.get_type() {
424        InstType::Load => {
425            let ptr = inst.get_operand().first().unwrap();
426            if let Operand::Instruction(ptr) = ptr.clone() {
427                if ptr.get_type() == InstType::Alloca {
428                    inst.set_operand(0, map[&ptr].into());
429                }
430            }
431        }
432        InstType::GetElementPtr => {
433            if inst.get_operand().len() == 2 {
434                // Refuse to replace `getelementptr %ptr, 0`
435                return;
436            }
437            let ptr = inst.get_operand().first().unwrap();
438            if let Operand::Instruction(ptr) = ptr.clone() {
439                if ptr.get_type() == InstType::Alloca {
440                    inst.set_operand(0, map[&ptr].into());
441                }
442            }
443        }
444        InstType::Store => {
445            let base = inst.get_operand().get(1).unwrap();
446            if let Operand::Instruction(ptr) = base.clone() {
447                if ptr.get_type() == InstType::Alloca {
448                    inst.set_operand(1, map[&ptr].into());
449                }
450            }
451        }
452        _ => (),
453    }
454}
455
456/// Merge effect if parallelizing them doesn't cause collision.
457/// Returns changed or not.
458fn merge_effect(a: &mut Effect, b: &Effect, indvar: &Operand) -> Result<bool> {
459    if a.def_range.can_conflict(&b.def_range, indvar)
460        || a.use_range.can_conflict(&b.def_range, indvar)
461        || a.def_range.can_conflict(&b.use_range, indvar)
462        || b.def_range.can_conflict(&b.use_range, indvar)
463        || b.def_range.can_conflict(&b.def_range, indvar)
464    {
465        cprintln!(
466            "[INFO] failed to merge {} with {} (indvar = {})",
467            a.dump(),
468            b.dump(),
469            indvar
470        );
471        return Ok(false);
472    }
473    a.def_range.merge(&b.def_range);
474    a.use_range.merge(&b.use_range);
475    Ok(true)
476}
477
478/// A candidate for parallelization.
479/// For example:
480///
481/// ```llvm
482/// exit = br (indvar < 6), loop, exit
483/// indvar = phi [2, pre_header], [indvar + 3, loop]
484/// ```
485///
486/// indvar = phi [2, pre_header], [indvar + 3, loop]
487/// exit = br (indvar < 6), loop, exit
488/// init_val = 2
489/// init_bb = pre_header
490/// exit_val = 6
491struct Candidate {
492    lo: LoopPtr,
493    indvar: InstPtr,
494    exit: InstPtr,
495    delta: i32,
496    init_val: Operand,
497    init_bb: BBPtr,
498    exit_val: Operand,
499    exit_bb: BBPtr,
500}
501
502impl Candidate {
503    #[allow(clippy::too_many_arguments)]
504    fn new(
505        lo: LoopPtr,
506        indvar: InstPtr,
507        exit: InstPtr,
508        delta: i32,
509        init_val: Operand,
510        init_bb: BBPtr,
511        exit_val: Operand,
512        exit_bb: BBPtr,
513    ) -> Self {
514        Self {
515            lo,
516            indvar,
517            exit,
518            delta,
519            init_val,
520            init_bb,
521            exit_val,
522            exit_bb,
523        }
524    }
525
526    /// Dump candidate to string for debugging.
527    #[allow(unused)]
528    fn dump(&self) -> String {
529        format!(
530            "Candidate {{\n  indvar: {},\n  exit: {},\n  init_val: {},\n  init_bb: {},\n  exit_val: {},\n  exit_bb: {},\n}}",
531            self.indvar.gen_llvm_ir(),
532            self.exit.gen_llvm_ir(),
533            self.init_val,
534            self.init_bb.name,
535            self.exit_val,
536            self.exit_bb.name,
537        )
538    }
539
540    /// Get induction variable from exit instruction.
541    /// Exit instruction should shape like:
542    /// `exit = br (indvar < N), loop, exit`
543    fn from_exit(exit: InstPtr, lo: LoopPtr, dom_tree: &mut DominatorTree) -> Option<Self> {
544        let pre_header = lo.pre_header.unwrap();
545        if exit.get_type() != InstType::Br {
546            cprintln!(
547                "[INFO] loop {} fails because {} is not br",
548                pre_header.name,
549                exit.gen_llvm_ir()
550            );
551            return None;
552        }
553
554        // Get basic block to go to when exit
555        let parent_bb = exit.get_parent_bb()?;
556        let exit_bb = parent_bb
557            .get_succ_bb()
558            .iter()
559            .find(|bb| !lo.is_in_loop(bb))?;
560
561        // Exit block should have only one pred
562        // TODO-PERF: this is for easy thread join implementation, but weakens optimization
563        if exit_bb.get_pred_bb().len() != 1 {
564            cprintln!(
565                "[INFO] loop {} fails because {} has multiple preds",
566                pre_header.name,
567                exit_bb.name
568            );
569            return None;
570        }
571
572        // Condition should be `indvar < op`, get `indvar` from condition
573        // TODO-PERF: use induction variable analysis to get `indvar` consistently
574        let Operand::Instruction(cond) = exit.get_operand().first()? else {
575            cprintln!(
576                "[INFO] loop {} fails because {}'s first operand is not inst",
577                pre_header.name,
578                exit.gen_llvm_ir()
579            );
580            return None;
581        };
582        let InstType::ICmp = cond.get_type() else {
583            cprintln!(
584                "[INFO] loop {} fails because {} is not condition",
585                pre_header.name,
586                cond.gen_llvm_ir()
587            );
588            return None;
589        };
590        let icmp = downcast_ref::<ICmp>(cond.as_ref().as_ref());
591        if icmp.op != ICmpOp::Slt {
592            cprintln!(
593                "[INFO] loop {} fails because {} is not slt",
594                pre_header.name,
595                cond.gen_llvm_ir()
596            );
597            return None;
598        }
599        let Operand::Instruction(indvar) = icmp.get_lhs() else {
600            cprintln!(
601                "[INFO] loop {} fails because {}'s lhs is not inst",
602                pre_header.name,
603                cond.gen_llvm_ir()
604            );
605            return None;
606        };
607        let exit_val = icmp.get_rhs().clone();
608
609        // Exit val should be calculated before loop (dominates pre_header)
610        if let Operand::Instruction(inst) = exit_val {
611            if !dom_tree.is_dominate(inst.get_parent_bb().unwrap(), pre_header) {
612                cprintln!(
613                    "[INFO] loop {} fails because {} is not calculated before loop",
614                    pre_header.name,
615                    inst.gen_llvm_ir()
616                );
617                return None;
618            }
619        }
620
621        // Indvar should be `phi [init_val, init_bb], [indvar + delta, next_bb]`
622        // `indvar` should be the only phi in its block (other phi can be non-trivial)
623        // `init_bb` should be `pre_header`
624        // `next_bb` should be in loop
625        if indvar.get_type() != InstType::Phi {
626            cprintln!(
627                "[INFO] loop {} fails because {} is not phi",
628                pre_header.name,
629                indvar.gen_llvm_ir()
630            );
631            return None;
632        }
633        let phi = downcast_ref::<Phi>(indvar.as_ref().as_ref());
634        let inc = phi.get_incoming_values();
635        if inc.len() != 2 {
636            cprintln!(
637                "[INFO] loop {} fails because {}'s incoming value length is not 2",
638                pre_header.name,
639                indvar.gen_llvm_ir()
640            );
641            return None;
642        }
643        let init_val = inc[0].0.clone();
644        let init_bb = lo.pre_header?;
645        if init_bb != inc[0].1 {
646            cprintln!(
647                "[INFO] loop {} fails because {} is not pre_header",
648                pre_header.name,
649                inc[0].1.name.clone()
650            );
651            return None;
652        }
653        let Operand::Instruction(next_val) = inc[1].0 else {
654            cprintln!(
655                "[INFO] loop {} fails because {}'s second incoming value is not inst",
656                pre_header.name,
657                indvar.gen_llvm_ir()
658            );
659            return None;
660        };
661        if next_val.get_type() != InstType::Add {
662            cprintln!(
663                "[INFO] loop {} fails because {} is not add",
664                pre_header.name,
665                next_val.gen_llvm_ir()
666            );
667            return None;
668        }
669        let Operand::Constant(Constant::Int(delta)) = next_val.get_operand()[1] else {
670            cprintln!(
671                "[INFO] loop {} fails because {}'s second operand is not constant",
672                pre_header.name,
673                next_val.gen_llvm_ir()
674            );
675            return None;
676        };
677        let next_bb = inc[1].1;
678        if !lo.is_in_loop(&next_bb) {
679            cprintln!(
680                "[INFO] loop {} fails because {} is not in loop",
681                pre_header.name,
682                next_bb.name
683            );
684            return None;
685        }
686        for inst in indvar.get_parent_bb().unwrap().iter() {
687            if inst.get_type() == InstType::Phi && inst != *indvar {
688                cprintln!(
689                    "[INFO] loop {} fails because {} has multiple phi",
690                    pre_header.name,
691                    inst.gen_llvm_ir()
692                );
693                return None;
694            }
695        }
696
697        // Construct induction variable
698        Some(Self::new(
699            lo, *indvar, exit, delta, init_val, init_bb, exit_val, *exit_bb,
700        ))
701    }
702}