use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
pub struct AsyncHandle<T> {
handle: JoinHandle<T>,
cancel_token: CancellationToken,
}
impl<T> AsyncHandle<T> {
pub(crate) fn new(handle: JoinHandle<T>, cancel_token: CancellationToken) -> Self {
Self {
handle,
cancel_token,
}
}
pub fn cancel(&self) {
self.cancel_token.cancel();
}
pub fn is_cancelled(&self) -> bool {
self.cancel_token.is_cancelled()
}
pub fn is_running(&self) -> bool {
!self.handle.is_finished() && !self.cancel_token.is_cancelled()
}
pub fn is_finished(&self) -> bool {
self.handle.is_finished()
}
pub fn abort(&self) {
self.handle.abort();
}
pub fn cancellation_token(&self) -> &CancellationToken {
&self.cancel_token
}
}
impl<T> Drop for AsyncHandle<T> {
fn drop(&mut self) {
self.cancel_token.cancel();
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[tokio::test]
async fn test_async_handle_cancel() {
let token = CancellationToken::new();
let token_clone = token.clone();
let handle = tokio::spawn(async move {
loop {
if token_clone.is_cancelled() {
return "Cancelled";
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
});
let async_handle = AsyncHandle::new(handle, token);
assert!(async_handle.is_running());
async_handle.cancel();
assert!(async_handle.is_cancelled());
tokio::time::sleep(Duration::from_millis(50)).await;
assert!(async_handle.is_finished());
}
#[tokio::test]
async fn test_async_handle_drop_cancels() {
let token = CancellationToken::new();
let token_clone = token.clone();
let token_check = token.clone();
let handle = tokio::spawn(async move {
loop {
if token_clone.is_cancelled() {
break;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
});
{
let _async_handle = AsyncHandle::new(handle, token);
}
assert!(token_check.is_cancelled());
}
#[tokio::test]
async fn test_async_handle_abort() {
let token = CancellationToken::new();
let token_clone = token.clone();
let handle = tokio::spawn(async move {
loop {
if token_clone.is_cancelled() {
return "Cancelled";
}
tokio::time::sleep(Duration::from_millis(100)).await;
}
});
let async_handle = AsyncHandle::new(handle, token);
async_handle.abort();
tokio::time::sleep(Duration::from_millis(10)).await;
assert!(async_handle.is_finished());
}
}