use anyhow::{bail, Result};
use std::{
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::Duration,
};
use tokio::{
sync::{
mpsc::{self, Sender},
Mutex,
},
task::JoinHandle,
};
type DurationVec = Arc<Mutex<Vec<Duration>>>;
pub struct DynTimeout {
cancelled: Arc<AtomicBool>,
durations: DurationVec,
sender: mpsc::Sender<()>,
thread: Option<JoinHandle<()>>,
receiver: mpsc::Receiver<()>,
max_waiting_time: Option<Duration>,
}
impl DynTimeout {
pub fn new(dur: Duration, callback: fn() -> ()) -> Self {
let durations: DurationVec = Arc::new(Mutex::new(vec![Duration::ZERO, dur]));
let thread_vec = durations.clone();
let cancelled = Arc::new(AtomicBool::new(false));
let thread_cancelled = cancelled.clone();
let (sender, mut receiver) = mpsc::channel::<()>(1);
let (tx, rx) = mpsc::channel::<()>(1);
Self {
cancelled,
durations,
sender,
receiver: rx,
thread: Some(tokio::task::spawn(async move {
loop {
let dur = {
match thread_vec.lock().await.pop() {
Some(dur) => dur,
None => break,
}
};
let _ = tokio::time::timeout(dur, async { receiver.recv().await }).await;
}
if !thread_cancelled.load(Ordering::Relaxed) {
callback();
}
tx.send(()).await.unwrap();
})),
max_waiting_time: None,
}
}
pub fn with_sender(dur: Duration, sender_in: Sender<()>) -> Self {
let durations: DurationVec = Arc::new(Mutex::new(vec![Duration::ZERO, dur]));
let thread_vec = durations.clone();
let cancelled = Arc::new(AtomicBool::new(false));
let thread_cancelled = cancelled.clone();
let (sender, mut receiver) = mpsc::channel::<()>(1);
let (tx, rx) = mpsc::channel::<()>(1);
Self {
cancelled,
durations,
sender,
receiver: rx,
thread: Some(tokio::task::spawn(async move {
loop {
let dur = {
match thread_vec.lock().await.pop() {
Some(dur) => dur,
None => break,
}
};
let _ = tokio::time::timeout(dur, async { receiver.recv().await }).await;
}
if !thread_cancelled.load(Ordering::Relaxed) {
sender_in.send(()).await.unwrap();
}
tx.send(()).await.unwrap();
})),
max_waiting_time: None,
}
}
pub fn set_max_waiting_time(&mut self, duration: Duration) {
self.max_waiting_time = Some(duration)
}
pub async fn add(&self, dur: Duration) -> Result<()> {
let mut durations = self.durations.lock().await;
if durations.is_empty() {
bail!("Timeout already reached")
}
if let Some(m) = self.max_waiting_time {
let mut tt = Duration::from_millis(0);
for d in durations.iter() {
tt += *d;
}
if tt >= m {
return Ok(());
}
}
durations.push(dur);
Ok(())
}
pub async fn sub(&self, dur: Duration) -> Result<()> {
let mut durations = self.durations.lock().await;
if durations.is_empty() {
bail!("Timeout already reached")
}
let mut pop_dur = Duration::default();
while pop_dur < dur && durations.len() > 1 {
pop_dur += durations.pop().unwrap();
}
if pop_dur > dur {
durations.push(pop_dur - dur);
}
Ok(())
}
pub async fn cancel(&mut self) -> Result<()> {
self.cancelled.store(true, Ordering::Relaxed);
self.durations.lock().await.clear();
self.sender.send(()).await?;
self.thread = None;
Ok(())
}
pub async fn wait(&mut self) -> Result<()> {
self.receiver.recv().await;
Ok(())
}
}