use pin_project_lite::pin_project;
use std::{future::Future, pin::Pin, task::Poll, time::Duration};
use tokio::time::{Instant, Sleep};
pub(crate) fn pausable_sleep(duration: Duration) -> PausableSleep {
PausableSleep::new(duration)
}
pin_project! {
#[derive(Debug)]
pub(crate) struct PausableSleep {
#[pin]
sleep: Sleep,
duration: Duration,
pause_state: SleepPauseState,
}
}
impl PausableSleep {
fn new(duration: Duration) -> Self {
Self {
sleep: tokio::time::sleep(duration),
duration,
pause_state: SleepPauseState::Running,
}
}
#[cfg_attr(not(unix), expect(dead_code))]
pub(crate) fn is_paused(&self) -> bool {
matches!(self.pause_state, SleepPauseState::Paused { .. })
}
pub(crate) fn pause(self: Pin<&mut Self>) {
let this = self.project();
match &*this.pause_state {
SleepPauseState::Running => {
let deadline = this.sleep.deadline();
this.sleep.reset(far_future());
let remaining = deadline.duration_since(Instant::now());
*this.pause_state = SleepPauseState::Paused { remaining };
}
SleepPauseState::Paused { remaining } => {
panic!(
"illegal state transition: pause() called while sleep was paused (remaining = {remaining:?})"
);
}
}
}
pub(crate) fn resume(self: Pin<&mut Self>) {
let this = self.project();
match &*this.pause_state {
SleepPauseState::Paused { remaining } => {
this.sleep.reset(Instant::now() + *remaining);
*this.pause_state = SleepPauseState::Running;
}
SleepPauseState::Running => {
panic!("illegal state transition: resume() called while sleep was running");
}
}
}
pub(crate) fn reset(self: Pin<&mut Self>, duration: Duration) {
let this = self.project();
*this.duration = duration;
match this.pause_state {
SleepPauseState::Running => {
this.sleep.reset(Instant::now() + duration);
}
SleepPauseState::Paused { remaining } => {
*remaining = duration;
}
}
}
pub(crate) fn reset_last_duration(self: Pin<&mut Self>) {
let duration = self.duration;
self.reset(duration);
}
}
impl Future for PausableSleep {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
let this = self.project();
this.sleep.poll(cx)
}
}
#[derive(Debug, PartialEq, Eq)]
enum SleepPauseState {
Running,
Paused { remaining: Duration },
}
fn far_future() -> Instant {
Instant::now() + far_future_duration()
}
pub(crate) const fn far_future_duration() -> Duration {
Duration::from_secs(86400 * 365 * 30)
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn reset_on_sleep() {
const TICK: Duration = Duration::from_millis(500);
let mut sleep = std::pin::pin!(pausable_sleep(Duration::from_millis(1)));
sleep.as_mut().pause();
assert!(
!sleep.as_mut().sleep.is_elapsed(),
"underlying sleep has been suspended"
);
sleep.as_mut().reset(TICK);
assert_eq!(
sleep.as_ref().pause_state,
SleepPauseState::Paused { remaining: TICK }
);
assert!(
!sleep.as_mut().sleep.is_elapsed(),
"underlying sleep is still suspended"
);
tokio::time::sleep(2 * TICK).await;
assert!(
!sleep.as_mut().sleep.is_elapsed(),
"underlying sleep is still suspended after waiting 2 ticks"
);
let now = Instant::now();
sleep.as_mut().resume();
sleep.as_mut().await;
assert!(
sleep.as_mut().sleep.is_elapsed(),
"underlying sleep has finally elapsed"
);
assert!(now.elapsed() >= TICK);
}
}