casper_storage/tracking_copy/
mod.rs

1//! This module defines the `TrackingCopy` - a utility that caches operations on the state, so that
2//! the underlying state remains unmodified, but it can be interacted with as if the modifications
3//! were applied on it.
4mod byte_size;
5mod error;
6mod ext;
7mod ext_entity;
8mod meter;
9#[cfg(test)]
10mod tests;
11
12use std::{
13    borrow::Borrow,
14    collections::{BTreeMap, BTreeSet, HashSet, VecDeque},
15    convert::{From, TryInto},
16    fmt::Debug,
17    sync::Arc,
18};
19
20use linked_hash_map::LinkedHashMap;
21use thiserror::Error;
22use tracing::error;
23
24use crate::{
25    global_state::{
26        error::Error as GlobalStateError, state::StateReader,
27        trie_store::operations::compute_state_hash, DEFAULT_MAX_QUERY_DEPTH,
28    },
29    KeyPrefix,
30};
31use casper_types::{
32    addressable_entity::NamedKeyAddr,
33    bytesrepr::{self, ToBytes},
34    contract_messages::{Message, Messages},
35    contracts::NamedKeys,
36    execution::{Effects, TransformError, TransformInstruction, TransformKindV2, TransformV2},
37    global_state::TrieMerkleProof,
38    handle_stored_dictionary_value, BlockGlobalAddr, CLType, CLValue, CLValueError, Digest, Key,
39    KeyTag, StoredValue, StoredValueTypeMismatch, U512,
40};
41
42use self::meter::{heap_meter::HeapSize, Meter};
43pub use self::{
44    error::Error as TrackingCopyError,
45    ext::TrackingCopyExt,
46    ext_entity::{FeesPurseHandling, TrackingCopyEntityExt},
47};
48
49/// Result of a query on a `TrackingCopy`.
50#[derive(Debug)]
51#[allow(clippy::large_enum_variant)]
52pub enum TrackingCopyQueryResult {
53    /// Invalid state root hash.
54    RootNotFound,
55    /// The value wasn't found.
56    ValueNotFound(String),
57    /// A circular reference was found in the state while traversing it.
58    CircularReference(String),
59    /// The query reached the depth limit.
60    DepthLimit {
61        /// The depth reached.
62        depth: u64,
63    },
64    /// The query was successful.
65    Success {
66        /// The value read from the state.
67        value: StoredValue,
68        /// Merkle proofs for the value.
69        proofs: Vec<TrieMerkleProof<Key, StoredValue>>,
70    },
71}
72
73impl TrackingCopyQueryResult {
74    /// Is this a successful query?
75    pub fn is_success(&self) -> bool {
76        matches!(self, TrackingCopyQueryResult::Success { .. })
77    }
78
79    /// As result.
80    pub fn into_result(self) -> Result<StoredValue, TrackingCopyError> {
81        match self {
82            TrackingCopyQueryResult::RootNotFound => {
83                Err(TrackingCopyError::Storage(Error::RootNotFound))
84            }
85            TrackingCopyQueryResult::ValueNotFound(msg) => {
86                Err(TrackingCopyError::ValueNotFound(msg))
87            }
88            TrackingCopyQueryResult::CircularReference(msg) => {
89                Err(TrackingCopyError::CircularReference(msg))
90            }
91            TrackingCopyQueryResult::DepthLimit { depth } => {
92                Err(TrackingCopyError::QueryDepthLimit { depth })
93            }
94            TrackingCopyQueryResult::Success { value, .. } => Ok(value),
95        }
96    }
97}
98
99/// Struct containing state relating to a given query.
100struct Query {
101    /// The key from where the search starts.
102    base_key: Key,
103    /// A collection of normalized keys which have been visited during the search.
104    visited_keys: HashSet<Key>,
105    /// The key currently being processed.
106    current_key: Key,
107    /// Path components which have not yet been followed, held in the same order in which they were
108    /// provided to the `query()` call.
109    unvisited_names: VecDeque<String>,
110    /// Path components which have been followed, held in the same order in which they were
111    /// provided to the `query()` call.
112    visited_names: Vec<String>,
113    /// Current depth of the query.
114    depth: u64,
115}
116
117impl Query {
118    fn new(base_key: Key, path: &[String]) -> Self {
119        Query {
120            base_key,
121            current_key: base_key.normalize(),
122            unvisited_names: path.iter().cloned().collect(),
123            visited_names: Vec::new(),
124            visited_keys: HashSet::new(),
125            depth: 0,
126        }
127    }
128
129    /// Panics if `unvisited_names` is empty.
130    fn next_name(&mut self) -> Option<&String> {
131        let next_name = self.unvisited_names.pop_front()?;
132        self.visited_names.push(next_name);
133        self.visited_names.last()
134    }
135
136    fn navigate(&mut self, key: Key) {
137        self.current_key = key.normalize();
138        self.depth += 1;
139    }
140
141    fn navigate_for_named_key(&mut self, named_key: Key) {
142        if let Key::NamedKey(_) = &named_key {
143            self.current_key = named_key.normalize();
144        }
145    }
146
147    fn into_not_found_result(self, msg_prefix: &str) -> TrackingCopyQueryResult {
148        let msg = format!("{} at path: {}", msg_prefix, self.current_path());
149        TrackingCopyQueryResult::ValueNotFound(msg)
150    }
151
152    fn into_circular_ref_result(self) -> TrackingCopyQueryResult {
153        let msg = format!(
154            "{:?} has formed a circular reference at path: {}",
155            self.current_key,
156            self.current_path()
157        );
158        TrackingCopyQueryResult::CircularReference(msg)
159    }
160
161    fn into_depth_limit_result(self) -> TrackingCopyQueryResult {
162        TrackingCopyQueryResult::DepthLimit { depth: self.depth }
163    }
164
165    fn current_path(&self) -> String {
166        let mut path = format!("{:?}", self.base_key);
167        for name in &self.visited_names {
168            path.push('/');
169            path.push_str(name);
170        }
171        path
172    }
173}
174
175/// Keeps track of already accessed keys.
176/// We deliberately separate cached Reads from cached mutations
177/// because we want to invalidate Reads' cache so it doesn't grow too fast.
178#[derive(Clone, Debug)]
179pub struct GenericTrackingCopyCache<M: Copy + Debug> {
180    max_cache_size: usize,
181    current_cache_size: usize,
182    reads_cached: LinkedHashMap<Key, StoredValue>,
183    muts_cached: BTreeMap<KeyWithByteRepr, StoredValue>,
184    prunes_cached: BTreeSet<Key>,
185    meter: M,
186}
187
188impl<M: Meter<Key, StoredValue> + Copy + Default> GenericTrackingCopyCache<M> {
189    /// Creates instance of `TrackingCopyCache` with specified `max_cache_size`,
190    /// above which least-recently-used elements of the cache are invalidated.
191    /// Measurements of elements' "size" is done with the usage of `Meter`
192    /// instance.
193    pub fn new(max_cache_size: usize, meter: M) -> GenericTrackingCopyCache<M> {
194        GenericTrackingCopyCache {
195            max_cache_size,
196            current_cache_size: 0,
197            reads_cached: LinkedHashMap::new(),
198            muts_cached: BTreeMap::new(),
199            prunes_cached: BTreeSet::new(),
200            meter,
201        }
202    }
203
204    /// Creates instance of `TrackingCopyCache` with specified `max_cache_size`, above which
205    /// least-recently-used elements of the cache are invalidated. Measurements of elements' "size"
206    /// is done with the usage of default `Meter` instance.
207    pub fn new_default(max_cache_size: usize) -> GenericTrackingCopyCache<M> {
208        GenericTrackingCopyCache::new(max_cache_size, M::default())
209    }
210
211    /// Inserts `key` and `value` pair to Read cache.
212    pub fn insert_read(&mut self, key: Key, value: StoredValue) {
213        let element_size = Meter::measure(&self.meter, &key, &value);
214        self.reads_cached.insert(key, value);
215        self.current_cache_size += element_size;
216        while self.current_cache_size > self.max_cache_size {
217            match self.reads_cached.pop_front() {
218                Some((k, v)) => {
219                    let element_size = Meter::measure(&self.meter, &k, &v);
220                    self.current_cache_size -= element_size;
221                }
222                None => break,
223            }
224        }
225    }
226
227    /// Inserts `key` and `value` pair to Write/Add cache.
228    pub fn insert_write(&mut self, key: Key, value: StoredValue) {
229        let kb = KeyWithByteRepr::new(key);
230        self.prunes_cached.remove(&key);
231        self.muts_cached.insert(kb, value);
232    }
233
234    /// Inserts `key` and `value` pair to Write/Add cache.
235    pub fn insert_prune(&mut self, key: Key) {
236        self.prunes_cached.insert(key);
237    }
238
239    /// Gets value from `key` in the cache.
240    pub fn get(&mut self, key: &Key) -> Option<&StoredValue> {
241        if self.prunes_cached.contains(key) {
242            // the item is marked for pruning and therefore
243            // is no longer accessible.
244            return None;
245        }
246        let kb = KeyWithByteRepr::new(*key);
247        if let Some(value) = self.muts_cached.get(&kb) {
248            return Some(value);
249        };
250
251        self.reads_cached.get_refresh(key).map(|v| &*v)
252    }
253
254    /// Get cached items by prefix.
255    fn get_muts_cached_by_byte_prefix(&self, prefix: &[u8]) -> Vec<Key> {
256        self.muts_cached
257            .range(prefix.to_vec()..)
258            .take_while(|(key, _)| key.starts_with(prefix))
259            .map(|(key, _)| key.to_key())
260            .collect()
261    }
262
263    /// Does the prune cache contain key.
264    pub fn is_pruned(&self, key: &Key) -> bool {
265        self.prunes_cached.contains(key)
266    }
267
268    pub(self) fn into_muts(self) -> (BTreeMap<KeyWithByteRepr, StoredValue>, BTreeSet<Key>) {
269        (self.muts_cached, self.prunes_cached)
270    }
271}
272
273/// A helper type for `TrackingCopyCache` that allows convenient storage and access
274/// to keys as bytes.
275/// Its equality and ordering is based on the byte representation of the key.
276#[derive(Debug, Clone)]
277struct KeyWithByteRepr(Key, Vec<u8>);
278
279impl KeyWithByteRepr {
280    #[inline]
281    fn new(key: Key) -> Self {
282        let bytes = key.to_bytes().expect("should always serialize a Key");
283        KeyWithByteRepr(key, bytes)
284    }
285
286    #[inline]
287    fn starts_with(&self, prefix: &[u8]) -> bool {
288        self.1.starts_with(prefix)
289    }
290
291    #[inline]
292    fn to_key(&self) -> Key {
293        self.0
294    }
295}
296
297impl Borrow<Vec<u8>> for KeyWithByteRepr {
298    #[inline]
299    fn borrow(&self) -> &Vec<u8> {
300        &self.1
301    }
302}
303
304impl PartialEq for KeyWithByteRepr {
305    #[inline]
306    fn eq(&self, other: &Self) -> bool {
307        self.1 == other.1
308    }
309}
310
311impl Eq for KeyWithByteRepr {}
312
313impl PartialOrd for KeyWithByteRepr {
314    #[inline]
315    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
316        Some(self.cmp(other))
317    }
318}
319
320impl Ord for KeyWithByteRepr {
321    #[inline]
322    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
323        self.1.cmp(&other.1)
324    }
325}
326
327/// An alias for a `TrackingCopyCache` with `HeapSize` as the meter.
328pub type TrackingCopyCache = GenericTrackingCopyCache<HeapSize>;
329
330/// An interface for the global state that caches all operations (reads and writes) instead of
331/// applying them directly to the state. This way the state remains unmodified, while the user can
332/// interact with it as if it was being modified in real time.
333#[derive(Clone)]
334pub struct TrackingCopy<R> {
335    reader: Arc<R>,
336    cache: TrackingCopyCache,
337    effects: Effects,
338    max_query_depth: u64,
339    messages: Messages,
340    enable_addressable_entity: bool,
341}
342
343/// Result of executing an "add" operation on a value in the state.
344#[derive(Debug)]
345pub enum AddResult {
346    /// The operation was successful.
347    Success,
348    /// The key was not found.
349    KeyNotFound(Key),
350    /// There was a type mismatch between the stored value and the value being added.
351    TypeMismatch(StoredValueTypeMismatch),
352    /// Serialization error.
353    Serialization(bytesrepr::Error),
354    /// Transform error.
355    Transform(TransformError),
356}
357
358impl From<CLValueError> for AddResult {
359    fn from(error: CLValueError) -> Self {
360        match error {
361            CLValueError::Serialization(error) => AddResult::Serialization(error),
362            CLValueError::Type(type_mismatch) => {
363                let expected = format!("{:?}", type_mismatch.expected);
364                let found = format!("{:?}", type_mismatch.found);
365                AddResult::TypeMismatch(StoredValueTypeMismatch::new(expected, found))
366            }
367        }
368    }
369}
370
371/// A helper type for `TrackingCopy` that represents a key-value pair.
372pub type TrackingCopyParts = (TrackingCopyCache, Effects, Messages);
373
374impl<R: StateReader<Key, StoredValue>> TrackingCopy<R>
375where
376    R: StateReader<Key, StoredValue, Error = GlobalStateError>,
377{
378    /// Creates a new `TrackingCopy` using the `reader` as the interface to the state.
379    pub fn new(
380        reader: R,
381        max_query_depth: u64,
382        enable_addressable_entity: bool,
383    ) -> TrackingCopy<R> {
384        TrackingCopy {
385            reader: Arc::new(reader),
386            // TODO: Should `max_cache_size` be a fraction of wasm memory limit?
387            cache: GenericTrackingCopyCache::new(1024 * 16, HeapSize),
388            effects: Effects::new(),
389            max_query_depth,
390            messages: Vec::new(),
391            enable_addressable_entity,
392        }
393    }
394
395    /// Returns the `reader` used to access the state.
396    pub fn reader(&self) -> &R {
397        &self.reader
398    }
399
400    /// Returns a shared reference to the `reader` used to access the state.
401    pub fn shared_reader(&self) -> Arc<R> {
402        Arc::clone(&self.reader)
403    }
404
405    /// Creates a new `TrackingCopy` using the `reader` as the interface to the state.
406    /// Returns a new `TrackingCopy` instance that is a snapshot of the current state, allowing
407    /// further changes to be made.
408    ///
409    /// This method creates a new `TrackingCopy` using the current instance (including its
410    /// mutations) as the base state to read against. Mutations made to the new `TrackingCopy`
411    /// will not impact the original instance.
412    ///
413    /// Note: Currently, there is no `join` or `merge` function to bring changes from a fork back to
414    /// the main `TrackingCopy`. Therefore, forking should be done repeatedly, which is
415    /// suboptimal and will be improved in the future.
416    pub fn fork(&self) -> TrackingCopy<&TrackingCopy<R>> {
417        TrackingCopy::new(self, self.max_query_depth, self.enable_addressable_entity)
418    }
419
420    /// Returns a new `TrackingCopy` instance that is a snapshot of the current state, allowing
421    /// further changes to be made.
422    ///
423    /// This method creates a new `TrackingCopy` using the current instance (including its
424    /// mutations) as the base state to read against. Mutations made to the new `TrackingCopy`
425    /// will not impact the original instance.
426    ///
427    /// Note: Currently, there is no `join` or `merge` function to bring changes from a fork back to
428    /// the main `TrackingCopy`. This method is an alternative to the `fork` method and is
429    /// provided for clarity and consistency in naming.
430    pub fn fork2(&self) -> Self {
431        TrackingCopy {
432            reader: Arc::clone(&self.reader),
433            cache: self.cache.clone(),
434            effects: self.effects.clone(),
435            max_query_depth: self.max_query_depth,
436            messages: self.messages.clone(),
437            enable_addressable_entity: self.enable_addressable_entity,
438        }
439    }
440
441    /// Applies the changes to the state.
442    ///
443    /// This is a low-level function that should be used only by the execution engine. The purpose
444    /// of this function is to apply the changes to the state from a forked tracking copy. Once
445    /// caller decides that the changes are valid, they can be applied to the state and the
446    /// processing can resume.
447    pub fn apply_changes(
448        &mut self,
449        effects: Effects,
450        cache: TrackingCopyCache,
451        messages: Messages,
452    ) {
453        self.effects = effects;
454        self.cache = cache;
455        self.messages = messages;
456    }
457
458    /// Returns a copy of the execution effects cached by this instance.
459    pub fn effects(&self) -> Effects {
460        self.effects.clone()
461    }
462
463    /// Returns copy of cache.
464    pub fn cache(&self) -> TrackingCopyCache {
465        self.cache.clone()
466    }
467
468    /// Destructure cached entries.
469    pub fn destructure(self) -> (Vec<(Key, StoredValue)>, BTreeSet<Key>, Effects) {
470        let (writes, prunes) = self.cache.into_muts();
471        let writes: Vec<(Key, StoredValue)> = writes.into_iter().map(|(k, v)| (k.0, v)).collect();
472
473        (writes, prunes, self.effects)
474    }
475
476    /// Enable the addressable entity and migrate accounts/contracts to entities.
477    pub fn enable_addressable_entity(&self) -> bool {
478        self.enable_addressable_entity
479    }
480
481    /// Get record by key.
482    pub fn get(&mut self, key: &Key) -> Result<Option<StoredValue>, TrackingCopyError> {
483        if let Some(value) = self.cache.get(key) {
484            return Ok(Some(value.to_owned()));
485        }
486        match self.reader.read(key) {
487            Ok(ret) => {
488                if let Some(value) = ret {
489                    self.cache.insert_read(*key, value.to_owned());
490                    Ok(Some(value))
491                } else {
492                    Ok(None)
493                }
494            }
495            Err(err) => Err(TrackingCopyError::Storage(err)),
496        }
497    }
498
499    /// Gets the set of keys in the state whose tag is `key_tag`.
500    pub fn get_keys(&self, key_tag: &KeyTag) -> Result<BTreeSet<Key>, TrackingCopyError> {
501        self.get_by_byte_prefix(&[*key_tag as u8])
502    }
503
504    /// Get keys by prefix.
505    pub fn get_keys_by_prefix(
506        &self,
507        key_prefix: &KeyPrefix,
508    ) -> Result<BTreeSet<Key>, TrackingCopyError> {
509        let byte_prefix = key_prefix
510            .to_bytes()
511            .map_err(TrackingCopyError::BytesRepr)?;
512        self.get_by_byte_prefix(&byte_prefix)
513    }
514
515    /// Gets the set of keys in the state by a byte prefix.
516    pub(crate) fn get_by_byte_prefix(
517        &self,
518        byte_prefix: &[u8],
519    ) -> Result<BTreeSet<Key>, TrackingCopyError> {
520        let ret = self.keys_with_prefix(byte_prefix)?.into_iter().collect();
521        Ok(ret)
522    }
523
524    /// Reads the value stored under `key`.
525    pub fn read(&mut self, key: &Key) -> Result<Option<StoredValue>, TrackingCopyError> {
526        let normalized_key = key.normalize();
527        if let Some(value) = self.get(&normalized_key)? {
528            self.effects
529                .push(TransformV2::new(normalized_key, TransformKindV2::Identity));
530            Ok(Some(value))
531        } else {
532            Ok(None)
533        }
534    }
535
536    /// Reads the first value stored under the keys in `keys`.
537    pub fn read_first(&mut self, keys: &[&Key]) -> Result<Option<StoredValue>, TrackingCopyError> {
538        for key in keys {
539            if let Some(value) = self.read(key)? {
540                return Ok(Some(value));
541            }
542        }
543        Ok(None)
544    }
545
546    /// Writes `value` under `key`. Note that the written value is only cached.
547    pub fn write(&mut self, key: Key, value: StoredValue) {
548        let normalized_key = key.normalize();
549        self.cache.insert_write(normalized_key, value.clone());
550        let transform = TransformV2::new(normalized_key, TransformKindV2::Write(value));
551        self.effects.push(transform);
552    }
553
554    /// Caches the emitted message and writes the message topic summary under the specified key.
555    ///
556    /// This function does not check the types for the key and the value so the caller should
557    /// correctly set the type. The `message_topic_key` should be of the `Key::MessageTopic`
558    /// variant and the `message_topic_summary` should be of the `StoredValue::Message` variant.
559    #[allow(clippy::too_many_arguments)]
560    pub fn emit_message(
561        &mut self,
562        message_topic_key: Key,
563        message_topic_summary: StoredValue,
564        message_key: Key,
565        message_value: StoredValue,
566        block_message_count_value: StoredValue,
567        message: Message,
568    ) {
569        self.write(message_key, message_value);
570        self.write(message_topic_key, message_topic_summary);
571        self.write(
572            Key::BlockGlobal(BlockGlobalAddr::MessageCount),
573            block_message_count_value,
574        );
575        self.messages.push(message);
576    }
577
578    /// Prunes a `key`.
579    pub fn prune(&mut self, key: Key) {
580        let normalized_key = key.normalize();
581        self.cache.insert_prune(normalized_key);
582        self.effects.push(TransformV2::new(
583            normalized_key,
584            TransformKindV2::Prune(key),
585        ));
586    }
587
588    /// Ok(None) represents missing key to which we want to "add" some value.
589    /// Ok(Some(unit)) represents successful operation.
590    /// Err(error) is reserved for unexpected errors when accessing global
591    /// state.
592    pub fn add(&mut self, key: Key, value: StoredValue) -> Result<AddResult, TrackingCopyError> {
593        let normalized_key = key.normalize();
594        let current_value = match self.get(&normalized_key)? {
595            None => return Ok(AddResult::KeyNotFound(normalized_key)),
596            Some(current_value) => current_value,
597        };
598
599        let type_name = value.type_name();
600        let mismatch = || {
601            Ok(AddResult::TypeMismatch(StoredValueTypeMismatch::new(
602                "I32, U64, U128, U256, U512 or (String, Key) tuple".to_string(),
603                type_name,
604            )))
605        };
606
607        let transform_kind = match value {
608            StoredValue::CLValue(cl_value) => match *cl_value.cl_type() {
609                CLType::I32 => match cl_value.into_t() {
610                    Ok(value) => TransformKindV2::AddInt32(value),
611                    Err(error) => return Ok(AddResult::from(error)),
612                },
613                CLType::U64 => match cl_value.into_t() {
614                    Ok(value) => TransformKindV2::AddUInt64(value),
615                    Err(error) => return Ok(AddResult::from(error)),
616                },
617                CLType::U128 => match cl_value.into_t() {
618                    Ok(value) => TransformKindV2::AddUInt128(value),
619                    Err(error) => return Ok(AddResult::from(error)),
620                },
621                CLType::U256 => match cl_value.into_t() {
622                    Ok(value) => TransformKindV2::AddUInt256(value),
623                    Err(error) => return Ok(AddResult::from(error)),
624                },
625                CLType::U512 => match cl_value.into_t() {
626                    Ok(value) => TransformKindV2::AddUInt512(value),
627                    Err(error) => return Ok(AddResult::from(error)),
628                },
629                _ => {
630                    if *cl_value.cl_type() == casper_types::named_key_type() {
631                        match cl_value.into_t() {
632                            Ok((name, key)) => {
633                                let mut named_keys = NamedKeys::new();
634                                named_keys.insert(name, key);
635                                TransformKindV2::AddKeys(named_keys)
636                            }
637                            Err(error) => return Ok(AddResult::from(error)),
638                        }
639                    } else {
640                        return mismatch();
641                    }
642                }
643            },
644            _ => return mismatch(),
645        };
646
647        match transform_kind.clone().apply(current_value) {
648            Ok(TransformInstruction::Store(new_value)) => {
649                self.cache.insert_write(normalized_key, new_value);
650                self.effects
651                    .push(TransformV2::new(normalized_key, transform_kind));
652                Ok(AddResult::Success)
653            }
654            Ok(TransformInstruction::Prune(key)) => {
655                self.cache.insert_prune(normalized_key);
656                self.effects.push(TransformV2::new(
657                    normalized_key,
658                    TransformKindV2::Prune(key),
659                ));
660                Ok(AddResult::Success)
661            }
662            Err(TransformError::TypeMismatch(type_mismatch)) => {
663                Ok(AddResult::TypeMismatch(type_mismatch))
664            }
665            Err(TransformError::Serialization(error)) => Ok(AddResult::Serialization(error)),
666            Err(transform_error) => Ok(AddResult::Transform(transform_error)),
667        }
668    }
669
670    /// Returns a copy of the messages cached by this instance.
671    pub fn messages(&self) -> Messages {
672        self.messages.clone()
673    }
674
675    /// Calling `query()` avoids calling into `self.cache`, so this will not return any values
676    /// written or mutated in this `TrackingCopy` via previous calls to `write()` or `add()`, since
677    /// these updates are only held in `self.cache`.
678    ///
679    /// The intent is that `query()` is only used to satisfy `QueryRequest`s made to the server.
680    /// Other EE internal use cases should call `read()` or `get()` in order to retrieve cached
681    /// values.
682    pub fn query(
683        &self,
684        base_key: Key,
685        path: &[String],
686    ) -> Result<TrackingCopyQueryResult, TrackingCopyError> {
687        let mut query = Query::new(base_key, path);
688
689        let mut proofs = Vec::new();
690
691        loop {
692            if query.depth >= self.max_query_depth {
693                return Ok(query.into_depth_limit_result());
694            }
695
696            if !query.visited_keys.insert(query.current_key) {
697                return Ok(query.into_circular_ref_result());
698            }
699
700            let stored_value = match self.reader.read_with_proof(&query.current_key)? {
701                None => {
702                    return Ok(query.into_not_found_result("Failed to find base key"));
703                }
704                Some(stored_value) => stored_value,
705            };
706
707            let value = stored_value.value().to_owned();
708
709            // Following code does a patching on the `StoredValue` to get an inner
710            // `DictionaryValue` for dictionaries only.
711            let value = match handle_stored_dictionary_value(query.current_key, value) {
712                Ok(patched_stored_value) => patched_stored_value,
713                Err(error) => {
714                    return Ok(query.into_not_found_result(&format!(
715                        "Failed to retrieve dictionary value: {}",
716                        error
717                    )));
718                }
719            };
720
721            proofs.push(stored_value);
722
723            if query.unvisited_names.is_empty() && !query.current_key.is_named_key() {
724                return Ok(TrackingCopyQueryResult::Success { value, proofs });
725            }
726
727            let stored_value: &StoredValue = proofs
728                .last()
729                .map(|r| r.value())
730                .expect("but we just pushed");
731
732            match stored_value {
733                StoredValue::Account(account) => {
734                    let mut maybe_msg_prefix: Option<String> = None;
735                    if let Some(name) = query.next_name() {
736                        if let Some(key) = account.named_keys().get(name) {
737                            query.navigate(*key);
738                        } else {
739                            maybe_msg_prefix = Some(format!("Name {} not found in Account", name));
740                        }
741                    } else {
742                        maybe_msg_prefix = Some("All names visited".to_string());
743                    }
744                    if let Some(msg_prefix) = maybe_msg_prefix {
745                        return Ok(query.into_not_found_result(&msg_prefix));
746                    }
747                }
748                StoredValue::Contract(contract) => {
749                    let mut maybe_msg_prefix: Option<String> = None;
750                    if let Some(name) = query.next_name() {
751                        if let Some(key) = contract.named_keys().get(name) {
752                            query.navigate(*key);
753                        } else {
754                            maybe_msg_prefix = Some(format!("Name {} not found in Contract", name));
755                        }
756                    } else {
757                        maybe_msg_prefix = Some("All names visited".to_string());
758                    }
759                    if let Some(msg_prefix) = maybe_msg_prefix {
760                        return Ok(query.into_not_found_result(&msg_prefix));
761                    }
762                }
763                StoredValue::AddressableEntity(_) => {
764                    let current_key = query.current_key;
765                    let mut maybe_msg_prefix: Option<String> = None;
766                    if let Some(name) = query.next_name() {
767                        if let Key::AddressableEntity(addr) = current_key {
768                            let named_key_addr =
769                                match NamedKeyAddr::new_from_string(addr, name.clone()) {
770                                    Ok(named_key_addr) => Key::NamedKey(named_key_addr),
771                                    Err(error) => {
772                                        let msg_prefix = format!("{}", error);
773                                        return Ok(query.into_not_found_result(&msg_prefix));
774                                    }
775                                };
776                            query.navigate_for_named_key(named_key_addr);
777                        } else {
778                            maybe_msg_prefix = Some("Invalid base key".to_string());
779                        }
780                    } else {
781                        maybe_msg_prefix = Some("All names visited".to_string());
782                    }
783                    if let Some(msg_prefix) = maybe_msg_prefix {
784                        return Ok(query.into_not_found_result(&msg_prefix));
785                    }
786                }
787                StoredValue::NamedKey(named_key_value) => {
788                    match query.visited_names.last() {
789                        Some(expected_name) => match named_key_value.get_name() {
790                            Ok(actual_name) => {
791                                if &actual_name != expected_name {
792                                    return Ok(query.into_not_found_result(
793                                        "Queried and retrieved names do not match",
794                                    ));
795                                } else if let Ok(key) = named_key_value.get_key() {
796                                    query.navigate(key)
797                                } else {
798                                    return Ok(query
799                                        .into_not_found_result("Failed to parse CLValue as Key"));
800                                }
801                            }
802                            Err(_) => {
803                                return Ok(query
804                                    .into_not_found_result("Failed to parse CLValue as String"));
805                            }
806                        },
807                        None if path.is_empty() => {
808                            return Ok(TrackingCopyQueryResult::Success { value, proofs });
809                        }
810                        None => return Ok(query.into_not_found_result("No visited names")),
811                    }
812                }
813                StoredValue::CLValue(cl_value) if cl_value.cl_type() == &CLType::Key => {
814                    if let Ok(key) = cl_value.to_owned().into_t::<Key>() {
815                        query.navigate(key);
816                    } else {
817                        return Ok(query.into_not_found_result("Failed to parse CLValue as Key"));
818                    }
819                }
820                StoredValue::CLValue(cl_value) => {
821                    let msg_prefix = format!(
822                        "Query cannot continue as {:?} is not an account, contract nor key to \
823                        such.  Value found",
824                        cl_value
825                    );
826                    return Ok(query.into_not_found_result(&msg_prefix));
827                }
828                StoredValue::ContractWasm(_) => {
829                    return Ok(query.into_not_found_result("ContractWasm value found."));
830                }
831                StoredValue::ContractPackage(_) => {
832                    return Ok(query.into_not_found_result("ContractPackage value found."));
833                }
834                StoredValue::SmartContract(_) => {
835                    return Ok(query.into_not_found_result("Package value found."));
836                }
837                StoredValue::ByteCode(_) => {
838                    return Ok(query.into_not_found_result("ByteCode value found."));
839                }
840                StoredValue::Transfer(_) => {
841                    return Ok(query.into_not_found_result("Legacy Transfer value found."));
842                }
843                StoredValue::DeployInfo(_) => {
844                    return Ok(query.into_not_found_result("DeployInfo value found."));
845                }
846                StoredValue::EraInfo(_) => {
847                    return Ok(query.into_not_found_result("EraInfo value found."));
848                }
849                StoredValue::Bid(_) => {
850                    return Ok(query.into_not_found_result("Bid value found."));
851                }
852                StoredValue::BidKind(_) => {
853                    return Ok(query.into_not_found_result("BidKind value found."));
854                }
855                StoredValue::Withdraw(_) => {
856                    return Ok(query.into_not_found_result("WithdrawPurses value found."));
857                }
858                StoredValue::Unbonding(_) => {
859                    return Ok(query.into_not_found_result("Unbonding value found."));
860                }
861                StoredValue::MessageTopic(_) => {
862                    return Ok(query.into_not_found_result("MessageTopic value found."));
863                }
864                StoredValue::Message(_) => {
865                    return Ok(query.into_not_found_result("Message value found."));
866                }
867                StoredValue::EntryPoint(_) => {
868                    return Ok(query.into_not_found_result("EntryPoint value found."));
869                }
870                StoredValue::Prepayment(_) => {
871                    return Ok(query.into_not_found_result("Prepayment value found."))
872                }
873                StoredValue::RawBytes(_) => {
874                    return Ok(query.into_not_found_result("RawBytes value found."));
875                }
876            }
877        }
878    }
879}
880
881/// The purpose of this implementation is to allow a "snapshot" mechanism for
882/// TrackingCopy. The state of a TrackingCopy (including the effects of
883/// any transforms it has accumulated) can be read using an immutable
884/// reference to that TrackingCopy via this trait implementation. See
885/// `TrackingCopy::fork` for more information.
886impl<R: StateReader<Key, StoredValue>> StateReader<Key, StoredValue> for &TrackingCopy<R> {
887    type Error = R::Error;
888
889    fn read(&self, key: &Key) -> Result<Option<StoredValue>, Self::Error> {
890        let kb = KeyWithByteRepr::new(*key);
891        if let Some(value) = self.cache.muts_cached.get(&kb) {
892            return Ok(Some(value.to_owned()));
893        }
894        if let Some(value) = self.reader.read(key)? {
895            Ok(Some(value))
896        } else {
897            Ok(None)
898        }
899    }
900
901    fn read_with_proof(
902        &self,
903        key: &Key,
904    ) -> Result<Option<TrieMerkleProof<Key, StoredValue>>, Self::Error> {
905        self.reader.read_with_proof(key)
906    }
907
908    fn keys_with_prefix(&self, byte_prefix: &[u8]) -> Result<Vec<Key>, Self::Error> {
909        let keys = self.reader.keys_with_prefix(byte_prefix)?;
910
911        let ret = keys
912            .into_iter()
913            // don't include keys marked for pruning
914            .filter(|key| !self.cache.is_pruned(key))
915            // there may be newly inserted keys which have not been committed yet
916            .chain(self.cache.get_muts_cached_by_byte_prefix(byte_prefix))
917            .collect();
918        Ok(ret)
919    }
920}
921
922/// Error conditions of a proof validation.
923#[derive(Error, Debug, PartialEq, Eq)]
924pub enum ValidationError {
925    /// The path should not have a different length than the proof less one.
926    #[error("The path should not have a different length than the proof less one.")]
927    PathLengthDifferentThanProofLessOne,
928
929    /// The provided key does not match the key in the proof.
930    #[error("The provided key does not match the key in the proof.")]
931    UnexpectedKey,
932
933    /// The provided value does not match the value in the proof.
934    #[error("The provided value does not match the value in the proof.")]
935    UnexpectedValue,
936
937    /// The proof hash is invalid.
938    #[error("The proof hash is invalid.")]
939    InvalidProofHash,
940
941    /// The path went cold.
942    #[error("The path went cold.")]
943    PathCold,
944
945    /// (De)serialization error.
946    #[error("Serialization error: {0}")]
947    BytesRepr(bytesrepr::Error),
948
949    /// Key is not a URef.
950    #[error("Key is not a URef")]
951    KeyIsNotAURef(Key),
952
953    /// Error converting a stored value to a [`Key`].
954    #[error("Failed to convert stored value to key")]
955    ValueToCLValueConversion,
956
957    /// CLValue conversion error.
958    #[error("{0}")]
959    CLValueError(CLValueError),
960}
961
962impl From<CLValueError> for ValidationError {
963    fn from(err: CLValueError) -> Self {
964        ValidationError::CLValueError(err)
965    }
966}
967
968impl From<bytesrepr::Error> for ValidationError {
969    fn from(error: bytesrepr::Error) -> Self {
970        Self::BytesRepr(error)
971    }
972}
973
974/// Validates proof of the query.
975///
976/// Returns [`ValidationError`] for any of
977pub fn validate_query_proof(
978    hash: &Digest,
979    proofs: &[TrieMerkleProof<Key, StoredValue>],
980    expected_first_key: &Key,
981    path: &[String],
982    expected_value: &StoredValue,
983) -> Result<(), ValidationError> {
984    if proofs.len() != path.len() + 1 {
985        return Err(ValidationError::PathLengthDifferentThanProofLessOne);
986    }
987
988    let mut proofs_iter = proofs.iter();
989    let first_proof = match proofs_iter.next() {
990        Some(proof) => proof,
991        None => {
992            return Err(ValidationError::PathLengthDifferentThanProofLessOne);
993        }
994    };
995
996    let mut path_components_iter = path.iter();
997
998    if first_proof.key() != &expected_first_key.normalize() {
999        return Err(ValidationError::UnexpectedKey);
1000    }
1001
1002    if hash != &compute_state_hash(first_proof)? {
1003        return Err(ValidationError::InvalidProofHash);
1004    }
1005
1006    let mut proof_value = first_proof.value();
1007
1008    for proof in proofs_iter {
1009        let named_keys = match proof_value {
1010            StoredValue::Account(account) => account.named_keys(),
1011            StoredValue::Contract(contract) => contract.named_keys(),
1012            _ => return Err(ValidationError::PathCold),
1013        };
1014
1015        let path_component = match path_components_iter.next() {
1016            Some(path_component) => path_component,
1017            None => return Err(ValidationError::PathCold),
1018        };
1019
1020        let key = match named_keys.get(path_component) {
1021            Some(key) => key,
1022            None => return Err(ValidationError::PathCold),
1023        };
1024
1025        if proof.key() != &key.normalize() {
1026            return Err(ValidationError::UnexpectedKey);
1027        }
1028
1029        if hash != &compute_state_hash(proof)? {
1030            return Err(ValidationError::InvalidProofHash);
1031        }
1032
1033        proof_value = proof.value();
1034    }
1035
1036    if proof_value != expected_value {
1037        return Err(ValidationError::UnexpectedValue);
1038    }
1039
1040    Ok(())
1041}
1042
1043/// Validates proof of the query.
1044///
1045/// Returns [`ValidationError`] for any of
1046pub fn validate_query_merkle_proof(
1047    hash: &Digest,
1048    proofs: &[TrieMerkleProof<Key, StoredValue>],
1049    expected_key_trace: &[Key],
1050    expected_value: &StoredValue,
1051) -> Result<(), ValidationError> {
1052    let expected_len = expected_key_trace.len();
1053    if proofs.len() != expected_len {
1054        return Err(ValidationError::PathLengthDifferentThanProofLessOne);
1055    }
1056
1057    let proof_keys: Vec<Key> = proofs.iter().map(|proof| *proof.key()).collect();
1058
1059    if !expected_key_trace.eq(&proof_keys) {
1060        return Err(ValidationError::UnexpectedKey);
1061    }
1062
1063    if expected_value != proofs[expected_len - 1].value() {
1064        return Err(ValidationError::UnexpectedValue);
1065    }
1066
1067    let mut proofs_iter = proofs.iter();
1068
1069    let first_proof = match proofs_iter.next() {
1070        Some(proof) => proof,
1071        None => return Err(ValidationError::PathLengthDifferentThanProofLessOne),
1072    };
1073
1074    if hash != &compute_state_hash(first_proof)? {
1075        return Err(ValidationError::InvalidProofHash);
1076    }
1077
1078    Ok(())
1079}
1080
1081/// Validates a proof of a balance request.
1082pub fn validate_balance_proof(
1083    hash: &Digest,
1084    balance_proof: &TrieMerkleProof<Key, StoredValue>,
1085    expected_purse_key: Key,
1086    expected_motes: &U512,
1087) -> Result<(), ValidationError> {
1088    let expected_balance_key = expected_purse_key
1089        .into_uref()
1090        .map(|uref| Key::Balance(uref.addr()))
1091        .ok_or_else(|| ValidationError::KeyIsNotAURef(expected_purse_key.to_owned()))?;
1092
1093    if balance_proof.key() != &expected_balance_key.normalize() {
1094        return Err(ValidationError::UnexpectedKey);
1095    }
1096
1097    if hash != &compute_state_hash(balance_proof)? {
1098        return Err(ValidationError::InvalidProofHash);
1099    }
1100
1101    let balance_proof_stored_value = balance_proof.value().to_owned();
1102
1103    let balance_proof_clvalue: CLValue = balance_proof_stored_value
1104        .try_into()
1105        .map_err(|_| ValidationError::ValueToCLValueConversion)?;
1106
1107    let balance_motes: U512 = balance_proof_clvalue.into_t()?;
1108
1109    if expected_motes != &balance_motes {
1110        return Err(ValidationError::UnexpectedValue);
1111    }
1112
1113    Ok(())
1114}
1115
1116use crate::global_state::{
1117    error::Error,
1118    state::{
1119        lmdb::{make_temporary_global_state, LmdbGlobalStateView},
1120        StateProvider,
1121    },
1122};
1123use tempfile::TempDir;
1124
1125/// Creates a temp global state with initial state and checks out a tracking copy on it.
1126pub fn new_temporary_tracking_copy(
1127    initial_data: impl IntoIterator<Item = (Key, StoredValue)>,
1128    max_query_depth: Option<u64>,
1129    enable_addressable_entity: bool,
1130) -> (TrackingCopy<LmdbGlobalStateView>, TempDir) {
1131    let (global_state, state_root_hash, tempdir) = make_temporary_global_state(initial_data);
1132
1133    let reader = global_state
1134        .checkout(state_root_hash)
1135        .expect("Checkout should not throw errors.")
1136        .expect("Root hash should exist.");
1137
1138    let query_depth = max_query_depth.unwrap_or(DEFAULT_MAX_QUERY_DEPTH);
1139
1140    (
1141        TrackingCopy::new(reader, query_depth, enable_addressable_entity),
1142        tempdir,
1143    )
1144}