use serde::{Serialize, de::DeserializeOwned};
use sha2::{Digest, Sha256};
use std::{
collections::{HashMap, VecDeque},
error::Error,
fmt::Debug,
hash::Hash,
sync::Arc,
};
use crate::{Diff, remote::Remote};
pub trait StateMapKey:
Hash + Eq + Clone + AsRef<[u8]> + Ord + Send + 'static + Serialize + DeserializeOwned
{
}
impl<T: Hash + Eq + Clone + AsRef<[u8]> + Ord + Send + 'static + Serialize + DeserializeOwned>
StateMapKey for T
{
}
pub trait StateMapValue:
Clone + AsRef<[u8]> + Send + 'static + Serialize + DeserializeOwned
{
}
impl<T: Clone + AsRef<[u8]> + Send + 'static + Serialize + DeserializeOwned> StateMapValue for T {}
#[derive(thiserror::Error, Debug, PartialEq, Eq)]
pub enum StateMapError {
#[error(
"The StateMap is frozen, modifications to keys is not possible, only the data can be mutated using updates"
)]
Frozen,
#[error(
"The StateMap is not frozen, Please freeze the StateMap before doing any post init operations."
)]
Unfrozen,
#[error("Error while trying to fetch updates: {0}")]
FetchUpdateError(String),
#[error("Key not found in the statemap")]
KeyNotFound,
#[error("Empty diff, no items in diff")]
EmptyDiff,
#[error("Unkown error")]
UnknownError(String),
}
impl From<Box<dyn Error>> for StateMapError {
fn from(value: Box<dyn Error>) -> Self {
Self::UnknownError(value.to_string())
}
}
pub struct StateMap<K, T> {
data: HashMap<K, T>,
order: Vec<K>,
update_id: u64,
hash: Option<[u8; 32]>,
remote: Arc<dyn Remote<K, T>>,
is_master: bool,
diffs: VecDeque<Diff<K, T>>, }
impl<K, T> StateMap<K, T> {
pub fn new(remote: Arc<dyn Remote<K, T>>) -> Self {
Self {
data: HashMap::new(),
order: Vec::new(),
update_id: 0,
hash: None,
remote,
is_master: false,
diffs: VecDeque::new(),
}
}
pub fn get_update_id(&self) -> u64 {
self.update_id
}
pub fn is_frozen(&self) -> bool {
self.hash.is_some()
}
pub fn set_master(&mut self, is_master: bool) -> Result<(), StateMapError> {
if self.is_frozen() {
Err(StateMapError::Frozen)
} else {
self.is_master = is_master;
Ok(())
}
}
pub fn hash(&self) -> Option<&[u8; 32]> {
self.hash.as_ref()
}
}
impl<K: StateMapKey, T: StateMapValue> StateMap<K, T> {
pub fn push(&mut self, key: K, value: T) -> Result<(), StateMapError> {
if self.hash.is_some() {
return Err(StateMapError::Frozen);
}
self.data.insert(key.clone(), value);
self.order.push(key);
Ok(())
}
fn calculate_hash(&mut self) -> [u8; 32] {
self.order.sort();
let mut hasher = Sha256::new();
for x in &self.order {
let v = self
.data
.get(x)
.expect("No value found for a Key, this should never happen");
let bytelen_key = (x.as_ref().len() as u64).to_be_bytes();
let bytelen_data = (v.as_ref().len() as u64).to_be_bytes();
hasher.update(bytelen_key);
hasher.update(x);
hasher.update(bytelen_data);
hasher.update(v);
}
hasher.finalize().into()
}
pub fn freeze(&mut self) -> Result<(), StateMapError> {
if self.hash.is_none() {
self.hash = Some(self.calculate_hash());
self.diffs.push_back(Diff::full_diff(self));
self.remote.init(self)?;
};
Ok(())
}
pub fn set_update_id(&mut self, update_id: u64) -> Result<(), StateMapError> {
if update_id > self.update_id {
if let Some(hash) = &self.hash {
let diff = self
.remote
.fetch_updates(self.update_id, update_id, hash)
.map_err(|e| StateMapError::FetchUpdateError(e.to_string()))?;
if diff.is_empty() {
return Err(StateMapError::EmptyDiff);
}
if diff.from_update_id() > self.update_id {
return Err(StateMapError::FetchUpdateError(format!(
"The from update id is not consistent, diff from_update_id: {}, self update_id: {}",
diff.from_update_id(),
self.update_id
)));
}
self.update_id = diff.upto_update_id();
for (k, v) in diff.get_diff() {
self.data.insert(k.clone(), v.clone());
}
} else {
return Err(StateMapError::Unfrozen);
}
}
Ok(())
}
pub fn get(&self, key: &K) -> Option<&T> {
self.data.get(key)
}
pub fn set(&mut self, key: K, value: T) -> Result<(), StateMapError> {
if self.data.contains_key(&key) {
if !self.is_frozen() || self.is_master {
assert!(
self.data.insert(key.clone(), value.clone()).is_some(),
"Set added a new key which it shouldn't have"
);
self.update_id += 1;
self.diffs.push_back(Diff::new(
[(key, value)],
false,
self.update_id - 1,
self.update_id,
));
Ok(())
} else {
Err(StateMapError::Frozen)
}
} else {
Err(StateMapError::KeyNotFound)
}
}
pub fn begin_modification(&mut self) -> Result<StateMapModifier<'_, K, T>, StateMapError> {
if !self.is_frozen() || self.is_master {
Ok(StateMapModifier::new(self))
} else {
Err(StateMapError::Frozen)
}
}
pub fn get_diff(&self, from_update_id: u64, upto_update_id: u64) -> Diff<K, T> {
let mut start = 0;
let mut end = self.diffs.len();
for (idx, x) in self.diffs.iter().enumerate() {
if x.from_update_id() == from_update_id {
start = idx;
break;
}
if x.from_update_id() > from_update_id {
return Diff::full_diff(self);
}
}
if self
.diffs
.back()
.expect("there should atleast be one diff")
.from_update_id()
< from_update_id
{
return Diff::full_diff(self);
}
for (idx, x) in self.diffs.iter().enumerate().rev() {
if x.upto_update_id() == upto_update_id {
end = idx;
break;
}
if x.upto_update_id() < upto_update_id {
return Diff::merge(self.diffs.range(start..=idx).cloned());
}
}
if self
.diffs
.front()
.expect("there should atleast be one diff")
.upto_update_id()
> upto_update_id
{
return Diff::empty();
}
Diff::merge(self.diffs.range(start..=end).cloned())
}
}
impl<'a, K, T> IntoIterator for &'a StateMap<K, T> {
type Item = (&'a K, &'a T);
type IntoIter = std::collections::hash_map::Iter<'a, K, T>;
fn into_iter(self) -> Self::IntoIter {
self.data.iter()
}
}
impl<K, T> IntoIterator for StateMap<K, T> {
type Item = (K, T);
type IntoIter = std::collections::hash_map::IntoIter<K, T>;
fn into_iter(self) -> Self::IntoIter {
self.data.into_iter()
}
}
impl<K, T> Debug for StateMap<K, T>
where
K: Debug,
T: Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StateMap")
.field("data", &self.data)
.field("update_id", &self.update_id)
.field("hash", &self.hash.as_ref().map(hex::encode))
.field("is_master", &self.is_master)
.field("diffs", &self.diffs)
.finish()
}
}
pub struct StateMapModifier<'a, K, T> {
statemap: &'a mut StateMap<K, T>,
diff: HashMap<K, T>,
}
impl<'a, K: StateMapKey, T: StateMapValue> StateMapModifier<'a, K, T> {
fn new(statemap: &'a mut StateMap<K, T>) -> Self {
Self {
statemap,
diff: HashMap::new(),
}
}
pub fn get(&self, key: &K) -> Option<&T> {
self.diff.get(key).or_else(|| self.statemap.get(key))
}
pub fn set(&mut self, key: K, value: T) -> Result<&mut Self, StateMapError> {
if self.statemap.data.contains_key(&key) {
self.diff.insert(key, value);
Ok(self)
} else {
Err(StateMapError::KeyNotFound)
}
}
pub fn commit(self) {
self.statemap.update_id += 1;
for (k, v) in &self.diff {
assert!(
self.statemap.data.insert(k.clone(), v.clone()).is_some(),
"Ensure no new keys are added"
);
}
self.statemap.diffs.push_back(Diff::new(
self.diff,
false,
self.statemap.update_id - 1,
self.statemap.update_id,
));
}
}