1use crate::config::ServiceConfig;
2use crate::error::TimerError;
3use crate::task::{CallbackWrapper, TaskCompletionReason, TaskId};
4use crate::wheel::Wheel;
5use futures::stream::{FuturesUnordered, StreamExt};
6use futures::future::BoxFuture;
7use parking_lot::Mutex;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::sync::mpsc;
11use tokio::task::JoinHandle;
12
13enum ServiceCommand {
15 AddBatchHandle {
17 task_ids: Vec<TaskId>,
18 completion_rxs: Vec<tokio::sync::oneshot::Receiver<TaskCompletionReason>>,
19 },
20 AddTimerHandle {
22 task_id: TaskId,
23 completion_rx: tokio::sync::oneshot::Receiver<TaskCompletionReason>,
24 },
25}
26
27pub struct TimerService {
67 command_tx: mpsc::Sender<ServiceCommand>,
69 timeout_rx: Option<mpsc::Receiver<TaskId>>,
71 actor_handle: Option<JoinHandle<()>>,
73 wheel: Arc<Mutex<Wheel>>,
75 shutdown_tx: Option<tokio::sync::oneshot::Sender<()>>,
77}
78
79impl TimerService {
80 pub(crate) fn new(wheel: Arc<Mutex<Wheel>>, config: ServiceConfig) -> Self {
100 let (command_tx, command_rx) = mpsc::channel(config.command_channel_capacity);
101 let (timeout_tx, timeout_rx) = mpsc::channel(config.timeout_channel_capacity);
102
103 let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
104 let actor = ServiceActor::new(command_rx, timeout_tx, shutdown_rx);
105 let actor_handle = tokio::spawn(async move {
106 actor.run().await;
107 });
108
109 Self {
110 command_tx,
111 timeout_rx: Some(timeout_rx),
112 actor_handle: Some(actor_handle),
113 wheel,
114 shutdown_tx: Some(shutdown_tx),
115 }
116 }
117
118 pub fn take_receiver(&mut self) -> Option<mpsc::Receiver<TaskId>> {
142 self.timeout_rx.take()
143 }
144
145 #[inline]
181 pub fn cancel_task(&self, task_id: TaskId) -> bool {
182 let mut wheel = self.wheel.lock();
185 wheel.cancel(task_id)
186 }
187
188 #[inline]
226 pub fn cancel_batch(&self, task_ids: &[TaskId]) -> usize {
227 if task_ids.is_empty() {
228 return 0;
229 }
230
231 let mut wheel = self.wheel.lock();
234 wheel.cancel_batch(task_ids)
235 }
236
237 #[inline]
281 pub fn postpone(&self, task_id: TaskId, new_delay: Duration, callback: Option<CallbackWrapper>) -> bool {
282 let mut wheel = self.wheel.lock();
283 wheel.postpone(task_id, new_delay, callback)
284 }
285
286 #[inline]
326 pub fn postpone_batch(&self, updates: Vec<(TaskId, Duration)>) -> usize {
327 if updates.is_empty() {
328 return 0;
329 }
330
331 let mut wheel = self.wheel.lock();
332 wheel.postpone_batch(updates)
333 }
334
335 #[inline]
377 pub fn postpone_batch_with_callbacks(
378 &self,
379 updates: Vec<(TaskId, Duration, Option<CallbackWrapper>)>,
380 ) -> usize {
381 if updates.is_empty() {
382 return 0;
383 }
384
385 let mut wheel = self.wheel.lock();
386 wheel.postpone_batch_with_callbacks(updates)
387 }
388
389 #[inline]
422 pub fn create_task(delay: Duration, callback: Option<CallbackWrapper>) -> crate::task::TimerTask {
423 crate::timer::TimerWheel::create_task(delay, callback)
424 }
425
426 #[inline]
457 pub fn create_batch(delays: Vec<Duration>) -> Vec<crate::task::TimerTask> {
458 crate::timer::TimerWheel::create_batch(delays)
459 }
460
461 #[inline]
497 pub fn create_batch_with_callbacks(callbacks: Vec<(Duration, Option<CallbackWrapper>)>) -> Vec<crate::task::TimerTask> {
498 crate::timer::TimerWheel::create_batch_with_callbacks(callbacks)
499 }
500
501 #[inline]
530 pub fn register(&self, task: crate::task::TimerTask) -> Result<(), TimerError> {
531 let (completion_tx, completion_rx) = tokio::sync::oneshot::channel();
532 let notifier = crate::task::CompletionNotifier(completion_tx);
533
534 let task_id = task.id;
535
536 {
538 let mut wheel_guard = self.wheel.lock();
539 wheel_guard.insert(task, notifier);
540 }
541
542 self.command_tx
544 .try_send(ServiceCommand::AddTimerHandle {
545 task_id,
546 completion_rx,
547 })
548 .map_err(|_| TimerError::RegisterFailed)?;
549
550 Ok(())
551 }
552
553 #[inline]
586 pub fn register_batch(&self, tasks: Vec<crate::task::TimerTask>) -> Result<(), TimerError> {
587 let task_count = tasks.len();
588 let mut completion_rxs = Vec::with_capacity(task_count);
589 let mut task_ids = Vec::with_capacity(task_count);
590 let mut prepared_tasks = Vec::with_capacity(task_count);
591
592 for task in tasks {
595 let (completion_tx, completion_rx) = tokio::sync::oneshot::channel();
596 let notifier = crate::task::CompletionNotifier(completion_tx);
597
598 task_ids.push(task.id);
599 completion_rxs.push(completion_rx);
600 prepared_tasks.push((task, notifier));
601 }
602
603 {
605 let mut wheel_guard = self.wheel.lock();
606 wheel_guard.insert_batch(prepared_tasks);
607 }
608
609 self.command_tx
611 .try_send(ServiceCommand::AddBatchHandle {
612 task_ids,
613 completion_rxs,
614 })
615 .map_err(|_| TimerError::RegisterFailed)?;
616
617 Ok(())
618 }
619
620 pub async fn shutdown(mut self) {
636 if let Some(shutdown_tx) = self.shutdown_tx.take() {
637 let _ = shutdown_tx.send(());
638 }
639 if let Some(handle) = self.actor_handle.take() {
640 let _ = handle.await;
641 }
642 }
643}
644
645
646impl Drop for TimerService {
647 fn drop(&mut self) {
648 if let Some(handle) = self.actor_handle.take() {
649 handle.abort();
650 }
651 }
652}
653
654struct ServiceActor {
656 command_rx: mpsc::Receiver<ServiceCommand>,
658 timeout_tx: mpsc::Sender<TaskId>,
660 shutdown_rx: tokio::sync::oneshot::Receiver<()>,
662}
663
664impl ServiceActor {
665 fn new(command_rx: mpsc::Receiver<ServiceCommand>, timeout_tx: mpsc::Sender<TaskId>, shutdown_rx: tokio::sync::oneshot::Receiver<()>) -> Self {
666 Self {
667 command_rx,
668 timeout_tx,
669 shutdown_rx,
670 }
671 }
672
673 async fn run(mut self) {
674 let mut futures: FuturesUnordered<BoxFuture<'static, (TaskId, Result<TaskCompletionReason, tokio::sync::oneshot::error::RecvError>)>> = FuturesUnordered::new();
677
678 let mut shutdown_rx = self.shutdown_rx;
680
681 loop {
682 tokio::select! {
683 _ = &mut shutdown_rx => {
685 break;
687 }
688
689 Some((task_id, result)) = futures.next() => {
691 if let Ok(TaskCompletionReason::Expired) = result {
693 let _ = self.timeout_tx.send(task_id).await;
694 }
695 }
697
698 Some(cmd) = self.command_rx.recv() => {
700 match cmd {
701 ServiceCommand::AddBatchHandle { task_ids, completion_rxs } => {
702 for (task_id, rx) in task_ids.into_iter().zip(completion_rxs.into_iter()) {
704 let future: BoxFuture<'static, (TaskId, Result<TaskCompletionReason, tokio::sync::oneshot::error::RecvError>)> = Box::pin(async move {
705 (task_id, rx.await)
706 });
707 futures.push(future);
708 }
709 }
710 ServiceCommand::AddTimerHandle { task_id, completion_rx } => {
711
712 let future: BoxFuture<'static, (TaskId, Result<TaskCompletionReason, tokio::sync::oneshot::error::RecvError>)> = Box::pin(async move {
714 (task_id, completion_rx.await)
715 });
716 futures.push(future);
717 }
718 }
719 }
720
721 else => {
723 break;
724 }
725 }
726 }
727 }
728}
729
730#[cfg(test)]
731mod tests {
732 use super::*;
733 use crate::TimerWheel;
734 use std::sync::atomic::{AtomicU32, Ordering};
735 use std::sync::Arc;
736 use std::time::Duration;
737
738 #[tokio::test]
739 async fn test_service_creation() {
740 let timer = TimerWheel::with_defaults();
741 let _service = timer.create_service(ServiceConfig::default());
742 }
743
744
745 #[tokio::test]
746 async fn test_add_timer_handle_and_receive_timeout() {
747 let timer = TimerWheel::with_defaults();
748 let mut service = timer.create_service(ServiceConfig::default());
749
750 let task = TimerService::create_task(Duration::from_millis(50), Some(CallbackWrapper::new(|| async {})));
752 let task_id = task.get_id();
753
754 service.register(task).unwrap();
756
757 let mut rx = service.take_receiver().unwrap();
759 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
760 .await
761 .expect("Should receive timeout notification")
762 .expect("Should receive Some value");
763
764 assert_eq!(received_task_id, task_id);
765 }
766
767
768 #[tokio::test]
769 async fn test_shutdown() {
770 let timer = TimerWheel::with_defaults();
771 let service = timer.create_service(ServiceConfig::default());
772
773 let task1 = TimerService::create_task(Duration::from_secs(10), None);
775 let task2 = TimerService::create_task(Duration::from_secs(10), None);
776 service.register(task1).unwrap();
777 service.register(task2).unwrap();
778
779 service.shutdown().await;
781 }
782
783
784
785 #[tokio::test]
786 async fn test_cancel_task() {
787 let timer = TimerWheel::with_defaults();
788 let service = timer.create_service(ServiceConfig::default());
789
790 let task = TimerService::create_task(Duration::from_secs(10), None);
792 let task_id = task.get_id();
793
794 service.register(task).unwrap();
795
796 let cancelled = service.cancel_task(task_id);
798 assert!(cancelled, "Task should be cancelled successfully");
799
800 let cancelled_again = service.cancel_task(task_id);
802 assert!(!cancelled_again, "Task should not exist anymore");
803 }
804
805 #[tokio::test]
806 async fn test_cancel_nonexistent_task() {
807 let timer = TimerWheel::with_defaults();
808 let service = timer.create_service(ServiceConfig::default());
809
810 let task = TimerService::create_task(Duration::from_millis(50), None);
812 service.register(task).unwrap();
813
814 let fake_task = TimerService::create_task(Duration::from_millis(50), None);
816 let fake_task_id = fake_task.get_id();
817 let cancelled = service.cancel_task(fake_task_id);
819 assert!(!cancelled, "Nonexistent task should not be cancelled");
820 }
821
822
823 #[tokio::test]
824 async fn test_task_timeout_cleans_up_task_sender() {
825 let timer = TimerWheel::with_defaults();
826 let mut service = timer.create_service(ServiceConfig::default());
827
828 let task = TimerService::create_task(Duration::from_millis(50), None);
830 let task_id = task.get_id();
831
832 service.register(task).unwrap();
833
834 let mut rx = service.take_receiver().unwrap();
836 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
837 .await
838 .expect("Should receive timeout notification")
839 .expect("Should receive Some value");
840
841 assert_eq!(received_task_id, task_id);
842
843 tokio::time::sleep(Duration::from_millis(10)).await;
845
846 let cancelled = service.cancel_task(task_id);
848 assert!(!cancelled, "Timed out task should not exist anymore");
849 }
850
851 #[tokio::test]
852 async fn test_cancel_task_spawns_background_task() {
853 let timer = TimerWheel::with_defaults();
854 let service = timer.create_service(ServiceConfig::default());
855 let counter = Arc::new(AtomicU32::new(0));
856
857 let counter_clone = Arc::clone(&counter);
859 let task = TimerService::create_task(
860 Duration::from_secs(10),
861 Some(CallbackWrapper::new(move || {
862 let counter = Arc::clone(&counter_clone);
863 async move {
864 counter.fetch_add(1, Ordering::SeqCst);
865 }
866 })),
867 );
868 let task_id = task.get_id();
869
870 service.register(task).unwrap();
871
872 let cancelled = service.cancel_task(task_id);
874 assert!(cancelled, "Task should be cancelled successfully");
875
876 tokio::time::sleep(Duration::from_millis(100)).await;
878 assert_eq!(counter.load(Ordering::SeqCst), 0, "Callback should not have been executed");
879
880 let cancelled_again = service.cancel_task(task_id);
882 assert!(!cancelled_again, "Task should have been removed from active_tasks");
883 }
884
885 #[tokio::test]
886 async fn test_schedule_once_direct() {
887 let timer = TimerWheel::with_defaults();
888 let mut service = timer.create_service(ServiceConfig::default());
889 let counter = Arc::new(AtomicU32::new(0));
890
891 let counter_clone = Arc::clone(&counter);
893 let task = TimerService::create_task(
894 Duration::from_millis(50),
895 Some(CallbackWrapper::new(move || {
896 let counter = Arc::clone(&counter_clone);
897 async move {
898 counter.fetch_add(1, Ordering::SeqCst);
899 }
900 })),
901 );
902 let task_id = task.get_id();
903 service.register(task).unwrap();
904
905 let mut rx = service.take_receiver().unwrap();
907 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
908 .await
909 .expect("Should receive timeout notification")
910 .expect("Should receive Some value");
911
912 assert_eq!(received_task_id, task_id);
913
914 tokio::time::sleep(Duration::from_millis(50)).await;
916 assert_eq!(counter.load(Ordering::SeqCst), 1);
917 }
918
919 #[tokio::test]
920 async fn test_schedule_once_batch_direct() {
921 let timer = TimerWheel::with_defaults();
922 let mut service = timer.create_service(ServiceConfig::default());
923 let counter = Arc::new(AtomicU32::new(0));
924
925 let callbacks: Vec<_> = (0..3)
927 .map(|_| {
928 let counter = Arc::clone(&counter);
929 (Duration::from_millis(50), Some(CallbackWrapper::new(move || {
930 let counter = Arc::clone(&counter);
931 async move {
932 counter.fetch_add(1, Ordering::SeqCst);
933 }
934 })))
935 })
936 .collect();
937
938 let tasks = TimerService::create_batch_with_callbacks(callbacks);
939 assert_eq!(tasks.len(), 3);
940 service.register_batch(tasks).unwrap();
941
942 let mut received_count = 0;
944 let mut rx = service.take_receiver().unwrap();
945
946 while received_count < 3 {
947 match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
948 Ok(Some(_task_id)) => {
949 received_count += 1;
950 }
951 Ok(None) => break,
952 Err(_) => break,
953 }
954 }
955
956 assert_eq!(received_count, 3);
957
958 tokio::time::sleep(Duration::from_millis(50)).await;
960 assert_eq!(counter.load(Ordering::SeqCst), 3);
961 }
962
963 #[tokio::test]
964 async fn test_schedule_once_notify_direct() {
965 let timer = TimerWheel::with_defaults();
966 let mut service = timer.create_service(ServiceConfig::default());
967
968 let task = TimerService::create_task(Duration::from_millis(50), None);
970 let task_id = task.get_id();
971 service.register(task).unwrap();
972
973 let mut rx = service.take_receiver().unwrap();
975 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
976 .await
977 .expect("Should receive timeout notification")
978 .expect("Should receive Some value");
979
980 assert_eq!(received_task_id, task_id);
981 }
982
983 #[tokio::test]
984 async fn test_schedule_and_cancel_direct() {
985 let timer = TimerWheel::with_defaults();
986 let service = timer.create_service(ServiceConfig::default());
987 let counter = Arc::new(AtomicU32::new(0));
988
989 let counter_clone = Arc::clone(&counter);
991 let task = TimerService::create_task(
992 Duration::from_secs(10),
993 Some(CallbackWrapper::new(move || {
994 let counter = Arc::clone(&counter_clone);
995 async move {
996 counter.fetch_add(1, Ordering::SeqCst);
997 }
998 })),
999 );
1000 let task_id = task.get_id();
1001 service.register(task).unwrap();
1002
1003 let cancelled = service.cancel_task(task_id);
1005 assert!(cancelled, "Task should be cancelled successfully");
1006
1007 tokio::time::sleep(Duration::from_millis(100)).await;
1009 assert_eq!(counter.load(Ordering::SeqCst), 0, "Callback should not have been executed");
1010 }
1011
1012 #[tokio::test]
1013 async fn test_cancel_batch_direct() {
1014 let timer = TimerWheel::with_defaults();
1015 let service = timer.create_service(ServiceConfig::default());
1016 let counter = Arc::new(AtomicU32::new(0));
1017
1018 let callbacks: Vec<_> = (0..10)
1020 .map(|_| {
1021 let counter = Arc::clone(&counter);
1022 (Duration::from_secs(10), Some(CallbackWrapper::new(move || {
1023 let counter = Arc::clone(&counter);
1024 async move {
1025 counter.fetch_add(1, Ordering::SeqCst);
1026 }
1027 })))
1028 })
1029 .collect();
1030
1031 let tasks = TimerService::create_batch_with_callbacks(callbacks);
1032 let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
1033 assert_eq!(task_ids.len(), 10);
1034 service.register_batch(tasks).unwrap();
1035
1036 let cancelled = service.cancel_batch(&task_ids);
1038 assert_eq!(cancelled, 10, "All 10 tasks should be cancelled");
1039
1040 tokio::time::sleep(Duration::from_millis(100)).await;
1042 assert_eq!(counter.load(Ordering::SeqCst), 0, "No callbacks should have been executed");
1043 }
1044
1045 #[tokio::test]
1046 async fn test_cancel_batch_partial() {
1047 let timer = TimerWheel::with_defaults();
1048 let service = timer.create_service(ServiceConfig::default());
1049 let counter = Arc::new(AtomicU32::new(0));
1050
1051 let callbacks: Vec<_> = (0..10)
1053 .map(|_| {
1054 let counter = Arc::clone(&counter);
1055 (Duration::from_secs(10), Some(CallbackWrapper::new(move || {
1056 let counter = Arc::clone(&counter);
1057 async move {
1058 counter.fetch_add(1, Ordering::SeqCst);
1059 }
1060 })))
1061 })
1062 .collect();
1063
1064 let tasks = TimerService::create_batch_with_callbacks(callbacks);
1065 let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
1066 service.register_batch(tasks).unwrap();
1067
1068 let to_cancel: Vec<_> = task_ids.iter().take(5).copied().collect();
1070 let cancelled = service.cancel_batch(&to_cancel);
1071 assert_eq!(cancelled, 5, "5 tasks should be cancelled");
1072
1073 tokio::time::sleep(Duration::from_millis(100)).await;
1075 assert_eq!(counter.load(Ordering::SeqCst), 0, "Cancelled tasks should not execute");
1076 }
1077
1078 #[tokio::test]
1079 async fn test_cancel_batch_empty() {
1080 let timer = TimerWheel::with_defaults();
1081 let service = timer.create_service(ServiceConfig::default());
1082
1083 let empty: Vec<TaskId> = vec![];
1085 let cancelled = service.cancel_batch(&empty);
1086 assert_eq!(cancelled, 0, "No tasks should be cancelled");
1087 }
1088
1089 #[tokio::test]
1090 async fn test_postpone() {
1091 let timer = TimerWheel::with_defaults();
1092 let mut service = timer.create_service(ServiceConfig::default());
1093 let counter = Arc::new(AtomicU32::new(0));
1094
1095 let counter_clone1 = Arc::clone(&counter);
1097 let task = TimerService::create_task(
1098 Duration::from_millis(50),
1099 Some(CallbackWrapper::new(move || {
1100 let counter = Arc::clone(&counter_clone1);
1101 async move {
1102 counter.fetch_add(1, Ordering::SeqCst);
1103 }
1104 })),
1105 );
1106 let task_id = task.get_id();
1107 service.register(task).unwrap();
1108
1109 let counter_clone2 = Arc::clone(&counter);
1111 let postponed = service.postpone(
1112 task_id,
1113 Duration::from_millis(100),
1114 Some(CallbackWrapper::new(move || {
1115 let counter = Arc::clone(&counter_clone2);
1116 async move {
1117 counter.fetch_add(10, Ordering::SeqCst);
1118 }
1119 }))
1120 );
1121 assert!(postponed, "Task should be postponed successfully");
1122
1123 let mut rx = service.take_receiver().unwrap();
1125 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1126 .await
1127 .expect("Should receive timeout notification")
1128 .expect("Should receive Some value");
1129
1130 assert_eq!(received_task_id, task_id);
1131
1132 tokio::time::sleep(Duration::from_millis(20)).await;
1134
1135 assert_eq!(counter.load(Ordering::SeqCst), 10);
1137 }
1138
1139 #[tokio::test]
1140 async fn test_postpone_nonexistent_task() {
1141 let timer = TimerWheel::with_defaults();
1142 let service = timer.create_service(ServiceConfig::default());
1143
1144 let fake_task = TimerService::create_task(Duration::from_millis(50), None);
1146 let fake_task_id = fake_task.get_id();
1147 let postponed = service.postpone(fake_task_id, Duration::from_millis(100), None);
1150 assert!(!postponed, "Nonexistent task should not be postponed");
1151 }
1152
1153 #[tokio::test]
1154 async fn test_postpone_batch() {
1155 let timer = TimerWheel::with_defaults();
1156 let mut service = timer.create_service(ServiceConfig::default());
1157 let counter = Arc::new(AtomicU32::new(0));
1158
1159 let mut task_ids = Vec::new();
1161 for _ in 0..3 {
1162 let counter_clone = Arc::clone(&counter);
1163 let task = TimerService::create_task(
1164 Duration::from_millis(50),
1165 Some(CallbackWrapper::new(move || {
1166 let counter = Arc::clone(&counter_clone);
1167 async move {
1168 counter.fetch_add(1, Ordering::SeqCst);
1169 }
1170 })),
1171 );
1172 task_ids.push((task.get_id(), Duration::from_millis(150), None));
1173 service.register(task).unwrap();
1174 }
1175
1176 let postponed = service.postpone_batch_with_callbacks(task_ids);
1178 assert_eq!(postponed, 3, "All 3 tasks should be postponed");
1179
1180 tokio::time::sleep(Duration::from_millis(70)).await;
1182 assert_eq!(counter.load(Ordering::SeqCst), 0);
1183
1184 let mut received_count = 0;
1186 let mut rx = service.take_receiver().unwrap();
1187
1188 while received_count < 3 {
1189 match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
1190 Ok(Some(_task_id)) => {
1191 received_count += 1;
1192 }
1193 Ok(None) => break,
1194 Err(_) => break,
1195 }
1196 }
1197
1198 assert_eq!(received_count, 3);
1199
1200 tokio::time::sleep(Duration::from_millis(20)).await;
1202 assert_eq!(counter.load(Ordering::SeqCst), 3);
1203 }
1204
1205 #[tokio::test]
1206 async fn test_postpone_batch_with_callbacks() {
1207 let timer = TimerWheel::with_defaults();
1208 let mut service = timer.create_service(ServiceConfig::default());
1209 let counter = Arc::new(AtomicU32::new(0));
1210
1211 let mut task_ids = Vec::new();
1213 for _ in 0..3 {
1214 let task = TimerService::create_task(
1215 Duration::from_millis(50),
1216 None,
1217 );
1218 task_ids.push(task.get_id());
1219 service.register(task).unwrap();
1220 }
1221
1222 let updates: Vec<_> = task_ids
1224 .into_iter()
1225 .map(|id| {
1226 let counter_clone = Arc::clone(&counter);
1227 (id, Duration::from_millis(150), Some(CallbackWrapper::new(move || {
1228 let counter = Arc::clone(&counter_clone);
1229 async move {
1230 counter.fetch_add(1, Ordering::SeqCst);
1231 }
1232 })))
1233 })
1234 .collect();
1235
1236 let postponed = service.postpone_batch_with_callbacks(updates);
1237 assert_eq!(postponed, 3, "All 3 tasks should be postponed");
1238
1239 tokio::time::sleep(Duration::from_millis(70)).await;
1241 assert_eq!(counter.load(Ordering::SeqCst), 0);
1242
1243 let mut received_count = 0;
1245 let mut rx = service.take_receiver().unwrap();
1246
1247 while received_count < 3 {
1248 match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
1249 Ok(Some(_task_id)) => {
1250 received_count += 1;
1251 }
1252 Ok(None) => break,
1253 Err(_) => break,
1254 }
1255 }
1256
1257 assert_eq!(received_count, 3);
1258
1259 tokio::time::sleep(Duration::from_millis(20)).await;
1261 assert_eq!(counter.load(Ordering::SeqCst), 3);
1262 }
1263
1264 #[tokio::test]
1265 async fn test_postpone_batch_empty() {
1266 let timer = TimerWheel::with_defaults();
1267 let service = timer.create_service(ServiceConfig::default());
1268
1269 let empty: Vec<(TaskId, Duration, Option<CallbackWrapper>)> = vec![];
1271 let postponed = service.postpone_batch_with_callbacks(empty);
1272 assert_eq!(postponed, 0, "No tasks should be postponed");
1273 }
1274
1275 #[tokio::test]
1276 async fn test_postpone_keeps_timeout_notification_valid() {
1277 let timer = TimerWheel::with_defaults();
1278 let mut service = timer.create_service(ServiceConfig::default());
1279 let counter = Arc::new(AtomicU32::new(0));
1280
1281 let counter_clone = Arc::clone(&counter);
1283 let task = TimerService::create_task(
1284 Duration::from_millis(50),
1285 Some(CallbackWrapper::new(move || {
1286 let counter = Arc::clone(&counter_clone);
1287 async move {
1288 counter.fetch_add(1, Ordering::SeqCst);
1289 }
1290 })),
1291 );
1292 let task_id = task.get_id();
1293 service.register(task).unwrap();
1294
1295 service.postpone(task_id, Duration::from_millis(100), None);
1297
1298 let mut rx = service.take_receiver().unwrap();
1300 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1301 .await
1302 .expect("Should receive timeout notification")
1303 .expect("Should receive Some value");
1304
1305 assert_eq!(received_task_id, task_id, "Timeout notification should still work after postpone");
1306
1307 tokio::time::sleep(Duration::from_millis(20)).await;
1309 assert_eq!(counter.load(Ordering::SeqCst), 1);
1310 }
1311
1312 #[tokio::test]
1313 async fn test_cancelled_task_not_forwarded_to_timeout_rx() {
1314 let timer = TimerWheel::with_defaults();
1315 let mut service = timer.create_service(ServiceConfig::default());
1316
1317 let task1 = TimerService::create_task(Duration::from_secs(10), None);
1319 let task1_id = task1.get_id();
1320 service.register(task1).unwrap();
1321
1322 let task2 = TimerService::create_task(Duration::from_millis(50), None);
1323 let task2_id = task2.get_id();
1324 service.register(task2).unwrap();
1325
1326 let cancelled = service.cancel_task(task1_id);
1328 assert!(cancelled, "Task should be cancelled");
1329
1330 let mut rx = service.take_receiver().unwrap();
1332 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1333 .await
1334 .expect("Should receive timeout notification")
1335 .expect("Should receive Some value");
1336
1337 assert_eq!(received_task_id, task2_id, "Should only receive expired task notification");
1339
1340 let no_more = tokio::time::timeout(Duration::from_millis(50), rx.recv()).await;
1342 assert!(no_more.is_err(), "Should not receive any more notifications");
1343 }
1344
1345 #[tokio::test]
1346 async fn test_take_receiver_twice() {
1347 let timer = TimerWheel::with_defaults();
1348 let mut service = timer.create_service(ServiceConfig::default());
1349
1350 let rx1 = service.take_receiver();
1352 assert!(rx1.is_some(), "First take_receiver should return Some");
1353
1354 let rx2 = service.take_receiver();
1356 assert!(rx2.is_none(), "Second take_receiver should return None");
1357 }
1358
1359 #[tokio::test]
1360 async fn test_postpone_batch_without_callbacks() {
1361 let timer = TimerWheel::with_defaults();
1362 let mut service = timer.create_service(ServiceConfig::default());
1363 let counter = Arc::new(AtomicU32::new(0));
1364
1365 let mut task_ids = Vec::new();
1367 for _ in 0..3 {
1368 let counter_clone = Arc::clone(&counter);
1369 let task = TimerService::create_task(
1370 Duration::from_millis(50),
1371 Some(CallbackWrapper::new(move || {
1372 let counter = Arc::clone(&counter_clone);
1373 async move {
1374 counter.fetch_add(1, Ordering::SeqCst);
1375 }
1376 })),
1377 );
1378 task_ids.push(task.get_id());
1379 service.register(task).unwrap();
1380 }
1381
1382 let updates: Vec<_> = task_ids
1384 .iter()
1385 .map(|&id| (id, Duration::from_millis(150)))
1386 .collect();
1387 let postponed = service.postpone_batch(updates);
1388 assert_eq!(postponed, 3, "All 3 tasks should be postponed");
1389
1390 tokio::time::sleep(Duration::from_millis(70)).await;
1392 assert_eq!(counter.load(Ordering::SeqCst), 0, "Callbacks should not fire yet");
1393
1394 let mut received_count = 0;
1396 let mut rx = service.take_receiver().unwrap();
1397
1398 while received_count < 3 {
1399 match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
1400 Ok(Some(_task_id)) => {
1401 received_count += 1;
1402 }
1403 Ok(None) => break,
1404 Err(_) => break,
1405 }
1406 }
1407
1408 assert_eq!(received_count, 3, "Should receive 3 timeout notifications");
1409
1410 tokio::time::sleep(Duration::from_millis(20)).await;
1412 assert_eq!(counter.load(Ordering::SeqCst), 3, "All callbacks should execute");
1413 }
1414}
1415