1use crate::config::ServiceConfig;
2use crate::error::TimerError;
3use crate::task::{CallbackWrapper, TaskCompletionReason, TaskId};
4use crate::wheel::Wheel;
5use futures::stream::{FuturesUnordered, StreamExt};
6use futures::future::BoxFuture;
7use parking_lot::Mutex;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::sync::mpsc;
11use tokio::task::JoinHandle;
12
13enum ServiceCommand {
15 AddBatchHandle {
17 task_ids: Vec<TaskId>,
18 completion_rxs: Vec<tokio::sync::oneshot::Receiver<TaskCompletionReason>>,
19 },
20 AddTimerHandle {
22 task_id: TaskId,
23 completion_rx: tokio::sync::oneshot::Receiver<TaskCompletionReason>,
24 },
25 Shutdown,
27}
28
29pub struct TimerService {
69 command_tx: mpsc::Sender<ServiceCommand>,
71 timeout_rx: Option<mpsc::Receiver<TaskId>>,
73 actor_handle: Option<JoinHandle<()>>,
75 wheel: Arc<Mutex<Wheel>>,
77}
78
79impl TimerService {
80 pub(crate) fn new(wheel: Arc<Mutex<Wheel>>, config: ServiceConfig) -> Self {
100 let (command_tx, command_rx) = mpsc::channel(config.command_channel_capacity);
101 let (timeout_tx, timeout_rx) = mpsc::channel(config.timeout_channel_capacity);
102
103 let actor = ServiceActor::new(command_rx, timeout_tx);
104 let actor_handle = tokio::spawn(async move {
105 actor.run().await;
106 });
107
108 Self {
109 command_tx,
110 timeout_rx: Some(timeout_rx),
111 actor_handle: Some(actor_handle),
112 wheel,
113 }
114 }
115
116 pub fn take_receiver(&mut self) -> Option<mpsc::Receiver<TaskId>> {
140 self.timeout_rx.take()
141 }
142
143 #[inline]
179 pub fn cancel_task(&self, task_id: TaskId) -> bool {
180 let mut wheel = self.wheel.lock();
183 wheel.cancel(task_id)
184 }
185
186 #[inline]
224 pub fn cancel_batch(&self, task_ids: &[TaskId]) -> usize {
225 if task_ids.is_empty() {
226 return 0;
227 }
228
229 let mut wheel = self.wheel.lock();
232 wheel.cancel_batch(task_ids)
233 }
234
235 #[inline]
279 pub fn postpone(&self, task_id: TaskId, new_delay: Duration, callback: Option<CallbackWrapper>) -> bool {
280 let mut wheel = self.wheel.lock();
281 wheel.postpone(task_id, new_delay, callback)
282 }
283
284 #[inline]
324 pub fn postpone_batch(&self, updates: Vec<(TaskId, Duration)>) -> usize {
325 if updates.is_empty() {
326 return 0;
327 }
328
329 let mut wheel = self.wheel.lock();
330 wheel.postpone_batch(updates)
331 }
332
333 #[inline]
375 pub fn postpone_batch_with_callbacks(
376 &self,
377 updates: Vec<(TaskId, Duration, Option<CallbackWrapper>)>,
378 ) -> usize {
379 if updates.is_empty() {
380 return 0;
381 }
382
383 let mut wheel = self.wheel.lock();
384 wheel.postpone_batch_with_callbacks(updates)
385 }
386
387 #[inline]
420 pub fn create_task(delay: Duration, callback: Option<CallbackWrapper>) -> crate::task::TimerTask {
421 crate::timer::TimerWheel::create_task(delay, callback)
422 }
423
424 #[inline]
455 pub fn create_batch(delays: Vec<Duration>) -> Vec<crate::task::TimerTask> {
456 crate::timer::TimerWheel::create_batch(delays)
457 }
458
459 #[inline]
495 pub fn create_batch_with_callbacks(callbacks: Vec<(Duration, Option<CallbackWrapper>)>) -> Vec<crate::task::TimerTask> {
496 crate::timer::TimerWheel::create_batch_with_callbacks(callbacks)
497 }
498
499 #[inline]
528 pub fn register(&self, task: crate::task::TimerTask) -> Result<(), TimerError> {
529 let (completion_tx, completion_rx) = tokio::sync::oneshot::channel();
530 let notifier = crate::task::CompletionNotifier(completion_tx);
531
532 let task_id = task.id;
533
534 {
536 let mut wheel_guard = self.wheel.lock();
537 wheel_guard.insert(task, notifier);
538 }
539
540 self.command_tx
542 .try_send(ServiceCommand::AddTimerHandle {
543 task_id,
544 completion_rx,
545 })
546 .map_err(|_| TimerError::RegisterFailed)?;
547
548 Ok(())
549 }
550
551 #[inline]
584 pub fn register_batch(&self, tasks: Vec<crate::task::TimerTask>) -> Result<(), TimerError> {
585 let task_count = tasks.len();
586 let mut completion_rxs = Vec::with_capacity(task_count);
587 let mut task_ids = Vec::with_capacity(task_count);
588 let mut prepared_tasks = Vec::with_capacity(task_count);
589
590 for task in tasks {
593 let (completion_tx, completion_rx) = tokio::sync::oneshot::channel();
594 let notifier = crate::task::CompletionNotifier(completion_tx);
595
596 task_ids.push(task.id);
597 completion_rxs.push(completion_rx);
598 prepared_tasks.push((task, notifier));
599 }
600
601 {
603 let mut wheel_guard = self.wheel.lock();
604 wheel_guard.insert_batch(prepared_tasks);
605 }
606
607 self.command_tx
609 .try_send(ServiceCommand::AddBatchHandle {
610 task_ids,
611 completion_rxs,
612 })
613 .map_err(|_| TimerError::RegisterFailed)?;
614
615 Ok(())
616 }
617
618 pub async fn shutdown(mut self) {
634 let _ = self.command_tx.send(ServiceCommand::Shutdown).await;
635 if let Some(handle) = self.actor_handle.take() {
636 let _ = handle.await;
637 }
638 }
639}
640
641
642impl Drop for TimerService {
643 fn drop(&mut self) {
644 if let Some(handle) = self.actor_handle.take() {
645 handle.abort();
646 }
647 }
648}
649
650struct ServiceActor {
652 command_rx: mpsc::Receiver<ServiceCommand>,
654 timeout_tx: mpsc::Sender<TaskId>,
656}
657
658impl ServiceActor {
659 fn new(command_rx: mpsc::Receiver<ServiceCommand>, timeout_tx: mpsc::Sender<TaskId>) -> Self {
660 Self {
661 command_rx,
662 timeout_tx,
663 }
664 }
665
666 async fn run(mut self) {
667 let mut futures: FuturesUnordered<BoxFuture<'static, (TaskId, Result<TaskCompletionReason, tokio::sync::oneshot::error::RecvError>)>> = FuturesUnordered::new();
670
671 loop {
672 tokio::select! {
673 Some((task_id, result)) = futures.next() => {
675 if let Ok(TaskCompletionReason::Expired) = result {
677 let _ = self.timeout_tx.send(task_id).await;
678 }
679 }
681
682 Some(cmd) = self.command_rx.recv() => {
684 match cmd {
685 ServiceCommand::AddBatchHandle { task_ids, completion_rxs } => {
686 for (task_id, rx) in task_ids.into_iter().zip(completion_rxs.into_iter()) {
688 let future: BoxFuture<'static, (TaskId, Result<TaskCompletionReason, tokio::sync::oneshot::error::RecvError>)> = Box::pin(async move {
689 (task_id, rx.await)
690 });
691 futures.push(future);
692 }
693 }
694 ServiceCommand::AddTimerHandle { task_id, completion_rx } => {
695
696 let future: BoxFuture<'static, (TaskId, Result<TaskCompletionReason, tokio::sync::oneshot::error::RecvError>)> = Box::pin(async move {
698 (task_id, completion_rx.await)
699 });
700 futures.push(future);
701 }
702 ServiceCommand::Shutdown => {
703 break;
704 }
705 }
706 }
707
708 else => {
710 break;
711 }
712 }
713 }
714 }
715}
716
717#[cfg(test)]
718mod tests {
719 use super::*;
720 use crate::TimerWheel;
721 use std::sync::atomic::{AtomicU32, Ordering};
722 use std::sync::Arc;
723 use std::time::Duration;
724
725 #[tokio::test]
726 async fn test_service_creation() {
727 let timer = TimerWheel::with_defaults();
728 let _service = timer.create_service(ServiceConfig::default());
729 }
730
731
732 #[tokio::test]
733 async fn test_add_timer_handle_and_receive_timeout() {
734 let timer = TimerWheel::with_defaults();
735 let mut service = timer.create_service(ServiceConfig::default());
736
737 let task = TimerService::create_task(Duration::from_millis(50), Some(CallbackWrapper::new(|| async {})));
739 let task_id = task.get_id();
740
741 service.register(task).unwrap();
743
744 let mut rx = service.take_receiver().unwrap();
746 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
747 .await
748 .expect("Should receive timeout notification")
749 .expect("Should receive Some value");
750
751 assert_eq!(received_task_id, task_id);
752 }
753
754
755 #[tokio::test]
756 async fn test_shutdown() {
757 let timer = TimerWheel::with_defaults();
758 let service = timer.create_service(ServiceConfig::default());
759
760 let task1 = TimerService::create_task(Duration::from_secs(10), None);
762 let task2 = TimerService::create_task(Duration::from_secs(10), None);
763 service.register(task1).unwrap();
764 service.register(task2).unwrap();
765
766 service.shutdown().await;
768 }
769
770
771
772 #[tokio::test]
773 async fn test_cancel_task() {
774 let timer = TimerWheel::with_defaults();
775 let service = timer.create_service(ServiceConfig::default());
776
777 let task = TimerService::create_task(Duration::from_secs(10), None);
779 let task_id = task.get_id();
780
781 service.register(task).unwrap();
782
783 let cancelled = service.cancel_task(task_id);
785 assert!(cancelled, "Task should be cancelled successfully");
786
787 let cancelled_again = service.cancel_task(task_id);
789 assert!(!cancelled_again, "Task should not exist anymore");
790 }
791
792 #[tokio::test]
793 async fn test_cancel_nonexistent_task() {
794 let timer = TimerWheel::with_defaults();
795 let service = timer.create_service(ServiceConfig::default());
796
797 let task = TimerService::create_task(Duration::from_millis(50), None);
799 service.register(task).unwrap();
800
801 let fake_task = TimerService::create_task(Duration::from_millis(50), None);
803 let fake_task_id = fake_task.get_id();
804 let cancelled = service.cancel_task(fake_task_id);
806 assert!(!cancelled, "Nonexistent task should not be cancelled");
807 }
808
809
810 #[tokio::test]
811 async fn test_task_timeout_cleans_up_task_sender() {
812 let timer = TimerWheel::with_defaults();
813 let mut service = timer.create_service(ServiceConfig::default());
814
815 let task = TimerService::create_task(Duration::from_millis(50), None);
817 let task_id = task.get_id();
818
819 service.register(task).unwrap();
820
821 let mut rx = service.take_receiver().unwrap();
823 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
824 .await
825 .expect("Should receive timeout notification")
826 .expect("Should receive Some value");
827
828 assert_eq!(received_task_id, task_id);
829
830 tokio::time::sleep(Duration::from_millis(10)).await;
832
833 let cancelled = service.cancel_task(task_id);
835 assert!(!cancelled, "Timed out task should not exist anymore");
836 }
837
838 #[tokio::test]
839 async fn test_cancel_task_spawns_background_task() {
840 let timer = TimerWheel::with_defaults();
841 let service = timer.create_service(ServiceConfig::default());
842 let counter = Arc::new(AtomicU32::new(0));
843
844 let counter_clone = Arc::clone(&counter);
846 let task = TimerService::create_task(
847 Duration::from_secs(10),
848 Some(CallbackWrapper::new(move || {
849 let counter = Arc::clone(&counter_clone);
850 async move {
851 counter.fetch_add(1, Ordering::SeqCst);
852 }
853 })),
854 );
855 let task_id = task.get_id();
856
857 service.register(task).unwrap();
858
859 let cancelled = service.cancel_task(task_id);
861 assert!(cancelled, "Task should be cancelled successfully");
862
863 tokio::time::sleep(Duration::from_millis(100)).await;
865 assert_eq!(counter.load(Ordering::SeqCst), 0, "Callback should not have been executed");
866
867 let cancelled_again = service.cancel_task(task_id);
869 assert!(!cancelled_again, "Task should have been removed from active_tasks");
870 }
871
872 #[tokio::test]
873 async fn test_schedule_once_direct() {
874 let timer = TimerWheel::with_defaults();
875 let mut service = timer.create_service(ServiceConfig::default());
876 let counter = Arc::new(AtomicU32::new(0));
877
878 let counter_clone = Arc::clone(&counter);
880 let task = TimerService::create_task(
881 Duration::from_millis(50),
882 Some(CallbackWrapper::new(move || {
883 let counter = Arc::clone(&counter_clone);
884 async move {
885 counter.fetch_add(1, Ordering::SeqCst);
886 }
887 })),
888 );
889 let task_id = task.get_id();
890 service.register(task).unwrap();
891
892 let mut rx = service.take_receiver().unwrap();
894 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
895 .await
896 .expect("Should receive timeout notification")
897 .expect("Should receive Some value");
898
899 assert_eq!(received_task_id, task_id);
900
901 tokio::time::sleep(Duration::from_millis(50)).await;
903 assert_eq!(counter.load(Ordering::SeqCst), 1);
904 }
905
906 #[tokio::test]
907 async fn test_schedule_once_batch_direct() {
908 let timer = TimerWheel::with_defaults();
909 let mut service = timer.create_service(ServiceConfig::default());
910 let counter = Arc::new(AtomicU32::new(0));
911
912 let callbacks: Vec<_> = (0..3)
914 .map(|_| {
915 let counter = Arc::clone(&counter);
916 (Duration::from_millis(50), Some(CallbackWrapper::new(move || {
917 let counter = Arc::clone(&counter);
918 async move {
919 counter.fetch_add(1, Ordering::SeqCst);
920 }
921 })))
922 })
923 .collect();
924
925 let tasks = TimerService::create_batch_with_callbacks(callbacks);
926 assert_eq!(tasks.len(), 3);
927 service.register_batch(tasks).unwrap();
928
929 let mut received_count = 0;
931 let mut rx = service.take_receiver().unwrap();
932
933 while received_count < 3 {
934 match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
935 Ok(Some(_task_id)) => {
936 received_count += 1;
937 }
938 Ok(None) => break,
939 Err(_) => break,
940 }
941 }
942
943 assert_eq!(received_count, 3);
944
945 tokio::time::sleep(Duration::from_millis(50)).await;
947 assert_eq!(counter.load(Ordering::SeqCst), 3);
948 }
949
950 #[tokio::test]
951 async fn test_schedule_once_notify_direct() {
952 let timer = TimerWheel::with_defaults();
953 let mut service = timer.create_service(ServiceConfig::default());
954
955 let task = TimerService::create_task(Duration::from_millis(50), None);
957 let task_id = task.get_id();
958 service.register(task).unwrap();
959
960 let mut rx = service.take_receiver().unwrap();
962 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
963 .await
964 .expect("Should receive timeout notification")
965 .expect("Should receive Some value");
966
967 assert_eq!(received_task_id, task_id);
968 }
969
970 #[tokio::test]
971 async fn test_schedule_and_cancel_direct() {
972 let timer = TimerWheel::with_defaults();
973 let service = timer.create_service(ServiceConfig::default());
974 let counter = Arc::new(AtomicU32::new(0));
975
976 let counter_clone = Arc::clone(&counter);
978 let task = TimerService::create_task(
979 Duration::from_secs(10),
980 Some(CallbackWrapper::new(move || {
981 let counter = Arc::clone(&counter_clone);
982 async move {
983 counter.fetch_add(1, Ordering::SeqCst);
984 }
985 })),
986 );
987 let task_id = task.get_id();
988 service.register(task).unwrap();
989
990 let cancelled = service.cancel_task(task_id);
992 assert!(cancelled, "Task should be cancelled successfully");
993
994 tokio::time::sleep(Duration::from_millis(100)).await;
996 assert_eq!(counter.load(Ordering::SeqCst), 0, "Callback should not have been executed");
997 }
998
999 #[tokio::test]
1000 async fn test_cancel_batch_direct() {
1001 let timer = TimerWheel::with_defaults();
1002 let service = timer.create_service(ServiceConfig::default());
1003 let counter = Arc::new(AtomicU32::new(0));
1004
1005 let callbacks: Vec<_> = (0..10)
1007 .map(|_| {
1008 let counter = Arc::clone(&counter);
1009 (Duration::from_secs(10), Some(CallbackWrapper::new(move || {
1010 let counter = Arc::clone(&counter);
1011 async move {
1012 counter.fetch_add(1, Ordering::SeqCst);
1013 }
1014 })))
1015 })
1016 .collect();
1017
1018 let tasks = TimerService::create_batch_with_callbacks(callbacks);
1019 let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
1020 assert_eq!(task_ids.len(), 10);
1021 service.register_batch(tasks).unwrap();
1022
1023 let cancelled = service.cancel_batch(&task_ids);
1025 assert_eq!(cancelled, 10, "All 10 tasks should be cancelled");
1026
1027 tokio::time::sleep(Duration::from_millis(100)).await;
1029 assert_eq!(counter.load(Ordering::SeqCst), 0, "No callbacks should have been executed");
1030 }
1031
1032 #[tokio::test]
1033 async fn test_cancel_batch_partial() {
1034 let timer = TimerWheel::with_defaults();
1035 let service = timer.create_service(ServiceConfig::default());
1036 let counter = Arc::new(AtomicU32::new(0));
1037
1038 let callbacks: Vec<_> = (0..10)
1040 .map(|_| {
1041 let counter = Arc::clone(&counter);
1042 (Duration::from_secs(10), Some(CallbackWrapper::new(move || {
1043 let counter = Arc::clone(&counter);
1044 async move {
1045 counter.fetch_add(1, Ordering::SeqCst);
1046 }
1047 })))
1048 })
1049 .collect();
1050
1051 let tasks = TimerService::create_batch_with_callbacks(callbacks);
1052 let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
1053 service.register_batch(tasks).unwrap();
1054
1055 let to_cancel: Vec<_> = task_ids.iter().take(5).copied().collect();
1057 let cancelled = service.cancel_batch(&to_cancel);
1058 assert_eq!(cancelled, 5, "5 tasks should be cancelled");
1059
1060 tokio::time::sleep(Duration::from_millis(100)).await;
1062 assert_eq!(counter.load(Ordering::SeqCst), 0, "Cancelled tasks should not execute");
1063 }
1064
1065 #[tokio::test]
1066 async fn test_cancel_batch_empty() {
1067 let timer = TimerWheel::with_defaults();
1068 let service = timer.create_service(ServiceConfig::default());
1069
1070 let empty: Vec<TaskId> = vec![];
1072 let cancelled = service.cancel_batch(&empty);
1073 assert_eq!(cancelled, 0, "No tasks should be cancelled");
1074 }
1075
1076 #[tokio::test]
1077 async fn test_postpone() {
1078 let timer = TimerWheel::with_defaults();
1079 let mut service = timer.create_service(ServiceConfig::default());
1080 let counter = Arc::new(AtomicU32::new(0));
1081
1082 let counter_clone1 = Arc::clone(&counter);
1084 let task = TimerService::create_task(
1085 Duration::from_millis(50),
1086 Some(CallbackWrapper::new(move || {
1087 let counter = Arc::clone(&counter_clone1);
1088 async move {
1089 counter.fetch_add(1, Ordering::SeqCst);
1090 }
1091 })),
1092 );
1093 let task_id = task.get_id();
1094 service.register(task).unwrap();
1095
1096 let counter_clone2 = Arc::clone(&counter);
1098 let postponed = service.postpone(
1099 task_id,
1100 Duration::from_millis(100),
1101 Some(CallbackWrapper::new(move || {
1102 let counter = Arc::clone(&counter_clone2);
1103 async move {
1104 counter.fetch_add(10, Ordering::SeqCst);
1105 }
1106 }))
1107 );
1108 assert!(postponed, "Task should be postponed successfully");
1109
1110 let mut rx = service.take_receiver().unwrap();
1112 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1113 .await
1114 .expect("Should receive timeout notification")
1115 .expect("Should receive Some value");
1116
1117 assert_eq!(received_task_id, task_id);
1118
1119 tokio::time::sleep(Duration::from_millis(20)).await;
1121
1122 assert_eq!(counter.load(Ordering::SeqCst), 10);
1124 }
1125
1126 #[tokio::test]
1127 async fn test_postpone_nonexistent_task() {
1128 let timer = TimerWheel::with_defaults();
1129 let service = timer.create_service(ServiceConfig::default());
1130
1131 let fake_task = TimerService::create_task(Duration::from_millis(50), None);
1133 let fake_task_id = fake_task.get_id();
1134 let postponed = service.postpone(fake_task_id, Duration::from_millis(100), None);
1137 assert!(!postponed, "Nonexistent task should not be postponed");
1138 }
1139
1140 #[tokio::test]
1141 async fn test_postpone_batch() {
1142 let timer = TimerWheel::with_defaults();
1143 let mut service = timer.create_service(ServiceConfig::default());
1144 let counter = Arc::new(AtomicU32::new(0));
1145
1146 let mut task_ids = Vec::new();
1148 for _ in 0..3 {
1149 let counter_clone = Arc::clone(&counter);
1150 let task = TimerService::create_task(
1151 Duration::from_millis(50),
1152 Some(CallbackWrapper::new(move || {
1153 let counter = Arc::clone(&counter_clone);
1154 async move {
1155 counter.fetch_add(1, Ordering::SeqCst);
1156 }
1157 })),
1158 );
1159 task_ids.push((task.get_id(), Duration::from_millis(150), None));
1160 service.register(task).unwrap();
1161 }
1162
1163 let postponed = service.postpone_batch_with_callbacks(task_ids);
1165 assert_eq!(postponed, 3, "All 3 tasks should be postponed");
1166
1167 tokio::time::sleep(Duration::from_millis(70)).await;
1169 assert_eq!(counter.load(Ordering::SeqCst), 0);
1170
1171 let mut received_count = 0;
1173 let mut rx = service.take_receiver().unwrap();
1174
1175 while received_count < 3 {
1176 match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
1177 Ok(Some(_task_id)) => {
1178 received_count += 1;
1179 }
1180 Ok(None) => break,
1181 Err(_) => break,
1182 }
1183 }
1184
1185 assert_eq!(received_count, 3);
1186
1187 tokio::time::sleep(Duration::from_millis(20)).await;
1189 assert_eq!(counter.load(Ordering::SeqCst), 3);
1190 }
1191
1192 #[tokio::test]
1193 async fn test_postpone_batch_with_callbacks() {
1194 let timer = TimerWheel::with_defaults();
1195 let mut service = timer.create_service(ServiceConfig::default());
1196 let counter = Arc::new(AtomicU32::new(0));
1197
1198 let mut task_ids = Vec::new();
1200 for _ in 0..3 {
1201 let task = TimerService::create_task(
1202 Duration::from_millis(50),
1203 None,
1204 );
1205 task_ids.push(task.get_id());
1206 service.register(task).unwrap();
1207 }
1208
1209 let updates: Vec<_> = task_ids
1211 .into_iter()
1212 .map(|id| {
1213 let counter_clone = Arc::clone(&counter);
1214 (id, Duration::from_millis(150), Some(CallbackWrapper::new(move || {
1215 let counter = Arc::clone(&counter_clone);
1216 async move {
1217 counter.fetch_add(1, Ordering::SeqCst);
1218 }
1219 })))
1220 })
1221 .collect();
1222
1223 let postponed = service.postpone_batch_with_callbacks(updates);
1224 assert_eq!(postponed, 3, "All 3 tasks should be postponed");
1225
1226 tokio::time::sleep(Duration::from_millis(70)).await;
1228 assert_eq!(counter.load(Ordering::SeqCst), 0);
1229
1230 let mut received_count = 0;
1232 let mut rx = service.take_receiver().unwrap();
1233
1234 while received_count < 3 {
1235 match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
1236 Ok(Some(_task_id)) => {
1237 received_count += 1;
1238 }
1239 Ok(None) => break,
1240 Err(_) => break,
1241 }
1242 }
1243
1244 assert_eq!(received_count, 3);
1245
1246 tokio::time::sleep(Duration::from_millis(20)).await;
1248 assert_eq!(counter.load(Ordering::SeqCst), 3);
1249 }
1250
1251 #[tokio::test]
1252 async fn test_postpone_batch_empty() {
1253 let timer = TimerWheel::with_defaults();
1254 let service = timer.create_service(ServiceConfig::default());
1255
1256 let empty: Vec<(TaskId, Duration, Option<CallbackWrapper>)> = vec![];
1258 let postponed = service.postpone_batch_with_callbacks(empty);
1259 assert_eq!(postponed, 0, "No tasks should be postponed");
1260 }
1261
1262 #[tokio::test]
1263 async fn test_postpone_keeps_timeout_notification_valid() {
1264 let timer = TimerWheel::with_defaults();
1265 let mut service = timer.create_service(ServiceConfig::default());
1266 let counter = Arc::new(AtomicU32::new(0));
1267
1268 let counter_clone = Arc::clone(&counter);
1270 let task = TimerService::create_task(
1271 Duration::from_millis(50),
1272 Some(CallbackWrapper::new(move || {
1273 let counter = Arc::clone(&counter_clone);
1274 async move {
1275 counter.fetch_add(1, Ordering::SeqCst);
1276 }
1277 })),
1278 );
1279 let task_id = task.get_id();
1280 service.register(task).unwrap();
1281
1282 service.postpone(task_id, Duration::from_millis(100), None);
1284
1285 let mut rx = service.take_receiver().unwrap();
1287 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1288 .await
1289 .expect("Should receive timeout notification")
1290 .expect("Should receive Some value");
1291
1292 assert_eq!(received_task_id, task_id, "Timeout notification should still work after postpone");
1293
1294 tokio::time::sleep(Duration::from_millis(20)).await;
1296 assert_eq!(counter.load(Ordering::SeqCst), 1);
1297 }
1298
1299 #[tokio::test]
1300 async fn test_cancelled_task_not_forwarded_to_timeout_rx() {
1301 let timer = TimerWheel::with_defaults();
1302 let mut service = timer.create_service(ServiceConfig::default());
1303
1304 let task1 = TimerService::create_task(Duration::from_secs(10), None);
1306 let task1_id = task1.get_id();
1307 service.register(task1).unwrap();
1308
1309 let task2 = TimerService::create_task(Duration::from_millis(50), None);
1310 let task2_id = task2.get_id();
1311 service.register(task2).unwrap();
1312
1313 let cancelled = service.cancel_task(task1_id);
1315 assert!(cancelled, "Task should be cancelled");
1316
1317 let mut rx = service.take_receiver().unwrap();
1319 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1320 .await
1321 .expect("Should receive timeout notification")
1322 .expect("Should receive Some value");
1323
1324 assert_eq!(received_task_id, task2_id, "Should only receive expired task notification");
1326
1327 let no_more = tokio::time::timeout(Duration::from_millis(50), rx.recv()).await;
1329 assert!(no_more.is_err(), "Should not receive any more notifications");
1330 }
1331
1332 #[tokio::test]
1333 async fn test_take_receiver_twice() {
1334 let timer = TimerWheel::with_defaults();
1335 let mut service = timer.create_service(ServiceConfig::default());
1336
1337 let rx1 = service.take_receiver();
1339 assert!(rx1.is_some(), "First take_receiver should return Some");
1340
1341 let rx2 = service.take_receiver();
1343 assert!(rx2.is_none(), "Second take_receiver should return None");
1344 }
1345
1346 #[tokio::test]
1347 async fn test_postpone_batch_without_callbacks() {
1348 let timer = TimerWheel::with_defaults();
1349 let mut service = timer.create_service(ServiceConfig::default());
1350 let counter = Arc::new(AtomicU32::new(0));
1351
1352 let mut task_ids = Vec::new();
1354 for _ in 0..3 {
1355 let counter_clone = Arc::clone(&counter);
1356 let task = TimerService::create_task(
1357 Duration::from_millis(50),
1358 Some(CallbackWrapper::new(move || {
1359 let counter = Arc::clone(&counter_clone);
1360 async move {
1361 counter.fetch_add(1, Ordering::SeqCst);
1362 }
1363 })),
1364 );
1365 task_ids.push(task.get_id());
1366 service.register(task).unwrap();
1367 }
1368
1369 let updates: Vec<_> = task_ids
1371 .iter()
1372 .map(|&id| (id, Duration::from_millis(150)))
1373 .collect();
1374 let postponed = service.postpone_batch(updates);
1375 assert_eq!(postponed, 3, "All 3 tasks should be postponed");
1376
1377 tokio::time::sleep(Duration::from_millis(70)).await;
1379 assert_eq!(counter.load(Ordering::SeqCst), 0, "Callbacks should not fire yet");
1380
1381 let mut received_count = 0;
1383 let mut rx = service.take_receiver().unwrap();
1384
1385 while received_count < 3 {
1386 match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
1387 Ok(Some(_task_id)) => {
1388 received_count += 1;
1389 }
1390 Ok(None) => break,
1391 Err(_) => break,
1392 }
1393 }
1394
1395 assert_eq!(received_count, 3, "Should receive 3 timeout notifications");
1396
1397 tokio::time::sleep(Duration::from_millis(20)).await;
1399 assert_eq!(counter.load(Ordering::SeqCst), 3, "All callbacks should execute");
1400 }
1401}
1402