use crate::config::ServiceConfig;
use crate::task::{CallbackWrapper, TaskId};
use crate::{TaskNotification, TimerTask, TimerWheel};
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use std::time::Duration;
#[tokio::test]
async fn test_cancel_task() {
let timer = TimerWheel::with_defaults();
let service = timer.create_service(ServiceConfig::default());
let handle = service.allocate_handle();
let task_id = handle.task_id();
let task = TimerTask::new_oneshot(Duration::from_secs(10), None);
service.register(handle, task).unwrap();
let cancelled = service.cancel_task(task_id);
assert!(cancelled, "Task should be cancelled successfully");
let cancelled_again = service.cancel_task(task_id);
assert!(!cancelled_again, "Task should not exist anymore");
}
#[tokio::test]
async fn test_cancel_nonexistent_task() {
let timer = TimerWheel::with_defaults();
let service = timer.create_service(ServiceConfig::default());
let handle = service.allocate_handle();
let task = TimerTask::new_oneshot(Duration::from_millis(50), None);
service.register(handle, task).unwrap();
let fake_handle = service.allocate_handle();
let fake_task_id = fake_handle.task_id();
let cancelled = service.cancel_task(fake_task_id);
assert!(!cancelled, "Nonexistent task should not be cancelled");
}
#[tokio::test]
async fn test_cancel_task_spawns_background_task() {
let timer = TimerWheel::with_defaults();
let service = timer.create_service(ServiceConfig::default());
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = Arc::clone(&counter);
let handle = service.allocate_handle();
let task_id = handle.task_id();
let task = TimerTask::new_oneshot(
Duration::from_secs(10),
Some(CallbackWrapper::new(move || {
let counter = Arc::clone(&counter_clone);
async move {
counter.fetch_add(1, Ordering::SeqCst);
}
})),
);
service.register(handle, task).unwrap();
let cancelled = service.cancel_task(task_id);
assert!(cancelled, "Task should be cancelled successfully");
tokio::time::sleep(Duration::from_millis(100)).await;
assert_eq!(
counter.load(Ordering::SeqCst),
0,
"Callback should not have been executed"
);
let cancelled_again = service.cancel_task(task_id);
assert!(
!cancelled_again,
"Task should have been removed from active_tasks"
);
}
#[tokio::test]
async fn test_schedule_and_cancel_direct() {
let timer = TimerWheel::with_defaults();
let service = timer.create_service(ServiceConfig::default());
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = Arc::clone(&counter);
let handle = service.allocate_handle();
let task_id = handle.task_id();
let task = TimerTask::new_oneshot(
Duration::from_secs(10),
Some(CallbackWrapper::new(move || {
let counter = Arc::clone(&counter_clone);
async move {
counter.fetch_add(1, Ordering::SeqCst);
}
})),
);
service.register(handle, task).unwrap();
let cancelled = service.cancel_task(task_id);
assert!(cancelled, "Task should be cancelled successfully");
tokio::time::sleep(Duration::from_millis(100)).await;
assert_eq!(
counter.load(Ordering::SeqCst),
0,
"Callback should not have been executed"
);
}
#[tokio::test]
async fn test_cancel_batch_direct() {
let timer = TimerWheel::with_defaults();
let service = timer.create_service(ServiceConfig::default());
let counter = Arc::new(AtomicU32::new(0));
let handles = service.allocate_handles(10);
let task_ids: Vec<_> = handles.iter().map(|h| h.task_id()).collect();
let tasks: Vec<_> = (0..10)
.map(|_| {
let counter = Arc::clone(&counter);
TimerTask::new_oneshot(
Duration::from_secs(10),
Some(CallbackWrapper::new(move || {
let counter = Arc::clone(&counter);
async move {
counter.fetch_add(1, Ordering::SeqCst);
}
})),
)
})
.collect();
assert_eq!(task_ids.len(), 10);
service.register_batch(handles, tasks).unwrap();
let cancelled = service.cancel_batch(&task_ids);
assert_eq!(cancelled, 10, "All 10 tasks should be cancelled");
tokio::time::sleep(Duration::from_millis(100)).await;
assert_eq!(
counter.load(Ordering::SeqCst),
0,
"No callbacks should have been executed"
);
}
#[tokio::test]
async fn test_cancel_batch_partial() {
let timer = TimerWheel::with_defaults();
let service = timer.create_service(ServiceConfig::default());
let counter = Arc::new(AtomicU32::new(0));
let handles = service.allocate_handles(10);
let task_ids: Vec<_> = handles.iter().map(|h| h.task_id()).collect();
let tasks: Vec<_> = (0..10)
.map(|_| {
let counter = Arc::clone(&counter);
TimerTask::new_oneshot(
Duration::from_secs(10),
Some(CallbackWrapper::new(move || {
let counter = Arc::clone(&counter);
async move {
counter.fetch_add(1, Ordering::SeqCst);
}
})),
)
})
.collect();
service.register_batch(handles, tasks).unwrap();
let to_cancel: Vec<_> = task_ids.iter().take(5).copied().collect();
let cancelled = service.cancel_batch(&to_cancel);
assert_eq!(cancelled, 5, "5 tasks should be cancelled");
tokio::time::sleep(Duration::from_millis(100)).await;
assert_eq!(
counter.load(Ordering::SeqCst),
0,
"Cancelled tasks should not execute"
);
}
#[tokio::test]
async fn test_cancel_batch_empty() {
let timer = TimerWheel::with_defaults();
let service = timer.create_service(ServiceConfig::default());
let empty: Vec<TaskId> = vec![];
let cancelled = service.cancel_batch(&empty);
assert_eq!(cancelled, 0, "No tasks should be cancelled");
}
#[tokio::test]
async fn test_cancelled_task_not_forwarded_to_timeout_rx() {
let timer = TimerWheel::with_defaults();
let mut service = timer.create_service(ServiceConfig::default());
let handle1 = service.allocate_handle();
let task1_id = handle1.task_id();
let task1 = TimerTask::new_oneshot(Duration::from_secs(10), None);
service.register(handle1, task1).unwrap();
let handle2 = service.allocate_handle();
let task2_id = handle2.task_id();
let task2 = TimerTask::new_oneshot(Duration::from_millis(50), None);
service.register(handle2, task2).unwrap();
let cancelled = service.cancel_task(task1_id);
assert!(cancelled, "Task should be cancelled");
let rx = service.take_receiver().unwrap();
let received_notification = tokio::time::timeout(Duration::from_millis(200), rx.recv())
.await
.expect("Should receive timeout notification")
.expect("Should receive Some value");
assert_eq!(
received_notification,
TaskNotification::OneShot(task2_id),
"Should only receive expired task notification"
);
let no_more = tokio::time::timeout(Duration::from_millis(50), rx.recv()).await;
assert!(
no_more.is_err(),
"Should not receive any more notifications"
);
}