use std::sync::{Arc, Mutex, MutexGuard};
use starknet_api::core::{ClassHash, CompiledClassHash, ContractAddress, Nonce};
use starknet_api::state::StorageKey;
use starknet_types_core::felt::Felt;
use crate::concurrency::versioned_storage::VersionedStorage;
use crate::concurrency::TxIndex;
use crate::execution::contract_class::RunnableCompiledClass;
use crate::state::cached_state::{ContractClassMapping, StateMaps};
use crate::state::errors::StateError;
use crate::state::state_api::{StateReader, StateResult, UpdatableState};
#[cfg(test)]
#[path = "versioned_state_test.rs"]
pub mod versioned_state_test;
const READ_ERR: &str = "Error: read value missing in the versioned storage";
#[derive(Debug)]
pub struct VersionedState<S: StateReader> {
initial_state: S,
storage: VersionedStorage<(ContractAddress, StorageKey), Felt>,
nonces: VersionedStorage<ContractAddress, Nonce>,
class_hashes: VersionedStorage<ContractAddress, ClassHash>,
compiled_class_hashes: VersionedStorage<ClassHash, CompiledClassHash>,
declared_contracts: VersionedStorage<ClassHash, bool>,
compiled_contract_classes: VersionedStorage<ClassHash, RunnableCompiledClass>,
}
impl<S: StateReader> VersionedState<S> {
pub fn new(initial_state: S) -> Self {
VersionedState {
initial_state,
storage: VersionedStorage::default(),
nonces: VersionedStorage::default(),
class_hashes: VersionedStorage::default(),
compiled_class_hashes: VersionedStorage::default(),
compiled_contract_classes: VersionedStorage::default(),
declared_contracts: VersionedStorage::default(),
}
}
fn get_writes_up_to_index(&mut self, tx_index: TxIndex) -> StateMaps {
StateMaps {
storage: self.storage.get_writes_up_to_index(tx_index),
nonces: self.nonces.get_writes_up_to_index(tx_index),
class_hashes: self.class_hashes.get_writes_up_to_index(tx_index),
compiled_class_hashes: self.compiled_class_hashes.get_writes_up_to_index(tx_index),
declared_contracts: self.declared_contracts.get_writes_up_to_index(tx_index),
}
}
#[cfg(any(feature = "testing", test))]
pub fn get_writes_of_index(&self, tx_index: TxIndex) -> StateMaps {
StateMaps {
storage: self.storage.get_writes_of_index(tx_index),
nonces: self.nonces.get_writes_of_index(tx_index),
class_hashes: self.class_hashes.get_writes_of_index(tx_index),
compiled_class_hashes: self.compiled_class_hashes.get_writes_of_index(tx_index),
declared_contracts: self.declared_contracts.get_writes_of_index(tx_index),
}
}
fn validate_reads(&mut self, tx_index: TxIndex, reads: &StateMaps) -> bool {
if tx_index == 0 {
return true;
}
let tx_index = tx_index - 1;
for (&(contract_address, storage_key), expected_value) in &reads.storage {
let value =
self.storage.read(tx_index, (contract_address, storage_key)).expect(READ_ERR);
if &value != expected_value {
return false;
}
}
for (&contract_address, expected_value) in &reads.nonces {
let value = self.nonces.read(tx_index, contract_address).expect(READ_ERR);
if &value != expected_value {
return false;
}
}
for (&contract_address, expected_value) in &reads.class_hashes {
let value = self.class_hashes.read(tx_index, contract_address).expect(READ_ERR);
if &value != expected_value {
return false;
}
}
for (&class_hash, expected_value) in &reads.compiled_class_hashes {
let value = self.compiled_class_hashes.read(tx_index, class_hash).expect(READ_ERR);
if &value != expected_value {
return false;
}
}
for (&class_hash, expected_value) in &reads.declared_contracts {
let is_declared = self.declared_contracts.read(tx_index, class_hash).expect(READ_ERR);
assert_eq!(
is_declared,
self.compiled_contract_classes.read(tx_index, class_hash).is_some(),
"The declared contracts mapping should match the compiled contract classes \
mapping."
);
if &is_declared != expected_value {
return false;
}
}
true
}
fn apply_writes(
&mut self,
tx_index: TxIndex,
writes: &StateMaps,
class_hash_to_class: &ContractClassMapping,
) {
for (&key, &value) in &writes.storage {
self.storage.write(tx_index, key, value);
}
for (&key, &value) in &writes.nonces {
self.nonces.write(tx_index, key, value);
}
for (&key, &value) in &writes.class_hashes {
self.class_hashes.write(tx_index, key, value);
}
for (&key, &value) in &writes.compiled_class_hashes {
self.compiled_class_hashes.write(tx_index, key, value);
}
for (&key, value) in class_hash_to_class {
self.compiled_contract_classes.write(tx_index, key, value.clone());
}
for (&key, &value) in &writes.declared_contracts {
self.declared_contracts.write(tx_index, key, value);
assert_eq!(
value,
self.compiled_contract_classes.read(tx_index, key).is_some(),
"The declared contracts mapping should match the compiled contract classes \
mapping."
);
}
}
fn delete_writes(
&mut self,
tx_index: TxIndex,
writes: &StateMaps,
class_hash_to_class: &ContractClassMapping,
) {
for &key in writes.storage.keys() {
self.storage.delete_write(key, tx_index);
}
for &key in writes.nonces.keys() {
self.nonces.delete_write(key, tx_index);
}
for &key in writes.class_hashes.keys() {
self.class_hashes.delete_write(key, tx_index);
}
for &key in writes.compiled_class_hashes.keys() {
self.compiled_class_hashes.delete_write(key, tx_index);
}
for &key in writes.declared_contracts.keys() {
self.declared_contracts.delete_write(key, tx_index);
}
for &key in class_hash_to_class.keys() {
self.compiled_contract_classes.delete_write(key, tx_index);
}
}
fn into_initial_state(self) -> S {
self.initial_state
}
}
impl<U: UpdatableState> VersionedState<U> {
pub fn commit_chunk_and_recover_block_state(mut self, n_committed_txs: usize) -> U {
let writes = self.get_writes_up_to_index(n_committed_txs);
let class_hash_to_class =
self.compiled_contract_classes.get_writes_up_to_index(n_committed_txs);
let mut state = self.into_initial_state();
state.apply_writes(&writes, &class_hash_to_class);
state
}
}
#[derive(Debug)]
pub enum VersionedStateError {
ExecutionHalted,
}
pub struct OptionalVersionedState<S: StateReader>(Option<VersionedState<S>>);
impl<S: StateReader> OptionalVersionedState<S> {
#[cfg(any(feature = "testing", test))]
pub fn new(state: S) -> Self {
OptionalVersionedState(Some(VersionedState::new(state)))
}
#[cfg(any(feature = "testing", test))]
pub fn inner_unwrap(&self) -> &VersionedState<S> {
self.0.as_ref().unwrap()
}
fn inner_mut(&mut self) -> StateResult<&mut VersionedState<S>> {
self.0
.as_mut()
.ok_or(StateError::StateReadError("Versioned state was already consumed.".into()))
}
fn inner_mut_or_versioned_state_error(
&mut self,
) -> Result<&mut VersionedState<S>, VersionedStateError> {
self.0.as_mut().ok_or(VersionedStateError::ExecutionHalted)
}
fn validate_reads(
&mut self,
tx_index: TxIndex,
reads: &StateMaps,
) -> Result<bool, VersionedStateError> {
Ok(self.inner_mut_or_versioned_state_error()?.validate_reads(tx_index, reads))
}
fn delete_writes(
&mut self,
tx_index: TxIndex,
writes: &StateMaps,
class_hash_to_class: &ContractClassMapping,
) -> Result<(), VersionedStateError> {
self.inner_mut_or_versioned_state_error()?.delete_writes(
tx_index,
writes,
class_hash_to_class,
);
Ok(())
}
fn apply_writes(
&mut self,
tx_index: TxIndex,
writes: &StateMaps,
class_hash_to_class: &ContractClassMapping,
) {
if let Some(state) = self.0.as_mut() {
state.apply_writes(tx_index, writes, class_hash_to_class)
}
}
}
pub struct ThreadSafeVersionedState<S: StateReader>(Arc<Mutex<OptionalVersionedState<S>>>);
pub type LockedVersionedState<'a, S> = MutexGuard<'a, OptionalVersionedState<S>>;
impl<S: StateReader> ThreadSafeVersionedState<S> {
pub fn new(versioned_state: VersionedState<S>) -> Self {
ThreadSafeVersionedState(Mutex::new(OptionalVersionedState(Some(versioned_state))).into())
}
pub fn pin_version(&self, tx_index: TxIndex) -> VersionedStateProxy<S> {
VersionedStateProxy { tx_index, state: self.0.clone() }
}
pub fn into_inner_state(&self) -> VersionedState<S> {
let mut opt_version_state = self.0.lock().expect("Failed to acquire state lock.");
opt_version_state.0.take().expect("Versioned state was already consumed.")
}
}
impl<S: StateReader> Clone for ThreadSafeVersionedState<S> {
fn clone(&self) -> Self {
ThreadSafeVersionedState(Arc::clone(&self.0))
}
}
pub struct VersionedStateProxy<S: StateReader> {
pub tx_index: TxIndex,
pub state: Arc<Mutex<OptionalVersionedState<S>>>,
}
impl<S: StateReader> VersionedStateProxy<S> {
fn state(&self) -> LockedVersionedState<'_, S> {
self.state.lock().expect("Failed to acquire state lock.")
}
pub fn validate_reads(&self, reads: &StateMaps) -> Result<bool, VersionedStateError> {
self.state().validate_reads(self.tx_index, reads)
}
pub fn delete_writes(
&self,
writes: &StateMaps,
class_hash_to_class: &ContractClassMapping,
) -> Result<(), VersionedStateError> {
self.state().delete_writes(self.tx_index, writes, class_hash_to_class)
}
}
impl<S: StateReader> UpdatableState for VersionedStateProxy<S> {
fn apply_writes(&mut self, writes: &StateMaps, class_hash_to_class: &ContractClassMapping) {
self.state().apply_writes(self.tx_index, writes, class_hash_to_class)
}
}
impl<S: StateReader> StateReader for VersionedStateProxy<S> {
fn get_storage_at(
&self,
contract_address: ContractAddress,
key: StorageKey,
) -> StateResult<Felt> {
let mut state_opt = self.state();
let state = state_opt.inner_mut()?;
match state.storage.read(self.tx_index, (contract_address, key)) {
Some(value) => Ok(value),
None => {
let initial_value = state.initial_state.get_storage_at(contract_address, key)?;
state.storage.set_initial_value((contract_address, key), initial_value);
Ok(initial_value)
}
}
}
fn get_nonce_at(&self, contract_address: ContractAddress) -> StateResult<Nonce> {
let mut state_opt = self.state();
let state = state_opt.inner_mut()?;
match state.nonces.read(self.tx_index, contract_address) {
Some(value) => Ok(value),
None => {
let initial_value = state.initial_state.get_nonce_at(contract_address)?;
state.nonces.set_initial_value(contract_address, initial_value);
Ok(initial_value)
}
}
}
fn get_class_hash_at(&self, contract_address: ContractAddress) -> StateResult<ClassHash> {
let mut state_opt = self.state();
let state = state_opt.inner_mut()?;
match state.class_hashes.read(self.tx_index, contract_address) {
Some(value) => Ok(value),
None => {
let initial_value = state.initial_state.get_class_hash_at(contract_address)?;
state.class_hashes.set_initial_value(contract_address, initial_value);
Ok(initial_value)
}
}
}
fn get_compiled_class_hash(&self, class_hash: ClassHash) -> StateResult<CompiledClassHash> {
let mut state_opt = self.state();
let state = state_opt.inner_mut()?;
match state.compiled_class_hashes.read(self.tx_index, class_hash) {
Some(value) => Ok(value),
None => {
let initial_value = state.initial_state.get_compiled_class_hash(class_hash)?;
state.compiled_class_hashes.set_initial_value(class_hash, initial_value);
Ok(initial_value)
}
}
}
fn get_compiled_class_hash_v2(
&self,
class_hash: ClassHash,
compiled_class: &RunnableCompiledClass,
) -> StateResult<CompiledClassHash> {
let mut state_opt = self.state();
let state = state_opt.inner_mut()?;
state.initial_state.get_compiled_class_hash_v2(class_hash, compiled_class)
}
fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult<RunnableCompiledClass> {
let mut state_opt = self.state();
let state = state_opt.inner_mut()?;
match state.compiled_contract_classes.read(self.tx_index, class_hash) {
Some(value) => Ok(value),
None => match state.initial_state.get_compiled_class(class_hash) {
Ok(initial_value) => {
state.declared_contracts.set_initial_value(class_hash, true);
state
.compiled_contract_classes
.set_initial_value(class_hash, initial_value.clone());
Ok(initial_value)
}
Err(StateError::UndeclaredClassHash(class_hash)) => {
state.declared_contracts.set_initial_value(class_hash, false);
state
.compiled_class_hashes
.set_initial_value(class_hash, CompiledClassHash(Felt::ZERO));
Err(StateError::UndeclaredClassHash(class_hash))?
}
Err(error) => Err(error)?,
},
}
}
}