use crate::{
core::errors::state_errors::StateError,
services::api::contract_classes::compiled_class::CompiledClass,
state::{
cached_state::UNINITIALIZED_CLASS_HASH, state_api::StateReader, state_cache::StorageEntry,
},
utils::{Address, ClassHash, CompiledClassHash},
};
use cairo_vm::felt::Felt252;
use getset::{Getters, MutGetters};
use std::collections::HashMap;
#[derive(Debug, MutGetters, Getters, PartialEq, Eq, Clone, Default)]
pub struct InMemoryStateReader {
#[getset(get_mut = "pub")]
pub address_to_class_hash: HashMap<Address, ClassHash>,
#[getset(get_mut = "pub")]
pub address_to_nonce: HashMap<Address, Felt252>,
#[getset(get_mut = "pub")]
pub address_to_storage: HashMap<StorageEntry, Felt252>,
#[getset(get_mut = "pub")]
pub class_hash_to_compiled_class: HashMap<ClassHash, CompiledClass>,
#[getset(get_mut = "pub")]
pub class_hash_to_compiled_class_hash: HashMap<ClassHash, CompiledClassHash>,
}
impl InMemoryStateReader {
pub const fn new(
address_to_class_hash: HashMap<Address, ClassHash>,
address_to_nonce: HashMap<Address, Felt252>,
address_to_storage: HashMap<StorageEntry, Felt252>,
class_hash_to_compiled_class: HashMap<ClassHash, CompiledClass>,
class_hash_to_compiled_class_hash: HashMap<ClassHash, CompiledClassHash>,
) -> Self {
Self {
address_to_class_hash,
address_to_nonce,
address_to_storage,
class_hash_to_compiled_class,
class_hash_to_compiled_class_hash,
}
}
fn get_compiled_class(
&self,
compiled_class_hash: &CompiledClassHash,
) -> Result<CompiledClass, StateError> {
match self.class_hash_to_compiled_class.get(compiled_class_hash) {
Some(compiled_class) => Ok(compiled_class.clone()),
None => Err(StateError::NoneCompiledClass(*compiled_class_hash)),
}
}
}
impl StateReader for InMemoryStateReader {
fn get_class_hash_at(&self, contract_address: &Address) -> Result<ClassHash, StateError> {
Ok(self
.address_to_class_hash
.get(contract_address)
.cloned()
.unwrap_or_default())
}
fn get_nonce_at(&self, contract_address: &Address) -> Result<Felt252, StateError> {
Ok(self
.address_to_nonce
.get(contract_address)
.cloned()
.unwrap_or_default())
}
fn get_storage_at(&self, storage_entry: &StorageEntry) -> Result<Felt252, StateError> {
Ok(self
.address_to_storage
.get(storage_entry)
.cloned()
.unwrap_or_default())
}
fn get_compiled_class_hash(
&self,
class_hash: &ClassHash,
) -> Result<CompiledClassHash, StateError> {
self.class_hash_to_compiled_class_hash
.get(class_hash)
.ok_or(StateError::NoneCompiledHash(*class_hash))
.copied()
}
fn get_contract_class(&self, class_hash: &ClassHash) -> Result<CompiledClass, StateError> {
if let Some(compiled_class) = self.class_hash_to_compiled_class.get(class_hash) {
return Ok(compiled_class.clone());
}
let compiled_class_hash = self.get_compiled_class_hash(class_hash)?;
if compiled_class_hash != *UNINITIALIZED_CLASS_HASH {
let compiled_class = self.get_compiled_class(&compiled_class_hash)?;
Ok(compiled_class)
} else {
Err(StateError::MissingCasmClass(compiled_class_hash))
}
}
}
#[cfg(test)]
mod tests {
use num_traits::{One, Zero};
use super::*;
use crate::services::api::contract_classes::deprecated_contract_class::ContractClass;
use std::sync::Arc;
#[test]
fn get_class_hash_at_returns_zero_if_missing() {
let state_reader = InMemoryStateReader::default();
assert!(Felt252::from_bytes_be(
&state_reader
.get_class_hash_at(&Address(Felt252::one()))
.unwrap()
)
.is_zero())
}
#[test]
fn get_storage_returns_zero_if_missing() {
let state_reader = InMemoryStateReader::default();
assert!(state_reader
.get_storage_at(&(Address(Felt252::one()), Felt252::one().to_be_bytes()))
.unwrap()
.is_zero())
}
#[test]
fn get_contract_state_test() {
let mut state_reader = InMemoryStateReader::new(
HashMap::new(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
);
let contract_address = Address(37810.into());
let class_hash = [1; 32];
let nonce = Felt252::new(109);
let storage_entry = (contract_address.clone(), [8; 32]);
let storage_value = Felt252::new(800);
state_reader
.address_to_class_hash
.insert(contract_address.clone(), class_hash);
state_reader
.address_to_nonce
.insert(contract_address.clone(), nonce.clone());
state_reader
.address_to_storage
.insert(storage_entry.clone(), storage_value.clone());
assert_eq!(
state_reader.get_class_hash_at(&contract_address).unwrap(),
class_hash
);
assert_eq!(state_reader.get_nonce_at(&contract_address).unwrap(), nonce);
assert_eq!(
state_reader.get_storage_at(&storage_entry).unwrap(),
storage_value
);
}
#[test]
fn get_contract_class_test() {
let mut state_reader = InMemoryStateReader::new(
HashMap::new(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
);
let contract_class_key = [0; 32];
let contract_class =
ContractClass::from_path("starknet_programs/raw_contract_classes/class_with_abi.json")
.unwrap();
state_reader.class_hash_to_compiled_class.insert(
[0; 32],
CompiledClass::Deprecated(Arc::new(contract_class.clone())),
);
assert_eq!(
state_reader
.get_contract_class(&contract_class_key)
.unwrap()
.try_into(),
Ok(contract_class)
)
}
}