use std::{collections::BTreeSet, fmt};
use cranelift_entity::PrimaryMap;
use miden_assembly::ast;
use midenc_hir::{diagnostics::Span, formatter::PrettyPrint, FunctionIdent, Ident};
use smallvec::smallvec;
use super::*;
use crate::InstructionPointer;
#[derive(Debug, Clone)]
pub struct Region {
    pub body: BlockId,
    pub blocks: PrimaryMap<BlockId, Block>,
}
impl Default for Region {
    fn default() -> Self {
        let mut blocks = PrimaryMap::<BlockId, Block>::default();
        let id = blocks.next_key();
        let body = blocks.push(Block {
            id,
            ops: smallvec![],
        });
        Self { body, blocks }
    }
}
impl Region {
    #[inline(always)]
    pub const fn id(&self) -> BlockId {
        self.body
    }
    #[inline]
    pub fn block(&self, id: BlockId) -> &Block {
        &self.blocks[id]
    }
    #[inline]
    pub fn block_mut(&mut self, id: BlockId) -> &mut Block {
        &mut self.blocks[id]
    }
    pub fn get(&self, ip: InstructionPointer) -> Option<Span<Op>> {
        self.blocks[ip.block].ops.get(ip.index).copied()
    }
    pub fn create_block(&mut self) -> BlockId {
        let id = self.blocks.next_key();
        self.blocks.push(Block {
            id,
            ops: smallvec![],
        });
        id
    }
    pub fn display<'a, 'b: 'a>(
        &'b self,
        function: Option<FunctionIdent>,
        imports: &'b ModuleImportInfo,
    ) -> DisplayRegion<'a> {
        DisplayRegion {
            region: self,
            function,
            imports,
        }
    }
    pub fn to_block(
        &self,
        imports: &ModuleImportInfo,
        locals: &BTreeSet<FunctionIdent>,
    ) -> ast::Block {
        emit_block(self.body, &self.blocks, imports, locals)
    }
    pub fn from_block(current_module: Ident, code: &ast::Block) -> Self {
        let mut region = Self::default();
        let body = region.body;
        import_block(current_module, &mut region, body, code);
        region
    }
}
impl core::ops::Index<InstructionPointer> for Region {
    type Output = Op;
    #[inline]
    fn index(&self, ip: InstructionPointer) -> &Self::Output {
        &self.blocks[ip.block].ops[ip.index]
    }
}
#[doc(hidden)]
pub struct DisplayRegion<'a> {
    region: &'a Region,
    function: Option<FunctionIdent>,
    imports: &'a ModuleImportInfo,
}
impl<'a> midenc_hir::formatter::PrettyPrint for DisplayRegion<'a> {
    fn render(&self) -> midenc_hir::formatter::Document {
        use midenc_hir::DisplayMasmBlock;
        let block = DisplayMasmBlock::new(
            self.function,
            Some(self.imports),
            &self.region.blocks,
            self.region.body,
        );
        block.render()
    }
}
impl<'a> fmt::Display for DisplayRegion<'a> {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        self.pretty_print(f)
    }
}
fn import_block(
    current_module: Ident,
    region: &mut Region,
    current_block_id: BlockId,
    block: &ast::Block,
) {
    for op in block.iter() {
        match op {
            ast::Op::Inst(ix) => {
                let span = ix.span();
                let current_block = region.block_mut(current_block_id);
                let ops = Op::from_masm(current_module, (**ix).clone());
                current_block.extend(ops.into_iter().map(|op| Span::new(span, op)));
            }
            ast::Op::If {
                span,
                ref then_blk,
                ref else_blk,
                ..
            } => {
                let then_blk_id = region.create_block();
                let else_blk_id = region.create_block();
                import_block(current_module, region, then_blk_id, then_blk);
                import_block(current_module, region, else_blk_id, else_blk);
                region.block_mut(current_block_id).push(Op::If(then_blk_id, else_blk_id), *span);
            }
            ast::Op::Repeat {
                span,
                count,
                ref body,
                ..
            } => {
                let body_blk = region.create_block();
                import_block(current_module, region, body_blk, body);
                let count = u16::try_from(*count).unwrap_or_else(|_| {
                    panic!("invalid repeat count: expected {count} to be less than 255")
                });
                region.block_mut(current_block_id).push(Op::Repeat(count, body_blk), *span);
            }
            ast::Op::While { span, ref body, .. } => {
                let body_blk = region.create_block();
                import_block(current_module, region, body_blk, body);
                region.block_mut(current_block_id).push(Op::While(body_blk), *span);
            }
        }
    }
}
#[allow(clippy::only_used_in_recursion)]
fn emit_block(
    block_id: BlockId,
    blocks: &PrimaryMap<BlockId, Block>,
    imports: &ModuleImportInfo,
    locals: &BTreeSet<FunctionIdent>,
) -> ast::Block {
    let current_block = &blocks[block_id];
    let mut ops = Vec::with_capacity(current_block.ops.len());
    for op in current_block.ops.iter().copied() {
        let span = op.span();
        match op.into_inner() {
            Op::If(then_blk, else_blk) => {
                let then_blk = emit_block(then_blk, blocks, imports, locals);
                let else_blk = emit_block(else_blk, blocks, imports, locals);
                ops.push(ast::Op::If {
                    span,
                    then_blk,
                    else_blk,
                });
            }
            Op::While(blk) => {
                let body = emit_block(blk, blocks, imports, locals);
                ops.push(ast::Op::While { span, body });
            }
            Op::Repeat(n, blk) => {
                let body = emit_block(blk, blocks, imports, locals);
                ops.push(ast::Op::Repeat {
                    span,
                    count: n as u32,
                    body,
                });
            }
            op => {
                ops.extend(
                    op.into_masm(imports, locals)
                        .into_iter()
                        .map(|inst| ast::Op::Inst(Span::new(span, inst))),
                );
            }
        }
    }
    ast::Block::new(Default::default(), ops)
}