use casper_types::{TimeDiff, Timestamp};
use std::time::Duration;
use tokio::{select, sync::Mutex, time};
use tokio_util::sync::CancellationToken;
struct TerminationData {
terminate_at: Timestamp,
stop_countdown: CancellationToken,
}
pub(super) struct ConnectionTerminator {
cancellation_token: CancellationToken,
countdown_data: Mutex<Option<TerminationData>>,
}
impl ConnectionTerminator {
pub(super) async fn terminate_at(&self, in_terminate_at: Timestamp) -> bool {
let now = Timestamp::now();
if in_terminate_at <= now {
return false;
}
let terminate_in = Duration::from_millis(in_terminate_at.millis() - now.millis());
let mut countdown_data_guard = self.countdown_data.lock().await;
if let Some(TerminationData {
terminate_at,
stop_countdown,
}) = countdown_data_guard.as_ref()
{
if in_terminate_at < *terminate_at {
return false;
} else {
stop_countdown.cancel();
}
}
if self.cancellation_token.is_cancelled() {
return false;
}
let stop_countdown = self
.spawn_termination_countdown(terminate_in, self.cancellation_token.clone())
.await;
let data = TerminationData {
terminate_at: in_terminate_at,
stop_countdown,
};
*countdown_data_guard = Some(data);
true
}
pub(crate) async fn delay_termination(&self, delay_by: TimeDiff) -> bool {
let temrinate_at = Timestamp::now() + delay_by;
self.terminate_at(temrinate_at).await
}
pub(super) fn new() -> Self {
let cancellation_token = CancellationToken::new();
ConnectionTerminator {
cancellation_token,
countdown_data: Mutex::new(None),
}
}
pub(super) fn get_cancellation_token(&self) -> CancellationToken {
self.cancellation_token.clone()
}
async fn spawn_termination_countdown(
&self,
terminate_in: Duration,
cancellation_token: CancellationToken,
) -> CancellationToken {
let cancel_countdown = CancellationToken::new();
let cancel_countdown_to_move = cancel_countdown.clone();
tokio::task::spawn(async move {
select! {
_ = time::sleep(terminate_in) => {
cancellation_token.cancel()
},
_ = cancel_countdown_to_move.cancelled() => {
},
}
});
cancel_countdown
}
}
#[cfg(test)]
mod tests {
use super::ConnectionTerminator;
use casper_types::{TimeDiff, Timestamp};
use std::time::Duration;
use tokio::{select, time::sleep};
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn should_fail_setting_expiration_in_past() {
let terminator = ConnectionTerminator::new();
let in_past = Timestamp::from(1);
assert!(!terminator.terminate_at(in_past).await);
let initial_inactivity = TimeDiff::from_seconds(1);
let terminator = ConnectionTerminator::new();
let now = Timestamp::now();
assert!(terminator.terminate_at(now + initial_inactivity).await);
assert!(!terminator.terminate_at(now).await);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn should_fail_setting_expiration_when_already_cancelled() {
let initial_inactivity = TimeDiff::from_seconds(1);
let terminator = ConnectionTerminator::new();
let now = Timestamp::now();
assert!(terminator.terminate_at(now + initial_inactivity).await);
let cancellation_token = terminator.get_cancellation_token();
select! {
_ = cancellation_token.cancelled() => {
let elapsed = now.elapsed();
assert!(elapsed >= TimeDiff::from_seconds(1));
assert!(elapsed <= TimeDiff::from_millis(1500));
},
_ = sleep(Duration::from_secs(10)) => {
unreachable!()
},
}
let initial_inactivity = TimeDiff::from_seconds(10);
let now = Timestamp::now();
assert!(!terminator.terminate_at(now + initial_inactivity).await);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn should_cancel_after_enough_inactivity() {
let initial_inactivity = TimeDiff::from_seconds(1);
let terminator = ConnectionTerminator::new();
let now = Timestamp::now();
assert!(terminator.terminate_at(now + initial_inactivity).await);
let cancellation_token = terminator.get_cancellation_token();
select! {
_ = cancellation_token.cancelled() => {
let elapsed = now.elapsed();
assert!(elapsed >= TimeDiff::from_seconds(1));
assert!(elapsed <= TimeDiff::from_millis(1500));
},
_ = sleep(Duration::from_secs(10)) => {
unreachable!()
},
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn should_cancel_after_extended_time() {
let initial_inactivity = TimeDiff::from_seconds(1);
let terminator = ConnectionTerminator::new();
let now = Timestamp::now();
assert!(terminator.terminate_at(now + initial_inactivity).await);
sleep(Duration::from_millis(100)).await;
terminator
.delay_termination(TimeDiff::from_seconds(2))
.await;
let cancellation_token = terminator.get_cancellation_token();
select! {
_ = cancellation_token.cancelled() => {
let elapsed = now.elapsed();
assert!(elapsed >= TimeDiff::from_seconds(2));
assert!(elapsed <= TimeDiff::from_millis(2500));
},
_ = sleep(Duration::from_secs(10)) => {
unreachable!()
},
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn should_cancel_after_multiple_time_extensions() {
let initial_inactivity = TimeDiff::from_seconds(1);
let terminator = ConnectionTerminator::new();
let now = Timestamp::now();
assert!(terminator.terminate_at(now + initial_inactivity).await);
sleep(Duration::from_millis(100)).await;
terminator
.delay_termination(TimeDiff::from_seconds(2))
.await;
sleep(Duration::from_millis(100)).await;
terminator
.delay_termination(TimeDiff::from_seconds(3))
.await;
let cancellation_token = terminator.get_cancellation_token();
select! {
_ = cancellation_token.cancelled() => {
let elapsed = now.elapsed();
assert!(elapsed >= TimeDiff::from_seconds(3));
assert!(elapsed <= TimeDiff::from_millis(4000));
},
_ = sleep(Duration::from_secs(10)) => {
unreachable!()
},
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn should_not_shorten_termination_time() {
let initial_inactivity = TimeDiff::from_seconds(1);
let terminator = ConnectionTerminator::new();
let now = Timestamp::now();
assert!(terminator.terminate_at(now + initial_inactivity).await);
sleep(Duration::from_millis(100)).await;
terminator
.delay_termination(TimeDiff::from_seconds(2))
.await;
sleep(Duration::from_millis(100)).await;
terminator
.delay_termination(TimeDiff::from_seconds(1))
.await;
let cancellation_token = terminator.get_cancellation_token();
select! {
_ = cancellation_token.cancelled() => {
let elapsed = now.elapsed();
assert!(elapsed >= TimeDiff::from_seconds(2));
assert!(elapsed <= TimeDiff::from_millis(2500));
},
_ = sleep(Duration::from_secs(10)) => {
unreachable!()
},
}
}
}