1use crate::{
4 db::{
5 self,
6 finalized::{query_state_exclusive_solution_set, query_state_inclusive_solution_set},
7 pool::ConnectionHandle,
8 ConnectionPool, QueryError,
9 },
10 error::{
11 PredicatesProgramsError, QueryPredicateError, QueryProgramError,
12 SolutionSetPredicatesError, StateReadError, ValidationError,
13 },
14};
15use essential_check::{
16 solution::{check_set_predicates, CheckPredicateConfig, PredicatesError},
17 vm::{Gas, StateRead},
18};
19use essential_node_types::{Block, BlockHeader};
20use essential_types::{
21 convert::bytes_from_word,
22 predicate::{Predicate, Program},
23 solution::{Solution, SolutionSet},
24 ContentAddress, Key, PredicateAddress, Value, Word,
25};
26use futures::FutureExt;
27use std::{collections::HashMap, pin::Pin, sync::Arc};
28
29#[cfg(test)]
30mod tests;
31
32#[derive(Clone)]
33struct State {
34 block_number: Word,
35 solution_set_index: u64,
36 pre_state: bool,
37 conn_pool: Db,
38}
39
40#[derive(Clone)]
41enum Db {
43 DryRun(DryRun),
44 ConnectionPool(ConnectionPool),
45}
46
47#[derive(Clone)]
48struct DryRun {
52 memory: Memory,
53 conn_pool: ConnectionPool,
54}
55
56#[derive(Clone)]
57struct Memory(db::ConnectionPool);
59
60enum Conn {
62 Cascade(Cascade),
63 Handle(ConnectionHandle),
64}
65
66enum Transaction<'a> {
68 Cascade(CascadeTransaction<'a>),
69 Handle(rusqlite::Transaction<'a>),
70}
71
72struct Cascade {
74 memory: ConnectionHandle,
75 db: ConnectionHandle,
76}
77
78struct CascadeTransaction<'a> {
80 memory: rusqlite::Transaction<'a>,
81 db: rusqlite::Transaction<'a>,
82}
83
84#[derive(Debug)]
86pub enum ValidateOutcome {
87 Valid(ValidOutcome),
89 Invalid(InvalidOutcome),
91}
92
93#[derive(Debug)]
96pub struct ValidOutcome {
97 pub total_gas: Gas,
99}
100
101#[derive(Debug)]
104pub struct InvalidOutcome {
105 pub failure: ValidateFailure,
107 pub solution_set_index: usize,
109}
110
111#[derive(Debug)]
114pub enum ValidateFailure {
115 MissingPredicate(PredicateAddress),
117 InvalidPredicate(PredicateAddress),
119 MissingProgram(ContentAddress),
121 InvalidProgram(ContentAddress),
123 #[allow(dead_code)]
124 PredicatesError(PredicatesError<StateReadError>),
126 GasOverflow,
128}
129
130#[cfg_attr(feature = "tracing", tracing::instrument(skip_all))]
136pub async fn validate_solution_set_dry_run(
137 conn_pool: &ConnectionPool,
138 contract_registry: &ContentAddress,
139 program_registry: &ContentAddress,
140 solution_set: SolutionSet,
141) -> Result<ValidateOutcome, ValidationError> {
142 let mut conn = conn_pool.acquire().await?;
143 let tx = conn.transaction()?;
144 let number = match db::get_latest_finalized_block_address(&tx)? {
145 Some(address) => db::get_block_header(&tx, &address)?
146 .map(|header| header.number)
147 .unwrap_or(1),
148 None => 1,
149 };
150 let block = Block {
151 header: BlockHeader {
152 number,
153 timestamp: std::time::SystemTime::now()
154 .duration_since(std::time::UNIX_EPOCH)
155 .expect("time must be valid"),
156 },
157 solution_sets: vec![solution_set],
158 };
159 drop(tx);
160 validate_dry_run(conn_pool, contract_registry, program_registry, &block).await
161}
162
163#[cfg_attr(feature = "tracing", tracing::instrument(skip_all))]
167pub async fn validate_dry_run(
168 conn_pool: &ConnectionPool,
169 contract_registry: &ContentAddress,
170 program_registry: &ContentAddress,
171 block: &Block,
172) -> Result<ValidateOutcome, ValidationError> {
173 let dry_run = DryRun::new(conn_pool.clone(), block).await?;
174 let db_type = Db::DryRun(dry_run);
175 validate_inner(db_type, contract_registry, program_registry, block).await
176}
177
178#[cfg_attr(feature = "tracing", tracing::instrument(skip_all))]
182pub(crate) async fn validate(
183 conn_pool: &ConnectionPool,
184 contract_registry: &ContentAddress,
185 program_registry: &ContentAddress,
186 block: &Block,
187) -> Result<ValidateOutcome, ValidationError> {
188 let db_type = Db::ConnectionPool(conn_pool.clone());
189 validate_inner(db_type, contract_registry, program_registry, block).await
190}
191
192#[cfg_attr(feature = "tracing", tracing::instrument(skip_all))]
196async fn validate_inner(
197 conn: Db,
198 contract_registry: &ContentAddress,
199 program_registry: &ContentAddress,
200 block: &Block,
201) -> Result<ValidateOutcome, ValidationError> {
202 let mut total_gas: u64 = 0;
203
204 for (solution_set_index, solution_set) in block.solution_sets.iter().enumerate() {
206 let pre_state = State {
207 block_number: block.header.number,
208 solution_set_index: solution_set_index as u64,
209 pre_state: true,
210 conn_pool: conn.clone(),
211 };
212 let post_state = State {
213 block_number: block.header.number,
214 solution_set_index: solution_set_index as u64,
215 pre_state: false,
216 conn_pool: conn.clone(),
217 };
218
219 let res =
221 query_solution_set_predicates(&post_state, contract_registry, &solution_set.solutions)
222 .await;
223 let predicates = match res {
224 Ok(predicates) => Arc::new(predicates),
225 Err(err) => match err {
226 SolutionSetPredicatesError::Acquire(err) => {
227 return Err(ValidationError::DbPoolClosed(err))
228 }
229 SolutionSetPredicatesError::QueryPredicate(addr, err) => match err {
230 QueryPredicateError::Query(err) => return Err(ValidationError::Query(err)),
231 QueryPredicateError::Decode(_)
232 | QueryPredicateError::MissingLenBytes
233 | QueryPredicateError::InvalidLenBytes => {
234 return Ok(ValidateOutcome::Invalid(InvalidOutcome {
235 failure: ValidateFailure::InvalidPredicate(addr),
236 solution_set_index,
237 }));
238 }
239 },
240 SolutionSetPredicatesError::MissingPredicate(addr) => {
241 return Ok(ValidateOutcome::Invalid(InvalidOutcome {
242 failure: ValidateFailure::MissingPredicate(addr),
243 solution_set_index,
244 }));
245 }
246 },
247 };
248
249 let res = query_predicates_programs(&post_state, program_registry, &predicates).await;
251 let programs = match res {
252 Ok(programs) => Arc::new(programs),
253 Err(err) => match err {
254 PredicatesProgramsError::Acquire(err) => {
255 return Err(ValidationError::DbPoolClosed(err))
256 }
257 PredicatesProgramsError::QueryProgram(addr, err) => match err {
258 QueryProgramError::Query(err) => return Err(ValidationError::Query(err)),
259 QueryProgramError::MissingLenBytes | QueryProgramError::InvalidLenBytes => {
260 return Ok(ValidateOutcome::Invalid(InvalidOutcome {
261 failure: ValidateFailure::InvalidProgram(addr),
262 solution_set_index,
263 }));
264 }
265 },
266 PredicatesProgramsError::MissingProgram(addr) => {
267 return Ok(ValidateOutcome::Invalid(InvalidOutcome {
268 failure: ValidateFailure::MissingProgram(addr),
269 solution_set_index,
270 }));
271 }
272 },
273 };
274
275 let get_predicate = move |addr: &PredicateAddress| {
276 predicates
277 .get(addr)
278 .cloned()
279 .expect("predicate must have been fetched in the previous step")
280 };
281
282 let get_program = move |addr: &ContentAddress| {
283 programs
284 .get(addr)
285 .cloned()
286 .expect("program must have been fetched in the previous step")
287 };
288
289 match check_set_predicates(
290 &pre_state,
291 &post_state,
292 Arc::new(solution_set.clone()),
293 get_predicate,
294 get_program,
295 Arc::new(CheckPredicateConfig::default()),
296 )
297 .await
298 {
299 Ok(g) => {
300 if let Some(g) = total_gas.checked_add(g) {
301 total_gas = g;
302 } else {
303 return Ok(ValidateOutcome::Invalid(InvalidOutcome {
304 failure: ValidateFailure::GasOverflow,
305 solution_set_index,
306 }));
307 }
308 }
309 Err(err) => {
310 #[cfg(feature = "tracing")]
311 tracing::debug!(
312 "Validation failed for block with number {} and address {} at solution set index {} with error {}",
313 block.header.number,
314 essential_hash::content_addr(block),
315 solution_set_index,
316 err
317 );
318 return Ok(ValidateOutcome::Invalid(InvalidOutcome {
319 failure: ValidateFailure::PredicatesError(err),
320 solution_set_index,
321 }));
322 }
323 }
324 }
325
326 #[cfg(feature = "tracing")]
327 tracing::debug!(
328 "Validation successful for block with number {} and address {}. Gas: {}",
329 block.header.number,
330 essential_hash::content_addr(block),
331 total_gas
332 );
333 Ok(ValidateOutcome::Valid(ValidOutcome { total_gas }))
334}
335
336impl DryRun {
337 pub async fn new(conn_pool: ConnectionPool, block: &Block) -> Result<Self, rusqlite::Error> {
340 let memory = Memory::new(block)?;
341 Ok(Self { memory, conn_pool })
342 }
343}
344
345impl Memory {
346 fn new(block: &Block) -> Result<Self, rusqlite::Error> {
348 let config = db::pool::Config {
351 conn_limit: 1,
352 source: db::pool::Source::Memory(uuid::Uuid::new_v4().to_string()),
353 };
354 let memory = db::ConnectionPool::new(&config)?;
355 let mut conn = memory
356 .try_acquire()
357 .expect("can't fail due to no other connections");
358
359 let tx = conn.transaction()?;
361 essential_node_db::create_tables(&tx)?;
362 let hash = essential_node_db::insert_block(&tx, block)?;
363 essential_node_db::finalize_block(&tx, &hash)?;
364 tx.commit()?;
365
366 Ok(Self(memory))
367 }
368}
369
370impl Db {
371 pub async fn acquire(&self) -> Result<Conn, tokio::sync::AcquireError> {
373 let conn = match self {
374 Db::DryRun(dry_run) => {
375 let cascade = Cascade {
376 memory: dry_run.memory.as_ref().acquire().await?,
377 db: dry_run.conn_pool.acquire().await?,
378 };
379 Conn::Cascade(cascade)
380 }
381 Db::ConnectionPool(conn_pool) => Conn::Handle(conn_pool.acquire().await?),
382 };
383 Ok(conn)
384 }
385}
386
387impl Conn {
388 fn transaction(&mut self) -> Result<Transaction<'_>, rusqlite::Error> {
390 match self {
391 Conn::Cascade(cascade) => {
392 let memory = cascade.memory.transaction()?;
393 let db = cascade.db.transaction()?;
394 Ok(Transaction::Cascade(CascadeTransaction { memory, db }))
395 }
396 Conn::Handle(handle) => {
397 let tx = handle.transaction()?;
398 Ok(Transaction::Handle(tx))
399 }
400 }
401 }
402}
403
404fn cascade(
406 conn: &CascadeTransaction,
407 f: impl Fn(&rusqlite::Transaction) -> Result<Option<Value>, QueryError>,
408) -> Result<Option<Value>, QueryError> {
409 match f(&conn.memory)? {
410 Some(val) => Ok(Some(val)),
411 None => f(&conn.db),
412 }
413}
414
415fn query(
417 conn: &Transaction,
418 f: impl Fn(&rusqlite::Transaction) -> Result<Option<Value>, QueryError>,
419) -> Result<Option<Value>, QueryError> {
420 match conn {
421 Transaction::Cascade(cascade_tx) => cascade(cascade_tx, f),
422 Transaction::Handle(tx) => f(tx),
423 }
424}
425
426impl StateRead for State {
427 type Error = StateReadError;
428
429 type Future =
430 Pin<Box<dyn std::future::Future<Output = Result<Vec<Vec<Word>>, Self::Error>> + Send>>;
431
432 fn key_range(
433 &self,
434 contract_addr: ContentAddress,
435 mut key: Key,
436 num_values: usize,
437 ) -> Self::Future {
438 let Self {
439 block_number,
440 solution_set_index,
441 pre_state,
442 conn_pool,
443 } = self.clone();
444
445 async move {
446 let mut conn = conn_pool.acquire().await?;
447
448 tokio::task::spawn_blocking(move || {
449 let mut values = vec![];
450 let tx = conn.transaction()?;
451
452 for _ in 0..num_values {
453 let value = query(&tx, |tx| {
454 query_state(
455 tx,
456 &contract_addr,
457 &key,
458 block_number,
459 solution_set_index,
460 pre_state,
461 )
462 })?;
463 let value = value.unwrap_or_default();
464 values.push(value);
465
466 key = next_key(key).ok_or_else(|| StateReadError::KeyRangeError)?;
467 }
468 Ok(values)
469 })
470 .await?
471 }
472 .boxed()
473 }
474}
475
476async fn query_solution_set_predicates(
479 state: &State,
480 contract_registry: &ContentAddress,
481 solutions: &[Solution],
482) -> Result<HashMap<PredicateAddress, Arc<Predicate>>, SolutionSetPredicatesError> {
483 let mut predicates = HashMap::default();
484 let mut conn = state.conn_pool.acquire().await?;
485 for solution in solutions {
486 let pred_addr = solution.predicate_to_solve.clone();
487 let Some(pred) = query_predicate(
488 &mut conn,
489 contract_registry,
490 &pred_addr,
491 state.block_number,
492 state.solution_set_index,
493 )
494 .map_err(|e| SolutionSetPredicatesError::QueryPredicate(pred_addr.clone(), e))?
495 else {
496 return Err(SolutionSetPredicatesError::MissingPredicate(
497 pred_addr.clone(),
498 ));
499 };
500 predicates.insert(pred_addr, Arc::new(pred));
501 }
502 Ok(predicates)
503}
504
505#[cfg_attr(feature = "tracing", tracing::instrument(skip_all, err))]
510fn query_predicate(
511 conn: &mut Conn,
512 contract_registry: &ContentAddress,
513 pred_addr: &PredicateAddress,
514 block_number: Word,
515 solution_set_ix: u64,
516) -> Result<Option<Predicate>, QueryPredicateError> {
517 use essential_node_types::contract_registry;
518 let pre_state = false;
519
520 #[cfg(feature = "tracing")]
521 tracing::trace!("{}:{}", pred_addr.contract, pred_addr.predicate);
522
523 let contract_predicate_key = contract_registry::contract_predicate_key(pred_addr);
525 let tx = conn.transaction().map_err(QueryError::Rusqlite)?;
526 if query(&tx, |tx| {
527 query_state(
528 tx,
529 contract_registry,
530 &contract_predicate_key,
531 block_number,
532 solution_set_ix,
533 pre_state,
534 )
535 })?
536 .is_none()
537 {
538 return Ok(None);
540 }
541
542 let predicate_key = contract_registry::predicate_key(&pred_addr.predicate);
544 let Some(pred_words) = query(&tx, |tx| {
545 query_state(
546 tx,
547 contract_registry,
548 &predicate_key,
549 block_number,
550 solution_set_ix,
551 pre_state,
552 )
553 })?
554 else {
555 return Ok(None);
557 };
558
559 let Some(&pred_len_bytes) = pred_words.first() else {
561 return Err(QueryPredicateError::MissingLenBytes);
562 };
563 let pred_len_bytes: usize = pred_len_bytes
564 .try_into()
565 .map_err(|_| QueryPredicateError::InvalidLenBytes)?;
566 let pred_words = &pred_words[1..];
567 let pred_bytes: Vec<u8> = pred_words
568 .iter()
569 .copied()
570 .flat_map(bytes_from_word)
571 .take(pred_len_bytes)
572 .collect();
573
574 let predicate = Predicate::decode(&pred_bytes)?;
575 Ok(Some(predicate))
576}
577
578async fn query_predicates_programs(
581 state: &State,
582 program_registry: &ContentAddress,
583 predicates: &HashMap<PredicateAddress, Arc<Predicate>>,
584) -> Result<HashMap<ContentAddress, Arc<Program>>, PredicatesProgramsError> {
585 let mut programs = HashMap::default();
586 let mut conn = state.conn_pool.acquire().await?;
587 for predicate in predicates.values() {
588 for node in &predicate.nodes {
589 let prog_addr = node.program_address.clone();
590 let Some(prog) = query_program(
591 &mut conn,
592 program_registry,
593 &prog_addr,
594 state.block_number,
595 state.solution_set_index,
596 )
597 .map_err(|e| PredicatesProgramsError::QueryProgram(prog_addr.clone(), e))?
598 else {
599 return Err(PredicatesProgramsError::MissingProgram(prog_addr.clone()));
600 };
601 programs.insert(prog_addr, Arc::new(prog));
602 }
603 }
604 Ok(programs)
605}
606
607fn query_program(
612 conn: &mut Conn,
613 program_registry: &ContentAddress,
614 prog_addr: &ContentAddress,
615 block_number: Word,
616 solution_set_ix: u64,
617) -> Result<Option<Program>, QueryProgramError> {
618 use essential_node_types::program_registry;
619 let pre_state = false;
620
621 #[cfg(feature = "tracing")]
622 tracing::trace!("{}", prog_addr);
623
624 let program_key = program_registry::program_key(prog_addr);
626 let tx = conn.transaction().map_err(QueryError::Rusqlite)?;
627 let Some(prog_words) = query(&tx, |tx| {
628 query_state(
629 tx,
630 program_registry,
631 &program_key,
632 block_number,
633 solution_set_ix,
634 pre_state,
635 )
636 })?
637 else {
638 return Ok(None);
640 };
641
642 let Some(&prog_len_bytes) = prog_words.first() else {
644 return Err(QueryProgramError::MissingLenBytes);
645 };
646 let prog_len_bytes: usize = prog_len_bytes
647 .try_into()
648 .map_err(|_| QueryProgramError::InvalidLenBytes)?;
649 let prog_words = &prog_words[1..];
650 let prog_bytes: Vec<u8> = prog_words
651 .iter()
652 .copied()
653 .flat_map(bytes_from_word)
654 .take(prog_len_bytes)
655 .collect();
656
657 let program = Program(prog_bytes);
658 Ok(Some(program))
659}
660
661fn query_state(
662 conn: &rusqlite::Connection,
663 contract_ca: &ContentAddress,
664 key: &Key,
665 block_number: Word,
666 solution_set_ix: u64,
667 pre_state: bool,
668) -> Result<Option<Value>, QueryError> {
669 if pre_state {
670 query_state_exclusive_solution_set(conn, contract_ca, key, block_number, solution_set_ix)
671 } else {
672 query_state_inclusive_solution_set(conn, contract_ca, key, block_number, solution_set_ix)
673 }
674}
675
676pub fn next_key(mut key: Key) -> Option<Key> {
678 for w in key.iter_mut().rev() {
679 match *w {
680 Word::MAX => *w = Word::MIN,
681 _ => {
682 *w += 1;
683 return Some(key);
684 }
685 }
686 }
687 None
688}
689
690impl AsRef<db::ConnectionPool> for Memory {
691 fn as_ref(&self) -> &db::ConnectionPool {
692 &self.0
693 }
694}