1use crate::config::ServiceConfig;
2use crate::error::TimerError;
3use crate::task::{CallbackWrapper, TaskCompletionReason, TaskId};
4use crate::timer::{BatchHandle, TimerHandle};
5use crate::wheel::Wheel;
6use futures::stream::{FuturesUnordered, StreamExt};
7use futures::future::BoxFuture;
8use parking_lot::Mutex;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::sync::mpsc;
12use tokio::task::JoinHandle;
13
14enum ServiceCommand {
16 AddBatchHandle(BatchHandle),
18 AddTimerHandle(TimerHandle),
20 Shutdown,
22}
23
24pub struct TimerService {
64 command_tx: mpsc::Sender<ServiceCommand>,
66 timeout_rx: Option<mpsc::Receiver<TaskId>>,
68 actor_handle: Option<JoinHandle<()>>,
70 wheel: Arc<Mutex<Wheel>>,
72}
73
74impl TimerService {
75 pub(crate) fn new(wheel: Arc<Mutex<Wheel>>, config: ServiceConfig) -> Self {
95 let (command_tx, command_rx) = mpsc::channel(config.command_channel_capacity);
96 let (timeout_tx, timeout_rx) = mpsc::channel(config.timeout_channel_capacity);
97
98 let actor = ServiceActor::new(command_rx, timeout_tx);
99 let actor_handle = tokio::spawn(async move {
100 actor.run().await;
101 });
102
103 Self {
104 command_tx,
105 timeout_rx: Some(timeout_rx),
106 actor_handle: Some(actor_handle),
107 wheel,
108 }
109 }
110
111 pub fn take_receiver(&mut self) -> Option<mpsc::Receiver<TaskId>> {
135 self.timeout_rx.take()
136 }
137
138 #[inline]
174 pub fn cancel_task(&self, task_id: TaskId) -> bool {
175 let mut wheel = self.wheel.lock();
178 wheel.cancel(task_id)
179 }
180
181 #[inline]
219 pub fn cancel_batch(&self, task_ids: &[TaskId]) -> usize {
220 if task_ids.is_empty() {
221 return 0;
222 }
223
224 let mut wheel = self.wheel.lock();
227 wheel.cancel_batch(task_ids)
228 }
229
230 #[inline]
274 pub fn postpone(&self, task_id: TaskId, new_delay: Duration, callback: Option<CallbackWrapper>) -> bool {
275 let mut wheel = self.wheel.lock();
276 wheel.postpone(task_id, new_delay, callback)
277 }
278
279 #[inline]
319 pub fn postpone_batch(&self, updates: Vec<(TaskId, Duration)>) -> usize {
320 if updates.is_empty() {
321 return 0;
322 }
323
324 let mut wheel = self.wheel.lock();
325 wheel.postpone_batch(updates)
326 }
327
328 #[inline]
370 pub fn postpone_batch_with_callbacks(
371 &self,
372 updates: Vec<(TaskId, Duration, Option<CallbackWrapper>)>,
373 ) -> usize {
374 if updates.is_empty() {
375 return 0;
376 }
377
378 let mut wheel = self.wheel.lock();
379 wheel.postpone_batch_with_callbacks(updates)
380 }
381
382 #[inline]
415 pub fn create_task(delay: Duration, callback: Option<CallbackWrapper>) -> crate::task::TimerTask {
416 crate::timer::TimerWheel::create_task(delay, callback)
417 }
418
419 #[inline]
455 pub fn create_batch(callbacks: Vec<(Duration, Option<CallbackWrapper>)>) -> Vec<crate::task::TimerTask> {
456 crate::timer::TimerWheel::create_batch(callbacks)
457 }
458
459 #[inline]
488 pub fn register(&self, task: crate::task::TimerTask) -> Result<(), TimerError> {
489 let (completion_tx, completion_rx) = tokio::sync::oneshot::channel();
490 let notifier = crate::task::CompletionNotifier(completion_tx);
491
492 let delay = task.delay;
493 let task_id = task.id;
494
495 {
497 let mut wheel_guard = self.wheel.lock();
498 wheel_guard.insert(delay, task, notifier);
499 }
500
501 let handle = TimerHandle::new(task_id, self.wheel.clone(), completion_rx);
503 self.command_tx
504 .try_send(ServiceCommand::AddTimerHandle(handle))
505 .map_err(|_| TimerError::RegisterFailed)?;
506
507 Ok(())
508 }
509
510 #[inline]
543 pub fn register_batch(&self, tasks: Vec<crate::task::TimerTask>) -> Result<(), TimerError> {
544 let task_count = tasks.len();
545 let mut completion_rxs = Vec::with_capacity(task_count);
546 let mut task_ids = Vec::with_capacity(task_count);
547 let mut prepared_tasks = Vec::with_capacity(task_count);
548
549 for task in tasks {
552 let (completion_tx, completion_rx) = tokio::sync::oneshot::channel();
553 let notifier = crate::task::CompletionNotifier(completion_tx);
554
555 task_ids.push(task.id);
556 completion_rxs.push(completion_rx);
557 prepared_tasks.push((task.delay, task, notifier));
558 }
559
560 {
562 let mut wheel_guard = self.wheel.lock();
563 wheel_guard.insert_batch(prepared_tasks);
564 }
565
566 let batch_handle = BatchHandle::new(task_ids, self.wheel.clone(), completion_rxs);
568 self.command_tx
569 .try_send(ServiceCommand::AddBatchHandle(batch_handle))
570 .map_err(|_| TimerError::RegisterFailed)?;
571
572 Ok(())
573 }
574
575 pub async fn shutdown(mut self) {
591 let _ = self.command_tx.send(ServiceCommand::Shutdown).await;
592 if let Some(handle) = self.actor_handle.take() {
593 let _ = handle.await;
594 }
595 }
596}
597
598
599impl Drop for TimerService {
600 fn drop(&mut self) {
601 if let Some(handle) = self.actor_handle.take() {
602 handle.abort();
603 }
604 }
605}
606
607struct ServiceActor {
609 command_rx: mpsc::Receiver<ServiceCommand>,
611 timeout_tx: mpsc::Sender<TaskId>,
613}
614
615impl ServiceActor {
616 fn new(command_rx: mpsc::Receiver<ServiceCommand>, timeout_tx: mpsc::Sender<TaskId>) -> Self {
617 Self {
618 command_rx,
619 timeout_tx,
620 }
621 }
622
623 async fn run(mut self) {
624 let mut futures: FuturesUnordered<BoxFuture<'static, (TaskId, Result<TaskCompletionReason, tokio::sync::oneshot::error::RecvError>)>> = FuturesUnordered::new();
627
628 loop {
629 tokio::select! {
630 Some((task_id, result)) = futures.next() => {
632 if let Ok(TaskCompletionReason::Expired) = result {
634 let _ = self.timeout_tx.send(task_id).await;
635 }
636 }
638
639 Some(cmd) = self.command_rx.recv() => {
641 match cmd {
642 ServiceCommand::AddBatchHandle(batch) => {
643 let BatchHandle {
644 task_ids,
645 completion_rxs,
646 ..
647 } = batch;
648
649 for (task_id, rx) in task_ids.into_iter().zip(completion_rxs.into_iter()) {
651 let future: BoxFuture<'static, (TaskId, Result<TaskCompletionReason, tokio::sync::oneshot::error::RecvError>)> = Box::pin(async move {
652 (task_id, rx.await)
653 });
654 futures.push(future);
655 }
656 }
657 ServiceCommand::AddTimerHandle(handle) => {
658 let TimerHandle{
659 task_id,
660 completion_rx,
661 ..
662 } = handle;
663
664 let future: BoxFuture<'static, (TaskId, Result<TaskCompletionReason, tokio::sync::oneshot::error::RecvError>)> = Box::pin(async move {
666 (task_id, completion_rx.0.await)
667 });
668 futures.push(future);
669 }
670 ServiceCommand::Shutdown => {
671 break;
672 }
673 }
674 }
675
676 else => {
678 break;
679 }
680 }
681 }
682 }
683}
684
685#[cfg(test)]
686mod tests {
687 use super::*;
688 use crate::TimerWheel;
689 use std::sync::atomic::{AtomicU32, Ordering};
690 use std::sync::Arc;
691 use std::time::Duration;
692
693 #[tokio::test]
694 async fn test_service_creation() {
695 let timer = TimerWheel::with_defaults();
696 let _service = timer.create_service();
697 }
698
699
700 #[tokio::test]
701 async fn test_add_timer_handle_and_receive_timeout() {
702 let timer = TimerWheel::with_defaults();
703 let mut service = timer.create_service();
704
705 let task = TimerService::create_task(Duration::from_millis(50), Some(CallbackWrapper::new(|| async {})));
707 let task_id = task.get_id();
708
709 service.register(task).unwrap();
711
712 let mut rx = service.take_receiver().unwrap();
714 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
715 .await
716 .expect("Should receive timeout notification")
717 .expect("Should receive Some value");
718
719 assert_eq!(received_task_id, task_id);
720 }
721
722
723 #[tokio::test]
724 async fn test_shutdown() {
725 let timer = TimerWheel::with_defaults();
726 let service = timer.create_service();
727
728 let task1 = TimerService::create_task(Duration::from_secs(10), None);
730 let task2 = TimerService::create_task(Duration::from_secs(10), None);
731 service.register(task1).unwrap();
732 service.register(task2).unwrap();
733
734 service.shutdown().await;
736 }
737
738
739
740 #[tokio::test]
741 async fn test_cancel_task() {
742 let timer = TimerWheel::with_defaults();
743 let service = timer.create_service();
744
745 let task = TimerService::create_task(Duration::from_secs(10), None);
747 let task_id = task.get_id();
748
749 service.register(task).unwrap();
750
751 let cancelled = service.cancel_task(task_id);
753 assert!(cancelled, "Task should be cancelled successfully");
754
755 let cancelled_again = service.cancel_task(task_id);
757 assert!(!cancelled_again, "Task should not exist anymore");
758 }
759
760 #[tokio::test]
761 async fn test_cancel_nonexistent_task() {
762 let timer = TimerWheel::with_defaults();
763 let service = timer.create_service();
764
765 let task = TimerService::create_task(Duration::from_millis(50), None);
767 service.register(task).unwrap();
768
769 let fake_task = TimerService::create_task(Duration::from_millis(50), None);
771 let fake_task_id = fake_task.get_id();
772 let cancelled = service.cancel_task(fake_task_id);
774 assert!(!cancelled, "Nonexistent task should not be cancelled");
775 }
776
777
778 #[tokio::test]
779 async fn test_task_timeout_cleans_up_task_sender() {
780 let timer = TimerWheel::with_defaults();
781 let mut service = timer.create_service();
782
783 let task = TimerService::create_task(Duration::from_millis(50), None);
785 let task_id = task.get_id();
786
787 service.register(task).unwrap();
788
789 let mut rx = service.take_receiver().unwrap();
791 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
792 .await
793 .expect("Should receive timeout notification")
794 .expect("Should receive Some value");
795
796 assert_eq!(received_task_id, task_id);
797
798 tokio::time::sleep(Duration::from_millis(10)).await;
800
801 let cancelled = service.cancel_task(task_id);
803 assert!(!cancelled, "Timed out task should not exist anymore");
804 }
805
806 #[tokio::test]
807 async fn test_cancel_task_spawns_background_task() {
808 let timer = TimerWheel::with_defaults();
809 let service = timer.create_service();
810 let counter = Arc::new(AtomicU32::new(0));
811
812 let counter_clone = Arc::clone(&counter);
814 let task = TimerService::create_task(
815 Duration::from_secs(10),
816 Some(CallbackWrapper::new(move || {
817 let counter = Arc::clone(&counter_clone);
818 async move {
819 counter.fetch_add(1, Ordering::SeqCst);
820 }
821 })),
822 );
823 let task_id = task.get_id();
824
825 service.register(task).unwrap();
826
827 let cancelled = service.cancel_task(task_id);
829 assert!(cancelled, "Task should be cancelled successfully");
830
831 tokio::time::sleep(Duration::from_millis(100)).await;
833 assert_eq!(counter.load(Ordering::SeqCst), 0, "Callback should not have been executed");
834
835 let cancelled_again = service.cancel_task(task_id);
837 assert!(!cancelled_again, "Task should have been removed from active_tasks");
838 }
839
840 #[tokio::test]
841 async fn test_schedule_once_direct() {
842 let timer = TimerWheel::with_defaults();
843 let mut service = timer.create_service();
844 let counter = Arc::new(AtomicU32::new(0));
845
846 let counter_clone = Arc::clone(&counter);
848 let task = TimerService::create_task(
849 Duration::from_millis(50),
850 Some(CallbackWrapper::new(move || {
851 let counter = Arc::clone(&counter_clone);
852 async move {
853 counter.fetch_add(1, Ordering::SeqCst);
854 }
855 })),
856 );
857 let task_id = task.get_id();
858 service.register(task).unwrap();
859
860 let mut rx = service.take_receiver().unwrap();
862 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
863 .await
864 .expect("Should receive timeout notification")
865 .expect("Should receive Some value");
866
867 assert_eq!(received_task_id, task_id);
868
869 tokio::time::sleep(Duration::from_millis(50)).await;
871 assert_eq!(counter.load(Ordering::SeqCst), 1);
872 }
873
874 #[tokio::test]
875 async fn test_schedule_once_batch_direct() {
876 let timer = TimerWheel::with_defaults();
877 let mut service = timer.create_service();
878 let counter = Arc::new(AtomicU32::new(0));
879
880 let callbacks: Vec<_> = (0..3)
882 .map(|_| {
883 let counter = Arc::clone(&counter);
884 (Duration::from_millis(50), Some(CallbackWrapper::new(move || {
885 let counter = Arc::clone(&counter);
886 async move {
887 counter.fetch_add(1, Ordering::SeqCst);
888 }
889 })))
890 })
891 .collect();
892
893 let tasks = TimerService::create_batch(callbacks);
894 assert_eq!(tasks.len(), 3);
895 service.register_batch(tasks).unwrap();
896
897 let mut received_count = 0;
899 let mut rx = service.take_receiver().unwrap();
900
901 while received_count < 3 {
902 match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
903 Ok(Some(_task_id)) => {
904 received_count += 1;
905 }
906 Ok(None) => break,
907 Err(_) => break,
908 }
909 }
910
911 assert_eq!(received_count, 3);
912
913 tokio::time::sleep(Duration::from_millis(50)).await;
915 assert_eq!(counter.load(Ordering::SeqCst), 3);
916 }
917
918 #[tokio::test]
919 async fn test_schedule_once_notify_direct() {
920 let timer = TimerWheel::with_defaults();
921 let mut service = timer.create_service();
922
923 let task = TimerService::create_task(Duration::from_millis(50), None);
925 let task_id = task.get_id();
926 service.register(task).unwrap();
927
928 let mut rx = service.take_receiver().unwrap();
930 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
931 .await
932 .expect("Should receive timeout notification")
933 .expect("Should receive Some value");
934
935 assert_eq!(received_task_id, task_id);
936 }
937
938 #[tokio::test]
939 async fn test_schedule_and_cancel_direct() {
940 let timer = TimerWheel::with_defaults();
941 let service = timer.create_service();
942 let counter = Arc::new(AtomicU32::new(0));
943
944 let counter_clone = Arc::clone(&counter);
946 let task = TimerService::create_task(
947 Duration::from_secs(10),
948 Some(CallbackWrapper::new(move || {
949 let counter = Arc::clone(&counter_clone);
950 async move {
951 counter.fetch_add(1, Ordering::SeqCst);
952 }
953 })),
954 );
955 let task_id = task.get_id();
956 service.register(task).unwrap();
957
958 let cancelled = service.cancel_task(task_id);
960 assert!(cancelled, "Task should be cancelled successfully");
961
962 tokio::time::sleep(Duration::from_millis(100)).await;
964 assert_eq!(counter.load(Ordering::SeqCst), 0, "Callback should not have been executed");
965 }
966
967 #[tokio::test]
968 async fn test_cancel_batch_direct() {
969 let timer = TimerWheel::with_defaults();
970 let service = timer.create_service();
971 let counter = Arc::new(AtomicU32::new(0));
972
973 let callbacks: Vec<_> = (0..10)
975 .map(|_| {
976 let counter = Arc::clone(&counter);
977 (Duration::from_secs(10), Some(CallbackWrapper::new(move || {
978 let counter = Arc::clone(&counter);
979 async move {
980 counter.fetch_add(1, Ordering::SeqCst);
981 }
982 })))
983 })
984 .collect();
985
986 let tasks = TimerService::create_batch(callbacks);
987 let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
988 assert_eq!(task_ids.len(), 10);
989 service.register_batch(tasks).unwrap();
990
991 let cancelled = service.cancel_batch(&task_ids);
993 assert_eq!(cancelled, 10, "All 10 tasks should be cancelled");
994
995 tokio::time::sleep(Duration::from_millis(100)).await;
997 assert_eq!(counter.load(Ordering::SeqCst), 0, "No callbacks should have been executed");
998 }
999
1000 #[tokio::test]
1001 async fn test_cancel_batch_partial() {
1002 let timer = TimerWheel::with_defaults();
1003 let service = timer.create_service();
1004 let counter = Arc::new(AtomicU32::new(0));
1005
1006 let callbacks: Vec<_> = (0..10)
1008 .map(|_| {
1009 let counter = Arc::clone(&counter);
1010 (Duration::from_secs(10), Some(CallbackWrapper::new(move || {
1011 let counter = Arc::clone(&counter);
1012 async move {
1013 counter.fetch_add(1, Ordering::SeqCst);
1014 }
1015 })))
1016 })
1017 .collect();
1018
1019 let tasks = TimerService::create_batch(callbacks);
1020 let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
1021 service.register_batch(tasks).unwrap();
1022
1023 let to_cancel: Vec<_> = task_ids.iter().take(5).copied().collect();
1025 let cancelled = service.cancel_batch(&to_cancel);
1026 assert_eq!(cancelled, 5, "5 tasks should be cancelled");
1027
1028 tokio::time::sleep(Duration::from_millis(100)).await;
1030 assert_eq!(counter.load(Ordering::SeqCst), 0, "Cancelled tasks should not execute");
1031 }
1032
1033 #[tokio::test]
1034 async fn test_cancel_batch_empty() {
1035 let timer = TimerWheel::with_defaults();
1036 let service = timer.create_service();
1037
1038 let empty: Vec<TaskId> = vec![];
1040 let cancelled = service.cancel_batch(&empty);
1041 assert_eq!(cancelled, 0, "No tasks should be cancelled");
1042 }
1043
1044 #[tokio::test]
1045 async fn test_postpone() {
1046 let timer = TimerWheel::with_defaults();
1047 let mut service = timer.create_service();
1048 let counter = Arc::new(AtomicU32::new(0));
1049
1050 let counter_clone1 = Arc::clone(&counter);
1052 let task = TimerService::create_task(
1053 Duration::from_millis(50),
1054 Some(CallbackWrapper::new(move || {
1055 let counter = Arc::clone(&counter_clone1);
1056 async move {
1057 counter.fetch_add(1, Ordering::SeqCst);
1058 }
1059 })),
1060 );
1061 let task_id = task.get_id();
1062 service.register(task).unwrap();
1063
1064 let counter_clone2 = Arc::clone(&counter);
1066 let postponed = service.postpone(
1067 task_id,
1068 Duration::from_millis(100),
1069 Some(CallbackWrapper::new(move || {
1070 let counter = Arc::clone(&counter_clone2);
1071 async move {
1072 counter.fetch_add(10, Ordering::SeqCst);
1073 }
1074 }))
1075 );
1076 assert!(postponed, "Task should be postponed successfully");
1077
1078 let mut rx = service.take_receiver().unwrap();
1080 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1081 .await
1082 .expect("Should receive timeout notification")
1083 .expect("Should receive Some value");
1084
1085 assert_eq!(received_task_id, task_id);
1086
1087 tokio::time::sleep(Duration::from_millis(20)).await;
1089
1090 assert_eq!(counter.load(Ordering::SeqCst), 10);
1092 }
1093
1094 #[tokio::test]
1095 async fn test_postpone_nonexistent_task() {
1096 let timer = TimerWheel::with_defaults();
1097 let service = timer.create_service();
1098
1099 let fake_task = TimerService::create_task(Duration::from_millis(50), None);
1101 let fake_task_id = fake_task.get_id();
1102 let postponed = service.postpone(fake_task_id, Duration::from_millis(100), None);
1105 assert!(!postponed, "Nonexistent task should not be postponed");
1106 }
1107
1108 #[tokio::test]
1109 async fn test_postpone_batch() {
1110 let timer = TimerWheel::with_defaults();
1111 let mut service = timer.create_service();
1112 let counter = Arc::new(AtomicU32::new(0));
1113
1114 let mut task_ids = Vec::new();
1116 for _ in 0..3 {
1117 let counter_clone = Arc::clone(&counter);
1118 let task = TimerService::create_task(
1119 Duration::from_millis(50),
1120 Some(CallbackWrapper::new(move || {
1121 let counter = Arc::clone(&counter_clone);
1122 async move {
1123 counter.fetch_add(1, Ordering::SeqCst);
1124 }
1125 })),
1126 );
1127 task_ids.push((task.get_id(), Duration::from_millis(150), None));
1128 service.register(task).unwrap();
1129 }
1130
1131 let postponed = service.postpone_batch_with_callbacks(task_ids);
1133 assert_eq!(postponed, 3, "All 3 tasks should be postponed");
1134
1135 tokio::time::sleep(Duration::from_millis(70)).await;
1137 assert_eq!(counter.load(Ordering::SeqCst), 0);
1138
1139 let mut received_count = 0;
1141 let mut rx = service.take_receiver().unwrap();
1142
1143 while received_count < 3 {
1144 match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
1145 Ok(Some(_task_id)) => {
1146 received_count += 1;
1147 }
1148 Ok(None) => break,
1149 Err(_) => break,
1150 }
1151 }
1152
1153 assert_eq!(received_count, 3);
1154
1155 tokio::time::sleep(Duration::from_millis(20)).await;
1157 assert_eq!(counter.load(Ordering::SeqCst), 3);
1158 }
1159
1160 #[tokio::test]
1161 async fn test_postpone_batch_with_callbacks() {
1162 let timer = TimerWheel::with_defaults();
1163 let mut service = timer.create_service();
1164 let counter = Arc::new(AtomicU32::new(0));
1165
1166 let mut task_ids = Vec::new();
1168 for _ in 0..3 {
1169 let task = TimerService::create_task(
1170 Duration::from_millis(50),
1171 None,
1172 );
1173 task_ids.push(task.get_id());
1174 service.register(task).unwrap();
1175 }
1176
1177 let updates: Vec<_> = task_ids
1179 .into_iter()
1180 .map(|id| {
1181 let counter_clone = Arc::clone(&counter);
1182 (id, Duration::from_millis(150), Some(CallbackWrapper::new(move || {
1183 let counter = Arc::clone(&counter_clone);
1184 async move {
1185 counter.fetch_add(1, Ordering::SeqCst);
1186 }
1187 })))
1188 })
1189 .collect();
1190
1191 let postponed = service.postpone_batch_with_callbacks(updates);
1192 assert_eq!(postponed, 3, "All 3 tasks should be postponed");
1193
1194 tokio::time::sleep(Duration::from_millis(70)).await;
1196 assert_eq!(counter.load(Ordering::SeqCst), 0);
1197
1198 let mut received_count = 0;
1200 let mut rx = service.take_receiver().unwrap();
1201
1202 while received_count < 3 {
1203 match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
1204 Ok(Some(_task_id)) => {
1205 received_count += 1;
1206 }
1207 Ok(None) => break,
1208 Err(_) => break,
1209 }
1210 }
1211
1212 assert_eq!(received_count, 3);
1213
1214 tokio::time::sleep(Duration::from_millis(20)).await;
1216 assert_eq!(counter.load(Ordering::SeqCst), 3);
1217 }
1218
1219 #[tokio::test]
1220 async fn test_postpone_batch_empty() {
1221 let timer = TimerWheel::with_defaults();
1222 let service = timer.create_service();
1223
1224 let empty: Vec<(TaskId, Duration, Option<CallbackWrapper>)> = vec![];
1226 let postponed = service.postpone_batch_with_callbacks(empty);
1227 assert_eq!(postponed, 0, "No tasks should be postponed");
1228 }
1229
1230 #[tokio::test]
1231 async fn test_postpone_keeps_timeout_notification_valid() {
1232 let timer = TimerWheel::with_defaults();
1233 let mut service = timer.create_service();
1234 let counter = Arc::new(AtomicU32::new(0));
1235
1236 let counter_clone = Arc::clone(&counter);
1238 let task = TimerService::create_task(
1239 Duration::from_millis(50),
1240 Some(CallbackWrapper::new(move || {
1241 let counter = Arc::clone(&counter_clone);
1242 async move {
1243 counter.fetch_add(1, Ordering::SeqCst);
1244 }
1245 })),
1246 );
1247 let task_id = task.get_id();
1248 service.register(task).unwrap();
1249
1250 service.postpone(task_id, Duration::from_millis(100), None);
1252
1253 let mut rx = service.take_receiver().unwrap();
1255 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1256 .await
1257 .expect("Should receive timeout notification")
1258 .expect("Should receive Some value");
1259
1260 assert_eq!(received_task_id, task_id, "Timeout notification should still work after postpone");
1261
1262 tokio::time::sleep(Duration::from_millis(20)).await;
1264 assert_eq!(counter.load(Ordering::SeqCst), 1);
1265 }
1266
1267 #[tokio::test]
1268 async fn test_cancelled_task_not_forwarded_to_timeout_rx() {
1269 let timer = TimerWheel::with_defaults();
1270 let mut service = timer.create_service();
1271
1272 let task1 = TimerService::create_task(Duration::from_secs(10), None);
1274 let task1_id = task1.get_id();
1275 service.register(task1).unwrap();
1276
1277 let task2 = TimerService::create_task(Duration::from_millis(50), None);
1278 let task2_id = task2.get_id();
1279 service.register(task2).unwrap();
1280
1281 let cancelled = service.cancel_task(task1_id);
1283 assert!(cancelled, "Task should be cancelled");
1284
1285 let mut rx = service.take_receiver().unwrap();
1287 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1288 .await
1289 .expect("Should receive timeout notification")
1290 .expect("Should receive Some value");
1291
1292 assert_eq!(received_task_id, task2_id, "Should only receive expired task notification");
1294
1295 let no_more = tokio::time::timeout(Duration::from_millis(50), rx.recv()).await;
1297 assert!(no_more.is_err(), "Should not receive any more notifications");
1298 }
1299
1300 #[tokio::test]
1301 async fn test_take_receiver_twice() {
1302 let timer = TimerWheel::with_defaults();
1303 let mut service = timer.create_service();
1304
1305 let rx1 = service.take_receiver();
1307 assert!(rx1.is_some(), "First take_receiver should return Some");
1308
1309 let rx2 = service.take_receiver();
1311 assert!(rx2.is_none(), "Second take_receiver should return None");
1312 }
1313
1314 #[tokio::test]
1315 async fn test_postpone_batch_without_callbacks() {
1316 let timer = TimerWheel::with_defaults();
1317 let mut service = timer.create_service();
1318 let counter = Arc::new(AtomicU32::new(0));
1319
1320 let mut task_ids = Vec::new();
1322 for _ in 0..3 {
1323 let counter_clone = Arc::clone(&counter);
1324 let task = TimerService::create_task(
1325 Duration::from_millis(50),
1326 Some(CallbackWrapper::new(move || {
1327 let counter = Arc::clone(&counter_clone);
1328 async move {
1329 counter.fetch_add(1, Ordering::SeqCst);
1330 }
1331 })),
1332 );
1333 task_ids.push(task.get_id());
1334 service.register(task).unwrap();
1335 }
1336
1337 let updates: Vec<_> = task_ids
1339 .iter()
1340 .map(|&id| (id, Duration::from_millis(150)))
1341 .collect();
1342 let postponed = service.postpone_batch(updates);
1343 assert_eq!(postponed, 3, "All 3 tasks should be postponed");
1344
1345 tokio::time::sleep(Duration::from_millis(70)).await;
1347 assert_eq!(counter.load(Ordering::SeqCst), 0, "Callbacks should not fire yet");
1348
1349 let mut received_count = 0;
1351 let mut rx = service.take_receiver().unwrap();
1352
1353 while received_count < 3 {
1354 match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
1355 Ok(Some(_task_id)) => {
1356 received_count += 1;
1357 }
1358 Ok(None) => break,
1359 Err(_) => break,
1360 }
1361 }
1362
1363 assert_eq!(received_count, 3, "Should receive 3 timeout notifications");
1364
1365 tokio::time::sleep(Duration::from_millis(20)).await;
1367 assert_eq!(counter.load(Ordering::SeqCst), 3, "All callbacks should execute");
1368 }
1369}
1370