casper_storage/tracking_copy/
mod.rs

1//! This module defines the `TrackingCopy` - a utility that caches operations on the state, so that
2//! the underlying state remains unmodified, but it can be interacted with as if the modifications
3//! were applied on it.
4mod byte_size;
5mod error;
6mod ext;
7mod ext_entity;
8mod meter;
9#[cfg(test)]
10mod tests;
11
12use std::{
13    borrow::Borrow,
14    collections::{BTreeMap, BTreeSet, HashSet, VecDeque},
15    convert::{From, TryInto},
16    fmt::Debug,
17    sync::Arc,
18};
19
20use linked_hash_map::LinkedHashMap;
21use thiserror::Error;
22use tracing::error;
23
24use crate::{
25    global_state::{
26        error::Error as GlobalStateError, state::StateReader,
27        trie_store::operations::compute_state_hash, DEFAULT_MAX_QUERY_DEPTH,
28    },
29    KeyPrefix,
30};
31use casper_types::{
32    addressable_entity::NamedKeyAddr,
33    bytesrepr::{self, ToBytes},
34    contract_messages::{Message, Messages},
35    contracts::NamedKeys,
36    execution::{Effects, TransformError, TransformInstruction, TransformKindV2, TransformV2},
37    global_state::TrieMerkleProof,
38    handle_stored_dictionary_value, BlockGlobalAddr, CLType, CLValue, CLValueError, Digest, Key,
39    KeyTag, StoredValue, StoredValueTypeMismatch, U512,
40};
41
42use self::meter::{heap_meter::HeapSize, Meter};
43pub use self::{
44    error::Error as TrackingCopyError,
45    ext::TrackingCopyExt,
46    ext_entity::{FeesPurseHandling, TrackingCopyEntityExt},
47};
48
49/// Result of a query on a `TrackingCopy`.
50#[derive(Debug)]
51#[allow(clippy::large_enum_variant)]
52pub enum TrackingCopyQueryResult {
53    /// Invalid state root hash.
54    RootNotFound,
55    /// The value wasn't found.
56    ValueNotFound(String),
57    /// A circular reference was found in the state while traversing it.
58    CircularReference(String),
59    /// The query reached the depth limit.
60    DepthLimit {
61        /// The depth reached.
62        depth: u64,
63    },
64    /// The query was successful.
65    Success {
66        /// The value read from the state.
67        value: StoredValue,
68        /// Merkle proofs for the value.
69        proofs: Vec<TrieMerkleProof<Key, StoredValue>>,
70    },
71}
72
73impl TrackingCopyQueryResult {
74    /// Is this a successful query?
75    pub fn is_success(&self) -> bool {
76        matches!(self, TrackingCopyQueryResult::Success { .. })
77    }
78
79    /// As result.
80    pub fn into_result(self) -> Result<StoredValue, TrackingCopyError> {
81        match self {
82            TrackingCopyQueryResult::RootNotFound => {
83                Err(TrackingCopyError::Storage(Error::RootNotFound))
84            }
85            TrackingCopyQueryResult::ValueNotFound(msg) => {
86                Err(TrackingCopyError::ValueNotFound(msg))
87            }
88            TrackingCopyQueryResult::CircularReference(msg) => {
89                Err(TrackingCopyError::CircularReference(msg))
90            }
91            TrackingCopyQueryResult::DepthLimit { depth } => {
92                Err(TrackingCopyError::QueryDepthLimit { depth })
93            }
94            TrackingCopyQueryResult::Success { value, .. } => Ok(value),
95        }
96    }
97}
98
99/// Struct containing state relating to a given query.
100struct Query {
101    /// The key from where the search starts.
102    base_key: Key,
103    /// A collection of normalized keys which have been visited during the search.
104    visited_keys: HashSet<Key>,
105    /// The key currently being processed.
106    current_key: Key,
107    /// Path components which have not yet been followed, held in the same order in which they were
108    /// provided to the `query()` call.
109    unvisited_names: VecDeque<String>,
110    /// Path components which have been followed, held in the same order in which they were
111    /// provided to the `query()` call.
112    visited_names: Vec<String>,
113    /// Current depth of the query.
114    depth: u64,
115}
116
117impl Query {
118    fn new(base_key: Key, path: &[String]) -> Self {
119        Query {
120            base_key,
121            current_key: base_key.normalize(),
122            unvisited_names: path.iter().cloned().collect(),
123            visited_names: Vec::new(),
124            visited_keys: HashSet::new(),
125            depth: 0,
126        }
127    }
128
129    /// Panics if `unvisited_names` is empty.
130    fn next_name(&mut self) -> &String {
131        let next_name = self.unvisited_names.pop_front().unwrap();
132        self.visited_names.push(next_name);
133        self.visited_names.last().unwrap()
134    }
135
136    fn navigate(&mut self, key: Key) {
137        self.current_key = key.normalize();
138        self.depth += 1;
139    }
140
141    fn navigate_for_named_key(&mut self, named_key: Key) {
142        if let Key::NamedKey(_) = &named_key {
143            self.current_key = named_key.normalize();
144        }
145    }
146
147    fn into_not_found_result(self, msg_prefix: &str) -> TrackingCopyQueryResult {
148        let msg = format!("{} at path: {}", msg_prefix, self.current_path());
149        TrackingCopyQueryResult::ValueNotFound(msg)
150    }
151
152    fn into_circular_ref_result(self) -> TrackingCopyQueryResult {
153        let msg = format!(
154            "{:?} has formed a circular reference at path: {}",
155            self.current_key,
156            self.current_path()
157        );
158        TrackingCopyQueryResult::CircularReference(msg)
159    }
160
161    fn into_depth_limit_result(self) -> TrackingCopyQueryResult {
162        TrackingCopyQueryResult::DepthLimit { depth: self.depth }
163    }
164
165    fn current_path(&self) -> String {
166        let mut path = format!("{:?}", self.base_key);
167        for name in &self.visited_names {
168            path.push('/');
169            path.push_str(name);
170        }
171        path
172    }
173}
174
175/// Keeps track of already accessed keys.
176/// We deliberately separate cached Reads from cached mutations
177/// because we want to invalidate Reads' cache so it doesn't grow too fast.
178#[derive(Clone, Debug)]
179pub struct GenericTrackingCopyCache<M: Copy + Debug> {
180    max_cache_size: usize,
181    current_cache_size: usize,
182    reads_cached: LinkedHashMap<Key, StoredValue>,
183    muts_cached: BTreeMap<KeyWithByteRepr, StoredValue>,
184    prunes_cached: BTreeSet<Key>,
185    meter: M,
186}
187
188impl<M: Meter<Key, StoredValue> + Copy + Default> GenericTrackingCopyCache<M> {
189    /// Creates instance of `TrackingCopyCache` with specified `max_cache_size`,
190    /// above which least-recently-used elements of the cache are invalidated.
191    /// Measurements of elements' "size" is done with the usage of `Meter`
192    /// instance.
193    pub fn new(max_cache_size: usize, meter: M) -> GenericTrackingCopyCache<M> {
194        GenericTrackingCopyCache {
195            max_cache_size,
196            current_cache_size: 0,
197            reads_cached: LinkedHashMap::new(),
198            muts_cached: BTreeMap::new(),
199            prunes_cached: BTreeSet::new(),
200            meter,
201        }
202    }
203
204    /// Creates instance of `TrackingCopyCache` with specified `max_cache_size`, above which
205    /// least-recently-used elements of the cache are invalidated. Measurements of elements' "size"
206    /// is done with the usage of default `Meter` instance.
207    pub fn new_default(max_cache_size: usize) -> GenericTrackingCopyCache<M> {
208        GenericTrackingCopyCache::new(max_cache_size, M::default())
209    }
210
211    /// Inserts `key` and `value` pair to Read cache.
212    pub fn insert_read(&mut self, key: Key, value: StoredValue) {
213        let element_size = Meter::measure(&self.meter, &key, &value);
214        self.reads_cached.insert(key, value);
215        self.current_cache_size += element_size;
216        while self.current_cache_size > self.max_cache_size {
217            match self.reads_cached.pop_front() {
218                Some((k, v)) => {
219                    let element_size = Meter::measure(&self.meter, &k, &v);
220                    self.current_cache_size -= element_size;
221                }
222                None => break,
223            }
224        }
225    }
226
227    /// Inserts `key` and `value` pair to Write/Add cache.
228    pub fn insert_write(&mut self, key: Key, value: StoredValue) {
229        let kb = KeyWithByteRepr::new(key);
230        self.prunes_cached.remove(&key);
231        self.muts_cached.insert(kb, value);
232    }
233
234    /// Inserts `key` and `value` pair to Write/Add cache.
235    pub fn insert_prune(&mut self, key: Key) {
236        self.prunes_cached.insert(key);
237    }
238
239    /// Gets value from `key` in the cache.
240    pub fn get(&mut self, key: &Key) -> Option<&StoredValue> {
241        if self.prunes_cached.contains(key) {
242            // the item is marked for pruning and therefore
243            // is no longer accessible.
244            return None;
245        }
246        let kb = KeyWithByteRepr::new(*key);
247        if let Some(value) = self.muts_cached.get(&kb) {
248            return Some(value);
249        };
250
251        self.reads_cached.get_refresh(key).map(|v| &*v)
252    }
253
254    /// Get cached items by prefix.
255    fn get_muts_cached_by_byte_prefix(&self, prefix: &[u8]) -> Vec<Key> {
256        self.muts_cached
257            .range(prefix.to_vec()..)
258            .take_while(|(key, _)| key.starts_with(prefix))
259            .map(|(key, _)| key.to_key())
260            .collect()
261    }
262
263    /// Does the prune cache contain key.
264    pub fn is_pruned(&self, key: &Key) -> bool {
265        self.prunes_cached.contains(key)
266    }
267
268    pub(self) fn into_muts(self) -> (BTreeMap<KeyWithByteRepr, StoredValue>, BTreeSet<Key>) {
269        (self.muts_cached, self.prunes_cached)
270    }
271}
272
273/// A helper type for `TrackingCopyCache` that allows convenient storage and access
274/// to keys as bytes.
275/// Its equality and ordering is based on the byte representation of the key.
276#[derive(Debug, Clone)]
277struct KeyWithByteRepr(Key, Vec<u8>);
278
279impl KeyWithByteRepr {
280    #[inline]
281    fn new(key: Key) -> Self {
282        let bytes = key.to_bytes().expect("should always serialize a Key");
283        KeyWithByteRepr(key, bytes)
284    }
285
286    #[inline]
287    fn starts_with(&self, prefix: &[u8]) -> bool {
288        self.1.starts_with(prefix)
289    }
290
291    #[inline]
292    fn to_key(&self) -> Key {
293        self.0
294    }
295}
296
297impl Borrow<Vec<u8>> for KeyWithByteRepr {
298    #[inline]
299    fn borrow(&self) -> &Vec<u8> {
300        &self.1
301    }
302}
303
304impl PartialEq for KeyWithByteRepr {
305    #[inline]
306    fn eq(&self, other: &Self) -> bool {
307        self.1 == other.1
308    }
309}
310
311impl Eq for KeyWithByteRepr {}
312
313impl PartialOrd for KeyWithByteRepr {
314    #[inline]
315    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
316        Some(self.cmp(other))
317    }
318}
319
320impl Ord for KeyWithByteRepr {
321    #[inline]
322    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
323        self.1.cmp(&other.1)
324    }
325}
326
327/// An alias for a `TrackingCopyCache` with `HeapSize` as the meter.
328pub type TrackingCopyCache = GenericTrackingCopyCache<HeapSize>;
329
330/// An interface for the global state that caches all operations (reads and writes) instead of
331/// applying them directly to the state. This way the state remains unmodified, while the user can
332/// interact with it as if it was being modified in real time.
333#[derive(Clone)]
334pub struct TrackingCopy<R> {
335    reader: Arc<R>,
336    cache: TrackingCopyCache,
337    effects: Effects,
338    max_query_depth: u64,
339    messages: Messages,
340    enable_addressable_entity: bool,
341}
342
343/// Result of executing an "add" operation on a value in the state.
344#[derive(Debug)]
345pub enum AddResult {
346    /// The operation was successful.
347    Success,
348    /// The key was not found.
349    KeyNotFound(Key),
350    /// There was a type mismatch between the stored value and the value being added.
351    TypeMismatch(StoredValueTypeMismatch),
352    /// Serialization error.
353    Serialization(bytesrepr::Error),
354    /// Transform error.
355    Transform(TransformError),
356}
357
358impl From<CLValueError> for AddResult {
359    fn from(error: CLValueError) -> Self {
360        match error {
361            CLValueError::Serialization(error) => AddResult::Serialization(error),
362            CLValueError::Type(type_mismatch) => {
363                let expected = format!("{:?}", type_mismatch.expected);
364                let found = format!("{:?}", type_mismatch.found);
365                AddResult::TypeMismatch(StoredValueTypeMismatch::new(expected, found))
366            }
367        }
368    }
369}
370
371/// A helper type for `TrackingCopy` that represents a key-value pair.
372pub type TrackingCopyParts = (TrackingCopyCache, Effects, Messages);
373
374impl<R: StateReader<Key, StoredValue>> TrackingCopy<R>
375where
376    R: StateReader<Key, StoredValue, Error = GlobalStateError>,
377{
378    /// Creates a new `TrackingCopy` using the `reader` as the interface to the state.
379    pub fn new(
380        reader: R,
381        max_query_depth: u64,
382        enable_addressable_entity: bool,
383    ) -> TrackingCopy<R> {
384        TrackingCopy {
385            reader: Arc::new(reader),
386            // TODO: Should `max_cache_size` be a fraction of wasm memory limit?
387            cache: GenericTrackingCopyCache::new(1024 * 16, HeapSize),
388            effects: Effects::new(),
389            max_query_depth,
390            messages: Vec::new(),
391            enable_addressable_entity,
392        }
393    }
394
395    /// Returns the `reader` used to access the state.
396    pub fn reader(&self) -> &R {
397        &self.reader
398    }
399
400    /// Returns a shared reference to the `reader` used to access the state.
401    pub fn shared_reader(&self) -> Arc<R> {
402        Arc::clone(&self.reader)
403    }
404
405    /// Creates a new `TrackingCopy` using the `reader` as the interface to the state.
406    /// Returns a new `TrackingCopy` instance that is a snapshot of the current state, allowing
407    /// further changes to be made.
408    ///
409    /// This method creates a new `TrackingCopy` using the current instance (including its
410    /// mutations) as the base state to read against. Mutations made to the new `TrackingCopy`
411    /// will not impact the original instance.
412    ///
413    /// Note: Currently, there is no `join` or `merge` function to bring changes from a fork back to
414    /// the main `TrackingCopy`. Therefore, forking should be done repeatedly, which is
415    /// suboptimal and will be improved in the future.
416    pub fn fork(&self) -> TrackingCopy<&TrackingCopy<R>> {
417        TrackingCopy::new(self, self.max_query_depth, self.enable_addressable_entity)
418    }
419
420    /// Returns a new `TrackingCopy` instance that is a snapshot of the current state, allowing
421    /// further changes to be made.
422    ///
423    /// This method creates a new `TrackingCopy` using the current instance (including its
424    /// mutations) as the base state to read against. Mutations made to the new `TrackingCopy`
425    /// will not impact the original instance.
426    ///
427    /// Note: Currently, there is no `join` or `merge` function to bring changes from a fork back to
428    /// the main `TrackingCopy`. This method is an alternative to the `fork` method and is
429    /// provided for clarity and consistency in naming.
430    pub fn fork2(&self) -> Self {
431        TrackingCopy {
432            reader: Arc::clone(&self.reader),
433            cache: self.cache.clone(),
434            effects: self.effects.clone(),
435            max_query_depth: self.max_query_depth,
436            messages: self.messages.clone(),
437            enable_addressable_entity: self.enable_addressable_entity,
438        }
439    }
440
441    /// Applies the changes to the state.
442    ///
443    /// This is a low-level function that should be used only by the execution engine. The purpose
444    /// of this function is to apply the changes to the state from a forked tracking copy. Once
445    /// caller decides that the changes are valid, they can be applied to the state and the
446    /// processing can resume.
447    pub fn apply_changes(
448        &mut self,
449        effects: Effects,
450        cache: TrackingCopyCache,
451        messages: Messages,
452    ) {
453        self.effects = effects;
454        self.cache = cache;
455        self.messages = messages;
456    }
457
458    /// Returns a copy of the execution effects cached by this instance.
459    pub fn effects(&self) -> Effects {
460        self.effects.clone()
461    }
462
463    /// Returns copy of cache.
464    pub fn cache(&self) -> TrackingCopyCache {
465        self.cache.clone()
466    }
467
468    /// Destructure cached entries.
469    pub fn destructure(self) -> (Vec<(Key, StoredValue)>, BTreeSet<Key>, Effects) {
470        let (writes, prunes) = self.cache.into_muts();
471        let writes: Vec<(Key, StoredValue)> = writes.into_iter().map(|(k, v)| (k.0, v)).collect();
472
473        (writes, prunes, self.effects)
474    }
475
476    /// Enable the addressable entity and migrate accounts/contracts to entities.
477    pub fn enable_addressable_entity(&self) -> bool {
478        self.enable_addressable_entity
479    }
480
481    /// Get record by key.
482    pub fn get(&mut self, key: &Key) -> Result<Option<StoredValue>, TrackingCopyError> {
483        if let Some(value) = self.cache.get(key) {
484            return Ok(Some(value.to_owned()));
485        }
486        match self.reader.read(key) {
487            Ok(ret) => {
488                if let Some(value) = ret {
489                    self.cache.insert_read(*key, value.to_owned());
490                    Ok(Some(value))
491                } else {
492                    Ok(None)
493                }
494            }
495            Err(err) => Err(TrackingCopyError::Storage(err)),
496        }
497    }
498
499    /// Gets the set of keys in the state whose tag is `key_tag`.
500    pub fn get_keys(&self, key_tag: &KeyTag) -> Result<BTreeSet<Key>, TrackingCopyError> {
501        self.get_by_byte_prefix(&[*key_tag as u8])
502    }
503
504    /// Get keys by prefix.
505    pub fn get_keys_by_prefix(
506        &self,
507        key_prefix: &KeyPrefix,
508    ) -> Result<BTreeSet<Key>, TrackingCopyError> {
509        let byte_prefix = key_prefix
510            .to_bytes()
511            .map_err(TrackingCopyError::BytesRepr)?;
512        self.get_by_byte_prefix(&byte_prefix)
513    }
514
515    /// Gets the set of keys in the state by a byte prefix.
516    pub(crate) fn get_by_byte_prefix(
517        &self,
518        byte_prefix: &[u8],
519    ) -> Result<BTreeSet<Key>, TrackingCopyError> {
520        let ret = self.keys_with_prefix(byte_prefix)?.into_iter().collect();
521        Ok(ret)
522    }
523
524    /// Reads the value stored under `key`.
525    pub fn read(&mut self, key: &Key) -> Result<Option<StoredValue>, TrackingCopyError> {
526        let normalized_key = key.normalize();
527        if let Some(value) = self.get(&normalized_key)? {
528            self.effects
529                .push(TransformV2::new(normalized_key, TransformKindV2::Identity));
530            Ok(Some(value))
531        } else {
532            Ok(None)
533        }
534    }
535
536    /// Reads the first value stored under the keys in `keys`.
537    pub fn read_first(&mut self, keys: &[&Key]) -> Result<Option<StoredValue>, TrackingCopyError> {
538        for key in keys {
539            if let Some(value) = self.read(key)? {
540                return Ok(Some(value));
541            }
542        }
543        Ok(None)
544    }
545
546    /// Writes `value` under `key`. Note that the written value is only cached.
547    pub fn write(&mut self, key: Key, value: StoredValue) {
548        let normalized_key = key.normalize();
549        self.cache.insert_write(normalized_key, value.clone());
550        let transform = TransformV2::new(normalized_key, TransformKindV2::Write(value));
551        self.effects.push(transform);
552    }
553
554    /// Caches the emitted message and writes the message topic summary under the specified key.
555    ///
556    /// This function does not check the types for the key and the value so the caller should
557    /// correctly set the type. The `message_topic_key` should be of the `Key::MessageTopic`
558    /// variant and the `message_topic_summary` should be of the `StoredValue::Message` variant.
559    #[allow(clippy::too_many_arguments)]
560    pub fn emit_message(
561        &mut self,
562        message_topic_key: Key,
563        message_topic_summary: StoredValue,
564        message_key: Key,
565        message_value: StoredValue,
566        block_message_count_value: StoredValue,
567        message: Message,
568    ) {
569        self.write(message_key, message_value);
570        self.write(message_topic_key, message_topic_summary);
571        self.write(
572            Key::BlockGlobal(BlockGlobalAddr::MessageCount),
573            block_message_count_value,
574        );
575        self.messages.push(message);
576    }
577
578    /// Prunes a `key`.
579    pub fn prune(&mut self, key: Key) {
580        let normalized_key = key.normalize();
581        self.cache.insert_prune(normalized_key);
582        self.effects.push(TransformV2::new(
583            normalized_key,
584            TransformKindV2::Prune(key),
585        ));
586    }
587
588    /// Ok(None) represents missing key to which we want to "add" some value.
589    /// Ok(Some(unit)) represents successful operation.
590    /// Err(error) is reserved for unexpected errors when accessing global
591    /// state.
592    pub fn add(&mut self, key: Key, value: StoredValue) -> Result<AddResult, TrackingCopyError> {
593        let normalized_key = key.normalize();
594        let current_value = match self.get(&normalized_key)? {
595            None => return Ok(AddResult::KeyNotFound(normalized_key)),
596            Some(current_value) => current_value,
597        };
598
599        let type_name = value.type_name();
600        let mismatch = || {
601            Ok(AddResult::TypeMismatch(StoredValueTypeMismatch::new(
602                "I32, U64, U128, U256, U512 or (String, Key) tuple".to_string(),
603                type_name,
604            )))
605        };
606
607        let transform_kind = match value {
608            StoredValue::CLValue(cl_value) => match *cl_value.cl_type() {
609                CLType::I32 => match cl_value.into_t() {
610                    Ok(value) => TransformKindV2::AddInt32(value),
611                    Err(error) => return Ok(AddResult::from(error)),
612                },
613                CLType::U64 => match cl_value.into_t() {
614                    Ok(value) => TransformKindV2::AddUInt64(value),
615                    Err(error) => return Ok(AddResult::from(error)),
616                },
617                CLType::U128 => match cl_value.into_t() {
618                    Ok(value) => TransformKindV2::AddUInt128(value),
619                    Err(error) => return Ok(AddResult::from(error)),
620                },
621                CLType::U256 => match cl_value.into_t() {
622                    Ok(value) => TransformKindV2::AddUInt256(value),
623                    Err(error) => return Ok(AddResult::from(error)),
624                },
625                CLType::U512 => match cl_value.into_t() {
626                    Ok(value) => TransformKindV2::AddUInt512(value),
627                    Err(error) => return Ok(AddResult::from(error)),
628                },
629                _ => {
630                    if *cl_value.cl_type() == casper_types::named_key_type() {
631                        match cl_value.into_t() {
632                            Ok((name, key)) => {
633                                let mut named_keys = NamedKeys::new();
634                                named_keys.insert(name, key);
635                                TransformKindV2::AddKeys(named_keys)
636                            }
637                            Err(error) => return Ok(AddResult::from(error)),
638                        }
639                    } else {
640                        return mismatch();
641                    }
642                }
643            },
644            _ => return mismatch(),
645        };
646
647        match transform_kind.clone().apply(current_value) {
648            Ok(TransformInstruction::Store(new_value)) => {
649                self.cache.insert_write(normalized_key, new_value);
650                self.effects
651                    .push(TransformV2::new(normalized_key, transform_kind));
652                Ok(AddResult::Success)
653            }
654            Ok(TransformInstruction::Prune(key)) => {
655                self.cache.insert_prune(normalized_key);
656                self.effects.push(TransformV2::new(
657                    normalized_key,
658                    TransformKindV2::Prune(key),
659                ));
660                Ok(AddResult::Success)
661            }
662            Err(TransformError::TypeMismatch(type_mismatch)) => {
663                Ok(AddResult::TypeMismatch(type_mismatch))
664            }
665            Err(TransformError::Serialization(error)) => Ok(AddResult::Serialization(error)),
666            Err(transform_error) => Ok(AddResult::Transform(transform_error)),
667        }
668    }
669
670    /// Returns a copy of the messages cached by this instance.
671    pub fn messages(&self) -> Messages {
672        self.messages.clone()
673    }
674
675    /// Calling `query()` avoids calling into `self.cache`, so this will not return any values
676    /// written or mutated in this `TrackingCopy` via previous calls to `write()` or `add()`, since
677    /// these updates are only held in `self.cache`.
678    ///
679    /// The intent is that `query()` is only used to satisfy `QueryRequest`s made to the server.
680    /// Other EE internal use cases should call `read()` or `get()` in order to retrieve cached
681    /// values.
682    pub fn query(
683        &self,
684        base_key: Key,
685        path: &[String],
686    ) -> Result<TrackingCopyQueryResult, TrackingCopyError> {
687        let mut query = Query::new(base_key, path);
688
689        let mut proofs = Vec::new();
690
691        loop {
692            if query.depth >= self.max_query_depth {
693                return Ok(query.into_depth_limit_result());
694            }
695
696            if !query.visited_keys.insert(query.current_key) {
697                return Ok(query.into_circular_ref_result());
698            }
699
700            let stored_value = match self.reader.read_with_proof(&query.current_key)? {
701                None => {
702                    return Ok(query.into_not_found_result("Failed to find base key"));
703                }
704                Some(stored_value) => stored_value,
705            };
706
707            let value = stored_value.value().to_owned();
708
709            // Following code does a patching on the `StoredValue` that unwraps an inner
710            // `DictionaryValue` for dictionaries only.
711            let value = match handle_stored_dictionary_value(query.current_key, value) {
712                Ok(patched_stored_value) => patched_stored_value,
713                Err(error) => {
714                    return Ok(query.into_not_found_result(&format!(
715                        "Failed to retrieve dictionary value: {}",
716                        error
717                    )));
718                }
719            };
720
721            proofs.push(stored_value);
722
723            if query.unvisited_names.is_empty() && !query.current_key.is_named_key() {
724                return Ok(TrackingCopyQueryResult::Success { value, proofs });
725            }
726
727            let stored_value: &StoredValue = proofs
728                .last()
729                .map(|r| r.value())
730                .expect("but we just pushed");
731
732            match stored_value {
733                StoredValue::Account(account) => {
734                    let name = query.next_name();
735                    if let Some(key) = account.named_keys().get(name) {
736                        query.navigate(*key);
737                    } else {
738                        let msg_prefix = format!("Name {} not found in Account", name);
739                        return Ok(query.into_not_found_result(&msg_prefix));
740                    }
741                }
742                StoredValue::Contract(contract) => {
743                    let name = query.next_name();
744                    if let Some(key) = contract.named_keys().get(name) {
745                        query.navigate(*key);
746                    } else {
747                        let msg_prefix = format!("Name {} not found in Contract", name);
748                        return Ok(query.into_not_found_result(&msg_prefix));
749                    }
750                }
751                StoredValue::NamedKey(named_key_value) => {
752                    match query.visited_names.last() {
753                        Some(expected_name) => match named_key_value.get_name() {
754                            Ok(actual_name) => {
755                                if &actual_name != expected_name {
756                                    return Ok(query.into_not_found_result(
757                                        "Queried and retrieved names do not match",
758                                    ));
759                                } else if let Ok(key) = named_key_value.get_key() {
760                                    query.navigate(key)
761                                } else {
762                                    return Ok(query
763                                        .into_not_found_result("Failed to parse CLValue as Key"));
764                                }
765                            }
766                            Err(_) => {
767                                return Ok(query
768                                    .into_not_found_result("Failed to parse CLValue as String"));
769                            }
770                        },
771                        None if path.is_empty() => {
772                            return Ok(TrackingCopyQueryResult::Success { value, proofs });
773                        }
774                        None => return Ok(query.into_not_found_result("No visited names")),
775                    }
776                }
777                StoredValue::CLValue(cl_value) if cl_value.cl_type() == &CLType::Key => {
778                    if let Ok(key) = cl_value.to_owned().into_t::<Key>() {
779                        query.navigate(key);
780                    } else {
781                        return Ok(query.into_not_found_result("Failed to parse CLValue as Key"));
782                    }
783                }
784                StoredValue::CLValue(cl_value) => {
785                    let msg_prefix = format!(
786                        "Query cannot continue as {:?} is not an account, contract nor key to \
787                        such.  Value found",
788                        cl_value
789                    );
790                    return Ok(query.into_not_found_result(&msg_prefix));
791                }
792                StoredValue::AddressableEntity(_) => {
793                    let current_key = query.current_key;
794                    let name = query.next_name();
795
796                    if let Key::AddressableEntity(addr) = current_key {
797                        let named_key_addr = match NamedKeyAddr::new_from_string(addr, name.clone())
798                        {
799                            Ok(named_key_addr) => Key::NamedKey(named_key_addr),
800                            Err(error) => {
801                                let msg_prefix = format!("{}", error);
802                                return Ok(query.into_not_found_result(&msg_prefix));
803                            }
804                        };
805                        query.navigate_for_named_key(named_key_addr);
806                    } else {
807                        let msg_prefix = "Invalid base key".to_string();
808                        return Ok(query.into_not_found_result(&msg_prefix));
809                    }
810                }
811                StoredValue::ContractWasm(_) => {
812                    return Ok(query.into_not_found_result("ContractWasm value found."));
813                }
814                StoredValue::ContractPackage(_) => {
815                    return Ok(query.into_not_found_result("ContractPackage value found."));
816                }
817                StoredValue::SmartContract(_) => {
818                    return Ok(query.into_not_found_result("Package value found."));
819                }
820                StoredValue::ByteCode(_) => {
821                    return Ok(query.into_not_found_result("ByteCode value found."));
822                }
823                StoredValue::Transfer(_) => {
824                    return Ok(query.into_not_found_result("Legacy Transfer value found."));
825                }
826                StoredValue::DeployInfo(_) => {
827                    return Ok(query.into_not_found_result("DeployInfo value found."));
828                }
829                StoredValue::EraInfo(_) => {
830                    return Ok(query.into_not_found_result("EraInfo value found."));
831                }
832                StoredValue::Bid(_) => {
833                    return Ok(query.into_not_found_result("Bid value found."));
834                }
835                StoredValue::BidKind(_) => {
836                    return Ok(query.into_not_found_result("BidKind value found."));
837                }
838                StoredValue::Withdraw(_) => {
839                    return Ok(query.into_not_found_result("WithdrawPurses value found."));
840                }
841                StoredValue::Unbonding(_) => {
842                    return Ok(query.into_not_found_result("Unbonding value found."));
843                }
844                StoredValue::MessageTopic(_) => {
845                    return Ok(query.into_not_found_result("MessageTopic value found."));
846                }
847                StoredValue::Message(_) => {
848                    return Ok(query.into_not_found_result("Message value found."));
849                }
850                StoredValue::EntryPoint(_) => {
851                    return Ok(query.into_not_found_result("EntryPoint value found."));
852                }
853                StoredValue::Prepayment(_) => {
854                    return Ok(query.into_not_found_result("Prepayment value found."))
855                }
856                StoredValue::RawBytes(_) => {
857                    return Ok(query.into_not_found_result("RawBytes value found."));
858                }
859            }
860        }
861    }
862}
863
864/// The purpose of this implementation is to allow a "snapshot" mechanism for
865/// TrackingCopy. The state of a TrackingCopy (including the effects of
866/// any transforms it has accumulated) can be read using an immutable
867/// reference to that TrackingCopy via this trait implementation. See
868/// `TrackingCopy::fork` for more information.
869impl<R: StateReader<Key, StoredValue>> StateReader<Key, StoredValue> for &TrackingCopy<R> {
870    type Error = R::Error;
871
872    fn read(&self, key: &Key) -> Result<Option<StoredValue>, Self::Error> {
873        let kb = KeyWithByteRepr::new(*key);
874        if let Some(value) = self.cache.muts_cached.get(&kb) {
875            return Ok(Some(value.to_owned()));
876        }
877        if let Some(value) = self.reader.read(key)? {
878            Ok(Some(value))
879        } else {
880            Ok(None)
881        }
882    }
883
884    fn read_with_proof(
885        &self,
886        key: &Key,
887    ) -> Result<Option<TrieMerkleProof<Key, StoredValue>>, Self::Error> {
888        self.reader.read_with_proof(key)
889    }
890
891    fn keys_with_prefix(&self, byte_prefix: &[u8]) -> Result<Vec<Key>, Self::Error> {
892        let keys = self.reader.keys_with_prefix(byte_prefix)?;
893
894        let ret = keys
895            .into_iter()
896            // don't include keys marked for pruning
897            .filter(|key| !self.cache.is_pruned(key))
898            // there may be newly inserted keys which have not been committed yet
899            .chain(self.cache.get_muts_cached_by_byte_prefix(byte_prefix))
900            .collect();
901        Ok(ret)
902    }
903}
904
905/// Error conditions of a proof validation.
906#[derive(Error, Debug, PartialEq, Eq)]
907pub enum ValidationError {
908    /// The path should not have a different length than the proof less one.
909    #[error("The path should not have a different length than the proof less one.")]
910    PathLengthDifferentThanProofLessOne,
911
912    /// The provided key does not match the key in the proof.
913    #[error("The provided key does not match the key in the proof.")]
914    UnexpectedKey,
915
916    /// The provided value does not match the value in the proof.
917    #[error("The provided value does not match the value in the proof.")]
918    UnexpectedValue,
919
920    /// The proof hash is invalid.
921    #[error("The proof hash is invalid.")]
922    InvalidProofHash,
923
924    /// The path went cold.
925    #[error("The path went cold.")]
926    PathCold,
927
928    /// (De)serialization error.
929    #[error("Serialization error: {0}")]
930    BytesRepr(bytesrepr::Error),
931
932    /// Key is not a URef.
933    #[error("Key is not a URef")]
934    KeyIsNotAURef(Key),
935
936    /// Error converting a stored value to a [`Key`].
937    #[error("Failed to convert stored value to key")]
938    ValueToCLValueConversion,
939
940    /// CLValue conversion error.
941    #[error("{0}")]
942    CLValueError(CLValueError),
943}
944
945impl From<CLValueError> for ValidationError {
946    fn from(err: CLValueError) -> Self {
947        ValidationError::CLValueError(err)
948    }
949}
950
951impl From<bytesrepr::Error> for ValidationError {
952    fn from(error: bytesrepr::Error) -> Self {
953        Self::BytesRepr(error)
954    }
955}
956
957/// Validates proof of the query.
958///
959/// Returns [`ValidationError`] for any of
960pub fn validate_query_proof(
961    hash: &Digest,
962    proofs: &[TrieMerkleProof<Key, StoredValue>],
963    expected_first_key: &Key,
964    path: &[String],
965    expected_value: &StoredValue,
966) -> Result<(), ValidationError> {
967    if proofs.len() != path.len() + 1 {
968        return Err(ValidationError::PathLengthDifferentThanProofLessOne);
969    }
970
971    let mut proofs_iter = proofs.iter();
972    let mut path_components_iter = path.iter();
973
974    // length check above means we are safe to unwrap here
975    let first_proof = proofs_iter.next().unwrap();
976
977    if first_proof.key() != &expected_first_key.normalize() {
978        return Err(ValidationError::UnexpectedKey);
979    }
980
981    if hash != &compute_state_hash(first_proof)? {
982        return Err(ValidationError::InvalidProofHash);
983    }
984
985    let mut proof_value = first_proof.value();
986
987    for proof in proofs_iter {
988        let named_keys = match proof_value {
989            StoredValue::Account(account) => account.named_keys(),
990            StoredValue::Contract(contract) => contract.named_keys(),
991            _ => return Err(ValidationError::PathCold),
992        };
993
994        let path_component = match path_components_iter.next() {
995            Some(path_component) => path_component,
996            None => return Err(ValidationError::PathCold),
997        };
998
999        let key = match named_keys.get(path_component) {
1000            Some(key) => key,
1001            None => return Err(ValidationError::PathCold),
1002        };
1003
1004        if proof.key() != &key.normalize() {
1005            return Err(ValidationError::UnexpectedKey);
1006        }
1007
1008        if hash != &compute_state_hash(proof)? {
1009            return Err(ValidationError::InvalidProofHash);
1010        }
1011
1012        proof_value = proof.value();
1013    }
1014
1015    if proof_value != expected_value {
1016        return Err(ValidationError::UnexpectedValue);
1017    }
1018
1019    Ok(())
1020}
1021
1022/// Validates proof of the query.
1023///
1024/// Returns [`ValidationError`] for any of
1025pub fn validate_query_merkle_proof(
1026    hash: &Digest,
1027    proofs: &[TrieMerkleProof<Key, StoredValue>],
1028    expected_key_trace: &[Key],
1029    expected_value: &StoredValue,
1030) -> Result<(), ValidationError> {
1031    let expected_len = expected_key_trace.len();
1032    if proofs.len() != expected_len {
1033        return Err(ValidationError::PathLengthDifferentThanProofLessOne);
1034    }
1035
1036    let proof_keys: Vec<Key> = proofs.iter().map(|proof| *proof.key()).collect();
1037
1038    if !expected_key_trace.eq(&proof_keys) {
1039        return Err(ValidationError::UnexpectedKey);
1040    }
1041
1042    if expected_value != proofs[expected_len - 1].value() {
1043        return Err(ValidationError::UnexpectedValue);
1044    }
1045
1046    let mut proofs_iter = proofs.iter();
1047
1048    // length check above means we are safe to unwrap here
1049    let first_proof = proofs_iter.next().unwrap();
1050
1051    if hash != &compute_state_hash(first_proof)? {
1052        return Err(ValidationError::InvalidProofHash);
1053    }
1054
1055    Ok(())
1056}
1057
1058/// Validates a proof of a balance request.
1059pub fn validate_balance_proof(
1060    hash: &Digest,
1061    balance_proof: &TrieMerkleProof<Key, StoredValue>,
1062    expected_purse_key: Key,
1063    expected_motes: &U512,
1064) -> Result<(), ValidationError> {
1065    let expected_balance_key = expected_purse_key
1066        .into_uref()
1067        .map(|uref| Key::Balance(uref.addr()))
1068        .ok_or_else(|| ValidationError::KeyIsNotAURef(expected_purse_key.to_owned()))?;
1069
1070    if balance_proof.key() != &expected_balance_key.normalize() {
1071        return Err(ValidationError::UnexpectedKey);
1072    }
1073
1074    if hash != &compute_state_hash(balance_proof)? {
1075        return Err(ValidationError::InvalidProofHash);
1076    }
1077
1078    let balance_proof_stored_value = balance_proof.value().to_owned();
1079
1080    let balance_proof_clvalue: CLValue = balance_proof_stored_value
1081        .try_into()
1082        .map_err(|_| ValidationError::ValueToCLValueConversion)?;
1083
1084    let balance_motes: U512 = balance_proof_clvalue.into_t()?;
1085
1086    if expected_motes != &balance_motes {
1087        return Err(ValidationError::UnexpectedValue);
1088    }
1089
1090    Ok(())
1091}
1092
1093use crate::global_state::{
1094    error::Error,
1095    state::{
1096        lmdb::{make_temporary_global_state, LmdbGlobalStateView},
1097        StateProvider,
1098    },
1099};
1100use tempfile::TempDir;
1101
1102/// Creates a temp global state with initial state and checks out a tracking copy on it.
1103pub fn new_temporary_tracking_copy(
1104    initial_data: impl IntoIterator<Item = (Key, StoredValue)>,
1105    max_query_depth: Option<u64>,
1106    enable_addressable_entity: bool,
1107) -> (TrackingCopy<LmdbGlobalStateView>, TempDir) {
1108    let (global_state, state_root_hash, tempdir) = make_temporary_global_state(initial_data);
1109
1110    let reader = global_state
1111        .checkout(state_root_hash)
1112        .expect("Checkout should not throw errors.")
1113        .expect("Root hash should exist.");
1114
1115    let query_depth = max_query_depth.unwrap_or(DEFAULT_MAX_QUERY_DEPTH);
1116
1117    (
1118        TrackingCopy::new(reader, query_depth, enable_addressable_entity),
1119        tempdir,
1120    )
1121}