1use crate::config::ServiceConfig;
2use crate::task::{TaskCompletionReason, TaskId, TimerCallback};
3use crate::timer::{BatchHandle, TimerHandle};
4use crate::wheel::Wheel;
5use futures::stream::{FuturesUnordered, StreamExt};
6use futures::future::BoxFuture;
7use parking_lot::Mutex;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::sync::mpsc;
11use tokio::task::JoinHandle;
12
13enum ServiceCommand {
15 AddBatchHandle(BatchHandle),
17 AddTimerHandle(TimerHandle),
19 Shutdown,
21}
22
23pub struct TimerService {
58 command_tx: mpsc::Sender<ServiceCommand>,
60 timeout_rx: Option<mpsc::Receiver<TaskId>>,
62 actor_handle: Option<JoinHandle<()>>,
64 wheel: Arc<Mutex<Wheel>>,
66}
67
68impl TimerService {
69 pub(crate) fn new(wheel: Arc<Mutex<Wheel>>, config: ServiceConfig) -> Self {
89 let (command_tx, command_rx) = mpsc::channel(config.command_channel_capacity);
90 let (timeout_tx, timeout_rx) = mpsc::channel(config.timeout_channel_capacity);
91
92 let actor = ServiceActor::new(command_rx, timeout_tx);
93 let actor_handle = tokio::spawn(async move {
94 actor.run().await;
95 });
96
97 Self {
98 command_tx,
99 timeout_rx: Some(timeout_rx),
100 actor_handle: Some(actor_handle),
101 wheel,
102 }
103 }
104
105 async fn add_batch_handle(&self, batch: BatchHandle) {
107 let _ = self.command_tx
108 .send(ServiceCommand::AddBatchHandle(batch))
109 .await;
110 }
111
112 async fn add_timer_handle(&self, handle: TimerHandle) {
114 let _ = self.command_tx
115 .send(ServiceCommand::AddTimerHandle(handle))
116 .await;
117 }
118
119 pub fn take_receiver(&mut self) -> Option<mpsc::Receiver<TaskId>> {
143 self.timeout_rx.take()
144 }
145
146 #[inline]
178 pub fn cancel_task(&self, task_id: TaskId) -> bool {
179 let mut wheel = self.wheel.lock();
182 wheel.cancel(task_id)
183 }
184
185 #[inline]
217 pub fn cancel_batch(&self, task_ids: &[TaskId]) -> usize {
218 if task_ids.is_empty() {
219 return 0;
220 }
221
222 let mut wheel = self.wheel.lock();
225 wheel.cancel_batch(task_ids)
226 }
227
228 #[inline]
265 pub fn postpone_task(&self, task_id: TaskId, new_delay: Duration) -> bool {
266 let mut wheel = self.wheel.lock();
269 wheel.postpone(task_id, new_delay, None)
270 }
271
272 #[inline]
313 pub fn postpone_task_with_callback<C>(
314 &self,
315 task_id: TaskId,
316 new_delay: Duration,
317 callback: C,
318 ) -> bool
319 where
320 C: TimerCallback,
321 {
322 use std::sync::Arc;
323 let callback_wrapper = Arc::new(callback) as Arc<dyn TimerCallback>;
324 let mut wheel = self.wheel.lock();
325 wheel.postpone(task_id, new_delay, Some(callback_wrapper))
326 }
327
328 #[inline]
362 pub fn postpone_batch(&self, updates: &[(TaskId, Duration)]) -> usize {
363 if updates.is_empty() {
364 return 0;
365 }
366
367 let updates_vec: Vec<_> = updates
368 .iter()
369 .map(|(task_id, delay)| (*task_id, *delay, None))
370 .collect();
371 let mut wheel = self.wheel.lock();
372 wheel.postpone_batch(updates_vec)
373 }
374
375 #[inline]
414 pub fn postpone_batch_with_callbacks<C>(
415 &self,
416 updates: Vec<(TaskId, Duration, C)>,
417 ) -> usize
418 where
419 C: TimerCallback,
420 {
421 if updates.is_empty() {
422 return 0;
423 }
424
425 use std::sync::Arc;
426 let updates_vec: Vec<_> = updates
427 .into_iter()
428 .map(|(task_id, delay, callback)| {
429 let callback_wrapper = Arc::new(callback) as Arc<dyn TimerCallback>;
430 (task_id, delay, Some(callback_wrapper))
431 })
432 .collect();
433 let mut wheel = self.wheel.lock();
434 wheel.postpone_batch(updates_vec)
435 }
436
437 pub fn create_task<C>(delay: Duration, callback: C) -> crate::task::TimerTask
468 where
469 C: TimerCallback,
470 {
471 crate::timer::TimerWheel::create_task(delay, callback)
472 }
473
474 pub fn create_batch<C>(callbacks: Vec<(Duration, C)>) -> Vec<crate::task::TimerTask>
506 where
507 C: TimerCallback,
508 {
509 crate::timer::TimerWheel::create_batch(callbacks)
510 }
511
512 #[inline]
535 pub async fn register(&self, task: crate::task::TimerTask) {
536 let (completion_tx, completion_rx) = tokio::sync::oneshot::channel();
537 let notifier = crate::task::CompletionNotifier(completion_tx);
538
539 let delay = task.delay;
540 let task_id = task.id;
541
542 {
544 let mut wheel_guard = self.wheel.lock();
545 wheel_guard.insert(delay, task, notifier);
546 }
547
548 let handle = TimerHandle::new(task_id, self.wheel.clone(), completion_rx);
550 self.add_timer_handle(handle).await;
551 }
552
553 #[inline]
576 pub async fn register_batch(&self, tasks: Vec<crate::task::TimerTask>) {
577 let task_count = tasks.len();
578 let mut completion_rxs = Vec::with_capacity(task_count);
579 let mut task_ids = Vec::with_capacity(task_count);
580 let mut prepared_tasks = Vec::with_capacity(task_count);
581
582 for task in tasks {
585 let (completion_tx, completion_rx) = tokio::sync::oneshot::channel();
586 let notifier = crate::task::CompletionNotifier(completion_tx);
587
588 task_ids.push(task.id);
589 completion_rxs.push(completion_rx);
590 prepared_tasks.push((task.delay, task, notifier));
591 }
592
593 {
595 let mut wheel_guard = self.wheel.lock();
596 wheel_guard.insert_batch(prepared_tasks);
597 }
598
599 let batch_handle = BatchHandle::new(task_ids, self.wheel.clone(), completion_rxs);
601 self.add_batch_handle(batch_handle).await;
602 }
603
604 pub async fn shutdown(mut self) {
620 let _ = self.command_tx.send(ServiceCommand::Shutdown).await;
621 if let Some(handle) = self.actor_handle.take() {
622 let _ = handle.await;
623 }
624 }
625}
626
627
628impl Drop for TimerService {
629 fn drop(&mut self) {
630 if let Some(handle) = self.actor_handle.take() {
631 handle.abort();
632 }
633 }
634}
635
636struct ServiceActor {
638 command_rx: mpsc::Receiver<ServiceCommand>,
640 timeout_tx: mpsc::Sender<TaskId>,
642}
643
644impl ServiceActor {
645 fn new(command_rx: mpsc::Receiver<ServiceCommand>, timeout_tx: mpsc::Sender<TaskId>) -> Self {
646 Self {
647 command_rx,
648 timeout_tx,
649 }
650 }
651
652 async fn run(mut self) {
653 let mut futures: FuturesUnordered<BoxFuture<'static, (TaskId, Result<TaskCompletionReason, tokio::sync::oneshot::error::RecvError>)>> = FuturesUnordered::new();
656
657 loop {
658 tokio::select! {
659 Some((task_id, result)) = futures.next() => {
661 if let Ok(TaskCompletionReason::Expired) = result {
663 let _ = self.timeout_tx.send(task_id).await;
664 }
665 }
667
668 Some(cmd) = self.command_rx.recv() => {
670 match cmd {
671 ServiceCommand::AddBatchHandle(batch) => {
672 let BatchHandle {
673 task_ids,
674 completion_rxs,
675 ..
676 } = batch;
677
678 for (task_id, rx) in task_ids.into_iter().zip(completion_rxs.into_iter()) {
680 let future: BoxFuture<'static, (TaskId, Result<TaskCompletionReason, tokio::sync::oneshot::error::RecvError>)> = Box::pin(async move {
681 (task_id, rx.await)
682 });
683 futures.push(future);
684 }
685 }
686 ServiceCommand::AddTimerHandle(handle) => {
687 let TimerHandle{
688 task_id,
689 completion_rx,
690 ..
691 } = handle;
692
693 let future: BoxFuture<'static, (TaskId, Result<TaskCompletionReason, tokio::sync::oneshot::error::RecvError>)> = Box::pin(async move {
695 (task_id, completion_rx.0.await)
696 });
697 futures.push(future);
698 }
699 ServiceCommand::Shutdown => {
700 break;
701 }
702 }
703 }
704
705 else => {
707 break;
708 }
709 }
710 }
711 }
712}
713
714#[cfg(test)]
715mod tests {
716 use super::*;
717 use crate::TimerWheel;
718 use std::sync::atomic::{AtomicU32, Ordering};
719 use std::sync::Arc;
720 use std::time::Duration;
721
722 #[tokio::test]
723 async fn test_service_creation() {
724 let timer = TimerWheel::with_defaults();
725 let _service = timer.create_service();
726 }
727
728
729 #[tokio::test]
730 async fn test_add_timer_handle_and_receive_timeout() {
731 let timer = TimerWheel::with_defaults();
732 let mut service = timer.create_service();
733
734 let task = TimerWheel::create_task(Duration::from_millis(50), || async {});
736 let task_id = task.get_id();
737 let handle = timer.register(task);
738
739 service.add_timer_handle(handle).await;
741
742 let mut rx = service.take_receiver().unwrap();
744 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
745 .await
746 .expect("Should receive timeout notification")
747 .expect("Should receive Some value");
748
749 assert_eq!(received_task_id, task_id);
750 }
751
752
753 #[tokio::test]
754 async fn test_shutdown() {
755 let timer = TimerWheel::with_defaults();
756 let service = timer.create_service();
757
758 let task1 = TimerService::create_task(Duration::from_secs(10), || async {});
760 let task2 = TimerService::create_task(Duration::from_secs(10), || async {});
761 service.register(task1).await;
762 service.register(task2).await;
763
764 service.shutdown().await;
766 }
767
768
769
770 #[tokio::test]
771 async fn test_cancel_task() {
772 let timer = TimerWheel::with_defaults();
773 let service = timer.create_service();
774
775 let task = TimerWheel::create_task(Duration::from_secs(10), || async {});
777 let task_id = task.get_id();
778 let handle = timer.register(task);
779
780 service.add_timer_handle(handle).await;
781
782 let cancelled = service.cancel_task(task_id);
784 assert!(cancelled, "Task should be cancelled successfully");
785
786 let cancelled_again = service.cancel_task(task_id);
788 assert!(!cancelled_again, "Task should not exist anymore");
789 }
790
791 #[tokio::test]
792 async fn test_cancel_nonexistent_task() {
793 let timer = TimerWheel::with_defaults();
794 let service = timer.create_service();
795
796 let task = TimerWheel::create_task(Duration::from_millis(50), || async {});
798 let handle = timer.register(task);
799 service.add_timer_handle(handle).await;
800
801 let fake_task = TimerWheel::create_task(Duration::from_millis(50), || async {});
803 let fake_task_id = fake_task.get_id();
804 let cancelled = service.cancel_task(fake_task_id);
806 assert!(!cancelled, "Nonexistent task should not be cancelled");
807 }
808
809
810 #[tokio::test]
811 async fn test_task_timeout_cleans_up_task_sender() {
812 let timer = TimerWheel::with_defaults();
813 let mut service = timer.create_service();
814
815 let task = TimerWheel::create_task(Duration::from_millis(50), || async {});
817 let task_id = task.get_id();
818 let handle = timer.register(task);
819
820 service.add_timer_handle(handle).await;
821
822 let mut rx = service.take_receiver().unwrap();
824 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
825 .await
826 .expect("Should receive timeout notification")
827 .expect("Should receive Some value");
828
829 assert_eq!(received_task_id, task_id);
830
831 tokio::time::sleep(Duration::from_millis(10)).await;
833
834 let cancelled = service.cancel_task(task_id);
836 assert!(!cancelled, "Timed out task should not exist anymore");
837 }
838
839 #[tokio::test]
840 async fn test_cancel_task_spawns_background_task() {
841 let timer = TimerWheel::with_defaults();
842 let service = timer.create_service();
843 let counter = Arc::new(AtomicU32::new(0));
844
845 let counter_clone = Arc::clone(&counter);
847 let task = TimerWheel::create_task(
848 Duration::from_secs(10),
849 move || {
850 let counter = Arc::clone(&counter_clone);
851 async move {
852 counter.fetch_add(1, Ordering::SeqCst);
853 }
854 },
855 );
856 let task_id = task.get_id();
857 let handle = timer.register(task);
858
859 service.add_timer_handle(handle).await;
860
861 let cancelled = service.cancel_task(task_id);
863 assert!(cancelled, "Task should be cancelled successfully");
864
865 tokio::time::sleep(Duration::from_millis(100)).await;
867 assert_eq!(counter.load(Ordering::SeqCst), 0, "Callback should not have been executed");
868
869 let cancelled_again = service.cancel_task(task_id);
871 assert!(!cancelled_again, "Task should have been removed from active_tasks");
872 }
873
874 #[tokio::test]
875 async fn test_schedule_once_direct() {
876 let timer = TimerWheel::with_defaults();
877 let mut service = timer.create_service();
878 let counter = Arc::new(AtomicU32::new(0));
879
880 let counter_clone = Arc::clone(&counter);
882 let task = TimerService::create_task(
883 Duration::from_millis(50),
884 move || {
885 let counter = Arc::clone(&counter_clone);
886 async move {
887 counter.fetch_add(1, Ordering::SeqCst);
888 }
889 },
890 );
891 let task_id = task.get_id();
892 service.register(task).await;
893
894 let mut rx = service.take_receiver().unwrap();
896 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
897 .await
898 .expect("Should receive timeout notification")
899 .expect("Should receive Some value");
900
901 assert_eq!(received_task_id, task_id);
902
903 tokio::time::sleep(Duration::from_millis(50)).await;
905 assert_eq!(counter.load(Ordering::SeqCst), 1);
906 }
907
908 #[tokio::test]
909 async fn test_schedule_once_batch_direct() {
910 let timer = TimerWheel::with_defaults();
911 let mut service = timer.create_service();
912 let counter = Arc::new(AtomicU32::new(0));
913
914 let callbacks: Vec<_> = (0..3)
916 .map(|_| {
917 let counter = Arc::clone(&counter);
918 (Duration::from_millis(50), move || {
919 let counter = Arc::clone(&counter);
920 async move {
921 counter.fetch_add(1, Ordering::SeqCst);
922 }
923 })
924 })
925 .collect();
926
927 let tasks = TimerService::create_batch(callbacks);
928 assert_eq!(tasks.len(), 3);
929 service.register_batch(tasks).await;
930
931 let mut received_count = 0;
933 let mut rx = service.take_receiver().unwrap();
934
935 while received_count < 3 {
936 match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
937 Ok(Some(_task_id)) => {
938 received_count += 1;
939 }
940 Ok(None) => break,
941 Err(_) => break,
942 }
943 }
944
945 assert_eq!(received_count, 3);
946
947 tokio::time::sleep(Duration::from_millis(50)).await;
949 assert_eq!(counter.load(Ordering::SeqCst), 3);
950 }
951
952 #[tokio::test]
953 async fn test_schedule_once_notify_direct() {
954 let timer = TimerWheel::with_defaults();
955 let mut service = timer.create_service();
956
957 let task = crate::task::TimerTask::new(Duration::from_millis(50), None);
959 let task_id = task.get_id();
960 service.register(task).await;
961
962 let mut rx = service.take_receiver().unwrap();
964 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
965 .await
966 .expect("Should receive timeout notification")
967 .expect("Should receive Some value");
968
969 assert_eq!(received_task_id, task_id);
970 }
971
972 #[tokio::test]
973 async fn test_schedule_and_cancel_direct() {
974 let timer = TimerWheel::with_defaults();
975 let service = timer.create_service();
976 let counter = Arc::new(AtomicU32::new(0));
977
978 let counter_clone = Arc::clone(&counter);
980 let task = TimerService::create_task(
981 Duration::from_secs(10),
982 move || {
983 let counter = Arc::clone(&counter_clone);
984 async move {
985 counter.fetch_add(1, Ordering::SeqCst);
986 }
987 },
988 );
989 let task_id = task.get_id();
990 service.register(task).await;
991
992 let cancelled = service.cancel_task(task_id);
994 assert!(cancelled, "Task should be cancelled successfully");
995
996 tokio::time::sleep(Duration::from_millis(100)).await;
998 assert_eq!(counter.load(Ordering::SeqCst), 0, "Callback should not have been executed");
999 }
1000
1001 #[tokio::test]
1002 async fn test_cancel_batch_direct() {
1003 let timer = TimerWheel::with_defaults();
1004 let service = timer.create_service();
1005 let counter = Arc::new(AtomicU32::new(0));
1006
1007 let callbacks: Vec<_> = (0..10)
1009 .map(|_| {
1010 let counter = Arc::clone(&counter);
1011 (Duration::from_secs(10), move || {
1012 let counter = Arc::clone(&counter);
1013 async move {
1014 counter.fetch_add(1, Ordering::SeqCst);
1015 }
1016 })
1017 })
1018 .collect();
1019
1020 let tasks = TimerService::create_batch(callbacks);
1021 let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
1022 assert_eq!(task_ids.len(), 10);
1023 service.register_batch(tasks).await;
1024
1025 let cancelled = service.cancel_batch(&task_ids);
1027 assert_eq!(cancelled, 10, "All 10 tasks should be cancelled");
1028
1029 tokio::time::sleep(Duration::from_millis(100)).await;
1031 assert_eq!(counter.load(Ordering::SeqCst), 0, "No callbacks should have been executed");
1032 }
1033
1034 #[tokio::test]
1035 async fn test_cancel_batch_partial() {
1036 let timer = TimerWheel::with_defaults();
1037 let service = timer.create_service();
1038 let counter = Arc::new(AtomicU32::new(0));
1039
1040 let callbacks: Vec<_> = (0..10)
1042 .map(|_| {
1043 let counter = Arc::clone(&counter);
1044 (Duration::from_secs(10), move || {
1045 let counter = Arc::clone(&counter);
1046 async move {
1047 counter.fetch_add(1, Ordering::SeqCst);
1048 }
1049 })
1050 })
1051 .collect();
1052
1053 let tasks = TimerService::create_batch(callbacks);
1054 let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
1055 service.register_batch(tasks).await;
1056
1057 let to_cancel: Vec<_> = task_ids.iter().take(5).copied().collect();
1059 let cancelled = service.cancel_batch(&to_cancel);
1060 assert_eq!(cancelled, 5, "5 tasks should be cancelled");
1061
1062 tokio::time::sleep(Duration::from_millis(100)).await;
1064 assert_eq!(counter.load(Ordering::SeqCst), 0, "Cancelled tasks should not execute");
1065 }
1066
1067 #[tokio::test]
1068 async fn test_cancel_batch_empty() {
1069 let timer = TimerWheel::with_defaults();
1070 let service = timer.create_service();
1071
1072 let empty: Vec<TaskId> = vec![];
1074 let cancelled = service.cancel_batch(&empty);
1075 assert_eq!(cancelled, 0, "No tasks should be cancelled");
1076 }
1077
1078 #[tokio::test]
1079 async fn test_postpone_task() {
1080 let timer = TimerWheel::with_defaults();
1081 let mut service = timer.create_service();
1082 let counter = Arc::new(AtomicU32::new(0));
1083
1084 let counter_clone = Arc::clone(&counter);
1086 let task = TimerService::create_task(
1087 Duration::from_millis(50),
1088 move || {
1089 let counter = Arc::clone(&counter_clone);
1090 async move {
1091 counter.fetch_add(1, Ordering::SeqCst);
1092 }
1093 },
1094 );
1095 let task_id = task.get_id();
1096 service.register(task).await;
1097
1098 let postponed = service.postpone_task(task_id, Duration::from_millis(150));
1100 assert!(postponed, "Task should be postponed successfully");
1101
1102 tokio::time::sleep(Duration::from_millis(70)).await;
1104 assert_eq!(counter.load(Ordering::SeqCst), 0);
1105
1106 let mut rx = service.take_receiver().unwrap();
1108 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1109 .await
1110 .expect("Should receive timeout notification")
1111 .expect("Should receive Some value");
1112
1113 assert_eq!(received_task_id, task_id);
1114
1115 tokio::time::sleep(Duration::from_millis(20)).await;
1117 assert_eq!(counter.load(Ordering::SeqCst), 1);
1118 }
1119
1120 #[tokio::test]
1121 async fn test_postpone_task_with_callback() {
1122 let timer = TimerWheel::with_defaults();
1123 let mut service = timer.create_service();
1124 let counter = Arc::new(AtomicU32::new(0));
1125
1126 let counter_clone1 = Arc::clone(&counter);
1128 let task = TimerService::create_task(
1129 Duration::from_millis(50),
1130 move || {
1131 let counter = Arc::clone(&counter_clone1);
1132 async move {
1133 counter.fetch_add(1, Ordering::SeqCst);
1134 }
1135 },
1136 );
1137 let task_id = task.get_id();
1138 service.register(task).await;
1139
1140 let counter_clone2 = Arc::clone(&counter);
1142 let postponed = service.postpone_task_with_callback(
1143 task_id,
1144 Duration::from_millis(100),
1145 move || {
1146 let counter = Arc::clone(&counter_clone2);
1147 async move {
1148 counter.fetch_add(10, Ordering::SeqCst);
1149 }
1150 }
1151 );
1152 assert!(postponed, "Task should be postponed successfully");
1153
1154 let mut rx = service.take_receiver().unwrap();
1156 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1157 .await
1158 .expect("Should receive timeout notification")
1159 .expect("Should receive Some value");
1160
1161 assert_eq!(received_task_id, task_id);
1162
1163 tokio::time::sleep(Duration::from_millis(20)).await;
1165
1166 assert_eq!(counter.load(Ordering::SeqCst), 10);
1168 }
1169
1170 #[tokio::test]
1171 async fn test_postpone_nonexistent_task() {
1172 let timer = TimerWheel::with_defaults();
1173 let service = timer.create_service();
1174
1175 let fake_task = TimerService::create_task(Duration::from_millis(50), || async {});
1177 let fake_task_id = fake_task.get_id();
1178 let postponed = service.postpone_task(fake_task_id, Duration::from_millis(100));
1181 assert!(!postponed, "Nonexistent task should not be postponed");
1182 }
1183
1184 #[tokio::test]
1185 async fn test_postpone_batch() {
1186 let timer = TimerWheel::with_defaults();
1187 let mut service = timer.create_service();
1188 let counter = Arc::new(AtomicU32::new(0));
1189
1190 let mut task_ids = Vec::new();
1192 for _ in 0..3 {
1193 let counter_clone = Arc::clone(&counter);
1194 let task = TimerService::create_task(
1195 Duration::from_millis(50),
1196 move || {
1197 let counter = Arc::clone(&counter_clone);
1198 async move {
1199 counter.fetch_add(1, Ordering::SeqCst);
1200 }
1201 },
1202 );
1203 task_ids.push((task.get_id(), Duration::from_millis(150)));
1204 service.register(task).await;
1205 }
1206
1207 let postponed = service.postpone_batch(&task_ids);
1209 assert_eq!(postponed, 3, "All 3 tasks should be postponed");
1210
1211 tokio::time::sleep(Duration::from_millis(70)).await;
1213 assert_eq!(counter.load(Ordering::SeqCst), 0);
1214
1215 let mut received_count = 0;
1217 let mut rx = service.take_receiver().unwrap();
1218
1219 while received_count < 3 {
1220 match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
1221 Ok(Some(_task_id)) => {
1222 received_count += 1;
1223 }
1224 Ok(None) => break,
1225 Err(_) => break,
1226 }
1227 }
1228
1229 assert_eq!(received_count, 3);
1230
1231 tokio::time::sleep(Duration::from_millis(20)).await;
1233 assert_eq!(counter.load(Ordering::SeqCst), 3);
1234 }
1235
1236 #[tokio::test]
1237 async fn test_postpone_batch_with_callbacks() {
1238 let timer = TimerWheel::with_defaults();
1239 let mut service = timer.create_service();
1240 let counter = Arc::new(AtomicU32::new(0));
1241
1242 let mut task_ids = Vec::new();
1244 for _ in 0..3 {
1245 let task = TimerService::create_task(
1246 Duration::from_millis(50),
1247 || async {},
1248 );
1249 task_ids.push(task.get_id());
1250 service.register(task).await;
1251 }
1252
1253 let updates: Vec<_> = task_ids
1255 .into_iter()
1256 .map(|id| {
1257 let counter_clone = Arc::clone(&counter);
1258 (id, Duration::from_millis(150), move || {
1259 let counter = Arc::clone(&counter_clone);
1260 async move {
1261 counter.fetch_add(1, Ordering::SeqCst);
1262 }
1263 })
1264 })
1265 .collect();
1266
1267 let postponed = service.postpone_batch_with_callbacks(updates);
1268 assert_eq!(postponed, 3, "All 3 tasks should be postponed");
1269
1270 tokio::time::sleep(Duration::from_millis(70)).await;
1272 assert_eq!(counter.load(Ordering::SeqCst), 0);
1273
1274 let mut received_count = 0;
1276 let mut rx = service.take_receiver().unwrap();
1277
1278 while received_count < 3 {
1279 match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
1280 Ok(Some(_task_id)) => {
1281 received_count += 1;
1282 }
1283 Ok(None) => break,
1284 Err(_) => break,
1285 }
1286 }
1287
1288 assert_eq!(received_count, 3);
1289
1290 tokio::time::sleep(Duration::from_millis(20)).await;
1292 assert_eq!(counter.load(Ordering::SeqCst), 3);
1293 }
1294
1295 #[tokio::test]
1296 async fn test_postpone_batch_empty() {
1297 let timer = TimerWheel::with_defaults();
1298 let service = timer.create_service();
1299
1300 let empty: Vec<(TaskId, Duration)> = vec![];
1302 let postponed = service.postpone_batch(&empty);
1303 assert_eq!(postponed, 0, "No tasks should be postponed");
1304 }
1305
1306 #[tokio::test]
1307 async fn test_postpone_keeps_timeout_notification_valid() {
1308 let timer = TimerWheel::with_defaults();
1309 let mut service = timer.create_service();
1310 let counter = Arc::new(AtomicU32::new(0));
1311
1312 let counter_clone = Arc::clone(&counter);
1314 let task = TimerService::create_task(
1315 Duration::from_millis(50),
1316 move || {
1317 let counter = Arc::clone(&counter_clone);
1318 async move {
1319 counter.fetch_add(1, Ordering::SeqCst);
1320 }
1321 },
1322 );
1323 let task_id = task.get_id();
1324 service.register(task).await;
1325
1326 service.postpone_task(task_id, Duration::from_millis(100));
1328
1329 let mut rx = service.take_receiver().unwrap();
1331 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1332 .await
1333 .expect("Should receive timeout notification")
1334 .expect("Should receive Some value");
1335
1336 assert_eq!(received_task_id, task_id, "Timeout notification should still work after postpone");
1337
1338 tokio::time::sleep(Duration::from_millis(20)).await;
1340 assert_eq!(counter.load(Ordering::SeqCst), 1);
1341 }
1342
1343 #[tokio::test]
1344 async fn test_cancelled_task_not_forwarded_to_timeout_rx() {
1345 let timer = TimerWheel::with_defaults();
1346 let mut service = timer.create_service();
1347
1348 let task1 = TimerService::create_task(Duration::from_secs(10), || async {});
1350 let task1_id = task1.get_id();
1351 service.register(task1).await;
1352
1353 let task2 = TimerService::create_task(Duration::from_millis(50), || async {});
1354 let task2_id = task2.get_id();
1355 service.register(task2).await;
1356
1357 let cancelled = service.cancel_task(task1_id);
1359 assert!(cancelled, "Task should be cancelled");
1360
1361 let mut rx = service.take_receiver().unwrap();
1363 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1364 .await
1365 .expect("Should receive timeout notification")
1366 .expect("Should receive Some value");
1367
1368 assert_eq!(received_task_id, task2_id, "Should only receive expired task notification");
1370
1371 let no_more = tokio::time::timeout(Duration::from_millis(50), rx.recv()).await;
1373 assert!(no_more.is_err(), "Should not receive any more notifications");
1374 }
1375}
1376