essential_check/
solution.rs

1//! Items related to validating `Solution`s and `SolutionSet`s.
2
3use crate::{
4    types::{
5        predicate::Predicate,
6        solution::{Solution, SolutionIndex, SolutionSet},
7        Key, PredicateAddress, Word,
8    },
9    vm::{
10        self,
11        asm::{self, FromBytesError},
12        Access, Gas, GasLimit, Memory, Stack,
13    },
14};
15#[cfg(feature = "tracing")]
16use essential_hash::content_addr;
17use essential_types::{predicate::Program, ContentAddress, Value};
18use essential_vm::{StateRead, StateReads};
19use std::{
20    collections::{BTreeMap, HashMap, HashSet},
21    fmt,
22    sync::Arc,
23};
24use thiserror::Error;
25
26use rayon::prelude::*;
27
28#[cfg(test)]
29mod tests;
30
31#[cfg(test)]
32mod test_graph_ops;
33
34#[cfg(test)]
35mod test_state_read_fallback;
36
37/// Configuration options passed to [`check_predicate`].
38#[derive(Clone, Debug, Default, Eq, Hash, PartialEq)]
39pub struct CheckPredicateConfig {
40    /// Whether or not to wait and collect all failures after a single state
41    /// read or constraint fails.
42    ///
43    /// Potentially useful for debugging or testing tools.
44    ///
45    /// Default: `false`
46    pub collect_all_failures: bool,
47}
48
49/// Required impl for retrieving access to any [`Solution`]'s [`Predicate`]s during check.
50pub trait GetPredicate {
51    /// Provides immediate access to the predicate with the given content address.
52    ///
53    /// This is called by [`check_set_predicates`] for each predicate in each solution being checked.
54    ///
55    /// All necessary programs are assumed to have been read from storage and
56    /// validated ahead of time.
57    fn get_predicate(&self, addr: &PredicateAddress) -> Arc<Predicate>;
58}
59
60/// Required impl for retrieving access to any [`Predicate`]'s [`Program`]s during check.
61pub trait GetProgram {
62    /// Provides immediate access to the program with the given content address.
63    ///
64    /// This is called by [`check_set_predicates`] for each node within each predicate for
65    /// each solution being checked.
66    ///
67    /// All necessary programs are assumed to have been read from storage and
68    /// validated ahead of time.
69    fn get_program(&self, ca: &ContentAddress) -> Arc<Program>;
70}
71
72#[derive(Debug)]
73/// Context for checking a predicate
74pub struct Ctx<'a> {
75    /// The mode the check is running in.
76    pub run_mode: RunMode,
77    /// The global cache of outputs, indexed by node index.
78    pub cache: &'a mut Cache,
79}
80
81/// Cache of parent outputs, indexed by node index for a predicate.
82pub type Cache = HashMap<u16, Arc<(Stack, Memory)>>;
83
84/// The node context in which a `Program` is evaluated (see [`run_program`]).
85struct ProgramCtx {
86    /// The outputs from the parent nodes.
87    parents: Vec<Arc<(Stack, Memory)>>,
88    /// If this node is a leaf.
89    leaf: bool,
90}
91
92/// The outputs of checking a solution set.
93#[derive(Debug, PartialEq)]
94pub struct Outputs {
95    /// The total gas spent.
96    pub gas: Gas,
97    /// The data outputs from solving each predicate.
98    pub data: Vec<DataFromSolution>,
99}
100
101/// The data outputs from solving a particular predicate.
102#[derive(Debug, PartialEq)]
103pub struct DataFromSolution {
104    /// The index of the solution that produced this data.
105    pub solution_index: SolutionIndex,
106    /// The data output from the solution.
107    pub data: Vec<DataOutput>,
108}
109
110/// The output of a program execution.
111#[derive(Debug, PartialEq)]
112enum ProgramOutput {
113    /// The program output is a boolean value
114    /// indicating whether the constraint was satisfied.
115    Satisfied(bool),
116    /// The program output is data.
117    DataOutput(DataOutput),
118}
119
120/// Types of data output from a program.
121#[derive(Debug, PartialEq)]
122pub enum DataOutput {
123    /// The program output is the memory.
124    Memory(Memory),
125}
126
127/// The output of a program depends on
128/// whether it is a leaf or a parent.
129enum Output {
130    /// Leaf nodes output bools or data.
131    Leaf(ProgramOutput),
132    /// Parent nodes output a stack and memory.
133    Parent(Arc<(Stack, Memory)>),
134}
135
136/// The mode the check is running in.
137#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
138pub enum RunMode {
139    /// Generating outputs
140    #[default]
141    Outputs,
142    /// Checking outputs
143    Checks,
144}
145
146/// [`check_set`] error.
147#[derive(Debug, Error)]
148pub enum InvalidSolutionSet {
149    /// Invalid solution.
150    #[error("invalid solution: {0}")]
151    Solution(#[from] InvalidSolution),
152    /// State mutations validation failed.
153    #[error("state mutations validation failed: {0}")]
154    StateMutations(#[from] InvalidSetStateMutations),
155}
156
157/// [`check_solutions`] error.
158#[derive(Debug, Error)]
159pub enum InvalidSolution {
160    /// There must be at least one solution.
161    #[error("must be at least one solution")]
162    Empty,
163    /// The number of solutions exceeds the limit.
164    #[error("the number of solutions ({0}) exceeds the limit ({MAX_SOLUTIONS})")]
165    TooMany(usize),
166    /// A solution's predicate data length exceeds the limit.
167    #[error("solution {0}'s predicate data length exceeded {1} (limit: {MAX_PREDICATE_DATA})")]
168    PredicateDataLenExceeded(usize, usize),
169    /// Invalid state mutation entry.
170    #[error("Invalid state mutation entry: {0}")]
171    StateMutationEntry(KvError),
172    /// Predicate data value too large.
173    #[error("Predicate data value len {0} exceeds limit {MAX_VALUE_SIZE}")]
174    PredDataValueTooLarge(usize),
175}
176
177/// Error with a slot key or value.
178#[derive(Debug, Error)]
179pub enum KvError {
180    /// The key is too large.
181    #[error("key with length {0} exceeds limit {MAX_KEY_SIZE}")]
182    KeyTooLarge(usize),
183    /// The value is too large.
184    #[error("value with length {0} exceeds limit {MAX_VALUE_SIZE}")]
185    ValueTooLarge(usize),
186}
187
188/// [`check_set_state_mutations`] error.
189#[derive(Debug, Error)]
190pub enum InvalidSetStateMutations {
191    /// The number of state mutations exceeds the limit.
192    #[error("the number of state mutations ({0}) exceeds the limit ({MAX_STATE_MUTATIONS})")]
193    TooMany(usize),
194    /// Discovered multiple mutations to the same slot.
195    #[error("attempt to apply multiple mutations to the same slot: {0:?} {1:?}")]
196    MultipleMutationsForSlot(PredicateAddress, Key),
197}
198
199/// [`check_set_predicates`] error.
200#[derive(Debug, Error)]
201pub enum PredicatesError<E> {
202    /// One or more solution failed their associated predicate checks.
203    #[error("{0}")]
204    Failed(#[from] PredicateErrors<E>),
205    /// Summing solution gas resulted in overflow.
206    #[error("summing solution gas overflowed")]
207    GasOverflowed,
208    /// Tried to compute mutations on solution set with existing mutations.
209    #[error("tried to compute mutations on solution set with existing mutations")]
210    ExistingMutations,
211}
212
213/// Predicate checking failed for the solution at the given indices.
214#[derive(Debug, Error)]
215pub struct PredicateErrors<E>(pub Vec<(SolutionIndex, PredicateError<E>)>);
216
217/// [`check_predicate`] error.
218#[derive(Debug, Error)]
219pub enum PredicateError<E> {
220    /// Failed to retrieve edges for a node, indicating that the predicate's graph is invalid.
221    #[error("failed to retrieve edges for node {0} indicating an invalid graph")]
222    InvalidNodeEdges(usize),
223    /// The execution of one or more programs failed.
224    #[error("one or more program execution errors occurred: {0}")]
225    ProgramErrors(#[from] ProgramErrors<E>),
226    /// One or more of the constraints were unsatisfied.
227    #[error("one or more constraints unsatisfied: {0}")]
228    ConstraintsUnsatisfied(#[from] ConstraintsUnsatisfied),
229    /// One or more of the mutations were invalid.
230    #[error(transparent)]
231    Mutations(#[from] MutationsError),
232}
233
234/// Program execution failed for the programs at the given node indices.
235#[derive(Debug, Error)]
236pub struct ProgramErrors<E>(Vec<(usize, ProgramError<E>)>);
237
238/// An error occurring during a program task.
239#[derive(Debug, Error)]
240pub enum ProgramError<E> {
241    /// Failed to parse ops from bytecode during bytecode mapping.
242    #[error("failed to parse an op during bytecode mapping: {0}")]
243    OpsFromBytesError(#[from] FromBytesError),
244    /// Concatenating the parent program [`Stack`]s caused an overflow.
245    #[error("concatenating parent program `Stack`s caused an overflow: {0}")]
246    ParentStackConcatOverflow(#[from] vm::error::StackError),
247    /// Concatenating the parent program [`Memory`] slices caused an overflow.
248    #[error("concatenating parent program `Memory` slices caused an overflow: {0}")]
249    ParentMemoryConcatOverflow(#[from] vm::error::MemoryError),
250    /// VM execution resulted in an error.
251    #[error("VM execution error: {0}")]
252    Vm(#[from] vm::error::ExecError<E>),
253}
254
255/// The index of each constraint that was not satisfied.
256#[derive(Debug, Error)]
257pub struct ConstraintsUnsatisfied(pub Vec<usize>);
258
259/// Error with computing mutations.
260#[derive(Debug, Error)]
261pub enum MutationsError {
262    /// Duplicate mutations for the same key.
263    #[error("duplicate mutations for the same key: {0:?}")]
264    DuplicateMutations(Key),
265    /// Error decoding mutations.
266    #[error(transparent)]
267    DecodeError(#[from] essential_types::solution::decode::MutationDecodeError),
268}
269
270/// Maximum number of predicate data of a solution.
271pub const MAX_PREDICATE_DATA: u32 = 100;
272/// Maximum number of solutions within a solution set.
273pub const MAX_SOLUTIONS: usize = 100;
274/// Maximum number of state mutations of a solution.
275pub const MAX_STATE_MUTATIONS: usize = 1000;
276/// Maximum number of words in a slot value.
277pub const MAX_VALUE_SIZE: usize = 10_000;
278/// Maximum number of words in a slot key.
279pub const MAX_KEY_SIZE: usize = 1000;
280
281impl<E: fmt::Display + fmt::Debug> fmt::Display for PredicateErrors<E> {
282    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
283        f.write_str("predicate checking failed for one or more solutions:\n")?;
284        for (ix, err) in &self.0 {
285            f.write_str(&format!("  {ix}: {err}\n"))?;
286        }
287        Ok(())
288    }
289}
290
291impl<E: fmt::Display + fmt::Debug> fmt::Display for ProgramErrors<E> {
292    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
293        f.write_str("the programs at the following node indices failed: \n")?;
294        for (node_ix, err) in &self.0 {
295            f.write_str(&format!("  {node_ix}: {:#?}\n", err))?;
296        }
297        Ok(())
298    }
299}
300
301impl fmt::Display for ConstraintsUnsatisfied {
302    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
303        f.write_str("the constraints at the following indices returned false: \n")?;
304        for ix in &self.0 {
305            f.write_str(&format!("  {ix}\n"))?;
306        }
307        Ok(())
308    }
309}
310
311impl<F> GetPredicate for F
312where
313    F: Fn(&PredicateAddress) -> Arc<Predicate>,
314{
315    fn get_predicate(&self, addr: &PredicateAddress) -> Arc<Predicate> {
316        (*self)(addr)
317    }
318}
319
320impl<F> GetProgram for F
321where
322    F: Fn(&ContentAddress) -> Arc<Program>,
323{
324    fn get_program(&self, ca: &ContentAddress) -> Arc<Program> {
325        (*self)(ca)
326    }
327}
328
329impl GetPredicate for HashMap<PredicateAddress, Arc<Predicate>> {
330    fn get_predicate(&self, addr: &PredicateAddress) -> Arc<Predicate> {
331        self[addr].clone()
332    }
333}
334
335impl GetProgram for HashMap<ContentAddress, Arc<Program>> {
336    fn get_program(&self, ca: &ContentAddress) -> Arc<Program> {
337        self[ca].clone()
338    }
339}
340
341impl<T: GetPredicate> GetPredicate for Arc<T> {
342    fn get_predicate(&self, addr: &PredicateAddress) -> Arc<Predicate> {
343        (**self).get_predicate(addr)
344    }
345}
346
347impl<T: GetProgram> GetProgram for Arc<T> {
348    fn get_program(&self, ca: &ContentAddress) -> Arc<Program> {
349        (**self).get_program(ca)
350    }
351}
352
353/// Validate a solution set, to the extent it can be validated without reference to
354/// its associated predicates.
355///
356/// This includes solutions and state mutations.
357#[cfg_attr(feature = "tracing", tracing::instrument(skip_all, fields(solution = %content_addr(set)), err))]
358pub fn check_set(set: &SolutionSet) -> Result<(), InvalidSolutionSet> {
359    check_solutions(&set.solutions)?;
360    check_set_state_mutations(set)?;
361    Ok(())
362}
363
364fn check_value_size(value: &[Word]) -> Result<(), KvError> {
365    if value.len() > MAX_VALUE_SIZE {
366        Err(KvError::ValueTooLarge(value.len()))
367    } else {
368        Ok(())
369    }
370}
371
372fn check_key_size(value: &[Word]) -> Result<(), KvError> {
373    if value.len() > MAX_KEY_SIZE {
374        Err(KvError::KeyTooLarge(value.len()))
375    } else {
376        Ok(())
377    }
378}
379
380/// Validate the solution set's slice of [`Solution`]s.
381pub fn check_solutions(solutions: &[Solution]) -> Result<(), InvalidSolution> {
382    // Validate solution.
383    // Ensure that at solution has at least one solution.
384    if solutions.is_empty() {
385        return Err(InvalidSolution::Empty);
386    }
387    // Ensure that solution length is below limit length.
388    if solutions.len() > MAX_SOLUTIONS {
389        return Err(InvalidSolution::TooMany(solutions.len()));
390    }
391
392    // Check whether the predicate data length has been exceeded.
393    for (solution_ix, solution) in solutions.iter().enumerate() {
394        // Ensure the length limit is not exceeded.
395        if solution.predicate_data.len() > MAX_PREDICATE_DATA as usize {
396            return Err(InvalidSolution::PredicateDataLenExceeded(
397                solution_ix,
398                solution.predicate_data.len(),
399            ));
400        }
401        for v in &solution.predicate_data {
402            check_value_size(v).map_err(|_| InvalidSolution::PredDataValueTooLarge(v.len()))?;
403        }
404    }
405    Ok(())
406}
407
408/// Validate the solution set's state mutations.
409pub fn check_set_state_mutations(set: &SolutionSet) -> Result<(), InvalidSolutionSet> {
410    // Validate state mutations.
411    // Ensure that the solution set's state mutations length is below limit length.
412    if set.state_mutations_len() > MAX_STATE_MUTATIONS {
413        return Err(InvalidSetStateMutations::TooMany(set.state_mutations_len()).into());
414    }
415
416    // Ensure that no more than one mutation per slot is proposed.
417    for solution in &set.solutions {
418        let mut mut_keys = HashSet::new();
419        for mutation in &solution.state_mutations {
420            if !mut_keys.insert(&mutation.key) {
421                return Err(InvalidSetStateMutations::MultipleMutationsForSlot(
422                    solution.predicate_to_solve.clone(),
423                    mutation.key.clone(),
424                )
425                .into());
426            }
427            // Check key length.
428            check_key_size(&mutation.key).map_err(InvalidSolution::StateMutationEntry)?;
429            // Check value length.
430            check_value_size(&mutation.value).map_err(InvalidSolution::StateMutationEntry)?;
431        }
432    }
433
434    Ok(())
435}
436
437fn decode_mutations<E>(
438    outputs: Outputs,
439    mut set: SolutionSet,
440) -> Result<SolutionSet, PredicatesError<E>> {
441    // For each output check if there are any state mutations and apply them.
442    for output in outputs.data {
443        // No two outputs can point to the same solution index.
444        // Get the solution that these outputs came from.
445        let s = &mut set.solutions[output.solution_index as usize];
446
447        // Set to check for duplicate mutations.
448        let mut mut_set = HashSet::new();
449
450        // For each memory output decode the mutations and apply them.
451        for data in output.data {
452            match data {
453                DataOutput::Memory(memory) => {
454                    for mutation in essential_types::solution::decode::decode_mutations(&memory)
455                        .map_err(|e| {
456                            PredicatesError::Failed(PredicateErrors(vec![(
457                                output.solution_index,
458                                PredicateError::Mutations(MutationsError::DecodeError(e)),
459                            )]))
460                        })?
461                    {
462                        // Check for duplicate mutation keys.
463                        if !mut_set.insert(mutation.key.clone()) {
464                            return Err(PredicatesError::Failed(PredicateErrors(vec![(
465                                output.solution_index,
466                                PredicateError::Mutations(MutationsError::DuplicateMutations(
467                                    mutation.key.clone(),
468                                )),
469                            )])));
470                        }
471
472                        // Apply the mutation.
473                        s.state_mutations.push(mutation);
474                    }
475                }
476            }
477        }
478    }
479    Ok(set)
480}
481
482/// Internal post state used for mutations.
483#[derive(Debug, Default)]
484struct PostState {
485    /// Contract => Key => Value
486    state: HashMap<ContentAddress, HashMap<Key, Value>>,
487}
488
489/// Arc wrapper for [`PostState`] to allow for cloning.
490/// Must take the same error type as the pre state.
491#[derive(Clone, Debug, Default)]
492struct PostStateArc<S>(Arc<PostState>, S)
493where
494    S: StateRead;
495
496impl<S> StateRead for PostStateArc<S>
497where
498    S: StateRead,
499{
500    type Error = S::Error;
501
502    fn key_range(
503        &self,
504        contract_addr: ContentAddress,
505        key: Key,
506        num_values: usize,
507    ) -> Result<Vec<Vec<essential_types::Word>>, Self::Error> {
508        read_or_fallback(&self.0, &self.1, contract_addr, key, num_values)
509    }
510}
511
512fn read_or_fallback<S: StateRead>(
513    post: &PostState,
514    state: &S,
515    contract_addr: ContentAddress,
516    mut key: Key,
517    num_values: usize,
518) -> Result<Vec<Vec<Word>>, S::Error> {
519    let mut out = Vec::with_capacity(num_values);
520    match post.state.get(&contract_addr) {
521        Some(contract_state) => {
522            for _ in 0..num_values {
523                match contract_state.get(&key) {
524                    Some(value) => out.push(value.clone()),
525                    None => {
526                        let mut value = state.key_range(contract_addr.clone(), key.clone(), 1)?;
527                        out.push(value.pop().unwrap_or_default());
528                    }
529                }
530                match next_key(key) {
531                    Some(next_key) => key = next_key,
532                    None => break,
533                }
534            }
535        }
536        None => {
537            out = state.key_range(contract_addr, key.clone(), num_values)?;
538        }
539    }
540    Ok(out)
541}
542
543/// Get the next key in the range of keys.
544fn next_key(mut key: Key) -> Option<Key> {
545    for w in key.iter_mut().rev() {
546        match *w {
547            Word::MAX => *w = Word::MIN,
548            _ => {
549                *w += 1;
550                return Some(key);
551            }
552        }
553    }
554    None
555}
556
557/// Check the given solution set against the given predicates and
558/// and compute the post state mutations for this set.
559///
560/// This is a two-pass check. The first pass generates the outputs
561/// and does not run any post state reads.
562/// The second pass checks the outputs and runs the post state reads.
563#[cfg_attr(feature = "tracing", tracing::instrument(skip_all))]
564pub fn check_and_compute_solution_set_two_pass<S>(
565    state: &S,
566    solution_set: SolutionSet,
567    get_predicate: impl GetPredicate + Sync + Clone,
568    get_program: impl 'static + Clone + GetProgram + Send + Sync,
569    config: Arc<CheckPredicateConfig>,
570) -> Result<(Gas, SolutionSet), PredicatesError<S::Error>>
571where
572    S: Clone + StateRead + Send + Sync + 'static,
573    S::Error: Send + Sync + 'static,
574{
575    // Create an empty post state,
576    let post_state = PostStateArc(Arc::new(PostState::default()), state.clone());
577
578    // Create an empty cache.
579    let mut cache = HashMap::new();
580
581    // Generate the outputs
582    let (mut gas, solution_set) = check_and_compute_solution_set(
583        &(state.clone(), post_state.clone()),
584        solution_set,
585        get_predicate.clone(),
586        get_program.clone(),
587        config.clone(),
588        RunMode::Outputs,
589        &mut cache,
590    )?;
591
592    // Get the post state back.
593    let mut post_state =
594        Arc::try_unwrap(post_state.0).expect("post state should have one reference");
595
596    // Apply the state mutations to the post state.
597    for solution in &solution_set.solutions {
598        for mutation in &solution.state_mutations {
599            post_state
600                .state
601                .entry(solution.predicate_to_solve.contract.clone())
602                .or_default()
603                .insert(mutation.key.clone(), mutation.value.clone());
604        }
605    }
606
607    // Put the post state back into an arc.
608    let post_state = PostStateArc(Arc::new(post_state), state.clone());
609
610    // Check the outputs
611    let (g, solution_set) = check_and_compute_solution_set(
612        &(state.clone(), post_state.clone()),
613        solution_set,
614        get_predicate,
615        get_program,
616        config,
617        RunMode::Checks,
618        &mut cache,
619    )?;
620
621    // Add the total gas
622    gas = gas.saturating_add(g);
623
624    // Return solutions set
625    Ok((gas, solution_set))
626}
627
628/// Check the given solution set against the given predicates and
629/// and compute the post state mutations for this set.
630#[cfg_attr(feature = "tracing", tracing::instrument(skip_all))]
631pub fn check_and_compute_solution_set<S>(
632    state: &S,
633    solution_set: SolutionSet,
634    get_predicate: impl GetPredicate + Sync,
635    get_program: impl 'static + Clone + GetProgram + Send + Sync,
636    config: Arc<CheckPredicateConfig>,
637    run_mode: RunMode,
638    cache: &mut HashMap<SolutionIndex, Cache>,
639) -> Result<(Gas, SolutionSet), PredicatesError<S::Error>>
640where
641    S: Clone + StateReads + Send + Sync + 'static,
642    S::Error: Send,
643{
644    // Check the set and gather any outputs.
645    let set = Arc::new(solution_set);
646    let outputs = check_set_predicates(
647        state,
648        set.clone(),
649        get_predicate,
650        get_program,
651        config,
652        run_mode,
653        cache,
654    )?;
655
656    // Safe to unwrap the arc here as we have no other references.
657    let set = Arc::try_unwrap(set).expect("set should have one reference");
658
659    // Get the gas
660    let gas = outputs.gas;
661
662    let set = decode_mutations(outputs, set)?;
663
664    Ok((gas, set))
665}
666
667/// Checks all of a [`SolutionSet`]'s [`Solution`]s against their associated [`Predicate`]s.
668///
669/// For each solution, we load the associated predicate and its programs and execute each
670/// in parallel and in topological order. The leaf nodes are treated as constraints or data outputs and if
671/// any constraint returns `false`, the whole solution set is considered to be invalid.
672///
673/// **NOTE:** This assumes that the given `SolutionSet` and all `Predicate`s have already
674/// been independently validated using [`solution::check_set`][check_set] and
675/// [`predicate::check`][crate::predicate::check] respectively.
676///
677/// ## Arguments
678///
679/// - `pre_state` must provide access to state *prior to* mutations being applied.
680/// - `post_state` must provide access to state *post* mutations being applied.
681///
682/// Returns the total gas spent.
683pub fn check_set_predicates<S>(
684    state: &S,
685    solution_set: Arc<SolutionSet>,
686    get_predicate: impl GetPredicate + Sync,
687    get_program: impl 'static + Clone + GetProgram + Send + Sync,
688    config: Arc<CheckPredicateConfig>,
689    run_mode: RunMode,
690    cache: &mut HashMap<SolutionIndex, Cache>,
691) -> Result<Outputs, PredicatesError<S::Error>>
692where
693    S: Clone + StateReads + Send + Sync + 'static,
694    S::Error: Send,
695{
696    #[cfg(feature = "tracing")]
697    tracing::trace!("{}", essential_hash::content_addr(&*solution_set));
698
699    let caches: Vec<_> = (0..solution_set.solutions.len())
700        .map(|i| {
701            let cache = cache.entry(i as u16).or_default();
702            core::mem::take(cache)
703        })
704        .collect();
705    // Check each solution in parallel.
706    let (ok, failed): (Vec<_>, Vec<_>) = solution_set
707        .solutions
708        .par_iter()
709        .zip(caches)
710        .enumerate()
711        .map(|(solution_index, (solution, mut cache))| {
712            let predicate = get_predicate.get_predicate(&solution.predicate_to_solve);
713            let solution_set = solution_set.clone();
714            let state = state.clone();
715            let config = config.clone();
716            let get_program = get_program.clone();
717
718            let res = check_predicate(
719                &state,
720                solution_set,
721                predicate,
722                get_program,
723                solution_index
724                    .try_into()
725                    .expect("solution index already validated"),
726                &config,
727                Ctx {
728                    run_mode,
729                    cache: &mut cache,
730                },
731            );
732
733            match res {
734                Ok(ok) => Ok((solution_index as u16, ok, cache)),
735                Err(e) => Err((solution_index as u16, e)),
736            }
737        })
738        .partition(Result::is_ok);
739
740    // If any predicates failed, return an error.
741    if !failed.is_empty() {
742        return Err(PredicateErrors(failed.into_iter().map(Result::unwrap_err).collect()).into());
743    }
744
745    // Calculate gas used.
746    let mut total_gas: u64 = 0;
747    let outputs = ok
748        .into_iter()
749        .map(Result::unwrap)
750        .map(|(solution_index, (gas, data_outputs), c)| {
751            let output = DataFromSolution {
752                solution_index,
753                data: data_outputs,
754            };
755            total_gas = total_gas.saturating_add(gas);
756            *cache.get_mut(&solution_index).expect("cache should exist") = c;
757            output
758        })
759        .collect();
760
761    Ok(Outputs {
762        gas: total_gas,
763        data: outputs,
764    })
765}
766
767/// Checks the predicate of the solution within the given set at the given `solution_index`.
768///
769/// Spawns a rayon task for each of the predicate's nodes to execute in parallel
770/// once their inputs are ready.
771///
772/// **NOTE:** This assumes that the given `SolutionSet` and `Predicate` have been
773/// independently validated using [`solution::check_set`][check_set]
774/// and [`predicate::check`][crate::predicate::check] respectively.
775///
776/// ## Arguments
777///
778/// - `pre_state` must provide access to state *prior to* mutations being applied.
779/// - `post_state` must provide access to state *post* mutations being applied.
780/// - `solution_index` represents the solution within `solution_set.solutions` that
781///   claims to solve this predicate.
782///
783/// Returns the total gas spent.
784pub fn check_predicate<S>(
785    state: &S,
786    solution_set: Arc<SolutionSet>,
787    predicate: Arc<Predicate>,
788    get_program: impl GetProgram + Send + Sync + 'static,
789    solution_index: SolutionIndex,
790    config: &CheckPredicateConfig,
791    ctx: Ctx,
792) -> Result<(Gas, Vec<DataOutput>), PredicateError<S::Error>>
793where
794    S: Clone + StateReads + Send + Sync + 'static,
795    S::Error: Send,
796{
797    let p = predicate.clone();
798
799    // Run all nodes that have all their inputs in parallel
800    let run = |ix: u16, parents: Vec<Arc<(Stack, Memory)>>| {
801        let program = get_program.get_program(&predicate.nodes[ix as usize].program_address);
802        let ctx = ProgramCtx {
803            parents,
804            leaf: predicate
805                .node_edges(ix as usize)
806                .expect("This is already checked")
807                .is_empty(),
808        };
809        let res = run_program(
810            state.clone(),
811            solution_set.clone(),
812            solution_index,
813            program,
814            ctx,
815        );
816        (ix, res)
817    };
818
819    check_predicate_inner(run, p, config, &get_program, ctx)
820}
821
822/// Includes nodes with no parents
823fn create_parent_map<E>(
824    predicate: &Predicate,
825) -> Result<BTreeMap<u16, Vec<u16>>, PredicateError<E>> {
826    let mut nodes: BTreeMap<u16, Vec<u16>> = BTreeMap::new();
827    // For each node add it their children's parents set
828    for node_ix in 0..predicate.nodes.len() {
829        // Insert this node incase it's a root
830        nodes.entry(node_ix as u16).or_default();
831
832        // Add any children
833        for edge in predicate
834            .node_edges(node_ix)
835            .ok_or_else(|| PredicateError::InvalidNodeEdges(node_ix))?
836        {
837            // Insert the child if it's not already there and then add this node as a parent
838            nodes.entry(*edge).or_default().push(node_ix as u16);
839        }
840    }
841    Ok(nodes)
842}
843
844fn in_degrees(num_nodes: usize, parent_map: &BTreeMap<u16, Vec<u16>>) -> BTreeMap<u16, usize> {
845    let mut in_degrees = BTreeMap::new();
846    for node in 0..num_nodes {
847        in_degrees.insert(
848            node as u16,
849            parent_map.get(&(node as u16)).map_or(0, |v| v.len()),
850        );
851    }
852
853    in_degrees
854}
855
856fn reduce_in_degrees(in_degrees: &mut BTreeMap<u16, usize>, children: &[u16]) {
857    for child in children {
858        if let Some(in_degree) = in_degrees.get_mut(child) {
859            *in_degree = in_degree.saturating_sub(1);
860        }
861    }
862}
863
864fn find_nodes_with_no_parents(in_degrees: &BTreeMap<u16, usize>) -> Vec<u16> {
865    in_degrees
866        .iter()
867        .filter_map(
868            |(node, in_degree)| {
869                if *in_degree == 0 {
870                    Some(*node)
871                } else {
872                    None
873                }
874            },
875        )
876        .collect()
877}
878
879/// Sorts the nodes in parallel topological order.
880///
881/// ## Note
882/// This is not a perfect ordering as the following:
883/// ```text
884///   A
885///  / \
886/// B   C
887/// |   |
888/// D   E
889///  \ /
890///   F
891/// ```
892/// Results in:
893/// ```text
894/// [[A], [B, C], [D, E], [F]]
895/// ```
896/// If `B` or `C` finish first then they could start on
897/// `D` or `E` respectively but this sort doesn't allow that.
898fn parallel_topo_sort<E>(
899    predicate: &Predicate,
900    parent_map: &BTreeMap<u16, Vec<u16>>,
901) -> Result<Vec<Vec<u16>>, PredicateError<E>> {
902    let mut in_degrees = in_degrees(predicate.nodes.len(), parent_map);
903
904    let mut out = Vec::new();
905    while !in_degrees.is_empty() {
906        let current_level = find_nodes_with_no_parents(&in_degrees);
907        if current_level.is_empty() {
908            // Cycle detected
909            // TODO: Change error
910            return Err(PredicateError::InvalidNodeEdges(0));
911        }
912
913        out.push(current_level.clone());
914
915        for node in current_level {
916            let children = predicate
917                .node_edges(node as usize)
918                .ok_or_else(|| PredicateError::InvalidNodeEdges(node as usize))?;
919            reduce_in_degrees(&mut in_degrees, children);
920            in_degrees.remove(&node);
921        }
922    }
923
924    Ok(out)
925}
926
927fn find_deferred<F>(predicate: &Predicate, is_deferred: F) -> HashSet<u16>
928where
929    F: Fn(&essential_types::predicate::Node) -> bool,
930{
931    let mut deferred = HashSet::new();
932    for (ix, node) in predicate.nodes.iter().enumerate() {
933        if is_deferred(node) {
934            deferred.insert(ix as u16);
935        }
936        if deferred.contains(&(ix as u16)) {
937            for child in predicate.node_edges(ix).expect("Already checked") {
938                deferred.insert(*child);
939            }
940        }
941    }
942    deferred
943}
944
945fn should_cache(node: u16, predicate: &Predicate, deferred: &HashSet<u16>) -> bool {
946    !deferred.contains(&node)
947        && predicate
948            .node_edges(node as usize)
949            .expect("Already checked")
950            .iter()
951            .any(|child| deferred.contains(child))
952}
953
954fn remove_deferred(nodes: Vec<Vec<u16>>, deferred: &HashSet<u16>) -> Vec<Vec<u16>> {
955    nodes
956        .into_iter()
957        .map(|level| {
958            level
959                .into_iter()
960                .filter(|node| !deferred.contains(node))
961                .collect::<Vec<_>>()
962        })
963        .filter(|level| !level.is_empty())
964        .collect()
965}
966
967fn remove_not_deferred(nodes: Vec<Vec<u16>>, deferred: &HashSet<u16>) -> Vec<Vec<u16>> {
968    nodes
969        .into_iter()
970        .map(|level| {
971            level
972                .into_iter()
973                .filter(|node| deferred.contains(node))
974                .collect::<Vec<_>>()
975        })
976        .filter(|level| !level.is_empty())
977        .collect()
978}
979
980/// Handles the checking of a predicate.
981/// - Sorts the nodes into parallel topological order.
982/// - Sets up for the run type.
983/// - Runs the programs in parallel where appropriate.
984/// - Collects the outputs and gas.
985fn check_predicate_inner<F, E>(
986    run: F,
987    predicate: Arc<Predicate>,
988    config: &CheckPredicateConfig,
989    get_program: &(impl GetProgram + Send + Sync + 'static),
990    ctx: Ctx<'_>,
991) -> Result<(Gas, Vec<DataOutput>), PredicateError<E>>
992where
993    F: Fn(u16, Vec<Arc<(Stack, Memory)>>) -> (u16, Result<(Output, u64), ProgramError<E>>)
994        + Send
995        + Sync
996        + Copy,
997    E: Send + std::fmt::Display,
998{
999    // Get the mode we are running and the global cache.
1000    let Ctx { run_mode, cache } = ctx;
1001
1002    // Create the parent map
1003    let parent_map = create_parent_map(&predicate)?;
1004
1005    // Create a parallel topological sort of the nodes
1006    let sorted_nodes = parallel_topo_sort(&predicate, &parent_map)?;
1007
1008    // Filter for which nodes are deferred. This is nodes with a post state read.
1009    let deferred_filter = |node: &essential_types::predicate::Node| -> bool {
1010        asm::effects::bytes_contains_any(
1011            &get_program.get_program(&node.program_address).0,
1012            asm::effects::Effects::PostKeyRange | asm::effects::Effects::PostKeyRangeExtern,
1013        )
1014    };
1015
1016    // Get the set of deferred nodes.
1017    let deferred = find_deferred(&predicate, deferred_filter);
1018
1019    // Depending on the run mode remove the deferred nodes or other nodes.
1020    let sorted_nodes = match run_mode {
1021        RunMode::Outputs => remove_deferred(sorted_nodes, &deferred),
1022        RunMode::Checks => remove_not_deferred(sorted_nodes, &deferred),
1023    };
1024
1025    // Setup a local cache for the outputs.
1026    let mut local_cache = Cache::new();
1027
1028    // The outputs from a run.
1029    let mut failed: Vec<(_, _)> = vec![];
1030    let mut total_gas: Gas = 0;
1031    let mut unsatisfied = Vec::new();
1032    let mut data_outputs = Vec::new();
1033
1034    // Run each set of parallel nodes.
1035    for parallel_nodes in sorted_nodes {
1036        // Run 1 or no length in serial to avoid overhead.
1037        let outputs: BTreeMap<u16, Result<(Output, Gas), _>> =
1038            if parallel_nodes.len() == 1 || parallel_nodes.is_empty() {
1039                parallel_nodes
1040                    .into_iter()
1041                    .map(|ix| {
1042                        // Check global cache then local cache
1043                        // for parent inputs.
1044                        let inputs = parent_map[&ix]
1045                            .iter()
1046                            .filter_map(|parent_ix| {
1047                                cache
1048                                    .get(parent_ix)
1049                                    .cloned()
1050                                    .or_else(|| local_cache.get(parent_ix).cloned())
1051                            })
1052                            .collect();
1053
1054                        // Run the program.
1055                        run(ix, inputs)
1056                    })
1057                    .collect()
1058            } else {
1059                parallel_nodes
1060                    .into_par_iter()
1061                    .map(|ix| {
1062                        // Check global cache then local cache
1063                        // for parent inputs.
1064                        let inputs = parent_map[&ix]
1065                            .iter()
1066                            .filter_map(|parent_ix| {
1067                                cache
1068                                    .get(parent_ix)
1069                                    .cloned()
1070                                    .or_else(|| local_cache.get(parent_ix).cloned())
1071                            })
1072                            .collect();
1073
1074                        // Run the program.
1075                        run(ix, inputs)
1076                    })
1077                    .collect()
1078            };
1079        for (node, res) in outputs {
1080            match res {
1081                Ok((Output::Parent(o), gas)) => {
1082                    // Check if we should add this output to the global or local cache.
1083                    if should_cache(node, &predicate, &deferred) {
1084                        cache.insert(node, o.clone());
1085                    } else {
1086                        local_cache.insert(node, o.clone());
1087                    }
1088
1089                    // Add to the total gas
1090                    total_gas = total_gas.saturating_add(gas);
1091                }
1092                Ok((Output::Leaf(o), gas)) => {
1093                    match o {
1094                        ProgramOutput::Satisfied(false) => {
1095                            unsatisfied.push(node as usize);
1096                        }
1097                        ProgramOutput::Satisfied(true) => {
1098                            // Nothing to do here.
1099                        }
1100                        ProgramOutput::DataOutput(data_output) => {
1101                            data_outputs.push(data_output);
1102                        }
1103                    }
1104
1105                    // Add to the total gas
1106                    total_gas = total_gas.saturating_add(gas);
1107                }
1108                Err(e) => {
1109                    failed.push((node as usize, e));
1110
1111                    if !config.collect_all_failures {
1112                        return Err(ProgramErrors(failed).into());
1113                    }
1114                }
1115            }
1116        }
1117    }
1118
1119    // If there are any failed constraints, return an error.
1120    if !failed.is_empty() {
1121        return Err(ProgramErrors(failed).into());
1122    }
1123
1124    // If there are any unsatisfied constraints, return an error.
1125    if !unsatisfied.is_empty() {
1126        return Err(ConstraintsUnsatisfied(unsatisfied).into());
1127    }
1128
1129    Ok((total_gas, data_outputs))
1130}
1131
1132/// Map the given program's bytecode and evaluate it.
1133///
1134/// If the program is a constraint, returns `Some(bool)` indicating whether or not the constraint
1135/// was satisfied, otherwise returns `None`.
1136fn run_program<S>(
1137    state: S,
1138    solution_set: Arc<SolutionSet>,
1139    solution_index: SolutionIndex,
1140    program: Arc<Program>,
1141    ctx: ProgramCtx,
1142) -> Result<(Output, Gas), ProgramError<S::Error>>
1143where
1144    S: StateReads,
1145{
1146    let ProgramCtx { parents, leaf } = ctx;
1147
1148    // Pull ops into memory.
1149    let ops = asm::from_bytes(program.0.iter().copied()).collect::<Result<Vec<_>, _>>()?;
1150
1151    // Create a new VM.
1152    let mut vm = vm::Vm::default();
1153
1154    // Use the results of the parent execution to initialise our stack and memory.
1155    for parent_result in parents {
1156        let (parent_stack, parent_memory) = Arc::unwrap_or_clone(parent_result);
1157        // Extend the stack.
1158        let mut stack: Vec<Word> = std::mem::take(&mut vm.stack).into();
1159        stack.append(&mut parent_stack.into());
1160        vm.stack = stack.try_into()?;
1161
1162        // Extend the memory.
1163        let mut memory: Vec<Word> = std::mem::take(&mut vm.memory).into();
1164        memory.append(&mut parent_memory.into());
1165        vm.memory = memory.try_into()?;
1166    }
1167
1168    // Setup solution access for execution.
1169    let access = Access::new(Arc::new(solution_set.solutions.clone()), solution_index);
1170
1171    // FIXME: Provide these from Config.
1172    let gas_cost = |_: &asm::Op| 1;
1173    let gas_limit = GasLimit::UNLIMITED;
1174
1175    // Read the state into the VM's memory.
1176    let gas_spent = vm.exec_ops(&ops, access, &state, &gas_cost, gas_limit)?;
1177
1178    let out = if leaf {
1179        match vm.stack[..] {
1180            [2] => Output::Leaf(ProgramOutput::DataOutput(DataOutput::Memory(vm.memory))),
1181            [1] => Output::Leaf(ProgramOutput::Satisfied(true)),
1182            _ => Output::Leaf(ProgramOutput::Satisfied(false)),
1183        }
1184    } else {
1185        let output = Arc::new((vm.stack, vm.memory));
1186        Output::Parent(output)
1187    };
1188
1189    Ok((out, gas_spent))
1190}