Skip to main content

midnight_storage/
storage.rs

1// This file is part of midnight-ledger.
2// Copyright (C) 2025 Midnight Foundation
3// SPDX-License-Identifier: Apache-2.0
4// Licensed under the Apache License, Version 2.0 (the "License");
5// You may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7// http://www.apache.org/licenses/LICENSE-2.0
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14//! storage containers
15
16use crate as storage;
17use crate::arena::{ArenaHash, ArenaKey, Sp};
18use crate::db::{DB, InMemoryDB};
19use crate::merkle_patricia_trie::Annotation;
20use crate::merkle_patricia_trie::MerklePatriciaTrie;
21use crate::merkle_patricia_trie::Semigroup;
22use crate::storable::{Loader, SizeAnn};
23use crate::{DefaultDB, Storable};
24use base_crypto::time::Timestamp;
25use crypto::digest::Digest;
26use derive_where::derive_where;
27#[cfg(feature = "proptest")]
28use proptest::arbitrary::Arbitrary;
29use rand::distributions::{Distribution, Standard};
30#[cfg(feature = "proptest")]
31use serialize::NoStrategy;
32use serialize::{Deserializable, Serializable, Tagged, tag_enforcement_test};
33use sha2::Sha256;
34use std::borrow::Borrow;
35use std::fmt::{Debug, Formatter};
36use std::hash::Hash;
37use std::marker::PhantomData;
38use std::ops::Deref;
39use std::sync::Arc;
40
41pub use storage_core::storage::*;
42
43/// Storage backed by an in-memory hash map, indexed by SHA256 hashes
44pub type InMemoryStorage = Storage<InMemoryDB<Sha256>>;
45
46/// A map from key hashes to values
47#[derive(Storable)]
48#[derive_where(Clone, Eq, PartialEq; V, A)]
49#[storable(db = D, invariant = HashMap::invariant)]
50pub struct HashMap<
51    K: Serializable + Storable<D>,
52    V: Storable<D>,
53    D: DB = DefaultDB,
54    A: Storable<D> + Annotation<(Sp<K, D>, Sp<V, D>)> = SizeAnn,
55>(
56    #[cfg(feature = "public-internal-structure")]
57    #[allow(clippy::type_complexity)]
58    pub Map<ArenaHash<D::Hasher>, (Sp<K, D>, Sp<V, D>), D, A>,
59    #[cfg(not(feature = "public-internal-structure"))]
60    #[allow(clippy::type_complexity)]
61    Map<ArenaHash<D::Hasher>, (Sp<K, D>, Sp<V, D>), D, A>,
62);
63
64impl<
65    K: Serializable + Storable<D> + Tagged,
66    V: Storable<D> + Tagged,
67    D: DB,
68    A: Storable<D> + Annotation<(Sp<K, D>, Sp<V, D>)> + Tagged,
69> Tagged for HashMap<K, V, D, A>
70{
71    fn tag() -> std::borrow::Cow<'static, str> {
72        format!("hash-map({},{},{})", K::tag(), V::tag(), A::tag()).into()
73    }
74    fn tag_unique_factor() -> String {
75        <Map<ArenaHash<D::Hasher>, (Sp<K, D>, Sp<V, D>), D, A>>::tag_unique_factor()
76    }
77}
78tag_enforcement_test!(HashMap<(), ()>);
79
80impl<
81    K: Serializable + Storable<D>,
82    V: Storable<D> + PartialEq,
83    D: DB,
84    A: Storable<D> + Annotation<(Sp<K, D>, Sp<V, D>)>,
85> Hash for HashMap<K, V, D, A>
86{
87    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
88        self.0.hash(state);
89    }
90}
91
92impl<
93    K: Serializable + Storable<D> + PartialOrd,
94    V: Storable<D> + PartialOrd,
95    D: DB,
96    A: Storable<D> + PartialOrd + Annotation<(Sp<K, D>, Sp<V, D>)>,
97> PartialOrd for HashMap<K, V, D, A>
98{
99    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
100        self.0.partial_cmp(&other.0)
101    }
102}
103
104impl<
105    K: Serializable + Storable<D> + Ord,
106    V: Storable<D> + Ord,
107    D: DB,
108    A: Storable<D> + Ord + Annotation<(Sp<K, D>, Sp<V, D>)>,
109> Ord for HashMap<K, V, D, A>
110{
111    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
112        self.0.cmp(&other.0)
113    }
114}
115
116impl<
117    K: Debug + Serializable + Storable<D>,
118    V: Debug + Storable<D>,
119    D: DB,
120    A: Storable<D> + Annotation<(Sp<K, D>, Sp<V, D>)>,
121> Debug for HashMap<K, V, D, A>
122{
123    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
124        f.debug_map()
125            .entries(self.iter().map(|kv| (kv.0.clone(), kv.1.clone())))
126            .finish()
127    }
128}
129
130#[cfg(feature = "proptest")]
131impl<
132    K: Storable<D> + Debug + Serializable,
133    V: Storable<D> + Debug,
134    D: DB,
135    A: Storable<D> + Annotation<(Sp<K, D>, Sp<V, D>)>,
136> Arbitrary for HashMap<K, V, D, A>
137where
138    Standard: Distribution<V> + Distribution<K>,
139{
140    type Strategy = NoStrategy<HashMap<K, V, D, A>>;
141    type Parameters = ();
142
143    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
144        NoStrategy(PhantomData)
145    }
146}
147
148impl<
149    D: DB,
150    K: Serializable + Storable<D>,
151    V: Storable<D>,
152    A: Storable<D> + Annotation<(Sp<K, D>, Sp<V, D>)>,
153> Distribution<HashMap<K, V, D, A>> for Standard
154where
155    Standard: Distribution<V> + Distribution<K>,
156{
157    fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> HashMap<K, V, D, A> {
158        let mut map = HashMap::new();
159        let size: usize = rng.gen_range(0..8);
160
161        for _ in 0..size {
162            map = map.insert(rng.r#gen(), rng.r#gen())
163        }
164
165        map
166    }
167}
168
169impl<
170    K: serde::Serialize + Serializable + Storable<D>,
171    V: serde::Serialize + Storable<D>,
172    D: DB,
173    A: Storable<D> + Annotation<(Sp<K, D>, Sp<V, D>)>,
174> serde::Serialize for HashMap<K, V, D, A>
175{
176    fn serialize<S: serde::Serializer>(&self, ser: S) -> Result<S::Ok, S::Error> {
177        ser.collect_map(
178            self.iter()
179                .map(|kv| (kv.0.deref().clone(), kv.1.deref().clone())),
180        )
181    }
182}
183
184struct HashMapVisitor<K, V, D, A>(PhantomData<(K, V, D, A)>);
185
186impl<
187    'de,
188    K: serde::Deserialize<'de> + Serializable + Storable<D>,
189    V: serde::Deserialize<'de> + Storable<D>,
190    D: DB,
191    A: Storable<D> + Annotation<(Sp<K, D>, Sp<V, D>)>,
192> serde::de::Visitor<'de> for HashMapVisitor<K, V, D, A>
193{
194    type Value = HashMap<K, V, D, A>;
195
196    fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
197        write!(formatter, "a hashmap")
198    }
199
200    fn visit_map<ACC: serde::de::MapAccess<'de>>(
201        self,
202        mut seq: ACC,
203    ) -> Result<HashMap<K, V, D, A>, ACC::Error> {
204        std::iter::from_fn(|| seq.next_entry::<K, V>().transpose()).collect()
205    }
206}
207
208impl<
209    'de,
210    K: serde::Deserialize<'de> + Serializable + Storable<D1>,
211    V: serde::Deserialize<'de> + Storable<D1>,
212    D1: DB,
213    A: Storable<D1> + Annotation<(Sp<K, D1>, Sp<V, D1>)>,
214> serde::Deserialize<'de> for HashMap<K, V, D1, A>
215{
216    fn deserialize<D: serde::Deserializer<'de>>(de: D) -> Result<Self, D::Error> {
217        de.deserialize_map(HashMapVisitor(PhantomData))
218    }
219}
220
221impl<
222    K: Serializable + Storable<D>,
223    V: Storable<D>,
224    D: DB,
225    A: Storable<D> + Annotation<(Sp<K, D>, Sp<V, D>)>,
226> Default for HashMap<K, V, D, A>
227{
228    fn default() -> Self {
229        Self::new()
230    }
231}
232
233impl<
234    K: Serializable + Storable<D>,
235    V: Storable<D>,
236    D: DB,
237    A: Storable<D> + Annotation<(Sp<K, D>, Sp<V, D>)>,
238> HashMap<K, V, D, A>
239{
240    /// Creates an empty map
241    pub fn new() -> Self {
242        Self(Map::new())
243    }
244
245    fn gen_key(key: &K) -> ArenaHash<D::Hasher> {
246        let mut hasher = D::Hasher::default();
247        let mut bytes: std::vec::Vec<u8> = std::vec::Vec::new();
248        K::serialize(key, &mut bytes).expect("HashMap key should be serializable");
249        hasher.update(bytes);
250        ArenaHash(hasher.finalize())
251    }
252
253    /// Insert object value in map, keyed with the hash of object key. Overwrites
254    /// any preexisting object under the same key
255    #[must_use]
256    pub fn insert(&self, key: K, value: V) -> Self {
257        HashMap(self.0.insert(
258            Self::gen_key(&key),
259            (
260                self.0.mpt.0.arena.alloc(key),
261                self.0.mpt.0.arena.alloc(value),
262            ),
263        ))
264    }
265
266    fn invariant(&self) -> Result<(), std::io::Error> {
267        for (hash, v) in self.0.iter() {
268            let key = &*v.0;
269            let hash2 = Self::gen_key(key);
270            if hash != hash2 {
271                return Err(std::io::Error::new(
272                    std::io::ErrorKind::InvalidData,
273                    "hashmap key doesn't match serialized hash",
274                ));
275            }
276        }
277        Ok(())
278    }
279
280    /// Get object keyed by the hash of object key
281    pub fn get(&self, key: &K) -> Option<Sp<V, D>> {
282        self.0.get(&Self::gen_key(key)).map(|(_, v)| v.clone())
283    }
284
285    /// Remove object keyed by the hash of object key
286    #[must_use]
287    pub fn remove(&self, key: &K) -> Self {
288        HashMap(self.0.remove(&Self::gen_key(key)))
289    }
290
291    /// Check if the map contains a key.
292    pub fn contains_key(&self, key: &K) -> bool {
293        self.0.contains_key(&Self::gen_key(key))
294    }
295
296    /// Consume internal pointers, returning only the leaves left dangling by this.
297    /// Used for custom `Drop` implementations.
298    pub fn into_inner_for_drop(self) -> impl Iterator<Item = (Option<K>, Option<V>)> {
299        self.0.into_inner_for_drop().filter_map(|(k, v)| {
300            let (k, v) = (Sp::into_inner(k), Sp::into_inner(v));
301            if k.is_none() && v.is_none() {
302                None
303            } else {
304                Some((k, v))
305            }
306        })
307    }
308
309    /// Iterate over the key value pairs in the hash map
310    #[allow(clippy::type_complexity)]
311    pub fn iter(
312        &self,
313    ) -> impl Iterator<Item = Sp<(Sp<K, D>, Sp<V, D>), D>> + use<'_, K, V, D, A> + '_ {
314        self.0.iter().map(|(_, v)| v)
315    }
316
317    /// Number of elements in the map
318    pub fn size(&self) -> usize {
319        self.0.size()
320    }
321
322    /// Returns true if empty
323    pub fn is_empty(&self) -> bool {
324        self.0.is_empty()
325    }
326
327    /// Returns keys
328    pub fn keys(&self) -> impl Iterator<Item = K> + use<'_, K, V, D, A> + '_ {
329        self.iter().map(|x| x.0.deref().clone())
330    }
331
332    /// Returns values
333    pub fn values(&self) -> impl Iterator<Item = V> + use<'_, K, V, D, A> + '_ {
334        self.iter().map(|x| x.1.deref().clone())
335    }
336
337    /// Retrieve the annotation on the root of the trie
338    pub fn ann(&self) -> A {
339        self.0.ann()
340    }
341}
342
343impl<K, V, D, A> FromIterator<(K, V)> for HashMap<K, V, D, A>
344where
345    K: Serializable + Storable<D>,
346    V: Storable<D>,
347    D: DB,
348    A: Storable<D> + Annotation<(Sp<K, D>, Sp<V, D>)>,
349{
350    fn from_iter<T: IntoIterator<Item = (K, V)>>(iter: T) -> Self {
351        iter.into_iter()
352            .fold(HashMap::new(), |map, (k, v)| map.insert(k, v))
353    }
354}
355
356impl<
357    K: Serializable + Deserializable + Storable<D>,
358    V: Storable<D>,
359    D: DB,
360    A: Storable<D> + Annotation<(Sp<K, D>, Sp<V, D>)> + Annotation<V>,
361> From<Map<K, V, D, A>> for HashMap<K, V, D, A>
362{
363    fn from(value: Map<K, V, D, A>) -> Self {
364        let mut hashmap = HashMap::new();
365
366        for (k, v) in value.iter() {
367            hashmap = hashmap.insert(k, v.deref().clone());
368        }
369
370        hashmap
371    }
372}
373
374/// Iterator type
375pub struct HashMapIntoIter<K, V, D: DB>
376where
377    K: 'static,
378    V: 'static,
379{
380    #[allow(clippy::type_complexity)]
381    inner: std::vec::IntoIter<(ArenaHash<D::Hasher>, (Sp<K, D>, Sp<V, D>))>,
382}
383
384impl<K, V, D> Iterator for HashMapIntoIter<K, V, D>
385where
386    K: Serializable + Storable<D> + Clone + 'static,
387    V: Storable<D> + Clone + 'static,
388    D: DB,
389{
390    type Item = (K, V);
391
392    fn next(&mut self) -> Option<Self::Item> {
393        self.inner
394            .next()
395            .map(|(_arena_key, (sp_key, sp_val))| ((*sp_key).clone(), (*sp_val).clone()))
396    }
397}
398
399impl<K, V, D> IntoIterator for HashMap<K, V, D>
400where
401    K: Serializable + Storable<D> + Clone + 'static,
402    V: Storable<D> + Clone + 'static,
403    D: DB,
404{
405    type Item = (K, V);
406    type IntoIter = HashMapIntoIter<K, V, D>;
407
408    fn into_iter(self) -> Self::IntoIter {
409        HashMapIntoIter {
410            inner: self.0.into_iter(),
411        }
412    }
413}
414
415/// A set. Uses `HashMap` under the hood.
416#[derive(Storable)]
417#[derive_where(Clone, Eq, PartialEq, PartialOrd, Ord, Hash; V, A)]
418#[storable(db = D)]
419#[tag = "hash-set"]
420pub struct HashSet<
421    V: Storable<D> + Serializable,
422    D: DB = DefaultDB,
423    A: Storable<D> + Annotation<(Sp<V, D>, Sp<(), D>)> = SizeAnn,
424>(pub HashMap<V, (), D, A>);
425tag_enforcement_test!(HashSet<()>);
426
427impl<
428    V: serde::Serialize + Serializable + Storable<D>,
429    D: DB,
430    A: Storable<D> + Annotation<(Sp<V, D>, Sp<(), D>)>,
431> serde::Serialize for HashSet<V, D, A>
432{
433    fn serialize<S: serde::Serializer>(&self, ser: S) -> Result<S::Ok, S::Error> {
434        ser.collect_seq(self.iter().map(|v| (**v).clone()))
435    }
436}
437
438impl<V: Storable<D> + Serializable, D: DB, A: Storable<D> + Annotation<(Sp<V, D>, Sp<(), D>)>>
439    HashSet<V, D, A>
440{
441    /// Creates an empty set
442    pub fn new() -> Self {
443        Self(HashMap::new())
444    }
445
446    /// Insert object value. Overwrites
447    /// any preexisting object under the same value
448    #[must_use]
449    pub fn insert(&self, value: V) -> Self {
450        HashSet(self.0.insert(value, ()))
451    }
452
453    /// Remove object
454    #[must_use]
455    pub fn remove(&self, value: &V) -> Self {
456        HashSet(self.0.remove(value))
457    }
458
459    /// Check if the set contains a value.
460    pub fn member(&self, value: &V) -> bool {
461        self.0.contains_key(value)
462    }
463
464    /// Check if a `HashSet` is the subset of another `HashSet`.
465    pub fn is_subset(&self, other: &HashSet<V, D, A>) -> bool {
466        self.iter().all(|x| other.member(&x))
467    }
468
469    /// Union with another set
470    pub fn union(&self, other: &HashSet<V, D, A>) -> HashSet<V, D, A>
471    where
472        V: Clone,
473    {
474        other
475            .iter()
476            .fold(self.clone(), |acc, x| acc.insert(x.deref().deref().clone()))
477    }
478
479    /// Iterate over the key value pairs in the hash set
480    pub fn iter(&self) -> impl Iterator<Item = Arc<Sp<V, D>>> + '_
481    where
482        V: Clone,
483    {
484        self.0.iter().map(|v| Arc::new(v.0.clone()))
485    }
486
487    /// Number of elements in the set
488    pub fn size(&self) -> usize {
489        self.0.size()
490    }
491
492    /// Returns true if empty
493    pub fn is_empty(&self) -> bool {
494        self.0.is_empty()
495    }
496
497    /// Retrieve the annotation on the root of the trie
498    pub fn ann(&self) -> A {
499        self.0.ann()
500    }
501}
502
503impl<V: Storable<D> + Serializable, D: DB, A: Storable<D> + Annotation<(Sp<V, D>, Sp<(), D>)>>
504    Default for HashSet<V, D, A>
505{
506    fn default() -> Self {
507        Self::new()
508    }
509}
510
511impl<
512    V: Debug + Storable<D> + Serializable,
513    D: DB,
514    A: Storable<D> + Annotation<(Sp<V, D>, Sp<(), D>)>,
515> Debug for HashSet<V, D, A>
516{
517    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
518        self.0.fmt(f)
519    }
520}
521
522#[cfg(feature = "proptest")]
523impl<
524    V: Storable<D> + Serializable + Debug,
525    D: DB,
526    A: Storable<D> + Annotation<(Sp<V, D>, Sp<(), D>)>,
527> Arbitrary for HashSet<V, D, A>
528where
529    Standard: Distribution<V>,
530{
531    type Strategy = NoStrategy<HashSet<V, D, A>>;
532    type Parameters = ();
533
534    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
535        NoStrategy(PhantomData)
536    }
537}
538
539impl<V: Storable<D> + Serializable, D: DB, A: Storable<D> + Annotation<(Sp<V, D>, Sp<(), D>)>>
540    Distribution<HashSet<V, D, A>> for Standard
541where
542    Standard: Distribution<V>,
543{
544    fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> HashSet<V, D, A> {
545        let mut set = HashSet::new();
546        let size: usize = rng.gen_range(0..8);
547
548        for _ in 0..size {
549            set = set.insert(rng.r#gen())
550        }
551
552        set
553    }
554}
555
556impl<V, D, A> FromIterator<V> for HashSet<V, D, A>
557where
558    V: Storable<D> + Serializable,
559    D: DB,
560    A: Storable<D> + Annotation<(Sp<V, D>, Sp<(), D>)>,
561{
562    fn from_iter<T: IntoIterator<Item = V>>(iter: T) -> Self {
563        iter.into_iter()
564            .fold(HashSet::new(), |set, item| set.insert(item))
565    }
566}
567
568/// An array built from a `MerklePatriciaTrie`
569#[derive_where(Clone; V)]
570#[derive(Storable)]
571#[storable(db = D, invariant = Array::invariant)]
572#[tag = "mpt-array[v1]"]
573pub struct Array<V: Storable<D>, D: DB = DefaultDB>(
574    // Array wraps MPT in an Sp to guarantee it only has one child
575    #[cfg(feature = "public-internal-structure")]
576    #[storable(child)]
577    pub Sp<MerklePatriciaTrie<V, D>, D>,
578    #[cfg(not(feature = "public-internal-structure"))]
579    #[storable(child)]
580    Sp<MerklePatriciaTrie<V, D>, D>,
581);
582tag_enforcement_test!(Array<()>);
583
584impl<V: Storable<D> + Debug, D: DB> Debug for Array<V, D> {
585    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
586        f.debug_list().entries(self.iter()).finish()
587    }
588}
589
590impl<V: Storable<D>, D: DB> From<&std::vec::Vec<V>> for Array<V, D> {
591    fn from(value: &std::vec::Vec<V>) -> Self {
592        value.clone().into()
593    }
594}
595
596impl<const N: usize, V: Storable<D>, D: DB> From<[V; N]> for Array<V, D> {
597    fn from(value: [V; N]) -> Self {
598        Array::from_iter(value)
599    }
600}
601
602impl<V: Storable<D>, D: DB> From<std::vec::Vec<V>> for Array<V, D> {
603    fn from(value: std::vec::Vec<V>) -> Self {
604        Array::from_iter(value)
605    }
606}
607
608impl<V: Storable<D>, D: DB> From<&Array<V, D>> for std::vec::Vec<V> {
609    fn from(value: &Array<V, D>) -> Self {
610        value.iter().map(|x| (*x).clone()).collect()
611    }
612}
613
614impl<V: Storable<D>, D: DB> From<Array<V, D>> for std::vec::Vec<V> {
615    fn from(value: Array<V, D>) -> Self {
616        (&value).into()
617    }
618}
619
620impl<V: Storable<D>, D: DB> std::iter::FromIterator<V> for Array<V, D> {
621    fn from_iter<I: IntoIterator<Item = V>>(iter: I) -> Self {
622        let mut arr = Array::new();
623        for item in iter {
624            arr = arr.push(item);
625        }
626        arr
627    }
628}
629
630impl<V: Storable<D>, D: DB> Distribution<Array<V, D>> for Standard
631where
632    Standard: Distribution<V>,
633{
634    fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Array<V, D> {
635        let mut array = Array::new();
636        let len = rng.r#gen::<u8>();
637        for _ in 0..len {
638            array = array.push(rng.r#gen())
639        }
640        array
641    }
642}
643
644#[cfg(feature = "proptest")]
645impl<V: Debug + Storable<D>, D: DB> Arbitrary for Array<V, D>
646where
647    Standard: Distribution<V>,
648{
649    type Strategy = NoStrategy<Array<V, D>>;
650    type Parameters = ();
651
652    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
653        NoStrategy(PhantomData)
654    }
655}
656
657impl<V: Storable<D> + PartialEq, D: DB> PartialEq for Array<V, D> {
658    fn eq(&self, other: &Self) -> bool {
659        self.0.eq(&other.0)
660    }
661}
662
663impl<V: Storable<D> + Eq, D: DB> Eq for Array<V, D> {}
664
665impl<V: Storable<D> + PartialOrd, D: DB> PartialOrd for Array<V, D> {
666    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
667        self.0.partial_cmp(&other.0)
668    }
669}
670
671impl<V: Storable<D> + Ord, D: DB> Ord for Array<V, D> {
672    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
673        self.0.cmp(&other.0)
674    }
675}
676
677impl<V: Storable<D>, D: DB> Hash for Array<V, D> {
678    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
679        self.0.deref().hash(state);
680    }
681}
682
683impl<V: Storable<D>, D: DB> Default for Array<V, D> {
684    fn default() -> Self {
685        Self::new()
686    }
687}
688
689impl<V: Storable<D>, D: DB> Array<V, D> {
690    // Convert array index into u4 nibbles for use as mpt path.
691    //
692    // Drops leading zero nibbles, so that small arrays will have short paths.
693    fn index_to_nibbles(i: usize) -> Vec<u8> {
694        let nibbles = to_nibbles(&BigEndianU64(i as u64));
695        nibbles.into_iter().skip_while(|x| *x == 0).collect()
696    }
697
698    fn nibbles_to_index(raw_nibbles: &[u8]) -> Result<u64, std::io::Error> {
699        if raw_nibbles.first() == Some(&0) {
700            return Err(std::io::Error::new(
701                std::io::ErrorKind::InvalidData,
702                "nibbles in array should not have leading zeroes",
703            ));
704        }
705        let mut nibbles = [0u8; 16];
706        if raw_nibbles.len() > 16 {
707            return Err(std::io::Error::new(
708                std::io::ErrorKind::InvalidData,
709                "too long key for index in array",
710            ));
711        }
712        nibbles[(16 - raw_nibbles.len())..].copy_from_slice(raw_nibbles);
713        let val: BigEndianU64 = from_nibbles(&nibbles)?;
714        Ok(val.0)
715    }
716
717    fn invariant(&self) -> Result<(), std::io::Error> {
718        let len = self.len() as u64;
719        self.0
720            .iter()
721            .map(|(k, _)| Self::nibbles_to_index(&k))
722            .try_for_each(|n| {
723                if n? >= len {
724                    Err(std::io::Error::new(
725                        std::io::ErrorKind::InvalidData,
726                        "index out of range for array on deserialization",
727                    ))
728                } else {
729                    Ok(())
730                }
731            })
732    }
733
734    /// Construct an empty new array
735    pub fn new() -> Self {
736        Array(Sp::new(MerklePatriciaTrie::new()))
737    }
738
739    /// Generates a new [Array] from a value slice
740    pub fn new_from_slice(values: &[V]) -> Self {
741        let mut array = Array::<V, D>::new();
742        for v in values.iter() {
743            array = array.push(v.clone());
744        }
745        array
746    }
747
748    /// Number of elements in Array.
749    ///
750    /// The elements are stored at indices `0..len()`.
751    pub fn len(&self) -> usize {
752        self.0.deref().clone().size()
753    }
754
755    /// Get element at `index`. Returns `None` if `index` is out of bounds.
756    pub fn get(&self, index: usize) -> Option<&V> {
757        self.0.lookup(&Self::index_to_nibbles(index))
758    }
759
760    /// Insert element at index.
761    ///
762    /// Must be an existing index, or returns `None`. Use `push` if you want to
763    /// grow the array.
764    #[must_use]
765    pub fn insert(&self, index: usize, value: V) -> Option<Self> {
766        if index >= self.len() {
767            return None; // Index out of bounds
768        }
769        Some(Array(Sp::new(
770            self.0.insert(&Self::index_to_nibbles(index), value),
771        )))
772    }
773
774    /// Appends an element to the end of the array, growing the length by 1.
775    #[must_use]
776    pub fn push(&self, value: V) -> Self {
777        let index = self.len();
778        Self(Sp::new(
779            self.0.insert(&Self::index_to_nibbles(index), value),
780        ))
781    }
782
783    /// Consume internal pointers, returning only the leaves left dangling by this.
784    /// Used for custom `Drop` implementations.
785    pub fn into_inner_for_drop(self) -> impl Iterator<Item = V> {
786        Sp::into_inner(self.0)
787            .into_iter()
788            .flat_map(MerklePatriciaTrie::into_inner_for_drop)
789    }
790
791    /// Iterate over the elements in the array as `Sp<V>`s
792    pub fn iter(&self) -> ArrayIter<'_, V, D> {
793        ArrayIter::new(self)
794    }
795
796    /// Iterate over the elements in the array as `&V` references
797    pub fn iter_deref(&self) -> impl Iterator<Item = &V> {
798        (0..self.len()).filter_map(|i| self.get(i))
799    }
800
801    /// Returns true if empty
802    pub fn is_empty(&self) -> bool {
803        self.0.is_empty()
804    }
805}
806
807impl<V: Storable<D> + serde::Serialize, D: DB> serde::Serialize for Array<V, D> {
808    fn serialize<S: serde::Serializer>(&self, ser: S) -> Result<S::Ok, S::Error> {
809        ser.collect_seq(self.iter().map(|v| v.deref().clone()))
810    }
811}
812
813struct ArrayVisitor<V, D>(PhantomData<(V, D)>);
814
815impl<'de, V: Storable<D> + serde::Deserialize<'de>, D: DB> serde::de::Visitor<'de>
816    for ArrayVisitor<V, D>
817{
818    type Value = Array<V, D>;
819
820    fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
821        write!(formatter, "an array")
822    }
823
824    fn visit_seq<A: serde::de::SeqAccess<'de>>(self, mut seq: A) -> Result<Array<V, D>, A::Error> {
825        Ok(Array::<V, D>::from(
826            &std::iter::from_fn(|| seq.next_element::<V>().transpose())
827                .collect::<Result<std::vec::Vec<V>, A::Error>>()?,
828        ))
829    }
830}
831
832impl<'de, V: Storable<D1> + serde::Deserialize<'de>, D1: DB> serde::Deserialize<'de>
833    for Array<V, D1>
834{
835    fn deserialize<D: serde::Deserializer<'de>>(de: D) -> Result<Self, D::Error> {
836        de.deserialize_seq(ArrayVisitor(PhantomData))
837    }
838}
839
840/// An iterator over `in_memory::Array`
841pub struct ArrayIter<'a, V: Storable<D>, D: DB> {
842    array: &'a Array<V, D>,
843    next_index: usize,
844}
845
846impl<'a, V: Storable<D>, D: DB> ArrayIter<'a, V, D> {
847    fn new(array: &'a Array<V, D>) -> Self {
848        ArrayIter {
849            array,
850            next_index: 0,
851        }
852    }
853}
854
855impl<V: Storable<D>, D: DB> Iterator for ArrayIter<'_, V, D> {
856    type Item = Sp<V, D>;
857
858    fn next(&mut self) -> Option<Self::Item> {
859        let result = self
860            .array
861            .0
862            .lookup_sp(&Array::<V, D>::index_to_nibbles(self.next_index));
863        self.next_index += 1;
864        result
865    }
866}
867
868#[derive(Storable)]
869#[derive_where(Clone, Eq, PartialEq, PartialOrd, Ord, Hash; V)]
870#[storable(db = D)]
871#[tag = "multi-set[v1]"]
872/// A set with quantity. Often known as a bag.
873pub struct MultiSet<V: Serializable + Storable<D>, D: DB> {
874    #[cfg(feature = "public-internal-structure")]
875    pub elements: HashMap<V, u32, D>,
876    #[cfg(not(feature = "public-internal-structure"))]
877    elements: HashMap<V, u32, D>,
878}
879tag_enforcement_test!(MultiSet<(), DefaultDB>);
880
881impl<V: Serializable + Storable<D>, D: DB> Default for MultiSet<V, D> {
882    fn default() -> Self {
883        Self::new()
884    }
885}
886
887impl<V: Serializable + Storable<D>, D: DB> MultiSet<V, D> {
888    /// Create a new, empty `MultiSet`
889    pub fn new() -> Self {
890        MultiSet {
891            elements: HashMap::new(),
892        }
893    }
894
895    /// Insert an element with a quantity of one or, if the element is already in the set, increase its quantity by one
896    #[must_use]
897    pub fn insert(&self, element: V) -> Self {
898        // Add an `entry` fn for HashMap
899        let current_count = self.elements.get(&element).map(|x| *x.deref()).unwrap_or(0);
900        MultiSet {
901            elements: self.elements.insert(element, current_count + 1),
902        }
903    }
904
905    /// Decrement the count of an element, removing it if its count becomes 0
906    #[must_use]
907    pub fn remove(&self, element: &V) -> Self {
908        self.remove_n(element, 1)
909    }
910
911    /// Decrement the count of an element by `n`, removing it if its count becomes 0
912    #[must_use]
913    pub fn remove_n(&self, element: &V, n: u32) -> Self {
914        let current_count = self.elements.get(element).map(|x| *x.deref()).unwrap_or(0);
915        let result = u32::checked_sub(current_count, n).unwrap_or(0);
916        if result == 0 {
917            MultiSet {
918                elements: self.elements.remove(element),
919            }
920        } else {
921            MultiSet {
922                elements: self.elements.insert(element.clone(), result),
923            }
924        }
925    }
926
927    /// How many of a given element are in the structure? Returns 0 when the element isn't present
928    pub fn count(&self, element: &V) -> u32 {
929        match self.elements.get(element) {
930            Some(i) => *i.deref(),
931            None => 0,
932        }
933    }
934
935    /// Check if a `MutliSet` is the subset of another `MutliSet`.
936    pub fn has_subset(&self, other: &MultiSet<V, D>) -> bool {
937        for element_x_other_count in other.elements.iter() {
938            let self_count = self.count(element_x_other_count.0.deref());
939            if self_count < *element_x_other_count.1.deref() {
940                return false;
941            }
942        }
943        true
944    }
945
946    /// Check if the set contains a value.
947    pub fn member(&self, element: &V) -> bool {
948        self.elements.contains_key(element)
949    }
950}
951
952/// A one-element collection.
953#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Serializable)]
954pub struct Identity<V>(pub V);
955
956impl<V: Storable<D>, D: DB> Storable<D> for Identity<V> {
957    fn children(&self) -> std::vec::Vec<ArenaKey<<D as DB>::Hasher>> {
958        self.0.children()
959    }
960    fn from_binary_repr<R: std::io::Read>(
961        reader: &mut R,
962        child_hashes: &mut impl Iterator<Item = ArenaKey<<D as DB>::Hasher>>,
963        loader: &impl Loader<D>,
964    ) -> Result<Self, std::io::Error>
965    where
966        Self: Sized,
967    {
968        V::from_binary_repr(reader, child_hashes, loader).map(Identity)
969    }
970    fn to_binary_repr<W: std::io::Write>(&self, writer: &mut W) -> Result<(), std::io::Error>
971    where
972        Self: Sized,
973    {
974        self.0.to_binary_repr(writer)
975    }
976}
977
978impl<V: Tagged> Tagged for Identity<V> {
979    fn tag() -> std::borrow::Cow<'static, str> {
980        V::tag()
981    }
982    fn tag_unique_factor() -> String {
983        V::tag_unique_factor()
984    }
985}
986
987impl<V> Distribution<Identity<V>> for Standard
988where
989    Standard: Distribution<V>,
990{
991    fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Identity<V> {
992        let v = <Standard as Distribution<V>>::sample(self, rng);
993        Identity(v)
994    }
995}
996
997impl<V> From<V> for Identity<V> {
998    fn from(v: V) -> Self {
999        Identity(v)
1000    }
1001}
1002
1003impl<T: Storable<D> + Serializable, D: DB, A: Storable<D> + Annotation<(Sp<T, D>, Sp<(), D>)>>
1004    Semigroup for HashSet<T, D, A>
1005{
1006    fn append(&self, other: &Self) -> Self {
1007        self.union(other)
1008    }
1009}
1010
1011/// An abstract container of items
1012pub trait Container<D: DB> {
1013    /// The contained type
1014    type Item: Storable<D> + Clone + PartialEq + Eq + PartialOrd + Ord + Hash;
1015    /// Gets an iterator over the `Container`'s items
1016    fn iter_items(self) -> impl Iterator<Item = Self::Item>;
1017    /// Wrap a single item in a `Container``
1018    fn once(_: Self::Item) -> Self;
1019}
1020
1021impl<T: Storable<D> + Clone + PartialEq + Eq + PartialOrd + Ord + Hash, D: DB> Container<D>
1022    for Identity<T>
1023{
1024    type Item = T;
1025    fn iter_items(self) -> impl Iterator<Item = Self::Item> {
1026        std::iter::once(self.0)
1027    }
1028
1029    fn once(item: Self::Item) -> Self {
1030        Self(item)
1031    }
1032}
1033
1034impl<
1035    T: Serializable + Storable<D> + Clone + PartialEq + Eq + PartialOrd + Ord + Hash,
1036    D: DB,
1037    A: Storable<D> + Annotation<(Sp<T, D>, Sp<(), D>)>,
1038> Container<D> for HashSet<T, D, A>
1039{
1040    type Item = T;
1041
1042    fn iter_items(self) -> impl Iterator<Item = Self::Item> {
1043        self.0.keys().collect::<Vec<_>>().into_iter()
1044    }
1045
1046    fn once(item: Self::Item) -> Self {
1047        Self::new().insert(item)
1048    }
1049}
1050
1051#[cfg(feature = "public-internal-structure")]
1052#[derive(Debug)]
1053pub struct BigEndianU64(u64);
1054#[cfg(not(feature = "public-internal-structure"))]
1055#[derive(Debug)]
1056struct BigEndianU64(u64);
1057
1058impl Serializable for BigEndianU64 {
1059    fn serialize(&self, writer: &mut impl std::io::Write) -> std::io::Result<()> {
1060        writer.write_all(&self.0.to_be_bytes())
1061    }
1062    fn serialized_size(&self) -> usize {
1063        8
1064    }
1065}
1066
1067impl Deserializable for BigEndianU64 {
1068    fn deserialize(
1069        reader: &mut impl std::io::Read,
1070        _recursion_depth: u32,
1071    ) -> std::io::Result<Self> {
1072        let mut buf = [0u8; 8];
1073        reader.read_exact(&mut buf[..])?;
1074        Ok(BigEndianU64(u64::from_be_bytes(buf)))
1075    }
1076}
1077
1078/// A mapping from `Timestamp`s to values
1079///
1080/// `Timestamp`s are big-endian encoded to allow for efficient predecessor
1081/// searching and pruning.
1082#[derive(Storable)]
1083#[derive_where(Clone, Eq, PartialEq, PartialOrd, Ord, Hash; C)]
1084#[storable(db = D)]
1085pub struct TimeFilterMap<C: Serializable + Storable<D>, D: DB>
1086where
1087    C: Container<D> + Serializable + Storable<D>,
1088    <C as Container<D>>::Item: Serializable + Storable<D>,
1089{
1090    #[cfg(feature = "public-internal-structure")]
1091    pub time_map: Map<BigEndianU64, C, D>,
1092    #[cfg(feature = "public-internal-structure")]
1093    pub set: MultiSet<<C as Container<D>>::Item, D>,
1094    #[cfg(not(feature = "public-internal-structure"))]
1095    time_map: Map<BigEndianU64, C, D>,
1096    #[cfg(not(feature = "public-internal-structure"))]
1097    set: MultiSet<<C as Container<D>>::Item, D>,
1098}
1099impl<C: Serializable + Storable<D>, D: DB> Tagged for TimeFilterMap<C, D>
1100where
1101    C: Container<D> + Serializable + Storable<D> + Tagged,
1102    <C as Container<D>>::Item: Serializable + Storable<D> + Tagged,
1103{
1104    fn tag() -> std::borrow::Cow<'static, str> {
1105        format!("time-filter-map[v1]({})", C::tag()).into()
1106    }
1107    fn tag_unique_factor() -> String {
1108        format!(
1109            "({},{})",
1110            C::tag(),
1111            <MultiSet<<C as Container<D>>::Item, D>>::tag()
1112        )
1113    }
1114}
1115tag_enforcement_test!(TimeFilterMap<Identity<()>, DefaultDB>);
1116
1117impl<V: Container<D> + Debug + Storable<D> + Serializable, D: DB> Debug for TimeFilterMap<V, D>
1118where
1119    <V as Container<D>>::Item: Serializable,
1120{
1121    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1122        self.time_map.fmt(f)
1123    }
1124}
1125
1126impl<C: Container<D> + Debug + Storable<D> + Serializable, D: DB> Default for TimeFilterMap<C, D>
1127where
1128    <C as Container<D>>::Item: Serializable,
1129{
1130    fn default() -> Self {
1131        Self::new()
1132    }
1133}
1134
1135impl<C: Container<D> + Debug + Storable<D> + Serializable, D: DB> TimeFilterMap<C, D>
1136where
1137    <C as Container<D>>::Item: Serializable,
1138{
1139    /// Create a new `TimeFilterMap`
1140    pub fn new() -> Self {
1141        TimeFilterMap {
1142            time_map: Map::new(),
1143            set: MultiSet::new(),
1144        }
1145    }
1146
1147    /// Return either a value precisely at the provided `Timestamp`, or the value at the next-earliest `Timestamp`, if one exists
1148    pub fn get(&self, ts: Timestamp) -> Option<&C> {
1149        let ts = BigEndianU64(ts.to_secs());
1150        match self.time_map.get(&ts) {
1151            Some(res) => Some(res),
1152            None => self.time_map.find_predecessor(&ts).map(|(_, v)| v),
1153        }
1154    }
1155
1156    /// Insert a value at the given `Timestamp`. If an entry at the `Timestamp` already exists, its value is replaced.
1157    ///
1158    /// Note: Despite the value being a `Container` item, this method does _not_ append to existing entries, it replaces them.
1159    #[must_use]
1160    pub fn insert(&self, ts: Timestamp, v: <C as Container<D>>::Item) -> Self {
1161        let mut res = self.clone();
1162        if let Some(x) = self.time_map.get(&BigEndianU64(ts.to_secs())) {
1163            for val in x.clone().iter_items() {
1164                res.set = res.set.remove(&val);
1165            }
1166        }
1167
1168        res.time_map = res
1169            .time_map
1170            .insert(BigEndianU64(ts.to_secs()), C::once(v.clone()));
1171        res.set = res.set.insert(v);
1172        res
1173    }
1174
1175    /// Insert a value at the given `Timestamp`. If an entry at the `Timestamp` already exists, this value is appended to it.
1176    #[must_use]
1177    pub fn upsert_one(&self, ts: Timestamp, v: <C as Container<D>>::Item) -> Self
1178    where
1179        C: Semigroup + Default,
1180    {
1181        self.upsert(ts, &C::once(v))
1182    }
1183
1184    /// Insert or update all values into a `Container` into our `TimeFilterMap`
1185    #[must_use]
1186    pub fn upsert(&self, ts: Timestamp, v: &C) -> Self
1187    where
1188        C: Semigroup + Default,
1189    {
1190        let xs = self
1191            .time_map
1192            .get(&BigEndianU64(ts.to_secs()))
1193            .cloned()
1194            .unwrap_or_default();
1195        let mut res = self.clone();
1196        for new_val in v.clone().iter_items() {
1197            res.set = res.set.insert(new_val.clone());
1198        }
1199        res.time_map = res
1200            .time_map
1201            .insert(BigEndianU64(ts.to_secs()), xs.append(v));
1202
1203        res
1204    }
1205
1206    /// Check if the `TimeFilterMap` contains a value
1207    pub fn contains(&self, v: &<C as Container<D>>::Item) -> bool {
1208        self.set.member(v)
1209    }
1210
1211    /// Check if `TimeFilterMap` contains all values in a `Container`
1212    pub fn contains_all(&self, v: C) -> bool {
1213        v.iter_items().all(|val| self.set.member(&val))
1214    }
1215
1216    /// Removes all entries with keys before the `cutoff_timestamp`
1217    #[must_use]
1218    pub fn filter(&self, cutoff_timestamp: Timestamp) -> Self
1219    where
1220        MerklePatriciaTrie<C, D>: 'static + Clone,
1221    {
1222        let cutoff_key = to_nibbles(&BigEndianU64(cutoff_timestamp.to_secs()));
1223
1224        let mut res = self.clone();
1225        let (new_mpt, removed_items_for_set) = self.time_map.mpt.prune(&cutoff_key);
1226        res.time_map.mpt = Sp::new(new_mpt);
1227
1228        for items in removed_items_for_set {
1229            for item in items.deref().clone().iter_items() {
1230                res.set = res.set.remove(&item);
1231            }
1232        }
1233        res
1234    }
1235}
1236
1237/// A persistently stored map, guaranteeing O(1) clones and log-time
1238/// modifications.
1239#[derive_where(PartialEq, Eq, PartialOrd, Ord; V, A)]
1240#[derive_where(Hash, Clone)]
1241pub struct Map<K, V: Storable<D>, D: DB = DefaultDB, A: Storable<D> + Annotation<V> = SizeAnn> {
1242    #[cfg(feature = "public-internal-structure")]
1243    pub mpt: Sp<MerklePatriciaTrie<V, D, A>, D>,
1244    #[cfg(feature = "public-internal-structure")]
1245    pub key_type: PhantomData<K>,
1246    #[cfg(not(feature = "public-internal-structure"))]
1247    pub(crate) mpt: Sp<MerklePatriciaTrie<V, D, A>, D>,
1248    #[cfg(not(feature = "public-internal-structure"))]
1249    key_type: PhantomData<K>,
1250}
1251
1252impl<K: Tagged, V: Storable<D> + Tagged, D: DB, A: Storable<D> + Annotation<V> + Tagged> Tagged
1253    for Map<K, V, D, A>
1254{
1255    fn tag() -> std::borrow::Cow<'static, str> {
1256        format!("mpt-map({},{},{})", K::tag(), V::tag(), A::tag()).into()
1257    }
1258    fn tag_unique_factor() -> String {
1259        <MerklePatriciaTrie<V, D, A>>::tag_unique_factor()
1260    }
1261}
1262tag_enforcement_test!(Map<(), ()>);
1263
1264impl<
1265    K: Sync + Send + 'static + Deserializable,
1266    V: Storable<D>,
1267    D: DB,
1268    A: Storable<D> + Annotation<V>,
1269> Storable<D> for Map<K, V, D, A>
1270{
1271    /// Rather than in-lining the wrapped MPT it is a child such that we know the public Map has
1272    /// only a single child element (rather than up to 16)
1273    fn children(&self) -> std::vec::Vec<ArenaKey<D::Hasher>> {
1274        vec![Sp::as_child(&self.mpt)]
1275    }
1276
1277    fn to_binary_repr<W: std::io::Write>(&self, _writer: &mut W) -> Result<(), std::io::Error>
1278    where
1279        Self: Sized,
1280    {
1281        Ok(())
1282    }
1283
1284    fn from_binary_repr<R: std::io::Read>(
1285        _reader: &mut R,
1286        child_hashes: &mut impl Iterator<Item = ArenaKey<D::Hasher>>,
1287        loader: &impl Loader<D>,
1288    ) -> Result<Self, std::io::Error>
1289    where
1290        Self: Sized,
1291    {
1292        let res = Self {
1293            mpt: loader.get_next(child_hashes)?,
1294            key_type: PhantomData,
1295        };
1296        loader.do_check(res)
1297    }
1298
1299    fn check_invariant(&self) -> Result<(), std::io::Error> {
1300        self.mpt
1301            .iter()
1302            .try_for_each(|(k, _)| from_nibbles::<K>(&k).and(Ok(())))
1303    }
1304}
1305
1306impl<
1307    K: Sync + Send + 'static + Deserializable,
1308    V: Storable<D>,
1309    D: DB,
1310    A: Storable<D> + Annotation<V>,
1311> Serializable for Map<K, V, D, A>
1312{
1313    fn serialize(&self, writer: &mut impl std::io::Write) -> std::io::Result<()> {
1314        Sp::new(self.clone()).serialize(writer)
1315    }
1316    fn serialized_size(&self) -> usize {
1317        Sp::new(self.clone()).serialized_size()
1318    }
1319}
1320
1321impl<
1322    K: Sync + Send + 'static + Deserializable,
1323    V: Storable<D>,
1324    D: DB,
1325    A: Storable<D> + Annotation<V>,
1326> Deserializable for Map<K, V, D, A>
1327{
1328    fn deserialize(reader: &mut impl std::io::Read, recursion_depth: u32) -> std::io::Result<Self> {
1329        <Sp<Map<K, V, D, A>, D> as Deserializable>::deserialize(reader, recursion_depth)
1330            .map(|s| (*s).clone())
1331    }
1332}
1333
1334impl<K, V, D, A> FromIterator<(K, V)> for Map<K, V, D, A>
1335where
1336    K: Serializable + Deserializable,
1337    V: Storable<D>,
1338    D: DB,
1339    A: Storable<D> + Annotation<V>,
1340{
1341    fn from_iter<T: IntoIterator<Item = (K, V)>>(iter: T) -> Self {
1342        iter.into_iter()
1343            .fold(Map::new(), |map, (k, v)| map.insert(k, v))
1344    }
1345}
1346
1347#[cfg(feature = "proptest")]
1348impl<
1349    K: Serializable + Deserializable + Debug,
1350    V: Storable<D> + Debug,
1351    D: DB,
1352    A: Storable<D> + Annotation<V>,
1353> Arbitrary for Map<K, V, D, A>
1354where
1355    Standard: Distribution<V> + Distribution<K>,
1356{
1357    type Strategy = NoStrategy<Map<K, V, D, A>>;
1358    type Parameters = ();
1359
1360    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
1361        NoStrategy(PhantomData)
1362    }
1363}
1364
1365impl<K: Serializable + Deserializable, V: Storable<D>, D: DB, A: Storable<D> + Annotation<V>>
1366    Distribution<Map<K, V, D, A>> for Standard
1367where
1368    Standard: Distribution<V> + Distribution<K>,
1369{
1370    fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Map<K, V, D, A> {
1371        let mut map = Map::new();
1372        let size: usize = rng.gen_range(0..8);
1373
1374        for _ in 0..size {
1375            map = map.insert(rng.r#gen(), rng.r#gen());
1376        }
1377        map
1378    }
1379}
1380
1381impl<
1382    K: serde::Serialize + Serializable + Deserializable,
1383    V: Storable<D> + serde::Serialize,
1384    D: DB,
1385    A: Storable<D> + Annotation<V>,
1386> serde::Serialize for Map<K, V, D, A>
1387{
1388    fn serialize<S: serde::Serializer>(&self, ser: S) -> Result<S::Ok, S::Error> {
1389        ser.collect_map(self.iter().map(|kv| (kv.0, kv.1.deref().clone())))
1390    }
1391}
1392
1393struct MapVisitor<K, V, D, A>(PhantomData<(K, V, D, A)>);
1394
1395impl<
1396    'de,
1397    K: serde::Deserialize<'de> + Serializable + Deserializable,
1398    V: Storable<D> + serde::Deserialize<'de>,
1399    D: DB,
1400    A: Storable<D> + Annotation<V>,
1401> serde::de::Visitor<'de> for MapVisitor<K, V, D, A>
1402{
1403    type Value = Map<K, V, D, A>;
1404
1405    fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
1406        write!(formatter, "a map")
1407    }
1408
1409    fn visit_map<ACC: serde::de::MapAccess<'de>>(
1410        self,
1411        mut seq: ACC,
1412    ) -> Result<Map<K, V, D, A>, ACC::Error> {
1413        std::iter::from_fn(|| seq.next_entry::<K, V>().transpose()).collect()
1414    }
1415}
1416
1417impl<
1418    'de,
1419    K: serde::Deserialize<'de> + Serializable + Deserializable,
1420    V: serde::Deserialize<'de> + Storable<D1>,
1421    D1: DB,
1422    A: Storable<D1> + Annotation<V>,
1423> serde::Deserialize<'de> for Map<K, V, D1, A>
1424{
1425    fn deserialize<D: serde::Deserializer<'de>>(de: D) -> Result<Self, D::Error> {
1426        de.deserialize_map(MapVisitor(PhantomData))
1427    }
1428}
1429
1430fn to_nibbles<T: Serializable>(value: &T) -> std::vec::Vec<u8> {
1431    let mut bytes = std::vec::Vec::new();
1432    T::serialize(value, &mut bytes).unwrap();
1433    let mut nibbles = std::vec::Vec::new();
1434    for b in bytes {
1435        nibbles.push((b & 0xf0) >> 4);
1436        nibbles.push(b & 0x0f);
1437    }
1438
1439    nibbles
1440}
1441
1442fn from_nibbles<T: Deserializable>(value: &[u8]) -> std::io::Result<T> {
1443    if value.iter().any(|v| *v >= 16) {
1444        return Err(std::io::Error::new(
1445            std::io::ErrorKind::InvalidData,
1446            "nibble out of range",
1447        ));
1448    }
1449    let bytes = value
1450        .chunks(2)
1451        .map(|nibbles_pair| {
1452            if nibbles_pair.len() != 2 {
1453                return Err(std::io::Error::new(
1454                    std::io::ErrorKind::InvalidData,
1455                    "nibble array must have even length",
1456                ));
1457            }
1458            Ok((nibbles_pair[0] << 4) | nibbles_pair[1])
1459        })
1460        .collect::<Result<std::vec::Vec<u8>, std::io::Error>>()?;
1461    T::deserialize(&mut &bytes[..], 0)
1462}
1463
1464impl<K: Serializable + Deserializable, V: Storable<D>, D: DB, A: Storable<D> + Annotation<V>>
1465    Default for Map<K, V, D, A>
1466{
1467    fn default() -> Self {
1468        Self::new()
1469    }
1470}
1471
1472impl<K: Serializable + Deserializable, V: Storable<D>, D: DB, A: Storable<D> + Annotation<V>>
1473    Map<K, V, D, A>
1474{
1475    /// Returns an empty map.
1476    pub fn new() -> Self {
1477        Self {
1478            mpt: Sp::new(MerklePatriciaTrie::new()),
1479            key_type: PhantomData,
1480        }
1481    }
1482
1483    /// Insert a key-value pair into the map. Must be `O(log(|self|))`.
1484    #[must_use]
1485    pub fn insert(&self, key: K, value: V) -> Self {
1486        Map {
1487            mpt: Sp::new(self.mpt.insert(&to_nibbles(&key), value)),
1488            key_type: self.key_type,
1489        }
1490    }
1491
1492    /// Remove a key from the map. Must be `O(log(|self|))`
1493    #[must_use]
1494    pub fn remove(&self, key: &K) -> Self {
1495        Map {
1496            mpt: Sp::new(self.mpt.remove(&to_nibbles(&key))),
1497            key_type: self.key_type,
1498        }
1499    }
1500
1501    /// Consume internal pointers, returning only the leaves left dangling by this.
1502    /// Used for custom `Drop` implementations.
1503    pub fn into_inner_for_drop(self) -> impl Iterator<Item = V> {
1504        Sp::into_inner(self.mpt)
1505            .into_iter()
1506            .flat_map(MerklePatriciaTrie::into_inner_for_drop)
1507    }
1508
1509    /// Iterate over the key-value pairs in the map in a deterministic, but unspecified order.
1510    pub fn iter(&self) -> impl Iterator<Item = (K, Sp<V, D>)> + use<'_, K, V, D, A> + '_ {
1511        self.mpt.iter().filter_map(|(p, v)| {
1512            // The path should always decode as nibbles if the map is well
1513            // formed, but at the moment a ill-formed maps can be created by
1514            // deserialization.
1515            let key = from_nibbles::<K>(&p).ok()?;
1516            Some((key, v))
1517        })
1518    }
1519
1520    /// Iterator over the keys in the map in a deterministic, but unspecified order.
1521    pub fn keys(&self) -> impl Iterator<Item = K> + use<'_, K, V, D, A> + '_ {
1522        self.iter().map(|(k, _)| k)
1523    }
1524
1525    /// Check if the map contains a key. Must be `O(log(|self|))`.
1526    pub fn contains_key<Q>(&self, key: &Q) -> bool
1527    where
1528        K: Borrow<Q>,
1529        Q: Ord + Serializable,
1530    {
1531        self.mpt.lookup(&to_nibbles(key)).is_some()
1532    }
1533
1534    /// Retrieve the value stored at a key, if applicable. Must be `O(log(|self|))`.
1535    pub fn get<Q>(&self, key: &Q) -> Option<&V>
1536    where
1537        K: Borrow<Q>,
1538        Q: Serializable,
1539    {
1540        self.mpt.lookup(&to_nibbles(&key))
1541    }
1542
1543    /// Lookup as Sp instead of raw `V` value.
1544    pub fn lookup_sp<Q>(&self, key: &Q) -> Option<Sp<V, D>>
1545    where
1546        Q: Serializable,
1547    {
1548        self.mpt.lookup_sp(&to_nibbles(key))
1549    }
1550
1551    /// Check if the map is empty. Must be O(1).
1552    pub fn is_empty(&self) -> bool {
1553        self.mpt.is_empty()
1554    }
1555
1556    /// Retrieve the number of key-value pairs in the map. Must be O(1).
1557    pub fn size(&self) -> usize {
1558        self.mpt.deref().clone().size()
1559    }
1560
1561    fn build_from_mpt(&self, mpt: MerklePatriciaTrie<V, D, A>) -> Self {
1562        Self {
1563            mpt: Sp::new(mpt),
1564            key_type: self.key_type,
1565        }
1566    }
1567
1568    /// Retrieve the annotation on the root of the trie
1569    pub fn ann(&self) -> A {
1570        self.mpt.ann()
1571    }
1572}
1573
1574impl<V: Storable<D>, D: DB, A: Storable<D> + Annotation<V>> Map<BigEndianU64, V, D, A> {
1575    /// Find the nearest predecessor to a given `target_path`
1576    pub fn find_predecessor<'a>(
1577        &'a self,
1578        target_path: &BigEndianU64,
1579    ) -> Option<(std::vec::Vec<u8>, &'a V)> {
1580        let target_nibbles = to_nibbles(&target_path);
1581        self.mpt.find_predecessor(target_nibbles.as_slice())
1582    }
1583
1584    /// Prunes all paths which are lexicographically less than or equal to `target_path`.
1585    /// Returns the updated tree, and a vector of the removed leaves.
1586    ///
1587    /// # Panics
1588    ///
1589    /// If any values in `target_path` are not `u4` nibbles, i.e. larger than
1590    /// 15.
1591    pub fn prune(&self, target_path: &[u8]) -> (Self, std::vec::Vec<Sp<V, D>>) {
1592        let (mpt, removed) = self.mpt.prune(target_path);
1593        (self.build_from_mpt(mpt), removed)
1594    }
1595}
1596
1597enum Decodable<T> {
1598    Yes(T),
1599    No,
1600}
1601
1602impl<T, E> From<Result<T, E>> for Decodable<T> {
1603    fn from(value: Result<T, E>) -> Self {
1604        match value {
1605            Ok(v) => Decodable::Yes(v),
1606            Err(_) => Decodable::No,
1607        }
1608    }
1609}
1610
1611impl<T: Debug> Debug for Decodable<T> {
1612    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1613        match self {
1614            Decodable::Yes(v) => v.fmt(f),
1615            Decodable::No => write!(f, "!decode error!"),
1616        }
1617    }
1618}
1619
1620impl<K: Deserializable + Debug, V: Storable<D> + Debug, D: DB, A: Storable<D> + Annotation<V>> Debug
1621    for Map<K, V, D, A>
1622{
1623    fn fmt(&self, formatter: &mut Formatter) -> std::fmt::Result {
1624        formatter
1625            .debug_map()
1626            .entries(
1627                self.mpt
1628                    .iter()
1629                    .map(|(k, v)| (Decodable::from(from_nibbles::<K>(&k)), v)),
1630            )
1631            .finish()
1632    }
1633}
1634
1635// TODO: remove and use clones at IntoIter callsite
1636impl<
1637    K: Clone + Serializable + Deserializable,
1638    V: Clone + Storable<D>,
1639    D: DB,
1640    A: Storable<D> + Annotation<V>,
1641> IntoIterator for Map<K, V, D, A>
1642{
1643    type Item = (K, V);
1644    type IntoIter = std::vec::IntoIter<(K, V)>;
1645
1646    fn into_iter(self) -> Self::IntoIter {
1647        self.iter()
1648            .map(|(k, x)| (k, x.deref().clone()))
1649            .collect::<std::vec::Vec<_>>()
1650            .into_iter()
1651    }
1652}
1653
1654#[cfg(test)]
1655mod tests {
1656    use super::*;
1657    use crate::storable::SMALL_OBJECT_LIMIT;
1658
1659    #[test]
1660    fn iter_map() {
1661        let mut map = Map::<_, _>::new();
1662        map = map.insert(1, 4);
1663        map = map.insert(2, 5);
1664        map = map.insert(3, 6);
1665        for (k, v) in map.iter() {
1666            match (k, Sp::deref(&v)) {
1667                (1, 4) | (2, 5) | (3, 6) => {}
1668                _ => unreachable!(),
1669            }
1670        }
1671    }
1672
1673    #[test]
1674    fn array_get() {
1675        let array: super::Array<_> = vec![0, 1, 2, 3].into();
1676        assert_eq!(array.get(0).cloned(), Some(0));
1677        assert_eq!(array.get(1).cloned(), Some(1));
1678        assert_eq!(array.get(2).cloned(), Some(2));
1679        assert_eq!(array.get(3).cloned(), Some(3));
1680        assert!(array.get(4).is_none());
1681        assert!(array.get(5).is_none());
1682        assert!(array.get(6).is_none());
1683        assert!(array.get(7).is_none());
1684    }
1685
1686    #[test]
1687    fn array_push() {
1688        let mut array = super::Array::<u32>::new();
1689        assert_eq!(array.len(), 0);
1690        array = array.push(0);
1691        assert_eq!(array.len(), 1);
1692        array = array.push(1);
1693        assert_eq!(array.len(), 2);
1694        assert_eq!(array, vec![0, 1].into());
1695    }
1696
1697    #[test]
1698    fn array_with_more_than_16_elements() {
1699        let _: super::Array<_> = (0..1024).collect();
1700    }
1701
1702    #[test]
1703    fn array_index_to_nibbles_is_big_endian() {
1704        assert_eq!(Array::<u8>::index_to_nibbles(0), Vec::<u8>::new());
1705        assert_eq!(Array::<u8>::index_to_nibbles(1), vec![1]);
1706        assert_eq!(Array::<u8>::index_to_nibbles(15), vec![15]);
1707        assert_eq!(Array::<u8>::index_to_nibbles(16), vec![1, 0]);
1708        assert_eq!(Array::<u8>::index_to_nibbles(255), vec![15, 15]);
1709        assert_eq!(Array::<u8>::index_to_nibbles(256), vec![1, 0, 0]);
1710        assert_eq!(
1711            Array::<u8>::index_to_nibbles((1 << 12) - 1),
1712            vec![15, 15, 15]
1713        );
1714        assert_eq!(Array::<u8>::index_to_nibbles(1 << 12), vec![1, 0, 0, 0]);
1715        assert_eq!(
1716            Array::<u8>::index_to_nibbles((1 << 32) - 1),
1717            vec![15; 32 / 4]
1718        );
1719        let mut expected = vec![0; 32 / 4 + 1];
1720        expected[0] = 1;
1721        assert_eq!(Array::<u8>::index_to_nibbles(1 << 32), expected);
1722    }
1723
1724    #[test]
1725    fn test_map_iterators() {
1726        let map = Map::<_, _>::new()
1727            .insert(40026u64, 12u64)
1728            .insert(12u64, 40026u64);
1729        let mut keys = map.keys().collect::<std::vec::Vec<_>>();
1730        keys.sort();
1731        assert_eq!(keys, vec![12u64, 40026u64]);
1732        let mut entries = map
1733            .iter()
1734            .map(|(k, v)| (k, *(v.deref())))
1735            .collect::<std::vec::Vec<_>>();
1736        entries.sort();
1737        assert_eq!(entries, vec![(12u64, 40026u64), (40026u64, 12u64)]);
1738    }
1739
1740    #[test]
1741    fn test_hashmap() {
1742        let mut hashmap = HashMap::<_, _>::new()
1743            .insert(40026u64, 12u64)
1744            .insert(12u64, 40026u64);
1745
1746        assert_eq!(hashmap.get(&40026u64).map(|sp| *(sp.deref())), Some(12u64));
1747        assert_eq!(hashmap.get(&12u64).map(|sp| *(sp.deref())), Some(40026u64));
1748        hashmap = hashmap.remove(&12u64);
1749        assert_eq!(hashmap.get(&12u64), None);
1750    }
1751
1752    #[test]
1753    fn test_predecessor_when_target_is_branch_prefix_time_map() {
1754        let mut time_map = TimeFilterMap::<Identity<i32>, InMemoryDB>::new();
1755
1756        time_map = time_map.insert(Timestamp::from_secs(1), 1);
1757        time_map = time_map.insert(Timestamp::from_secs(256), 256);
1758        time_map = time_map.insert(Timestamp::from_secs(512), 512);
1759
1760        assert_eq!(
1761            time_map.get(Timestamp::from_secs(257)).copied(),
1762            Some(Identity(256))
1763        );
1764    }
1765
1766    #[test]
1767    fn test_smoke_time_map() {
1768        let mut time_map = TimeFilterMap::<Identity<i32>, InMemoryDB>::new();
1769        time_map = time_map.insert(Timestamp::from_secs(1), 1);
1770        time_map = time_map.insert(Timestamp::from_secs(2), 2);
1771
1772        assert_eq!(time_map.get(Timestamp::from_secs(0)).copied(), None);
1773        assert_eq!(
1774            time_map.get(Timestamp::from_secs(1)).copied(),
1775            Some(Identity(1))
1776        );
1777        assert_eq!(
1778            time_map.get(Timestamp::from_secs(2)).copied(),
1779            Some(Identity(2))
1780        );
1781        assert_eq!(
1782            time_map.get(Timestamp::from_secs(3)).copied(),
1783            Some(Identity(2))
1784        );
1785        assert!(!time_map.contains(&0));
1786        assert!(time_map.contains(&1));
1787        assert!(time_map.contains(&2));
1788        assert!(!time_map.contains(&3));
1789        // Drop all things before the first item
1790        time_map = time_map.filter(Timestamp::from_secs(2));
1791        // First item should be gone now
1792        assert_eq!(time_map.get(Timestamp::from_secs(1)).copied(), None);
1793        assert_eq!(
1794            time_map.get(Timestamp::from_secs(2)).copied(),
1795            Some(Identity(2))
1796        );
1797        assert_eq!(
1798            time_map.get(Timestamp::from_secs(3)).copied(),
1799            Some(Identity(2))
1800        );
1801        assert!(!time_map.contains(&0));
1802        assert!(!time_map.contains(&1));
1803        assert!(time_map.contains(&2));
1804        assert!(!time_map.contains(&3));
1805
1806        // Fails if to_nibbles bit emission order is reversed (as it was originally)
1807        time_map = time_map.insert(Timestamp::from_secs(16), 16);
1808        assert_eq!(
1809            time_map.get(Timestamp::from_secs(16)).copied(),
1810            Some(Identity(16))
1811        );
1812        assert_eq!(
1813            time_map.get(Timestamp::from_secs(17)).copied(),
1814            Some(Identity(16))
1815        );
1816
1817        time_map = time_map.filter(Timestamp::from_secs(2));
1818
1819        assert_eq!(time_map.get(Timestamp::from_secs(1)).copied(), None);
1820        assert_eq!(
1821            time_map.get(Timestamp::from_secs(2)).copied(),
1822            Some(Identity(2))
1823        );
1824        assert_eq!(
1825            time_map.get(Timestamp::from_secs(3)).copied(),
1826            Some(Identity(2))
1827        );
1828        assert_eq!(
1829            time_map.get(Timestamp::from_secs(17)).copied(),
1830            Some(Identity(16))
1831        );
1832
1833        assert!(!time_map.contains(&0));
1834        assert!(!time_map.contains(&1));
1835        assert!(time_map.contains(&2));
1836        assert!(!time_map.contains(&3));
1837        assert!(time_map.contains(&16));
1838
1839        // Fails if little-endian encoded during serialisation
1840        time_map = time_map.insert(Timestamp::from_secs(256), 256);
1841
1842        assert_eq!(
1843            time_map.get(Timestamp::from_secs(256)).copied(),
1844            Some(Identity(256))
1845        );
1846        assert_eq!(
1847            time_map.get(Timestamp::from_secs(257)).copied(),
1848            Some(Identity(256))
1849        );
1850
1851        time_map = time_map.filter(Timestamp::from_secs(2));
1852
1853        assert_eq!(time_map.get(Timestamp::from_secs(1)).copied(), None);
1854        assert_eq!(
1855            time_map.get(Timestamp::from_secs(2)).copied(),
1856            Some(Identity(2))
1857        );
1858        assert_eq!(
1859            time_map.get(Timestamp::from_secs(3)).copied(),
1860            Some(Identity(2))
1861        );
1862        assert_eq!(
1863            time_map.get(Timestamp::from_secs(257)).copied(),
1864            Some(Identity(256))
1865        );
1866
1867        assert!(!time_map.contains(&0));
1868        assert!(!time_map.contains(&1));
1869        assert!(time_map.contains(&2));
1870        assert!(!time_map.contains(&3));
1871        assert!(time_map.contains(&256));
1872    }
1873
1874    #[test]
1875    fn test_get_empty_time_map() {
1876        let time_map = TimeFilterMap::<Identity<i32>, InMemoryDB>::new();
1877        assert_eq!(time_map.get(Timestamp::from_secs(100)), None);
1878    }
1879
1880    #[test]
1881    fn test_insert_duplicate_value_allowed_time_map() {
1882        let mut time_map = TimeFilterMap::<Identity<i32>, InMemoryDB>::new();
1883        assert_eq!(0, time_map.set.count(&100));
1884        time_map = time_map.insert(Timestamp::from_secs(10), 100);
1885        assert_eq!(1, time_map.set.count(&100));
1886        time_map = time_map.insert(Timestamp::from_secs(20), 100);
1887        assert_eq!(2, time_map.set.count(&100));
1888    }
1889
1890    #[test]
1891    fn test_insert_filter_clears_set_via_duplicate_logic_time_map() {
1892        let mut time_map = TimeFilterMap::<Identity<i32>, InMemoryDB>::new();
1893        time_map = time_map.insert(Timestamp::from_secs(10), 100);
1894        time_map = time_map.filter(Timestamp::from_secs(11));
1895        let _ = time_map.insert(Timestamp::from_secs(20), 100); // Should NOT panic
1896    }
1897
1898    #[test]
1899    fn test_replace_existing_timestamp() {
1900        let mut time_map = TimeFilterMap::<Identity<i32>, InMemoryDB>::new();
1901        time_map = time_map.insert(Timestamp::from_secs(100), 1);
1902        assert_eq!(
1903            time_map.get(Timestamp::from_secs(100)).copied(),
1904            Some(Identity(1))
1905        );
1906        assert!(time_map.contains(&1));
1907        time_map = time_map.insert(Timestamp::from_secs(100), 2);
1908        assert!(!time_map.contains(&1));
1909        assert!(time_map.contains(&2));
1910        assert_eq!(
1911            time_map.get(Timestamp::from_secs(100)).copied(),
1912            Some(Identity(2))
1913        );
1914        assert_eq!(
1915            time_map.get(Timestamp::from_secs(101)).copied(),
1916            Some(Identity(2))
1917        );
1918    }
1919
1920    #[test]
1921    fn test_filter_below_minimum_key_time_map() {
1922        let mut time_map = TimeFilterMap::<Identity<i32>, InMemoryDB>::new();
1923        time_map = time_map.insert(Timestamp::from_secs(10), 10);
1924        time_map = time_map.insert(Timestamp::from_secs(20), 20);
1925        time_map = time_map.filter(Timestamp::from_secs(9)); // Cutoff before any existing keys
1926
1927        assert!(time_map.contains(&10));
1928        assert!(time_map.contains(&20));
1929        assert_eq!(
1930            time_map.get(Timestamp::from_secs(10)).copied(),
1931            Some(Identity(10))
1932        );
1933        assert_eq!(
1934            time_map.get(Timestamp::from_secs(20)).copied(),
1935            Some(Identity(20))
1936        );
1937    }
1938
1939    #[test]
1940    fn test_filter_cutoff_above_maximum_key_time_map() {
1941        let mut time_map = TimeFilterMap::<Identity<i32>, InMemoryDB>::new();
1942        time_map = time_map.insert(Timestamp::from_secs(10), 10);
1943        time_map = time_map.insert(Timestamp::from_secs(20), 20);
1944        time_map = time_map.filter(Timestamp::from_secs(21)); // Cutoff after latest key
1945
1946        assert!(!time_map.contains(&10));
1947        assert!(!time_map.contains(&20));
1948        assert_eq!(time_map.get(Timestamp::from_secs(10)), None);
1949        assert_eq!(time_map.get(Timestamp::from_secs(20)), None);
1950    }
1951
1952    #[test]
1953    fn test_filter_exact_match_time_map() {
1954        let mut time_map = TimeFilterMap::<Identity<i32>, InMemoryDB>::new();
1955        time_map = time_map.insert(Timestamp::from_secs(10), 10);
1956        time_map = time_map.insert(Timestamp::from_secs(20), 20);
1957        time_map = time_map.insert(Timestamp::from_secs(30), 30);
1958
1959        time_map = time_map.filter(Timestamp::from_secs(20)); // Prunes keys strictly < 20 (shouldn't remove 20)
1960
1961        assert!(!time_map.contains(&10));
1962        assert!(time_map.contains(&20));
1963        assert!(time_map.contains(&30));
1964        assert_eq!(time_map.get(Timestamp::from_secs(10)), None);
1965        assert_eq!(
1966            time_map.get(Timestamp::from_secs(20)).copied(),
1967            Some(Identity(20))
1968        );
1969        assert_eq!(
1970            time_map.get(Timestamp::from_secs(21)).copied(),
1971            Some(Identity(20))
1972        );
1973        assert_eq!(
1974            time_map.get(Timestamp::from_secs(30)).copied(),
1975            Some(Identity(30))
1976        );
1977    }
1978
1979    #[test]
1980    fn test_multiple_filters_time_map() {
1981        let mut time_map = TimeFilterMap::<Identity<i32>, InMemoryDB>::new();
1982        time_map = time_map.insert(Timestamp::from_secs(10), 10);
1983        time_map = time_map.insert(Timestamp::from_secs(20), 20);
1984        time_map = time_map.insert(Timestamp::from_secs(30), 30);
1985        time_map = time_map.insert(Timestamp::from_secs(40), 40);
1986
1987        time_map = time_map.filter(Timestamp::from_secs(20)); // Removes 10
1988        assert!(!time_map.contains(&10));
1989        assert!(time_map.contains(&20));
1990        assert_eq!(
1991            time_map.get(Timestamp::from_secs(30)).copied(),
1992            Some(Identity(30))
1993        );
1994        assert_eq!(
1995            time_map.get(Timestamp::from_secs(31)).copied(),
1996            Some(Identity(30))
1997        );
1998
1999        time_map = time_map.filter(Timestamp::from_secs(35)); // Removes 20, 30
2000        assert!(!time_map.contains(&20));
2001        assert!(!time_map.contains(&30));
2002        assert!(time_map.contains(&40));
2003        assert_eq!(time_map.get(Timestamp::from_secs(39)), None);
2004        assert_eq!(
2005            time_map.get(Timestamp::from_secs(40)).copied(),
2006            Some(Identity(40))
2007        );
2008        assert_eq!(
2009            time_map.get(Timestamp::from_secs(41)).copied(),
2010            Some(Identity(40))
2011        );
2012
2013        time_map = time_map.filter(Timestamp::from_secs(41)); // Removes 40. The map is now empty.
2014        assert!(!time_map.contains(&40));
2015        assert_eq!(time_map.get(Timestamp::from_secs(40)), None);
2016    }
2017
2018    #[test]
2019    fn test_zero_timestamp_time_map() {
2020        let mut time_map = TimeFilterMap::<Identity<i32>, InMemoryDB>::new();
2021        time_map = time_map.insert(Timestamp::from_secs(0), 0);
2022        time_map = time_map.insert(Timestamp::from_secs(10), 10);
2023
2024        assert_eq!(
2025            time_map.get(Timestamp::from_secs(0)).copied(),
2026            Some(Identity(0))
2027        );
2028        assert_eq!(
2029            time_map.get(Timestamp::from_secs(5)).copied(),
2030            Some(Identity(0))
2031        );
2032        assert!(time_map.contains(&0));
2033
2034        time_map = time_map.filter(Timestamp::from_secs(0));
2035        assert_eq!(
2036            time_map.get(Timestamp::from_secs(0)).copied(),
2037            Some(Identity(0))
2038        );
2039        assert!(time_map.contains(&0));
2040
2041        time_map = time_map.filter(Timestamp::from_secs(1));
2042        assert_eq!(time_map.get(Timestamp::from_secs(0)), None);
2043        assert!(!time_map.contains(&0));
2044        assert_eq!(time_map.get(Timestamp::from_secs(5)), None);
2045        assert_eq!(
2046            time_map.get(Timestamp::from_secs(10)).copied(),
2047            Some(Identity(10))
2048        );
2049    }
2050
2051    #[test]
2052    fn test_large_key_differences_time_map() {
2053        let mut time_map = TimeFilterMap::<Identity<i32>, InMemoryDB>::new();
2054        time_map = time_map.insert(Timestamp::from_secs(10), 10);
2055        time_map = time_map.insert(Timestamp::from_secs(1000000), 1000000);
2056
2057        assert_eq!(time_map.get(Timestamp::from_secs(9)).copied(), None);
2058        assert_eq!(
2059            time_map.get(Timestamp::from_secs(10)).copied(),
2060            Some(Identity(10))
2061        );
2062        assert_eq!(
2063            time_map.get(Timestamp::from_secs(11)).copied(),
2064            Some(Identity(10))
2065        );
2066        assert_eq!(
2067            time_map.get(Timestamp::from_secs(999999)).copied(),
2068            Some(Identity(10))
2069        );
2070        assert_eq!(
2071            time_map.get(Timestamp::from_secs(1000000)).copied(),
2072            Some(Identity(1000000))
2073        );
2074        assert_eq!(
2075            time_map.get(Timestamp::from_secs(1000001)).copied(),
2076            Some(Identity(1000000))
2077        );
2078        assert!(time_map.contains(&10));
2079        assert!(time_map.contains(&1000000));
2080
2081        time_map = time_map.filter(Timestamp::from_secs(10));
2082        assert_eq!(
2083            time_map.get(Timestamp::from_secs(10)).copied(),
2084            Some(Identity(10))
2085        );
2086        assert_eq!(
2087            time_map.get(Timestamp::from_secs(1000000)).copied(),
2088            Some(Identity(1000000))
2089        );
2090        assert!(time_map.contains(&10));
2091        assert!(time_map.contains(&1000000));
2092
2093        time_map = time_map.filter(Timestamp::from_secs(1000000));
2094        assert_eq!(time_map.get(Timestamp::from_secs(10)), None);
2095        assert_eq!(
2096            time_map.get(Timestamp::from_secs(1000000)).copied(),
2097            Some(Identity(1000000))
2098        );
2099        assert!(!time_map.contains(&10));
2100        assert!(time_map.contains(&1000000));
2101
2102        time_map = time_map.filter(Timestamp::from_secs(99999999));
2103        assert_eq!(time_map.get(Timestamp::from_secs(1000000)), None);
2104        assert!(!time_map.contains(&1000000));
2105    }
2106
2107    /// Test default storage APIs, including using `WrappedDB` for isolation.
2108    #[test]
2109    fn test_default_storage() {
2110        // Create isolated storage types for `DefaultDB`.
2111        struct Tag1;
2112        type D1 = WrappedDB<DefaultDB, Tag1>;
2113        struct Tag2;
2114        type D2 = WrappedDB<DefaultDB, Tag2>;
2115
2116        // Check that implicitly creating default storage of type InMemoryDB (if
2117        // necessary) works, by requesting it. Since in theory some other test
2118        // thread could have explicitly set the InMemoryDB default storage, this
2119        // test is not accurate. An accurate test could:
2120        //
2121
2122        // - hold the STORAGES lock
2123        // - remove any existing InMemoryDB that was set by implicit usage in another test thread
2124        // - check that we get a new one implicitly
2125        // - reinsert the old one, if any
2126        // - drop the lock
2127        //
2128        // But the implicit InMemoryDB is just a hack for testing anyway, so
2129        // we'll just check that it's set and not worry about how :)
2130        {
2131            default_storage::<InMemoryDB>();
2132            assert!(try_get_default_storage::<InMemoryDB>().is_some());
2133        }
2134
2135        // Check that default storages of other db types are not created
2136        // implicitly.
2137        assert!(try_get_default_storage::<D1>().is_none());
2138        let result = std::panic::catch_unwind(|| {
2139            default_storage::<D1>();
2140        });
2141        assert!(result.is_err());
2142
2143        // Create a default storage of type D1.
2144        let b1 = set_default_storage::<D1>(Storage::<D1>::default).unwrap();
2145        let s1 = b1.arena.alloc([42u8; SMALL_OBJECT_LIMIT]);
2146        assert!(
2147            default_storage::<D1>()
2148                .get::<[u8; SMALL_OBJECT_LIMIT]>(&s1.as_typed_key())
2149                .is_ok()
2150        );
2151
2152        // Check that D1 and D2 have disjoint default storages, even tho they're
2153        // the same underlying database type.
2154        set_default_storage::<D2>(Storage::<D2>::default).unwrap();
2155        assert!(
2156            default_storage::<D2>()
2157                .get::<[u8; SMALL_OBJECT_LIMIT]>(&s1.as_typed_key())
2158                .is_err()
2159        );
2160
2161        // Drop the D1 default storage and see that we can create a new one.
2162        unsafe_drop_default_storage::<D1>();
2163        assert!(try_get_default_storage::<D1>().is_none());
2164        set_default_storage::<D1>(Storage::<D1>::default).unwrap();
2165        assert!(
2166            default_storage::<D1>()
2167                .get::<[u8; SMALL_OBJECT_LIMIT]>(&s1.as_typed_key())
2168                .is_err()
2169        );
2170
2171        // Check that dropping the default storage for D1 didn't affect existing
2172        // references.
2173        assert!(
2174            b1.get::<[u8; SMALL_OBJECT_LIMIT]>(&s1.as_typed_key())
2175                .is_ok()
2176        );
2177        assert!(
2178            default_storage::<D1>()
2179                .get::<[u8; SMALL_OBJECT_LIMIT]>(&s1.as_typed_key())
2180                .is_err()
2181        );
2182
2183        // Check that we can restore the original D1 default storage (unlikely
2184        // use case ...)
2185        let s = Arc::into_inner(b1).expect("we should have the only reference");
2186        unsafe_drop_default_storage::<D1>();
2187        set_default_storage::<D1>(|| s).unwrap();
2188        assert!(
2189            default_storage::<D1>()
2190                .get::<[u8; SMALL_OBJECT_LIMIT]>(&s1.as_typed_key())
2191                .is_ok()
2192        );
2193    }
2194
2195    #[cfg(feature = "sqlite")]
2196    #[test]
2197    fn persist_to_disk_sqldb() {
2198        use crate::{DefaultHasher, db::SqlDB};
2199
2200        let path = tempfile::NamedTempFile::new().unwrap().into_temp_path();
2201        test_persist_to_disk::<SqlDB<DefaultHasher>>(|| SqlDB::exclusive_file(&path));
2202    }
2203
2204    #[cfg(feature = "parity-db")]
2205    #[test]
2206    fn persist_to_disk_paritydb() {
2207        use crate::{DefaultHasher, db::ParityDb};
2208
2209        let path = tempfile::TempDir::new().unwrap().keep();
2210        test_persist_to_disk::<ParityDb<DefaultHasher>>(|| ParityDb::open(&path));
2211    }
2212
2213    /// Test that persisting objects to disk works:
2214    ///
2215    /// - create a first storage backed by a first db
2216    /// - create an object, persist it, and flush the db
2217    /// - create a second storage, backed by a second db, pointing to the same
2218    ///   file as the first db
2219    /// - reload the object from the second storage and check its correctness
2220    ///
2221    /// This incidentally includes a test of `WrappedDB` and
2222    /// `set_default_storage`.
2223    ///
2224    /// This test doesn't make sense for `InMemoryDB`, because that DB doesn't
2225    /// persist to disk.
2226    #[cfg(any(feature = "sqlite", feature = "parity-db"))]
2227    fn test_persist_to_disk<D: DB>(mk_db: impl Fn() -> D) {
2228        // Create a unique wrapper type for D, to avoid conflicts with
2229        // other tests running using D.
2230        struct Tag;
2231        type W<D> = WrappedDB<D, Tag>;
2232
2233        // Compute key in a block so that everything else gets dropped. Need to
2234        // drop everything to avoid needing non-exclusive access to the DB.
2235        let key1 = {
2236            let db1: W<D> = WrappedDB::wrap(mk_db());
2237            let storage1 = Storage::new(DEFAULT_CACHE_SIZE, db1);
2238            let storage1 = set_default_storage(|| storage1).unwrap();
2239            let arena = &storage1.arena;
2240            let vals1 = vec![1u8, 1, 2, 3, 5];
2241            let array1: super::Array<_, W<D>> = vals1.into();
2242            let mut sp1 = arena.alloc(array1.clone());
2243            sp1.persist();
2244            storage1.with_backend(|backend| backend.flush_all_changes_to_db());
2245            sp1.as_typed_key()
2246        };
2247        unsafe_drop_default_storage::<W<D>>();
2248        std::thread::sleep(std::time::Duration::from_secs(1));
2249
2250        let db2: W<D> = WrappedDB::wrap(mk_db());
2251        let storage2 = Storage::new(DEFAULT_CACHE_SIZE, db2);
2252        let storage2 = set_default_storage(|| storage2).unwrap();
2253        let array1 = storage2.arena.get::<super::Array<_, _>>(&key1).unwrap();
2254        let vals2 = vec![1u8, 1, 2, 3, 5];
2255        let array2: super::Array<_, W<D>> = vals2.into();
2256        assert_eq!(*array1, array2);
2257    }
2258
2259    // Test that malformed map with odd-length nibbles no longer cause panics in
2260    // iteration.
2261    //
2262    // This test was created to demonstrate a crash that has since been fixed in PR#612.
2263    #[test]
2264    fn deserialization_malicious_map() {
2265        use crate::arena::Sp;
2266        use crate::merkle_patricia_trie::{MerklePatriciaTrie, Node};
2267        use serialize::{Deserializable, Serializable};
2268
2269        // Create a malformed Extension node with odd-length nibbles.  This
2270        // bypasses normal validation by directly constructing the node.
2271        let leaf: Node<u32> = Node::Leaf {
2272            ann: SizeAnn(0),
2273            value: Sp::new(42u32),
2274        };
2275        let extension_node = Node::Extension {
2276            ann: SizeAnn(1),
2277            compressed_path: vec![1, 2, 3], // 3 nibbles = odd length!
2278            child: Sp::new(leaf),
2279        };
2280        let mpt = MerklePatriciaTrie(Sp::new(extension_node));
2281        let malformed_map = Map {
2282            mpt: Sp::new(mpt),
2283            key_type: std::marker::PhantomData::<u32>,
2284        };
2285
2286        // Serialize the malformed map to get the attack vector.
2287        // This simulates what an attacker could send as serialized data.
2288        let mut serialized = std::vec::Vec::new();
2289        malformed_map.serialize(&mut serialized).unwrap();
2290
2291        // Deserialize using public API
2292        let mut cursor = std::io::Cursor::new(&serialized);
2293        assert!(Map::<u32, u32>::deserialize(&mut cursor, 0).is_err());
2294    }
2295
2296    #[test]
2297    fn init_many() {
2298        for _ in 0..10_000 {
2299            Array::<()>::new();
2300        }
2301    }
2302}