use alloc::vec::Vec;
use anyhow::Result;
use risc0_core::field::{Elem, ExtElem, Field};
use crate::{hal::cpu::SyncSlice, taps::TapSet};
pub const REGISTER_GROUP_ACCUM: usize = 0;
pub const REGISTER_GROUP_CODE: usize = 1;
pub const REGISTER_GROUP_DATA: usize = 2;
#[derive(Clone, Copy)]
pub struct MixState<EE: ExtElem> {
    pub tot: EE,
    pub mul: EE,
}
pub trait CircuitStepHandler<E: Elem> {
    fn call(
        &mut self,
        cycle: usize,
        name: &str,
        extra: &str,
        args: &[E],
        outs: &mut [E],
    ) -> Result<()>;
    fn sort(&mut self, name: &str);
    fn calc_prefix_products(&mut self);
}
pub struct CircuitStepContext {
    pub size: usize,
    pub cycle: usize,
}
pub trait CircuitStep<E: Elem> {
    fn step_exec<S: CircuitStepHandler<E>>(
        &self,
        ctx: &CircuitStepContext,
        custom: &mut S,
        args: &[SyncSlice<E>],
    ) -> Result<E>;
    fn step_verify_bytes<S: CircuitStepHandler<E>>(
        &self,
        ctx: &CircuitStepContext,
        custom: &mut S,
        args: &[SyncSlice<E>],
    ) -> Result<E>;
    fn step_verify_mem<S: CircuitStepHandler<E>>(
        &self,
        ctx: &CircuitStepContext,
        custom: &mut S,
        args: &[SyncSlice<E>],
    ) -> Result<E>;
    fn step_compute_accum<S: CircuitStepHandler<E>>(
        &self,
        ctx: &CircuitStepContext,
        custom: &mut S,
        args: &[SyncSlice<E>],
    ) -> Result<E>;
    fn step_verify_accum<S: CircuitStepHandler<E>>(
        &self,
        ctx: &CircuitStepContext,
        custom: &mut S,
        args: &[SyncSlice<E>],
    ) -> Result<E>;
}
pub trait PolyFp<F: Field> {
    fn poly_fp(
        &self,
        cycle: usize,
        steps: usize,
        mix: &F::ExtElem,
        args: &[&[F::Elem]],
    ) -> F::ExtElem;
}
pub trait PolyExt<F: Field> {
    fn poly_ext(
        &self,
        mix: &F::ExtElem,
        u: &[F::ExtElem],
        args: &[&[F::Elem]],
    ) -> MixState<F::ExtElem>;
}
pub trait TapsProvider {
    fn get_taps(&self) -> &'static TapSet<'static>;
    fn code_size(&self) -> usize {
        self.get_taps().group_size(REGISTER_GROUP_CODE)
    }
}
pub trait CircuitInfo {
    const OUTPUT_SIZE: usize;
    const MIX_SIZE: usize;
}
pub trait CircuitCoreDef<F: Field>: CircuitInfo + PolyExt<F> + TapsProvider {}
pub trait CircuitProveDef<F: Field>:
    CircuitStep<F::Elem> + PolyFp<F> + CircuitCoreDef<F> + Sync
{
}
pub type Arg = usize;
pub type Var = usize;
pub struct PolyExtStepDef {
    pub block: &'static [PolyExtStep],
    pub ret: Var,
}
pub enum PolyExtStep {
    Const(u32),
    Get(usize),
    GetGlobal(Arg, usize),
    Add(Var, Var),
    Sub(Var, Var),
    Mul(Var, Var),
    True,
    AndEqz(Var, Var),
    AndCond(Var, Var, Var),
}
impl PolyExtStep {
    pub fn step<F: Field>(
        &self,
        fp_vars: &mut Vec<F::ExtElem>,
        mix_vars: &mut Vec<MixState<F::ExtElem>>,
        mix: &F::ExtElem,
        u: &[F::ExtElem],
        args: &[&[F::Elem]],
    ) {
        match self {
            PolyExtStep::Const(value) => {
                let elem = F::Elem::from_u64(*value as u64);
                fp_vars.push(F::ExtElem::from_subfield(&elem));
            }
            PolyExtStep::Get(tap) => {
                fp_vars.push(u[*tap]);
            }
            PolyExtStep::GetGlobal(base, offset) => {
                fp_vars.push(F::ExtElem::from_subfield(&args[*base][*offset]));
            }
            PolyExtStep::Add(x1, x2) => {
                fp_vars.push(fp_vars[*x1] + fp_vars[*x2]);
            }
            PolyExtStep::Sub(x1, x2) => {
                fp_vars.push(fp_vars[*x1] - fp_vars[*x2]);
            }
            PolyExtStep::Mul(x1, x2) => {
                fp_vars.push(fp_vars[*x1] * fp_vars[*x2]);
            }
            PolyExtStep::True => {
                mix_vars.push(MixState {
                    tot: F::ExtElem::ZERO,
                    mul: F::ExtElem::ONE,
                });
            }
            PolyExtStep::AndEqz(x, val) => {
                let x = mix_vars[*x];
                let val = fp_vars[*val];
                mix_vars.push(MixState {
                    tot: x.tot + x.mul * val,
                    mul: x.mul * *mix,
                });
            }
            PolyExtStep::AndCond(x, cond, inner) => {
                let x = mix_vars[*x];
                let cond = fp_vars[*cond];
                let inner = mix_vars[*inner];
                mix_vars.push(MixState {
                    tot: x.tot + cond * inner.tot * x.mul,
                    mul: x.mul * inner.mul,
                });
            }
        }
    }
}
impl PolyExtStepDef {
    pub fn step<F: Field>(
        &self,
        mix: &F::ExtElem,
        u: &[F::ExtElem],
        args: &[&[F::Elem]],
    ) -> MixState<F::ExtElem> {
        let mut fp_vars = Vec::with_capacity(self.block.len() - (self.ret + 1));
        let mut mix_vars = Vec::with_capacity(self.ret + 1);
        for op in self.block.iter() {
            op.step::<F>(&mut fp_vars, &mut mix_vars, mix, u, args);
        }
        assert_eq!(
            fp_vars.len(),
            self.block.len() - (self.ret + 1),
            "Miscalculated capacity for fp_vars"
        );
        assert_eq!(
            mix_vars.len(),
            self.ret + 1,
            "Miscalculated capacity for mix_vars"
        );
        mix_vars[self.ret]
    }
}