use crate::StrippableError;
use codec::{Decode, Encode};
use frame_support::PalletError;
use hash_db::{HashDB, Hasher, EMPTY_PREFIX};
use scale_info::TypeInfo;
use sp_std::{boxed::Box, collections::btree_set::BTreeSet, vec::Vec};
use sp_trie::{
read_trie_value, LayoutV1, MemoryDB, Recorder, StorageProof, Trie, TrieConfiguration,
TrieDBBuilder, TrieError, TrieHash,
};
pub type RawStorageProof = Vec<Vec<u8>>;
#[derive(Clone, Copy, Debug)]
pub enum ProofSize {
Minimal(u32),
HasLargeLeaf(u32),
}
pub struct StorageProofChecker<H>
where
H: Hasher,
{
proof_nodes_count: usize,
root: H::Out,
db: MemoryDB<H>,
recorder: Recorder<LayoutV1<H>>,
}
impl<H> StorageProofChecker<H>
where
H: Hasher,
{
pub fn new(root: H::Out, proof: RawStorageProof) -> Result<Self, Error> {
let proof_nodes_count = proof.len();
let proof = StorageProof::new(proof);
if proof_nodes_count != proof.iter_nodes().count() {
return Err(Error::DuplicateNodesInProof)
}
let db = proof.into_memory_db();
if !db.contains(&root, EMPTY_PREFIX) {
return Err(Error::StorageRootMismatch)
}
let recorder = Recorder::default();
let checker = StorageProofChecker { proof_nodes_count, root, db, recorder };
Ok(checker)
}
pub fn ensure_no_unused_nodes(mut self) -> Result<(), Error> {
let visited_nodes = self
.recorder
.drain()
.into_iter()
.map(|record| record.data)
.collect::<BTreeSet<_>>();
let visited_nodes_count = visited_nodes.len();
if self.proof_nodes_count == visited_nodes_count {
Ok(())
} else {
Err(Error::UnusedNodesInTheProof)
}
}
pub fn read_value(&mut self, key: &[u8]) -> Result<Option<Vec<u8>>, Error> {
read_trie_value::<LayoutV1<H>, _>(&self.db, &self.root, key, Some(&mut self.recorder), None)
.map_err(|_| Error::StorageValueUnavailable)
}
pub fn read_and_decode_value<T: Decode>(&mut self, key: &[u8]) -> Result<Option<T>, Error> {
self.read_value(key).and_then(|v| {
v.map(|v| T::decode(&mut &v[..]).map_err(|e| Error::StorageValueDecodeFailed(e.into())))
.transpose()
})
}
pub fn read_and_decode_mandatory_value<T: Decode>(&mut self, key: &[u8]) -> Result<T, Error> {
self.read_and_decode_value(key)?.ok_or(Error::StorageValueEmpty)
}
pub fn read_and_decode_opt_value<T: Decode>(&mut self, key: &[u8]) -> Result<Option<T>, Error> {
match self.read_and_decode_value(key) {
Ok(outbound_lane_data) => Ok(outbound_lane_data),
Err(Error::StorageValueUnavailable) => Ok(None),
Err(e) => Err(e),
}
}
}
#[derive(Encode, Decode, Clone, Eq, PartialEq, PalletError, Debug, TypeInfo)]
pub enum Error {
DuplicateNodesInProof,
UnusedNodesInTheProof,
StorageRootMismatch,
StorageValueUnavailable,
StorageValueEmpty,
StorageValueDecodeFailed(StrippableError<codec::Error>),
}
#[cfg(feature = "std")]
pub fn craft_valid_storage_proof() -> (sp_core::H256, RawStorageProof) {
use sp_state_machine::{backend::Backend, prove_read, InMemoryBackend};
let state_version = sp_runtime::StateVersion::default();
let backend = <InMemoryBackend<sp_core::Blake2Hasher>>::from((
vec![
(None, vec![(b"key1".to_vec(), Some(b"value1".to_vec()))]),
(None, vec![(b"key2".to_vec(), Some(b"value2".to_vec()))]),
(None, vec![(b"key3".to_vec(), Some(b"value3".to_vec()))]),
(None, vec![(b"key4".to_vec(), Some((42u64, 42u32, 42u16, 42u8).encode()))]),
(None, vec![(b"key11".to_vec(), Some(vec![0u8; 32]))]),
],
state_version,
));
let root = backend.storage_root(std::iter::empty(), state_version).0;
let proof =
prove_read(backend, &[&b"key1"[..], &b"key2"[..], &b"key4"[..], &b"key22"[..]]).unwrap();
(root, proof.into_nodes().into_iter().collect())
}
pub fn record_all_keys<L: TrieConfiguration, DB>(
db: &DB,
root: &TrieHash<L>,
) -> Result<RawStorageProof, Box<TrieError<L>>>
where
DB: hash_db::HashDBRef<L::Hash, trie_db::DBValue>,
{
let mut recorder = Recorder::<L>::new();
let trie = TrieDBBuilder::<L>::new(db, root).with_recorder(&mut recorder).build();
for x in trie.iter()? {
let (key, _) = x?;
trie.get(&key)?;
}
Ok(recorder
.drain()
.into_iter()
.map(|n| n.data.to_vec())
.collect::<BTreeSet<_>>()
.into_iter()
.collect())
}
#[cfg(test)]
pub mod tests {
use super::*;
use codec::Encode;
#[test]
fn storage_proof_check() {
let (root, proof) = craft_valid_storage_proof();
let mut checker =
<StorageProofChecker<sp_core::Blake2Hasher>>::new(root, proof.clone()).unwrap();
assert_eq!(checker.read_value(b"key1"), Ok(Some(b"value1".to_vec())));
assert_eq!(checker.read_value(b"key2"), Ok(Some(b"value2".to_vec())));
assert_eq!(checker.read_value(b"key4"), Ok(Some((42u64, 42u32, 42u16, 42u8).encode())));
assert_eq!(checker.read_value(b"key11111"), Err(Error::StorageValueUnavailable));
assert_eq!(checker.read_value(b"key22"), Ok(None));
assert_eq!(checker.read_and_decode_value(b"key4"), Ok(Some((42u64, 42u32, 42u16, 42u8))),);
assert!(matches!(
checker.read_and_decode_value::<[u8; 64]>(b"key4"),
Err(Error::StorageValueDecodeFailed(_)),
));
assert_eq!(
<StorageProofChecker<sp_core::Blake2Hasher>>::new(sp_core::H256::random(), proof).err(),
Some(Error::StorageRootMismatch)
);
}
#[test]
fn proof_with_duplicate_items_is_rejected() {
let (root, mut proof) = craft_valid_storage_proof();
proof.push(proof.first().unwrap().clone());
assert_eq!(
StorageProofChecker::<sp_core::Blake2Hasher>::new(root, proof).map(drop),
Err(Error::DuplicateNodesInProof),
);
}
#[test]
fn proof_with_unused_items_is_rejected() {
let (root, proof) = craft_valid_storage_proof();
let mut checker =
StorageProofChecker::<sp_core::Blake2Hasher>::new(root, proof.clone()).unwrap();
checker.read_value(b"key1").unwrap();
checker.read_value(b"key2").unwrap();
checker.read_value(b"key4").unwrap();
checker.read_value(b"key22").unwrap();
assert_eq!(checker.ensure_no_unused_nodes(), Ok(()));
let checker = StorageProofChecker::<sp_core::Blake2Hasher>::new(root, proof).unwrap();
assert_eq!(checker.ensure_no_unused_nodes(), Err(Error::UnusedNodesInTheProof));
}
}