1use crate::task::{CallbackWrapper, TaskId, TimerCallback};
2use crate::timer::{BatchHandle, TimerHandle};
3use crate::error::TimerError;
4use crate::wheel::Wheel;
5use futures::stream::{FuturesUnordered, StreamExt};
6use futures::future::BoxFuture;
7use parking_lot::Mutex;
8use rustc_hash::FxHashSet;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::sync::mpsc;
12use tokio::task::JoinHandle;
13
14enum ServiceCommand {
16 AddBatchHandle(BatchHandle),
18 AddTimerHandle(TimerHandle),
20 RemoveTasks {
22 task_ids: Vec<TaskId>,
23 },
24 Shutdown,
26}
27
28pub struct TimerService {
62 command_tx: mpsc::Sender<ServiceCommand>,
64 timeout_rx: Option<mpsc::Receiver<TaskId>>,
66 actor_handle: Option<JoinHandle<()>>,
68 wheel: Arc<Mutex<Wheel>>,
70}
71
72impl TimerService {
73 pub(crate) fn new(wheel: Arc<Mutex<Wheel>>) -> Self {
92 let (command_tx, command_rx) = mpsc::channel(512);
94 let (timeout_tx, timeout_rx) = mpsc::channel(1000);
95
96 let actor = ServiceActor::new(command_rx, timeout_tx);
97 let actor_handle = tokio::spawn(async move {
98 actor.run().await;
99 });
100
101 Self {
102 command_tx,
103 timeout_rx: Some(timeout_rx),
104 actor_handle: Some(actor_handle),
105 wheel,
106 }
107 }
108
109 async fn add_batch_handle(&self, batch: BatchHandle) -> Result<(), TimerError> {
111 self.command_tx
112 .send(ServiceCommand::AddBatchHandle(batch))
113 .await
114 .map_err(|_| TimerError::ChannelClosed)
115 }
116
117 async fn add_timer_handle(&self, handle: TimerHandle) -> Result<(), TimerError> {
119 self.command_tx
120 .send(ServiceCommand::AddTimerHandle(handle))
121 .await
122 .map_err(|_| TimerError::ChannelClosed)
123 }
124
125 pub fn take_receiver(&mut self) -> Option<mpsc::Receiver<TaskId>> {
149 self.timeout_rx.take()
150 }
151
152 pub async fn cancel_task(&self, task_id: TaskId) -> Result<bool, String> {
183 let success = {
186 let mut wheel = self.wheel.lock();
187 wheel.cancel(task_id)
188 };
189
190 if success {
192 let _ = self.command_tx
193 .send(ServiceCommand::RemoveTasks {
194 task_ids: vec![task_id]
195 })
196 .await;
197 }
198
199 Ok(success)
200 }
201
202 pub async fn cancel_batch(&self, task_ids: &[TaskId]) -> usize {
232 if task_ids.is_empty() {
233 return 0;
234 }
235
236 let cancelled_count = {
238 let mut wheel = self.wheel.lock();
239 wheel.cancel_batch(task_ids)
240 };
241
242 let _ = self.command_tx
244 .send(ServiceCommand::RemoveTasks {
245 task_ids: task_ids.to_vec()
246 })
247 .await;
248
249 cancelled_count
250 }
251
252 pub async fn schedule_once<C>(&self, delay: Duration, callback: C) -> Result<TaskId, TimerError>
281 where
282 C: TimerCallback,
283 {
284 let handle = self.create_timer_handle(delay, Some(Arc::new(callback)))?;
286 let task_id = handle.task_id();
287
288 self.add_timer_handle(handle).await?;
290
291 Ok(task_id)
292 }
293
294 pub async fn schedule_once_batch<C>(&self, callbacks: Vec<(Duration, C)>) -> Result<Vec<TaskId>, TimerError>
325 where
326 C: TimerCallback,
327 {
328 let batch_handle = self.create_batch_handle(callbacks)?;
330 let task_ids = batch_handle.task_ids().to_vec();
331
332 self.add_batch_handle(batch_handle).await?;
334
335 Ok(task_ids)
336 }
337
338 pub async fn schedule_once_notify(&self, delay: Duration) -> Result<TaskId, TimerError> {
365 let handle = self.create_timer_handle(delay, None)?;
367 let task_id = handle.task_id();
368
369 self.add_timer_handle(handle).await?;
371
372 Ok(task_id)
373 }
374
375 fn create_timer_handle(
377 &self,
378 delay: Duration,
379 callback: Option<CallbackWrapper>,
380 ) -> Result<TimerHandle, TimerError> {
381 crate::timer::TimerWheel::create_timer_handle_internal(
382 &self.wheel,
383 delay,
384 callback
385 )
386 }
387
388 fn create_batch_handle<C>(
390 &self,
391 callbacks: Vec<(Duration, C)>,
392 ) -> Result<BatchHandle, TimerError>
393 where
394 C: TimerCallback,
395 {
396 crate::timer::TimerWheel::create_batch_handle_internal(
397 &self.wheel,
398 callbacks
399 )
400 }
401
402 pub async fn shutdown(mut self) {
418 let _ = self.command_tx.send(ServiceCommand::Shutdown).await;
419 if let Some(handle) = self.actor_handle.take() {
420 let _ = handle.await;
421 }
422 }
423}
424
425
426impl Drop for TimerService {
427 fn drop(&mut self) {
428 if let Some(handle) = self.actor_handle.take() {
429 handle.abort();
430 }
431 }
432}
433
434struct ServiceActor {
436 command_rx: mpsc::Receiver<ServiceCommand>,
438 timeout_tx: mpsc::Sender<TaskId>,
440 active_tasks: FxHashSet<TaskId>,
442}
443
444impl ServiceActor {
445 fn new(command_rx: mpsc::Receiver<ServiceCommand>, timeout_tx: mpsc::Sender<TaskId>) -> Self {
446 Self {
447 command_rx,
448 timeout_tx,
449 active_tasks: FxHashSet::default(),
450 }
451 }
452
453 async fn run(mut self) {
454 let mut futures: FuturesUnordered<BoxFuture<'static, (TaskId, Result<(), tokio::sync::oneshot::error::RecvError>)>> = FuturesUnordered::new();
457
458 loop {
459 tokio::select! {
460 Some((task_id, _result)) = futures.next() => {
462 let _ = self.timeout_tx.send(task_id).await;
464 self.active_tasks.remove(&task_id);
466 }
468
469 Some(cmd) = self.command_rx.recv() => {
471 match cmd {
472 ServiceCommand::AddBatchHandle(batch) => {
473 let BatchHandle {
474 task_ids,
475 completion_rxs,
476 ..
477 } = batch;
478
479 for (task_id, rx) in task_ids.into_iter().zip(completion_rxs.into_iter()) {
481 self.active_tasks.insert(task_id);
483
484 let future: BoxFuture<'static, (TaskId, Result<(), tokio::sync::oneshot::error::RecvError>)> = Box::pin(async move {
485 (task_id, rx.await)
486 });
487 futures.push(future);
488 }
489 }
490 ServiceCommand::AddTimerHandle(handle) => {
491 let TimerHandle{
492 task_id,
493 completion_rx,
494 ..
495 } = handle;
496
497 self.active_tasks.insert(task_id);
499
500 let future: BoxFuture<'static, (TaskId, Result<(), tokio::sync::oneshot::error::RecvError>)> = Box::pin(async move {
502 (task_id, completion_rx.0.await)
503 });
504 futures.push(future);
505 }
506 ServiceCommand::RemoveTasks { task_ids } => {
507 for task_id in task_ids {
510 self.active_tasks.remove(&task_id);
511 }
512 }
513 ServiceCommand::Shutdown => {
514 break;
515 }
516 }
517 }
518
519 else => {
521 break;
522 }
523 }
524 }
525 }
526}
527
528#[cfg(test)]
529mod tests {
530 use super::*;
531 use crate::TimerWheel;
532 use std::sync::atomic::{AtomicU32, Ordering};
533 use std::sync::Arc;
534 use std::time::Duration;
535
536 #[tokio::test]
537 async fn test_service_creation() {
538 let timer = TimerWheel::with_defaults().unwrap();
539 let _service = timer.create_service();
540 }
541
542
543 #[tokio::test]
544 async fn test_add_timer_handle_and_receive_timeout() {
545 let timer = TimerWheel::with_defaults().unwrap();
546 let mut service = timer.create_service();
547
548 let handle = timer.schedule_once(Duration::from_millis(50), || async {}).await.unwrap();
550 let task_id = handle.task_id();
551
552 service.add_timer_handle(handle).await.unwrap();
554
555 let mut rx = service.take_receiver().unwrap();
557 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
558 .await
559 .expect("Should receive timeout notification")
560 .expect("Should receive Some value");
561
562 assert_eq!(received_task_id, task_id);
563 }
564
565
566 #[tokio::test]
567 async fn test_shutdown() {
568 let timer = TimerWheel::with_defaults().unwrap();
569 let service = timer.create_service();
570
571 let _task_id1 = service.schedule_once(Duration::from_secs(10), || async {}).await.unwrap();
573 let _task_id2 = service.schedule_once(Duration::from_secs(10), || async {}).await.unwrap();
574
575 service.shutdown().await;
577 }
578
579
580
581 #[tokio::test]
582 async fn test_cancel_task() {
583 let timer = TimerWheel::with_defaults().unwrap();
584 let service = timer.create_service();
585
586 let handle = timer.schedule_once(Duration::from_secs(10), || async {}).await.unwrap();
588 let task_id = handle.task_id();
589
590 service.add_timer_handle(handle).await.unwrap();
591
592 let cancelled = service.cancel_task(task_id).await.unwrap();
594 assert!(cancelled, "Task should be cancelled successfully");
595
596 let cancelled_again = service.cancel_task(task_id).await.unwrap();
598 assert!(!cancelled_again, "Task should not exist anymore");
599 }
600
601 #[tokio::test]
602 async fn test_cancel_nonexistent_task() {
603 let timer = TimerWheel::with_defaults().unwrap();
604 let service = timer.create_service();
605
606 let handle = timer.schedule_once(Duration::from_millis(50), || async {}).await.unwrap();
608 service.add_timer_handle(handle).await.unwrap();
609
610 let fake_task_id = TaskId::new();
612 let cancelled = service.cancel_task(fake_task_id).await.unwrap();
613 assert!(!cancelled, "Nonexistent task should not be cancelled");
614 }
615
616
617 #[tokio::test]
618 async fn test_task_timeout_cleans_up_task_sender() {
619 let timer = TimerWheel::with_defaults().unwrap();
620 let mut service = timer.create_service();
621
622 let handle = timer.schedule_once(Duration::from_millis(50), || async {}).await.unwrap();
624 let task_id = handle.task_id();
625
626 service.add_timer_handle(handle).await.unwrap();
627
628 let mut rx = service.take_receiver().unwrap();
630 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
631 .await
632 .expect("Should receive timeout notification")
633 .expect("Should receive Some value");
634
635 assert_eq!(received_task_id, task_id);
636
637 tokio::time::sleep(Duration::from_millis(10)).await;
639
640 let cancelled = service.cancel_task(task_id).await.unwrap();
642 assert!(!cancelled, "Timed out task should not exist anymore");
643 }
644
645 #[tokio::test]
646 async fn test_cancel_task_spawns_background_task() {
647 let timer = TimerWheel::with_defaults().unwrap();
648 let service = timer.create_service();
649 let counter = Arc::new(AtomicU32::new(0));
650
651 let counter_clone = Arc::clone(&counter);
653 let handle = timer.schedule_once(
654 Duration::from_secs(10),
655 move || {
656 let counter = Arc::clone(&counter_clone);
657 async move {
658 counter.fetch_add(1, Ordering::SeqCst);
659 }
660 },
661 ).await.unwrap();
662 let task_id = handle.task_id();
663
664 service.add_timer_handle(handle).await.unwrap();
665
666 let cancelled = service.cancel_task(task_id).await.unwrap();
668 assert!(cancelled, "Task should be cancelled successfully");
669
670 tokio::time::sleep(Duration::from_millis(100)).await;
672 assert_eq!(counter.load(Ordering::SeqCst), 0, "Callback should not have been executed");
673
674 let cancelled_again = service.cancel_task(task_id).await.unwrap();
676 assert!(!cancelled_again, "Task should have been removed from active_tasks");
677 }
678
679 #[tokio::test]
680 async fn test_schedule_once_direct() {
681 let timer = TimerWheel::with_defaults().unwrap();
682 let mut service = timer.create_service();
683 let counter = Arc::new(AtomicU32::new(0));
684
685 let counter_clone = Arc::clone(&counter);
687 let task_id = service.schedule_once(
688 Duration::from_millis(50),
689 move || {
690 let counter = Arc::clone(&counter_clone);
691 async move {
692 counter.fetch_add(1, Ordering::SeqCst);
693 }
694 },
695 ).await.unwrap();
696
697 let mut rx = service.take_receiver().unwrap();
699 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
700 .await
701 .expect("Should receive timeout notification")
702 .expect("Should receive Some value");
703
704 assert_eq!(received_task_id, task_id);
705
706 tokio::time::sleep(Duration::from_millis(50)).await;
708 assert_eq!(counter.load(Ordering::SeqCst), 1);
709 }
710
711 #[tokio::test]
712 async fn test_schedule_once_batch_direct() {
713 let timer = TimerWheel::with_defaults().unwrap();
714 let mut service = timer.create_service();
715 let counter = Arc::new(AtomicU32::new(0));
716
717 let callbacks: Vec<_> = (0..3)
719 .map(|_| {
720 let counter = Arc::clone(&counter);
721 (Duration::from_millis(50), move || {
722 let counter = Arc::clone(&counter);
723 async move {
724 counter.fetch_add(1, Ordering::SeqCst);
725 }
726 })
727 })
728 .collect();
729
730 let task_ids = service.schedule_once_batch(callbacks).await.unwrap();
731 assert_eq!(task_ids.len(), 3);
732
733 let mut received_count = 0;
735 let mut rx = service.take_receiver().unwrap();
736
737 while received_count < 3 {
738 match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
739 Ok(Some(_task_id)) => {
740 received_count += 1;
741 }
742 Ok(None) => break,
743 Err(_) => break,
744 }
745 }
746
747 assert_eq!(received_count, 3);
748
749 tokio::time::sleep(Duration::from_millis(50)).await;
751 assert_eq!(counter.load(Ordering::SeqCst), 3);
752 }
753
754 #[tokio::test]
755 async fn test_schedule_once_notify_direct() {
756 let timer = TimerWheel::with_defaults().unwrap();
757 let mut service = timer.create_service();
758
759 let task_id = service.schedule_once_notify(Duration::from_millis(50)).await.unwrap();
761
762 let mut rx = service.take_receiver().unwrap();
764 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
765 .await
766 .expect("Should receive timeout notification")
767 .expect("Should receive Some value");
768
769 assert_eq!(received_task_id, task_id);
770 }
771
772 #[tokio::test]
773 async fn test_schedule_and_cancel_direct() {
774 let timer = TimerWheel::with_defaults().unwrap();
775 let service = timer.create_service();
776 let counter = Arc::new(AtomicU32::new(0));
777
778 let counter_clone = Arc::clone(&counter);
780 let task_id = service.schedule_once(
781 Duration::from_secs(10),
782 move || {
783 let counter = Arc::clone(&counter_clone);
784 async move {
785 counter.fetch_add(1, Ordering::SeqCst);
786 }
787 },
788 ).await.unwrap();
789
790 let cancelled = service.cancel_task(task_id).await.unwrap();
792 assert!(cancelled, "Task should be cancelled successfully");
793
794 tokio::time::sleep(Duration::from_millis(100)).await;
796 assert_eq!(counter.load(Ordering::SeqCst), 0, "Callback should not have been executed");
797 }
798
799 #[tokio::test]
800 async fn test_cancel_batch_direct() {
801 let timer = TimerWheel::with_defaults().unwrap();
802 let service = timer.create_service();
803 let counter = Arc::new(AtomicU32::new(0));
804
805 let callbacks: Vec<_> = (0..10)
807 .map(|_| {
808 let counter = Arc::clone(&counter);
809 (Duration::from_secs(10), move || {
810 let counter = Arc::clone(&counter);
811 async move {
812 counter.fetch_add(1, Ordering::SeqCst);
813 }
814 })
815 })
816 .collect();
817
818 let task_ids = service.schedule_once_batch(callbacks).await.unwrap();
819 assert_eq!(task_ids.len(), 10);
820
821 let cancelled = service.cancel_batch(&task_ids).await;
823 assert_eq!(cancelled, 10, "All 10 tasks should be cancelled");
824
825 tokio::time::sleep(Duration::from_millis(100)).await;
827 assert_eq!(counter.load(Ordering::SeqCst), 0, "No callbacks should have been executed");
828 }
829
830 #[tokio::test]
831 async fn test_cancel_batch_partial() {
832 let timer = TimerWheel::with_defaults().unwrap();
833 let service = timer.create_service();
834 let counter = Arc::new(AtomicU32::new(0));
835
836 let callbacks: Vec<_> = (0..10)
838 .map(|_| {
839 let counter = Arc::clone(&counter);
840 (Duration::from_secs(10), move || {
841 let counter = Arc::clone(&counter);
842 async move {
843 counter.fetch_add(1, Ordering::SeqCst);
844 }
845 })
846 })
847 .collect();
848
849 let task_ids = service.schedule_once_batch(callbacks).await.unwrap();
850
851 let to_cancel: Vec<_> = task_ids.iter().take(5).copied().collect();
853 let cancelled = service.cancel_batch(&to_cancel).await;
854 assert_eq!(cancelled, 5, "5 tasks should be cancelled");
855
856 tokio::time::sleep(Duration::from_millis(100)).await;
858 assert_eq!(counter.load(Ordering::SeqCst), 0, "Cancelled tasks should not execute");
859 }
860
861 #[tokio::test]
862 async fn test_cancel_batch_empty() {
863 let timer = TimerWheel::with_defaults().unwrap();
864 let service = timer.create_service();
865
866 let empty: Vec<TaskId> = vec![];
868 let cancelled = service.cancel_batch(&empty).await;
869 assert_eq!(cancelled, 0, "No tasks should be cancelled");
870 }
871}
872