1use crate::config::ServiceConfig;
2use crate::task::{TaskId, TimerCallback};
3use crate::timer::{BatchHandle, TimerHandle};
4use crate::wheel::Wheel;
5use futures::stream::{FuturesUnordered, StreamExt};
6use futures::future::BoxFuture;
7use parking_lot::Mutex;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::sync::mpsc;
11use tokio::task::JoinHandle;
12
13enum ServiceCommand {
15 AddBatchHandle(BatchHandle),
17 AddTimerHandle(TimerHandle),
19 Shutdown,
21}
22
23pub struct TimerService {
58 command_tx: mpsc::Sender<ServiceCommand>,
60 timeout_rx: Option<mpsc::Receiver<TaskId>>,
62 actor_handle: Option<JoinHandle<()>>,
64 wheel: Arc<Mutex<Wheel>>,
66}
67
68impl TimerService {
69 pub(crate) fn new(wheel: Arc<Mutex<Wheel>>, config: ServiceConfig) -> Self {
89 let (command_tx, command_rx) = mpsc::channel(config.command_channel_capacity);
90 let (timeout_tx, timeout_rx) = mpsc::channel(config.timeout_channel_capacity);
91
92 let actor = ServiceActor::new(command_rx, timeout_tx);
93 let actor_handle = tokio::spawn(async move {
94 actor.run().await;
95 });
96
97 Self {
98 command_tx,
99 timeout_rx: Some(timeout_rx),
100 actor_handle: Some(actor_handle),
101 wheel,
102 }
103 }
104
105 async fn add_batch_handle(&self, batch: BatchHandle) {
107 let _ = self.command_tx
108 .send(ServiceCommand::AddBatchHandle(batch))
109 .await;
110 }
111
112 async fn add_timer_handle(&self, handle: TimerHandle) {
114 let _ = self.command_tx
115 .send(ServiceCommand::AddTimerHandle(handle))
116 .await;
117 }
118
119 pub fn take_receiver(&mut self) -> Option<mpsc::Receiver<TaskId>> {
143 self.timeout_rx.take()
144 }
145
146 #[inline]
179 pub async fn cancel_task(&self, task_id: TaskId) -> bool {
180 let mut wheel = self.wheel.lock();
183 wheel.cancel(task_id)
184 }
185
186 #[inline]
218 pub async fn cancel_batch(&self, task_ids: &[TaskId]) -> usize {
219 if task_ids.is_empty() {
220 return 0;
221 }
222
223 let mut wheel = self.wheel.lock();
226 wheel.cancel_batch(task_ids)
227 }
228
229 #[inline]
266 pub async fn postpone_task(&self, task_id: TaskId, new_delay: Duration) -> bool {
267 let mut wheel = self.wheel.lock();
270 wheel.postpone(task_id, new_delay, None)
271 }
272
273 #[inline]
314 pub async fn postpone_task_with_callback<C>(
315 &self,
316 task_id: TaskId,
317 new_delay: Duration,
318 callback: C,
319 ) -> bool
320 where
321 C: TimerCallback,
322 {
323 use std::sync::Arc;
324 let callback_wrapper = Arc::new(callback) as Arc<dyn TimerCallback>;
325 let mut wheel = self.wheel.lock();
326 wheel.postpone(task_id, new_delay, Some(callback_wrapper))
327 }
328
329 #[inline]
363 pub async fn postpone_batch(&self, updates: &[(TaskId, Duration)]) -> usize {
364 if updates.is_empty() {
365 return 0;
366 }
367
368 let updates_vec: Vec<_> = updates
369 .iter()
370 .map(|(task_id, delay)| (*task_id, *delay, None))
371 .collect();
372 let mut wheel = self.wheel.lock();
373 wheel.postpone_batch(updates_vec)
374 }
375
376 #[inline]
415 pub async fn postpone_batch_with_callbacks<C>(
416 &self,
417 updates: Vec<(TaskId, Duration, C)>,
418 ) -> usize
419 where
420 C: TimerCallback,
421 {
422 if updates.is_empty() {
423 return 0;
424 }
425
426 use std::sync::Arc;
427 let updates_vec: Vec<_> = updates
428 .into_iter()
429 .map(|(task_id, delay, callback)| {
430 let callback_wrapper = Arc::new(callback) as Arc<dyn TimerCallback>;
431 (task_id, delay, Some(callback_wrapper))
432 })
433 .collect();
434 let mut wheel = self.wheel.lock();
435 wheel.postpone_batch(updates_vec)
436 }
437
438 pub fn create_task<C>(delay: Duration, callback: C) -> crate::task::TimerTask
469 where
470 C: TimerCallback,
471 {
472 crate::timer::TimerWheel::create_task(delay, callback)
473 }
474
475 pub fn create_batch<C>(callbacks: Vec<(Duration, C)>) -> Vec<crate::task::TimerTask>
507 where
508 C: TimerCallback,
509 {
510 crate::timer::TimerWheel::create_batch(callbacks)
511 }
512
513 #[inline]
536 pub async fn register(&self, task: crate::task::TimerTask) {
537 let (completion_tx, completion_rx) = tokio::sync::oneshot::channel();
538 let notifier = crate::task::CompletionNotifier(completion_tx);
539
540 let delay = task.delay;
541 let task_id = task.id;
542
543 {
545 let mut wheel_guard = self.wheel.lock();
546 wheel_guard.insert(delay, task, notifier);
547 }
548
549 let handle = TimerHandle::new(task_id, self.wheel.clone(), completion_rx);
551 self.add_timer_handle(handle).await;
552 }
553
554 #[inline]
577 pub async fn register_batch(&self, tasks: Vec<crate::task::TimerTask>) {
578 let task_count = tasks.len();
579 let mut completion_rxs = Vec::with_capacity(task_count);
580 let mut task_ids = Vec::with_capacity(task_count);
581 let mut prepared_tasks = Vec::with_capacity(task_count);
582
583 for task in tasks {
586 let (completion_tx, completion_rx) = tokio::sync::oneshot::channel();
587 let notifier = crate::task::CompletionNotifier(completion_tx);
588
589 task_ids.push(task.id);
590 completion_rxs.push(completion_rx);
591 prepared_tasks.push((task.delay, task, notifier));
592 }
593
594 {
596 let mut wheel_guard = self.wheel.lock();
597 wheel_guard.insert_batch(prepared_tasks);
598 }
599
600 let batch_handle = BatchHandle::new(task_ids, self.wheel.clone(), completion_rxs);
602 self.add_batch_handle(batch_handle).await;
603 }
604
605 pub async fn shutdown(mut self) {
621 let _ = self.command_tx.send(ServiceCommand::Shutdown).await;
622 if let Some(handle) = self.actor_handle.take() {
623 let _ = handle.await;
624 }
625 }
626}
627
628
629impl Drop for TimerService {
630 fn drop(&mut self) {
631 if let Some(handle) = self.actor_handle.take() {
632 handle.abort();
633 }
634 }
635}
636
637struct ServiceActor {
639 command_rx: mpsc::Receiver<ServiceCommand>,
641 timeout_tx: mpsc::Sender<TaskId>,
643}
644
645impl ServiceActor {
646 fn new(command_rx: mpsc::Receiver<ServiceCommand>, timeout_tx: mpsc::Sender<TaskId>) -> Self {
647 Self {
648 command_rx,
649 timeout_tx,
650 }
651 }
652
653 async fn run(mut self) {
654 let mut futures: FuturesUnordered<BoxFuture<'static, (TaskId, Result<(), tokio::sync::oneshot::error::RecvError>)>> = FuturesUnordered::new();
657
658 loop {
659 tokio::select! {
660 Some((task_id, _result)) = futures.next() => {
662 let _ = self.timeout_tx.send(task_id).await;
664 }
666
667 Some(cmd) = self.command_rx.recv() => {
669 match cmd {
670 ServiceCommand::AddBatchHandle(batch) => {
671 let BatchHandle {
672 task_ids,
673 completion_rxs,
674 ..
675 } = batch;
676
677 for (task_id, rx) in task_ids.into_iter().zip(completion_rxs.into_iter()) {
679 let future: BoxFuture<'static, (TaskId, Result<(), tokio::sync::oneshot::error::RecvError>)> = Box::pin(async move {
680 (task_id, rx.await)
681 });
682 futures.push(future);
683 }
684 }
685 ServiceCommand::AddTimerHandle(handle) => {
686 let TimerHandle{
687 task_id,
688 completion_rx,
689 ..
690 } = handle;
691
692 let future: BoxFuture<'static, (TaskId, Result<(), tokio::sync::oneshot::error::RecvError>)> = Box::pin(async move {
694 (task_id, completion_rx.0.await)
695 });
696 futures.push(future);
697 }
698 ServiceCommand::Shutdown => {
699 break;
700 }
701 }
702 }
703
704 else => {
706 break;
707 }
708 }
709 }
710 }
711}
712
713#[cfg(test)]
714mod tests {
715 use super::*;
716 use crate::TimerWheel;
717 use std::sync::atomic::{AtomicU32, Ordering};
718 use std::sync::Arc;
719 use std::time::Duration;
720
721 #[tokio::test]
722 async fn test_service_creation() {
723 let timer = TimerWheel::with_defaults();
724 let _service = timer.create_service();
725 }
726
727
728 #[tokio::test]
729 async fn test_add_timer_handle_and_receive_timeout() {
730 let timer = TimerWheel::with_defaults();
731 let mut service = timer.create_service();
732
733 let task = TimerWheel::create_task(Duration::from_millis(50), || async {});
735 let task_id = task.get_id();
736 let handle = timer.register(task);
737
738 service.add_timer_handle(handle).await;
740
741 let mut rx = service.take_receiver().unwrap();
743 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
744 .await
745 .expect("Should receive timeout notification")
746 .expect("Should receive Some value");
747
748 assert_eq!(received_task_id, task_id);
749 }
750
751
752 #[tokio::test]
753 async fn test_shutdown() {
754 let timer = TimerWheel::with_defaults();
755 let service = timer.create_service();
756
757 let task1 = TimerService::create_task(Duration::from_secs(10), || async {});
759 let task2 = TimerService::create_task(Duration::from_secs(10), || async {});
760 service.register(task1).await;
761 service.register(task2).await;
762
763 service.shutdown().await;
765 }
766
767
768
769 #[tokio::test]
770 async fn test_cancel_task() {
771 let timer = TimerWheel::with_defaults();
772 let service = timer.create_service();
773
774 let task = TimerWheel::create_task(Duration::from_secs(10), || async {});
776 let task_id = task.get_id();
777 let handle = timer.register(task);
778
779 service.add_timer_handle(handle).await;
780
781 let cancelled = service.cancel_task(task_id).await;
783 assert!(cancelled, "Task should be cancelled successfully");
784
785 let cancelled_again = service.cancel_task(task_id).await;
787 assert!(!cancelled_again, "Task should not exist anymore");
788 }
789
790 #[tokio::test]
791 async fn test_cancel_nonexistent_task() {
792 let timer = TimerWheel::with_defaults();
793 let service = timer.create_service();
794
795 let task = TimerWheel::create_task(Duration::from_millis(50), || async {});
797 let handle = timer.register(task);
798 service.add_timer_handle(handle).await;
799
800 let fake_task = TimerWheel::create_task(Duration::from_millis(50), || async {});
802 let fake_task_id = fake_task.get_id();
803 let cancelled = service.cancel_task(fake_task_id).await;
805 assert!(!cancelled, "Nonexistent task should not be cancelled");
806 }
807
808
809 #[tokio::test]
810 async fn test_task_timeout_cleans_up_task_sender() {
811 let timer = TimerWheel::with_defaults();
812 let mut service = timer.create_service();
813
814 let task = TimerWheel::create_task(Duration::from_millis(50), || async {});
816 let task_id = task.get_id();
817 let handle = timer.register(task);
818
819 service.add_timer_handle(handle).await;
820
821 let mut rx = service.take_receiver().unwrap();
823 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
824 .await
825 .expect("Should receive timeout notification")
826 .expect("Should receive Some value");
827
828 assert_eq!(received_task_id, task_id);
829
830 tokio::time::sleep(Duration::from_millis(10)).await;
832
833 let cancelled = service.cancel_task(task_id).await;
835 assert!(!cancelled, "Timed out task should not exist anymore");
836 }
837
838 #[tokio::test]
839 async fn test_cancel_task_spawns_background_task() {
840 let timer = TimerWheel::with_defaults();
841 let service = timer.create_service();
842 let counter = Arc::new(AtomicU32::new(0));
843
844 let counter_clone = Arc::clone(&counter);
846 let task = TimerWheel::create_task(
847 Duration::from_secs(10),
848 move || {
849 let counter = Arc::clone(&counter_clone);
850 async move {
851 counter.fetch_add(1, Ordering::SeqCst);
852 }
853 },
854 );
855 let task_id = task.get_id();
856 let handle = timer.register(task);
857
858 service.add_timer_handle(handle).await;
859
860 let cancelled = service.cancel_task(task_id).await;
862 assert!(cancelled, "Task should be cancelled successfully");
863
864 tokio::time::sleep(Duration::from_millis(100)).await;
866 assert_eq!(counter.load(Ordering::SeqCst), 0, "Callback should not have been executed");
867
868 let cancelled_again = service.cancel_task(task_id).await;
870 assert!(!cancelled_again, "Task should have been removed from active_tasks");
871 }
872
873 #[tokio::test]
874 async fn test_schedule_once_direct() {
875 let timer = TimerWheel::with_defaults();
876 let mut service = timer.create_service();
877 let counter = Arc::new(AtomicU32::new(0));
878
879 let counter_clone = Arc::clone(&counter);
881 let task = TimerService::create_task(
882 Duration::from_millis(50),
883 move || {
884 let counter = Arc::clone(&counter_clone);
885 async move {
886 counter.fetch_add(1, Ordering::SeqCst);
887 }
888 },
889 );
890 let task_id = task.get_id();
891 service.register(task).await;
892
893 let mut rx = service.take_receiver().unwrap();
895 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
896 .await
897 .expect("Should receive timeout notification")
898 .expect("Should receive Some value");
899
900 assert_eq!(received_task_id, task_id);
901
902 tokio::time::sleep(Duration::from_millis(50)).await;
904 assert_eq!(counter.load(Ordering::SeqCst), 1);
905 }
906
907 #[tokio::test]
908 async fn test_schedule_once_batch_direct() {
909 let timer = TimerWheel::with_defaults();
910 let mut service = timer.create_service();
911 let counter = Arc::new(AtomicU32::new(0));
912
913 let callbacks: Vec<_> = (0..3)
915 .map(|_| {
916 let counter = Arc::clone(&counter);
917 (Duration::from_millis(50), move || {
918 let counter = Arc::clone(&counter);
919 async move {
920 counter.fetch_add(1, Ordering::SeqCst);
921 }
922 })
923 })
924 .collect();
925
926 let tasks = TimerService::create_batch(callbacks);
927 assert_eq!(tasks.len(), 3);
928 service.register_batch(tasks).await;
929
930 let mut received_count = 0;
932 let mut rx = service.take_receiver().unwrap();
933
934 while received_count < 3 {
935 match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
936 Ok(Some(_task_id)) => {
937 received_count += 1;
938 }
939 Ok(None) => break,
940 Err(_) => break,
941 }
942 }
943
944 assert_eq!(received_count, 3);
945
946 tokio::time::sleep(Duration::from_millis(50)).await;
948 assert_eq!(counter.load(Ordering::SeqCst), 3);
949 }
950
951 #[tokio::test]
952 async fn test_schedule_once_notify_direct() {
953 let timer = TimerWheel::with_defaults();
954 let mut service = timer.create_service();
955
956 let task = crate::task::TimerTask::new(Duration::from_millis(50), None);
958 let task_id = task.get_id();
959 service.register(task).await;
960
961 let mut rx = service.take_receiver().unwrap();
963 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
964 .await
965 .expect("Should receive timeout notification")
966 .expect("Should receive Some value");
967
968 assert_eq!(received_task_id, task_id);
969 }
970
971 #[tokio::test]
972 async fn test_schedule_and_cancel_direct() {
973 let timer = TimerWheel::with_defaults();
974 let service = timer.create_service();
975 let counter = Arc::new(AtomicU32::new(0));
976
977 let counter_clone = Arc::clone(&counter);
979 let task = TimerService::create_task(
980 Duration::from_secs(10),
981 move || {
982 let counter = Arc::clone(&counter_clone);
983 async move {
984 counter.fetch_add(1, Ordering::SeqCst);
985 }
986 },
987 );
988 let task_id = task.get_id();
989 service.register(task).await;
990
991 let cancelled = service.cancel_task(task_id).await;
993 assert!(cancelled, "Task should be cancelled successfully");
994
995 tokio::time::sleep(Duration::from_millis(100)).await;
997 assert_eq!(counter.load(Ordering::SeqCst), 0, "Callback should not have been executed");
998 }
999
1000 #[tokio::test]
1001 async fn test_cancel_batch_direct() {
1002 let timer = TimerWheel::with_defaults();
1003 let service = timer.create_service();
1004 let counter = Arc::new(AtomicU32::new(0));
1005
1006 let callbacks: Vec<_> = (0..10)
1008 .map(|_| {
1009 let counter = Arc::clone(&counter);
1010 (Duration::from_secs(10), move || {
1011 let counter = Arc::clone(&counter);
1012 async move {
1013 counter.fetch_add(1, Ordering::SeqCst);
1014 }
1015 })
1016 })
1017 .collect();
1018
1019 let tasks = TimerService::create_batch(callbacks);
1020 let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
1021 assert_eq!(task_ids.len(), 10);
1022 service.register_batch(tasks).await;
1023
1024 let cancelled = service.cancel_batch(&task_ids).await;
1026 assert_eq!(cancelled, 10, "All 10 tasks should be cancelled");
1027
1028 tokio::time::sleep(Duration::from_millis(100)).await;
1030 assert_eq!(counter.load(Ordering::SeqCst), 0, "No callbacks should have been executed");
1031 }
1032
1033 #[tokio::test]
1034 async fn test_cancel_batch_partial() {
1035 let timer = TimerWheel::with_defaults();
1036 let service = timer.create_service();
1037 let counter = Arc::new(AtomicU32::new(0));
1038
1039 let callbacks: Vec<_> = (0..10)
1041 .map(|_| {
1042 let counter = Arc::clone(&counter);
1043 (Duration::from_secs(10), move || {
1044 let counter = Arc::clone(&counter);
1045 async move {
1046 counter.fetch_add(1, Ordering::SeqCst);
1047 }
1048 })
1049 })
1050 .collect();
1051
1052 let tasks = TimerService::create_batch(callbacks);
1053 let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
1054 service.register_batch(tasks).await;
1055
1056 let to_cancel: Vec<_> = task_ids.iter().take(5).copied().collect();
1058 let cancelled = service.cancel_batch(&to_cancel).await;
1059 assert_eq!(cancelled, 5, "5 tasks should be cancelled");
1060
1061 tokio::time::sleep(Duration::from_millis(100)).await;
1063 assert_eq!(counter.load(Ordering::SeqCst), 0, "Cancelled tasks should not execute");
1064 }
1065
1066 #[tokio::test]
1067 async fn test_cancel_batch_empty() {
1068 let timer = TimerWheel::with_defaults();
1069 let service = timer.create_service();
1070
1071 let empty: Vec<TaskId> = vec![];
1073 let cancelled = service.cancel_batch(&empty).await;
1074 assert_eq!(cancelled, 0, "No tasks should be cancelled");
1075 }
1076
1077 #[tokio::test]
1078 async fn test_postpone_task() {
1079 let timer = TimerWheel::with_defaults();
1080 let mut service = timer.create_service();
1081 let counter = Arc::new(AtomicU32::new(0));
1082
1083 let counter_clone = Arc::clone(&counter);
1085 let task = TimerService::create_task(
1086 Duration::from_millis(50),
1087 move || {
1088 let counter = Arc::clone(&counter_clone);
1089 async move {
1090 counter.fetch_add(1, Ordering::SeqCst);
1091 }
1092 },
1093 );
1094 let task_id = task.get_id();
1095 service.register(task).await;
1096
1097 let postponed = service.postpone_task(task_id, Duration::from_millis(150)).await;
1099 assert!(postponed, "Task should be postponed successfully");
1100
1101 tokio::time::sleep(Duration::from_millis(70)).await;
1103 assert_eq!(counter.load(Ordering::SeqCst), 0);
1104
1105 let mut rx = service.take_receiver().unwrap();
1107 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1108 .await
1109 .expect("Should receive timeout notification")
1110 .expect("Should receive Some value");
1111
1112 assert_eq!(received_task_id, task_id);
1113
1114 tokio::time::sleep(Duration::from_millis(20)).await;
1116 assert_eq!(counter.load(Ordering::SeqCst), 1);
1117 }
1118
1119 #[tokio::test]
1120 async fn test_postpone_task_with_callback() {
1121 let timer = TimerWheel::with_defaults();
1122 let mut service = timer.create_service();
1123 let counter = Arc::new(AtomicU32::new(0));
1124
1125 let counter_clone1 = Arc::clone(&counter);
1127 let task = TimerService::create_task(
1128 Duration::from_millis(50),
1129 move || {
1130 let counter = Arc::clone(&counter_clone1);
1131 async move {
1132 counter.fetch_add(1, Ordering::SeqCst);
1133 }
1134 },
1135 );
1136 let task_id = task.get_id();
1137 service.register(task).await;
1138
1139 let counter_clone2 = Arc::clone(&counter);
1141 let postponed = service.postpone_task_with_callback(
1142 task_id,
1143 Duration::from_millis(100),
1144 move || {
1145 let counter = Arc::clone(&counter_clone2);
1146 async move {
1147 counter.fetch_add(10, Ordering::SeqCst);
1148 }
1149 }
1150 ).await;
1151 assert!(postponed, "Task should be postponed successfully");
1152
1153 let mut rx = service.take_receiver().unwrap();
1155 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1156 .await
1157 .expect("Should receive timeout notification")
1158 .expect("Should receive Some value");
1159
1160 assert_eq!(received_task_id, task_id);
1161
1162 tokio::time::sleep(Duration::from_millis(20)).await;
1164
1165 assert_eq!(counter.load(Ordering::SeqCst), 10);
1167 }
1168
1169 #[tokio::test]
1170 async fn test_postpone_nonexistent_task() {
1171 let timer = TimerWheel::with_defaults();
1172 let service = timer.create_service();
1173
1174 let fake_task = TimerService::create_task(Duration::from_millis(50), || async {});
1176 let fake_task_id = fake_task.get_id();
1177 let postponed = service.postpone_task(fake_task_id, Duration::from_millis(100)).await;
1180 assert!(!postponed, "Nonexistent task should not be postponed");
1181 }
1182
1183 #[tokio::test]
1184 async fn test_postpone_batch() {
1185 let timer = TimerWheel::with_defaults();
1186 let mut service = timer.create_service();
1187 let counter = Arc::new(AtomicU32::new(0));
1188
1189 let mut task_ids = Vec::new();
1191 for _ in 0..3 {
1192 let counter_clone = Arc::clone(&counter);
1193 let task = TimerService::create_task(
1194 Duration::from_millis(50),
1195 move || {
1196 let counter = Arc::clone(&counter_clone);
1197 async move {
1198 counter.fetch_add(1, Ordering::SeqCst);
1199 }
1200 },
1201 );
1202 task_ids.push((task.get_id(), Duration::from_millis(150)));
1203 service.register(task).await;
1204 }
1205
1206 let postponed = service.postpone_batch(&task_ids).await;
1208 assert_eq!(postponed, 3, "All 3 tasks should be postponed");
1209
1210 tokio::time::sleep(Duration::from_millis(70)).await;
1212 assert_eq!(counter.load(Ordering::SeqCst), 0);
1213
1214 let mut received_count = 0;
1216 let mut rx = service.take_receiver().unwrap();
1217
1218 while received_count < 3 {
1219 match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
1220 Ok(Some(_task_id)) => {
1221 received_count += 1;
1222 }
1223 Ok(None) => break,
1224 Err(_) => break,
1225 }
1226 }
1227
1228 assert_eq!(received_count, 3);
1229
1230 tokio::time::sleep(Duration::from_millis(20)).await;
1232 assert_eq!(counter.load(Ordering::SeqCst), 3);
1233 }
1234
1235 #[tokio::test]
1236 async fn test_postpone_batch_with_callbacks() {
1237 let timer = TimerWheel::with_defaults();
1238 let mut service = timer.create_service();
1239 let counter = Arc::new(AtomicU32::new(0));
1240
1241 let mut task_ids = Vec::new();
1243 for _ in 0..3 {
1244 let task = TimerService::create_task(
1245 Duration::from_millis(50),
1246 || async {},
1247 );
1248 task_ids.push(task.get_id());
1249 service.register(task).await;
1250 }
1251
1252 let updates: Vec<_> = task_ids
1254 .into_iter()
1255 .map(|id| {
1256 let counter_clone = Arc::clone(&counter);
1257 (id, Duration::from_millis(150), move || {
1258 let counter = Arc::clone(&counter_clone);
1259 async move {
1260 counter.fetch_add(1, Ordering::SeqCst);
1261 }
1262 })
1263 })
1264 .collect();
1265
1266 let postponed = service.postpone_batch_with_callbacks(updates).await;
1267 assert_eq!(postponed, 3, "All 3 tasks should be postponed");
1268
1269 tokio::time::sleep(Duration::from_millis(70)).await;
1271 assert_eq!(counter.load(Ordering::SeqCst), 0);
1272
1273 let mut received_count = 0;
1275 let mut rx = service.take_receiver().unwrap();
1276
1277 while received_count < 3 {
1278 match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
1279 Ok(Some(_task_id)) => {
1280 received_count += 1;
1281 }
1282 Ok(None) => break,
1283 Err(_) => break,
1284 }
1285 }
1286
1287 assert_eq!(received_count, 3);
1288
1289 tokio::time::sleep(Duration::from_millis(20)).await;
1291 assert_eq!(counter.load(Ordering::SeqCst), 3);
1292 }
1293
1294 #[tokio::test]
1295 async fn test_postpone_batch_empty() {
1296 let timer = TimerWheel::with_defaults();
1297 let service = timer.create_service();
1298
1299 let empty: Vec<(TaskId, Duration)> = vec![];
1301 let postponed = service.postpone_batch(&empty).await;
1302 assert_eq!(postponed, 0, "No tasks should be postponed");
1303 }
1304
1305 #[tokio::test]
1306 async fn test_postpone_keeps_timeout_notification_valid() {
1307 let timer = TimerWheel::with_defaults();
1308 let mut service = timer.create_service();
1309 let counter = Arc::new(AtomicU32::new(0));
1310
1311 let counter_clone = Arc::clone(&counter);
1313 let task = TimerService::create_task(
1314 Duration::from_millis(50),
1315 move || {
1316 let counter = Arc::clone(&counter_clone);
1317 async move {
1318 counter.fetch_add(1, Ordering::SeqCst);
1319 }
1320 },
1321 );
1322 let task_id = task.get_id();
1323 service.register(task).await;
1324
1325 service.postpone_task(task_id, Duration::from_millis(100)).await;
1327
1328 let mut rx = service.take_receiver().unwrap();
1330 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1331 .await
1332 .expect("Should receive timeout notification")
1333 .expect("Should receive Some value");
1334
1335 assert_eq!(received_task_id, task_id, "Timeout notification should still work after postpone");
1336
1337 tokio::time::sleep(Duration::from_millis(20)).await;
1339 assert_eq!(counter.load(Ordering::SeqCst), 1);
1340 }
1341}
1342