use crate::dominator_tree::DominatorTree;
use crate::entity::entity_impl;
use crate::entity::SecondaryMap;
use crate::entity::{Keys, PrimaryMap};
use crate::flowgraph::{BlockPredecessor, ControlFlowGraph};
use crate::ir::{Block, Function, Layout};
use crate::packed_option::PackedOption;
use crate::timing;
use alloc::vec::Vec;
use smallvec::{smallvec, SmallVec};
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
pub struct Loop(u32);
entity_impl!(Loop, "loop");
pub struct LoopAnalysis {
    loops: PrimaryMap<Loop, LoopData>,
    block_loop_map: SecondaryMap<Block, PackedOption<Loop>>,
    valid: bool,
}
struct LoopData {
    header: Block,
    parent: PackedOption<Loop>,
    level: LoopLevel,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct LoopLevel(u8);
impl LoopLevel {
    const INVALID: u8 = u8::MAX;
    pub fn root() -> Self {
        Self(0)
    }
    pub fn level(self) -> usize {
        self.0 as usize
    }
    pub fn invalid() -> Self {
        Self(Self::INVALID)
    }
    pub fn inc(self) -> Self {
        if self.0 == (Self::INVALID - 1) {
            self
        } else {
            Self(self.0 + 1)
        }
    }
    pub fn clamped(level: usize) -> Self {
        Self(
            u8::try_from(std::cmp::min(level, (Self::INVALID as usize) - 1))
                .expect("Clamped value must always convert"),
        )
    }
}
impl std::default::Default for LoopLevel {
    fn default() -> Self {
        LoopLevel::invalid()
    }
}
impl LoopData {
    pub fn new(header: Block, parent: Option<Loop>) -> Self {
        Self {
            header,
            parent: parent.into(),
            level: LoopLevel::invalid(),
        }
    }
}
impl LoopAnalysis {
    pub fn new() -> Self {
        Self {
            valid: false,
            loops: PrimaryMap::new(),
            block_loop_map: SecondaryMap::new(),
        }
    }
    pub fn loops(&self) -> Keys<Loop> {
        self.loops.keys()
    }
    pub fn loop_header(&self, lp: Loop) -> Block {
        self.loops[lp].header
    }
    pub fn loop_parent(&self, lp: Loop) -> Option<Loop> {
        self.loops[lp].parent.expand()
    }
    pub fn innermost_loop(&self, block: Block) -> Option<Loop> {
        self.block_loop_map[block].expand()
    }
    pub fn is_loop_header(&self, block: Block) -> Option<Loop> {
        self.innermost_loop(block)
            .filter(|&lp| self.loop_header(lp) == block)
    }
    pub fn is_in_loop(&self, block: Block, lp: Loop) -> bool {
        let block_loop = self.block_loop_map[block];
        match block_loop.expand() {
            None => false,
            Some(block_loop) => self.is_child_loop(block_loop, lp),
        }
    }
    pub fn is_child_loop(&self, child: Loop, parent: Loop) -> bool {
        let mut finger = Some(child);
        while let Some(finger_loop) = finger {
            if finger_loop == parent {
                return true;
            }
            finger = self.loop_parent(finger_loop);
        }
        false
    }
    pub fn loop_level(&self, block: Block) -> LoopLevel {
        self.innermost_loop(block)
            .map_or(LoopLevel(0), |lp| self.loops[lp].level)
    }
}
impl LoopAnalysis {
    pub fn compute(&mut self, func: &Function, cfg: &ControlFlowGraph, domtree: &DominatorTree) {
        let _tt = timing::loop_analysis();
        self.loops.clear();
        self.block_loop_map.clear();
        self.block_loop_map.resize(func.dfg.num_blocks());
        self.find_loop_headers(cfg, domtree, &func.layout);
        self.discover_loop_blocks(cfg, domtree, &func.layout);
        self.assign_loop_levels();
        self.valid = true;
    }
    pub fn is_valid(&self) -> bool {
        self.valid
    }
    pub fn clear(&mut self) {
        self.loops.clear();
        self.block_loop_map.clear();
        self.valid = false;
    }
    fn find_loop_headers(
        &mut self,
        cfg: &ControlFlowGraph,
        domtree: &DominatorTree,
        layout: &Layout,
    ) {
        for &block in domtree.cfg_postorder().iter().rev() {
            for BlockPredecessor {
                inst: pred_inst, ..
            } in cfg.pred_iter(block)
            {
                if domtree.dominates(block, pred_inst, layout) {
                    let lp = self.loops.push(LoopData::new(block, None));
                    self.block_loop_map[block] = lp.into();
                    break;
                    }
            }
        }
    }
    fn discover_loop_blocks(
        &mut self,
        cfg: &ControlFlowGraph,
        domtree: &DominatorTree,
        layout: &Layout,
    ) {
        let mut stack: Vec<Block> = Vec::new();
        for lp in self.loops().rev() {
            for BlockPredecessor {
                block: pred,
                inst: pred_inst,
            } in cfg.pred_iter(self.loops[lp].header)
            {
                if domtree.dominates(self.loops[lp].header, pred_inst, layout) {
                    stack.push(pred);
                }
            }
            while let Some(node) = stack.pop() {
                let continue_dfs: Option<Block>;
                match self.block_loop_map[node].expand() {
                    None => {
                        self.block_loop_map[node] = PackedOption::from(lp);
                        continue_dfs = Some(node);
                    }
                    Some(node_loop) => {
                        let mut node_loop = node_loop;
                        let mut node_loop_parent_option = self.loops[node_loop].parent;
                        while let Some(node_loop_parent) = node_loop_parent_option.expand() {
                            if node_loop_parent == lp {
                                break;
                            } else {
                                node_loop = node_loop_parent;
                                node_loop_parent_option = self.loops[node_loop].parent;
                            }
                        }
                        match node_loop_parent_option.expand() {
                            Some(_) => continue_dfs = None,
                            None => {
                                if node_loop != lp {
                                    self.loops[node_loop].parent = lp.into();
                                    continue_dfs = Some(self.loops[node_loop].header)
                                } else {
                                    continue_dfs = None
                                }
                            }
                        }
                    }
                }
                if let Some(continue_dfs) = continue_dfs {
                    for BlockPredecessor { block: pred, .. } in cfg.pred_iter(continue_dfs) {
                        stack.push(pred)
                    }
                }
            }
        }
    }
    fn assign_loop_levels(&mut self) {
        let mut stack: SmallVec<[Loop; 8]> = smallvec![];
        for lp in self.loops.keys() {
            if self.loops[lp].level == LoopLevel::invalid() {
                stack.push(lp);
                while let Some(&lp) = stack.last() {
                    if let Some(parent) = self.loops[lp].parent.into() {
                        if self.loops[parent].level != LoopLevel::invalid() {
                            self.loops[lp].level = self.loops[parent].level.inc();
                            stack.pop();
                        } else {
                            stack.push(parent);
                        }
                    } else {
                        self.loops[lp].level = LoopLevel::root().inc();
                        stack.pop();
                    }
                }
            }
        }
    }
}
#[cfg(test)]
mod tests {
    use crate::cursor::{Cursor, FuncCursor};
    use crate::dominator_tree::DominatorTree;
    use crate::flowgraph::ControlFlowGraph;
    use crate::ir::{types, Function, InstBuilder};
    use crate::loop_analysis::{Loop, LoopAnalysis};
    use alloc::vec::Vec;
    #[test]
    fn nested_loops_detection() {
        let mut func = Function::new();
        let block0 = func.dfg.make_block();
        let block1 = func.dfg.make_block();
        let block2 = func.dfg.make_block();
        let block3 = func.dfg.make_block();
        let block4 = func.dfg.make_block();
        let cond = func.dfg.append_block_param(block0, types::I32);
        {
            let mut cur = FuncCursor::new(&mut func);
            cur.insert_block(block0);
            cur.ins().jump(block1, &[]);
            cur.insert_block(block1);
            cur.ins().jump(block2, &[]);
            cur.insert_block(block2);
            cur.ins().brif(cond, block1, &[], block3, &[]);
            cur.insert_block(block3);
            cur.ins().brif(cond, block0, &[], block4, &[]);
            cur.insert_block(block4);
            cur.ins().return_(&[]);
        }
        let mut loop_analysis = LoopAnalysis::new();
        let mut cfg = ControlFlowGraph::new();
        let mut domtree = DominatorTree::new();
        cfg.compute(&func);
        domtree.compute(&func, &cfg);
        loop_analysis.compute(&func, &cfg, &domtree);
        let loops = loop_analysis.loops().collect::<Vec<Loop>>();
        assert_eq!(loops.len(), 2);
        assert_eq!(loop_analysis.loop_header(loops[0]), block0);
        assert_eq!(loop_analysis.loop_header(loops[1]), block1);
        assert_eq!(loop_analysis.loop_parent(loops[1]), Some(loops[0]));
        assert_eq!(loop_analysis.loop_parent(loops[0]), None);
        assert_eq!(loop_analysis.is_in_loop(block0, loops[0]), true);
        assert_eq!(loop_analysis.is_in_loop(block0, loops[1]), false);
        assert_eq!(loop_analysis.is_in_loop(block1, loops[1]), true);
        assert_eq!(loop_analysis.is_in_loop(block1, loops[0]), true);
        assert_eq!(loop_analysis.is_in_loop(block2, loops[1]), true);
        assert_eq!(loop_analysis.is_in_loop(block2, loops[0]), true);
        assert_eq!(loop_analysis.is_in_loop(block3, loops[0]), true);
        assert_eq!(loop_analysis.is_in_loop(block0, loops[1]), false);
        assert_eq!(loop_analysis.loop_level(block0).level(), 1);
        assert_eq!(loop_analysis.loop_level(block1).level(), 2);
        assert_eq!(loop_analysis.loop_level(block2).level(), 2);
        assert_eq!(loop_analysis.loop_level(block3).level(), 1);
    }
    #[test]
    fn complex_loop_detection() {
        let mut func = Function::new();
        let block0 = func.dfg.make_block();
        let block1 = func.dfg.make_block();
        let block2 = func.dfg.make_block();
        let block3 = func.dfg.make_block();
        let block4 = func.dfg.make_block();
        let block5 = func.dfg.make_block();
        let block6 = func.dfg.make_block();
        let cond = func.dfg.append_block_param(block0, types::I32);
        {
            let mut cur = FuncCursor::new(&mut func);
            cur.insert_block(block0);
            cur.ins().brif(cond, block1, &[], block3, &[]);
            cur.insert_block(block1);
            cur.ins().jump(block2, &[]);
            cur.insert_block(block2);
            cur.ins().brif(cond, block1, &[], block5, &[]);
            cur.insert_block(block3);
            cur.ins().jump(block4, &[]);
            cur.insert_block(block4);
            cur.ins().brif(cond, block3, &[], block5, &[]);
            cur.insert_block(block5);
            cur.ins().brif(cond, block0, &[], block6, &[]);
            cur.insert_block(block6);
            cur.ins().return_(&[]);
        }
        let mut loop_analysis = LoopAnalysis::new();
        let cfg = ControlFlowGraph::with_function(&func);
        let domtree = DominatorTree::with_function(&func, &cfg);
        loop_analysis.compute(&func, &cfg, &domtree);
        let loops = loop_analysis.loops().collect::<Vec<Loop>>();
        assert_eq!(loops.len(), 3);
        assert_eq!(loop_analysis.loop_header(loops[0]), block0);
        assert_eq!(loop_analysis.loop_header(loops[1]), block3);
        assert_eq!(loop_analysis.loop_header(loops[2]), block1);
        assert_eq!(loop_analysis.loop_parent(loops[1]), Some(loops[0]));
        assert_eq!(loop_analysis.loop_parent(loops[2]), Some(loops[0]));
        assert_eq!(loop_analysis.loop_parent(loops[0]), None);
        assert_eq!(loop_analysis.is_in_loop(block0, loops[0]), true);
        assert_eq!(loop_analysis.is_in_loop(block1, loops[2]), true);
        assert_eq!(loop_analysis.is_in_loop(block2, loops[2]), true);
        assert_eq!(loop_analysis.is_in_loop(block3, loops[1]), true);
        assert_eq!(loop_analysis.is_in_loop(block4, loops[1]), true);
        assert_eq!(loop_analysis.is_in_loop(block5, loops[0]), true);
        assert_eq!(loop_analysis.loop_level(block0).level(), 1);
        assert_eq!(loop_analysis.loop_level(block1).level(), 2);
        assert_eq!(loop_analysis.loop_level(block2).level(), 2);
        assert_eq!(loop_analysis.loop_level(block3).level(), 2);
        assert_eq!(loop_analysis.loop_level(block4).level(), 2);
        assert_eq!(loop_analysis.loop_level(block5).level(), 1);
    }
}