1use crate::config::ServiceConfig;
2use crate::task::{CallbackWrapper, TaskId, TimerCallback};
3use crate::timer::{BatchHandle, TimerHandle};
4use crate::wheel::Wheel;
5use futures::stream::{FuturesUnordered, StreamExt};
6use futures::future::BoxFuture;
7use parking_lot::Mutex;
8use rustc_hash::FxHashSet;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::sync::mpsc;
12use tokio::task::JoinHandle;
13
14enum ServiceCommand {
16 AddBatchHandle(BatchHandle),
18 AddTimerHandle(TimerHandle),
20 RemoveTasks {
22 task_ids: Vec<TaskId>,
23 },
24 Shutdown,
26}
27
28pub struct TimerService {
62 command_tx: mpsc::Sender<ServiceCommand>,
64 timeout_rx: Option<mpsc::Receiver<TaskId>>,
66 actor_handle: Option<JoinHandle<()>>,
68 wheel: Arc<Mutex<Wheel>>,
70}
71
72impl TimerService {
73 pub(crate) fn new(wheel: Arc<Mutex<Wheel>>, config: ServiceConfig) -> Self {
93 let (command_tx, command_rx) = mpsc::channel(config.command_channel_capacity);
94 let (timeout_tx, timeout_rx) = mpsc::channel(config.timeout_channel_capacity);
95
96 let actor = ServiceActor::new(command_rx, timeout_tx);
97 let actor_handle = tokio::spawn(async move {
98 actor.run().await;
99 });
100
101 Self {
102 command_tx,
103 timeout_rx: Some(timeout_rx),
104 actor_handle: Some(actor_handle),
105 wheel,
106 }
107 }
108
109 async fn add_batch_handle(&self, batch: BatchHandle) {
111 let _ = self.command_tx
112 .send(ServiceCommand::AddBatchHandle(batch))
113 .await;
114 }
115
116 async fn add_timer_handle(&self, handle: TimerHandle) {
118 let _ = self.command_tx
119 .send(ServiceCommand::AddTimerHandle(handle))
120 .await;
121 }
122
123 pub fn take_receiver(&mut self) -> Option<mpsc::Receiver<TaskId>> {
147 self.timeout_rx.take()
148 }
149
150 pub async fn cancel_task(&self, task_id: TaskId) -> bool {
181 let success = {
184 let mut wheel = self.wheel.lock();
185 wheel.cancel(task_id)
186 };
187
188 if success {
190 let _ = self.command_tx
191 .send(ServiceCommand::RemoveTasks {
192 task_ids: vec![task_id]
193 })
194 .await;
195 }
196
197 success
198 }
199
200 pub async fn cancel_batch(&self, task_ids: &[TaskId]) -> usize {
230 if task_ids.is_empty() {
231 return 0;
232 }
233
234 let cancelled_count = {
236 let mut wheel = self.wheel.lock();
237 wheel.cancel_batch(task_ids)
238 };
239
240 let _ = self.command_tx
242 .send(ServiceCommand::RemoveTasks {
243 task_ids: task_ids.to_vec()
244 })
245 .await;
246
247 cancelled_count
248 }
249
250 pub async fn schedule_once<C>(&self, delay: Duration, callback: C) -> TaskId
278 where
279 C: TimerCallback,
280 {
281 let handle = self.create_timer_handle(delay, Some(Arc::new(callback)));
283 let task_id = handle.task_id();
284
285 self.add_timer_handle(handle).await;
287
288 task_id
289 }
290
291 pub async fn schedule_once_batch<C>(&self, callbacks: Vec<(Duration, C)>) -> Vec<TaskId>
321 where
322 C: TimerCallback,
323 {
324 let batch_handle = self.create_batch_handle(callbacks);
326 let task_ids = batch_handle.task_ids().to_vec();
327
328 self.add_batch_handle(batch_handle).await;
330
331 task_ids
332 }
333
334 pub async fn schedule_once_notify(&self, delay: Duration) -> TaskId {
360 let handle = self.create_timer_handle(delay, None);
362 let task_id = handle.task_id();
363
364 self.add_timer_handle(handle).await;
366
367 task_id
368 }
369
370 fn create_timer_handle(
372 &self,
373 delay: Duration,
374 callback: Option<CallbackWrapper>,
375 ) -> TimerHandle {
376 crate::timer::TimerWheel::create_timer_handle_internal(
377 &self.wheel,
378 delay,
379 callback
380 )
381 }
382
383 fn create_batch_handle<C>(
385 &self,
386 callbacks: Vec<(Duration, C)>,
387 ) -> BatchHandle
388 where
389 C: TimerCallback,
390 {
391 crate::timer::TimerWheel::create_batch_handle_internal(
392 &self.wheel,
393 callbacks
394 )
395 }
396
397 pub async fn shutdown(mut self) {
413 let _ = self.command_tx.send(ServiceCommand::Shutdown).await;
414 if let Some(handle) = self.actor_handle.take() {
415 let _ = handle.await;
416 }
417 }
418}
419
420
421impl Drop for TimerService {
422 fn drop(&mut self) {
423 if let Some(handle) = self.actor_handle.take() {
424 handle.abort();
425 }
426 }
427}
428
429struct ServiceActor {
431 command_rx: mpsc::Receiver<ServiceCommand>,
433 timeout_tx: mpsc::Sender<TaskId>,
435 active_tasks: FxHashSet<TaskId>,
437}
438
439impl ServiceActor {
440 fn new(command_rx: mpsc::Receiver<ServiceCommand>, timeout_tx: mpsc::Sender<TaskId>) -> Self {
441 Self {
442 command_rx,
443 timeout_tx,
444 active_tasks: FxHashSet::default(),
445 }
446 }
447
448 async fn run(mut self) {
449 let mut futures: FuturesUnordered<BoxFuture<'static, (TaskId, Result<(), tokio::sync::oneshot::error::RecvError>)>> = FuturesUnordered::new();
452
453 loop {
454 tokio::select! {
455 Some((task_id, _result)) = futures.next() => {
457 let _ = self.timeout_tx.send(task_id).await;
459 self.active_tasks.remove(&task_id);
461 }
463
464 Some(cmd) = self.command_rx.recv() => {
466 match cmd {
467 ServiceCommand::AddBatchHandle(batch) => {
468 let BatchHandle {
469 task_ids,
470 completion_rxs,
471 ..
472 } = batch;
473
474 for (task_id, rx) in task_ids.into_iter().zip(completion_rxs.into_iter()) {
476 self.active_tasks.insert(task_id);
478
479 let future: BoxFuture<'static, (TaskId, Result<(), tokio::sync::oneshot::error::RecvError>)> = Box::pin(async move {
480 (task_id, rx.await)
481 });
482 futures.push(future);
483 }
484 }
485 ServiceCommand::AddTimerHandle(handle) => {
486 let TimerHandle{
487 task_id,
488 completion_rx,
489 ..
490 } = handle;
491
492 self.active_tasks.insert(task_id);
494
495 let future: BoxFuture<'static, (TaskId, Result<(), tokio::sync::oneshot::error::RecvError>)> = Box::pin(async move {
497 (task_id, completion_rx.0.await)
498 });
499 futures.push(future);
500 }
501 ServiceCommand::RemoveTasks { task_ids } => {
502 for task_id in task_ids {
505 self.active_tasks.remove(&task_id);
506 }
507 }
508 ServiceCommand::Shutdown => {
509 break;
510 }
511 }
512 }
513
514 else => {
516 break;
517 }
518 }
519 }
520 }
521}
522
523#[cfg(test)]
524mod tests {
525 use super::*;
526 use crate::TimerWheel;
527 use std::sync::atomic::{AtomicU32, Ordering};
528 use std::sync::Arc;
529 use std::time::Duration;
530
531 #[tokio::test]
532 async fn test_service_creation() {
533 let timer = TimerWheel::with_defaults();
534 let _service = timer.create_service();
535 }
536
537
538 #[tokio::test]
539 async fn test_add_timer_handle_and_receive_timeout() {
540 let timer = TimerWheel::with_defaults();
541 let mut service = timer.create_service();
542
543 let handle = timer.schedule_once(Duration::from_millis(50), || async {}).await;
545 let task_id = handle.task_id();
546
547 service.add_timer_handle(handle).await;
549
550 let mut rx = service.take_receiver().unwrap();
552 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
553 .await
554 .expect("Should receive timeout notification")
555 .expect("Should receive Some value");
556
557 assert_eq!(received_task_id, task_id);
558 }
559
560
561 #[tokio::test]
562 async fn test_shutdown() {
563 let timer = TimerWheel::with_defaults();
564 let service = timer.create_service();
565
566 let _task_id1 = service.schedule_once(Duration::from_secs(10), || async {}).await;
568 let _task_id2 = service.schedule_once(Duration::from_secs(10), || async {}).await;
569
570 service.shutdown().await;
572 }
573
574
575
576 #[tokio::test]
577 async fn test_cancel_task() {
578 let timer = TimerWheel::with_defaults();
579 let service = timer.create_service();
580
581 let handle = timer.schedule_once(Duration::from_secs(10), || async {}).await;
583 let task_id = handle.task_id();
584
585 service.add_timer_handle(handle).await;
586
587 let cancelled = service.cancel_task(task_id).await;
589 assert!(cancelled, "Task should be cancelled successfully");
590
591 let cancelled_again = service.cancel_task(task_id).await;
593 assert!(!cancelled_again, "Task should not exist anymore");
594 }
595
596 #[tokio::test]
597 async fn test_cancel_nonexistent_task() {
598 let timer = TimerWheel::with_defaults();
599 let service = timer.create_service();
600
601 let handle = timer.schedule_once(Duration::from_millis(50), || async {}).await;
603 service.add_timer_handle(handle).await;
604
605 let fake_task_id = TaskId::new();
607 let cancelled = service.cancel_task(fake_task_id).await;
608 assert!(!cancelled, "Nonexistent task should not be cancelled");
609 }
610
611
612 #[tokio::test]
613 async fn test_task_timeout_cleans_up_task_sender() {
614 let timer = TimerWheel::with_defaults();
615 let mut service = timer.create_service();
616
617 let handle = timer.schedule_once(Duration::from_millis(50), || async {}).await;
619 let task_id = handle.task_id();
620
621 service.add_timer_handle(handle).await;
622
623 let mut rx = service.take_receiver().unwrap();
625 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
626 .await
627 .expect("Should receive timeout notification")
628 .expect("Should receive Some value");
629
630 assert_eq!(received_task_id, task_id);
631
632 tokio::time::sleep(Duration::from_millis(10)).await;
634
635 let cancelled = service.cancel_task(task_id).await;
637 assert!(!cancelled, "Timed out task should not exist anymore");
638 }
639
640 #[tokio::test]
641 async fn test_cancel_task_spawns_background_task() {
642 let timer = TimerWheel::with_defaults();
643 let service = timer.create_service();
644 let counter = Arc::new(AtomicU32::new(0));
645
646 let counter_clone = Arc::clone(&counter);
648 let handle = timer.schedule_once(
649 Duration::from_secs(10),
650 move || {
651 let counter = Arc::clone(&counter_clone);
652 async move {
653 counter.fetch_add(1, Ordering::SeqCst);
654 }
655 },
656 ).await;
657 let task_id = handle.task_id();
658
659 service.add_timer_handle(handle).await;
660
661 let cancelled = service.cancel_task(task_id).await;
663 assert!(cancelled, "Task should be cancelled successfully");
664
665 tokio::time::sleep(Duration::from_millis(100)).await;
667 assert_eq!(counter.load(Ordering::SeqCst), 0, "Callback should not have been executed");
668
669 let cancelled_again = service.cancel_task(task_id).await;
671 assert!(!cancelled_again, "Task should have been removed from active_tasks");
672 }
673
674 #[tokio::test]
675 async fn test_schedule_once_direct() {
676 let timer = TimerWheel::with_defaults();
677 let mut service = timer.create_service();
678 let counter = Arc::new(AtomicU32::new(0));
679
680 let counter_clone = Arc::clone(&counter);
682 let task_id = service.schedule_once(
683 Duration::from_millis(50),
684 move || {
685 let counter = Arc::clone(&counter_clone);
686 async move {
687 counter.fetch_add(1, Ordering::SeqCst);
688 }
689 },
690 ).await;
691
692 let mut rx = service.take_receiver().unwrap();
694 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
695 .await
696 .expect("Should receive timeout notification")
697 .expect("Should receive Some value");
698
699 assert_eq!(received_task_id, task_id);
700
701 tokio::time::sleep(Duration::from_millis(50)).await;
703 assert_eq!(counter.load(Ordering::SeqCst), 1);
704 }
705
706 #[tokio::test]
707 async fn test_schedule_once_batch_direct() {
708 let timer = TimerWheel::with_defaults();
709 let mut service = timer.create_service();
710 let counter = Arc::new(AtomicU32::new(0));
711
712 let callbacks: Vec<_> = (0..3)
714 .map(|_| {
715 let counter = Arc::clone(&counter);
716 (Duration::from_millis(50), move || {
717 let counter = Arc::clone(&counter);
718 async move {
719 counter.fetch_add(1, Ordering::SeqCst);
720 }
721 })
722 })
723 .collect();
724
725 let task_ids = service.schedule_once_batch(callbacks).await;
726 assert_eq!(task_ids.len(), 3);
727
728 let mut received_count = 0;
730 let mut rx = service.take_receiver().unwrap();
731
732 while received_count < 3 {
733 match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
734 Ok(Some(_task_id)) => {
735 received_count += 1;
736 }
737 Ok(None) => break,
738 Err(_) => break,
739 }
740 }
741
742 assert_eq!(received_count, 3);
743
744 tokio::time::sleep(Duration::from_millis(50)).await;
746 assert_eq!(counter.load(Ordering::SeqCst), 3);
747 }
748
749 #[tokio::test]
750 async fn test_schedule_once_notify_direct() {
751 let timer = TimerWheel::with_defaults();
752 let mut service = timer.create_service();
753
754 let task_id = service.schedule_once_notify(Duration::from_millis(50)).await;
756
757 let mut rx = service.take_receiver().unwrap();
759 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
760 .await
761 .expect("Should receive timeout notification")
762 .expect("Should receive Some value");
763
764 assert_eq!(received_task_id, task_id);
765 }
766
767 #[tokio::test]
768 async fn test_schedule_and_cancel_direct() {
769 let timer = TimerWheel::with_defaults();
770 let service = timer.create_service();
771 let counter = Arc::new(AtomicU32::new(0));
772
773 let counter_clone = Arc::clone(&counter);
775 let task_id = service.schedule_once(
776 Duration::from_secs(10),
777 move || {
778 let counter = Arc::clone(&counter_clone);
779 async move {
780 counter.fetch_add(1, Ordering::SeqCst);
781 }
782 },
783 ).await;
784
785 let cancelled = service.cancel_task(task_id).await;
787 assert!(cancelled, "Task should be cancelled successfully");
788
789 tokio::time::sleep(Duration::from_millis(100)).await;
791 assert_eq!(counter.load(Ordering::SeqCst), 0, "Callback should not have been executed");
792 }
793
794 #[tokio::test]
795 async fn test_cancel_batch_direct() {
796 let timer = TimerWheel::with_defaults();
797 let service = timer.create_service();
798 let counter = Arc::new(AtomicU32::new(0));
799
800 let callbacks: Vec<_> = (0..10)
802 .map(|_| {
803 let counter = Arc::clone(&counter);
804 (Duration::from_secs(10), move || {
805 let counter = Arc::clone(&counter);
806 async move {
807 counter.fetch_add(1, Ordering::SeqCst);
808 }
809 })
810 })
811 .collect();
812
813 let task_ids = service.schedule_once_batch(callbacks).await;
814 assert_eq!(task_ids.len(), 10);
815
816 let cancelled = service.cancel_batch(&task_ids).await;
818 assert_eq!(cancelled, 10, "All 10 tasks should be cancelled");
819
820 tokio::time::sleep(Duration::from_millis(100)).await;
822 assert_eq!(counter.load(Ordering::SeqCst), 0, "No callbacks should have been executed");
823 }
824
825 #[tokio::test]
826 async fn test_cancel_batch_partial() {
827 let timer = TimerWheel::with_defaults();
828 let service = timer.create_service();
829 let counter = Arc::new(AtomicU32::new(0));
830
831 let callbacks: Vec<_> = (0..10)
833 .map(|_| {
834 let counter = Arc::clone(&counter);
835 (Duration::from_secs(10), move || {
836 let counter = Arc::clone(&counter);
837 async move {
838 counter.fetch_add(1, Ordering::SeqCst);
839 }
840 })
841 })
842 .collect();
843
844 let task_ids = service.schedule_once_batch(callbacks).await;
845
846 let to_cancel: Vec<_> = task_ids.iter().take(5).copied().collect();
848 let cancelled = service.cancel_batch(&to_cancel).await;
849 assert_eq!(cancelled, 5, "5 tasks should be cancelled");
850
851 tokio::time::sleep(Duration::from_millis(100)).await;
853 assert_eq!(counter.load(Ordering::SeqCst), 0, "Cancelled tasks should not execute");
854 }
855
856 #[tokio::test]
857 async fn test_cancel_batch_empty() {
858 let timer = TimerWheel::with_defaults();
859 let service = timer.create_service();
860
861 let empty: Vec<TaskId> = vec![];
863 let cancelled = service.cancel_batch(&empty).await;
864 assert_eq!(cancelled, 0, "No tasks should be cancelled");
865 }
866}
867