1mod byte_size;
5mod error;
6mod ext;
7mod ext_entity;
8mod meter;
9#[cfg(test)]
10mod tests;
11
12use std::{
13 borrow::Borrow,
14 collections::{BTreeMap, BTreeSet, HashSet, VecDeque},
15 convert::{From, TryInto},
16 fmt::Debug,
17 sync::Arc,
18};
19
20use linked_hash_map::LinkedHashMap;
21use thiserror::Error;
22use tracing::error;
23
24use crate::{
25 global_state::{
26 error::Error as GlobalStateError, state::StateReader,
27 trie_store::operations::compute_state_hash, DEFAULT_MAX_QUERY_DEPTH,
28 },
29 KeyPrefix,
30};
31use casper_types::{
32 addressable_entity::NamedKeyAddr,
33 bytesrepr::{self, ToBytes},
34 contract_messages::{Message, Messages},
35 contracts::NamedKeys,
36 execution::{Effects, TransformError, TransformInstruction, TransformKindV2, TransformV2},
37 global_state::TrieMerkleProof,
38 handle_stored_dictionary_value, BlockGlobalAddr, CLType, CLValue, CLValueError, Digest, Key,
39 KeyTag, StoredValue, StoredValueTypeMismatch, U512,
40};
41
42use self::meter::{heap_meter::HeapSize, Meter};
43pub use self::{
44 error::Error as TrackingCopyError,
45 ext::TrackingCopyExt,
46 ext_entity::{FeesPurseHandling, TrackingCopyEntityExt},
47};
48
49#[derive(Debug)]
51#[allow(clippy::large_enum_variant)]
52pub enum TrackingCopyQueryResult {
53 RootNotFound,
55 ValueNotFound(String),
57 CircularReference(String),
59 DepthLimit {
61 depth: u64,
63 },
64 Success {
66 value: StoredValue,
68 proofs: Vec<TrieMerkleProof<Key, StoredValue>>,
70 },
71}
72
73impl TrackingCopyQueryResult {
74 pub fn is_success(&self) -> bool {
76 matches!(self, TrackingCopyQueryResult::Success { .. })
77 }
78
79 pub fn into_result(self) -> Result<StoredValue, TrackingCopyError> {
81 match self {
82 TrackingCopyQueryResult::RootNotFound => {
83 Err(TrackingCopyError::Storage(Error::RootNotFound))
84 }
85 TrackingCopyQueryResult::ValueNotFound(msg) => {
86 Err(TrackingCopyError::ValueNotFound(msg))
87 }
88 TrackingCopyQueryResult::CircularReference(msg) => {
89 Err(TrackingCopyError::CircularReference(msg))
90 }
91 TrackingCopyQueryResult::DepthLimit { depth } => {
92 Err(TrackingCopyError::QueryDepthLimit { depth })
93 }
94 TrackingCopyQueryResult::Success { value, .. } => Ok(value),
95 }
96 }
97}
98
99struct Query {
101 base_key: Key,
103 visited_keys: HashSet<Key>,
105 current_key: Key,
107 unvisited_names: VecDeque<String>,
110 visited_names: Vec<String>,
113 depth: u64,
115}
116
117impl Query {
118 fn new(base_key: Key, path: &[String]) -> Self {
119 Query {
120 base_key,
121 current_key: base_key.normalize(),
122 unvisited_names: path.iter().cloned().collect(),
123 visited_names: Vec::new(),
124 visited_keys: HashSet::new(),
125 depth: 0,
126 }
127 }
128
129 fn next_name(&mut self) -> &String {
131 let next_name = self.unvisited_names.pop_front().unwrap();
132 self.visited_names.push(next_name);
133 self.visited_names.last().unwrap()
134 }
135
136 fn navigate(&mut self, key: Key) {
137 self.current_key = key.normalize();
138 self.depth += 1;
139 }
140
141 fn navigate_for_named_key(&mut self, named_key: Key) {
142 if let Key::NamedKey(_) = &named_key {
143 self.current_key = named_key.normalize();
144 }
145 }
146
147 fn into_not_found_result(self, msg_prefix: &str) -> TrackingCopyQueryResult {
148 let msg = format!("{} at path: {}", msg_prefix, self.current_path());
149 TrackingCopyQueryResult::ValueNotFound(msg)
150 }
151
152 fn into_circular_ref_result(self) -> TrackingCopyQueryResult {
153 let msg = format!(
154 "{:?} has formed a circular reference at path: {}",
155 self.current_key,
156 self.current_path()
157 );
158 TrackingCopyQueryResult::CircularReference(msg)
159 }
160
161 fn into_depth_limit_result(self) -> TrackingCopyQueryResult {
162 TrackingCopyQueryResult::DepthLimit { depth: self.depth }
163 }
164
165 fn current_path(&self) -> String {
166 let mut path = format!("{:?}", self.base_key);
167 for name in &self.visited_names {
168 path.push('/');
169 path.push_str(name);
170 }
171 path
172 }
173}
174
175#[derive(Clone, Debug)]
179pub struct GenericTrackingCopyCache<M: Copy + Debug> {
180 max_cache_size: usize,
181 current_cache_size: usize,
182 reads_cached: LinkedHashMap<Key, StoredValue>,
183 muts_cached: BTreeMap<KeyWithByteRepr, StoredValue>,
184 prunes_cached: BTreeSet<Key>,
185 meter: M,
186}
187
188impl<M: Meter<Key, StoredValue> + Copy + Default> GenericTrackingCopyCache<M> {
189 pub fn new(max_cache_size: usize, meter: M) -> GenericTrackingCopyCache<M> {
194 GenericTrackingCopyCache {
195 max_cache_size,
196 current_cache_size: 0,
197 reads_cached: LinkedHashMap::new(),
198 muts_cached: BTreeMap::new(),
199 prunes_cached: BTreeSet::new(),
200 meter,
201 }
202 }
203
204 pub fn new_default(max_cache_size: usize) -> GenericTrackingCopyCache<M> {
208 GenericTrackingCopyCache::new(max_cache_size, M::default())
209 }
210
211 pub fn insert_read(&mut self, key: Key, value: StoredValue) {
213 let element_size = Meter::measure(&self.meter, &key, &value);
214 self.reads_cached.insert(key, value);
215 self.current_cache_size += element_size;
216 while self.current_cache_size > self.max_cache_size {
217 match self.reads_cached.pop_front() {
218 Some((k, v)) => {
219 let element_size = Meter::measure(&self.meter, &k, &v);
220 self.current_cache_size -= element_size;
221 }
222 None => break,
223 }
224 }
225 }
226
227 pub fn insert_write(&mut self, key: Key, value: StoredValue) {
229 let kb = KeyWithByteRepr::new(key);
230 self.prunes_cached.remove(&key);
231 self.muts_cached.insert(kb, value);
232 }
233
234 pub fn insert_prune(&mut self, key: Key) {
236 self.prunes_cached.insert(key);
237 }
238
239 pub fn get(&mut self, key: &Key) -> Option<&StoredValue> {
241 if self.prunes_cached.contains(key) {
242 return None;
245 }
246 let kb = KeyWithByteRepr::new(*key);
247 if let Some(value) = self.muts_cached.get(&kb) {
248 return Some(value);
249 };
250
251 self.reads_cached.get_refresh(key).map(|v| &*v)
252 }
253
254 fn get_muts_cached_by_byte_prefix(&self, prefix: &[u8]) -> Vec<Key> {
256 self.muts_cached
257 .range(prefix.to_vec()..)
258 .take_while(|(key, _)| key.starts_with(prefix))
259 .map(|(key, _)| key.to_key())
260 .collect()
261 }
262
263 pub fn is_pruned(&self, key: &Key) -> bool {
265 self.prunes_cached.contains(key)
266 }
267
268 pub(self) fn into_muts(self) -> (BTreeMap<KeyWithByteRepr, StoredValue>, BTreeSet<Key>) {
269 (self.muts_cached, self.prunes_cached)
270 }
271}
272
273#[derive(Debug, Clone)]
277struct KeyWithByteRepr(Key, Vec<u8>);
278
279impl KeyWithByteRepr {
280 #[inline]
281 fn new(key: Key) -> Self {
282 let bytes = key.to_bytes().expect("should always serialize a Key");
283 KeyWithByteRepr(key, bytes)
284 }
285
286 #[inline]
287 fn starts_with(&self, prefix: &[u8]) -> bool {
288 self.1.starts_with(prefix)
289 }
290
291 #[inline]
292 fn to_key(&self) -> Key {
293 self.0
294 }
295}
296
297impl Borrow<Vec<u8>> for KeyWithByteRepr {
298 #[inline]
299 fn borrow(&self) -> &Vec<u8> {
300 &self.1
301 }
302}
303
304impl PartialEq for KeyWithByteRepr {
305 #[inline]
306 fn eq(&self, other: &Self) -> bool {
307 self.1 == other.1
308 }
309}
310
311impl Eq for KeyWithByteRepr {}
312
313impl PartialOrd for KeyWithByteRepr {
314 #[inline]
315 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
316 Some(self.cmp(other))
317 }
318}
319
320impl Ord for KeyWithByteRepr {
321 #[inline]
322 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
323 self.1.cmp(&other.1)
324 }
325}
326
327pub type TrackingCopyCache = GenericTrackingCopyCache<HeapSize>;
329
330#[derive(Clone)]
334pub struct TrackingCopy<R> {
335 reader: Arc<R>,
336 cache: TrackingCopyCache,
337 effects: Effects,
338 max_query_depth: u64,
339 messages: Messages,
340 enable_addressable_entity: bool,
341}
342
343#[derive(Debug)]
345pub enum AddResult {
346 Success,
348 KeyNotFound(Key),
350 TypeMismatch(StoredValueTypeMismatch),
352 Serialization(bytesrepr::Error),
354 Transform(TransformError),
356}
357
358impl From<CLValueError> for AddResult {
359 fn from(error: CLValueError) -> Self {
360 match error {
361 CLValueError::Serialization(error) => AddResult::Serialization(error),
362 CLValueError::Type(type_mismatch) => {
363 let expected = format!("{:?}", type_mismatch.expected);
364 let found = format!("{:?}", type_mismatch.found);
365 AddResult::TypeMismatch(StoredValueTypeMismatch::new(expected, found))
366 }
367 }
368 }
369}
370
371pub type TrackingCopyParts = (TrackingCopyCache, Effects, Messages);
373
374impl<R: StateReader<Key, StoredValue>> TrackingCopy<R>
375where
376 R: StateReader<Key, StoredValue, Error = GlobalStateError>,
377{
378 pub fn new(
380 reader: R,
381 max_query_depth: u64,
382 enable_addressable_entity: bool,
383 ) -> TrackingCopy<R> {
384 TrackingCopy {
385 reader: Arc::new(reader),
386 cache: GenericTrackingCopyCache::new(1024 * 16, HeapSize),
388 effects: Effects::new(),
389 max_query_depth,
390 messages: Vec::new(),
391 enable_addressable_entity,
392 }
393 }
394
395 pub fn reader(&self) -> &R {
397 &self.reader
398 }
399
400 pub fn shared_reader(&self) -> Arc<R> {
402 Arc::clone(&self.reader)
403 }
404
405 pub fn fork(&self) -> TrackingCopy<&TrackingCopy<R>> {
417 TrackingCopy::new(self, self.max_query_depth, self.enable_addressable_entity)
418 }
419
420 pub fn fork2(&self) -> Self {
431 TrackingCopy {
432 reader: Arc::clone(&self.reader),
433 cache: self.cache.clone(),
434 effects: self.effects.clone(),
435 max_query_depth: self.max_query_depth,
436 messages: self.messages.clone(),
437 enable_addressable_entity: self.enable_addressable_entity,
438 }
439 }
440
441 pub fn apply_changes(
448 &mut self,
449 effects: Effects,
450 cache: TrackingCopyCache,
451 messages: Messages,
452 ) {
453 self.effects = effects;
454 self.cache = cache;
455 self.messages = messages;
456 }
457
458 pub fn effects(&self) -> Effects {
460 self.effects.clone()
461 }
462
463 pub fn cache(&self) -> TrackingCopyCache {
465 self.cache.clone()
466 }
467
468 pub fn destructure(self) -> (Vec<(Key, StoredValue)>, BTreeSet<Key>, Effects) {
470 let (writes, prunes) = self.cache.into_muts();
471 let writes: Vec<(Key, StoredValue)> = writes.into_iter().map(|(k, v)| (k.0, v)).collect();
472
473 (writes, prunes, self.effects)
474 }
475
476 pub fn enable_addressable_entity(&self) -> bool {
478 self.enable_addressable_entity
479 }
480
481 pub fn get(&mut self, key: &Key) -> Result<Option<StoredValue>, TrackingCopyError> {
483 if let Some(value) = self.cache.get(key) {
484 return Ok(Some(value.to_owned()));
485 }
486 match self.reader.read(key) {
487 Ok(ret) => {
488 if let Some(value) = ret {
489 self.cache.insert_read(*key, value.to_owned());
490 Ok(Some(value))
491 } else {
492 Ok(None)
493 }
494 }
495 Err(err) => Err(TrackingCopyError::Storage(err)),
496 }
497 }
498
499 pub fn get_keys(&self, key_tag: &KeyTag) -> Result<BTreeSet<Key>, TrackingCopyError> {
501 self.get_by_byte_prefix(&[*key_tag as u8])
502 }
503
504 pub fn get_keys_by_prefix(
506 &self,
507 key_prefix: &KeyPrefix,
508 ) -> Result<BTreeSet<Key>, TrackingCopyError> {
509 let byte_prefix = key_prefix
510 .to_bytes()
511 .map_err(TrackingCopyError::BytesRepr)?;
512 self.get_by_byte_prefix(&byte_prefix)
513 }
514
515 pub(crate) fn get_by_byte_prefix(
517 &self,
518 byte_prefix: &[u8],
519 ) -> Result<BTreeSet<Key>, TrackingCopyError> {
520 let ret = self.keys_with_prefix(byte_prefix)?.into_iter().collect();
521 Ok(ret)
522 }
523
524 pub fn read(&mut self, key: &Key) -> Result<Option<StoredValue>, TrackingCopyError> {
526 let normalized_key = key.normalize();
527 if let Some(value) = self.get(&normalized_key)? {
528 self.effects
529 .push(TransformV2::new(normalized_key, TransformKindV2::Identity));
530 Ok(Some(value))
531 } else {
532 Ok(None)
533 }
534 }
535
536 pub fn read_first(&mut self, keys: &[&Key]) -> Result<Option<StoredValue>, TrackingCopyError> {
538 for key in keys {
539 if let Some(value) = self.read(key)? {
540 return Ok(Some(value));
541 }
542 }
543 Ok(None)
544 }
545
546 pub fn write(&mut self, key: Key, value: StoredValue) {
548 let normalized_key = key.normalize();
549 self.cache.insert_write(normalized_key, value.clone());
550 let transform = TransformV2::new(normalized_key, TransformKindV2::Write(value));
551 self.effects.push(transform);
552 }
553
554 #[allow(clippy::too_many_arguments)]
560 pub fn emit_message(
561 &mut self,
562 message_topic_key: Key,
563 message_topic_summary: StoredValue,
564 message_key: Key,
565 message_value: StoredValue,
566 block_message_count_value: StoredValue,
567 message: Message,
568 ) {
569 self.write(message_key, message_value);
570 self.write(message_topic_key, message_topic_summary);
571 self.write(
572 Key::BlockGlobal(BlockGlobalAddr::MessageCount),
573 block_message_count_value,
574 );
575 self.messages.push(message);
576 }
577
578 pub fn prune(&mut self, key: Key) {
580 let normalized_key = key.normalize();
581 self.cache.insert_prune(normalized_key);
582 self.effects.push(TransformV2::new(
583 normalized_key,
584 TransformKindV2::Prune(key),
585 ));
586 }
587
588 pub fn add(&mut self, key: Key, value: StoredValue) -> Result<AddResult, TrackingCopyError> {
593 let normalized_key = key.normalize();
594 let current_value = match self.get(&normalized_key)? {
595 None => return Ok(AddResult::KeyNotFound(normalized_key)),
596 Some(current_value) => current_value,
597 };
598
599 let type_name = value.type_name();
600 let mismatch = || {
601 Ok(AddResult::TypeMismatch(StoredValueTypeMismatch::new(
602 "I32, U64, U128, U256, U512 or (String, Key) tuple".to_string(),
603 type_name,
604 )))
605 };
606
607 let transform_kind = match value {
608 StoredValue::CLValue(cl_value) => match *cl_value.cl_type() {
609 CLType::I32 => match cl_value.into_t() {
610 Ok(value) => TransformKindV2::AddInt32(value),
611 Err(error) => return Ok(AddResult::from(error)),
612 },
613 CLType::U64 => match cl_value.into_t() {
614 Ok(value) => TransformKindV2::AddUInt64(value),
615 Err(error) => return Ok(AddResult::from(error)),
616 },
617 CLType::U128 => match cl_value.into_t() {
618 Ok(value) => TransformKindV2::AddUInt128(value),
619 Err(error) => return Ok(AddResult::from(error)),
620 },
621 CLType::U256 => match cl_value.into_t() {
622 Ok(value) => TransformKindV2::AddUInt256(value),
623 Err(error) => return Ok(AddResult::from(error)),
624 },
625 CLType::U512 => match cl_value.into_t() {
626 Ok(value) => TransformKindV2::AddUInt512(value),
627 Err(error) => return Ok(AddResult::from(error)),
628 },
629 _ => {
630 if *cl_value.cl_type() == casper_types::named_key_type() {
631 match cl_value.into_t() {
632 Ok((name, key)) => {
633 let mut named_keys = NamedKeys::new();
634 named_keys.insert(name, key);
635 TransformKindV2::AddKeys(named_keys)
636 }
637 Err(error) => return Ok(AddResult::from(error)),
638 }
639 } else {
640 return mismatch();
641 }
642 }
643 },
644 _ => return mismatch(),
645 };
646
647 match transform_kind.clone().apply(current_value) {
648 Ok(TransformInstruction::Store(new_value)) => {
649 self.cache.insert_write(normalized_key, new_value);
650 self.effects
651 .push(TransformV2::new(normalized_key, transform_kind));
652 Ok(AddResult::Success)
653 }
654 Ok(TransformInstruction::Prune(key)) => {
655 self.cache.insert_prune(normalized_key);
656 self.effects.push(TransformV2::new(
657 normalized_key,
658 TransformKindV2::Prune(key),
659 ));
660 Ok(AddResult::Success)
661 }
662 Err(TransformError::TypeMismatch(type_mismatch)) => {
663 Ok(AddResult::TypeMismatch(type_mismatch))
664 }
665 Err(TransformError::Serialization(error)) => Ok(AddResult::Serialization(error)),
666 Err(transform_error) => Ok(AddResult::Transform(transform_error)),
667 }
668 }
669
670 pub fn messages(&self) -> Messages {
672 self.messages.clone()
673 }
674
675 pub fn query(
683 &self,
684 base_key: Key,
685 path: &[String],
686 ) -> Result<TrackingCopyQueryResult, TrackingCopyError> {
687 let mut query = Query::new(base_key, path);
688
689 let mut proofs = Vec::new();
690
691 loop {
692 if query.depth >= self.max_query_depth {
693 return Ok(query.into_depth_limit_result());
694 }
695
696 if !query.visited_keys.insert(query.current_key) {
697 return Ok(query.into_circular_ref_result());
698 }
699
700 let stored_value = match self.reader.read_with_proof(&query.current_key)? {
701 None => {
702 return Ok(query.into_not_found_result("Failed to find base key"));
703 }
704 Some(stored_value) => stored_value,
705 };
706
707 let value = stored_value.value().to_owned();
708
709 let value = match handle_stored_dictionary_value(query.current_key, value) {
712 Ok(patched_stored_value) => patched_stored_value,
713 Err(error) => {
714 return Ok(query.into_not_found_result(&format!(
715 "Failed to retrieve dictionary value: {}",
716 error
717 )));
718 }
719 };
720
721 proofs.push(stored_value);
722
723 if query.unvisited_names.is_empty() && !query.current_key.is_named_key() {
724 return Ok(TrackingCopyQueryResult::Success { value, proofs });
725 }
726
727 let stored_value: &StoredValue = proofs
728 .last()
729 .map(|r| r.value())
730 .expect("but we just pushed");
731
732 match stored_value {
733 StoredValue::Account(account) => {
734 let name = query.next_name();
735 if let Some(key) = account.named_keys().get(name) {
736 query.navigate(*key);
737 } else {
738 let msg_prefix = format!("Name {} not found in Account", name);
739 return Ok(query.into_not_found_result(&msg_prefix));
740 }
741 }
742 StoredValue::Contract(contract) => {
743 let name = query.next_name();
744 if let Some(key) = contract.named_keys().get(name) {
745 query.navigate(*key);
746 } else {
747 let msg_prefix = format!("Name {} not found in Contract", name);
748 return Ok(query.into_not_found_result(&msg_prefix));
749 }
750 }
751 StoredValue::NamedKey(named_key_value) => {
752 match query.visited_names.last() {
753 Some(expected_name) => match named_key_value.get_name() {
754 Ok(actual_name) => {
755 if &actual_name != expected_name {
756 return Ok(query.into_not_found_result(
757 "Queried and retrieved names do not match",
758 ));
759 } else if let Ok(key) = named_key_value.get_key() {
760 query.navigate(key)
761 } else {
762 return Ok(query
763 .into_not_found_result("Failed to parse CLValue as Key"));
764 }
765 }
766 Err(_) => {
767 return Ok(query
768 .into_not_found_result("Failed to parse CLValue as String"));
769 }
770 },
771 None if path.is_empty() => {
772 return Ok(TrackingCopyQueryResult::Success { value, proofs });
773 }
774 None => return Ok(query.into_not_found_result("No visited names")),
775 }
776 }
777 StoredValue::CLValue(cl_value) if cl_value.cl_type() == &CLType::Key => {
778 if let Ok(key) = cl_value.to_owned().into_t::<Key>() {
779 query.navigate(key);
780 } else {
781 return Ok(query.into_not_found_result("Failed to parse CLValue as Key"));
782 }
783 }
784 StoredValue::CLValue(cl_value) => {
785 let msg_prefix = format!(
786 "Query cannot continue as {:?} is not an account, contract nor key to \
787 such. Value found",
788 cl_value
789 );
790 return Ok(query.into_not_found_result(&msg_prefix));
791 }
792 StoredValue::AddressableEntity(_) => {
793 let current_key = query.current_key;
794 let name = query.next_name();
795
796 if let Key::AddressableEntity(addr) = current_key {
797 let named_key_addr = match NamedKeyAddr::new_from_string(addr, name.clone())
798 {
799 Ok(named_key_addr) => Key::NamedKey(named_key_addr),
800 Err(error) => {
801 let msg_prefix = format!("{}", error);
802 return Ok(query.into_not_found_result(&msg_prefix));
803 }
804 };
805 query.navigate_for_named_key(named_key_addr);
806 } else {
807 let msg_prefix = "Invalid base key".to_string();
808 return Ok(query.into_not_found_result(&msg_prefix));
809 }
810 }
811 StoredValue::ContractWasm(_) => {
812 return Ok(query.into_not_found_result("ContractWasm value found."));
813 }
814 StoredValue::ContractPackage(_) => {
815 return Ok(query.into_not_found_result("ContractPackage value found."));
816 }
817 StoredValue::SmartContract(_) => {
818 return Ok(query.into_not_found_result("Package value found."));
819 }
820 StoredValue::ByteCode(_) => {
821 return Ok(query.into_not_found_result("ByteCode value found."));
822 }
823 StoredValue::Transfer(_) => {
824 return Ok(query.into_not_found_result("Legacy Transfer value found."));
825 }
826 StoredValue::DeployInfo(_) => {
827 return Ok(query.into_not_found_result("DeployInfo value found."));
828 }
829 StoredValue::EraInfo(_) => {
830 return Ok(query.into_not_found_result("EraInfo value found."));
831 }
832 StoredValue::Bid(_) => {
833 return Ok(query.into_not_found_result("Bid value found."));
834 }
835 StoredValue::BidKind(_) => {
836 return Ok(query.into_not_found_result("BidKind value found."));
837 }
838 StoredValue::Withdraw(_) => {
839 return Ok(query.into_not_found_result("WithdrawPurses value found."));
840 }
841 StoredValue::Unbonding(_) => {
842 return Ok(query.into_not_found_result("Unbonding value found."));
843 }
844 StoredValue::MessageTopic(_) => {
845 return Ok(query.into_not_found_result("MessageTopic value found."));
846 }
847 StoredValue::Message(_) => {
848 return Ok(query.into_not_found_result("Message value found."));
849 }
850 StoredValue::EntryPoint(_) => {
851 return Ok(query.into_not_found_result("EntryPoint value found."));
852 }
853 StoredValue::Prepayment(_) => {
854 return Ok(query.into_not_found_result("Prepayment value found."))
855 }
856 StoredValue::RawBytes(_) => {
857 return Ok(query.into_not_found_result("RawBytes value found."));
858 }
859 }
860 }
861 }
862}
863
864impl<R: StateReader<Key, StoredValue>> StateReader<Key, StoredValue> for &TrackingCopy<R> {
870 type Error = R::Error;
871
872 fn read(&self, key: &Key) -> Result<Option<StoredValue>, Self::Error> {
873 let kb = KeyWithByteRepr::new(*key);
874 if let Some(value) = self.cache.muts_cached.get(&kb) {
875 return Ok(Some(value.to_owned()));
876 }
877 if let Some(value) = self.reader.read(key)? {
878 Ok(Some(value))
879 } else {
880 Ok(None)
881 }
882 }
883
884 fn read_with_proof(
885 &self,
886 key: &Key,
887 ) -> Result<Option<TrieMerkleProof<Key, StoredValue>>, Self::Error> {
888 self.reader.read_with_proof(key)
889 }
890
891 fn keys_with_prefix(&self, byte_prefix: &[u8]) -> Result<Vec<Key>, Self::Error> {
892 let keys = self.reader.keys_with_prefix(byte_prefix)?;
893
894 let ret = keys
895 .into_iter()
896 .filter(|key| !self.cache.is_pruned(key))
898 .chain(self.cache.get_muts_cached_by_byte_prefix(byte_prefix))
900 .collect();
901 Ok(ret)
902 }
903}
904
905#[derive(Error, Debug, PartialEq, Eq)]
907pub enum ValidationError {
908 #[error("The path should not have a different length than the proof less one.")]
910 PathLengthDifferentThanProofLessOne,
911
912 #[error("The provided key does not match the key in the proof.")]
914 UnexpectedKey,
915
916 #[error("The provided value does not match the value in the proof.")]
918 UnexpectedValue,
919
920 #[error("The proof hash is invalid.")]
922 InvalidProofHash,
923
924 #[error("The path went cold.")]
926 PathCold,
927
928 #[error("Serialization error: {0}")]
930 BytesRepr(bytesrepr::Error),
931
932 #[error("Key is not a URef")]
934 KeyIsNotAURef(Key),
935
936 #[error("Failed to convert stored value to key")]
938 ValueToCLValueConversion,
939
940 #[error("{0}")]
942 CLValueError(CLValueError),
943}
944
945impl From<CLValueError> for ValidationError {
946 fn from(err: CLValueError) -> Self {
947 ValidationError::CLValueError(err)
948 }
949}
950
951impl From<bytesrepr::Error> for ValidationError {
952 fn from(error: bytesrepr::Error) -> Self {
953 Self::BytesRepr(error)
954 }
955}
956
957pub fn validate_query_proof(
961 hash: &Digest,
962 proofs: &[TrieMerkleProof<Key, StoredValue>],
963 expected_first_key: &Key,
964 path: &[String],
965 expected_value: &StoredValue,
966) -> Result<(), ValidationError> {
967 if proofs.len() != path.len() + 1 {
968 return Err(ValidationError::PathLengthDifferentThanProofLessOne);
969 }
970
971 let mut proofs_iter = proofs.iter();
972 let mut path_components_iter = path.iter();
973
974 let first_proof = proofs_iter.next().unwrap();
976
977 if first_proof.key() != &expected_first_key.normalize() {
978 return Err(ValidationError::UnexpectedKey);
979 }
980
981 if hash != &compute_state_hash(first_proof)? {
982 return Err(ValidationError::InvalidProofHash);
983 }
984
985 let mut proof_value = first_proof.value();
986
987 for proof in proofs_iter {
988 let named_keys = match proof_value {
989 StoredValue::Account(account) => account.named_keys(),
990 StoredValue::Contract(contract) => contract.named_keys(),
991 _ => return Err(ValidationError::PathCold),
992 };
993
994 let path_component = match path_components_iter.next() {
995 Some(path_component) => path_component,
996 None => return Err(ValidationError::PathCold),
997 };
998
999 let key = match named_keys.get(path_component) {
1000 Some(key) => key,
1001 None => return Err(ValidationError::PathCold),
1002 };
1003
1004 if proof.key() != &key.normalize() {
1005 return Err(ValidationError::UnexpectedKey);
1006 }
1007
1008 if hash != &compute_state_hash(proof)? {
1009 return Err(ValidationError::InvalidProofHash);
1010 }
1011
1012 proof_value = proof.value();
1013 }
1014
1015 if proof_value != expected_value {
1016 return Err(ValidationError::UnexpectedValue);
1017 }
1018
1019 Ok(())
1020}
1021
1022pub fn validate_query_merkle_proof(
1026 hash: &Digest,
1027 proofs: &[TrieMerkleProof<Key, StoredValue>],
1028 expected_key_trace: &[Key],
1029 expected_value: &StoredValue,
1030) -> Result<(), ValidationError> {
1031 let expected_len = expected_key_trace.len();
1032 if proofs.len() != expected_len {
1033 return Err(ValidationError::PathLengthDifferentThanProofLessOne);
1034 }
1035
1036 let proof_keys: Vec<Key> = proofs.iter().map(|proof| *proof.key()).collect();
1037
1038 if !expected_key_trace.eq(&proof_keys) {
1039 return Err(ValidationError::UnexpectedKey);
1040 }
1041
1042 if expected_value != proofs[expected_len - 1].value() {
1043 return Err(ValidationError::UnexpectedValue);
1044 }
1045
1046 let mut proofs_iter = proofs.iter();
1047
1048 let first_proof = proofs_iter.next().unwrap();
1050
1051 if hash != &compute_state_hash(first_proof)? {
1052 return Err(ValidationError::InvalidProofHash);
1053 }
1054
1055 Ok(())
1056}
1057
1058pub fn validate_balance_proof(
1060 hash: &Digest,
1061 balance_proof: &TrieMerkleProof<Key, StoredValue>,
1062 expected_purse_key: Key,
1063 expected_motes: &U512,
1064) -> Result<(), ValidationError> {
1065 let expected_balance_key = expected_purse_key
1066 .into_uref()
1067 .map(|uref| Key::Balance(uref.addr()))
1068 .ok_or_else(|| ValidationError::KeyIsNotAURef(expected_purse_key.to_owned()))?;
1069
1070 if balance_proof.key() != &expected_balance_key.normalize() {
1071 return Err(ValidationError::UnexpectedKey);
1072 }
1073
1074 if hash != &compute_state_hash(balance_proof)? {
1075 return Err(ValidationError::InvalidProofHash);
1076 }
1077
1078 let balance_proof_stored_value = balance_proof.value().to_owned();
1079
1080 let balance_proof_clvalue: CLValue = balance_proof_stored_value
1081 .try_into()
1082 .map_err(|_| ValidationError::ValueToCLValueConversion)?;
1083
1084 let balance_motes: U512 = balance_proof_clvalue.into_t()?;
1085
1086 if expected_motes != &balance_motes {
1087 return Err(ValidationError::UnexpectedValue);
1088 }
1089
1090 Ok(())
1091}
1092
1093use crate::global_state::{
1094 error::Error,
1095 state::{
1096 lmdb::{make_temporary_global_state, LmdbGlobalStateView},
1097 StateProvider,
1098 },
1099};
1100use tempfile::TempDir;
1101
1102pub fn new_temporary_tracking_copy(
1104 initial_data: impl IntoIterator<Item = (Key, StoredValue)>,
1105 max_query_depth: Option<u64>,
1106 enable_addressable_entity: bool,
1107) -> (TrackingCopy<LmdbGlobalStateView>, TempDir) {
1108 let (global_state, state_root_hash, tempdir) = make_temporary_global_state(initial_data);
1109
1110 let reader = global_state
1111 .checkout(state_root_hash)
1112 .expect("Checkout should not throw errors.")
1113 .expect("Root hash should exist.");
1114
1115 let query_depth = max_query_depth.unwrap_or(DEFAULT_MAX_QUERY_DEPTH);
1116
1117 (
1118 TrackingCopy::new(reader, query_depth, enable_addressable_entity),
1119 tempdir,
1120 )
1121}