1use crate::config::ServiceConfig;
2use crate::task::{TaskId, TimerCallback};
3use crate::timer::{BatchHandle, TimerHandle};
4use crate::wheel::Wheel;
5use futures::stream::{FuturesUnordered, StreamExt};
6use futures::future::BoxFuture;
7use parking_lot::Mutex;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::sync::mpsc;
11use tokio::task::JoinHandle;
12
13enum ServiceCommand {
15 AddBatchHandle(BatchHandle),
17 AddTimerHandle(TimerHandle),
19 Shutdown,
21}
22
23pub struct TimerService {
58 command_tx: mpsc::Sender<ServiceCommand>,
60 timeout_rx: Option<mpsc::Receiver<TaskId>>,
62 actor_handle: Option<JoinHandle<()>>,
64 wheel: Arc<Mutex<Wheel>>,
66}
67
68impl TimerService {
69 pub(crate) fn new(wheel: Arc<Mutex<Wheel>>, config: ServiceConfig) -> Self {
89 let (command_tx, command_rx) = mpsc::channel(config.command_channel_capacity);
90 let (timeout_tx, timeout_rx) = mpsc::channel(config.timeout_channel_capacity);
91
92 let actor = ServiceActor::new(command_rx, timeout_tx);
93 let actor_handle = tokio::spawn(async move {
94 actor.run().await;
95 });
96
97 Self {
98 command_tx,
99 timeout_rx: Some(timeout_rx),
100 actor_handle: Some(actor_handle),
101 wheel,
102 }
103 }
104
105 async fn add_batch_handle(&self, batch: BatchHandle) {
107 let _ = self.command_tx
108 .send(ServiceCommand::AddBatchHandle(batch))
109 .await;
110 }
111
112 async fn add_timer_handle(&self, handle: TimerHandle) {
114 let _ = self.command_tx
115 .send(ServiceCommand::AddTimerHandle(handle))
116 .await;
117 }
118
119 pub fn take_receiver(&mut self) -> Option<mpsc::Receiver<TaskId>> {
143 self.timeout_rx.take()
144 }
145
146 #[inline]
179 pub async fn cancel_task(&self, task_id: TaskId) -> bool {
180 let mut wheel = self.wheel.lock();
183 wheel.cancel(task_id)
184 }
185
186 #[inline]
218 pub async fn cancel_batch(&self, task_ids: &[TaskId]) -> usize {
219 if task_ids.is_empty() {
220 return 0;
221 }
222
223 let mut wheel = self.wheel.lock();
226 wheel.cancel_batch(task_ids)
227 }
228
229 pub fn create_task<C>(delay: Duration, callback: C) -> crate::task::TimerTask
260 where
261 C: TimerCallback,
262 {
263 crate::timer::TimerWheel::create_task(delay, callback)
264 }
265
266 pub fn create_batch<C>(callbacks: Vec<(Duration, C)>) -> Vec<crate::task::TimerTask>
298 where
299 C: TimerCallback,
300 {
301 crate::timer::TimerWheel::create_batch(callbacks)
302 }
303
304 #[inline]
327 pub async fn register(&self, task: crate::task::TimerTask) {
328 let (completion_tx, completion_rx) = tokio::sync::oneshot::channel();
329 let notifier = crate::task::CompletionNotifier(completion_tx);
330
331 let delay = task.delay;
332 let task_id = task.id;
333
334 {
336 let mut wheel_guard = self.wheel.lock();
337 wheel_guard.insert(delay, task, notifier);
338 }
339
340 let handle = TimerHandle::new(task_id, self.wheel.clone(), completion_rx);
342 self.add_timer_handle(handle).await;
343 }
344
345 #[inline]
368 pub async fn register_batch(&self, tasks: Vec<crate::task::TimerTask>) {
369 let task_count = tasks.len();
370 let mut completion_rxs = Vec::with_capacity(task_count);
371 let mut task_ids = Vec::with_capacity(task_count);
372 let mut prepared_tasks = Vec::with_capacity(task_count);
373
374 for task in tasks {
377 let (completion_tx, completion_rx) = tokio::sync::oneshot::channel();
378 let notifier = crate::task::CompletionNotifier(completion_tx);
379
380 task_ids.push(task.id);
381 completion_rxs.push(completion_rx);
382 prepared_tasks.push((task.delay, task, notifier));
383 }
384
385 {
387 let mut wheel_guard = self.wheel.lock();
388 wheel_guard.insert_batch(prepared_tasks);
389 }
390
391 let batch_handle = BatchHandle::new(task_ids, self.wheel.clone(), completion_rxs);
393 self.add_batch_handle(batch_handle).await;
394 }
395
396 pub async fn shutdown(mut self) {
412 let _ = self.command_tx.send(ServiceCommand::Shutdown).await;
413 if let Some(handle) = self.actor_handle.take() {
414 let _ = handle.await;
415 }
416 }
417}
418
419
420impl Drop for TimerService {
421 fn drop(&mut self) {
422 if let Some(handle) = self.actor_handle.take() {
423 handle.abort();
424 }
425 }
426}
427
428struct ServiceActor {
430 command_rx: mpsc::Receiver<ServiceCommand>,
432 timeout_tx: mpsc::Sender<TaskId>,
434}
435
436impl ServiceActor {
437 fn new(command_rx: mpsc::Receiver<ServiceCommand>, timeout_tx: mpsc::Sender<TaskId>) -> Self {
438 Self {
439 command_rx,
440 timeout_tx,
441 }
442 }
443
444 async fn run(mut self) {
445 let mut futures: FuturesUnordered<BoxFuture<'static, (TaskId, Result<(), tokio::sync::oneshot::error::RecvError>)>> = FuturesUnordered::new();
448
449 loop {
450 tokio::select! {
451 Some((task_id, _result)) = futures.next() => {
453 let _ = self.timeout_tx.send(task_id).await;
455 }
457
458 Some(cmd) = self.command_rx.recv() => {
460 match cmd {
461 ServiceCommand::AddBatchHandle(batch) => {
462 let BatchHandle {
463 task_ids,
464 completion_rxs,
465 ..
466 } = batch;
467
468 for (task_id, rx) in task_ids.into_iter().zip(completion_rxs.into_iter()) {
470 let future: BoxFuture<'static, (TaskId, Result<(), tokio::sync::oneshot::error::RecvError>)> = Box::pin(async move {
471 (task_id, rx.await)
472 });
473 futures.push(future);
474 }
475 }
476 ServiceCommand::AddTimerHandle(handle) => {
477 let TimerHandle{
478 task_id,
479 completion_rx,
480 ..
481 } = handle;
482
483 let future: BoxFuture<'static, (TaskId, Result<(), tokio::sync::oneshot::error::RecvError>)> = Box::pin(async move {
485 (task_id, completion_rx.0.await)
486 });
487 futures.push(future);
488 }
489 ServiceCommand::Shutdown => {
490 break;
491 }
492 }
493 }
494
495 else => {
497 break;
498 }
499 }
500 }
501 }
502}
503
504#[cfg(test)]
505mod tests {
506 use super::*;
507 use crate::TimerWheel;
508 use std::sync::atomic::{AtomicU32, Ordering};
509 use std::sync::Arc;
510 use std::time::Duration;
511
512 #[tokio::test]
513 async fn test_service_creation() {
514 let timer = TimerWheel::with_defaults();
515 let _service = timer.create_service();
516 }
517
518
519 #[tokio::test]
520 async fn test_add_timer_handle_and_receive_timeout() {
521 let timer = TimerWheel::with_defaults();
522 let mut service = timer.create_service();
523
524 let task = TimerWheel::create_task(Duration::from_millis(50), || async {});
526 let task_id = task.get_id();
527 let handle = timer.register(task);
528
529 service.add_timer_handle(handle).await;
531
532 let mut rx = service.take_receiver().unwrap();
534 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
535 .await
536 .expect("Should receive timeout notification")
537 .expect("Should receive Some value");
538
539 assert_eq!(received_task_id, task_id);
540 }
541
542
543 #[tokio::test]
544 async fn test_shutdown() {
545 let timer = TimerWheel::with_defaults();
546 let service = timer.create_service();
547
548 let task1 = TimerService::create_task(Duration::from_secs(10), || async {});
550 let task2 = TimerService::create_task(Duration::from_secs(10), || async {});
551 service.register(task1).await;
552 service.register(task2).await;
553
554 service.shutdown().await;
556 }
557
558
559
560 #[tokio::test]
561 async fn test_cancel_task() {
562 let timer = TimerWheel::with_defaults();
563 let service = timer.create_service();
564
565 let task = TimerWheel::create_task(Duration::from_secs(10), || async {});
567 let task_id = task.get_id();
568 let handle = timer.register(task);
569
570 service.add_timer_handle(handle).await;
571
572 let cancelled = service.cancel_task(task_id).await;
574 assert!(cancelled, "Task should be cancelled successfully");
575
576 let cancelled_again = service.cancel_task(task_id).await;
578 assert!(!cancelled_again, "Task should not exist anymore");
579 }
580
581 #[tokio::test]
582 async fn test_cancel_nonexistent_task() {
583 let timer = TimerWheel::with_defaults();
584 let service = timer.create_service();
585
586 let task = TimerWheel::create_task(Duration::from_millis(50), || async {});
588 let handle = timer.register(task);
589 service.add_timer_handle(handle).await;
590
591 let fake_task = TimerWheel::create_task(Duration::from_millis(50), || async {});
593 let fake_task_id = fake_task.get_id();
594 let cancelled = service.cancel_task(fake_task_id).await;
596 assert!(!cancelled, "Nonexistent task should not be cancelled");
597 }
598
599
600 #[tokio::test]
601 async fn test_task_timeout_cleans_up_task_sender() {
602 let timer = TimerWheel::with_defaults();
603 let mut service = timer.create_service();
604
605 let task = TimerWheel::create_task(Duration::from_millis(50), || async {});
607 let task_id = task.get_id();
608 let handle = timer.register(task);
609
610 service.add_timer_handle(handle).await;
611
612 let mut rx = service.take_receiver().unwrap();
614 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
615 .await
616 .expect("Should receive timeout notification")
617 .expect("Should receive Some value");
618
619 assert_eq!(received_task_id, task_id);
620
621 tokio::time::sleep(Duration::from_millis(10)).await;
623
624 let cancelled = service.cancel_task(task_id).await;
626 assert!(!cancelled, "Timed out task should not exist anymore");
627 }
628
629 #[tokio::test]
630 async fn test_cancel_task_spawns_background_task() {
631 let timer = TimerWheel::with_defaults();
632 let service = timer.create_service();
633 let counter = Arc::new(AtomicU32::new(0));
634
635 let counter_clone = Arc::clone(&counter);
637 let task = TimerWheel::create_task(
638 Duration::from_secs(10),
639 move || {
640 let counter = Arc::clone(&counter_clone);
641 async move {
642 counter.fetch_add(1, Ordering::SeqCst);
643 }
644 },
645 );
646 let task_id = task.get_id();
647 let handle = timer.register(task);
648
649 service.add_timer_handle(handle).await;
650
651 let cancelled = service.cancel_task(task_id).await;
653 assert!(cancelled, "Task should be cancelled successfully");
654
655 tokio::time::sleep(Duration::from_millis(100)).await;
657 assert_eq!(counter.load(Ordering::SeqCst), 0, "Callback should not have been executed");
658
659 let cancelled_again = service.cancel_task(task_id).await;
661 assert!(!cancelled_again, "Task should have been removed from active_tasks");
662 }
663
664 #[tokio::test]
665 async fn test_schedule_once_direct() {
666 let timer = TimerWheel::with_defaults();
667 let mut service = timer.create_service();
668 let counter = Arc::new(AtomicU32::new(0));
669
670 let counter_clone = Arc::clone(&counter);
672 let task = TimerService::create_task(
673 Duration::from_millis(50),
674 move || {
675 let counter = Arc::clone(&counter_clone);
676 async move {
677 counter.fetch_add(1, Ordering::SeqCst);
678 }
679 },
680 );
681 let task_id = task.get_id();
682 service.register(task).await;
683
684 let mut rx = service.take_receiver().unwrap();
686 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
687 .await
688 .expect("Should receive timeout notification")
689 .expect("Should receive Some value");
690
691 assert_eq!(received_task_id, task_id);
692
693 tokio::time::sleep(Duration::from_millis(50)).await;
695 assert_eq!(counter.load(Ordering::SeqCst), 1);
696 }
697
698 #[tokio::test]
699 async fn test_schedule_once_batch_direct() {
700 let timer = TimerWheel::with_defaults();
701 let mut service = timer.create_service();
702 let counter = Arc::new(AtomicU32::new(0));
703
704 let callbacks: Vec<_> = (0..3)
706 .map(|_| {
707 let counter = Arc::clone(&counter);
708 (Duration::from_millis(50), move || {
709 let counter = Arc::clone(&counter);
710 async move {
711 counter.fetch_add(1, Ordering::SeqCst);
712 }
713 })
714 })
715 .collect();
716
717 let tasks = TimerService::create_batch(callbacks);
718 assert_eq!(tasks.len(), 3);
719 service.register_batch(tasks).await;
720
721 let mut received_count = 0;
723 let mut rx = service.take_receiver().unwrap();
724
725 while received_count < 3 {
726 match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
727 Ok(Some(_task_id)) => {
728 received_count += 1;
729 }
730 Ok(None) => break,
731 Err(_) => break,
732 }
733 }
734
735 assert_eq!(received_count, 3);
736
737 tokio::time::sleep(Duration::from_millis(50)).await;
739 assert_eq!(counter.load(Ordering::SeqCst), 3);
740 }
741
742 #[tokio::test]
743 async fn test_schedule_once_notify_direct() {
744 let timer = TimerWheel::with_defaults();
745 let mut service = timer.create_service();
746
747 let task = crate::task::TimerTask::new(Duration::from_millis(50), None);
749 let task_id = task.get_id();
750 service.register(task).await;
751
752 let mut rx = service.take_receiver().unwrap();
754 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
755 .await
756 .expect("Should receive timeout notification")
757 .expect("Should receive Some value");
758
759 assert_eq!(received_task_id, task_id);
760 }
761
762 #[tokio::test]
763 async fn test_schedule_and_cancel_direct() {
764 let timer = TimerWheel::with_defaults();
765 let service = timer.create_service();
766 let counter = Arc::new(AtomicU32::new(0));
767
768 let counter_clone = Arc::clone(&counter);
770 let task = TimerService::create_task(
771 Duration::from_secs(10),
772 move || {
773 let counter = Arc::clone(&counter_clone);
774 async move {
775 counter.fetch_add(1, Ordering::SeqCst);
776 }
777 },
778 );
779 let task_id = task.get_id();
780 service.register(task).await;
781
782 let cancelled = service.cancel_task(task_id).await;
784 assert!(cancelled, "Task should be cancelled successfully");
785
786 tokio::time::sleep(Duration::from_millis(100)).await;
788 assert_eq!(counter.load(Ordering::SeqCst), 0, "Callback should not have been executed");
789 }
790
791 #[tokio::test]
792 async fn test_cancel_batch_direct() {
793 let timer = TimerWheel::with_defaults();
794 let service = timer.create_service();
795 let counter = Arc::new(AtomicU32::new(0));
796
797 let callbacks: Vec<_> = (0..10)
799 .map(|_| {
800 let counter = Arc::clone(&counter);
801 (Duration::from_secs(10), move || {
802 let counter = Arc::clone(&counter);
803 async move {
804 counter.fetch_add(1, Ordering::SeqCst);
805 }
806 })
807 })
808 .collect();
809
810 let tasks = TimerService::create_batch(callbacks);
811 let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
812 assert_eq!(task_ids.len(), 10);
813 service.register_batch(tasks).await;
814
815 let cancelled = service.cancel_batch(&task_ids).await;
817 assert_eq!(cancelled, 10, "All 10 tasks should be cancelled");
818
819 tokio::time::sleep(Duration::from_millis(100)).await;
821 assert_eq!(counter.load(Ordering::SeqCst), 0, "No callbacks should have been executed");
822 }
823
824 #[tokio::test]
825 async fn test_cancel_batch_partial() {
826 let timer = TimerWheel::with_defaults();
827 let service = timer.create_service();
828 let counter = Arc::new(AtomicU32::new(0));
829
830 let callbacks: Vec<_> = (0..10)
832 .map(|_| {
833 let counter = Arc::clone(&counter);
834 (Duration::from_secs(10), move || {
835 let counter = Arc::clone(&counter);
836 async move {
837 counter.fetch_add(1, Ordering::SeqCst);
838 }
839 })
840 })
841 .collect();
842
843 let tasks = TimerService::create_batch(callbacks);
844 let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
845 service.register_batch(tasks).await;
846
847 let to_cancel: Vec<_> = task_ids.iter().take(5).copied().collect();
849 let cancelled = service.cancel_batch(&to_cancel).await;
850 assert_eq!(cancelled, 5, "5 tasks should be cancelled");
851
852 tokio::time::sleep(Duration::from_millis(100)).await;
854 assert_eq!(counter.load(Ordering::SeqCst), 0, "Cancelled tasks should not execute");
855 }
856
857 #[tokio::test]
858 async fn test_cancel_batch_empty() {
859 let timer = TimerWheel::with_defaults();
860 let service = timer.create_service();
861
862 let empty: Vec<TaskId> = vec![];
864 let cancelled = service.cancel_batch(&empty).await;
865 assert_eq!(cancelled, 0, "No tasks should be cancelled");
866 }
867}
868