use crate::{
    constraint_vm::{
        self,
        error::{CheckError, ConstraintErrors, ConstraintsUnsatisfied},
        TransientData,
    },
    state_read_vm::{
        self, asm::FromBytesError, error::StateReadError, Access, BytecodeMapped, Gas, GasLimit,
        SolutionAccess, StateRead, StateSlotSlice, StateSlots,
    },
    types::{
        predicate::Predicate,
        solution::{Solution, SolutionData, SolutionDataIndex},
        Key, PredicateAddress, StateReadBytecode, Word,
    },
};
#[cfg(feature = "tracing")]
use essential_hash::content_addr;
use std::{collections::HashSet, fmt, sync::Arc};
use thiserror::Error;
use tokio::task::JoinSet;
#[cfg(feature = "tracing")]
use tracing::Instrument;
#[derive(Clone, Debug, Default, Eq, Hash, PartialEq)]
pub struct CheckPredicateConfig {
    pub collect_all_failures: bool,
}
#[derive(Debug, Error)]
pub enum InvalidSolution {
    #[error("invalid solution data: {0}")]
    Data(#[from] InvalidSolutionData),
    #[error("state mutations validation failed: {0}")]
    StateMutations(#[from] InvalidStateMutations),
    #[error("transient data validation failed: {0}")]
    TransientData(#[from] InvalidTransientData),
}
#[derive(Debug, Error)]
pub enum InvalidSolutionData {
    #[error("must be at least one solution data")]
    Empty,
    #[error("the number of solution data ({0}) exceeds the limit ({MAX_SOLUTION_DATA})")]
    TooMany(usize),
    #[error("data {0} expects too many decision vars {1} (limit: {MAX_DECISION_VARIABLES})")]
    TooManyDecisionVariables(usize, usize),
    #[error("Invalid state mutation entry: {0}")]
    StateMutationEntry(KvError),
    #[error("Invalid transient data entry: {0}")]
    TransientDataEntry(KvError),
    #[error("Decision variable value len {0} exceeds limit {MAX_VALUE_SIZE}")]
    DecVarValueTooLarge(usize),
}
#[derive(Debug, Error)]
pub enum KvError {
    #[error("key with length {0} exceeds limit {MAX_KEY_SIZE}")]
    KeyTooLarge(usize),
    #[error("value with length {0} exceeds limit {MAX_VALUE_SIZE}")]
    ValueTooLarge(usize),
}
#[derive(Debug, Error)]
pub enum InvalidStateMutations {
    #[error("the number of state mutations ({0}) exceeds the limit ({MAX_STATE_MUTATIONS})")]
    TooMany(usize),
    #[error("state mutation pathway {0} out of range of solution data")]
    PathwayOutOfRangeOfSolutionData(u16),
    #[error("attempt to apply multiple mutations to the same slot: {0:?} {1:?}")]
    MultipleMutationsForSlot(PredicateAddress, Key),
}
#[derive(Debug, Error)]
pub enum InvalidTransientData {
    #[error("the number of transient data ({0}) exceeds the limit ({MAX_TRANSIENT_DATA})")]
    TooMany(usize),
}
#[derive(Debug, Error)]
pub enum PredicatesError<E> {
    #[error("{0}")]
    Failed(#[from] PredicateErrors<E>),
    #[error("one or more spawned tasks failed to join: {0}")]
    Join(#[from] tokio::task::JoinError),
    #[error("summing solution data gas overflowed")]
    GasOverflowed,
}
#[derive(Debug, Error)]
pub struct PredicateErrors<E>(pub Vec<(SolutionDataIndex, PredicateError<E>)>);
#[derive(Debug, Error)]
pub enum PredicateError<E> {
    #[error("failed to parse an op during bytecode mapping: {0}")]
    OpsFromBytesError(#[from] FromBytesError),
    #[error("state read execution error: {0}")]
    StateRead(#[from] StateReadError<E>),
    #[error("constraint checking failed: {0}")]
    Constraints(#[from] PredicateConstraintsError),
}
#[derive(Debug, Error)]
#[error("number of solution data decision variables ({data}) differs from predicate ({predicate})")]
pub struct InvalidDecisionVariablesLength {
    pub data: usize,
    pub predicate: u32,
}
#[derive(Debug, Error)]
pub enum PredicateConstraintsError {
    #[error("check failed: {0}")]
    Check(#[from] constraint_vm::error::CheckError),
    #[error("failed to recv: {0}")]
    Recv(#[from] tokio::sync::oneshot::error::RecvError),
}
pub const MAX_DECISION_VARIABLES: u32 = 100;
pub const MAX_SOLUTION_DATA: usize = 100;
pub const MAX_STATE_MUTATIONS: usize = 1000;
pub const MAX_TRANSIENT_DATA: usize = 1000;
pub const MAX_VALUE_SIZE: usize = 10_000;
pub const MAX_KEY_SIZE: usize = 1000;
impl<E: fmt::Display> fmt::Display for PredicateErrors<E> {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        f.write_str("predicate checking failed for one or more solution data:\n")?;
        for (ix, err) in &self.0 {
            f.write_str(&format!("  {ix}: {err}\n"))?;
        }
        Ok(())
    }
}
#[cfg_attr(feature = "tracing", tracing::instrument(skip_all, fields(solution = %content_addr(solution)), err))]
pub fn check(solution: &Solution) -> Result<(), InvalidSolution> {
    check_data(&solution.data)?;
    check_state_mutations(solution)?;
    check_transient_data(solution)?;
    Ok(())
}
fn check_value_size(value: &[Word]) -> Result<(), KvError> {
    if value.len() > MAX_VALUE_SIZE {
        Err(KvError::ValueTooLarge(value.len()))
    } else {
        Ok(())
    }
}
fn check_key_size(value: &[Word]) -> Result<(), KvError> {
    if value.len() > MAX_KEY_SIZE {
        Err(KvError::KeyTooLarge(value.len()))
    } else {
        Ok(())
    }
}
pub fn check_data(data_slice: &[SolutionData]) -> Result<(), InvalidSolutionData> {
    if data_slice.is_empty() {
        return Err(InvalidSolutionData::Empty);
    }
    if data_slice.len() > MAX_SOLUTION_DATA {
        return Err(InvalidSolutionData::TooMany(data_slice.len()));
    }
    for (data_ix, data) in data_slice.iter().enumerate() {
        if data.decision_variables.len() > MAX_DECISION_VARIABLES as usize {
            return Err(InvalidSolutionData::TooManyDecisionVariables(
                data_ix,
                data.decision_variables.len(),
            ));
        }
        for v in &data.decision_variables {
            check_value_size(v).map_err(|_| InvalidSolutionData::DecVarValueTooLarge(v.len()))?;
        }
    }
    Ok(())
}
pub fn check_state_mutations(solution: &Solution) -> Result<(), InvalidSolution> {
    if solution.state_mutations_len() > MAX_STATE_MUTATIONS {
        return Err(InvalidStateMutations::TooMany(solution.state_mutations_len()).into());
    }
    for data in &solution.data {
        let mut mut_keys = HashSet::new();
        for mutation in &data.state_mutations {
            if !mut_keys.insert(&mutation.key) {
                return Err(InvalidStateMutations::MultipleMutationsForSlot(
                    data.predicate_to_solve.clone(),
                    mutation.key.clone(),
                )
                .into());
            }
            check_key_size(&mutation.key).map_err(InvalidSolutionData::StateMutationEntry)?;
            check_value_size(&mutation.value).map_err(InvalidSolutionData::StateMutationEntry)?;
        }
    }
    Ok(())
}
pub fn check_transient_data(solution: &Solution) -> Result<(), InvalidSolution> {
    if solution.transient_data_len() > MAX_TRANSIENT_DATA {
        return Err(InvalidTransientData::TooMany(solution.transient_data_len()).into());
    }
    for data in &solution.data {
        for mutation in &data.transient_data {
            check_key_size(&mutation.key).map_err(InvalidSolutionData::TransientDataEntry)?;
            check_value_size(&mutation.value).map_err(InvalidSolutionData::TransientDataEntry)?;
        }
    }
    Ok(())
}
#[cfg_attr(feature = "tracing", tracing::instrument(skip_all))]
pub async fn check_predicates<SA, SB>(
    pre_state: &SA,
    post_state: &SB,
    solution: Arc<Solution>,
    get_predicate: impl Fn(&PredicateAddress) -> Arc<Predicate>,
    config: Arc<CheckPredicateConfig>,
) -> Result<Gas, PredicatesError<SA::Error>>
where
    SA: Clone + StateRead + Send + Sync + 'static,
    SB: Clone + StateRead<Error = SA::Error> + Send + Sync + 'static,
    SA::Future: Send,
    SB::Future: Send,
    SA::Error: Send,
{
    #[cfg(feature = "tracing")]
    tracing::trace!("{}", essential_hash::content_addr(&*solution));
    let transient_data: Arc<TransientData> =
        Arc::new(essential_constraint_vm::transient_data(&solution));
    let mut set: JoinSet<(_, Result<_, PredicateError<SA::Error>>)> = JoinSet::new();
    for (solution_data_index, data) in solution.data.iter().enumerate() {
        let solution_data_index: SolutionDataIndex = solution_data_index
            .try_into()
            .expect("solution data index already validated");
        let predicate = get_predicate(&data.predicate_to_solve);
        let solution = solution.clone();
        let transient_data = transient_data.clone();
        let pre_state: SA = pre_state.clone();
        let post_state: SB = post_state.clone();
        let config = config.clone();
        let future = async move {
            let pre_state = pre_state;
            let post_state = post_state;
            let res = check_predicate(
                &pre_state,
                &post_state,
                solution,
                predicate,
                solution_data_index,
                &config,
                transient_data,
            )
            .await;
            (solution_data_index, res)
        };
        #[cfg(feature = "tracing")]
        let future = future.in_current_span();
        set.spawn(future);
    }
    let mut total_gas: u64 = 0;
    let mut failed = vec![];
    while let Some(res) = set.join_next().await {
        let (solution_data_ix, res) = res?;
        let g = match res {
            Ok(ok) => ok,
            Err(e) => {
                failed.push((solution_data_ix, e));
                if config.collect_all_failures {
                    continue;
                } else {
                    return Err(PredicateErrors(failed).into());
                }
            }
        };
        total_gas = total_gas
            .checked_add(g)
            .ok_or(PredicatesError::GasOverflowed)?;
    }
    if !failed.is_empty() {
        return Err(PredicateErrors(failed).into());
    }
    Ok(total_gas)
}
#[cfg_attr(
    feature = "tracing",
    tracing::instrument(
        skip_all,
        fields(
            solution = %format!("{}", content_addr(&*solution))[0..8],
            data={solution_data_index},
        ),
    ),
)]
pub async fn check_predicate<SA, SB>(
    pre_state: &SA,
    post_state: &SB,
    solution: Arc<Solution>,
    predicate: Arc<Predicate>,
    solution_data_index: SolutionDataIndex,
    config: &CheckPredicateConfig,
    transient_data: Arc<TransientData>,
) -> Result<Gas, PredicateError<SA::Error>>
where
    SA: StateRead,
    SB: StateRead<Error = SA::Error>,
{
    let (state_read_gas, pre_slots, post_slots) = predicate_state_slots(
        pre_state,
        post_state,
        &solution,
        &predicate.state_read,
        solution_data_index,
        &transient_data,
    )
    .await?;
    check_predicate_constraints(
        solution,
        solution_data_index,
        predicate.clone(),
        Arc::from(pre_slots.into_boxed_slice()),
        Arc::from(post_slots.into_boxed_slice()),
        config,
        transient_data,
    )
    .await?;
    Ok(state_read_gas)
}
pub type PreStateSlots = Vec<Vec<Word>>;
pub type PostStateSlots = Vec<Vec<Word>>;
#[cfg_attr(feature = "tracing", tracing::instrument(skip_all))]
pub async fn predicate_state_slots<SA, SB>(
    pre_state: &SA,
    post_state: &SB,
    solution: &Solution,
    predicate_state_reads: &[StateReadBytecode],
    solution_data_index: SolutionDataIndex,
    transient_data: &TransientData,
) -> Result<(Gas, PreStateSlots, PostStateSlots), PredicateError<SA::Error>>
where
    SA: StateRead,
    SB: StateRead<Error = SA::Error>,
{
    let mut total_gas = 0;
    let mut pre_slots: Vec<Vec<Word>> = Vec::new();
    let mut post_slots: Vec<Vec<Word>> = Vec::new();
    let mutable_keys = constraint_vm::mut_keys_set(solution, solution_data_index);
    let solution_access =
        SolutionAccess::new(solution, solution_data_index, &mutable_keys, transient_data);
    for (state_read_index, state_read) in predicate_state_reads.iter().enumerate() {
        #[cfg(not(feature = "tracing"))]
        let _ = state_read_index;
        let state_read_mapped = BytecodeMapped::try_from(&state_read[..])?;
        let future = read_state_slots(
            &state_read_mapped,
            Access {
                solution: solution_access,
                state_slots: StateSlots {
                    pre: &pre_slots,
                    post: &post_slots,
                },
            },
            pre_state,
        );
        #[cfg(feature = "tracing")]
        let (gas, new_pre_slots) = future
            .instrument(tracing::info_span!("pre", ix = state_read_index))
            .await?;
        #[cfg(not(feature = "tracing"))]
        let (gas, new_pre_slots) = future.await?;
        total_gas += gas;
        pre_slots.extend(new_pre_slots);
        let future = read_state_slots(
            &state_read_mapped,
            Access {
                solution: solution_access,
                state_slots: StateSlots {
                    pre: &pre_slots,
                    post: &post_slots,
                },
            },
            post_state,
        );
        #[cfg(feature = "tracing")]
        let (gas, new_post_slots) = future
            .instrument(tracing::info_span!("post", ix = state_read_index))
            .await?;
        #[cfg(not(feature = "tracing"))]
        let (gas, new_post_slots) = future.await?;
        total_gas += gas;
        post_slots.extend(new_post_slots);
    }
    Ok((total_gas, pre_slots, post_slots))
}
async fn read_state_slots<S>(
    bytecode_mapped: &BytecodeMapped<&[u8]>,
    access: Access<'_>,
    state_read: &S,
) -> Result<(Gas, Vec<Vec<Word>>), state_read_vm::error::StateReadError<S::Error>>
where
    S: StateRead,
{
    let mut vm = state_read_vm::Vm::default();
    let gas_spent = vm
        .exec_bytecode(
            bytecode_mapped,
            access,
            state_read,
            &|_: &state_read_vm::asm::Op| 1,
            GasLimit::UNLIMITED,
        )
        .await?;
    Ok((gas_spent, vm.into_state_slots()))
}
#[cfg_attr(feature = "tracing", tracing::instrument(skip_all, "check"))]
pub async fn check_predicate_constraints(
    solution: Arc<Solution>,
    solution_data_index: SolutionDataIndex,
    predicate: Arc<Predicate>,
    pre_slots: Arc<StateSlotSlice>,
    post_slots: Arc<StateSlotSlice>,
    config: &CheckPredicateConfig,
    transient_data: Arc<TransientData>,
) -> Result<(), PredicateConstraintsError> {
    let r = check_predicate_constraints_parallel(
        solution.clone(),
        solution_data_index,
        predicate.clone(),
        pre_slots.clone(),
        post_slots.clone(),
        config,
        transient_data.clone(),
    )
    .await;
    #[cfg(feature = "tracing")]
    if let Err(ref err) = r {
        tracing::trace!("error checking constraints: {}", err);
    }
    r
}
async fn check_predicate_constraints_parallel(
    solution: Arc<Solution>,
    solution_data_index: SolutionDataIndex,
    predicate: Arc<Predicate>,
    pre_slots: Arc<StateSlotSlice>,
    post_slots: Arc<StateSlotSlice>,
    config: &CheckPredicateConfig,
    transient_data: Arc<TransientData>,
) -> Result<(), PredicateConstraintsError> {
    let mut handles = Vec::with_capacity(predicate.constraints.len());
    for ix in 0..predicate.constraints.len() {
        let (tx, rx) = tokio::sync::oneshot::channel();
        handles.push(rx);
        let solution = solution.clone();
        let transient_data = transient_data.clone();
        let pre_slots = pre_slots.clone();
        let post_slots = post_slots.clone();
        let predicate = predicate.clone();
        #[cfg(feature = "tracing")]
        let span = tracing::Span::current();
        rayon::spawn(move || {
            #[cfg(feature = "tracing")]
            let span = tracing::trace_span!(parent: &span, "constraint", ix = ix as u32);
            #[cfg(feature = "tracing")]
            let guard = span.enter();
            let mutable_keys = constraint_vm::mut_keys_set(&solution, solution_data_index);
            let solution_access = SolutionAccess::new(
                &solution,
                solution_data_index,
                &mutable_keys,
                &transient_data,
            );
            let access = Access {
                solution: solution_access,
                state_slots: StateSlots {
                    pre: &pre_slots,
                    post: &post_slots,
                },
            };
            let res = constraint_vm::eval_bytecode_iter(
                predicate
                    .constraints
                    .get(ix)
                    .expect("Safe due to above len check")
                    .iter()
                    .copied(),
                access,
            );
            let _ = tx.send((ix, res));
            #[cfg(feature = "tracing")]
            drop(guard)
        })
    }
    let mut failed = Vec::new();
    let mut unsatisfied = Vec::new();
    for handle in handles {
        let (ix, res): (usize, Result<bool, _>) = handle.await?;
        match res {
            Err(err) => {
                failed.push((ix, err));
                if !config.collect_all_failures {
                    break;
                }
            }
            Ok(b) if !b => unsatisfied.push(ix),
            _ => (),
        }
    }
    if !failed.is_empty() {
        return Err(CheckError::from(ConstraintErrors(failed)).into());
    }
    if !unsatisfied.is_empty() {
        return Err(CheckError::from(ConstraintsUnsatisfied(unsatisfied)).into());
    }
    Ok(())
}