use std::future::Future;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex;
use tokio::time::Instant;
use tokio::{sync::oneshot, task::JoinHandle};
use tokio_util::sync::CancellationToken;
use tracing::{debug, trace, warn};
#[derive(Clone)]
pub struct Debouncer {
name: Arc<str>,
debounce_duration: Duration,
max_debounce: Option<Duration>,
cancel_task_timeout: Duration,
current_task: Arc<Mutex<Option<TaskToken>>>,
timer_handle: Arc<Mutex<Option<TimerHandle>>>,
event_handler: Option<EventHandler>,
cancel_token: CancellationToken,
}
pub type StoredTask = Arc<
dyn Fn(CancellationToken) -> std::pin::Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync,
>;
#[derive(Clone)]
pub struct StoredTaskDebouncer {
debouncer: Debouncer,
task_fn: StoredTask,
}
pub trait DebouncedTask: Send + Sync + 'static {
fn execute(&self, token: CancellationToken) -> impl Future<Output = ()> + Send + Sync;
}
#[derive(Clone)]
pub struct TaskDebouncer<T: DebouncedTask> {
debouncer: Debouncer,
task: Arc<T>,
}
#[derive(Debug)]
pub struct TimerHandle {
debounced_at: Instant,
first_debounce_at: Option<Instant>,
timer_token: CancellationToken,
join_handle: JoinHandle<()>,
exit_rx: oneshot::Receiver<TaskExit>,
}
#[derive(Debug)]
pub struct TaskToken {
pub started_at: Instant,
pub task_token: CancellationToken,
}
#[derive(Clone, Debug, Copy, PartialEq, Eq)]
pub enum TaskExit {
Normal,
Cancelled,
Aborted,
NotStarted,
}
#[derive(Clone, Debug)]
pub enum DebounceEvent {
Debounced {
instant: Instant,
first_debounce_at: Option<Instant>,
debounce_ends_at: Instant,
},
Started {
instant: Instant,
},
Ended {
instant: Instant,
exit_status: TaskExit,
},
}
pub type EventHandler = Arc<dyn Fn(DebounceEvent) + Send + Sync + 'static>;
macro_rules! impl_debouncer_builder {
($type:ty) => {
impl $type {
pub fn with_task_timeout(mut self, task_timeout: Duration) -> Self {
self.debouncer = self.debouncer.with_task_timeout(task_timeout);
self
}
pub fn with_max_wait(mut self, max_wait: Duration) -> Self {
self.debouncer = self.debouncer.with_max_wait(max_wait);
self
}
pub fn with_event_handler<E: Fn(DebounceEvent) + Send + Sync + 'static>(
mut self,
event_handler: E,
) -> Self {
self.debouncer = self.debouncer.with_event_handler(event_handler);
self
}
pub async fn stop(&self) {
self.debouncer.stop().await
}
}
};
}
impl Debouncer {
pub fn new(
debounce_duration: Duration,
cancel_token: CancellationToken,
task_name: impl AsRef<str>,
) -> Self {
Self {
debounce_duration,
max_debounce: None,
cancel_task_timeout: Duration::from_secs(2),
timer_handle: Arc::new(Mutex::new(None)),
cancel_token,
name: Arc::from(task_name.as_ref()),
current_task: Arc::new(Mutex::new(None)),
event_handler: None,
}
}
pub fn with_task_timeout(mut self, task_timeout: Duration) -> Self {
self.cancel_task_timeout = task_timeout;
self
}
pub fn with_max_wait(mut self, max_wait: Duration) -> Self {
self.max_debounce = Some(max_wait);
self
}
pub fn with_event_handler<E: Fn(DebounceEvent) + Send + Sync + 'static>(
mut self,
event_handler: E,
) -> Self {
self.event_handler = Some(Arc::new(event_handler));
self
}
fn fire_event(&self, event: DebounceEvent) {
if let Some(event_handler) = &self.event_handler {
event_handler(event)
}
}
fn should_wait_for_debounce(&self, first_debounce_at: Option<Instant>) -> bool {
if let Some(max_wait) = self.max_debounce {
if let Some(first_debounce) = first_debounce_at {
let wait_period = Instant::now().duration_since(first_debounce);
if wait_period >= max_wait {
trace!(
"task {:?} exceeded the max debounce waiting period, {wait_period:?} >= {max_wait:?}",
self.name
);
return false;
}
}
}
true
}
async fn spawn_task<Task, Fut>(&self, task: Task) -> TaskExit
where
Task: FnOnce(CancellationToken) -> Fut + Send + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let task_token = CancellationToken::new();
let now = {
let mut current_task_token = self.current_task.lock().await;
if let Some(current_token) = current_task_token.take() {
current_token.task_token.cancel();
}
let now = Instant::now();
*current_task_token = Some(TaskToken {
started_at: now,
task_token: task_token.clone(),
});
now
};
let task_handle = tokio::spawn({
let task_token = task_token.clone();
async move {
task(task_token).await;
}
});
tokio::pin!(task_handle);
self.fire_event(DebounceEvent::Started { instant: now });
let exit_status = tokio::select! {
result = &mut task_handle => {
match result {
Ok(()) => TaskExit::Normal,
Err(_) => TaskExit::Aborted,
}
}
_ = self.wait_for_cancellation(&task_token) => {
tokio::select! {
_ = tokio::time::sleep(self.cancel_task_timeout) => {
task_handle.abort();
TaskExit::Aborted
}
result = &mut task_handle => {
match result {
Ok(()) => TaskExit::Cancelled,
Err(_) => TaskExit::Aborted,
}
}
}
}
};
self.fire_event(DebounceEvent::Ended {
exit_status,
instant: Instant::now(),
});
exit_status
}
async fn wait_for_cancellation(&self, task_token: &CancellationToken) {
tokio::select! {
_ = self.cancel_token.cancelled() => {
trace!("task {:?} debouncer token cancelled", self.name);
task_token.cancel();
}
_ = task_token.cancelled() => {
trace!("task {:?} task token cancelled", self.name);
}
}
}
async fn timer<F, Fut>(
&self,
now: Instant,
first_debounce_at: Option<Instant>,
exit_tx: oneshot::Sender<TaskExit>,
timer_token: CancellationToken,
f: F,
) where
F: FnOnce(CancellationToken) -> Fut + Send + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
self.fire_event(DebounceEvent::Debounced {
instant: now,
debounce_ends_at: now.checked_add(self.debounce_duration).unwrap_or(now),
first_debounce_at,
});
let exit_status = tokio::select! {
_ = timer_token.cancelled() => {
trace!("task {:?} timer cancelled by next trigger", self.name);
TaskExit::NotStarted
}
_ = self.cancel_token.cancelled() => {
trace!("task {:?} cancelled while waiting to execute next task", self.name);
TaskExit::NotStarted
}
_ = self.wait_for_debounce(first_debounce_at) => {
self.spawn_task(f).await
}
};
let _ = exit_tx.send(exit_status);
}
async fn wait_for_debounce(&self, first_debounce_at: Option<Instant>) {
if self.should_wait_for_debounce(first_debounce_at) {
tokio::time::sleep(self.debounce_duration).await;
}
}
pub async fn run_now<Task, Fut>(&self, task: Task)
where
Task: FnOnce(CancellationToken) -> Fut + Send + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let mut timer_handle = self.timer_handle.lock().await;
if let Some(timer_handle) = timer_handle.take() {
if !timer_handle
.cancel_with_timeout(self.cancel_task_timeout)
.await
{
warn!("task {:?} aborted timer handle", self.name);
}
}
drop(timer_handle);
self.spawn_task(task).await;
}
pub async fn debounce<Task, Fut>(&self, task: Task)
where
Task: FnOnce(CancellationToken) -> Fut + Send + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let timer_token = CancellationToken::new();
let (exit_tx, exit_rx) = oneshot::channel();
let mut timer_handle = self.timer_handle.lock().await;
let first_debounce_at = if let Some(existing_timer) = timer_handle.take() {
let first_debounce = Some(
existing_timer
.first_debounce_at
.unwrap_or(existing_timer.debounced_at),
);
if !existing_timer
.cancel_with_timeout(self.cancel_task_timeout)
.await
{
warn!("task {:?} aborted timer handle", self.name);
}
first_debounce
} else {
None
};
let now = Instant::now();
let debouncer = self.clone();
*timer_handle = Some(TimerHandle {
exit_rx,
timer_token: timer_token.clone(),
debounced_at: now,
first_debounce_at,
join_handle: tokio::spawn(async move {
debouncer
.timer(now, first_debounce_at, exit_tx, timer_token, task)
.await;
}),
});
}
pub async fn stop(&self) {
debug!("task {:?} stopping...", self.name);
if let Some(current_token) = self.current_task.lock().await.take() {
trace!("task {:?} running, cancelling task...", self.name);
current_token.task_token.cancel();
}
if let Some(timer_handle) = self.timer_handle.lock().await.take() {
trace!("task {:?} waiting on timer....", self.name);
if !timer_handle
.cancel_with_timeout(self.cancel_task_timeout)
.await
{
warn!("task {:?} aborted timer handle", self.name);
}
}
debug!("task {:?} stopped", self.name);
}
}
impl TimerHandle {
pub async fn cancel_with_timeout(self, timeout: Duration) -> bool {
self.timer_token.cancel();
tokio::select! {
_ = tokio::time::sleep(timeout) => {
self.join_handle.abort();
false
}
_ = self.exit_rx => {
true
}
}
}
}
impl StoredTaskDebouncer {
pub fn new<F, Fut>(
debounce_timeout: Duration,
debouncer_token: CancellationToken,
task_type: impl AsRef<str>,
task_fn: F,
) -> Self
where
F: Fn(CancellationToken) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
StoredTaskDebouncer {
debouncer: Debouncer::new(debounce_timeout, debouncer_token, task_type),
task_fn: Arc::new(move |token| Box::pin(task_fn(token))),
}
}
pub fn set_task<F, Fut>(&mut self, task_fn: F)
where
F: Fn(CancellationToken) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
self.task_fn = Arc::new(move |token| Box::pin(task_fn(token)));
}
pub async fn debounce(&self) {
let task_fn = self.task_fn.clone();
self.debouncer
.debounce(move |token| async move { task_fn(token).await })
.await
}
}
impl_debouncer_builder!(StoredTaskDebouncer);
impl<T: DebouncedTask> TaskDebouncer<T> {
pub fn new(
debounce_timeout: Duration,
debouncer_token: CancellationToken,
task_type: impl AsRef<str>,
task: T,
) -> Self {
TaskDebouncer {
debouncer: Debouncer::new(debounce_timeout, debouncer_token, task_type),
task: Arc::new(task),
}
}
pub fn set_task(&mut self, task: T) {
self.task = Arc::new(task);
}
pub async fn debounce(&self) {
let task = self.task.clone();
self.debouncer
.debounce(move |token| async move { task.execute(token).await })
.await
}
}
impl<T: DebouncedTask> TaskDebouncer<T> {
pub fn with_task_timeout(mut self, task_timeout: Duration) -> Self {
self.debouncer = self.debouncer.with_task_timeout(task_timeout);
self
}
pub fn with_max_wait(mut self, max_wait: Duration) -> Self {
self.debouncer = self.debouncer.with_max_wait(max_wait);
self
}
pub fn with_event_handler<E: Fn(DebounceEvent) + Send + Sync + 'static>(
mut self,
event_handler: E,
) -> Self {
self.debouncer = self.debouncer.with_event_handler(event_handler);
self
}
pub async fn stop(&self) {
self.debouncer.stop().await
}
}