essential_check/
solution.rs

1//! Items related to validating `Solution`s and `SolutionSet`s.
2
3use crate::{
4    types::{
5        predicate::Predicate,
6        solution::{Solution, SolutionIndex, SolutionSet},
7        Key, PredicateAddress, Word,
8    },
9    vm::{
10        self,
11        asm::{self, FromBytesError},
12        Access, Gas, GasLimit, Memory, Stack,
13    },
14};
15#[cfg(feature = "tracing")]
16use essential_hash::content_addr;
17use essential_types::{predicate::Program, ContentAddress, Value};
18use essential_vm::{StateRead, StateReads};
19use std::{
20    collections::{BTreeMap, HashMap, HashSet},
21    fmt,
22    sync::Arc,
23};
24use thiserror::Error;
25
26use rayon::prelude::*;
27
28#[cfg(test)]
29mod tests;
30
31#[cfg(test)]
32mod test_graph_ops;
33
34/// Configuration options passed to [`check_predicate`].
35#[derive(Clone, Debug, Default, Eq, Hash, PartialEq)]
36pub struct CheckPredicateConfig {
37    /// Whether or not to wait and collect all failures after a single state
38    /// read or constraint fails.
39    ///
40    /// Potentially useful for debugging or testing tools.
41    ///
42    /// Default: `false`
43    pub collect_all_failures: bool,
44}
45
46/// Required impl for retrieving access to any [`Solution`]'s [`Predicate`]s during check.
47pub trait GetPredicate {
48    /// Provides immediate access to the predicate with the given content address.
49    ///
50    /// This is called by [`check_set_predicates`] for each predicate in each solution being checked.
51    ///
52    /// All necessary programs are assumed to have been read from storage and
53    /// validated ahead of time.
54    fn get_predicate(&self, addr: &PredicateAddress) -> Arc<Predicate>;
55}
56
57/// Required impl for retrieving access to any [`Predicate`]'s [`Program`]s during check.
58pub trait GetProgram {
59    /// Provides immediate access to the program with the given content address.
60    ///
61    /// This is called by [`check_set_predicates`] for each node within each predicate for
62    /// each solution being checked.
63    ///
64    /// All necessary programs are assumed to have been read from storage and
65    /// validated ahead of time.
66    fn get_program(&self, ca: &ContentAddress) -> Arc<Program>;
67}
68
69#[derive(Debug)]
70/// Context for checking a predicate
71pub struct Ctx<'a> {
72    /// The mode the check is running in.
73    pub run_mode: RunMode,
74    /// The global cache of outputs, indexed by node index.
75    pub cache: &'a mut Cache,
76}
77
78/// Cache of parent outputs, indexed by node index for a predicate.
79pub type Cache = HashMap<u16, Arc<(Stack, Memory)>>;
80
81/// The node context in which a `Program` is evaluated (see [`run_program`]).
82struct ProgramCtx {
83    /// The outputs from the parent nodes.
84    parents: Vec<Arc<(Stack, Memory)>>,
85    /// If this node is a leaf.
86    leaf: bool,
87}
88
89/// The outputs of checking a solution set.
90#[derive(Debug, PartialEq)]
91pub struct Outputs {
92    /// The total gas spent.
93    pub gas: Gas,
94    /// The data outputs from solving each predicate.
95    pub data: Vec<DataFromSolution>,
96}
97
98/// The data outputs from solving a particular predicate.
99#[derive(Debug, PartialEq)]
100pub struct DataFromSolution {
101    /// The index of the solution that produced this data.
102    pub solution_index: SolutionIndex,
103    /// The data output from the solution.
104    pub data: Vec<DataOutput>,
105}
106
107/// The output of a program execution.
108#[derive(Debug, PartialEq)]
109enum ProgramOutput {
110    /// The program output is a boolean value
111    /// indicating whether the constraint was satisfied.
112    Satisfied(bool),
113    /// The program output is data.
114    DataOutput(DataOutput),
115}
116
117/// Types of data output from a program.
118#[derive(Debug, PartialEq)]
119pub enum DataOutput {
120    /// The program output is the memory.
121    Memory(Memory),
122}
123
124/// The output of a program depends on
125/// whether it is a leaf or a parent.
126enum Output {
127    /// Leaf nodes output bools or data.
128    Leaf(ProgramOutput),
129    /// Parent nodes output a stack and memory.
130    Parent(Arc<(Stack, Memory)>),
131}
132
133/// The mode the check is running in.
134#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
135pub enum RunMode {
136    /// Generating outputs
137    #[default]
138    Outputs,
139    /// Checking outputs
140    Checks,
141}
142
143/// [`check_set`] error.
144#[derive(Debug, Error)]
145pub enum InvalidSolutionSet {
146    /// Invalid solution.
147    #[error("invalid solution: {0}")]
148    Solution(#[from] InvalidSolution),
149    /// State mutations validation failed.
150    #[error("state mutations validation failed: {0}")]
151    StateMutations(#[from] InvalidSetStateMutations),
152}
153
154/// [`check_solutions`] error.
155#[derive(Debug, Error)]
156pub enum InvalidSolution {
157    /// There must be at least one solution.
158    #[error("must be at least one solution")]
159    Empty,
160    /// The number of solutions exceeds the limit.
161    #[error("the number of solutions ({0}) exceeds the limit ({MAX_SOLUTIONS})")]
162    TooMany(usize),
163    /// A solution's predicate data length exceeds the limit.
164    #[error("solution {0}'s predicate data length exceeded {1} (limit: {MAX_PREDICATE_DATA})")]
165    PredicateDataLenExceeded(usize, usize),
166    /// Invalid state mutation entry.
167    #[error("Invalid state mutation entry: {0}")]
168    StateMutationEntry(KvError),
169    /// Predicate data value too large.
170    #[error("Predicate data value len {0} exceeds limit {MAX_VALUE_SIZE}")]
171    PredDataValueTooLarge(usize),
172}
173
174/// Error with a slot key or value.
175#[derive(Debug, Error)]
176pub enum KvError {
177    /// The key is too large.
178    #[error("key with length {0} exceeds limit {MAX_KEY_SIZE}")]
179    KeyTooLarge(usize),
180    /// The value is too large.
181    #[error("value with length {0} exceeds limit {MAX_VALUE_SIZE}")]
182    ValueTooLarge(usize),
183}
184
185/// [`check_set_state_mutations`] error.
186#[derive(Debug, Error)]
187pub enum InvalidSetStateMutations {
188    /// The number of state mutations exceeds the limit.
189    #[error("the number of state mutations ({0}) exceeds the limit ({MAX_STATE_MUTATIONS})")]
190    TooMany(usize),
191    /// Discovered multiple mutations to the same slot.
192    #[error("attempt to apply multiple mutations to the same slot: {0:?} {1:?}")]
193    MultipleMutationsForSlot(PredicateAddress, Key),
194}
195
196/// [`check_set_predicates`] error.
197#[derive(Debug, Error)]
198pub enum PredicatesError<E> {
199    /// One or more solution failed their associated predicate checks.
200    #[error("{0}")]
201    Failed(#[from] PredicateErrors<E>),
202    /// Summing solution gas resulted in overflow.
203    #[error("summing solution gas overflowed")]
204    GasOverflowed,
205    /// Tried to compute mutations on solution set with existing mutations.
206    #[error("tried to compute mutations on solution set with existing mutations")]
207    ExistingMutations,
208}
209
210/// Predicate checking failed for the solution at the given indices.
211#[derive(Debug, Error)]
212pub struct PredicateErrors<E>(pub Vec<(SolutionIndex, PredicateError<E>)>);
213
214/// [`check_predicate`] error.
215#[derive(Debug, Error)]
216pub enum PredicateError<E> {
217    /// Failed to retrieve edges for a node, indicating that the predicate's graph is invalid.
218    #[error("failed to retrieve edges for node {0} indicating an invalid graph")]
219    InvalidNodeEdges(usize),
220    /// The execution of one or more programs failed.
221    #[error("one or more program execution errors occurred: {0}")]
222    ProgramErrors(#[from] ProgramErrors<E>),
223    /// One or more of the constraints were unsatisfied.
224    #[error("one or more constraints unsatisfied: {0}")]
225    ConstraintsUnsatisfied(#[from] ConstraintsUnsatisfied),
226    /// One or more of the mutations were invalid.
227    #[error(transparent)]
228    Mutations(#[from] MutationsError),
229}
230
231/// Program execution failed for the programs at the given node indices.
232#[derive(Debug, Error)]
233pub struct ProgramErrors<E>(Vec<(usize, ProgramError<E>)>);
234
235/// An error occurring during a program task.
236#[derive(Debug, Error)]
237pub enum ProgramError<E> {
238    /// Failed to parse ops from bytecode during bytecode mapping.
239    #[error("failed to parse an op during bytecode mapping: {0}")]
240    OpsFromBytesError(#[from] FromBytesError),
241    /// Concatenating the parent program [`Stack`]s caused an overflow.
242    #[error("concatenating parent program `Stack`s caused an overflow: {0}")]
243    ParentStackConcatOverflow(#[from] vm::error::StackError),
244    /// Concatenating the parent program [`Memory`] slices caused an overflow.
245    #[error("concatenating parent program `Memory` slices caused an overflow: {0}")]
246    ParentMemoryConcatOverflow(#[from] vm::error::MemoryError),
247    /// VM execution resulted in an error.
248    #[error("VM execution error: {0}")]
249    Vm(#[from] vm::error::ExecError<E>),
250}
251
252/// The index of each constraint that was not satisfied.
253#[derive(Debug, Error)]
254pub struct ConstraintsUnsatisfied(pub Vec<usize>);
255
256/// Error with computing mutations.
257#[derive(Debug, Error)]
258pub enum MutationsError {
259    /// Duplicate mutations for the same key.
260    #[error("duplicate mutations for the same key: {0:?}")]
261    DuplicateMutations(Key),
262    /// Error decoding mutations.
263    #[error(transparent)]
264    DecodeError(#[from] essential_types::solution::decode::MutationDecodeError),
265}
266
267/// Maximum number of predicate data of a solution.
268pub const MAX_PREDICATE_DATA: u32 = 100;
269/// Maximum number of solutions within a solution set.
270pub const MAX_SOLUTIONS: usize = 100;
271/// Maximum number of state mutations of a solution.
272pub const MAX_STATE_MUTATIONS: usize = 1000;
273/// Maximum number of words in a slot value.
274pub const MAX_VALUE_SIZE: usize = 10_000;
275/// Maximum number of words in a slot key.
276pub const MAX_KEY_SIZE: usize = 1000;
277
278impl<E: fmt::Display> fmt::Display for PredicateErrors<E> {
279    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
280        f.write_str("predicate checking failed for one or more solutions:\n")?;
281        for (ix, err) in &self.0 {
282            f.write_str(&format!("  {ix}: {err}\n"))?;
283        }
284        Ok(())
285    }
286}
287
288impl<E: fmt::Display> fmt::Display for ProgramErrors<E> {
289    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
290        f.write_str("the programs at the following node indices failed: \n")?;
291        for (node_ix, err) in &self.0 {
292            f.write_str(&format!("  {node_ix}: {err}\n"))?;
293        }
294        Ok(())
295    }
296}
297
298impl fmt::Display for ConstraintsUnsatisfied {
299    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
300        f.write_str("the constraints at the following indices returned false: \n")?;
301        for ix in &self.0 {
302            f.write_str(&format!("  {ix}\n"))?;
303        }
304        Ok(())
305    }
306}
307
308impl<F> GetPredicate for F
309where
310    F: Fn(&PredicateAddress) -> Arc<Predicate>,
311{
312    fn get_predicate(&self, addr: &PredicateAddress) -> Arc<Predicate> {
313        (*self)(addr)
314    }
315}
316
317impl<F> GetProgram for F
318where
319    F: Fn(&ContentAddress) -> Arc<Program>,
320{
321    fn get_program(&self, ca: &ContentAddress) -> Arc<Program> {
322        (*self)(ca)
323    }
324}
325
326impl GetPredicate for HashMap<PredicateAddress, Arc<Predicate>> {
327    fn get_predicate(&self, addr: &PredicateAddress) -> Arc<Predicate> {
328        self[addr].clone()
329    }
330}
331
332impl GetProgram for HashMap<ContentAddress, Arc<Program>> {
333    fn get_program(&self, ca: &ContentAddress) -> Arc<Program> {
334        self[ca].clone()
335    }
336}
337
338impl<T: GetPredicate> GetPredicate for Arc<T> {
339    fn get_predicate(&self, addr: &PredicateAddress) -> Arc<Predicate> {
340        (**self).get_predicate(addr)
341    }
342}
343
344impl<T: GetProgram> GetProgram for Arc<T> {
345    fn get_program(&self, ca: &ContentAddress) -> Arc<Program> {
346        (**self).get_program(ca)
347    }
348}
349
350/// Validate a solution set, to the extent it can be validated without reference to
351/// its associated predicates.
352///
353/// This includes solutions and state mutations.
354#[cfg_attr(feature = "tracing", tracing::instrument(skip_all, fields(solution = %content_addr(set)), err))]
355pub fn check_set(set: &SolutionSet) -> Result<(), InvalidSolutionSet> {
356    check_solutions(&set.solutions)?;
357    check_set_state_mutations(set)?;
358    Ok(())
359}
360
361fn check_value_size(value: &[Word]) -> Result<(), KvError> {
362    if value.len() > MAX_VALUE_SIZE {
363        Err(KvError::ValueTooLarge(value.len()))
364    } else {
365        Ok(())
366    }
367}
368
369fn check_key_size(value: &[Word]) -> Result<(), KvError> {
370    if value.len() > MAX_KEY_SIZE {
371        Err(KvError::KeyTooLarge(value.len()))
372    } else {
373        Ok(())
374    }
375}
376
377/// Validate the solution set's slice of [`Solution`]s.
378pub fn check_solutions(solutions: &[Solution]) -> Result<(), InvalidSolution> {
379    // Validate solution.
380    // Ensure that at solution has at least one solution.
381    if solutions.is_empty() {
382        return Err(InvalidSolution::Empty);
383    }
384    // Ensure that solution length is below limit length.
385    if solutions.len() > MAX_SOLUTIONS {
386        return Err(InvalidSolution::TooMany(solutions.len()));
387    }
388
389    // Check whether the predicate data length has been exceeded.
390    for (solution_ix, solution) in solutions.iter().enumerate() {
391        // Ensure the length limit is not exceeded.
392        if solution.predicate_data.len() > MAX_PREDICATE_DATA as usize {
393            return Err(InvalidSolution::PredicateDataLenExceeded(
394                solution_ix,
395                solution.predicate_data.len(),
396            ));
397        }
398        for v in &solution.predicate_data {
399            check_value_size(v).map_err(|_| InvalidSolution::PredDataValueTooLarge(v.len()))?;
400        }
401    }
402    Ok(())
403}
404
405/// Validate the solution set's state mutations.
406pub fn check_set_state_mutations(set: &SolutionSet) -> Result<(), InvalidSolutionSet> {
407    // Validate state mutations.
408    // Ensure that the solution set's state mutations length is below limit length.
409    if set.state_mutations_len() > MAX_STATE_MUTATIONS {
410        return Err(InvalidSetStateMutations::TooMany(set.state_mutations_len()).into());
411    }
412
413    // Ensure that no more than one mutation per slot is proposed.
414    for solution in &set.solutions {
415        let mut mut_keys = HashSet::new();
416        for mutation in &solution.state_mutations {
417            if !mut_keys.insert(&mutation.key) {
418                return Err(InvalidSetStateMutations::MultipleMutationsForSlot(
419                    solution.predicate_to_solve.clone(),
420                    mutation.key.clone(),
421                )
422                .into());
423            }
424            // Check key length.
425            check_key_size(&mutation.key).map_err(InvalidSolution::StateMutationEntry)?;
426            // Check value length.
427            check_value_size(&mutation.value).map_err(InvalidSolution::StateMutationEntry)?;
428        }
429    }
430
431    Ok(())
432}
433
434fn decode_mutations<E>(
435    outputs: Outputs,
436    mut set: SolutionSet,
437) -> Result<SolutionSet, PredicatesError<E>> {
438    // For each output check if there are any state mutations and apply them.
439    for output in outputs.data {
440        // No two outputs can point to the same solution index.
441        // Get the solution that these outputs came from.
442        let s = &mut set.solutions[output.solution_index as usize];
443
444        // Set to check for duplicate mutations.
445        let mut mut_set = HashSet::new();
446
447        // For each memory output decode the mutations and apply them.
448        for data in output.data {
449            match data {
450                DataOutput::Memory(memory) => {
451                    for mutation in essential_types::solution::decode::decode_mutations(&memory)
452                        .map_err(|e| {
453                            PredicatesError::Failed(PredicateErrors(vec![(
454                                output.solution_index,
455                                PredicateError::Mutations(MutationsError::DecodeError(e)),
456                            )]))
457                        })?
458                    {
459                        // Check for duplicate mutation keys.
460                        if !mut_set.insert(mutation.key.clone()) {
461                            return Err(PredicatesError::Failed(PredicateErrors(vec![(
462                                output.solution_index,
463                                PredicateError::Mutations(MutationsError::DuplicateMutations(
464                                    mutation.key.clone(),
465                                )),
466                            )])));
467                        }
468
469                        // Apply the mutation.
470                        s.state_mutations.push(mutation);
471                    }
472                }
473            }
474        }
475    }
476    Ok(set)
477}
478
479/// Internal post state used for mutations.
480#[derive(Debug, Default)]
481struct PostState {
482    /// Contract => Key => Value
483    state: HashMap<ContentAddress, HashMap<Key, Value>>,
484}
485
486/// Arc wrapper for [`PostState`] to allow for cloning.
487/// Must take the same error type as the pre state.
488#[derive(Debug, Default)]
489struct PostStateArc<E>(Arc<PostState>, std::marker::PhantomData<E>);
490
491impl<E> Clone for PostStateArc<E> {
492    fn clone(&self) -> Self {
493        Self(self.0.clone(), Default::default())
494    }
495}
496
497impl<E> StateRead for PostStateArc<E>
498where
499    E: std::fmt::Display + std::fmt::Debug,
500{
501    type Error = E;
502
503    fn key_range(
504        &self,
505        contract_addr: ContentAddress,
506        mut key: Key,
507        num_values: usize,
508    ) -> Result<Vec<Vec<essential_types::Word>>, Self::Error> {
509        let out = self
510            .0
511            .state
512            .get(&contract_addr)
513            .map(|state| {
514                let mut values = Vec::with_capacity(num_values);
515                for _ in 0..num_values {
516                    let Some(value) = state.get(&key) else {
517                        return values;
518                    };
519                    values.push(value.clone());
520                    let Some(k) = next_key(key) else {
521                        return values;
522                    };
523                    key = k;
524                }
525                values
526            })
527            .unwrap_or_default();
528        Ok(out)
529    }
530}
531
532/// Get the next key in the range of keys.
533fn next_key(mut key: Key) -> Option<Key> {
534    for w in key.iter_mut().rev() {
535        match *w {
536            Word::MAX => *w = Word::MIN,
537            _ => {
538                *w += 1;
539                return Some(key);
540            }
541        }
542    }
543    None
544}
545
546/// Check the given solution set against the given predicates and
547/// and compute the post state mutations for this set.
548///
549/// This is a two-pass check. The first pass generates the outputs
550/// and does not run any post state reads.
551/// The second pass checks the outputs and runs the post state reads.
552#[cfg_attr(feature = "tracing", tracing::instrument(skip_all))]
553pub fn check_and_compute_solution_set_two_pass<S>(
554    state: &S,
555    solution_set: SolutionSet,
556    get_predicate: impl GetPredicate + Sync + Clone,
557    get_program: impl 'static + Clone + GetProgram + Send + Sync,
558    config: Arc<CheckPredicateConfig>,
559) -> Result<(Gas, SolutionSet), PredicatesError<S::Error>>
560where
561    S: Clone + StateRead + Send + Sync + 'static,
562    S::Error: Send + Sync + 'static,
563{
564    // Create an empty post state,
565    let post_state = PostStateArc::<S::Error>(Arc::new(PostState::default()), Default::default());
566
567    // Create an empty cache.
568    let mut cache = HashMap::new();
569
570    // Generate the outputs
571    let (mut gas, solution_set) = check_and_compute_solution_set(
572        &(state.clone(), post_state.clone()),
573        solution_set,
574        get_predicate.clone(),
575        get_program.clone(),
576        config.clone(),
577        RunMode::Outputs,
578        &mut cache,
579    )?;
580
581    // Get the post state back.
582    let mut post_state =
583        Arc::try_unwrap(post_state.0).expect("post state should have one reference");
584
585    // Apply the state mutations to the post state.
586    for solution in &solution_set.solutions {
587        for mutation in &solution.state_mutations {
588            post_state
589                .state
590                .entry(solution.predicate_to_solve.contract.clone())
591                .or_default()
592                .insert(mutation.key.clone(), mutation.value.clone());
593        }
594    }
595
596    // Put the post state back into an arc.
597    let post_state = PostStateArc(Arc::new(post_state), Default::default());
598
599    // Check the outputs
600    let (g, solution_set) = check_and_compute_solution_set(
601        &(state.clone(), post_state.clone()),
602        solution_set,
603        get_predicate,
604        get_program,
605        config,
606        RunMode::Checks,
607        &mut cache,
608    )?;
609
610    // Add the total gas
611    gas = gas.saturating_add(g);
612
613    // Return solutions set
614    Ok((gas, solution_set))
615}
616
617/// Check the given solution set against the given predicates and
618/// and compute the post state mutations for this set.
619#[cfg_attr(feature = "tracing", tracing::instrument(skip_all))]
620pub fn check_and_compute_solution_set<S>(
621    state: &S,
622    solution_set: SolutionSet,
623    get_predicate: impl GetPredicate + Sync,
624    get_program: impl 'static + Clone + GetProgram + Send + Sync,
625    config: Arc<CheckPredicateConfig>,
626    run_mode: RunMode,
627    cache: &mut HashMap<SolutionIndex, Cache>,
628) -> Result<(Gas, SolutionSet), PredicatesError<S::Error>>
629where
630    S: Clone + StateReads + Send + Sync + 'static,
631    S::Error: Send,
632{
633    // Check the set and gather any outputs.
634    let set = Arc::new(solution_set);
635    let outputs = check_set_predicates(
636        state,
637        set.clone(),
638        get_predicate,
639        get_program,
640        config,
641        run_mode,
642        cache,
643    )?;
644
645    // Safe to unwrap the arc here as we have no other references.
646    let set = Arc::try_unwrap(set).expect("set should have one reference");
647
648    // Get the gas
649    let gas = outputs.gas;
650
651    let set = decode_mutations(outputs, set)?;
652
653    Ok((gas, set))
654}
655
656/// Checks all of a [`SolutionSet`]'s [`Solution`]s against their associated [`Predicate`]s.
657///
658/// For each solution, we load the associated predicate and its programs and execute each
659/// in parallel and in topological order. The leaf nodes are treated as constraints or data outputs and if
660/// any constraint returns `false`, the whole solution set is considered to be invalid.
661///
662/// **NOTE:** This assumes that the given `SolutionSet` and all `Predicate`s have already
663/// been independently validated using [`solution::check_set`][check_set] and
664/// [`predicate::check`][crate::predicate::check] respectively.
665///
666/// ## Arguments
667///
668/// - `pre_state` must provide access to state *prior to* mutations being applied.
669/// - `post_state` must provide access to state *post* mutations being applied.
670///
671/// Returns the total gas spent.
672pub fn check_set_predicates<S>(
673    state: &S,
674    solution_set: Arc<SolutionSet>,
675    get_predicate: impl GetPredicate + Sync,
676    get_program: impl 'static + Clone + GetProgram + Send + Sync,
677    config: Arc<CheckPredicateConfig>,
678    run_mode: RunMode,
679    cache: &mut HashMap<SolutionIndex, Cache>,
680) -> Result<Outputs, PredicatesError<S::Error>>
681where
682    S: Clone + StateReads + Send + Sync + 'static,
683    S::Error: Send,
684{
685    #[cfg(feature = "tracing")]
686    tracing::trace!("{}", essential_hash::content_addr(&*solution_set));
687
688    let caches: Vec<_> = (0..solution_set.solutions.len())
689        .map(|i| {
690            let cache = cache.entry(i as u16).or_default();
691            core::mem::take(cache)
692        })
693        .collect();
694    // Check each solution in parallel.
695    let (ok, failed): (Vec<_>, Vec<_>) = solution_set
696        .solutions
697        .par_iter()
698        .zip(caches)
699        .enumerate()
700        .map(|(solution_index, (solution, mut cache))| {
701            let predicate = get_predicate.get_predicate(&solution.predicate_to_solve);
702            let solution_set = solution_set.clone();
703            let state = state.clone();
704            let config = config.clone();
705            let get_program = get_program.clone();
706
707            let res = check_predicate(
708                &state,
709                solution_set,
710                predicate,
711                get_program,
712                solution_index
713                    .try_into()
714                    .expect("solution index already validated"),
715                &config,
716                Ctx {
717                    run_mode,
718                    cache: &mut cache,
719                },
720            );
721
722            match res {
723                Ok(ok) => Ok((solution_index as u16, ok, cache)),
724                Err(e) => Err((solution_index as u16, e)),
725            }
726        })
727        .partition(Result::is_ok);
728
729    // If any predicates failed, return an error.
730    if !failed.is_empty() {
731        return Err(PredicateErrors(failed.into_iter().map(Result::unwrap_err).collect()).into());
732    }
733
734    // Calculate gas used.
735    let mut total_gas: u64 = 0;
736    let outputs = ok
737        .into_iter()
738        .map(Result::unwrap)
739        .map(|(solution_index, (gas, data_outputs), c)| {
740            let output = DataFromSolution {
741                solution_index,
742                data: data_outputs,
743            };
744            total_gas = total_gas.saturating_add(gas);
745            *cache.get_mut(&solution_index).expect("cache should exist") = c;
746            output
747        })
748        .collect();
749
750    Ok(Outputs {
751        gas: total_gas,
752        data: outputs,
753    })
754}
755
756/// Checks the predicate of the solution within the given set at the given `solution_index`.
757///
758/// Spawns a rayon task for each of the predicate's nodes to execute in parallel
759/// once their inputs are ready.
760///
761/// **NOTE:** This assumes that the given `SolutionSet` and `Predicate` have been
762/// independently validated using [`solution::check_set`][check_set]
763/// and [`predicate::check`][crate::predicate::check] respectively.
764///
765/// ## Arguments
766///
767/// - `pre_state` must provide access to state *prior to* mutations being applied.
768/// - `post_state` must provide access to state *post* mutations being applied.
769/// - `solution_index` represents the solution within `solution_set.solutions` that
770///   claims to solve this predicate.
771///
772/// Returns the total gas spent.
773pub fn check_predicate<S>(
774    state: &S,
775    solution_set: Arc<SolutionSet>,
776    predicate: Arc<Predicate>,
777    get_program: impl GetProgram + Send + Sync + 'static,
778    solution_index: SolutionIndex,
779    config: &CheckPredicateConfig,
780    ctx: Ctx,
781) -> Result<(Gas, Vec<DataOutput>), PredicateError<S::Error>>
782where
783    S: Clone + StateReads + Send + Sync + 'static,
784    S::Error: Send,
785{
786    let p = predicate.clone();
787
788    // Run all nodes that have all their inputs in parallel
789    let run = |ix: u16, parents: Vec<Arc<(Stack, Memory)>>| {
790        let program = get_program.get_program(&predicate.nodes[ix as usize].program_address);
791        let ctx = ProgramCtx {
792            parents,
793            leaf: predicate
794                .node_edges(ix as usize)
795                .expect("This is already checked")
796                .is_empty(),
797        };
798        let res = run_program(
799            state.clone(),
800            solution_set.clone(),
801            solution_index,
802            program,
803            ctx,
804        );
805        (ix, res)
806    };
807
808    check_predicate_inner(run, p, config, &get_program, ctx)
809}
810
811/// Includes nodes with no parents
812fn create_parent_map<E>(
813    predicate: &Predicate,
814) -> Result<BTreeMap<u16, Vec<u16>>, PredicateError<E>> {
815    let mut nodes: BTreeMap<u16, Vec<u16>> = BTreeMap::new();
816    // For each node add it their children's parents set
817    for node_ix in 0..predicate.nodes.len() {
818        // Insert this node incase it's a root
819        nodes.entry(node_ix as u16).or_default();
820
821        // Add any children
822        for edge in predicate
823            .node_edges(node_ix)
824            .ok_or_else(|| PredicateError::InvalidNodeEdges(node_ix))?
825        {
826            // Insert the child if it's not already there and then add this node as a parent
827            nodes.entry(*edge).or_default().push(node_ix as u16);
828        }
829    }
830    Ok(nodes)
831}
832
833fn in_degrees(num_nodes: usize, parent_map: &BTreeMap<u16, Vec<u16>>) -> BTreeMap<u16, usize> {
834    let mut in_degrees = BTreeMap::new();
835    for node in 0..num_nodes {
836        in_degrees.insert(
837            node as u16,
838            parent_map.get(&(node as u16)).map_or(0, |v| v.len()),
839        );
840    }
841
842    in_degrees
843}
844
845fn reduce_in_degrees(in_degrees: &mut BTreeMap<u16, usize>, children: &[u16]) {
846    for child in children {
847        if let Some(in_degree) = in_degrees.get_mut(child) {
848            *in_degree = in_degree.saturating_sub(1);
849        }
850    }
851}
852
853fn find_nodes_with_no_parents(in_degrees: &BTreeMap<u16, usize>) -> Vec<u16> {
854    in_degrees
855        .iter()
856        .filter_map(
857            |(node, in_degree)| {
858                if *in_degree == 0 {
859                    Some(*node)
860                } else {
861                    None
862                }
863            },
864        )
865        .collect()
866}
867
868/// Sorts the nodes in parallel topological order.
869///
870/// ## Note
871/// This is not a perfect ordering as the following:
872/// ```text
873///   A
874///  / \
875/// B   C
876/// |   |
877/// D   E
878///  \ /
879///   F
880/// ```
881/// Results in:
882/// ```text
883/// [[A], [B, C], [D, E], [F]]
884/// ```
885/// If `B` or `C` finish first then they could start on
886/// `D` or `E` respectively but this sort doesn't allow that.
887fn parallel_topo_sort<E>(
888    predicate: &Predicate,
889    parent_map: &BTreeMap<u16, Vec<u16>>,
890) -> Result<Vec<Vec<u16>>, PredicateError<E>> {
891    let mut in_degrees = in_degrees(predicate.nodes.len(), parent_map);
892
893    let mut out = Vec::new();
894    while !in_degrees.is_empty() {
895        let current_level = find_nodes_with_no_parents(&in_degrees);
896        if current_level.is_empty() {
897            // Cycle detected
898            // TODO: Change error
899            return Err(PredicateError::InvalidNodeEdges(0));
900        }
901
902        out.push(current_level.clone());
903
904        for node in current_level {
905            let children = predicate
906                .node_edges(node as usize)
907                .ok_or_else(|| PredicateError::InvalidNodeEdges(node as usize))?;
908            reduce_in_degrees(&mut in_degrees, children);
909            in_degrees.remove(&node);
910        }
911    }
912
913    Ok(out)
914}
915
916fn find_deferred<F>(predicate: &Predicate, is_deferred: F) -> HashSet<u16>
917where
918    F: Fn(&essential_types::predicate::Node) -> bool,
919{
920    let mut deferred = HashSet::new();
921    for (ix, node) in predicate.nodes.iter().enumerate() {
922        if is_deferred(node) {
923            deferred.insert(ix as u16);
924        }
925        if deferred.contains(&(ix as u16)) {
926            for child in predicate.node_edges(ix).expect("Already checked") {
927                deferred.insert(*child);
928            }
929        }
930    }
931    deferred
932}
933
934fn should_cache(node: u16, predicate: &Predicate, deferred: &HashSet<u16>) -> bool {
935    !deferred.contains(&node)
936        && predicate
937            .node_edges(node as usize)
938            .expect("Already checked")
939            .iter()
940            .any(|child| deferred.contains(child))
941}
942
943fn remove_deferred(nodes: Vec<Vec<u16>>, deferred: &HashSet<u16>) -> Vec<Vec<u16>> {
944    nodes
945        .into_iter()
946        .map(|level| {
947            level
948                .into_iter()
949                .filter(|node| !deferred.contains(node))
950                .collect::<Vec<_>>()
951        })
952        .filter(|level| !level.is_empty())
953        .collect()
954}
955
956fn remove_not_deferred(nodes: Vec<Vec<u16>>, deferred: &HashSet<u16>) -> Vec<Vec<u16>> {
957    nodes
958        .into_iter()
959        .map(|level| {
960            level
961                .into_iter()
962                .filter(|node| deferred.contains(node))
963                .collect::<Vec<_>>()
964        })
965        .filter(|level| !level.is_empty())
966        .collect()
967}
968
969/// Handles the checking of a predicate.
970/// - Sorts the nodes into parallel topological order.
971/// - Sets up for the run type.
972/// - Runs the programs in parallel where appropriate.
973/// - Collects the outputs and gas.
974fn check_predicate_inner<F, E>(
975    run: F,
976    predicate: Arc<Predicate>,
977    config: &CheckPredicateConfig,
978    get_program: &(impl GetProgram + Send + Sync + 'static),
979    ctx: Ctx<'_>,
980) -> Result<(Gas, Vec<DataOutput>), PredicateError<E>>
981where
982    F: Fn(u16, Vec<Arc<(Stack, Memory)>>) -> (u16, Result<(Output, u64), ProgramError<E>>)
983        + Send
984        + Sync
985        + Copy,
986    E: Send,
987{
988    // Get the mode we are running and the global cache.
989    let Ctx { run_mode, cache } = ctx;
990
991    // Create the parent map
992    let parent_map = create_parent_map(&predicate)?;
993
994    // Create a parallel topological sort of the nodes
995    let sorted_nodes = parallel_topo_sort(&predicate, &parent_map)?;
996
997    // Filter for which nodes are deferred. This is nodes with a post state read.
998    let deferred_filter = |node: &essential_types::predicate::Node| -> bool {
999        asm::effects::bytes_contains_any(
1000            &get_program.get_program(&node.program_address).0,
1001            asm::effects::Effects::PostKeyRange | asm::effects::Effects::PostKeyRangeExtern,
1002        )
1003    };
1004
1005    // Get the set of deferred nodes.
1006    let deferred = find_deferred(&predicate, deferred_filter);
1007
1008    // Depending on the run mode remove the deferred nodes or other nodes.
1009    let sorted_nodes = match run_mode {
1010        RunMode::Outputs => remove_deferred(sorted_nodes, &deferred),
1011        RunMode::Checks => remove_not_deferred(sorted_nodes, &deferred),
1012    };
1013
1014    // Setup a local cache for the outputs.
1015    let mut local_cache = Cache::new();
1016
1017    // The outputs from a run.
1018    let mut failed: Vec<(_, _)> = vec![];
1019    let mut total_gas: Gas = 0;
1020    let mut unsatisfied = Vec::new();
1021    let mut data_outputs = Vec::new();
1022
1023    // Run each set of parallel nodes.
1024    for parallel_nodes in sorted_nodes {
1025        // Run 1 or no length in serial to avoid overhead.
1026        let outputs: BTreeMap<u16, Result<(Output, Gas), _>> =
1027            if parallel_nodes.len() == 1 || parallel_nodes.is_empty() {
1028                parallel_nodes
1029                    .into_iter()
1030                    .map(|ix| {
1031                        // Check global cache then local cache
1032                        // for parent inputs.
1033                        let inputs = parent_map[&ix]
1034                            .iter()
1035                            .filter_map(|parent_ix| {
1036                                cache
1037                                    .get(parent_ix)
1038                                    .cloned()
1039                                    .or_else(|| local_cache.get(parent_ix).cloned())
1040                            })
1041                            .collect();
1042
1043                        // Run the program.
1044                        run(ix, inputs)
1045                    })
1046                    .collect()
1047            } else {
1048                parallel_nodes
1049                    .into_par_iter()
1050                    .map(|ix| {
1051                        // Check global cache then local cache
1052                        // for parent inputs.
1053                        let inputs = parent_map[&ix]
1054                            .iter()
1055                            .filter_map(|parent_ix| {
1056                                cache
1057                                    .get(parent_ix)
1058                                    .cloned()
1059                                    .or_else(|| local_cache.get(parent_ix).cloned())
1060                            })
1061                            .collect();
1062
1063                        // Run the program.
1064                        run(ix, inputs)
1065                    })
1066                    .collect()
1067            };
1068        for (node, res) in outputs {
1069            match res {
1070                Ok((Output::Parent(o), gas)) => {
1071                    // Check if we should add this output to the global or local cache.
1072                    if should_cache(node, &predicate, &deferred) {
1073                        cache.insert(node, o.clone());
1074                    } else {
1075                        local_cache.insert(node, o.clone());
1076                    }
1077
1078                    // Add to the total gas
1079                    total_gas = total_gas.saturating_add(gas);
1080                }
1081                Ok((Output::Leaf(o), gas)) => {
1082                    match o {
1083                        ProgramOutput::Satisfied(false) => {
1084                            unsatisfied.push(node as usize);
1085                        }
1086                        ProgramOutput::Satisfied(true) => {
1087                            // Nothing to do here.
1088                        }
1089                        ProgramOutput::DataOutput(data_output) => {
1090                            data_outputs.push(data_output);
1091                        }
1092                    }
1093
1094                    // Add to the total gas
1095                    total_gas = total_gas.saturating_add(gas);
1096                }
1097                Err(e) => {
1098                    failed.push((node as usize, e));
1099
1100                    if !config.collect_all_failures {
1101                        return Err(ProgramErrors(failed).into());
1102                    }
1103                }
1104            }
1105        }
1106    }
1107
1108    // If there are any failed constraints, return an error.
1109    if !failed.is_empty() {
1110        return Err(ProgramErrors(failed).into());
1111    }
1112
1113    // If there are any unsatisfied constraints, return an error.
1114    if !unsatisfied.is_empty() {
1115        return Err(ConstraintsUnsatisfied(unsatisfied).into());
1116    }
1117
1118    Ok((total_gas, data_outputs))
1119}
1120
1121/// Map the given program's bytecode and evaluate it.
1122///
1123/// If the program is a constraint, returns `Some(bool)` indicating whether or not the constraint
1124/// was satisfied, otherwise returns `None`.
1125fn run_program<S>(
1126    state: S,
1127    solution_set: Arc<SolutionSet>,
1128    solution_index: SolutionIndex,
1129    program: Arc<Program>,
1130    ctx: ProgramCtx,
1131) -> Result<(Output, Gas), ProgramError<S::Error>>
1132where
1133    S: StateReads,
1134{
1135    let ProgramCtx { parents, leaf } = ctx;
1136
1137    // Pull ops into memory.
1138    let ops = asm::from_bytes(program.0.iter().copied()).collect::<Result<Vec<_>, _>>()?;
1139
1140    // Create a new VM.
1141    let mut vm = vm::Vm::default();
1142
1143    // Use the results of the parent execution to initialise our stack and memory.
1144    for parent_result in parents {
1145        let (parent_stack, parent_memory) = Arc::unwrap_or_clone(parent_result);
1146        // Extend the stack.
1147        let mut stack: Vec<Word> = std::mem::take(&mut vm.stack).into();
1148        stack.append(&mut parent_stack.into());
1149        vm.stack = stack.try_into()?;
1150
1151        // Extend the memory.
1152        let mut memory: Vec<Word> = std::mem::take(&mut vm.memory).into();
1153        memory.append(&mut parent_memory.into());
1154        vm.memory = memory.try_into()?;
1155    }
1156
1157    // Setup solution access for execution.
1158    let access = Access::new(Arc::new(solution_set.solutions.clone()), solution_index);
1159
1160    // FIXME: Provide these from Config.
1161    let gas_cost = |_: &asm::Op| 1;
1162    let gas_limit = GasLimit::UNLIMITED;
1163
1164    // Read the state into the VM's memory.
1165    let gas_spent = vm.exec_ops(&ops, access, &state, &gas_cost, gas_limit)?;
1166
1167    let out = if leaf {
1168        match vm.stack[..] {
1169            [2] => Output::Leaf(ProgramOutput::DataOutput(DataOutput::Memory(vm.memory))),
1170            [1] => Output::Leaf(ProgramOutput::Satisfied(true)),
1171            _ => Output::Leaf(ProgramOutput::Satisfied(false)),
1172        }
1173    } else {
1174        let output = Arc::new((vm.stack, vm.memory));
1175        Output::Parent(output)
1176    };
1177
1178    Ok((out, gas_spent))
1179}