1use crate::config::ServiceConfig;
2use crate::error::TimerError;
3use crate::task::{TaskCompletionReason, TaskId, TimerCallback};
4use crate::timer::{BatchHandle, TimerHandle};
5use crate::wheel::Wheel;
6use futures::stream::{FuturesUnordered, StreamExt};
7use futures::future::BoxFuture;
8use parking_lot::Mutex;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::sync::mpsc;
12use tokio::task::JoinHandle;
13
14enum ServiceCommand {
16 AddBatchHandle(BatchHandle),
18 AddTimerHandle(TimerHandle),
20 Shutdown,
22}
23
24pub struct TimerService {
59 command_tx: mpsc::Sender<ServiceCommand>,
61 timeout_rx: Option<mpsc::Receiver<TaskId>>,
63 actor_handle: Option<JoinHandle<()>>,
65 wheel: Arc<Mutex<Wheel>>,
67}
68
69impl TimerService {
70 pub(crate) fn new(wheel: Arc<Mutex<Wheel>>, config: ServiceConfig) -> Self {
90 let (command_tx, command_rx) = mpsc::channel(config.command_channel_capacity);
91 let (timeout_tx, timeout_rx) = mpsc::channel(config.timeout_channel_capacity);
92
93 let actor = ServiceActor::new(command_rx, timeout_tx);
94 let actor_handle = tokio::spawn(async move {
95 actor.run().await;
96 });
97
98 Self {
99 command_tx,
100 timeout_rx: Some(timeout_rx),
101 actor_handle: Some(actor_handle),
102 wheel,
103 }
104 }
105
106 pub fn take_receiver(&mut self) -> Option<mpsc::Receiver<TaskId>> {
130 self.timeout_rx.take()
131 }
132
133 #[inline]
165 pub fn cancel_task(&self, task_id: TaskId) -> bool {
166 let mut wheel = self.wheel.lock();
169 wheel.cancel(task_id)
170 }
171
172 #[inline]
204 pub fn cancel_batch(&self, task_ids: &[TaskId]) -> usize {
205 if task_ids.is_empty() {
206 return 0;
207 }
208
209 let mut wheel = self.wheel.lock();
212 wheel.cancel_batch(task_ids)
213 }
214
215 #[inline]
252 pub fn postpone_task(&self, task_id: TaskId, new_delay: Duration) -> bool {
253 let mut wheel = self.wheel.lock();
256 wheel.postpone(task_id, new_delay, None)
257 }
258
259 #[inline]
300 pub fn postpone_task_with_callback<C>(
301 &self,
302 task_id: TaskId,
303 new_delay: Duration,
304 callback: C,
305 ) -> bool
306 where
307 C: TimerCallback,
308 {
309 use std::sync::Arc;
310 let callback_wrapper = Arc::new(callback) as Arc<dyn TimerCallback>;
311 let mut wheel = self.wheel.lock();
312 wheel.postpone(task_id, new_delay, Some(callback_wrapper))
313 }
314
315 #[inline]
349 pub fn postpone_batch(&self, updates: &[(TaskId, Duration)]) -> usize {
350 if updates.is_empty() {
351 return 0;
352 }
353
354 let updates_vec: Vec<_> = updates
355 .iter()
356 .map(|(task_id, delay)| (*task_id, *delay, None))
357 .collect();
358 let mut wheel = self.wheel.lock();
359 wheel.postpone_batch(updates_vec)
360 }
361
362 #[inline]
401 pub fn postpone_batch_with_callbacks<C>(
402 &self,
403 updates: Vec<(TaskId, Duration, C)>,
404 ) -> usize
405 where
406 C: TimerCallback,
407 {
408 if updates.is_empty() {
409 return 0;
410 }
411
412 use std::sync::Arc;
413 let updates_vec: Vec<_> = updates
414 .into_iter()
415 .map(|(task_id, delay, callback)| {
416 let callback_wrapper = Arc::new(callback) as Arc<dyn TimerCallback>;
417 (task_id, delay, Some(callback_wrapper))
418 })
419 .collect();
420 let mut wheel = self.wheel.lock();
421 wheel.postpone_batch(updates_vec)
422 }
423
424 pub fn create_task<C>(delay: Duration, callback: C) -> crate::task::TimerTask
455 where
456 C: TimerCallback,
457 {
458 crate::timer::TimerWheel::create_task(delay, callback)
459 }
460
461 pub fn create_batch<C>(callbacks: Vec<(Duration, C)>) -> Vec<crate::task::TimerTask>
493 where
494 C: TimerCallback,
495 {
496 crate::timer::TimerWheel::create_batch(callbacks)
497 }
498
499 #[inline]
526 pub fn register(&self, task: crate::task::TimerTask) -> Result<(), TimerError> {
527 let (completion_tx, completion_rx) = tokio::sync::oneshot::channel();
528 let notifier = crate::task::CompletionNotifier(completion_tx);
529
530 let delay = task.delay;
531 let task_id = task.id;
532
533 {
535 let mut wheel_guard = self.wheel.lock();
536 wheel_guard.insert(delay, task, notifier);
537 }
538
539 let handle = TimerHandle::new(task_id, self.wheel.clone(), completion_rx);
541 self.command_tx
542 .try_send(ServiceCommand::AddTimerHandle(handle))
543 .map_err(|_| TimerError::RegisterFailed)?;
544
545 Ok(())
546 }
547
548 #[inline]
575 pub fn register_batch(&self, tasks: Vec<crate::task::TimerTask>) -> Result<(), TimerError> {
576 let task_count = tasks.len();
577 let mut completion_rxs = Vec::with_capacity(task_count);
578 let mut task_ids = Vec::with_capacity(task_count);
579 let mut prepared_tasks = Vec::with_capacity(task_count);
580
581 for task in tasks {
584 let (completion_tx, completion_rx) = tokio::sync::oneshot::channel();
585 let notifier = crate::task::CompletionNotifier(completion_tx);
586
587 task_ids.push(task.id);
588 completion_rxs.push(completion_rx);
589 prepared_tasks.push((task.delay, task, notifier));
590 }
591
592 {
594 let mut wheel_guard = self.wheel.lock();
595 wheel_guard.insert_batch(prepared_tasks);
596 }
597
598 let batch_handle = BatchHandle::new(task_ids, self.wheel.clone(), completion_rxs);
600 self.command_tx
601 .try_send(ServiceCommand::AddBatchHandle(batch_handle))
602 .map_err(|_| TimerError::RegisterFailed)?;
603
604 Ok(())
605 }
606
607 pub async fn shutdown(mut self) {
623 let _ = self.command_tx.send(ServiceCommand::Shutdown).await;
624 if let Some(handle) = self.actor_handle.take() {
625 let _ = handle.await;
626 }
627 }
628}
629
630
631impl Drop for TimerService {
632 fn drop(&mut self) {
633 if let Some(handle) = self.actor_handle.take() {
634 handle.abort();
635 }
636 }
637}
638
639struct ServiceActor {
641 command_rx: mpsc::Receiver<ServiceCommand>,
643 timeout_tx: mpsc::Sender<TaskId>,
645}
646
647impl ServiceActor {
648 fn new(command_rx: mpsc::Receiver<ServiceCommand>, timeout_tx: mpsc::Sender<TaskId>) -> Self {
649 Self {
650 command_rx,
651 timeout_tx,
652 }
653 }
654
655 async fn run(mut self) {
656 let mut futures: FuturesUnordered<BoxFuture<'static, (TaskId, Result<TaskCompletionReason, tokio::sync::oneshot::error::RecvError>)>> = FuturesUnordered::new();
659
660 loop {
661 tokio::select! {
662 Some((task_id, result)) = futures.next() => {
664 if let Ok(TaskCompletionReason::Expired) = result {
666 let _ = self.timeout_tx.send(task_id).await;
667 }
668 }
670
671 Some(cmd) = self.command_rx.recv() => {
673 match cmd {
674 ServiceCommand::AddBatchHandle(batch) => {
675 let BatchHandle {
676 task_ids,
677 completion_rxs,
678 ..
679 } = batch;
680
681 for (task_id, rx) in task_ids.into_iter().zip(completion_rxs.into_iter()) {
683 let future: BoxFuture<'static, (TaskId, Result<TaskCompletionReason, tokio::sync::oneshot::error::RecvError>)> = Box::pin(async move {
684 (task_id, rx.await)
685 });
686 futures.push(future);
687 }
688 }
689 ServiceCommand::AddTimerHandle(handle) => {
690 let TimerHandle{
691 task_id,
692 completion_rx,
693 ..
694 } = handle;
695
696 let future: BoxFuture<'static, (TaskId, Result<TaskCompletionReason, tokio::sync::oneshot::error::RecvError>)> = Box::pin(async move {
698 (task_id, completion_rx.0.await)
699 });
700 futures.push(future);
701 }
702 ServiceCommand::Shutdown => {
703 break;
704 }
705 }
706 }
707
708 else => {
710 break;
711 }
712 }
713 }
714 }
715}
716
717#[cfg(test)]
718mod tests {
719 use super::*;
720 use crate::TimerWheel;
721 use std::sync::atomic::{AtomicU32, Ordering};
722 use std::sync::Arc;
723 use std::time::Duration;
724
725 #[tokio::test]
726 async fn test_service_creation() {
727 let timer = TimerWheel::with_defaults();
728 let _service = timer.create_service();
729 }
730
731
732 #[tokio::test]
733 async fn test_add_timer_handle_and_receive_timeout() {
734 let timer = TimerWheel::with_defaults();
735 let mut service = timer.create_service();
736
737 let task = TimerService::create_task(Duration::from_millis(50), || async {});
739 let task_id = task.get_id();
740
741 service.register(task).unwrap();
743
744 let mut rx = service.take_receiver().unwrap();
746 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
747 .await
748 .expect("Should receive timeout notification")
749 .expect("Should receive Some value");
750
751 assert_eq!(received_task_id, task_id);
752 }
753
754
755 #[tokio::test]
756 async fn test_shutdown() {
757 let timer = TimerWheel::with_defaults();
758 let service = timer.create_service();
759
760 let task1 = TimerService::create_task(Duration::from_secs(10), || async {});
762 let task2 = TimerService::create_task(Duration::from_secs(10), || async {});
763 service.register(task1).unwrap();
764 service.register(task2).unwrap();
765
766 service.shutdown().await;
768 }
769
770
771
772 #[tokio::test]
773 async fn test_cancel_task() {
774 let timer = TimerWheel::with_defaults();
775 let service = timer.create_service();
776
777 let task = TimerService::create_task(Duration::from_secs(10), || async {});
779 let task_id = task.get_id();
780
781 service.register(task).unwrap();
782
783 let cancelled = service.cancel_task(task_id);
785 assert!(cancelled, "Task should be cancelled successfully");
786
787 let cancelled_again = service.cancel_task(task_id);
789 assert!(!cancelled_again, "Task should not exist anymore");
790 }
791
792 #[tokio::test]
793 async fn test_cancel_nonexistent_task() {
794 let timer = TimerWheel::with_defaults();
795 let service = timer.create_service();
796
797 let task = TimerService::create_task(Duration::from_millis(50), || async {});
799 service.register(task).unwrap();
800
801 let fake_task = TimerService::create_task(Duration::from_millis(50), || async {});
803 let fake_task_id = fake_task.get_id();
804 let cancelled = service.cancel_task(fake_task_id);
806 assert!(!cancelled, "Nonexistent task should not be cancelled");
807 }
808
809
810 #[tokio::test]
811 async fn test_task_timeout_cleans_up_task_sender() {
812 let timer = TimerWheel::with_defaults();
813 let mut service = timer.create_service();
814
815 let task = TimerService::create_task(Duration::from_millis(50), || async {});
817 let task_id = task.get_id();
818
819 service.register(task).unwrap();
820
821 let mut rx = service.take_receiver().unwrap();
823 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
824 .await
825 .expect("Should receive timeout notification")
826 .expect("Should receive Some value");
827
828 assert_eq!(received_task_id, task_id);
829
830 tokio::time::sleep(Duration::from_millis(10)).await;
832
833 let cancelled = service.cancel_task(task_id);
835 assert!(!cancelled, "Timed out task should not exist anymore");
836 }
837
838 #[tokio::test]
839 async fn test_cancel_task_spawns_background_task() {
840 let timer = TimerWheel::with_defaults();
841 let service = timer.create_service();
842 let counter = Arc::new(AtomicU32::new(0));
843
844 let counter_clone = Arc::clone(&counter);
846 let task = TimerService::create_task(
847 Duration::from_secs(10),
848 move || {
849 let counter = Arc::clone(&counter_clone);
850 async move {
851 counter.fetch_add(1, Ordering::SeqCst);
852 }
853 },
854 );
855 let task_id = task.get_id();
856
857 service.register(task).unwrap();
858
859 let cancelled = service.cancel_task(task_id);
861 assert!(cancelled, "Task should be cancelled successfully");
862
863 tokio::time::sleep(Duration::from_millis(100)).await;
865 assert_eq!(counter.load(Ordering::SeqCst), 0, "Callback should not have been executed");
866
867 let cancelled_again = service.cancel_task(task_id);
869 assert!(!cancelled_again, "Task should have been removed from active_tasks");
870 }
871
872 #[tokio::test]
873 async fn test_schedule_once_direct() {
874 let timer = TimerWheel::with_defaults();
875 let mut service = timer.create_service();
876 let counter = Arc::new(AtomicU32::new(0));
877
878 let counter_clone = Arc::clone(&counter);
880 let task = TimerService::create_task(
881 Duration::from_millis(50),
882 move || {
883 let counter = Arc::clone(&counter_clone);
884 async move {
885 counter.fetch_add(1, Ordering::SeqCst);
886 }
887 },
888 );
889 let task_id = task.get_id();
890 service.register(task).unwrap();
891
892 let mut rx = service.take_receiver().unwrap();
894 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
895 .await
896 .expect("Should receive timeout notification")
897 .expect("Should receive Some value");
898
899 assert_eq!(received_task_id, task_id);
900
901 tokio::time::sleep(Duration::from_millis(50)).await;
903 assert_eq!(counter.load(Ordering::SeqCst), 1);
904 }
905
906 #[tokio::test]
907 async fn test_schedule_once_batch_direct() {
908 let timer = TimerWheel::with_defaults();
909 let mut service = timer.create_service();
910 let counter = Arc::new(AtomicU32::new(0));
911
912 let callbacks: Vec<_> = (0..3)
914 .map(|_| {
915 let counter = Arc::clone(&counter);
916 (Duration::from_millis(50), move || {
917 let counter = Arc::clone(&counter);
918 async move {
919 counter.fetch_add(1, Ordering::SeqCst);
920 }
921 })
922 })
923 .collect();
924
925 let tasks = TimerService::create_batch(callbacks);
926 assert_eq!(tasks.len(), 3);
927 service.register_batch(tasks).unwrap();
928
929 let mut received_count = 0;
931 let mut rx = service.take_receiver().unwrap();
932
933 while received_count < 3 {
934 match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
935 Ok(Some(_task_id)) => {
936 received_count += 1;
937 }
938 Ok(None) => break,
939 Err(_) => break,
940 }
941 }
942
943 assert_eq!(received_count, 3);
944
945 tokio::time::sleep(Duration::from_millis(50)).await;
947 assert_eq!(counter.load(Ordering::SeqCst), 3);
948 }
949
950 #[tokio::test]
951 async fn test_schedule_once_notify_direct() {
952 let timer = TimerWheel::with_defaults();
953 let mut service = timer.create_service();
954
955 let task = crate::task::TimerTask::new(Duration::from_millis(50), None);
957 let task_id = task.get_id();
958 service.register(task).unwrap();
959
960 let mut rx = service.take_receiver().unwrap();
962 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
963 .await
964 .expect("Should receive timeout notification")
965 .expect("Should receive Some value");
966
967 assert_eq!(received_task_id, task_id);
968 }
969
970 #[tokio::test]
971 async fn test_schedule_and_cancel_direct() {
972 let timer = TimerWheel::with_defaults();
973 let service = timer.create_service();
974 let counter = Arc::new(AtomicU32::new(0));
975
976 let counter_clone = Arc::clone(&counter);
978 let task = TimerService::create_task(
979 Duration::from_secs(10),
980 move || {
981 let counter = Arc::clone(&counter_clone);
982 async move {
983 counter.fetch_add(1, Ordering::SeqCst);
984 }
985 },
986 );
987 let task_id = task.get_id();
988 service.register(task).unwrap();
989
990 let cancelled = service.cancel_task(task_id);
992 assert!(cancelled, "Task should be cancelled successfully");
993
994 tokio::time::sleep(Duration::from_millis(100)).await;
996 assert_eq!(counter.load(Ordering::SeqCst), 0, "Callback should not have been executed");
997 }
998
999 #[tokio::test]
1000 async fn test_cancel_batch_direct() {
1001 let timer = TimerWheel::with_defaults();
1002 let service = timer.create_service();
1003 let counter = Arc::new(AtomicU32::new(0));
1004
1005 let callbacks: Vec<_> = (0..10)
1007 .map(|_| {
1008 let counter = Arc::clone(&counter);
1009 (Duration::from_secs(10), move || {
1010 let counter = Arc::clone(&counter);
1011 async move {
1012 counter.fetch_add(1, Ordering::SeqCst);
1013 }
1014 })
1015 })
1016 .collect();
1017
1018 let tasks = TimerService::create_batch(callbacks);
1019 let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
1020 assert_eq!(task_ids.len(), 10);
1021 service.register_batch(tasks).unwrap();
1022
1023 let cancelled = service.cancel_batch(&task_ids);
1025 assert_eq!(cancelled, 10, "All 10 tasks should be cancelled");
1026
1027 tokio::time::sleep(Duration::from_millis(100)).await;
1029 assert_eq!(counter.load(Ordering::SeqCst), 0, "No callbacks should have been executed");
1030 }
1031
1032 #[tokio::test]
1033 async fn test_cancel_batch_partial() {
1034 let timer = TimerWheel::with_defaults();
1035 let service = timer.create_service();
1036 let counter = Arc::new(AtomicU32::new(0));
1037
1038 let callbacks: Vec<_> = (0..10)
1040 .map(|_| {
1041 let counter = Arc::clone(&counter);
1042 (Duration::from_secs(10), move || {
1043 let counter = Arc::clone(&counter);
1044 async move {
1045 counter.fetch_add(1, Ordering::SeqCst);
1046 }
1047 })
1048 })
1049 .collect();
1050
1051 let tasks = TimerService::create_batch(callbacks);
1052 let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
1053 service.register_batch(tasks).unwrap();
1054
1055 let to_cancel: Vec<_> = task_ids.iter().take(5).copied().collect();
1057 let cancelled = service.cancel_batch(&to_cancel);
1058 assert_eq!(cancelled, 5, "5 tasks should be cancelled");
1059
1060 tokio::time::sleep(Duration::from_millis(100)).await;
1062 assert_eq!(counter.load(Ordering::SeqCst), 0, "Cancelled tasks should not execute");
1063 }
1064
1065 #[tokio::test]
1066 async fn test_cancel_batch_empty() {
1067 let timer = TimerWheel::with_defaults();
1068 let service = timer.create_service();
1069
1070 let empty: Vec<TaskId> = vec![];
1072 let cancelled = service.cancel_batch(&empty);
1073 assert_eq!(cancelled, 0, "No tasks should be cancelled");
1074 }
1075
1076 #[tokio::test]
1077 async fn test_postpone_task() {
1078 let timer = TimerWheel::with_defaults();
1079 let mut service = timer.create_service();
1080 let counter = Arc::new(AtomicU32::new(0));
1081
1082 let counter_clone = Arc::clone(&counter);
1084 let task = TimerService::create_task(
1085 Duration::from_millis(50),
1086 move || {
1087 let counter = Arc::clone(&counter_clone);
1088 async move {
1089 counter.fetch_add(1, Ordering::SeqCst);
1090 }
1091 },
1092 );
1093 let task_id = task.get_id();
1094 service.register(task).unwrap();
1095
1096 let postponed = service.postpone_task(task_id, Duration::from_millis(150));
1098 assert!(postponed, "Task should be postponed successfully");
1099
1100 tokio::time::sleep(Duration::from_millis(70)).await;
1102 assert_eq!(counter.load(Ordering::SeqCst), 0);
1103
1104 let mut rx = service.take_receiver().unwrap();
1106 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1107 .await
1108 .expect("Should receive timeout notification")
1109 .expect("Should receive Some value");
1110
1111 assert_eq!(received_task_id, task_id);
1112
1113 tokio::time::sleep(Duration::from_millis(20)).await;
1115 assert_eq!(counter.load(Ordering::SeqCst), 1);
1116 }
1117
1118 #[tokio::test]
1119 async fn test_postpone_task_with_callback() {
1120 let timer = TimerWheel::with_defaults();
1121 let mut service = timer.create_service();
1122 let counter = Arc::new(AtomicU32::new(0));
1123
1124 let counter_clone1 = Arc::clone(&counter);
1126 let task = TimerService::create_task(
1127 Duration::from_millis(50),
1128 move || {
1129 let counter = Arc::clone(&counter_clone1);
1130 async move {
1131 counter.fetch_add(1, Ordering::SeqCst);
1132 }
1133 },
1134 );
1135 let task_id = task.get_id();
1136 service.register(task).unwrap();
1137
1138 let counter_clone2 = Arc::clone(&counter);
1140 let postponed = service.postpone_task_with_callback(
1141 task_id,
1142 Duration::from_millis(100),
1143 move || {
1144 let counter = Arc::clone(&counter_clone2);
1145 async move {
1146 counter.fetch_add(10, Ordering::SeqCst);
1147 }
1148 }
1149 );
1150 assert!(postponed, "Task should be postponed successfully");
1151
1152 let mut rx = service.take_receiver().unwrap();
1154 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1155 .await
1156 .expect("Should receive timeout notification")
1157 .expect("Should receive Some value");
1158
1159 assert_eq!(received_task_id, task_id);
1160
1161 tokio::time::sleep(Duration::from_millis(20)).await;
1163
1164 assert_eq!(counter.load(Ordering::SeqCst), 10);
1166 }
1167
1168 #[tokio::test]
1169 async fn test_postpone_nonexistent_task() {
1170 let timer = TimerWheel::with_defaults();
1171 let service = timer.create_service();
1172
1173 let fake_task = TimerService::create_task(Duration::from_millis(50), || async {});
1175 let fake_task_id = fake_task.get_id();
1176 let postponed = service.postpone_task(fake_task_id, Duration::from_millis(100));
1179 assert!(!postponed, "Nonexistent task should not be postponed");
1180 }
1181
1182 #[tokio::test]
1183 async fn test_postpone_batch() {
1184 let timer = TimerWheel::with_defaults();
1185 let mut service = timer.create_service();
1186 let counter = Arc::new(AtomicU32::new(0));
1187
1188 let mut task_ids = Vec::new();
1190 for _ in 0..3 {
1191 let counter_clone = Arc::clone(&counter);
1192 let task = TimerService::create_task(
1193 Duration::from_millis(50),
1194 move || {
1195 let counter = Arc::clone(&counter_clone);
1196 async move {
1197 counter.fetch_add(1, Ordering::SeqCst);
1198 }
1199 },
1200 );
1201 task_ids.push((task.get_id(), Duration::from_millis(150)));
1202 service.register(task).unwrap();
1203 }
1204
1205 let postponed = service.postpone_batch(&task_ids);
1207 assert_eq!(postponed, 3, "All 3 tasks should be postponed");
1208
1209 tokio::time::sleep(Duration::from_millis(70)).await;
1211 assert_eq!(counter.load(Ordering::SeqCst), 0);
1212
1213 let mut received_count = 0;
1215 let mut rx = service.take_receiver().unwrap();
1216
1217 while received_count < 3 {
1218 match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
1219 Ok(Some(_task_id)) => {
1220 received_count += 1;
1221 }
1222 Ok(None) => break,
1223 Err(_) => break,
1224 }
1225 }
1226
1227 assert_eq!(received_count, 3);
1228
1229 tokio::time::sleep(Duration::from_millis(20)).await;
1231 assert_eq!(counter.load(Ordering::SeqCst), 3);
1232 }
1233
1234 #[tokio::test]
1235 async fn test_postpone_batch_with_callbacks() {
1236 let timer = TimerWheel::with_defaults();
1237 let mut service = timer.create_service();
1238 let counter = Arc::new(AtomicU32::new(0));
1239
1240 let mut task_ids = Vec::new();
1242 for _ in 0..3 {
1243 let task = TimerService::create_task(
1244 Duration::from_millis(50),
1245 || async {},
1246 );
1247 task_ids.push(task.get_id());
1248 service.register(task).unwrap();
1249 }
1250
1251 let updates: Vec<_> = task_ids
1253 .into_iter()
1254 .map(|id| {
1255 let counter_clone = Arc::clone(&counter);
1256 (id, Duration::from_millis(150), move || {
1257 let counter = Arc::clone(&counter_clone);
1258 async move {
1259 counter.fetch_add(1, Ordering::SeqCst);
1260 }
1261 })
1262 })
1263 .collect();
1264
1265 let postponed = service.postpone_batch_with_callbacks(updates);
1266 assert_eq!(postponed, 3, "All 3 tasks should be postponed");
1267
1268 tokio::time::sleep(Duration::from_millis(70)).await;
1270 assert_eq!(counter.load(Ordering::SeqCst), 0);
1271
1272 let mut received_count = 0;
1274 let mut rx = service.take_receiver().unwrap();
1275
1276 while received_count < 3 {
1277 match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
1278 Ok(Some(_task_id)) => {
1279 received_count += 1;
1280 }
1281 Ok(None) => break,
1282 Err(_) => break,
1283 }
1284 }
1285
1286 assert_eq!(received_count, 3);
1287
1288 tokio::time::sleep(Duration::from_millis(20)).await;
1290 assert_eq!(counter.load(Ordering::SeqCst), 3);
1291 }
1292
1293 #[tokio::test]
1294 async fn test_postpone_batch_empty() {
1295 let timer = TimerWheel::with_defaults();
1296 let service = timer.create_service();
1297
1298 let empty: Vec<(TaskId, Duration)> = vec![];
1300 let postponed = service.postpone_batch(&empty);
1301 assert_eq!(postponed, 0, "No tasks should be postponed");
1302 }
1303
1304 #[tokio::test]
1305 async fn test_postpone_keeps_timeout_notification_valid() {
1306 let timer = TimerWheel::with_defaults();
1307 let mut service = timer.create_service();
1308 let counter = Arc::new(AtomicU32::new(0));
1309
1310 let counter_clone = Arc::clone(&counter);
1312 let task = TimerService::create_task(
1313 Duration::from_millis(50),
1314 move || {
1315 let counter = Arc::clone(&counter_clone);
1316 async move {
1317 counter.fetch_add(1, Ordering::SeqCst);
1318 }
1319 },
1320 );
1321 let task_id = task.get_id();
1322 service.register(task).unwrap();
1323
1324 service.postpone_task(task_id, Duration::from_millis(100));
1326
1327 let mut rx = service.take_receiver().unwrap();
1329 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1330 .await
1331 .expect("Should receive timeout notification")
1332 .expect("Should receive Some value");
1333
1334 assert_eq!(received_task_id, task_id, "Timeout notification should still work after postpone");
1335
1336 tokio::time::sleep(Duration::from_millis(20)).await;
1338 assert_eq!(counter.load(Ordering::SeqCst), 1);
1339 }
1340
1341 #[tokio::test]
1342 async fn test_cancelled_task_not_forwarded_to_timeout_rx() {
1343 let timer = TimerWheel::with_defaults();
1344 let mut service = timer.create_service();
1345
1346 let task1 = TimerService::create_task(Duration::from_secs(10), || async {});
1348 let task1_id = task1.get_id();
1349 service.register(task1).unwrap();
1350
1351 let task2 = TimerService::create_task(Duration::from_millis(50), || async {});
1352 let task2_id = task2.get_id();
1353 service.register(task2).unwrap();
1354
1355 let cancelled = service.cancel_task(task1_id);
1357 assert!(cancelled, "Task should be cancelled");
1358
1359 let mut rx = service.take_receiver().unwrap();
1361 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1362 .await
1363 .expect("Should receive timeout notification")
1364 .expect("Should receive Some value");
1365
1366 assert_eq!(received_task_id, task2_id, "Should only receive expired task notification");
1368
1369 let no_more = tokio::time::timeout(Duration::from_millis(50), rx.recv()).await;
1371 assert!(no_more.is_err(), "Should not receive any more notifications");
1372 }
1373}
1374