1use crate::config::ServiceConfig;
2use crate::task::{TaskCompletionReason, TaskId, TimerCallback};
3use crate::timer::{BatchHandle, TimerHandle};
4use crate::wheel::Wheel;
5use futures::stream::{FuturesUnordered, StreamExt};
6use futures::future::BoxFuture;
7use parking_lot::Mutex;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::sync::mpsc;
11use tokio::task::JoinHandle;
12
13enum ServiceCommand {
15 AddBatchHandle(BatchHandle),
17 AddTimerHandle(TimerHandle),
19 Shutdown,
21}
22
23pub struct TimerService {
58 command_tx: mpsc::Sender<ServiceCommand>,
60 timeout_rx: Option<mpsc::Receiver<TaskId>>,
62 actor_handle: Option<JoinHandle<()>>,
64 wheel: Arc<Mutex<Wheel>>,
66}
67
68impl TimerService {
69 pub(crate) fn new(wheel: Arc<Mutex<Wheel>>, config: ServiceConfig) -> Self {
89 let (command_tx, command_rx) = mpsc::channel(config.command_channel_capacity);
90 let (timeout_tx, timeout_rx) = mpsc::channel(config.timeout_channel_capacity);
91
92 let actor = ServiceActor::new(command_rx, timeout_tx);
93 let actor_handle = tokio::spawn(async move {
94 actor.run().await;
95 });
96
97 Self {
98 command_tx,
99 timeout_rx: Some(timeout_rx),
100 actor_handle: Some(actor_handle),
101 wheel,
102 }
103 }
104
105 async fn add_batch_handle(&self, batch: BatchHandle) {
107 let _ = self.command_tx
108 .send(ServiceCommand::AddBatchHandle(batch))
109 .await;
110 }
111
112 async fn add_timer_handle(&self, handle: TimerHandle) {
114 let _ = self.command_tx
115 .send(ServiceCommand::AddTimerHandle(handle))
116 .await;
117 }
118
119 pub fn take_receiver(&mut self) -> Option<mpsc::Receiver<TaskId>> {
143 self.timeout_rx.take()
144 }
145
146 #[inline]
179 pub async fn cancel_task(&self, task_id: TaskId) -> bool {
180 let mut wheel = self.wheel.lock();
183 wheel.cancel(task_id)
184 }
185
186 #[inline]
218 pub async fn cancel_batch(&self, task_ids: &[TaskId]) -> usize {
219 if task_ids.is_empty() {
220 return 0;
221 }
222
223 let mut wheel = self.wheel.lock();
226 wheel.cancel_batch(task_ids)
227 }
228
229 #[inline]
266 pub async fn postpone_task(&self, task_id: TaskId, new_delay: Duration) -> bool {
267 let mut wheel = self.wheel.lock();
270 wheel.postpone(task_id, new_delay, None)
271 }
272
273 #[inline]
314 pub async fn postpone_task_with_callback<C>(
315 &self,
316 task_id: TaskId,
317 new_delay: Duration,
318 callback: C,
319 ) -> bool
320 where
321 C: TimerCallback,
322 {
323 use std::sync::Arc;
324 let callback_wrapper = Arc::new(callback) as Arc<dyn TimerCallback>;
325 let mut wheel = self.wheel.lock();
326 wheel.postpone(task_id, new_delay, Some(callback_wrapper))
327 }
328
329 #[inline]
363 pub async fn postpone_batch(&self, updates: &[(TaskId, Duration)]) -> usize {
364 if updates.is_empty() {
365 return 0;
366 }
367
368 let updates_vec: Vec<_> = updates
369 .iter()
370 .map(|(task_id, delay)| (*task_id, *delay, None))
371 .collect();
372 let mut wheel = self.wheel.lock();
373 wheel.postpone_batch(updates_vec)
374 }
375
376 #[inline]
415 pub async fn postpone_batch_with_callbacks<C>(
416 &self,
417 updates: Vec<(TaskId, Duration, C)>,
418 ) -> usize
419 where
420 C: TimerCallback,
421 {
422 if updates.is_empty() {
423 return 0;
424 }
425
426 use std::sync::Arc;
427 let updates_vec: Vec<_> = updates
428 .into_iter()
429 .map(|(task_id, delay, callback)| {
430 let callback_wrapper = Arc::new(callback) as Arc<dyn TimerCallback>;
431 (task_id, delay, Some(callback_wrapper))
432 })
433 .collect();
434 let mut wheel = self.wheel.lock();
435 wheel.postpone_batch(updates_vec)
436 }
437
438 pub fn create_task<C>(delay: Duration, callback: C) -> crate::task::TimerTask
469 where
470 C: TimerCallback,
471 {
472 crate::timer::TimerWheel::create_task(delay, callback)
473 }
474
475 pub fn create_batch<C>(callbacks: Vec<(Duration, C)>) -> Vec<crate::task::TimerTask>
507 where
508 C: TimerCallback,
509 {
510 crate::timer::TimerWheel::create_batch(callbacks)
511 }
512
513 #[inline]
536 pub async fn register(&self, task: crate::task::TimerTask) {
537 let (completion_tx, completion_rx) = tokio::sync::oneshot::channel();
538 let notifier = crate::task::CompletionNotifier(completion_tx);
539
540 let delay = task.delay;
541 let task_id = task.id;
542
543 {
545 let mut wheel_guard = self.wheel.lock();
546 wheel_guard.insert(delay, task, notifier);
547 }
548
549 let handle = TimerHandle::new(task_id, self.wheel.clone(), completion_rx);
551 self.add_timer_handle(handle).await;
552 }
553
554 #[inline]
577 pub async fn register_batch(&self, tasks: Vec<crate::task::TimerTask>) {
578 let task_count = tasks.len();
579 let mut completion_rxs = Vec::with_capacity(task_count);
580 let mut task_ids = Vec::with_capacity(task_count);
581 let mut prepared_tasks = Vec::with_capacity(task_count);
582
583 for task in tasks {
586 let (completion_tx, completion_rx) = tokio::sync::oneshot::channel();
587 let notifier = crate::task::CompletionNotifier(completion_tx);
588
589 task_ids.push(task.id);
590 completion_rxs.push(completion_rx);
591 prepared_tasks.push((task.delay, task, notifier));
592 }
593
594 {
596 let mut wheel_guard = self.wheel.lock();
597 wheel_guard.insert_batch(prepared_tasks);
598 }
599
600 let batch_handle = BatchHandle::new(task_ids, self.wheel.clone(), completion_rxs);
602 self.add_batch_handle(batch_handle).await;
603 }
604
605 pub async fn shutdown(mut self) {
621 let _ = self.command_tx.send(ServiceCommand::Shutdown).await;
622 if let Some(handle) = self.actor_handle.take() {
623 let _ = handle.await;
624 }
625 }
626}
627
628
629impl Drop for TimerService {
630 fn drop(&mut self) {
631 if let Some(handle) = self.actor_handle.take() {
632 handle.abort();
633 }
634 }
635}
636
637struct ServiceActor {
639 command_rx: mpsc::Receiver<ServiceCommand>,
641 timeout_tx: mpsc::Sender<TaskId>,
643}
644
645impl ServiceActor {
646 fn new(command_rx: mpsc::Receiver<ServiceCommand>, timeout_tx: mpsc::Sender<TaskId>) -> Self {
647 Self {
648 command_rx,
649 timeout_tx,
650 }
651 }
652
653 async fn run(mut self) {
654 let mut futures: FuturesUnordered<BoxFuture<'static, (TaskId, Result<TaskCompletionReason, tokio::sync::oneshot::error::RecvError>)>> = FuturesUnordered::new();
657
658 loop {
659 tokio::select! {
660 Some((task_id, result)) = futures.next() => {
662 if let Ok(TaskCompletionReason::Expired) = result {
664 let _ = self.timeout_tx.send(task_id).await;
665 }
666 }
668
669 Some(cmd) = self.command_rx.recv() => {
671 match cmd {
672 ServiceCommand::AddBatchHandle(batch) => {
673 let BatchHandle {
674 task_ids,
675 completion_rxs,
676 ..
677 } = batch;
678
679 for (task_id, rx) in task_ids.into_iter().zip(completion_rxs.into_iter()) {
681 let future: BoxFuture<'static, (TaskId, Result<TaskCompletionReason, tokio::sync::oneshot::error::RecvError>)> = Box::pin(async move {
682 (task_id, rx.await)
683 });
684 futures.push(future);
685 }
686 }
687 ServiceCommand::AddTimerHandle(handle) => {
688 let TimerHandle{
689 task_id,
690 completion_rx,
691 ..
692 } = handle;
693
694 let future: BoxFuture<'static, (TaskId, Result<TaskCompletionReason, tokio::sync::oneshot::error::RecvError>)> = Box::pin(async move {
696 (task_id, completion_rx.0.await)
697 });
698 futures.push(future);
699 }
700 ServiceCommand::Shutdown => {
701 break;
702 }
703 }
704 }
705
706 else => {
708 break;
709 }
710 }
711 }
712 }
713}
714
715#[cfg(test)]
716mod tests {
717 use super::*;
718 use crate::TimerWheel;
719 use std::sync::atomic::{AtomicU32, Ordering};
720 use std::sync::Arc;
721 use std::time::Duration;
722
723 #[tokio::test]
724 async fn test_service_creation() {
725 let timer = TimerWheel::with_defaults();
726 let _service = timer.create_service();
727 }
728
729
730 #[tokio::test]
731 async fn test_add_timer_handle_and_receive_timeout() {
732 let timer = TimerWheel::with_defaults();
733 let mut service = timer.create_service();
734
735 let task = TimerWheel::create_task(Duration::from_millis(50), || async {});
737 let task_id = task.get_id();
738 let handle = timer.register(task);
739
740 service.add_timer_handle(handle).await;
742
743 let mut rx = service.take_receiver().unwrap();
745 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
746 .await
747 .expect("Should receive timeout notification")
748 .expect("Should receive Some value");
749
750 assert_eq!(received_task_id, task_id);
751 }
752
753
754 #[tokio::test]
755 async fn test_shutdown() {
756 let timer = TimerWheel::with_defaults();
757 let service = timer.create_service();
758
759 let task1 = TimerService::create_task(Duration::from_secs(10), || async {});
761 let task2 = TimerService::create_task(Duration::from_secs(10), || async {});
762 service.register(task1).await;
763 service.register(task2).await;
764
765 service.shutdown().await;
767 }
768
769
770
771 #[tokio::test]
772 async fn test_cancel_task() {
773 let timer = TimerWheel::with_defaults();
774 let service = timer.create_service();
775
776 let task = TimerWheel::create_task(Duration::from_secs(10), || async {});
778 let task_id = task.get_id();
779 let handle = timer.register(task);
780
781 service.add_timer_handle(handle).await;
782
783 let cancelled = service.cancel_task(task_id).await;
785 assert!(cancelled, "Task should be cancelled successfully");
786
787 let cancelled_again = service.cancel_task(task_id).await;
789 assert!(!cancelled_again, "Task should not exist anymore");
790 }
791
792 #[tokio::test]
793 async fn test_cancel_nonexistent_task() {
794 let timer = TimerWheel::with_defaults();
795 let service = timer.create_service();
796
797 let task = TimerWheel::create_task(Duration::from_millis(50), || async {});
799 let handle = timer.register(task);
800 service.add_timer_handle(handle).await;
801
802 let fake_task = TimerWheel::create_task(Duration::from_millis(50), || async {});
804 let fake_task_id = fake_task.get_id();
805 let cancelled = service.cancel_task(fake_task_id).await;
807 assert!(!cancelled, "Nonexistent task should not be cancelled");
808 }
809
810
811 #[tokio::test]
812 async fn test_task_timeout_cleans_up_task_sender() {
813 let timer = TimerWheel::with_defaults();
814 let mut service = timer.create_service();
815
816 let task = TimerWheel::create_task(Duration::from_millis(50), || async {});
818 let task_id = task.get_id();
819 let handle = timer.register(task);
820
821 service.add_timer_handle(handle).await;
822
823 let mut rx = service.take_receiver().unwrap();
825 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
826 .await
827 .expect("Should receive timeout notification")
828 .expect("Should receive Some value");
829
830 assert_eq!(received_task_id, task_id);
831
832 tokio::time::sleep(Duration::from_millis(10)).await;
834
835 let cancelled = service.cancel_task(task_id).await;
837 assert!(!cancelled, "Timed out task should not exist anymore");
838 }
839
840 #[tokio::test]
841 async fn test_cancel_task_spawns_background_task() {
842 let timer = TimerWheel::with_defaults();
843 let service = timer.create_service();
844 let counter = Arc::new(AtomicU32::new(0));
845
846 let counter_clone = Arc::clone(&counter);
848 let task = TimerWheel::create_task(
849 Duration::from_secs(10),
850 move || {
851 let counter = Arc::clone(&counter_clone);
852 async move {
853 counter.fetch_add(1, Ordering::SeqCst);
854 }
855 },
856 );
857 let task_id = task.get_id();
858 let handle = timer.register(task);
859
860 service.add_timer_handle(handle).await;
861
862 let cancelled = service.cancel_task(task_id).await;
864 assert!(cancelled, "Task should be cancelled successfully");
865
866 tokio::time::sleep(Duration::from_millis(100)).await;
868 assert_eq!(counter.load(Ordering::SeqCst), 0, "Callback should not have been executed");
869
870 let cancelled_again = service.cancel_task(task_id).await;
872 assert!(!cancelled_again, "Task should have been removed from active_tasks");
873 }
874
875 #[tokio::test]
876 async fn test_schedule_once_direct() {
877 let timer = TimerWheel::with_defaults();
878 let mut service = timer.create_service();
879 let counter = Arc::new(AtomicU32::new(0));
880
881 let counter_clone = Arc::clone(&counter);
883 let task = TimerService::create_task(
884 Duration::from_millis(50),
885 move || {
886 let counter = Arc::clone(&counter_clone);
887 async move {
888 counter.fetch_add(1, Ordering::SeqCst);
889 }
890 },
891 );
892 let task_id = task.get_id();
893 service.register(task).await;
894
895 let mut rx = service.take_receiver().unwrap();
897 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
898 .await
899 .expect("Should receive timeout notification")
900 .expect("Should receive Some value");
901
902 assert_eq!(received_task_id, task_id);
903
904 tokio::time::sleep(Duration::from_millis(50)).await;
906 assert_eq!(counter.load(Ordering::SeqCst), 1);
907 }
908
909 #[tokio::test]
910 async fn test_schedule_once_batch_direct() {
911 let timer = TimerWheel::with_defaults();
912 let mut service = timer.create_service();
913 let counter = Arc::new(AtomicU32::new(0));
914
915 let callbacks: Vec<_> = (0..3)
917 .map(|_| {
918 let counter = Arc::clone(&counter);
919 (Duration::from_millis(50), move || {
920 let counter = Arc::clone(&counter);
921 async move {
922 counter.fetch_add(1, Ordering::SeqCst);
923 }
924 })
925 })
926 .collect();
927
928 let tasks = TimerService::create_batch(callbacks);
929 assert_eq!(tasks.len(), 3);
930 service.register_batch(tasks).await;
931
932 let mut received_count = 0;
934 let mut rx = service.take_receiver().unwrap();
935
936 while received_count < 3 {
937 match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
938 Ok(Some(_task_id)) => {
939 received_count += 1;
940 }
941 Ok(None) => break,
942 Err(_) => break,
943 }
944 }
945
946 assert_eq!(received_count, 3);
947
948 tokio::time::sleep(Duration::from_millis(50)).await;
950 assert_eq!(counter.load(Ordering::SeqCst), 3);
951 }
952
953 #[tokio::test]
954 async fn test_schedule_once_notify_direct() {
955 let timer = TimerWheel::with_defaults();
956 let mut service = timer.create_service();
957
958 let task = crate::task::TimerTask::new(Duration::from_millis(50), None);
960 let task_id = task.get_id();
961 service.register(task).await;
962
963 let mut rx = service.take_receiver().unwrap();
965 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
966 .await
967 .expect("Should receive timeout notification")
968 .expect("Should receive Some value");
969
970 assert_eq!(received_task_id, task_id);
971 }
972
973 #[tokio::test]
974 async fn test_schedule_and_cancel_direct() {
975 let timer = TimerWheel::with_defaults();
976 let service = timer.create_service();
977 let counter = Arc::new(AtomicU32::new(0));
978
979 let counter_clone = Arc::clone(&counter);
981 let task = TimerService::create_task(
982 Duration::from_secs(10),
983 move || {
984 let counter = Arc::clone(&counter_clone);
985 async move {
986 counter.fetch_add(1, Ordering::SeqCst);
987 }
988 },
989 );
990 let task_id = task.get_id();
991 service.register(task).await;
992
993 let cancelled = service.cancel_task(task_id).await;
995 assert!(cancelled, "Task should be cancelled successfully");
996
997 tokio::time::sleep(Duration::from_millis(100)).await;
999 assert_eq!(counter.load(Ordering::SeqCst), 0, "Callback should not have been executed");
1000 }
1001
1002 #[tokio::test]
1003 async fn test_cancel_batch_direct() {
1004 let timer = TimerWheel::with_defaults();
1005 let service = timer.create_service();
1006 let counter = Arc::new(AtomicU32::new(0));
1007
1008 let callbacks: Vec<_> = (0..10)
1010 .map(|_| {
1011 let counter = Arc::clone(&counter);
1012 (Duration::from_secs(10), move || {
1013 let counter = Arc::clone(&counter);
1014 async move {
1015 counter.fetch_add(1, Ordering::SeqCst);
1016 }
1017 })
1018 })
1019 .collect();
1020
1021 let tasks = TimerService::create_batch(callbacks);
1022 let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
1023 assert_eq!(task_ids.len(), 10);
1024 service.register_batch(tasks).await;
1025
1026 let cancelled = service.cancel_batch(&task_ids).await;
1028 assert_eq!(cancelled, 10, "All 10 tasks should be cancelled");
1029
1030 tokio::time::sleep(Duration::from_millis(100)).await;
1032 assert_eq!(counter.load(Ordering::SeqCst), 0, "No callbacks should have been executed");
1033 }
1034
1035 #[tokio::test]
1036 async fn test_cancel_batch_partial() {
1037 let timer = TimerWheel::with_defaults();
1038 let service = timer.create_service();
1039 let counter = Arc::new(AtomicU32::new(0));
1040
1041 let callbacks: Vec<_> = (0..10)
1043 .map(|_| {
1044 let counter = Arc::clone(&counter);
1045 (Duration::from_secs(10), move || {
1046 let counter = Arc::clone(&counter);
1047 async move {
1048 counter.fetch_add(1, Ordering::SeqCst);
1049 }
1050 })
1051 })
1052 .collect();
1053
1054 let tasks = TimerService::create_batch(callbacks);
1055 let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
1056 service.register_batch(tasks).await;
1057
1058 let to_cancel: Vec<_> = task_ids.iter().take(5).copied().collect();
1060 let cancelled = service.cancel_batch(&to_cancel).await;
1061 assert_eq!(cancelled, 5, "5 tasks should be cancelled");
1062
1063 tokio::time::sleep(Duration::from_millis(100)).await;
1065 assert_eq!(counter.load(Ordering::SeqCst), 0, "Cancelled tasks should not execute");
1066 }
1067
1068 #[tokio::test]
1069 async fn test_cancel_batch_empty() {
1070 let timer = TimerWheel::with_defaults();
1071 let service = timer.create_service();
1072
1073 let empty: Vec<TaskId> = vec![];
1075 let cancelled = service.cancel_batch(&empty).await;
1076 assert_eq!(cancelled, 0, "No tasks should be cancelled");
1077 }
1078
1079 #[tokio::test]
1080 async fn test_postpone_task() {
1081 let timer = TimerWheel::with_defaults();
1082 let mut service = timer.create_service();
1083 let counter = Arc::new(AtomicU32::new(0));
1084
1085 let counter_clone = Arc::clone(&counter);
1087 let task = TimerService::create_task(
1088 Duration::from_millis(50),
1089 move || {
1090 let counter = Arc::clone(&counter_clone);
1091 async move {
1092 counter.fetch_add(1, Ordering::SeqCst);
1093 }
1094 },
1095 );
1096 let task_id = task.get_id();
1097 service.register(task).await;
1098
1099 let postponed = service.postpone_task(task_id, Duration::from_millis(150)).await;
1101 assert!(postponed, "Task should be postponed successfully");
1102
1103 tokio::time::sleep(Duration::from_millis(70)).await;
1105 assert_eq!(counter.load(Ordering::SeqCst), 0);
1106
1107 let mut rx = service.take_receiver().unwrap();
1109 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1110 .await
1111 .expect("Should receive timeout notification")
1112 .expect("Should receive Some value");
1113
1114 assert_eq!(received_task_id, task_id);
1115
1116 tokio::time::sleep(Duration::from_millis(20)).await;
1118 assert_eq!(counter.load(Ordering::SeqCst), 1);
1119 }
1120
1121 #[tokio::test]
1122 async fn test_postpone_task_with_callback() {
1123 let timer = TimerWheel::with_defaults();
1124 let mut service = timer.create_service();
1125 let counter = Arc::new(AtomicU32::new(0));
1126
1127 let counter_clone1 = Arc::clone(&counter);
1129 let task = TimerService::create_task(
1130 Duration::from_millis(50),
1131 move || {
1132 let counter = Arc::clone(&counter_clone1);
1133 async move {
1134 counter.fetch_add(1, Ordering::SeqCst);
1135 }
1136 },
1137 );
1138 let task_id = task.get_id();
1139 service.register(task).await;
1140
1141 let counter_clone2 = Arc::clone(&counter);
1143 let postponed = service.postpone_task_with_callback(
1144 task_id,
1145 Duration::from_millis(100),
1146 move || {
1147 let counter = Arc::clone(&counter_clone2);
1148 async move {
1149 counter.fetch_add(10, Ordering::SeqCst);
1150 }
1151 }
1152 ).await;
1153 assert!(postponed, "Task should be postponed successfully");
1154
1155 let mut rx = service.take_receiver().unwrap();
1157 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1158 .await
1159 .expect("Should receive timeout notification")
1160 .expect("Should receive Some value");
1161
1162 assert_eq!(received_task_id, task_id);
1163
1164 tokio::time::sleep(Duration::from_millis(20)).await;
1166
1167 assert_eq!(counter.load(Ordering::SeqCst), 10);
1169 }
1170
1171 #[tokio::test]
1172 async fn test_postpone_nonexistent_task() {
1173 let timer = TimerWheel::with_defaults();
1174 let service = timer.create_service();
1175
1176 let fake_task = TimerService::create_task(Duration::from_millis(50), || async {});
1178 let fake_task_id = fake_task.get_id();
1179 let postponed = service.postpone_task(fake_task_id, Duration::from_millis(100)).await;
1182 assert!(!postponed, "Nonexistent task should not be postponed");
1183 }
1184
1185 #[tokio::test]
1186 async fn test_postpone_batch() {
1187 let timer = TimerWheel::with_defaults();
1188 let mut service = timer.create_service();
1189 let counter = Arc::new(AtomicU32::new(0));
1190
1191 let mut task_ids = Vec::new();
1193 for _ in 0..3 {
1194 let counter_clone = Arc::clone(&counter);
1195 let task = TimerService::create_task(
1196 Duration::from_millis(50),
1197 move || {
1198 let counter = Arc::clone(&counter_clone);
1199 async move {
1200 counter.fetch_add(1, Ordering::SeqCst);
1201 }
1202 },
1203 );
1204 task_ids.push((task.get_id(), Duration::from_millis(150)));
1205 service.register(task).await;
1206 }
1207
1208 let postponed = service.postpone_batch(&task_ids).await;
1210 assert_eq!(postponed, 3, "All 3 tasks should be postponed");
1211
1212 tokio::time::sleep(Duration::from_millis(70)).await;
1214 assert_eq!(counter.load(Ordering::SeqCst), 0);
1215
1216 let mut received_count = 0;
1218 let mut rx = service.take_receiver().unwrap();
1219
1220 while received_count < 3 {
1221 match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
1222 Ok(Some(_task_id)) => {
1223 received_count += 1;
1224 }
1225 Ok(None) => break,
1226 Err(_) => break,
1227 }
1228 }
1229
1230 assert_eq!(received_count, 3);
1231
1232 tokio::time::sleep(Duration::from_millis(20)).await;
1234 assert_eq!(counter.load(Ordering::SeqCst), 3);
1235 }
1236
1237 #[tokio::test]
1238 async fn test_postpone_batch_with_callbacks() {
1239 let timer = TimerWheel::with_defaults();
1240 let mut service = timer.create_service();
1241 let counter = Arc::new(AtomicU32::new(0));
1242
1243 let mut task_ids = Vec::new();
1245 for _ in 0..3 {
1246 let task = TimerService::create_task(
1247 Duration::from_millis(50),
1248 || async {},
1249 );
1250 task_ids.push(task.get_id());
1251 service.register(task).await;
1252 }
1253
1254 let updates: Vec<_> = task_ids
1256 .into_iter()
1257 .map(|id| {
1258 let counter_clone = Arc::clone(&counter);
1259 (id, Duration::from_millis(150), move || {
1260 let counter = Arc::clone(&counter_clone);
1261 async move {
1262 counter.fetch_add(1, Ordering::SeqCst);
1263 }
1264 })
1265 })
1266 .collect();
1267
1268 let postponed = service.postpone_batch_with_callbacks(updates).await;
1269 assert_eq!(postponed, 3, "All 3 tasks should be postponed");
1270
1271 tokio::time::sleep(Duration::from_millis(70)).await;
1273 assert_eq!(counter.load(Ordering::SeqCst), 0);
1274
1275 let mut received_count = 0;
1277 let mut rx = service.take_receiver().unwrap();
1278
1279 while received_count < 3 {
1280 match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
1281 Ok(Some(_task_id)) => {
1282 received_count += 1;
1283 }
1284 Ok(None) => break,
1285 Err(_) => break,
1286 }
1287 }
1288
1289 assert_eq!(received_count, 3);
1290
1291 tokio::time::sleep(Duration::from_millis(20)).await;
1293 assert_eq!(counter.load(Ordering::SeqCst), 3);
1294 }
1295
1296 #[tokio::test]
1297 async fn test_postpone_batch_empty() {
1298 let timer = TimerWheel::with_defaults();
1299 let service = timer.create_service();
1300
1301 let empty: Vec<(TaskId, Duration)> = vec![];
1303 let postponed = service.postpone_batch(&empty).await;
1304 assert_eq!(postponed, 0, "No tasks should be postponed");
1305 }
1306
1307 #[tokio::test]
1308 async fn test_postpone_keeps_timeout_notification_valid() {
1309 let timer = TimerWheel::with_defaults();
1310 let mut service = timer.create_service();
1311 let counter = Arc::new(AtomicU32::new(0));
1312
1313 let counter_clone = Arc::clone(&counter);
1315 let task = TimerService::create_task(
1316 Duration::from_millis(50),
1317 move || {
1318 let counter = Arc::clone(&counter_clone);
1319 async move {
1320 counter.fetch_add(1, Ordering::SeqCst);
1321 }
1322 },
1323 );
1324 let task_id = task.get_id();
1325 service.register(task).await;
1326
1327 service.postpone_task(task_id, Duration::from_millis(100)).await;
1329
1330 let mut rx = service.take_receiver().unwrap();
1332 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1333 .await
1334 .expect("Should receive timeout notification")
1335 .expect("Should receive Some value");
1336
1337 assert_eq!(received_task_id, task_id, "Timeout notification should still work after postpone");
1338
1339 tokio::time::sleep(Duration::from_millis(20)).await;
1341 assert_eq!(counter.load(Ordering::SeqCst), 1);
1342 }
1343
1344 #[tokio::test]
1345 async fn test_cancelled_task_not_forwarded_to_timeout_rx() {
1346 let timer = TimerWheel::with_defaults();
1347 let mut service = timer.create_service();
1348
1349 let task1 = TimerService::create_task(Duration::from_secs(10), || async {});
1351 let task1_id = task1.get_id();
1352 service.register(task1).await;
1353
1354 let task2 = TimerService::create_task(Duration::from_millis(50), || async {});
1355 let task2_id = task2.get_id();
1356 service.register(task2).await;
1357
1358 let cancelled = service.cancel_task(task1_id).await;
1360 assert!(cancelled, "Task should be cancelled");
1361
1362 let mut rx = service.take_receiver().unwrap();
1364 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1365 .await
1366 .expect("Should receive timeout notification")
1367 .expect("Should receive Some value");
1368
1369 assert_eq!(received_task_id, task2_id, "Should only receive expired task notification");
1371
1372 let no_more = tokio::time::timeout(Duration::from_millis(50), rx.recv()).await;
1374 assert!(no_more.is_err(), "Should not receive any more notifications");
1375 }
1376}
1377