use snarkvm_fields::PrimeField;
use crate::snark::varuna::{witness_label, CircuitId, SNARKMode};
use itertools::Itertools;
use std::collections::BTreeMap;
#[derive(Clone, Debug)]
pub(crate) struct BatchCombiners<F> {
    pub(crate) circuit_combiner: F,
    pub(crate) instance_combiners: Vec<F>,
}
#[derive(Clone, Debug)]
pub struct FirstMessage<F: PrimeField> {
    pub(crate) batch_combiners: BTreeMap<CircuitId, BatchCombiners<F>>,
}
#[derive(Copy, Clone, Debug)]
pub struct SecondMessage<F> {
    pub alpha: F,
    pub eta_b: F,
    pub eta_c: F,
}
#[derive(Copy, Clone, Debug)]
pub struct ThirdMessage<F> {
    pub beta: F,
}
#[derive(Clone, Debug)]
pub struct FourthMessage<F> {
    pub delta_a: Vec<F>,
    pub delta_b: Vec<F>,
    pub delta_c: Vec<F>,
}
impl<F: PrimeField> FourthMessage<F> {
    pub fn into_iter(self) -> impl Iterator<Item = F> {
        self.delta_a
            .into_iter()
            .zip_eq(self.delta_b.into_iter())
            .zip_eq(self.delta_c.into_iter())
            .flat_map(|((r_a, r_b), r_c)| [r_a, r_b, r_c])
    }
}
#[derive(Clone, Debug)]
pub struct QuerySet<F: PrimeField> {
    pub batch_sizes: BTreeMap<CircuitId, usize>,
    pub rowcheck_zerocheck_query: (String, F),
    pub g_1_query: (String, F),
    pub lineval_sumcheck_query: (String, F),
    pub g_a_query: (String, F),
    pub g_b_query: (String, F),
    pub g_c_query: (String, F),
    pub matrix_sumcheck_query: (String, F),
}
impl<F: PrimeField> QuerySet<F> {
    pub fn new<SM: SNARKMode>(state: &super::State<F, SM>) -> Self {
        let alpha = state.second_round_message.as_ref().unwrap().alpha;
        let beta = state.third_round_message.unwrap().beta;
        let gamma = state.gamma.unwrap();
        Self {
            batch_sizes: state.circuit_specific_states.iter().map(|(c, s)| (*c, s.batch_size)).collect(),
            rowcheck_zerocheck_query: ("alpha".into(), alpha),
            g_1_query: ("beta".into(), beta),
            lineval_sumcheck_query: ("beta".into(), beta),
            g_a_query: ("gamma".into(), gamma),
            g_b_query: ("gamma".into(), gamma),
            g_c_query: ("gamma".into(), gamma),
            matrix_sumcheck_query: ("gamma".into(), gamma),
        }
    }
    pub fn to_set(&self) -> crate::polycommit::sonic_pc::QuerySet<F> {
        let mut query_set = crate::polycommit::sonic_pc::QuerySet::new();
        for &circuit_id in self.batch_sizes.keys() {
            query_set.insert((witness_label(circuit_id, "g_a", 0), self.g_a_query.clone()));
            query_set.insert((witness_label(circuit_id, "g_b", 0), self.g_b_query.clone()));
            query_set.insert((witness_label(circuit_id, "g_c", 0), self.g_c_query.clone()));
        }
        query_set.insert(("g_1".into(), self.g_1_query.clone()));
        query_set.insert(("rowcheck_zerocheck".into(), self.rowcheck_zerocheck_query.clone()));
        query_set.insert(("lineval_sumcheck".into(), self.lineval_sumcheck_query.clone()));
        query_set.insert(("matrix_sumcheck".into(), self.matrix_sumcheck_query.clone()));
        query_set
    }
}
#[derive(Debug, Clone, Copy)]
pub(crate) struct QueryPoints<F: PrimeField> {
    pub(crate) alpha: F,
    pub(crate) beta: F,
    pub(crate) gamma: F,
}
impl<F: PrimeField> QueryPoints<F> {
    pub(crate) fn new(alpha: F, beta: F, gamma: F) -> Self {
        Self { alpha, beta, gamma }
    }
    pub(crate) fn into_iter(self) -> impl IntoIterator<Item = F> {
        [self.alpha, self.beta, self.gamma]
    }
}