1use crate::{
4 types::{
5 predicate::Predicate,
6 solution::{Solution, SolutionIndex, SolutionSet},
7 Key, PredicateAddress, Word,
8 },
9 vm::{
10 self,
11 asm::{self, FromBytesError},
12 Access, BytecodeMapped, Gas, GasLimit, Memory, Stack, StateRead,
13 },
14};
15#[cfg(feature = "tracing")]
16use essential_hash::content_addr;
17use essential_types::{
18 predicate::{Program, Reads},
19 ContentAddress,
20};
21use std::{
22 collections::{HashMap, HashSet},
23 fmt,
24 sync::Arc,
25};
26use thiserror::Error;
27use tokio::{sync::oneshot, task::JoinSet};
28#[cfg(feature = "tracing")]
29use tracing::Instrument;
30
31#[derive(Clone, Debug, Default, Eq, Hash, PartialEq)]
33pub struct CheckPredicateConfig {
34 pub collect_all_failures: bool,
41}
42
43pub trait GetPredicate {
45 fn get_predicate(&self, addr: &PredicateAddress) -> Arc<Predicate>;
52}
53
54pub trait GetProgram {
56 fn get_program(&self, ca: &ContentAddress) -> Arc<Program>;
64}
65
66struct ProgramCtx {
68 parents: Vec<oneshot::Receiver<Arc<(Stack, Memory)>>>,
72 children: Vec<oneshot::Sender<Arc<(Stack, Memory)>>>,
73 reads: Reads,
74}
75
76#[derive(Debug, Error)]
78pub enum InvalidSolutionSet {
79 #[error("invalid solution: {0}")]
81 Solution(#[from] InvalidSolution),
82 #[error("state mutations validation failed: {0}")]
84 StateMutations(#[from] InvalidSetStateMutations),
85}
86
87#[derive(Debug, Error)]
89pub enum InvalidSolution {
90 #[error("must be at least one solution")]
92 Empty,
93 #[error("the number of solutions ({0}) exceeds the limit ({MAX_SOLUTIONS})")]
95 TooMany(usize),
96 #[error("solution {0}'s predicate data length exceeded {1} (limit: {MAX_PREDICATE_DATA})")]
98 PredicateDataLenExceeded(usize, usize),
99 #[error("Invalid state mutation entry: {0}")]
101 StateMutationEntry(KvError),
102 #[error("Predicate data value len {0} exceeds limit {MAX_VALUE_SIZE}")]
104 PredDataValueTooLarge(usize),
105}
106
107#[derive(Debug, Error)]
109pub enum KvError {
110 #[error("key with length {0} exceeds limit {MAX_KEY_SIZE}")]
112 KeyTooLarge(usize),
113 #[error("value with length {0} exceeds limit {MAX_VALUE_SIZE}")]
115 ValueTooLarge(usize),
116}
117
118#[derive(Debug, Error)]
120pub enum InvalidSetStateMutations {
121 #[error("the number of state mutations ({0}) exceeds the limit ({MAX_STATE_MUTATIONS})")]
123 TooMany(usize),
124 #[error("attempt to apply multiple mutations to the same slot: {0:?} {1:?}")]
126 MultipleMutationsForSlot(PredicateAddress, Key),
127}
128
129#[derive(Debug, Error)]
131pub enum PredicatesError<E> {
132 #[error("{0}")]
134 Failed(#[from] PredicateErrors<E>),
135 #[error("one or more spawned tasks failed to join: {0}")]
137 Join(#[from] tokio::task::JoinError),
138 #[error("summing solution gas overflowed")]
140 GasOverflowed,
141}
142
143#[derive(Debug, Error)]
145pub struct PredicateErrors<E>(pub Vec<(SolutionIndex, PredicateError<E>)>);
146
147#[derive(Debug, Error)]
149pub enum PredicateError<E> {
150 #[error("one or more spawned program tasks failed to join: {0}")]
152 Join(#[from] tokio::task::JoinError),
153 #[error("failed to retrieve edges for node {0} indicating an invalid graph")]
155 InvalidNodeEdges(usize),
156 #[error("one or more program execution errors occurred: {0}")]
158 ProgramErrors(#[from] ProgramErrors<E>),
159 #[error("one or more constraints unsatisfied: {0}")]
161 ConstraintsUnsatisfied(#[from] ConstraintsUnsatisfied),
162}
163
164#[derive(Debug, Error)]
166pub struct ProgramErrors<E>(Vec<(usize, ProgramError<E>)>);
167
168#[derive(Debug, Error)]
170pub enum ProgramError<E> {
171 #[error("failed to parse an op during bytecode mapping: {0}")]
173 OpsFromBytesError(#[from] FromBytesError),
174 #[error("parent result oneshot channel closed: {0}")]
176 ParentChannelDropped(#[from] oneshot::error::RecvError),
177 #[error("concatenating parent program `Stack`s caused an overflow: {0}")]
179 ParentStackConcatOverflow(#[from] vm::error::StackError),
180 #[error("concatenating parent program `Memory` slices caused an overflow: {0}")]
182 ParentMemoryConcatOverflow(#[from] vm::error::MemoryError),
183 #[error("VM execution error: {0}")]
185 Vm(#[from] vm::error::ExecError<E>),
186}
187
188#[derive(Debug, Error)]
190pub struct ConstraintsUnsatisfied(pub Vec<usize>);
191
192pub const MAX_PREDICATE_DATA: u32 = 100;
194pub const MAX_SOLUTIONS: usize = 100;
196pub const MAX_STATE_MUTATIONS: usize = 1000;
198pub const MAX_VALUE_SIZE: usize = 10_000;
200pub const MAX_KEY_SIZE: usize = 1000;
202
203impl<E: fmt::Display> fmt::Display for PredicateErrors<E> {
204 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
205 f.write_str("predicate checking failed for one or more solutions:\n")?;
206 for (ix, err) in &self.0 {
207 f.write_str(&format!(" {ix}: {err}\n"))?;
208 }
209 Ok(())
210 }
211}
212
213impl<E: fmt::Display> fmt::Display for ProgramErrors<E> {
214 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
215 f.write_str("the programs at the following node indices failed: \n")?;
216 for (node_ix, err) in &self.0 {
217 f.write_str(&format!(" {node_ix}: {err}\n"))?;
218 }
219 Ok(())
220 }
221}
222
223impl fmt::Display for ConstraintsUnsatisfied {
224 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
225 f.write_str("the constraints at the following indices returned false: \n")?;
226 for ix in &self.0 {
227 f.write_str(&format!(" {ix}\n"))?;
228 }
229 Ok(())
230 }
231}
232
233impl<F> GetPredicate for F
234where
235 F: Fn(&PredicateAddress) -> Arc<Predicate>,
236{
237 fn get_predicate(&self, addr: &PredicateAddress) -> Arc<Predicate> {
238 (*self)(addr)
239 }
240}
241
242impl<F> GetProgram for F
243where
244 F: Fn(&ContentAddress) -> Arc<Program>,
245{
246 fn get_program(&self, ca: &ContentAddress) -> Arc<Program> {
247 (*self)(ca)
248 }
249}
250
251impl GetPredicate for HashMap<PredicateAddress, Arc<Predicate>> {
252 fn get_predicate(&self, addr: &PredicateAddress) -> Arc<Predicate> {
253 self[addr].clone()
254 }
255}
256
257impl GetProgram for HashMap<ContentAddress, Arc<Program>> {
258 fn get_program(&self, ca: &ContentAddress) -> Arc<Program> {
259 self[ca].clone()
260 }
261}
262
263impl<T: GetPredicate> GetPredicate for Arc<T> {
264 fn get_predicate(&self, addr: &PredicateAddress) -> Arc<Predicate> {
265 (**self).get_predicate(addr)
266 }
267}
268
269impl<T: GetProgram> GetProgram for Arc<T> {
270 fn get_program(&self, ca: &ContentAddress) -> Arc<Program> {
271 (**self).get_program(ca)
272 }
273}
274
275#[cfg_attr(feature = "tracing", tracing::instrument(skip_all, fields(solution = %content_addr(set)), err))]
280pub fn check_set(set: &SolutionSet) -> Result<(), InvalidSolutionSet> {
281 check_solutions(&set.solutions)?;
282 check_set_state_mutations(set)?;
283 Ok(())
284}
285
286fn check_value_size(value: &[Word]) -> Result<(), KvError> {
287 if value.len() > MAX_VALUE_SIZE {
288 Err(KvError::ValueTooLarge(value.len()))
289 } else {
290 Ok(())
291 }
292}
293
294fn check_key_size(value: &[Word]) -> Result<(), KvError> {
295 if value.len() > MAX_KEY_SIZE {
296 Err(KvError::KeyTooLarge(value.len()))
297 } else {
298 Ok(())
299 }
300}
301
302pub fn check_solutions(solutions: &[Solution]) -> Result<(), InvalidSolution> {
304 if solutions.is_empty() {
307 return Err(InvalidSolution::Empty);
308 }
309 if solutions.len() > MAX_SOLUTIONS {
311 return Err(InvalidSolution::TooMany(solutions.len()));
312 }
313
314 for (solution_ix, solution) in solutions.iter().enumerate() {
316 if solution.predicate_data.len() > MAX_PREDICATE_DATA as usize {
318 return Err(InvalidSolution::PredicateDataLenExceeded(
319 solution_ix,
320 solution.predicate_data.len(),
321 ));
322 }
323 for v in &solution.predicate_data {
324 check_value_size(v).map_err(|_| InvalidSolution::PredDataValueTooLarge(v.len()))?;
325 }
326 }
327 Ok(())
328}
329
330pub fn check_set_state_mutations(set: &SolutionSet) -> Result<(), InvalidSolutionSet> {
332 if set.state_mutations_len() > MAX_STATE_MUTATIONS {
335 return Err(InvalidSetStateMutations::TooMany(set.state_mutations_len()).into());
336 }
337
338 for solution in &set.solutions {
340 let mut mut_keys = HashSet::new();
341 for mutation in &solution.state_mutations {
342 if !mut_keys.insert(&mutation.key) {
343 return Err(InvalidSetStateMutations::MultipleMutationsForSlot(
344 solution.predicate_to_solve.clone(),
345 mutation.key.clone(),
346 )
347 .into());
348 }
349 check_key_size(&mutation.key).map_err(InvalidSolution::StateMutationEntry)?;
351 check_value_size(&mutation.value).map_err(InvalidSolution::StateMutationEntry)?;
353 }
354 }
355
356 Ok(())
357}
358
359#[cfg_attr(feature = "tracing", tracing::instrument(skip_all))]
376pub async fn check_set_predicates<SA, SB>(
377 pre_state: &SA,
378 post_state: &SB,
379 solution_set: Arc<SolutionSet>,
380 get_predicate: impl GetPredicate,
381 get_program: impl 'static + Clone + GetProgram + Send + Sync,
382 config: Arc<CheckPredicateConfig>,
383) -> Result<Gas, PredicatesError<SA::Error>>
384where
385 SA: Clone + StateRead + Send + Sync + 'static,
386 SB: Clone + StateRead<Error = SA::Error> + Send + Sync + 'static,
387 SA::Future: Send,
388 SB::Future: Send,
389 SA::Error: Send,
390{
391 #[cfg(feature = "tracing")]
392 tracing::trace!("{}", essential_hash::content_addr(&*solution_set));
393
394 let mut set: JoinSet<(_, Result<_, PredicateError<SA::Error>>)> = JoinSet::new();
396 for (solution_index, solution) in solution_set.solutions.iter().enumerate() {
397 let solution_index: SolutionIndex = solution_index
398 .try_into()
399 .expect("solution index already validated");
400 let predicate = get_predicate.get_predicate(&solution.predicate_to_solve);
401 let solution_set = solution_set.clone();
402 let pre_state: SA = pre_state.clone();
403 let post_state: SB = post_state.clone();
404 let config = config.clone();
405 let get_program = get_program.clone();
406
407 let future = async move {
408 let pre_state = pre_state;
409 let post_state = post_state;
410 let res = check_predicate(
411 &pre_state,
412 &post_state,
413 solution_set,
414 predicate,
415 &get_program,
416 solution_index,
417 &config,
418 )
419 .await;
420 (solution_index, res)
421 };
422
423 #[cfg(feature = "tracing")]
424 let future = future.in_current_span();
425
426 set.spawn(future);
427 }
428
429 let mut total_gas: u64 = 0;
431 let mut failed = vec![];
432 while let Some(res) = set.join_next().await {
433 let (solution_ix, res) = res?;
434 let g = match res {
435 Ok(ok) => ok,
436 Err(e) => {
437 failed.push((solution_ix, e));
438 if config.collect_all_failures {
439 continue;
440 } else {
441 return Err(PredicateErrors(failed).into());
442 }
443 }
444 };
445
446 total_gas = total_gas
447 .checked_add(g)
448 .ok_or(PredicatesError::GasOverflowed)?;
449 }
450
451 if !failed.is_empty() {
453 return Err(PredicateErrors(failed).into());
454 }
455
456 Ok(total_gas)
457}
458
459#[cfg_attr(
477 feature = "tracing",
478 tracing::instrument(
479 skip_all,
480 fields(
481 set = %format!("{}", content_addr(&*solution_set))[0..8],
482 solution={solution_index},
483 ),
484 ),
485)]
486pub async fn check_predicate<SA, SB>(
487 pre_state: &SA,
488 post_state: &SB,
489 solution_set: Arc<SolutionSet>,
490 predicate: Arc<Predicate>,
491 get_program: &impl GetProgram,
492 solution_index: SolutionIndex,
493 config: &CheckPredicateConfig,
494) -> Result<Gas, PredicateError<SA::Error>>
495where
496 SA: Clone + StateRead + Send + Sync + 'static,
497 SB: Clone + StateRead<Error = SA::Error> + Send + Sync + 'static,
498 SA::Future: Send,
499 SB::Future: Send,
500 SA::Error: Send,
501{
502 type NodeIx = usize;
503 type ParentResultRxs = Vec<oneshot::Receiver<Arc<(Stack, Memory)>>>;
504
505 let mut parent_results: HashMap<NodeIx, ParentResultRxs> = HashMap::new();
507
508 let program_futures = predicate
511 .nodes
512 .iter()
513 .enumerate()
514 .map(|(node_ix, node)| {
515 let edges = predicate
516 .node_edges(node_ix)
517 .ok_or_else(|| PredicateError::InvalidNodeEdges(node_ix))?;
518
519 let parents: ParentResultRxs = parent_results.remove(&node_ix).unwrap_or_default();
521
522 let mut txs = vec![];
525 for &e in edges {
526 let (tx, rx) = oneshot::channel();
527 txs.push(tx);
528 let child = usize::from(e);
529 parent_results.entry(child).or_default().push(rx);
530 }
531
532 let program_fut = run_program(
534 pre_state.clone(),
535 post_state.clone(),
536 solution_set.clone(),
537 solution_index,
538 get_program.get_program(&node.program_address),
539 ProgramCtx {
540 parents,
541 children: txs,
542 reads: node.reads,
543 },
544 );
545
546 Ok((node_ix, program_fut))
547 })
548 .collect::<Result<Vec<(NodeIx, _)>, PredicateError<SA::Error>>>()?;
549
550 let mut program_tasks: JoinSet<(NodeIx, Result<_, _>)> = program_futures
552 .into_iter()
553 .map(|(node_ix, program_fut)| async move { (node_ix, program_fut.await) })
554 .collect();
555
556 let mut failed = Vec::new();
558 let mut unsatisfied = Vec::new();
559
560 let mut total_gas: Gas = 0;
562 while let Some(join_res) = program_tasks.join_next().await {
563 let (node_ix, prog_res) = join_res?;
564 match prog_res {
565 Ok((satisfied, gas)) => {
566 if let Some(false) = satisfied {
568 unsatisfied.push(node_ix);
569 }
570 total_gas = total_gas.saturating_add(gas);
571 }
572 Err(err) => {
573 failed.push((node_ix, err));
574 if !config.collect_all_failures {
575 break;
576 }
577 }
578 }
579 }
580
581 if !failed.is_empty() {
583 return Err(ProgramErrors(failed).into());
584 }
585
586 if !unsatisfied.is_empty() {
588 return Err(ConstraintsUnsatisfied(unsatisfied).into());
589 }
590
591 Ok(total_gas)
592}
593
594#[cfg_attr(
599 feature = "tracing",
600 tracing::instrument(
601 fields(CA = %format!("{}:{:?}", &format!("{}", content_addr(&*program))[0..8], ctx.reads)),
602 skip_all,
603 ),
604)]
605async fn run_program<SA, SB>(
606 pre_state: SA,
607 post_state: SB,
608 solution_set: Arc<SolutionSet>,
609 solution_index: SolutionIndex,
610 program: Arc<Program>,
611 ctx: ProgramCtx,
612) -> Result<(Option<bool>, Gas), ProgramError<SA::Error>>
613where
614 SA: StateRead,
615 SB: StateRead<Error = SA::Error>,
616{
617 let program_mapped = BytecodeMapped::try_from(&program.0[..])?;
618
619 let mut vm = vm::Vm::default();
621
622 #[cfg(feature = "tracing")]
623 tracing::trace!(
624 "Program {} [{} {}, {} {}]",
625 content_addr(&*program),
626 ctx.parents.len(),
627 if ctx.parents.len() == 1 {
628 "parent"
629 } else {
630 "parents"
631 },
632 ctx.children.len(),
633 if ctx.children.len() == 1 {
634 "child"
635 } else {
636 "children"
637 },
638 );
639
640 for parent_rx in ctx.parents {
642 let parent_result: Arc<_> = parent_rx.await?;
643 let (parent_stack, parent_memory) = Arc::unwrap_or_clone(parent_result);
644 let mut stack: Vec<Word> = std::mem::take(&mut vm.stack).into();
646 stack.append(&mut parent_stack.into());
647 vm.stack = stack.try_into()?;
648
649 let mut memory: Vec<Word> = std::mem::take(&mut vm.memory).into();
651 memory.append(&mut parent_memory.into());
652 vm.memory = memory.try_into()?;
653 }
654
655 #[cfg(feature = "tracing")]
656 tracing::trace!(
657 "VM initialised with: \n ├── {:?}\n └── {:?}",
658 &vm.stack,
659 &vm.memory
660 );
661
662 let mut_keys = vm::mut_keys_set(&solution_set, solution_index);
664 let access = Access::new(&solution_set, solution_index, &mut_keys);
665
666 let gas_cost = |_: &asm::Op| 1;
668 let gas_limit = GasLimit::UNLIMITED;
669
670 let gas_spent = match ctx.reads {
672 Reads::Pre => {
673 vm.exec_bytecode(&program_mapped, access, &pre_state, &gas_cost, gas_limit)
674 .await?
675 }
676 Reads::Post => {
677 vm.exec_bytecode(&program_mapped, access, &post_state, &gas_cost, gas_limit)
678 .await?
679 }
680 };
681
682 let opt_satisfied = if ctx.children.is_empty() {
684 Some(vm.stack[..] == [1])
685 } else {
686 let output = Arc::new((vm.stack, vm.memory));
687 for tx in ctx.children {
688 let _ = tx.send(output.clone());
689 }
690 None
691 };
692
693 Ok((opt_satisfied, gas_spent))
694}