use std::collections::HashMap;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use rstest::rstest;
use starknet_api::{contract_address, felt, storage_key};
use crate::abi::sierra_types::{SierraType, SierraU128};
use crate::concurrency::scheduler::{Scheduler, Task, TransactionStatus};
use crate::concurrency::test_utils::{safe_versioned_state_for_testing, DEFAULT_CHUNK_SIZE};
use crate::concurrency::versioned_state::ThreadSafeVersionedState;
use crate::concurrency::TxIndex;
use crate::state::cached_state::{CachedState, ContractClassMapping, StateMaps};
use crate::state::state_api::UpdatableState;
use crate::test_utils::dict_state_reader::DictStateReader;
const CONTRACT_ADDRESS: &str = "0x18031991";
const STORAGE_KEY: u8 = 27;
fn validate(
scheduler: &Scheduler,
versioned_state: &ThreadSafeVersionedState<CachedState<DictStateReader>>,
tx_index: TxIndex,
commit_phase: bool,
) -> bool {
let state_proxy = versioned_state.pin_version(tx_index);
let (reads, writes) = get_reads_writes_for(Task::ValidationTask(tx_index), versioned_state);
let read_set_valid = state_proxy.validate_reads(&reads).unwrap();
let aborted = !read_set_valid && scheduler.try_validation_abort(tx_index, commit_phase);
if aborted {
state_proxy.delete_writes(&writes, &ContractClassMapping::default()).unwrap();
scheduler.finish_abort(tx_index);
}
read_set_valid
}
#[rstest]
fn scheduler_flow_test(
#[values(1, 2, 4, 32, 64, 128)] num_threads: u8,
) {
let scheduler = Arc::new(Scheduler::new(DEFAULT_CHUNK_SIZE));
let versioned_state =
safe_versioned_state_for_testing(CachedState::from(DictStateReader::default()));
let mut handles = vec![];
for _ in 0..num_threads {
let scheduler = Arc::clone(&scheduler);
let versioned_state = versioned_state.clone();
let handle = std::thread::spawn(move || {
let mut task = Task::AskForTask;
loop {
if let Some(mut transaction_committer) = scheduler.try_enter_commit_phase() {
while let Some(tx_index) = transaction_committer.try_commit() {
if !validate(&scheduler, &versioned_state, tx_index, true) {
transaction_committer.uncommit();
break;
}
if tx_index == DEFAULT_CHUNK_SIZE - 1 {
scheduler.halt();
break;
}
}
}
task = match task {
Task::ExecutionTask(tx_index) => {
let (_, writes) =
get_reads_writes_for(Task::ExecutionTask(tx_index), &versioned_state);
versioned_state
.pin_version(tx_index)
.apply_writes(&writes, &ContractClassMapping::default());
scheduler.finish_execution(tx_index);
Task::AskForTask
}
Task::ValidationTask(tx_index) => {
validate(&scheduler, &versioned_state, tx_index, false);
Task::AskForTask
}
Task::NoTaskAvailable => Task::AskForTask,
Task::AskForTask => scheduler.next_task(),
Task::Done => break,
}
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
assert!(scheduler.execution_index.load(Ordering::Acquire) >= DEFAULT_CHUNK_SIZE);
assert!(*scheduler.commit_index.lock().unwrap() == DEFAULT_CHUNK_SIZE);
assert!(scheduler.get_n_committed_txs() == DEFAULT_CHUNK_SIZE);
assert!(scheduler.done_marker.load(Ordering::Acquire));
let inner_versioned_state = versioned_state.into_inner_state();
for tx_index in 0..DEFAULT_CHUNK_SIZE {
assert_eq!(scheduler.get_tx_status(tx_index), TransactionStatus::Committed);
let storage_writes = inner_versioned_state.get_writes_of_index(tx_index).storage;
assert_eq!(
*storage_writes
.get(&(contract_address!(CONTRACT_ADDRESS), storage_key!(STORAGE_KEY)))
.unwrap(),
felt!(format!("{:x}", tx_index + 1).as_str())
);
}
}
fn get_reads_writes_for(
task: Task,
versioned_state: &ThreadSafeVersionedState<CachedState<DictStateReader>>,
) -> (StateMaps, StateMaps) {
match task {
Task::ExecutionTask(tx_index) => {
let state_proxy = match tx_index {
0 => {
return (
state_maps_with_single_storage_entry(0),
state_maps_with_single_storage_entry(1),
);
}
_ => versioned_state.pin_version(tx_index - 1),
};
let tx_written_value = SierraU128::from_storage(
&state_proxy,
&contract_address!(CONTRACT_ADDRESS),
&storage_key!(STORAGE_KEY),
)
.unwrap()
.as_value();
(
state_maps_with_single_storage_entry(tx_written_value),
state_maps_with_single_storage_entry(tx_written_value + 1),
)
}
Task::ValidationTask(tx_index) => {
let state_proxy = versioned_state.pin_version(tx_index);
let tx_written_value = SierraU128::from_storage(
&state_proxy,
&contract_address!(CONTRACT_ADDRESS),
&storage_key!(STORAGE_KEY),
)
.unwrap()
.as_value();
(
state_maps_with_single_storage_entry(tx_written_value - 1),
state_maps_with_single_storage_entry(tx_written_value),
)
}
_ => panic!("Only execution and validation tasks shold be used here."),
}
}
fn state_maps_with_single_storage_entry(value: u128) -> StateMaps {
StateMaps {
storage: HashMap::from([(
(contract_address!(CONTRACT_ADDRESS), storage_key!(STORAGE_KEY)),
felt!(value),
)]),
..Default::default()
}
}