use crate::config::ServiceConfig;
use crate::error::TimerError;
use crate::task::{CallbackWrapper, CompletionReceiver, TaskCompletion, TaskId};
use crate::wheel::Wheel;
use crate::{BatchHandle, TimerHandle};
use futures::future::BoxFuture;
use futures::stream::{FuturesUnordered, StreamExt};
use lite_sync::{
oneshot::lite::{Receiver, Sender, channel},
spsc,
};
use parking_lot::Mutex;
use std::sync::Arc;
use std::time::Duration;
use tokio::task::JoinHandle;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TaskNotification {
OneShot(TaskId),
Periodic(TaskId),
}
impl TaskNotification {
pub fn task_id(&self) -> TaskId {
match self {
TaskNotification::OneShot(id) => *id,
TaskNotification::Periodic(id) => *id,
}
}
pub fn is_oneshot(&self) -> bool {
matches!(self, TaskNotification::OneShot(_))
}
pub fn is_periodic(&self) -> bool {
matches!(self, TaskNotification::Periodic(_))
}
}
enum ServiceCommand {
AddBatchHandle {
task_ids: Vec<TaskId>,
completion_rxs: Vec<CompletionReceiver>,
},
AddTimerHandle {
task_id: TaskId,
completion_rx: CompletionReceiver,
},
}
pub struct TimerService {
command_tx: spsc::Sender<ServiceCommand, 32>,
timeout_rx: Option<spsc::Receiver<TaskNotification, 32>>,
actor_handle: Option<JoinHandle<()>>,
wheel: Arc<Mutex<Wheel>>,
shutdown_tx: Option<Sender<()>>,
}
impl TimerService {
pub fn allocate_handle(&self) -> crate::task::TaskHandle {
self.wheel.lock().allocate_handle()
}
pub fn allocate_handles(&self, count: usize) -> Vec<crate::task::TaskHandle> {
self.wheel.lock().allocate_handles(count)
}
pub(crate) fn new(wheel: Arc<Mutex<Wheel>>, config: ServiceConfig) -> Self {
let (command_tx, command_rx) = spsc::channel(config.command_channel_capacity);
let (timeout_tx, timeout_rx) = spsc::channel(config.timeout_channel_capacity);
let (shutdown_tx, shutdown_rx) = channel::<()>();
let actor = ServiceActor::new(command_rx, timeout_tx, shutdown_rx);
let actor_handle = tokio::spawn(async move {
actor.run().await;
});
Self {
command_tx,
timeout_rx: Some(timeout_rx),
actor_handle: Some(actor_handle),
wheel,
shutdown_tx: Some(shutdown_tx),
}
}
pub fn take_receiver(&mut self) -> Option<spsc::Receiver<TaskNotification, 32>> {
self.timeout_rx.take()
}
#[inline]
pub fn cancel_task(&self, task_id: TaskId) -> bool {
let mut wheel = self.wheel.lock();
wheel.cancel(task_id)
}
#[inline]
pub fn cancel_batch(&self, task_ids: &[TaskId]) -> usize {
if task_ids.is_empty() {
return 0;
}
let mut wheel = self.wheel.lock();
wheel.cancel_batch(task_ids)
}
#[inline]
pub fn postpone(
&self,
task_id: TaskId,
new_delay: Duration,
callback: Option<CallbackWrapper>,
) -> bool {
let mut wheel = self.wheel.lock();
wheel.postpone(task_id, new_delay, callback)
}
#[inline]
pub fn postpone_batch(&self, updates: Vec<(TaskId, Duration)>) -> usize {
if updates.is_empty() {
return 0;
}
let mut wheel = self.wheel.lock();
wheel.postpone_batch(updates)
}
#[inline]
pub fn postpone_batch_with_callbacks(
&self,
updates: Vec<(TaskId, Duration, Option<CallbackWrapper>)>,
) -> usize {
if updates.is_empty() {
return 0;
}
let mut wheel = self.wheel.lock();
wheel.postpone_batch_with_callbacks(updates)
}
#[inline]
pub fn register(
&self,
handle: crate::task::TaskHandle,
task: crate::task::TimerTask,
) -> Result<TimerHandle, TimerError> {
let task_id = handle.task_id();
let (task, completion_rx) =
crate::task::TimerTaskWithCompletionNotifier::from_timer_task(task);
{
let mut wheel_guard = self.wheel.lock();
wheel_guard.insert(handle, task);
}
self.command_tx
.try_send(ServiceCommand::AddTimerHandle {
task_id,
completion_rx,
})
.map_err(|_| TimerError::RegisterFailed)?;
Ok(TimerHandle::new(task_id, self.wheel.clone()))
}
#[inline]
pub fn register_batch(
&self,
handles: Vec<crate::task::TaskHandle>,
tasks: Vec<crate::task::TimerTask>,
) -> Result<BatchHandle, TimerError> {
if handles.len() != tasks.len() {
return Err(TimerError::BatchLengthMismatch {
handles_len: handles.len(),
tasks_len: tasks.len(),
});
}
let task_count = tasks.len();
let mut completion_rxs = Vec::with_capacity(task_count);
let mut task_ids = Vec::with_capacity(task_count);
let mut prepared_handles = Vec::with_capacity(task_count);
let mut prepared_tasks = Vec::with_capacity(task_count);
for (handle, task) in handles.into_iter().zip(tasks.into_iter()) {
let task_id = handle.task_id();
let (task, completion_rx) =
crate::task::TimerTaskWithCompletionNotifier::from_timer_task(task);
task_ids.push(task_id);
completion_rxs.push(completion_rx);
prepared_handles.push(handle);
prepared_tasks.push(task);
}
{
let mut wheel_guard = self.wheel.lock();
wheel_guard.insert_batch(prepared_handles, prepared_tasks)?;
}
self.command_tx
.try_send(ServiceCommand::AddBatchHandle {
task_ids: task_ids.clone(),
completion_rxs,
})
.map_err(|_| TimerError::RegisterFailed)?;
Ok(BatchHandle::new(task_ids, self.wheel.clone()))
}
pub async fn shutdown(mut self) {
if let Some(shutdown_tx) = self.shutdown_tx.take() {
shutdown_tx.notify(());
}
if let Some(handle) = self.actor_handle.take() {
let _ = handle.await;
}
}
}
impl Drop for TimerService {
fn drop(&mut self) {
if let Some(handle) = self.actor_handle.take() {
handle.abort();
}
}
}
struct ServiceActor {
command_rx: spsc::Receiver<ServiceCommand, 32>,
timeout_tx: spsc::Sender<TaskNotification, 32>,
shutdown_rx: Receiver<()>,
}
impl ServiceActor {
fn new(
command_rx: spsc::Receiver<ServiceCommand, 32>,
timeout_tx: spsc::Sender<TaskNotification, 32>,
shutdown_rx: Receiver<()>,
) -> Self {
Self {
command_rx,
timeout_tx,
shutdown_rx,
}
}
async fn run(self) {
let mut oneshot_futures: FuturesUnordered<BoxFuture<'static, (TaskId, TaskCompletion)>> =
FuturesUnordered::new();
type PeriodicFutureResult = (
TaskId,
Option<TaskCompletion>,
crate::task::PeriodicCompletionReceiver,
);
let mut periodic_futures: FuturesUnordered<BoxFuture<'static, PeriodicFutureResult>> =
FuturesUnordered::new();
let mut shutdown_rx = self.shutdown_rx;
loop {
tokio::select! {
_ = &mut shutdown_rx => {
break;
}
Some((task_id, completion)) = oneshot_futures.next() => {
if completion == TaskCompletion::Called {
let _ = self.timeout_tx.send(TaskNotification::OneShot(task_id)).await;
}
}
Some((task_id, reason, mut receiver)) = periodic_futures.next() => {
if let Some(TaskCompletion::Called) = reason {
let _ = self.timeout_tx.send(TaskNotification::Periodic(task_id)).await;
let future: BoxFuture<'static, PeriodicFutureResult> = Box::pin(async move {
let reason = receiver.recv().await;
(task_id, reason, receiver)
});
periodic_futures.push(future);
}
}
Some(cmd) = self.command_rx.recv() => {
match cmd {
ServiceCommand::AddBatchHandle { task_ids, completion_rxs } => {
for (task_id, rx) in task_ids.into_iter().zip(completion_rxs.into_iter()) {
match rx {
crate::task::CompletionReceiver::OneShot(receiver) => {
let future: BoxFuture<'static, (TaskId, TaskCompletion)> = Box::pin(async move {
(task_id, receiver.recv().await.unwrap())
});
oneshot_futures.push(future);
},
crate::task::CompletionReceiver::Periodic(mut receiver) => {
let future: BoxFuture<'static, PeriodicFutureResult> = Box::pin(async move {
let reason = receiver.recv().await;
(task_id, reason, receiver)
});
periodic_futures.push(future);
}
}
}
}
ServiceCommand::AddTimerHandle { task_id, completion_rx } => {
match completion_rx {
crate::task::CompletionReceiver::OneShot(receiver) => {
let future: BoxFuture<'static, (TaskId, TaskCompletion)> = Box::pin(async move {
(task_id, receiver.recv().await.unwrap())
});
oneshot_futures.push(future);
},
crate::task::CompletionReceiver::Periodic(mut receiver) => {
let future: BoxFuture<'static, PeriodicFutureResult> = Box::pin(async move {
let reason = receiver.recv().await;
(task_id, reason, receiver)
});
periodic_futures.push(future);
}
}
}
}
}
else => {
break;
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{TimerTask, TimerWheel};
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use std::time::Duration;
#[tokio::test]
async fn test_service_creation() {
let timer = TimerWheel::with_defaults();
let _service = timer.create_service(ServiceConfig::default());
}
#[tokio::test]
async fn test_add_timer_handle_and_receive_timeout() {
let timer = TimerWheel::with_defaults();
let mut service = timer.create_service(ServiceConfig::default());
let handle = service.allocate_handle();
let task_id = handle.task_id();
let task = TimerTask::new_oneshot(
Duration::from_millis(50),
Some(CallbackWrapper::new(|| async {})),
);
service.register(handle, task).unwrap();
let rx = service.take_receiver().unwrap();
let received_notification = tokio::time::timeout(Duration::from_millis(200), rx.recv())
.await
.expect("Should receive timeout notification")
.expect("Should receive Some value");
assert_eq!(received_notification, TaskNotification::OneShot(task_id));
}
#[tokio::test]
async fn test_shutdown() {
let timer = TimerWheel::with_defaults();
let service = timer.create_service(ServiceConfig::default());
let handle1 = service.allocate_handle();
let handle2 = service.allocate_handle();
let task1 = TimerTask::new_oneshot(Duration::from_secs(10), None);
let task2 = TimerTask::new_oneshot(Duration::from_secs(10), None);
service.register(handle1, task1).unwrap();
service.register(handle2, task2).unwrap();
service.shutdown().await;
}
#[tokio::test]
async fn test_schedule_once_direct() {
let timer = TimerWheel::with_defaults();
let mut service = timer.create_service(ServiceConfig::default());
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = Arc::clone(&counter);
let handle = service.allocate_handle();
let task_id = handle.task_id();
let task = TimerTask::new_oneshot(
Duration::from_millis(50),
Some(CallbackWrapper::new(move || {
let counter = Arc::clone(&counter_clone);
async move {
counter.fetch_add(1, Ordering::SeqCst);
}
})),
);
service.register(handle, task).unwrap();
let rx = service.take_receiver().unwrap();
let received_notification = tokio::time::timeout(Duration::from_millis(200), rx.recv())
.await
.expect("Should receive timeout notification")
.expect("Should receive Some value");
assert_eq!(received_notification, TaskNotification::OneShot(task_id));
tokio::time::sleep(Duration::from_millis(50)).await;
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_schedule_once_notify_direct() {
let timer = TimerWheel::with_defaults();
let mut service = timer.create_service(ServiceConfig::default());
let handle = service.allocate_handle();
let task_id = handle.task_id();
let task = TimerTask::new_oneshot(Duration::from_millis(50), None);
service.register(handle, task).unwrap();
let rx = service.take_receiver().unwrap();
let received_notification = tokio::time::timeout(Duration::from_millis(200), rx.recv())
.await
.expect("Should receive timeout notification")
.expect("Should receive Some value");
assert_eq!(received_notification, TaskNotification::OneShot(task_id));
}
#[tokio::test]
async fn test_task_timeout_cleans_up_task_sender() {
let timer = TimerWheel::with_defaults();
let mut service = timer.create_service(ServiceConfig::default());
let handle = service.allocate_handle();
let task_id = handle.task_id();
let task = TimerTask::new_oneshot(Duration::from_millis(50), None);
service.register(handle, task).unwrap();
let rx = service.take_receiver().unwrap();
let received_notification = tokio::time::timeout(Duration::from_millis(200), rx.recv())
.await
.expect("Should receive timeout notification")
.expect("Should receive Some value");
assert_eq!(received_notification, TaskNotification::OneShot(task_id));
tokio::time::sleep(Duration::from_millis(10)).await;
let cancelled = service.cancel_task(task_id);
assert!(!cancelled, "Timed out task should not exist anymore");
}
#[tokio::test]
async fn test_take_receiver_twice() {
let timer = TimerWheel::with_defaults();
let mut service = timer.create_service(ServiceConfig::default());
let rx1 = service.take_receiver();
assert!(rx1.is_some(), "First take_receiver should return Some");
let rx2 = service.take_receiver();
assert!(rx2.is_none(), "Second take_receiver should return None");
}
}