use std::collections::HashMap;
use crate::error::{InternalError, InvalidStateError};
use crate::state::merkle::{node::Node, MerkleRadixLeafReadError, MerkleRadixLeafReader};
use crate::state::{
Prune, Read, StateChange, StatePruneError, StateReadError, StateWriteError, Write,
};
use super::backend;
use super::encode_and_hash;
use super::{
store::{MerkleRadixStore, SqlMerkleRadixStore},
MerkleRadixOverlay, MerkleRadixPruner, SqlMerkleState, SqlMerkleStateBuildError,
SqlMerkleStateBuilder,
};
impl SqlMerkleStateBuilder<backend::SqliteBackend> {
pub fn build(self) -> Result<SqlMerkleState<backend::SqliteBackend>, SqlMerkleStateBuildError> {
let backend = self
.backend
.ok_or_else(|| InvalidStateError::with_message("must provide a backend".into()))?;
let tree_name = self
.tree_name
.ok_or_else(|| InvalidStateError::with_message("must provide a tree name".into()))?;
let store = SqlMerkleRadixStore::new(&backend);
let (initial_state_root_hash, _) = encode_and_hash(Node::default())?;
let tree_id: i64 = if self.create_tree {
store.get_or_create_tree(&tree_name, &hex::encode(&initial_state_root_hash))?
} else {
store.get_tree_id_by_name(&tree_name)?.ok_or_else(|| {
InvalidStateError::with_message("must provide the name of an existing tree".into())
})?
};
Ok(SqlMerkleState { backend, tree_id })
}
}
impl SqlMerkleState<backend::SqliteBackend> {
pub fn delete_tree(self) -> Result<(), InternalError> {
let store = SqlMerkleRadixStore::new(&self.backend);
store.delete_tree(self.tree_id)?;
Ok(())
}
}
impl Write for SqlMerkleState<backend::SqliteBackend> {
type StateId = String;
type Key = String;
type Value = Vec<u8>;
fn commit(
&self,
state_id: &Self::StateId,
state_changes: &[StateChange],
) -> Result<Self::StateId, StateWriteError> {
let overlay = MerkleRadixOverlay::new(
self.tree_id,
&*state_id,
SqlMerkleRadixStore::new(&self.backend),
);
let (next_state_id, tree_update) = overlay
.generate_updates(state_changes)
.map_err(|e| StateWriteError::StorageError(Box::new(e)))?;
overlay
.write_updates(&next_state_id, tree_update)
.map_err(|e| StateWriteError::StorageError(Box::new(e)))?;
Ok(next_state_id)
}
fn compute_state_id(
&self,
state_id: &Self::StateId,
state_changes: &[StateChange],
) -> Result<Self::StateId, StateWriteError> {
let overlay = MerkleRadixOverlay::new(
self.tree_id,
&*state_id,
SqlMerkleRadixStore::new(&self.backend),
);
let (next_state_id, _) = overlay
.generate_updates(state_changes)
.map_err(|e| StateWriteError::StorageError(Box::new(e)))?;
Ok(next_state_id)
}
}
impl Read for SqlMerkleState<backend::SqliteBackend> {
type StateId = String;
type Key = String;
type Value = Vec<u8>;
fn get(
&self,
state_id: &Self::StateId,
keys: &[Self::Key],
) -> Result<HashMap<Self::Key, Self::Value>, StateReadError> {
let overlay = MerkleRadixOverlay::new(
self.tree_id,
&*state_id,
SqlMerkleRadixStore::new(&self.backend),
);
if !overlay
.has_root()
.map_err(|e| StateReadError::StorageError(Box::new(e)))?
{
return Err(StateReadError::InvalidStateId(state_id.into()));
}
overlay
.get_entries(keys)
.map_err(|e| StateReadError::StorageError(Box::new(e)))
}
fn clone_box(
&self,
) -> Box<dyn Read<StateId = Self::StateId, Key = Self::Key, Value = Self::Value>> {
Box::new(self.clone())
}
}
impl Prune for SqlMerkleState<backend::SqliteBackend> {
type StateId = String;
type Key = String;
type Value = Vec<u8>;
fn prune(&self, state_ids: Vec<Self::StateId>) -> Result<Vec<Self::Key>, StatePruneError> {
let overlay = MerkleRadixPruner::new(self.tree_id, SqlMerkleRadixStore::new(&self.backend));
overlay
.prune(&state_ids)
.map_err(|e| StatePruneError::StorageError(Box::new(e)))
}
}
type IterResult<T> = Result<T, MerkleRadixLeafReadError>;
type LeafIter<T> = Box<dyn Iterator<Item = IterResult<T>>>;
impl MerkleRadixLeafReader for SqlMerkleState<backend::SqliteBackend> {
fn leaves(
&self,
state_id: &Self::StateId,
subtree: Option<&str>,
) -> IterResult<LeafIter<(Self::Key, Self::Value)>> {
if &self.initial_state_root_hash()? == state_id {
return Ok(Box::new(std::iter::empty()));
}
let leaves = SqlMerkleRadixStore::new(&self.backend).list_entries(
self.tree_id,
state_id,
subtree,
)?;
Ok(Box::new(leaves.into_iter().map(Ok)))
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::state::merkle::sql::backend::SqliteBackendBuilder;
use crate::state::merkle::sql::migration::MigrationManager;
#[test]
fn test_multiple_trees() -> Result<(), Box<dyn std::error::Error>> {
let backend = SqliteBackendBuilder::new().with_memory_database().build()?;
backend.run_migrations()?;
let tree_1 = SqlMerkleStateBuilder::new()
.with_backend(backend.clone())
.with_tree("test-1")
.create_tree_if_necessary()
.build()?;
let initial_state_root_hash = tree_1.initial_state_root_hash()?;
let state_change_set = StateChange::Set {
key: "1234".to_string(),
value: "state_value".as_bytes().to_vec(),
};
let new_root = tree_1
.commit(&initial_state_root_hash, &[state_change_set])
.unwrap();
assert_read_value_at_address(&tree_1, &new_root, "1234", Some("state_value"));
let tree_2 = SqlMerkleStateBuilder::new()
.with_backend(backend)
.with_tree("test-2")
.create_tree_if_necessary()
.build()?;
assert!(tree_2.get(&new_root, &["1234".to_string()]).is_err());
Ok(())
}
#[test]
fn test_build_fails_without_explicit_create() -> Result<(), Box<dyn std::error::Error>> {
let backend = SqliteBackendBuilder::new().with_memory_database().build()?;
backend.run_migrations()?;
assert!(SqlMerkleStateBuilder::new()
.with_backend(backend.clone())
.with_tree("test-1")
.build()
.is_err());
Ok(())
}
#[test]
fn test_delete_tree() -> Result<(), Box<dyn std::error::Error>> {
let backend = SqliteBackendBuilder::new().with_memory_database().build()?;
backend.run_migrations()?;
let state = SqlMerkleStateBuilder::new()
.with_backend(backend.clone())
.with_tree("test-1")
.create_tree_if_necessary()
.build()?;
let initial_state_root_hash = state.initial_state_root_hash()?;
let state_change_set = StateChange::Set {
key: "1234".to_string(),
value: "state_value".as_bytes().to_vec(),
};
let new_root = state.commit(&initial_state_root_hash, &[state_change_set])?;
assert_read_value_at_address(&state, &new_root, "1234", Some("state_value"));
drop(state);
let state = SqlMerkleStateBuilder::new()
.with_backend(backend.clone())
.with_tree("test-1")
.build()?;
assert_read_value_at_address(&state, &new_root, "1234", Some("state_value"));
state.delete_tree()?;
assert!(
SqlMerkleStateBuilder::new()
.with_backend(backend)
.with_tree("test-1")
.build()
.is_err(),
"The tree should no longer exist"
);
Ok(())
}
fn assert_read_value_at_address<R>(
merkle_read: &R,
root_hash: &str,
address: &str,
expected_value: Option<&str>,
) where
R: Read<StateId = String, Key = String, Value = Vec<u8>>,
{
let value = merkle_read
.get(&root_hash.to_string(), &[address.to_string()])
.and_then(|mut values| {
Ok(values.remove(address).map(|value| {
String::from_utf8(value).expect("could not convert bytes to string")
}))
});
match value {
Ok(value) => assert_eq!(expected_value, value.as_deref()),
Err(err) => panic!("value at address {} produced an error: {}", address, err),
}
}
}