1mod byte_size;
5mod error;
6mod ext;
7mod ext_entity;
8mod meter;
9#[cfg(test)]
10mod tests;
11
12use std::{
13 borrow::Borrow,
14 collections::{BTreeMap, BTreeSet, HashSet, VecDeque},
15 convert::{From, TryInto},
16 fmt::Debug,
17 sync::Arc,
18};
19
20use linked_hash_map::LinkedHashMap;
21use thiserror::Error;
22use tracing::error;
23
24use crate::{
25 global_state::{
26 error::Error as GlobalStateError, state::StateReader,
27 trie_store::operations::compute_state_hash, DEFAULT_MAX_QUERY_DEPTH,
28 },
29 KeyPrefix,
30};
31use casper_types::{
32 addressable_entity::NamedKeyAddr,
33 bytesrepr::{self, ToBytes},
34 contract_messages::{Message, Messages},
35 contracts::NamedKeys,
36 execution::{Effects, TransformError, TransformInstruction, TransformKindV2, TransformV2},
37 global_state::TrieMerkleProof,
38 handle_stored_dictionary_value, BlockGlobalAddr, CLType, CLValue, CLValueError, Digest, Key,
39 KeyTag, StoredValue, StoredValueTypeMismatch, U512,
40};
41
42use self::meter::{heap_meter::HeapSize, Meter};
43pub use self::{
44 error::Error as TrackingCopyError,
45 ext::TrackingCopyExt,
46 ext_entity::{FeesPurseHandling, TrackingCopyEntityExt},
47};
48
49#[derive(Debug)]
51#[allow(clippy::large_enum_variant)]
52pub enum TrackingCopyQueryResult {
53 RootNotFound,
55 ValueNotFound(String),
57 CircularReference(String),
59 DepthLimit {
61 depth: u64,
63 },
64 Success {
66 value: StoredValue,
68 proofs: Vec<TrieMerkleProof<Key, StoredValue>>,
70 },
71}
72
73impl TrackingCopyQueryResult {
74 pub fn is_success(&self) -> bool {
76 matches!(self, TrackingCopyQueryResult::Success { .. })
77 }
78
79 pub fn into_result(self) -> Result<StoredValue, TrackingCopyError> {
81 match self {
82 TrackingCopyQueryResult::RootNotFound => {
83 Err(TrackingCopyError::Storage(Error::RootNotFound))
84 }
85 TrackingCopyQueryResult::ValueNotFound(msg) => {
86 Err(TrackingCopyError::ValueNotFound(msg))
87 }
88 TrackingCopyQueryResult::CircularReference(msg) => {
89 Err(TrackingCopyError::CircularReference(msg))
90 }
91 TrackingCopyQueryResult::DepthLimit { depth } => {
92 Err(TrackingCopyError::QueryDepthLimit { depth })
93 }
94 TrackingCopyQueryResult::Success { value, .. } => Ok(value),
95 }
96 }
97}
98
99struct Query {
101 base_key: Key,
103 visited_keys: HashSet<Key>,
105 current_key: Key,
107 unvisited_names: VecDeque<String>,
110 visited_names: Vec<String>,
113 depth: u64,
115}
116
117impl Query {
118 fn new(base_key: Key, path: &[String]) -> Self {
119 Query {
120 base_key,
121 current_key: base_key.normalize(),
122 unvisited_names: path.iter().cloned().collect(),
123 visited_names: Vec::new(),
124 visited_keys: HashSet::new(),
125 depth: 0,
126 }
127 }
128
129 fn next_name(&mut self) -> Option<&String> {
131 let next_name = self.unvisited_names.pop_front()?;
132 self.visited_names.push(next_name);
133 self.visited_names.last()
134 }
135
136 fn navigate(&mut self, key: Key) {
137 self.current_key = key.normalize();
138 self.depth += 1;
139 }
140
141 fn navigate_for_named_key(&mut self, named_key: Key) {
142 if let Key::NamedKey(_) = &named_key {
143 self.current_key = named_key.normalize();
144 }
145 }
146
147 fn into_not_found_result(self, msg_prefix: &str) -> TrackingCopyQueryResult {
148 let msg = format!("{} at path: {}", msg_prefix, self.current_path());
149 TrackingCopyQueryResult::ValueNotFound(msg)
150 }
151
152 fn into_circular_ref_result(self) -> TrackingCopyQueryResult {
153 let msg = format!(
154 "{:?} has formed a circular reference at path: {}",
155 self.current_key,
156 self.current_path()
157 );
158 TrackingCopyQueryResult::CircularReference(msg)
159 }
160
161 fn into_depth_limit_result(self) -> TrackingCopyQueryResult {
162 TrackingCopyQueryResult::DepthLimit { depth: self.depth }
163 }
164
165 fn current_path(&self) -> String {
166 let mut path = format!("{:?}", self.base_key);
167 for name in &self.visited_names {
168 path.push('/');
169 path.push_str(name);
170 }
171 path
172 }
173}
174
175#[derive(Clone, Debug)]
179pub struct GenericTrackingCopyCache<M: Copy + Debug> {
180 max_cache_size: usize,
181 current_cache_size: usize,
182 reads_cached: LinkedHashMap<Key, StoredValue>,
183 muts_cached: BTreeMap<KeyWithByteRepr, StoredValue>,
184 prunes_cached: BTreeSet<Key>,
185 meter: M,
186}
187
188impl<M: Meter<Key, StoredValue> + Copy + Default> GenericTrackingCopyCache<M> {
189 pub fn new(max_cache_size: usize, meter: M) -> GenericTrackingCopyCache<M> {
194 GenericTrackingCopyCache {
195 max_cache_size,
196 current_cache_size: 0,
197 reads_cached: LinkedHashMap::new(),
198 muts_cached: BTreeMap::new(),
199 prunes_cached: BTreeSet::new(),
200 meter,
201 }
202 }
203
204 pub fn new_default(max_cache_size: usize) -> GenericTrackingCopyCache<M> {
208 GenericTrackingCopyCache::new(max_cache_size, M::default())
209 }
210
211 pub fn insert_read(&mut self, key: Key, value: StoredValue) {
213 let element_size = Meter::measure(&self.meter, &key, &value);
214 self.reads_cached.insert(key, value);
215 self.current_cache_size += element_size;
216 while self.current_cache_size > self.max_cache_size {
217 match self.reads_cached.pop_front() {
218 Some((k, v)) => {
219 let element_size = Meter::measure(&self.meter, &k, &v);
220 self.current_cache_size -= element_size;
221 }
222 None => break,
223 }
224 }
225 }
226
227 pub fn insert_write(&mut self, key: Key, value: StoredValue) {
229 let kb = KeyWithByteRepr::new(key);
230 self.prunes_cached.remove(&key);
231 self.muts_cached.insert(kb, value);
232 }
233
234 pub fn insert_prune(&mut self, key: Key) {
236 self.prunes_cached.insert(key);
237 }
238
239 pub fn get(&mut self, key: &Key) -> Option<&StoredValue> {
241 if self.prunes_cached.contains(key) {
242 return None;
245 }
246 let kb = KeyWithByteRepr::new(*key);
247 if let Some(value) = self.muts_cached.get(&kb) {
248 return Some(value);
249 };
250
251 self.reads_cached.get_refresh(key).map(|v| &*v)
252 }
253
254 fn get_muts_cached_by_byte_prefix(&self, prefix: &[u8]) -> Vec<Key> {
256 self.muts_cached
257 .range(prefix.to_vec()..)
258 .take_while(|(key, _)| key.starts_with(prefix))
259 .map(|(key, _)| key.to_key())
260 .collect()
261 }
262
263 pub fn is_pruned(&self, key: &Key) -> bool {
265 self.prunes_cached.contains(key)
266 }
267
268 pub(self) fn into_muts(self) -> (BTreeMap<KeyWithByteRepr, StoredValue>, BTreeSet<Key>) {
269 (self.muts_cached, self.prunes_cached)
270 }
271}
272
273#[derive(Debug, Clone)]
277struct KeyWithByteRepr(Key, Vec<u8>);
278
279impl KeyWithByteRepr {
280 #[inline]
281 fn new(key: Key) -> Self {
282 let bytes = key.to_bytes().expect("should always serialize a Key");
283 KeyWithByteRepr(key, bytes)
284 }
285
286 #[inline]
287 fn starts_with(&self, prefix: &[u8]) -> bool {
288 self.1.starts_with(prefix)
289 }
290
291 #[inline]
292 fn to_key(&self) -> Key {
293 self.0
294 }
295}
296
297impl Borrow<Vec<u8>> for KeyWithByteRepr {
298 #[inline]
299 fn borrow(&self) -> &Vec<u8> {
300 &self.1
301 }
302}
303
304impl PartialEq for KeyWithByteRepr {
305 #[inline]
306 fn eq(&self, other: &Self) -> bool {
307 self.1 == other.1
308 }
309}
310
311impl Eq for KeyWithByteRepr {}
312
313impl PartialOrd for KeyWithByteRepr {
314 #[inline]
315 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
316 Some(self.cmp(other))
317 }
318}
319
320impl Ord for KeyWithByteRepr {
321 #[inline]
322 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
323 self.1.cmp(&other.1)
324 }
325}
326
327pub type TrackingCopyCache = GenericTrackingCopyCache<HeapSize>;
329
330#[derive(Clone)]
334pub struct TrackingCopy<R> {
335 reader: Arc<R>,
336 cache: TrackingCopyCache,
337 effects: Effects,
338 max_query_depth: u64,
339 messages: Messages,
340 enable_addressable_entity: bool,
341}
342
343#[derive(Debug)]
345pub enum AddResult {
346 Success,
348 KeyNotFound(Key),
350 TypeMismatch(StoredValueTypeMismatch),
352 Serialization(bytesrepr::Error),
354 Transform(TransformError),
356}
357
358impl From<CLValueError> for AddResult {
359 fn from(error: CLValueError) -> Self {
360 match error {
361 CLValueError::Serialization(error) => AddResult::Serialization(error),
362 CLValueError::Type(type_mismatch) => {
363 let expected = format!("{:?}", type_mismatch.expected);
364 let found = format!("{:?}", type_mismatch.found);
365 AddResult::TypeMismatch(StoredValueTypeMismatch::new(expected, found))
366 }
367 }
368 }
369}
370
371pub type TrackingCopyParts = (TrackingCopyCache, Effects, Messages);
373
374impl<R: StateReader<Key, StoredValue>> TrackingCopy<R>
375where
376 R: StateReader<Key, StoredValue, Error = GlobalStateError>,
377{
378 pub fn new(
380 reader: R,
381 max_query_depth: u64,
382 enable_addressable_entity: bool,
383 ) -> TrackingCopy<R> {
384 TrackingCopy {
385 reader: Arc::new(reader),
386 cache: GenericTrackingCopyCache::new(1024 * 16, HeapSize),
388 effects: Effects::new(),
389 max_query_depth,
390 messages: Vec::new(),
391 enable_addressable_entity,
392 }
393 }
394
395 pub fn reader(&self) -> &R {
397 &self.reader
398 }
399
400 pub fn shared_reader(&self) -> Arc<R> {
402 Arc::clone(&self.reader)
403 }
404
405 pub fn fork(&self) -> TrackingCopy<&TrackingCopy<R>> {
417 TrackingCopy::new(self, self.max_query_depth, self.enable_addressable_entity)
418 }
419
420 pub fn fork2(&self) -> Self {
431 TrackingCopy {
432 reader: Arc::clone(&self.reader),
433 cache: self.cache.clone(),
434 effects: self.effects.clone(),
435 max_query_depth: self.max_query_depth,
436 messages: self.messages.clone(),
437 enable_addressable_entity: self.enable_addressable_entity,
438 }
439 }
440
441 pub fn apply_changes(
448 &mut self,
449 effects: Effects,
450 cache: TrackingCopyCache,
451 messages: Messages,
452 ) {
453 self.effects = effects;
454 self.cache = cache;
455 self.messages = messages;
456 }
457
458 pub fn effects(&self) -> Effects {
460 self.effects.clone()
461 }
462
463 pub fn cache(&self) -> TrackingCopyCache {
465 self.cache.clone()
466 }
467
468 pub fn destructure(self) -> (Vec<(Key, StoredValue)>, BTreeSet<Key>, Effects) {
470 let (writes, prunes) = self.cache.into_muts();
471 let writes: Vec<(Key, StoredValue)> = writes.into_iter().map(|(k, v)| (k.0, v)).collect();
472
473 (writes, prunes, self.effects)
474 }
475
476 pub fn enable_addressable_entity(&self) -> bool {
478 self.enable_addressable_entity
479 }
480
481 pub fn get(&mut self, key: &Key) -> Result<Option<StoredValue>, TrackingCopyError> {
483 if let Some(value) = self.cache.get(key) {
484 return Ok(Some(value.to_owned()));
485 }
486 match self.reader.read(key) {
487 Ok(ret) => {
488 if let Some(value) = ret {
489 self.cache.insert_read(*key, value.to_owned());
490 Ok(Some(value))
491 } else {
492 Ok(None)
493 }
494 }
495 Err(err) => Err(TrackingCopyError::Storage(err)),
496 }
497 }
498
499 pub fn get_keys(&self, key_tag: &KeyTag) -> Result<BTreeSet<Key>, TrackingCopyError> {
501 self.get_by_byte_prefix(&[*key_tag as u8])
502 }
503
504 pub fn get_keys_by_prefix(
506 &self,
507 key_prefix: &KeyPrefix,
508 ) -> Result<BTreeSet<Key>, TrackingCopyError> {
509 let byte_prefix = key_prefix
510 .to_bytes()
511 .map_err(TrackingCopyError::BytesRepr)?;
512 self.get_by_byte_prefix(&byte_prefix)
513 }
514
515 pub(crate) fn get_by_byte_prefix(
517 &self,
518 byte_prefix: &[u8],
519 ) -> Result<BTreeSet<Key>, TrackingCopyError> {
520 let ret = self.keys_with_prefix(byte_prefix)?.into_iter().collect();
521 Ok(ret)
522 }
523
524 pub fn read(&mut self, key: &Key) -> Result<Option<StoredValue>, TrackingCopyError> {
526 let normalized_key = key.normalize();
527 if let Some(value) = self.get(&normalized_key)? {
528 self.effects
529 .push(TransformV2::new(normalized_key, TransformKindV2::Identity));
530 Ok(Some(value))
531 } else {
532 Ok(None)
533 }
534 }
535
536 pub fn read_first(&mut self, keys: &[&Key]) -> Result<Option<StoredValue>, TrackingCopyError> {
538 for key in keys {
539 if let Some(value) = self.read(key)? {
540 return Ok(Some(value));
541 }
542 }
543 Ok(None)
544 }
545
546 pub fn write(&mut self, key: Key, value: StoredValue) {
548 let normalized_key = key.normalize();
549 self.cache.insert_write(normalized_key, value.clone());
550 let transform = TransformV2::new(normalized_key, TransformKindV2::Write(value));
551 self.effects.push(transform);
552 }
553
554 #[allow(clippy::too_many_arguments)]
560 pub fn emit_message(
561 &mut self,
562 message_topic_key: Key,
563 message_topic_summary: StoredValue,
564 message_key: Key,
565 message_value: StoredValue,
566 block_message_count_value: StoredValue,
567 message: Message,
568 ) {
569 self.write(message_key, message_value);
570 self.write(message_topic_key, message_topic_summary);
571 self.write(
572 Key::BlockGlobal(BlockGlobalAddr::MessageCount),
573 block_message_count_value,
574 );
575 self.messages.push(message);
576 }
577
578 pub fn prune(&mut self, key: Key) {
580 let normalized_key = key.normalize();
581 self.cache.insert_prune(normalized_key);
582 self.effects.push(TransformV2::new(
583 normalized_key,
584 TransformKindV2::Prune(key),
585 ));
586 }
587
588 pub fn add(&mut self, key: Key, value: StoredValue) -> Result<AddResult, TrackingCopyError> {
593 let normalized_key = key.normalize();
594 let current_value = match self.get(&normalized_key)? {
595 None => return Ok(AddResult::KeyNotFound(normalized_key)),
596 Some(current_value) => current_value,
597 };
598
599 let type_name = value.type_name();
600 let mismatch = || {
601 Ok(AddResult::TypeMismatch(StoredValueTypeMismatch::new(
602 "I32, U64, U128, U256, U512 or (String, Key) tuple".to_string(),
603 type_name,
604 )))
605 };
606
607 let transform_kind = match value {
608 StoredValue::CLValue(cl_value) => match *cl_value.cl_type() {
609 CLType::I32 => match cl_value.into_t() {
610 Ok(value) => TransformKindV2::AddInt32(value),
611 Err(error) => return Ok(AddResult::from(error)),
612 },
613 CLType::U64 => match cl_value.into_t() {
614 Ok(value) => TransformKindV2::AddUInt64(value),
615 Err(error) => return Ok(AddResult::from(error)),
616 },
617 CLType::U128 => match cl_value.into_t() {
618 Ok(value) => TransformKindV2::AddUInt128(value),
619 Err(error) => return Ok(AddResult::from(error)),
620 },
621 CLType::U256 => match cl_value.into_t() {
622 Ok(value) => TransformKindV2::AddUInt256(value),
623 Err(error) => return Ok(AddResult::from(error)),
624 },
625 CLType::U512 => match cl_value.into_t() {
626 Ok(value) => TransformKindV2::AddUInt512(value),
627 Err(error) => return Ok(AddResult::from(error)),
628 },
629 _ => {
630 if *cl_value.cl_type() == casper_types::named_key_type() {
631 match cl_value.into_t() {
632 Ok((name, key)) => {
633 let mut named_keys = NamedKeys::new();
634 named_keys.insert(name, key);
635 TransformKindV2::AddKeys(named_keys)
636 }
637 Err(error) => return Ok(AddResult::from(error)),
638 }
639 } else {
640 return mismatch();
641 }
642 }
643 },
644 _ => return mismatch(),
645 };
646
647 match transform_kind.clone().apply(current_value) {
648 Ok(TransformInstruction::Store(new_value)) => {
649 self.cache.insert_write(normalized_key, new_value);
650 self.effects
651 .push(TransformV2::new(normalized_key, transform_kind));
652 Ok(AddResult::Success)
653 }
654 Ok(TransformInstruction::Prune(key)) => {
655 self.cache.insert_prune(normalized_key);
656 self.effects.push(TransformV2::new(
657 normalized_key,
658 TransformKindV2::Prune(key),
659 ));
660 Ok(AddResult::Success)
661 }
662 Err(TransformError::TypeMismatch(type_mismatch)) => {
663 Ok(AddResult::TypeMismatch(type_mismatch))
664 }
665 Err(TransformError::Serialization(error)) => Ok(AddResult::Serialization(error)),
666 Err(transform_error) => Ok(AddResult::Transform(transform_error)),
667 }
668 }
669
670 pub fn messages(&self) -> Messages {
672 self.messages.clone()
673 }
674
675 pub fn query(
683 &self,
684 base_key: Key,
685 path: &[String],
686 ) -> Result<TrackingCopyQueryResult, TrackingCopyError> {
687 let mut query = Query::new(base_key, path);
688
689 let mut proofs = Vec::new();
690
691 loop {
692 if query.depth >= self.max_query_depth {
693 return Ok(query.into_depth_limit_result());
694 }
695
696 if !query.visited_keys.insert(query.current_key) {
697 return Ok(query.into_circular_ref_result());
698 }
699
700 let stored_value = match self.reader.read_with_proof(&query.current_key)? {
701 None => {
702 return Ok(query.into_not_found_result("Failed to find base key"));
703 }
704 Some(stored_value) => stored_value,
705 };
706
707 let value = stored_value.value().to_owned();
708
709 let value = match handle_stored_dictionary_value(query.current_key, value) {
712 Ok(patched_stored_value) => patched_stored_value,
713 Err(error) => {
714 return Ok(query.into_not_found_result(&format!(
715 "Failed to retrieve dictionary value: {}",
716 error
717 )));
718 }
719 };
720
721 proofs.push(stored_value);
722
723 if query.unvisited_names.is_empty() && !query.current_key.is_named_key() {
724 return Ok(TrackingCopyQueryResult::Success { value, proofs });
725 }
726
727 let stored_value: &StoredValue = proofs
728 .last()
729 .map(|r| r.value())
730 .expect("but we just pushed");
731
732 match stored_value {
733 StoredValue::Account(account) => {
734 let mut maybe_msg_prefix: Option<String> = None;
735 if let Some(name) = query.next_name() {
736 if let Some(key) = account.named_keys().get(name) {
737 query.navigate(*key);
738 } else {
739 maybe_msg_prefix = Some(format!("Name {} not found in Account", name));
740 }
741 } else {
742 maybe_msg_prefix = Some("All names visited".to_string());
743 }
744 if let Some(msg_prefix) = maybe_msg_prefix {
745 return Ok(query.into_not_found_result(&msg_prefix));
746 }
747 }
748 StoredValue::Contract(contract) => {
749 let mut maybe_msg_prefix: Option<String> = None;
750 if let Some(name) = query.next_name() {
751 if let Some(key) = contract.named_keys().get(name) {
752 query.navigate(*key);
753 } else {
754 maybe_msg_prefix = Some(format!("Name {} not found in Contract", name));
755 }
756 } else {
757 maybe_msg_prefix = Some("All names visited".to_string());
758 }
759 if let Some(msg_prefix) = maybe_msg_prefix {
760 return Ok(query.into_not_found_result(&msg_prefix));
761 }
762 }
763 StoredValue::AddressableEntity(_) => {
764 let current_key = query.current_key;
765 let mut maybe_msg_prefix: Option<String> = None;
766 if let Some(name) = query.next_name() {
767 if let Key::AddressableEntity(addr) = current_key {
768 let named_key_addr =
769 match NamedKeyAddr::new_from_string(addr, name.clone()) {
770 Ok(named_key_addr) => Key::NamedKey(named_key_addr),
771 Err(error) => {
772 let msg_prefix = format!("{}", error);
773 return Ok(query.into_not_found_result(&msg_prefix));
774 }
775 };
776 query.navigate_for_named_key(named_key_addr);
777 } else {
778 maybe_msg_prefix = Some("Invalid base key".to_string());
779 }
780 } else {
781 maybe_msg_prefix = Some("All names visited".to_string());
782 }
783 if let Some(msg_prefix) = maybe_msg_prefix {
784 return Ok(query.into_not_found_result(&msg_prefix));
785 }
786 }
787 StoredValue::NamedKey(named_key_value) => {
788 match query.visited_names.last() {
789 Some(expected_name) => match named_key_value.get_name() {
790 Ok(actual_name) => {
791 if &actual_name != expected_name {
792 return Ok(query.into_not_found_result(
793 "Queried and retrieved names do not match",
794 ));
795 } else if let Ok(key) = named_key_value.get_key() {
796 query.navigate(key)
797 } else {
798 return Ok(query
799 .into_not_found_result("Failed to parse CLValue as Key"));
800 }
801 }
802 Err(_) => {
803 return Ok(query
804 .into_not_found_result("Failed to parse CLValue as String"));
805 }
806 },
807 None if path.is_empty() => {
808 return Ok(TrackingCopyQueryResult::Success { value, proofs });
809 }
810 None => return Ok(query.into_not_found_result("No visited names")),
811 }
812 }
813 StoredValue::CLValue(cl_value) if cl_value.cl_type() == &CLType::Key => {
814 if let Ok(key) = cl_value.to_owned().into_t::<Key>() {
815 query.navigate(key);
816 } else {
817 return Ok(query.into_not_found_result("Failed to parse CLValue as Key"));
818 }
819 }
820 StoredValue::CLValue(cl_value) => {
821 let msg_prefix = format!(
822 "Query cannot continue as {:?} is not an account, contract nor key to \
823 such. Value found",
824 cl_value
825 );
826 return Ok(query.into_not_found_result(&msg_prefix));
827 }
828 StoredValue::ContractWasm(_) => {
829 return Ok(query.into_not_found_result("ContractWasm value found."));
830 }
831 StoredValue::ContractPackage(_) => {
832 return Ok(query.into_not_found_result("ContractPackage value found."));
833 }
834 StoredValue::SmartContract(_) => {
835 return Ok(query.into_not_found_result("Package value found."));
836 }
837 StoredValue::ByteCode(_) => {
838 return Ok(query.into_not_found_result("ByteCode value found."));
839 }
840 StoredValue::Transfer(_) => {
841 return Ok(query.into_not_found_result("Legacy Transfer value found."));
842 }
843 StoredValue::DeployInfo(_) => {
844 return Ok(query.into_not_found_result("DeployInfo value found."));
845 }
846 StoredValue::EraInfo(_) => {
847 return Ok(query.into_not_found_result("EraInfo value found."));
848 }
849 StoredValue::Bid(_) => {
850 return Ok(query.into_not_found_result("Bid value found."));
851 }
852 StoredValue::BidKind(_) => {
853 return Ok(query.into_not_found_result("BidKind value found."));
854 }
855 StoredValue::Withdraw(_) => {
856 return Ok(query.into_not_found_result("WithdrawPurses value found."));
857 }
858 StoredValue::Unbonding(_) => {
859 return Ok(query.into_not_found_result("Unbonding value found."));
860 }
861 StoredValue::MessageTopic(_) => {
862 return Ok(query.into_not_found_result("MessageTopic value found."));
863 }
864 StoredValue::Message(_) => {
865 return Ok(query.into_not_found_result("Message value found."));
866 }
867 StoredValue::EntryPoint(_) => {
868 return Ok(query.into_not_found_result("EntryPoint value found."));
869 }
870 StoredValue::Prepayment(_) => {
871 return Ok(query.into_not_found_result("Prepayment value found."))
872 }
873 StoredValue::RawBytes(_) => {
874 return Ok(query.into_not_found_result("RawBytes value found."));
875 }
876 }
877 }
878 }
879}
880
881impl<R: StateReader<Key, StoredValue>> StateReader<Key, StoredValue> for &TrackingCopy<R> {
887 type Error = R::Error;
888
889 fn read(&self, key: &Key) -> Result<Option<StoredValue>, Self::Error> {
890 let kb = KeyWithByteRepr::new(*key);
891 if let Some(value) = self.cache.muts_cached.get(&kb) {
892 return Ok(Some(value.to_owned()));
893 }
894 if let Some(value) = self.reader.read(key)? {
895 Ok(Some(value))
896 } else {
897 Ok(None)
898 }
899 }
900
901 fn read_with_proof(
902 &self,
903 key: &Key,
904 ) -> Result<Option<TrieMerkleProof<Key, StoredValue>>, Self::Error> {
905 self.reader.read_with_proof(key)
906 }
907
908 fn keys_with_prefix(&self, byte_prefix: &[u8]) -> Result<Vec<Key>, Self::Error> {
909 let keys = self.reader.keys_with_prefix(byte_prefix)?;
910
911 let ret = keys
912 .into_iter()
913 .filter(|key| !self.cache.is_pruned(key))
915 .chain(self.cache.get_muts_cached_by_byte_prefix(byte_prefix))
917 .collect();
918 Ok(ret)
919 }
920}
921
922#[derive(Error, Debug, PartialEq, Eq)]
924pub enum ValidationError {
925 #[error("The path should not have a different length than the proof less one.")]
927 PathLengthDifferentThanProofLessOne,
928
929 #[error("The provided key does not match the key in the proof.")]
931 UnexpectedKey,
932
933 #[error("The provided value does not match the value in the proof.")]
935 UnexpectedValue,
936
937 #[error("The proof hash is invalid.")]
939 InvalidProofHash,
940
941 #[error("The path went cold.")]
943 PathCold,
944
945 #[error("Serialization error: {0}")]
947 BytesRepr(bytesrepr::Error),
948
949 #[error("Key is not a URef")]
951 KeyIsNotAURef(Key),
952
953 #[error("Failed to convert stored value to key")]
955 ValueToCLValueConversion,
956
957 #[error("{0}")]
959 CLValueError(CLValueError),
960}
961
962impl From<CLValueError> for ValidationError {
963 fn from(err: CLValueError) -> Self {
964 ValidationError::CLValueError(err)
965 }
966}
967
968impl From<bytesrepr::Error> for ValidationError {
969 fn from(error: bytesrepr::Error) -> Self {
970 Self::BytesRepr(error)
971 }
972}
973
974pub fn validate_query_proof(
978 hash: &Digest,
979 proofs: &[TrieMerkleProof<Key, StoredValue>],
980 expected_first_key: &Key,
981 path: &[String],
982 expected_value: &StoredValue,
983) -> Result<(), ValidationError> {
984 if proofs.len() != path.len() + 1 {
985 return Err(ValidationError::PathLengthDifferentThanProofLessOne);
986 }
987
988 let mut proofs_iter = proofs.iter();
989 let first_proof = match proofs_iter.next() {
990 Some(proof) => proof,
991 None => {
992 return Err(ValidationError::PathLengthDifferentThanProofLessOne);
993 }
994 };
995
996 let mut path_components_iter = path.iter();
997
998 if first_proof.key() != &expected_first_key.normalize() {
999 return Err(ValidationError::UnexpectedKey);
1000 }
1001
1002 if hash != &compute_state_hash(first_proof)? {
1003 return Err(ValidationError::InvalidProofHash);
1004 }
1005
1006 let mut proof_value = first_proof.value();
1007
1008 for proof in proofs_iter {
1009 let named_keys = match proof_value {
1010 StoredValue::Account(account) => account.named_keys(),
1011 StoredValue::Contract(contract) => contract.named_keys(),
1012 _ => return Err(ValidationError::PathCold),
1013 };
1014
1015 let path_component = match path_components_iter.next() {
1016 Some(path_component) => path_component,
1017 None => return Err(ValidationError::PathCold),
1018 };
1019
1020 let key = match named_keys.get(path_component) {
1021 Some(key) => key,
1022 None => return Err(ValidationError::PathCold),
1023 };
1024
1025 if proof.key() != &key.normalize() {
1026 return Err(ValidationError::UnexpectedKey);
1027 }
1028
1029 if hash != &compute_state_hash(proof)? {
1030 return Err(ValidationError::InvalidProofHash);
1031 }
1032
1033 proof_value = proof.value();
1034 }
1035
1036 if proof_value != expected_value {
1037 return Err(ValidationError::UnexpectedValue);
1038 }
1039
1040 Ok(())
1041}
1042
1043pub fn validate_query_merkle_proof(
1047 hash: &Digest,
1048 proofs: &[TrieMerkleProof<Key, StoredValue>],
1049 expected_key_trace: &[Key],
1050 expected_value: &StoredValue,
1051) -> Result<(), ValidationError> {
1052 let expected_len = expected_key_trace.len();
1053 if proofs.len() != expected_len {
1054 return Err(ValidationError::PathLengthDifferentThanProofLessOne);
1055 }
1056
1057 let proof_keys: Vec<Key> = proofs.iter().map(|proof| *proof.key()).collect();
1058
1059 if !expected_key_trace.eq(&proof_keys) {
1060 return Err(ValidationError::UnexpectedKey);
1061 }
1062
1063 if expected_value != proofs[expected_len - 1].value() {
1064 return Err(ValidationError::UnexpectedValue);
1065 }
1066
1067 let mut proofs_iter = proofs.iter();
1068
1069 let first_proof = match proofs_iter.next() {
1070 Some(proof) => proof,
1071 None => return Err(ValidationError::PathLengthDifferentThanProofLessOne),
1072 };
1073
1074 if hash != &compute_state_hash(first_proof)? {
1075 return Err(ValidationError::InvalidProofHash);
1076 }
1077
1078 Ok(())
1079}
1080
1081pub fn validate_balance_proof(
1083 hash: &Digest,
1084 balance_proof: &TrieMerkleProof<Key, StoredValue>,
1085 expected_purse_key: Key,
1086 expected_motes: &U512,
1087) -> Result<(), ValidationError> {
1088 let expected_balance_key = expected_purse_key
1089 .into_uref()
1090 .map(|uref| Key::Balance(uref.addr()))
1091 .ok_or_else(|| ValidationError::KeyIsNotAURef(expected_purse_key.to_owned()))?;
1092
1093 if balance_proof.key() != &expected_balance_key.normalize() {
1094 return Err(ValidationError::UnexpectedKey);
1095 }
1096
1097 if hash != &compute_state_hash(balance_proof)? {
1098 return Err(ValidationError::InvalidProofHash);
1099 }
1100
1101 let balance_proof_stored_value = balance_proof.value().to_owned();
1102
1103 let balance_proof_clvalue: CLValue = balance_proof_stored_value
1104 .try_into()
1105 .map_err(|_| ValidationError::ValueToCLValueConversion)?;
1106
1107 let balance_motes: U512 = balance_proof_clvalue.into_t()?;
1108
1109 if expected_motes != &balance_motes {
1110 return Err(ValidationError::UnexpectedValue);
1111 }
1112
1113 Ok(())
1114}
1115
1116use crate::global_state::{
1117 error::Error,
1118 state::{
1119 lmdb::{make_temporary_global_state, LmdbGlobalStateView},
1120 StateProvider,
1121 },
1122};
1123use tempfile::TempDir;
1124
1125pub fn new_temporary_tracking_copy(
1127 initial_data: impl IntoIterator<Item = (Key, StoredValue)>,
1128 max_query_depth: Option<u64>,
1129 enable_addressable_entity: bool,
1130) -> (TrackingCopy<LmdbGlobalStateView>, TempDir) {
1131 let (global_state, state_root_hash, tempdir) = make_temporary_global_state(initial_data);
1132
1133 let reader = global_state
1134 .checkout(state_root_hash)
1135 .expect("Checkout should not throw errors.")
1136 .expect("Root hash should exist.");
1137
1138 let query_depth = max_query_depth.unwrap_or(DEFAULT_MAX_QUERY_DEPTH);
1139
1140 (
1141 TrackingCopy::new(reader, query_depth, enable_addressable_entity),
1142 tempdir,
1143 )
1144}