use super::ChainGuards;
use futures::FutureExt;
use linera_base::identifiers::ChainId;
use std::{
mem,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::Duration,
};
use tokio::{sync::Barrier, time::sleep};
#[tokio::test]
async fn dropped_guard_does_not_leak() {
let chain_id = ChainId::root(0);
let guards = ChainGuards::default();
let guard = guards.guard(chain_id).await;
assert_eq!(guards.active_guards(), 1);
mem::drop(guard);
assert_eq!(guards.active_guards(), 0);
}
#[tokio::test]
async fn guard_can_be_obtained_later_again() {
let chain_id = ChainId::root(0);
let guards = ChainGuards::default();
guards.guard(chain_id).await;
assert!(guards.guard(chain_id).now_or_never().is_some());
}
#[tokio::test(start_paused = true)]
async fn prevents_concurrent_access_to_the_same_chain() {
let chain_id = ChainId::root(0);
let access = ConcurrentAccessTest::default()
.spawn_two_tasks_to_obtain_guards_for(chain_id, chain_id)
.await;
assert_eq!(access, Access::Sequential);
}
#[tokio::test(start_paused = true)]
async fn allows_concurrent_access_to_different_chains() {
let access = ConcurrentAccessTest::default()
.spawn_two_tasks_to_obtain_guards_for(ChainId::root(0), ChainId::root(1))
.await;
assert_eq!(access, Access::Concurrent);
}
#[derive(Clone)]
pub struct ConcurrentAccessTest {
guards: ChainGuards,
after_first_guard_is_obtained: Arc<Barrier>,
first_task_finished: Arc<AtomicBool>,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum Access {
Concurrent,
Sequential,
}
impl Default for ConcurrentAccessTest {
fn default() -> Self {
ConcurrentAccessTest {
guards: ChainGuards::default(),
after_first_guard_is_obtained: Arc::new(Barrier::new(2)),
first_task_finished: Arc::new(AtomicBool::new(false)),
}
}
}
impl ConcurrentAccessTest {
pub async fn spawn_two_tasks_to_obtain_guards_for(
self,
first_chain: ChainId,
second_chain: ChainId,
) -> Access {
let first_task = tokio::spawn(self.clone().run_first_task(first_chain));
let second_task = tokio::spawn(self.run_second_task(second_chain));
first_task.await.expect("First task failed");
second_task.await.expect("Second task failed")
}
async fn run_first_task(self, chain_id: ChainId) {
let _guard = self.guards.guard(chain_id).await;
self.after_first_guard_is_obtained.wait().await;
sleep(Duration::from_secs(10)).await;
self.first_task_finished.store(true, Ordering::Release);
}
async fn run_second_task(self, chain_id: ChainId) -> Access {
self.after_first_guard_is_obtained.wait().await;
let _guard = self.guards.guard(chain_id).await;
match self.first_task_finished.load(Ordering::Acquire) {
false => Access::Concurrent,
true => Access::Sequential,
}
}
}