use crate::bforest;
use crate::entity::SecondaryMap;
use crate::ir::instructions::BranchInfo;
use crate::ir::{Ebb, Function, Inst};
use crate::timing;
use core::mem;
#[derive(Debug, PartialEq, Eq)]
pub struct BasicBlock {
    
    pub ebb: Ebb,
    
    pub inst: Inst,
}
impl BasicBlock {
    
    pub fn new(ebb: Ebb, inst: Inst) -> Self {
        Self { ebb, inst }
    }
}
#[derive(Clone, Default)]
struct CFGNode {
    
    
    
    
    
    
    
    
    
    
    
    
    pub predecessors: bforest::Map<Inst, Ebb>,
    
    
    pub successors: bforest::Set<Ebb>,
}
pub struct ControlFlowGraph {
    data: SecondaryMap<Ebb, CFGNode>,
    pred_forest: bforest::MapForest<Inst, Ebb>,
    succ_forest: bforest::SetForest<Ebb>,
    valid: bool,
}
impl ControlFlowGraph {
    
    pub fn new() -> Self {
        Self {
            data: SecondaryMap::new(),
            valid: false,
            pred_forest: bforest::MapForest::new(),
            succ_forest: bforest::SetForest::new(),
        }
    }
    
    pub fn clear(&mut self) {
        self.data.clear();
        self.pred_forest.clear();
        self.succ_forest.clear();
        self.valid = false;
    }
    
    pub fn with_function(func: &Function) -> Self {
        let mut cfg = Self::new();
        cfg.compute(func);
        cfg
    }
    
    
    
    pub fn compute(&mut self, func: &Function) {
        let _tt = timing::flowgraph();
        self.clear();
        self.data.resize(func.dfg.num_ebbs());
        for ebb in &func.layout {
            self.compute_ebb(func, ebb);
        }
        self.valid = true;
    }
    fn compute_ebb(&mut self, func: &Function, ebb: Ebb) {
        for inst in func.layout.ebb_insts(ebb) {
            match func.dfg.analyze_branch(inst) {
                BranchInfo::SingleDest(dest, _) => {
                    self.add_edge(ebb, inst, dest);
                }
                BranchInfo::Table(jt, dest) => {
                    if let Some(dest) = dest {
                        self.add_edge(ebb, inst, dest);
                    }
                    for dest in func.jump_tables[jt].iter() {
                        self.add_edge(ebb, inst, *dest);
                    }
                }
                BranchInfo::NotABranch => {}
            }
        }
    }
    fn invalidate_ebb_successors(&mut self, ebb: Ebb) {
        
        
        
        let mut successors = mem::replace(&mut self.data[ebb].successors, Default::default());
        for succ in successors.iter(&self.succ_forest) {
            self.data[succ]
                .predecessors
                .retain(&mut self.pred_forest, |_, &mut e| e != ebb);
        }
        successors.clear(&mut self.succ_forest);
    }
    
    
    
    
    
    
    pub fn recompute_ebb(&mut self, func: &Function, ebb: Ebb) {
        debug_assert!(self.is_valid());
        self.invalidate_ebb_successors(ebb);
        self.compute_ebb(func, ebb);
    }
    fn add_edge(&mut self, from: Ebb, from_inst: Inst, to: Ebb) {
        self.data[from]
            .successors
            .insert(to, &mut self.succ_forest, &());
        self.data[to]
            .predecessors
            .insert(from_inst, from, &mut self.pred_forest, &());
    }
    
    pub fn pred_iter(&self, ebb: Ebb) -> PredIter {
        PredIter(self.data[ebb].predecessors.iter(&self.pred_forest))
    }
    
    pub fn succ_iter(&self, ebb: Ebb) -> SuccIter {
        debug_assert!(self.is_valid());
        self.data[ebb].successors.iter(&self.succ_forest)
    }
    
    
    
    
    
    pub fn is_valid(&self) -> bool {
        self.valid
    }
}
pub struct PredIter<'a>(bforest::MapIter<'a, Inst, Ebb>);
impl<'a> Iterator for PredIter<'a> {
    type Item = BasicBlock;
    fn next(&mut self) -> Option<BasicBlock> {
        self.0.next().map(|(i, e)| BasicBlock::new(e, i))
    }
}
pub type SuccIter<'a> = bforest::SetIter<'a, Ebb>;
#[cfg(test)]
mod tests {
    use super::*;
    use crate::cursor::{Cursor, FuncCursor};
    use crate::ir::{types, Function, InstBuilder};
    use alloc::vec::Vec;
    #[test]
    fn empty() {
        let func = Function::new();
        ControlFlowGraph::with_function(&func);
    }
    #[test]
    fn no_predecessors() {
        let mut func = Function::new();
        let ebb0 = func.dfg.make_ebb();
        let ebb1 = func.dfg.make_ebb();
        let ebb2 = func.dfg.make_ebb();
        func.layout.append_ebb(ebb0);
        func.layout.append_ebb(ebb1);
        func.layout.append_ebb(ebb2);
        let cfg = ControlFlowGraph::with_function(&func);
        let mut fun_ebbs = func.layout.ebbs();
        for ebb in func.layout.ebbs() {
            assert_eq!(ebb, fun_ebbs.next().unwrap());
            assert_eq!(cfg.pred_iter(ebb).count(), 0);
            assert_eq!(cfg.succ_iter(ebb).count(), 0);
        }
    }
    #[test]
    fn branches_and_jumps() {
        let mut func = Function::new();
        let ebb0 = func.dfg.make_ebb();
        let cond = func.dfg.append_ebb_param(ebb0, types::I32);
        let ebb1 = func.dfg.make_ebb();
        let ebb2 = func.dfg.make_ebb();
        let br_ebb0_ebb2;
        let br_ebb1_ebb1;
        let jmp_ebb0_ebb1;
        let jmp_ebb1_ebb2;
        {
            let mut cur = FuncCursor::new(&mut func);
            cur.insert_ebb(ebb0);
            br_ebb0_ebb2 = cur.ins().brnz(cond, ebb2, &[]);
            jmp_ebb0_ebb1 = cur.ins().jump(ebb1, &[]);
            cur.insert_ebb(ebb1);
            br_ebb1_ebb1 = cur.ins().brnz(cond, ebb1, &[]);
            jmp_ebb1_ebb2 = cur.ins().jump(ebb2, &[]);
            cur.insert_ebb(ebb2);
        }
        let mut cfg = ControlFlowGraph::with_function(&func);
        {
            let ebb0_predecessors = cfg.pred_iter(ebb0).collect::<Vec<_>>();
            let ebb1_predecessors = cfg.pred_iter(ebb1).collect::<Vec<_>>();
            let ebb2_predecessors = cfg.pred_iter(ebb2).collect::<Vec<_>>();
            let ebb0_successors = cfg.succ_iter(ebb0).collect::<Vec<_>>();
            let ebb1_successors = cfg.succ_iter(ebb1).collect::<Vec<_>>();
            let ebb2_successors = cfg.succ_iter(ebb2).collect::<Vec<_>>();
            assert_eq!(ebb0_predecessors.len(), 0);
            assert_eq!(ebb1_predecessors.len(), 2);
            assert_eq!(ebb2_predecessors.len(), 2);
            assert_eq!(
                ebb1_predecessors.contains(&BasicBlock::new(ebb0, jmp_ebb0_ebb1)),
                true
            );
            assert_eq!(
                ebb1_predecessors.contains(&BasicBlock::new(ebb1, br_ebb1_ebb1)),
                true
            );
            assert_eq!(
                ebb2_predecessors.contains(&BasicBlock::new(ebb0, br_ebb0_ebb2)),
                true
            );
            assert_eq!(
                ebb2_predecessors.contains(&BasicBlock::new(ebb1, jmp_ebb1_ebb2)),
                true
            );
            assert_eq!(ebb0_successors, [ebb1, ebb2]);
            assert_eq!(ebb1_successors, [ebb1, ebb2]);
            assert_eq!(ebb2_successors, []);
        }
        
        func.dfg.replace(br_ebb0_ebb2).brnz(cond, ebb1, &[]);
        func.dfg.replace(jmp_ebb0_ebb1).return_(&[]);
        cfg.recompute_ebb(&mut func, ebb0);
        let br_ebb0_ebb1 = br_ebb0_ebb2;
        {
            let ebb0_predecessors = cfg.pred_iter(ebb0).collect::<Vec<_>>();
            let ebb1_predecessors = cfg.pred_iter(ebb1).collect::<Vec<_>>();
            let ebb2_predecessors = cfg.pred_iter(ebb2).collect::<Vec<_>>();
            let ebb0_successors = cfg.succ_iter(ebb0);
            let ebb1_successors = cfg.succ_iter(ebb1);
            let ebb2_successors = cfg.succ_iter(ebb2);
            assert_eq!(ebb0_predecessors.len(), 0);
            assert_eq!(ebb1_predecessors.len(), 2);
            assert_eq!(ebb2_predecessors.len(), 1);
            assert_eq!(
                ebb1_predecessors.contains(&BasicBlock::new(ebb0, br_ebb0_ebb1)),
                true
            );
            assert_eq!(
                ebb1_predecessors.contains(&BasicBlock::new(ebb1, br_ebb1_ebb1)),
                true
            );
            assert_eq!(
                ebb2_predecessors.contains(&BasicBlock::new(ebb0, br_ebb0_ebb2)),
                false
            );
            assert_eq!(
                ebb2_predecessors.contains(&BasicBlock::new(ebb1, jmp_ebb1_ebb2)),
                true
            );
            assert_eq!(ebb0_successors.collect::<Vec<_>>(), [ebb1]);
            assert_eq!(ebb1_successors.collect::<Vec<_>>(), [ebb1, ebb2]);
            assert_eq!(ebb2_successors.collect::<Vec<_>>(), []);
        }
    }
}