use alloc::{
    collections::{btree_map::Entry, BTreeMap},
    string::ToString,
    vec::Vec,
};
use miden_crypto::EMPTY_WORD;
use super::{
    AccountDeltaError, ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable,
    Word,
};
use crate::Digest;
const IMMUTABLE_STORAGE_SLOT: u8 = u8::MAX;
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct AccountStorageDelta {
    slots: BTreeMap<u8, Word>,
    maps: BTreeMap<u8, StorageMapDelta>,
}
impl AccountStorageDelta {
    pub fn new(
        slots: BTreeMap<u8, Word>,
        maps: BTreeMap<u8, StorageMapDelta>,
    ) -> Result<Self, AccountDeltaError> {
        let result = Self { slots, maps };
        result.validate()?;
        Ok(result)
    }
    pub fn slots(&self) -> &BTreeMap<u8, Word> {
        &self.slots
    }
    pub fn maps(&self) -> &BTreeMap<u8, StorageMapDelta> {
        &self.maps
    }
    pub fn is_empty(&self) -> bool {
        self.slots.is_empty() && self.maps.is_empty()
    }
    pub fn set_item(&mut self, slot_index: u8, new_slot_value: Word) {
        self.slots.insert(slot_index, new_slot_value);
    }
    pub fn set_map_item(&mut self, slot_index: u8, key: Digest, new_value: Word) {
        self.maps.entry(slot_index).or_default().insert(key, new_value);
    }
    pub fn merge(&mut self, other: Self) -> Result<(), AccountDeltaError> {
        self.slots.extend(other.slots);
        for (slot, update) in other.maps.into_iter() {
            match self.maps.entry(slot) {
                Entry::Vacant(entry) => {
                    entry.insert(update);
                },
                Entry::Occupied(mut entry) => entry.get_mut().merge(update),
            }
        }
        self.validate()
    }
    fn validate(&self) -> Result<(), AccountDeltaError> {
        if self.slots.contains_key(&IMMUTABLE_STORAGE_SLOT)
            || self.maps.contains_key(&IMMUTABLE_STORAGE_SLOT)
        {
            return Err(AccountDeltaError::ImmutableStorageSlot(IMMUTABLE_STORAGE_SLOT as usize));
        }
        for slot in self.maps.keys() {
            if self.slots.contains_key(slot) {
                return Err(AccountDeltaError::DuplicateStorageItemUpdate(*slot as usize));
            }
        }
        Ok(())
    }
}
#[cfg(any(feature = "testing", test))]
impl AccountStorageDelta {
    pub fn from_iters(
        cleared_items: impl IntoIterator<Item = u8>,
        updated_items: impl IntoIterator<Item = (u8, Word)>,
        updated_maps: impl IntoIterator<Item = (u8, StorageMapDelta)>,
    ) -> Self {
        Self {
            slots: BTreeMap::from_iter(
                cleared_items.into_iter().map(|key| (key, EMPTY_WORD)).chain(updated_items),
            ),
            maps: BTreeMap::from_iter(updated_maps),
        }
    }
}
impl Serializable for AccountStorageDelta {
    fn write_into<W: ByteWriter>(&self, target: &mut W) {
        let cleared: Vec<u8> = self
            .slots
            .iter()
            .filter(|&(_, value)| (value == &EMPTY_WORD))
            .map(|(slot, _)| *slot)
            .collect();
        let updated: Vec<_> =
            self.slots.iter().filter(|&(_, value)| value != &EMPTY_WORD).collect();
        target.write_u8(cleared.len() as u8);
        target.write_many(cleared.iter());
        target.write_u8(updated.len() as u8);
        target.write_many(updated.iter());
        target.write_u8(self.maps.len() as u8);
        target.write_many(self.maps.iter());
    }
}
impl Deserializable for AccountStorageDelta {
    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
        let mut slots = BTreeMap::new();
        let num_cleared_items = source.read_u8()? as usize;
        for _ in 0..num_cleared_items {
            let cleared_slot = source.read_u8()?;
            slots.insert(cleared_slot, EMPTY_WORD);
        }
        let num_updated_items = source.read_u8()? as usize;
        for _ in 0..num_updated_items {
            let (updated_slot, updated_value) = source.read()?;
            slots.insert(updated_slot, updated_value);
        }
        let num_maps = source.read_u8()? as usize;
        let maps = source.read_many::<(u8, StorageMapDelta)>(num_maps)?.into_iter().collect();
        Self::new(slots, maps).map_err(|err| DeserializationError::InvalidValue(err.to_string()))
    }
}
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct StorageMapDelta(BTreeMap<Digest, Word>);
impl StorageMapDelta {
    pub fn new(map: BTreeMap<Digest, Word>) -> Self {
        Self(map)
    }
    pub fn leaves(&self) -> &BTreeMap<Digest, Word> {
        &self.0
    }
    pub fn insert(&mut self, key: Digest, value: Word) {
        self.0.insert(key, value);
    }
    pub fn is_empty(&self) -> bool {
        self.0.is_empty()
    }
    pub fn merge(&mut self, other: Self) {
        self.0.extend(other.0);
    }
}
#[cfg(any(feature = "testing", test))]
impl StorageMapDelta {
    pub fn from_iters(
        cleared_leaves: impl IntoIterator<Item = Word>,
        updated_leaves: impl IntoIterator<Item = (Word, Word)>,
    ) -> Self {
        Self(BTreeMap::from_iter(
            cleared_leaves
                .into_iter()
                .map(|key| (key.into(), EMPTY_WORD))
                .chain(updated_leaves.into_iter().map(|(key, value)| (key.into(), value))),
        ))
    }
}
impl Serializable for StorageMapDelta {
    fn write_into<W: ByteWriter>(&self, target: &mut W) {
        let cleared: Vec<&Digest> = self
            .0
            .iter()
            .filter(|&(_, value)| value == &EMPTY_WORD)
            .map(|(key, _)| key)
            .collect();
        let updated: Vec<_> = self.0.iter().filter(|&(_, value)| value != &EMPTY_WORD).collect();
        target.write_usize(cleared.len());
        target.write_many(cleared.iter());
        target.write_usize(updated.len());
        target.write_many(updated.iter());
    }
}
impl Deserializable for StorageMapDelta {
    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
        let mut map = BTreeMap::new();
        let cleared_count = source.read_usize()?;
        for _ in 0..cleared_count {
            let cleared_key = source.read()?;
            map.insert(cleared_key, EMPTY_WORD);
        }
        let updated_count = source.read_usize()?;
        for _ in 0..updated_count {
            let (updated_key, updated_value) = source.read()?;
            map.insert(updated_key, updated_value);
        }
        Ok(Self::new(map))
    }
}
#[cfg(test)]
mod tests {
    use super::{AccountStorageDelta, Deserializable, Serializable};
    use crate::{
        accounts::StorageMapDelta, testing::storage::AccountStorageDeltaBuilder, ONE, ZERO,
    };
    #[test]
    fn account_storage_delta_validation() {
        let delta = AccountStorageDelta::from_iters(
            [1, 2, 3],
            [(4, [ONE, ONE, ONE, ONE]), (5, [ONE, ONE, ONE, ZERO])],
            [],
        );
        assert!(delta.validate().is_ok());
        let bytes = delta.to_bytes();
        assert_eq!(AccountStorageDelta::read_from_bytes(&bytes), Ok(delta));
        let delta = AccountStorageDelta::from_iters([1, 2, 255], [], []);
        assert!(delta.validate().is_err());
        let bytes = delta.to_bytes();
        assert!(AccountStorageDelta::read_from_bytes(&bytes).is_err());
        let bytes = delta.to_bytes();
        assert!(AccountStorageDelta::read_from_bytes(&bytes).is_err());
        let delta = AccountStorageDelta::from_iters(
            [],
            [(4, [ONE, ONE, ONE, ONE]), (255, [ONE, ONE, ONE, ZERO])],
            [],
        );
        assert!(delta.validate().is_err());
        let bytes = delta.to_bytes();
        assert!(AccountStorageDelta::read_from_bytes(&bytes).is_err());
        let bytes = delta.to_bytes();
        assert!(AccountStorageDelta::read_from_bytes(&bytes).is_err());
        let delta = AccountStorageDelta::from_iters(
            [1, 2, 3],
            [(2, [ONE, ONE, ONE, ONE]), (5, [ONE, ONE, ONE, ZERO])],
            [(1, StorageMapDelta::default())],
        );
        assert!(delta.validate().is_err());
        let bytes = delta.to_bytes();
        assert!(AccountStorageDelta::read_from_bytes(&bytes).is_err());
        let delta = AccountStorageDelta::from_iters(
            [1, 3],
            [(2, [ONE, ONE, ONE, ONE]), (5, [ONE, ONE, ONE, ZERO])],
            [(2, StorageMapDelta::default())],
        );
        assert!(delta.validate().is_err());
        let bytes = delta.to_bytes();
        assert!(AccountStorageDelta::read_from_bytes(&bytes).is_err());
    }
    #[test]
    fn test_is_empty() {
        let storage_delta = AccountStorageDelta::default();
        assert!(storage_delta.is_empty());
        let storage_delta = AccountStorageDelta::from_iters([1], [], []);
        assert!(!storage_delta.is_empty());
        let storage_delta = AccountStorageDelta::from_iters([], [(2, [ONE, ONE, ONE, ONE])], []);
        assert!(!storage_delta.is_empty());
        let storage_delta =
            AccountStorageDelta::from_iters([], [], [(3, StorageMapDelta::default())]);
        assert!(!storage_delta.is_empty());
    }
    #[test]
    fn test_serde_account_storage_delta() {
        let storage_delta = AccountStorageDelta::default();
        let serialized = storage_delta.to_bytes();
        let deserialized = AccountStorageDelta::read_from_bytes(&serialized).unwrap();
        assert_eq!(deserialized, storage_delta);
        let storage_delta = AccountStorageDelta::from_iters([1], [], []);
        let serialized = storage_delta.to_bytes();
        let deserialized = AccountStorageDelta::read_from_bytes(&serialized).unwrap();
        assert_eq!(deserialized, storage_delta);
        let storage_delta = AccountStorageDelta::from_iters([], [(2, [ONE, ONE, ONE, ONE])], []);
        let serialized = storage_delta.to_bytes();
        let deserialized = AccountStorageDelta::read_from_bytes(&serialized).unwrap();
        assert_eq!(deserialized, storage_delta);
        let storage_delta =
            AccountStorageDelta::from_iters([], [], [(3, StorageMapDelta::default())]);
        let serialized = storage_delta.to_bytes();
        let deserialized = AccountStorageDelta::read_from_bytes(&serialized).unwrap();
        assert_eq!(deserialized, storage_delta);
    }
    #[test]
    fn test_serde_storage_map_delta() {
        let storage_map_delta = StorageMapDelta::default();
        let serialized = storage_map_delta.to_bytes();
        let deserialized = StorageMapDelta::read_from_bytes(&serialized).unwrap();
        assert_eq!(deserialized, storage_map_delta);
        let storage_map_delta = StorageMapDelta::from_iters([[ONE, ONE, ONE, ONE]], []);
        let serialized = storage_map_delta.to_bytes();
        let deserialized = StorageMapDelta::read_from_bytes(&serialized).unwrap();
        assert_eq!(deserialized, storage_map_delta);
        let storage_map_delta =
            StorageMapDelta::from_iters([], [([ZERO, ZERO, ZERO, ZERO], [ONE, ONE, ONE, ONE])]);
        let serialized = storage_map_delta.to_bytes();
        let deserialized = StorageMapDelta::read_from_bytes(&serialized).unwrap();
        assert_eq!(deserialized, storage_map_delta);
    }
    #[rstest::rstest]
    #[case::some_some(Some(1), Some(2), Some(2))]
    #[case::none_some(None, Some(2), Some(2))]
    #[case::some_none(Some(1), None, None)]
    #[test]
    fn merge_items(#[case] x: Option<u64>, #[case] y: Option<u64>, #[case] expected: Option<u64>) {
        fn create_delta(item: Option<u64>) -> AccountStorageDelta {
            const SLOT: u8 = 123;
            let item = item.map(|x| (SLOT, [vm_core::Felt::new(x), ZERO, ZERO, ZERO]));
            AccountStorageDeltaBuilder::default()
                .add_cleared_items(item.is_none().then_some(SLOT))
                .add_updated_items(item)
                .build()
                .unwrap()
        }
        let mut delta_x = create_delta(x);
        let delta_y = create_delta(y);
        let expected = create_delta(expected);
        delta_x.merge(delta_y).unwrap();
        assert_eq!(delta_x, expected);
    }
    #[rstest::rstest]
    #[case::some_some(Some(1), Some(2), Some(2))]
    #[case::none_some(None, Some(2), Some(2))]
    #[case::some_none(Some(1), None, None)]
    #[test]
    fn merge_maps(#[case] x: Option<u64>, #[case] y: Option<u64>, #[case] expected: Option<u64>) {
        fn create_delta(value: Option<u64>) -> StorageMapDelta {
            let key = [vm_core::Felt::new(10), ZERO, ZERO, ZERO];
            match value {
                Some(value) => StorageMapDelta::from_iters(
                    [],
                    [(key, [vm_core::Felt::new(value), ZERO, ZERO, ZERO])],
                ),
                None => StorageMapDelta::from_iters([key], []),
            }
        }
        let mut delta_x = create_delta(x);
        let delta_y = create_delta(y);
        let expected = create_delta(expected);
        delta_x.merge(delta_y);
        assert_eq!(delta_x, expected);
    }
}