1use crate::config::ServiceConfig;
2use crate::error::TimerError;
3use crate::task::{CallbackWrapper, TaskCompletionReason, TaskId};
4use crate::timer::{BatchHandle, TimerHandle};
5use crate::wheel::Wheel;
6use futures::stream::{FuturesUnordered, StreamExt};
7use futures::future::BoxFuture;
8use parking_lot::Mutex;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::sync::mpsc;
12use tokio::task::JoinHandle;
13
14enum ServiceCommand {
16 AddBatchHandle(BatchHandle),
18 AddTimerHandle(TimerHandle),
20 Shutdown,
22}
23
24pub struct TimerService {
60 command_tx: mpsc::Sender<ServiceCommand>,
62 timeout_rx: Option<mpsc::Receiver<TaskId>>,
64 actor_handle: Option<JoinHandle<()>>,
66 wheel: Arc<Mutex<Wheel>>,
68}
69
70impl TimerService {
71 pub(crate) fn new(wheel: Arc<Mutex<Wheel>>, config: ServiceConfig) -> Self {
91 let (command_tx, command_rx) = mpsc::channel(config.command_channel_capacity);
92 let (timeout_tx, timeout_rx) = mpsc::channel(config.timeout_channel_capacity);
93
94 let actor = ServiceActor::new(command_rx, timeout_tx);
95 let actor_handle = tokio::spawn(async move {
96 actor.run().await;
97 });
98
99 Self {
100 command_tx,
101 timeout_rx: Some(timeout_rx),
102 actor_handle: Some(actor_handle),
103 wheel,
104 }
105 }
106
107 pub fn take_receiver(&mut self) -> Option<mpsc::Receiver<TaskId>> {
131 self.timeout_rx.take()
132 }
133
134 #[inline]
168 pub fn cancel_task(&self, task_id: TaskId) -> bool {
169 let mut wheel = self.wheel.lock();
172 wheel.cancel(task_id)
173 }
174
175 #[inline]
208 pub fn cancel_batch(&self, task_ids: &[TaskId]) -> usize {
209 if task_ids.is_empty() {
210 return 0;
211 }
212
213 let mut wheel = self.wheel.lock();
216 wheel.cancel_batch(task_ids)
217 }
218
219 #[inline]
263 pub fn postpone(&self, task_id: TaskId, new_delay: Duration, callback: Option<CallbackWrapper>) -> bool {
264 let mut wheel = self.wheel.lock();
265 wheel.postpone(task_id, new_delay, callback)
266 }
267
268 #[inline]
303 pub fn postpone_batch(&self, updates: Vec<(TaskId, Duration)>) -> usize {
304 if updates.is_empty() {
305 return 0;
306 }
307
308 let mut wheel = self.wheel.lock();
309 wheel.postpone_batch(updates)
310 }
311
312 #[inline]
353 pub fn postpone_batch_with_callbacks(
354 &self,
355 updates: Vec<(TaskId, Duration, Option<CallbackWrapper>)>,
356 ) -> usize {
357 if updates.is_empty() {
358 return 0;
359 }
360
361 let mut wheel = self.wheel.lock();
362 wheel.postpone_batch_with_callbacks(updates)
363 }
364
365 #[inline]
398 pub fn create_task(delay: Duration, callback: Option<CallbackWrapper>) -> crate::task::TimerTask {
399 crate::timer::TimerWheel::create_task(delay, callback)
400 }
401
402 #[inline]
438 pub fn create_batch(callbacks: Vec<(Duration, Option<CallbackWrapper>)>) -> Vec<crate::task::TimerTask> {
439 crate::timer::TimerWheel::create_batch(callbacks)
440 }
441
442 #[inline]
471 pub fn register(&self, task: crate::task::TimerTask) -> Result<(), TimerError> {
472 let (completion_tx, completion_rx) = tokio::sync::oneshot::channel();
473 let notifier = crate::task::CompletionNotifier(completion_tx);
474
475 let delay = task.delay;
476 let task_id = task.id;
477
478 {
480 let mut wheel_guard = self.wheel.lock();
481 wheel_guard.insert(delay, task, notifier);
482 }
483
484 let handle = TimerHandle::new(task_id, self.wheel.clone(), completion_rx);
486 self.command_tx
487 .try_send(ServiceCommand::AddTimerHandle(handle))
488 .map_err(|_| TimerError::RegisterFailed)?;
489
490 Ok(())
491 }
492
493 #[inline]
521 pub fn register_batch(&self, tasks: Vec<crate::task::TimerTask>) -> Result<(), TimerError> {
522 let task_count = tasks.len();
523 let mut completion_rxs = Vec::with_capacity(task_count);
524 let mut task_ids = Vec::with_capacity(task_count);
525 let mut prepared_tasks = Vec::with_capacity(task_count);
526
527 for task in tasks {
530 let (completion_tx, completion_rx) = tokio::sync::oneshot::channel();
531 let notifier = crate::task::CompletionNotifier(completion_tx);
532
533 task_ids.push(task.id);
534 completion_rxs.push(completion_rx);
535 prepared_tasks.push((task.delay, task, notifier));
536 }
537
538 {
540 let mut wheel_guard = self.wheel.lock();
541 wheel_guard.insert_batch(prepared_tasks);
542 }
543
544 let batch_handle = BatchHandle::new(task_ids, self.wheel.clone(), completion_rxs);
546 self.command_tx
547 .try_send(ServiceCommand::AddBatchHandle(batch_handle))
548 .map_err(|_| TimerError::RegisterFailed)?;
549
550 Ok(())
551 }
552
553 pub async fn shutdown(mut self) {
569 let _ = self.command_tx.send(ServiceCommand::Shutdown).await;
570 if let Some(handle) = self.actor_handle.take() {
571 let _ = handle.await;
572 }
573 }
574}
575
576
577impl Drop for TimerService {
578 fn drop(&mut self) {
579 if let Some(handle) = self.actor_handle.take() {
580 handle.abort();
581 }
582 }
583}
584
585struct ServiceActor {
587 command_rx: mpsc::Receiver<ServiceCommand>,
589 timeout_tx: mpsc::Sender<TaskId>,
591}
592
593impl ServiceActor {
594 fn new(command_rx: mpsc::Receiver<ServiceCommand>, timeout_tx: mpsc::Sender<TaskId>) -> Self {
595 Self {
596 command_rx,
597 timeout_tx,
598 }
599 }
600
601 async fn run(mut self) {
602 let mut futures: FuturesUnordered<BoxFuture<'static, (TaskId, Result<TaskCompletionReason, tokio::sync::oneshot::error::RecvError>)>> = FuturesUnordered::new();
605
606 loop {
607 tokio::select! {
608 Some((task_id, result)) = futures.next() => {
610 if let Ok(TaskCompletionReason::Expired) = result {
612 let _ = self.timeout_tx.send(task_id).await;
613 }
614 }
616
617 Some(cmd) = self.command_rx.recv() => {
619 match cmd {
620 ServiceCommand::AddBatchHandle(batch) => {
621 let BatchHandle {
622 task_ids,
623 completion_rxs,
624 ..
625 } = batch;
626
627 for (task_id, rx) in task_ids.into_iter().zip(completion_rxs.into_iter()) {
629 let future: BoxFuture<'static, (TaskId, Result<TaskCompletionReason, tokio::sync::oneshot::error::RecvError>)> = Box::pin(async move {
630 (task_id, rx.await)
631 });
632 futures.push(future);
633 }
634 }
635 ServiceCommand::AddTimerHandle(handle) => {
636 let TimerHandle{
637 task_id,
638 completion_rx,
639 ..
640 } = handle;
641
642 let future: BoxFuture<'static, (TaskId, Result<TaskCompletionReason, tokio::sync::oneshot::error::RecvError>)> = Box::pin(async move {
644 (task_id, completion_rx.0.await)
645 });
646 futures.push(future);
647 }
648 ServiceCommand::Shutdown => {
649 break;
650 }
651 }
652 }
653
654 else => {
656 break;
657 }
658 }
659 }
660 }
661}
662
663#[cfg(test)]
664mod tests {
665 use super::*;
666 use crate::TimerWheel;
667 use std::sync::atomic::{AtomicU32, Ordering};
668 use std::sync::Arc;
669 use std::time::Duration;
670
671 #[tokio::test]
672 async fn test_service_creation() {
673 let timer = TimerWheel::with_defaults();
674 let _service = timer.create_service();
675 }
676
677
678 #[tokio::test]
679 async fn test_add_timer_handle_and_receive_timeout() {
680 let timer = TimerWheel::with_defaults();
681 let mut service = timer.create_service();
682
683 let task = TimerService::create_task(Duration::from_millis(50), Some(CallbackWrapper::new(|| async {})));
685 let task_id = task.get_id();
686
687 service.register(task).unwrap();
689
690 let mut rx = service.take_receiver().unwrap();
692 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
693 .await
694 .expect("Should receive timeout notification")
695 .expect("Should receive Some value");
696
697 assert_eq!(received_task_id, task_id);
698 }
699
700
701 #[tokio::test]
702 async fn test_shutdown() {
703 let timer = TimerWheel::with_defaults();
704 let service = timer.create_service();
705
706 let task1 = TimerService::create_task(Duration::from_secs(10), None);
708 let task2 = TimerService::create_task(Duration::from_secs(10), None);
709 service.register(task1).unwrap();
710 service.register(task2).unwrap();
711
712 service.shutdown().await;
714 }
715
716
717
718 #[tokio::test]
719 async fn test_cancel_task() {
720 let timer = TimerWheel::with_defaults();
721 let service = timer.create_service();
722
723 let task = TimerService::create_task(Duration::from_secs(10), None);
725 let task_id = task.get_id();
726
727 service.register(task).unwrap();
728
729 let cancelled = service.cancel_task(task_id);
731 assert!(cancelled, "Task should be cancelled successfully");
732
733 let cancelled_again = service.cancel_task(task_id);
735 assert!(!cancelled_again, "Task should not exist anymore");
736 }
737
738 #[tokio::test]
739 async fn test_cancel_nonexistent_task() {
740 let timer = TimerWheel::with_defaults();
741 let service = timer.create_service();
742
743 let task = TimerService::create_task(Duration::from_millis(50), None);
745 service.register(task).unwrap();
746
747 let fake_task = TimerService::create_task(Duration::from_millis(50), None);
749 let fake_task_id = fake_task.get_id();
750 let cancelled = service.cancel_task(fake_task_id);
752 assert!(!cancelled, "Nonexistent task should not be cancelled");
753 }
754
755
756 #[tokio::test]
757 async fn test_task_timeout_cleans_up_task_sender() {
758 let timer = TimerWheel::with_defaults();
759 let mut service = timer.create_service();
760
761 let task = TimerService::create_task(Duration::from_millis(50), None);
763 let task_id = task.get_id();
764
765 service.register(task).unwrap();
766
767 let mut rx = service.take_receiver().unwrap();
769 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
770 .await
771 .expect("Should receive timeout notification")
772 .expect("Should receive Some value");
773
774 assert_eq!(received_task_id, task_id);
775
776 tokio::time::sleep(Duration::from_millis(10)).await;
778
779 let cancelled = service.cancel_task(task_id);
781 assert!(!cancelled, "Timed out task should not exist anymore");
782 }
783
784 #[tokio::test]
785 async fn test_cancel_task_spawns_background_task() {
786 let timer = TimerWheel::with_defaults();
787 let service = timer.create_service();
788 let counter = Arc::new(AtomicU32::new(0));
789
790 let counter_clone = Arc::clone(&counter);
792 let task = TimerService::create_task(
793 Duration::from_secs(10),
794 Some(CallbackWrapper::new(move || {
795 let counter = Arc::clone(&counter_clone);
796 async move {
797 counter.fetch_add(1, Ordering::SeqCst);
798 }
799 })),
800 );
801 let task_id = task.get_id();
802
803 service.register(task).unwrap();
804
805 let cancelled = service.cancel_task(task_id);
807 assert!(cancelled, "Task should be cancelled successfully");
808
809 tokio::time::sleep(Duration::from_millis(100)).await;
811 assert_eq!(counter.load(Ordering::SeqCst), 0, "Callback should not have been executed");
812
813 let cancelled_again = service.cancel_task(task_id);
815 assert!(!cancelled_again, "Task should have been removed from active_tasks");
816 }
817
818 #[tokio::test]
819 async fn test_schedule_once_direct() {
820 let timer = TimerWheel::with_defaults();
821 let mut service = timer.create_service();
822 let counter = Arc::new(AtomicU32::new(0));
823
824 let counter_clone = Arc::clone(&counter);
826 let task = TimerService::create_task(
827 Duration::from_millis(50),
828 Some(CallbackWrapper::new(move || {
829 let counter = Arc::clone(&counter_clone);
830 async move {
831 counter.fetch_add(1, Ordering::SeqCst);
832 }
833 })),
834 );
835 let task_id = task.get_id();
836 service.register(task).unwrap();
837
838 let mut rx = service.take_receiver().unwrap();
840 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
841 .await
842 .expect("Should receive timeout notification")
843 .expect("Should receive Some value");
844
845 assert_eq!(received_task_id, task_id);
846
847 tokio::time::sleep(Duration::from_millis(50)).await;
849 assert_eq!(counter.load(Ordering::SeqCst), 1);
850 }
851
852 #[tokio::test]
853 async fn test_schedule_once_batch_direct() {
854 let timer = TimerWheel::with_defaults();
855 let mut service = timer.create_service();
856 let counter = Arc::new(AtomicU32::new(0));
857
858 let callbacks: Vec<_> = (0..3)
860 .map(|_| {
861 let counter = Arc::clone(&counter);
862 (Duration::from_millis(50), Some(CallbackWrapper::new(move || {
863 let counter = Arc::clone(&counter);
864 async move {
865 counter.fetch_add(1, Ordering::SeqCst);
866 }
867 })))
868 })
869 .collect();
870
871 let tasks = TimerService::create_batch(callbacks);
872 assert_eq!(tasks.len(), 3);
873 service.register_batch(tasks).unwrap();
874
875 let mut received_count = 0;
877 let mut rx = service.take_receiver().unwrap();
878
879 while received_count < 3 {
880 match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
881 Ok(Some(_task_id)) => {
882 received_count += 1;
883 }
884 Ok(None) => break,
885 Err(_) => break,
886 }
887 }
888
889 assert_eq!(received_count, 3);
890
891 tokio::time::sleep(Duration::from_millis(50)).await;
893 assert_eq!(counter.load(Ordering::SeqCst), 3);
894 }
895
896 #[tokio::test]
897 async fn test_schedule_once_notify_direct() {
898 let timer = TimerWheel::with_defaults();
899 let mut service = timer.create_service();
900
901 let task = crate::task::TimerTask::new(Duration::from_millis(50), None);
903 let task_id = task.get_id();
904 service.register(task).unwrap();
905
906 let mut rx = service.take_receiver().unwrap();
908 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
909 .await
910 .expect("Should receive timeout notification")
911 .expect("Should receive Some value");
912
913 assert_eq!(received_task_id, task_id);
914 }
915
916 #[tokio::test]
917 async fn test_schedule_and_cancel_direct() {
918 let timer = TimerWheel::with_defaults();
919 let service = timer.create_service();
920 let counter = Arc::new(AtomicU32::new(0));
921
922 let counter_clone = Arc::clone(&counter);
924 let task = TimerService::create_task(
925 Duration::from_secs(10),
926 Some(CallbackWrapper::new(move || {
927 let counter = Arc::clone(&counter_clone);
928 async move {
929 counter.fetch_add(1, Ordering::SeqCst);
930 }
931 })),
932 );
933 let task_id = task.get_id();
934 service.register(task).unwrap();
935
936 let cancelled = service.cancel_task(task_id);
938 assert!(cancelled, "Task should be cancelled successfully");
939
940 tokio::time::sleep(Duration::from_millis(100)).await;
942 assert_eq!(counter.load(Ordering::SeqCst), 0, "Callback should not have been executed");
943 }
944
945 #[tokio::test]
946 async fn test_cancel_batch_direct() {
947 let timer = TimerWheel::with_defaults();
948 let service = timer.create_service();
949 let counter = Arc::new(AtomicU32::new(0));
950
951 let callbacks: Vec<_> = (0..10)
953 .map(|_| {
954 let counter = Arc::clone(&counter);
955 (Duration::from_secs(10), Some(CallbackWrapper::new(move || {
956 let counter = Arc::clone(&counter);
957 async move {
958 counter.fetch_add(1, Ordering::SeqCst);
959 }
960 })))
961 })
962 .collect();
963
964 let tasks = TimerService::create_batch(callbacks);
965 let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
966 assert_eq!(task_ids.len(), 10);
967 service.register_batch(tasks).unwrap();
968
969 let cancelled = service.cancel_batch(&task_ids);
971 assert_eq!(cancelled, 10, "All 10 tasks should be cancelled");
972
973 tokio::time::sleep(Duration::from_millis(100)).await;
975 assert_eq!(counter.load(Ordering::SeqCst), 0, "No callbacks should have been executed");
976 }
977
978 #[tokio::test]
979 async fn test_cancel_batch_partial() {
980 let timer = TimerWheel::with_defaults();
981 let service = timer.create_service();
982 let counter = Arc::new(AtomicU32::new(0));
983
984 let callbacks: Vec<_> = (0..10)
986 .map(|_| {
987 let counter = Arc::clone(&counter);
988 (Duration::from_secs(10), Some(CallbackWrapper::new(move || {
989 let counter = Arc::clone(&counter);
990 async move {
991 counter.fetch_add(1, Ordering::SeqCst);
992 }
993 })))
994 })
995 .collect();
996
997 let tasks = TimerService::create_batch(callbacks);
998 let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
999 service.register_batch(tasks).unwrap();
1000
1001 let to_cancel: Vec<_> = task_ids.iter().take(5).copied().collect();
1003 let cancelled = service.cancel_batch(&to_cancel);
1004 assert_eq!(cancelled, 5, "5 tasks should be cancelled");
1005
1006 tokio::time::sleep(Duration::from_millis(100)).await;
1008 assert_eq!(counter.load(Ordering::SeqCst), 0, "Cancelled tasks should not execute");
1009 }
1010
1011 #[tokio::test]
1012 async fn test_cancel_batch_empty() {
1013 let timer = TimerWheel::with_defaults();
1014 let service = timer.create_service();
1015
1016 let empty: Vec<TaskId> = vec![];
1018 let cancelled = service.cancel_batch(&empty);
1019 assert_eq!(cancelled, 0, "No tasks should be cancelled");
1020 }
1021
1022 #[tokio::test]
1023 async fn test_postpone() {
1024 let timer = TimerWheel::with_defaults();
1025 let mut service = timer.create_service();
1026 let counter = Arc::new(AtomicU32::new(0));
1027
1028 let counter_clone1 = Arc::clone(&counter);
1030 let task = TimerService::create_task(
1031 Duration::from_millis(50),
1032 Some(CallbackWrapper::new(move || {
1033 let counter = Arc::clone(&counter_clone1);
1034 async move {
1035 counter.fetch_add(1, Ordering::SeqCst);
1036 }
1037 })),
1038 );
1039 let task_id = task.get_id();
1040 service.register(task).unwrap();
1041
1042 let counter_clone2 = Arc::clone(&counter);
1044 let postponed = service.postpone(
1045 task_id,
1046 Duration::from_millis(100),
1047 Some(CallbackWrapper::new(move || {
1048 let counter = Arc::clone(&counter_clone2);
1049 async move {
1050 counter.fetch_add(10, Ordering::SeqCst);
1051 }
1052 }))
1053 );
1054 assert!(postponed, "Task should be postponed successfully");
1055
1056 let mut rx = service.take_receiver().unwrap();
1058 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1059 .await
1060 .expect("Should receive timeout notification")
1061 .expect("Should receive Some value");
1062
1063 assert_eq!(received_task_id, task_id);
1064
1065 tokio::time::sleep(Duration::from_millis(20)).await;
1067
1068 assert_eq!(counter.load(Ordering::SeqCst), 10);
1070 }
1071
1072 #[tokio::test]
1073 async fn test_postpone_nonexistent_task() {
1074 let timer = TimerWheel::with_defaults();
1075 let service = timer.create_service();
1076
1077 let fake_task = TimerService::create_task(Duration::from_millis(50), None);
1079 let fake_task_id = fake_task.get_id();
1080 let postponed = service.postpone(fake_task_id, Duration::from_millis(100), None);
1083 assert!(!postponed, "Nonexistent task should not be postponed");
1084 }
1085
1086 #[tokio::test]
1087 async fn test_postpone_batch() {
1088 let timer = TimerWheel::with_defaults();
1089 let mut service = timer.create_service();
1090 let counter = Arc::new(AtomicU32::new(0));
1091
1092 let mut task_ids = Vec::new();
1094 for _ in 0..3 {
1095 let counter_clone = Arc::clone(&counter);
1096 let task = TimerService::create_task(
1097 Duration::from_millis(50),
1098 Some(CallbackWrapper::new(move || {
1099 let counter = Arc::clone(&counter_clone);
1100 async move {
1101 counter.fetch_add(1, Ordering::SeqCst);
1102 }
1103 })),
1104 );
1105 task_ids.push((task.get_id(), Duration::from_millis(150), None));
1106 service.register(task).unwrap();
1107 }
1108
1109 let postponed = service.postpone_batch_with_callbacks(task_ids);
1111 assert_eq!(postponed, 3, "All 3 tasks should be postponed");
1112
1113 tokio::time::sleep(Duration::from_millis(70)).await;
1115 assert_eq!(counter.load(Ordering::SeqCst), 0);
1116
1117 let mut received_count = 0;
1119 let mut rx = service.take_receiver().unwrap();
1120
1121 while received_count < 3 {
1122 match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
1123 Ok(Some(_task_id)) => {
1124 received_count += 1;
1125 }
1126 Ok(None) => break,
1127 Err(_) => break,
1128 }
1129 }
1130
1131 assert_eq!(received_count, 3);
1132
1133 tokio::time::sleep(Duration::from_millis(20)).await;
1135 assert_eq!(counter.load(Ordering::SeqCst), 3);
1136 }
1137
1138 #[tokio::test]
1139 async fn test_postpone_batch_with_callbacks() {
1140 let timer = TimerWheel::with_defaults();
1141 let mut service = timer.create_service();
1142 let counter = Arc::new(AtomicU32::new(0));
1143
1144 let mut task_ids = Vec::new();
1146 for _ in 0..3 {
1147 let task = TimerService::create_task(
1148 Duration::from_millis(50),
1149 None,
1150 );
1151 task_ids.push(task.get_id());
1152 service.register(task).unwrap();
1153 }
1154
1155 let updates: Vec<_> = task_ids
1157 .into_iter()
1158 .map(|id| {
1159 let counter_clone = Arc::clone(&counter);
1160 (id, Duration::from_millis(150), Some(CallbackWrapper::new(move || {
1161 let counter = Arc::clone(&counter_clone);
1162 async move {
1163 counter.fetch_add(1, Ordering::SeqCst);
1164 }
1165 })))
1166 })
1167 .collect();
1168
1169 let postponed = service.postpone_batch_with_callbacks(updates);
1170 assert_eq!(postponed, 3, "All 3 tasks should be postponed");
1171
1172 tokio::time::sleep(Duration::from_millis(70)).await;
1174 assert_eq!(counter.load(Ordering::SeqCst), 0);
1175
1176 let mut received_count = 0;
1178 let mut rx = service.take_receiver().unwrap();
1179
1180 while received_count < 3 {
1181 match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
1182 Ok(Some(_task_id)) => {
1183 received_count += 1;
1184 }
1185 Ok(None) => break,
1186 Err(_) => break,
1187 }
1188 }
1189
1190 assert_eq!(received_count, 3);
1191
1192 tokio::time::sleep(Duration::from_millis(20)).await;
1194 assert_eq!(counter.load(Ordering::SeqCst), 3);
1195 }
1196
1197 #[tokio::test]
1198 async fn test_postpone_batch_empty() {
1199 let timer = TimerWheel::with_defaults();
1200 let service = timer.create_service();
1201
1202 let empty: Vec<(TaskId, Duration, Option<CallbackWrapper>)> = vec![];
1204 let postponed = service.postpone_batch_with_callbacks(empty);
1205 assert_eq!(postponed, 0, "No tasks should be postponed");
1206 }
1207
1208 #[tokio::test]
1209 async fn test_postpone_keeps_timeout_notification_valid() {
1210 let timer = TimerWheel::with_defaults();
1211 let mut service = timer.create_service();
1212 let counter = Arc::new(AtomicU32::new(0));
1213
1214 let counter_clone = Arc::clone(&counter);
1216 let task = TimerService::create_task(
1217 Duration::from_millis(50),
1218 Some(CallbackWrapper::new(move || {
1219 let counter = Arc::clone(&counter_clone);
1220 async move {
1221 counter.fetch_add(1, Ordering::SeqCst);
1222 }
1223 })),
1224 );
1225 let task_id = task.get_id();
1226 service.register(task).unwrap();
1227
1228 service.postpone(task_id, Duration::from_millis(100), None);
1230
1231 let mut rx = service.take_receiver().unwrap();
1233 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1234 .await
1235 .expect("Should receive timeout notification")
1236 .expect("Should receive Some value");
1237
1238 assert_eq!(received_task_id, task_id, "Timeout notification should still work after postpone");
1239
1240 tokio::time::sleep(Duration::from_millis(20)).await;
1242 assert_eq!(counter.load(Ordering::SeqCst), 1);
1243 }
1244
1245 #[tokio::test]
1246 async fn test_cancelled_task_not_forwarded_to_timeout_rx() {
1247 let timer = TimerWheel::with_defaults();
1248 let mut service = timer.create_service();
1249
1250 let task1 = TimerService::create_task(Duration::from_secs(10), None);
1252 let task1_id = task1.get_id();
1253 service.register(task1).unwrap();
1254
1255 let task2 = TimerService::create_task(Duration::from_millis(50), None);
1256 let task2_id = task2.get_id();
1257 service.register(task2).unwrap();
1258
1259 let cancelled = service.cancel_task(task1_id);
1261 assert!(cancelled, "Task should be cancelled");
1262
1263 let mut rx = service.take_receiver().unwrap();
1265 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1266 .await
1267 .expect("Should receive timeout notification")
1268 .expect("Should receive Some value");
1269
1270 assert_eq!(received_task_id, task2_id, "Should only receive expired task notification");
1272
1273 let no_more = tokio::time::timeout(Duration::from_millis(50), rx.recv()).await;
1275 assert!(no_more.is_err(), "Should not receive any more notifications");
1276 }
1277}
1278