1use crate::config::{BatchConfig, WheelConfig};
2use crate::task::{TaskCompletionReason, TaskId, TaskLocation, TimerTask};
3use rustc_hash::FxHashMap;
4use std::time::Duration;
5
6struct WheelLayer {
10 slots: Vec<Vec<TimerTask>>,
14
15 current_tick: u64,
19
20 slot_count: usize,
24
25 tick_duration: Duration,
29
30 tick_duration_ms: u64,
34
35 slot_mask: usize,
39}
40
41impl WheelLayer {
42 fn new(slot_count: usize, tick_duration: Duration) -> Self {
46 let mut slots = Vec::with_capacity(slot_count);
47 for _ in 0..slot_count {
53 slots.push(Vec::with_capacity(4));
54 }
55
56 let tick_duration_ms = tick_duration.as_millis() as u64;
57 let slot_mask = slot_count - 1;
58
59 Self {
60 slots,
61 current_tick: 0,
62 slot_count,
63 tick_duration,
64 tick_duration_ms,
65 slot_mask,
66 }
67 }
68
69 fn delay_to_ticks(&self, delay: Duration) -> u64 {
73 let ticks = delay.as_millis() as u64 / self.tick_duration.as_millis() as u64;
74 ticks.max(1) }
76}
77
78pub struct Wheel {
82 l0: WheelLayer,
86
87 l1: WheelLayer,
91
92 l1_tick_ratio: u64,
96
97 task_index: FxHashMap<TaskId, TaskLocation>,
101
102 batch_config: BatchConfig,
106
107 l0_capacity_ms: u64,
111
112 l1_capacity_ticks: u64,
116}
117
118impl Wheel {
119 pub fn new(config: WheelConfig, batch_config: BatchConfig) -> Self {
137 let l0 = WheelLayer::new(config.l0_slot_count, config.l0_tick_duration);
138 let l1 = WheelLayer::new(config.l1_slot_count, config.l1_tick_duration);
139
140 let l1_tick_ratio = l1.tick_duration_ms / l0.tick_duration_ms;
143
144 let l0_capacity_ms = (l0.slot_count as u64) * l0.tick_duration_ms;
147 let l1_capacity_ticks = l1.slot_count as u64;
148
149 Self {
150 l0,
151 l1,
152 l1_tick_ratio,
153 task_index: FxHashMap::default(),
154 batch_config,
155 l0_capacity_ms,
156 l1_capacity_ticks,
157 }
158 }
159
160 #[allow(dead_code)]
164 pub fn current_tick(&self) -> u64 {
165 self.l0.current_tick
166 }
167
168 #[allow(dead_code)]
172 pub fn tick_duration(&self) -> Duration {
173 self.l0.tick_duration
174 }
175
176 #[allow(dead_code)]
180 pub fn slot_count(&self) -> usize {
181 self.l0.slot_count
182 }
183
184 #[allow(dead_code)]
188 pub fn delay_to_ticks(&self, delay: Duration) -> u64 {
189 self.l0.delay_to_ticks(delay)
190 }
191
192 #[inline(always)]
208 fn determine_layer(&self, delay: Duration) -> (u8, u64, u32) {
209 let delay_ms = delay.as_millis() as u64;
210
211 if delay_ms < self.l0_capacity_ms {
214 let l0_ticks = (delay_ms / self.l0.tick_duration_ms).max(1);
215 return (0, l0_ticks, 0);
216 }
217
218 let l1_ticks = (delay_ms / self.l1.tick_duration_ms).max(1);
221
222 if l1_ticks < self.l1_capacity_ticks {
223 (1, l1_ticks, 0)
224 } else {
225 let rounds = (l1_ticks / self.l1_capacity_ticks) as u32;
226 (1, l1_ticks, rounds)
227 }
228 }
229
230 #[inline]
260 pub fn insert(&mut self, mut task: TimerTask, notifier: crate::task::CompletionNotifier) -> TaskId {
261 let (level, ticks, rounds) = self.determine_layer(task.delay);
262
263 let (current_tick, slot_mask, slots) = match level {
266 0 => (self.l0.current_tick, self.l0.slot_mask, &mut self.l0.slots),
267 _ => (self.l1.current_tick, self.l1.slot_mask, &mut self.l1.slots),
268 };
269
270 let total_ticks = current_tick + ticks;
271 let slot_index = (total_ticks as usize) & slot_mask;
272
273 task.prepare_for_registration(notifier, total_ticks, rounds);
276
277 let task_id = task.id;
278
279 let vec_index = slots[slot_index].len();
282 let location = TaskLocation::new(level, slot_index, vec_index);
283
284 slots[slot_index].push(task);
287
288 self.task_index.insert(task_id, location);
291
292 task_id
293 }
294
295 #[inline]
319 pub fn insert_batch(&mut self, tasks: Vec<(TimerTask, crate::task::CompletionNotifier)>) -> Vec<TaskId> {
320 let task_count = tasks.len();
321
322 self.task_index.reserve(task_count);
325
326 let mut task_ids = Vec::with_capacity(task_count);
327
328 for (mut task, notifier) in tasks {
329 let (level, ticks, rounds) = self.determine_layer(task.delay);
330
331 let (current_tick, slot_mask, slots) = match level {
334 0 => (self.l0.current_tick, self.l0.slot_mask, &mut self.l0.slots),
335 _ => (self.l1.current_tick, self.l1.slot_mask, &mut self.l1.slots),
336 };
337
338 let total_ticks = current_tick + ticks;
339 let slot_index = (total_ticks as usize) & slot_mask;
340
341 task.prepare_for_registration(notifier, total_ticks, rounds);
344
345 let task_id = task.id;
346
347 let vec_index = slots[slot_index].len();
350 let location = TaskLocation::new(level, slot_index, vec_index);
351
352 slots[slot_index].push(task);
355
356 self.task_index.insert(task_id, location);
359
360 task_ids.push(task_id);
361 }
362
363 task_ids
364 }
365
366 #[inline]
382 pub fn cancel(&mut self, task_id: TaskId) -> bool {
383 let location = match self.task_index.remove(&task_id) {
386 Some(loc) => loc,
387 None => return false,
388 };
389
390 let slot = match location.level {
393 0 => &mut self.l0.slots[location.slot_index],
394 _ => &mut self.l1.slots[location.slot_index],
395 };
396
397 if location.vec_index >= slot.len() || slot[location.vec_index].id != task_id {
400 self.task_index.insert(task_id, location);
403 return false;
404 }
405
406 if let Some(notifier) = slot[location.vec_index].completion_notifier.take() {
409 let _ = notifier.0.send(TaskCompletionReason::Cancelled);
410 }
411
412 let removed_task = slot.swap_remove(location.vec_index);
415
416 if location.vec_index < slot.len() {
419 let swapped_task_id = slot[location.vec_index].id;
420 if let Some(swapped_location) = self.task_index.get_mut(&swapped_task_id) {
423 swapped_location.vec_index = location.vec_index;
424 }
425 }
426
427 debug_assert_eq!(removed_task.id, task_id);
430 true
431 }
432
433 #[inline]
461 pub fn cancel_batch(&mut self, task_ids: &[TaskId]) -> usize {
462 let mut cancelled_count = 0;
463
464 if task_ids.len() <= self.batch_config.small_batch_threshold {
467 for &task_id in task_ids {
468 if self.cancel(task_id) {
469 cancelled_count += 1;
470 }
471 }
472 return cancelled_count;
473 }
474
475 let l0_slot_count = self.l0.slot_count;
480 let l1_slot_count = self.l1.slot_count;
481
482 let mut l0_tasks_by_slot: Vec<Vec<(TaskId, usize)>> = vec![Vec::new(); l0_slot_count];
483 let mut l1_tasks_by_slot: Vec<Vec<(TaskId, usize)>> = vec![Vec::new(); l1_slot_count];
484
485 for &task_id in task_ids {
488 if let Some(location) = self.task_index.get(&task_id) {
489 if location.level == 0 {
490 l0_tasks_by_slot[location.slot_index].push((task_id, location.vec_index));
491 } else {
492 l1_tasks_by_slot[location.slot_index].push((task_id, location.vec_index));
493 }
494 }
495 }
496
497 for (slot_index, tasks) in l0_tasks_by_slot.iter_mut().enumerate() {
500 if tasks.is_empty() {
501 continue;
502 }
503
504 tasks.sort_unstable_by(|a, b| b.1.cmp(&a.1));
507
508 let slot = &mut self.l0.slots[slot_index];
509
510 for &(task_id, vec_index) in tasks.iter() {
511 if vec_index < slot.len() && slot[vec_index].id == task_id {
512 if let Some(notifier) = slot[vec_index].completion_notifier.take() {
513 let _ = notifier.0.send(TaskCompletionReason::Cancelled);
514 }
515
516 slot.swap_remove(vec_index);
517
518 if vec_index < slot.len() {
519 let swapped_task_id = slot[vec_index].id;
520 if let Some(swapped_location) = self.task_index.get_mut(&swapped_task_id) {
521 swapped_location.vec_index = vec_index;
522 }
523 }
524
525 self.task_index.remove(&task_id);
526 cancelled_count += 1;
527 }
528 }
529 }
530
531 for (slot_index, tasks) in l1_tasks_by_slot.iter_mut().enumerate() {
534 if tasks.is_empty() {
535 continue;
536 }
537
538 tasks.sort_unstable_by(|a, b| b.1.cmp(&a.1));
539
540 let slot = &mut self.l1.slots[slot_index];
541
542 for &(task_id, vec_index) in tasks.iter() {
543 if vec_index < slot.len() && slot[vec_index].id == task_id {
544 if let Some(notifier) = slot[vec_index].completion_notifier.take() {
545 let _ = notifier.0.send(TaskCompletionReason::Cancelled);
546 }
547
548 slot.swap_remove(vec_index);
549
550 if vec_index < slot.len() {
551 let swapped_task_id = slot[vec_index].id;
552 if let Some(swapped_location) = self.task_index.get_mut(&swapped_task_id) {
553 swapped_location.vec_index = vec_index;
554 }
555 }
556
557 self.task_index.remove(&task_id);
558 cancelled_count += 1;
559 }
560 }
561 }
562
563 cancelled_count
564 }
565
566 pub fn advance(&mut self) -> Vec<TimerTask> {
586 self.l0.current_tick += 1;
589
590 let mut expired_tasks = Vec::new();
591
592 let l0_slot_index = (self.l0.current_tick as usize) & self.l0.slot_mask;
595 let l0_slot = &mut self.l0.slots[l0_slot_index];
596
597 let i = 0;
598 while i < l0_slot.len() {
599 let task = &l0_slot[i];
600
601 self.task_index.remove(&task.id);
604
605 let expired_task = l0_slot.swap_remove(i);
608
609 if i < l0_slot.len() {
612 let swapped_task_id = l0_slot[i].id;
613 if let Some(swapped_location) = self.task_index.get_mut(&swapped_task_id) {
614 swapped_location.vec_index = i;
615 }
616 }
617
618 expired_tasks.push(expired_task);
619 }
620
621 if self.l0.current_tick % self.l1_tick_ratio == 0 {
626 self.l1.current_tick += 1;
627 let l1_slot_index = (self.l1.current_tick as usize) & self.l1.slot_mask;
628 let l1_slot = &mut self.l1.slots[l1_slot_index];
629
630 let mut tasks_to_demote = Vec::new();
633 let mut i = 0;
634 while i < l1_slot.len() {
635 let task = &mut l1_slot[i];
636
637 if task.rounds > 0 {
638 task.rounds -= 1;
641 if let Some(location) = self.task_index.get_mut(&task.id) {
642 location.vec_index = i;
643 }
644 i += 1;
645 } else {
646 self.task_index.remove(&task.id);
649 let task_to_demote = l1_slot.swap_remove(i);
650
651 if i < l1_slot.len() {
652 let swapped_task_id = l1_slot[i].id;
653 if let Some(swapped_location) = self.task_index.get_mut(&swapped_task_id) {
654 swapped_location.vec_index = i;
655 }
656 }
657
658 tasks_to_demote.push(task_to_demote);
659 }
660 }
661
662 self.demote_tasks(tasks_to_demote);
665 }
666
667 expired_tasks
668 }
669
670 fn demote_tasks(&mut self, tasks: Vec<TimerTask>) {
678 for task in tasks {
679 let l1_tick_ratio = self.l1_tick_ratio;
684
685 let l1_deadline = task.deadline_tick;
688
689 let l0_deadline_tick = l1_deadline * l1_tick_ratio;
692 let l0_current_tick = self.l0.current_tick;
693
694 let remaining_l0_ticks = if l0_deadline_tick > l0_current_tick {
697 l0_deadline_tick - l0_current_tick
698 } else {
699 1 };
701
702 let target_l0_tick = l0_current_tick + remaining_l0_ticks;
705 let l0_slot_index = (target_l0_tick as usize) & self.l0.slot_mask;
706
707 let task_id = task.id;
708 let vec_index = self.l0.slots[l0_slot_index].len();
709 let location = TaskLocation::new(0, l0_slot_index, vec_index);
710
711 self.l0.slots[l0_slot_index].push(task);
714 self.task_index.insert(task_id, location);
715 }
716 }
717
718 #[allow(dead_code)]
722 pub fn is_empty(&self) -> bool {
723 self.task_index.is_empty()
724 }
725
726 #[inline]
762 pub fn postpone(
763 &mut self,
764 task_id: TaskId,
765 new_delay: Duration,
766 new_callback: Option<crate::task::CallbackWrapper>,
767 ) -> bool {
768 let old_location = match self.task_index.remove(&task_id) {
771 Some(loc) => loc,
772 None => return false,
773 };
774
775 let slot = match old_location.level {
778 0 => &mut self.l0.slots[old_location.slot_index],
779 _ => &mut self.l1.slots[old_location.slot_index],
780 };
781
782 if old_location.vec_index >= slot.len() || slot[old_location.vec_index].id != task_id {
785 self.task_index.insert(task_id, old_location);
788 return false;
789 }
790
791 let mut task = slot.swap_remove(old_location.vec_index);
794
795 if old_location.vec_index < slot.len() {
798 let swapped_task_id = slot[old_location.vec_index].id;
799 if let Some(swapped_location) = self.task_index.get_mut(&swapped_task_id) {
800 swapped_location.vec_index = old_location.vec_index;
801 }
802 }
803
804 task.delay = new_delay;
807 if let Some(callback) = new_callback {
808 task.callback = Some(callback);
809 }
810
811 let (new_level, ticks, new_rounds) = self.determine_layer(new_delay);
814
815 let (current_tick, slot_mask, slots) = match new_level {
818 0 => (self.l0.current_tick, self.l0.slot_mask, &mut self.l0.slots),
819 _ => (self.l1.current_tick, self.l1.slot_mask, &mut self.l1.slots),
820 };
821
822 let total_ticks = current_tick + ticks;
823 let new_slot_index = (total_ticks as usize) & slot_mask;
824
825 task.deadline_tick = total_ticks;
828 task.rounds = new_rounds;
829
830 let new_vec_index = slots[new_slot_index].len();
833 let new_location = TaskLocation::new(new_level, new_slot_index, new_vec_index);
834
835 slots[new_slot_index].push(task);
836 self.task_index.insert(task_id, new_location);
837
838 true
839 }
840
841 #[inline]
871 pub fn postpone_batch(
872 &mut self,
873 updates: Vec<(TaskId, Duration)>,
874 ) -> usize {
875 let mut postponed_count = 0;
876
877 for (task_id, new_delay) in updates {
878 if self.postpone(task_id, new_delay, None) {
879 postponed_count += 1;
880 }
881 }
882
883 postponed_count
884 }
885
886 pub fn postpone_batch_with_callbacks(
902 &mut self,
903 updates: Vec<(TaskId, Duration, Option<crate::task::CallbackWrapper>)>,
904 ) -> usize {
905 let mut postponed_count = 0;
906
907 for (task_id, new_delay, new_callback) in updates {
908 if self.postpone(task_id, new_delay, new_callback) {
909 postponed_count += 1;
910 }
911 }
912
913 postponed_count
914 }
915}
916
917#[cfg(test)]
918mod tests {
919 use super::*;
920 use crate::task::CallbackWrapper;
921
922 #[test]
923 fn test_wheel_creation() {
924 let wheel = Wheel::new(WheelConfig::default(), BatchConfig::default());
925 assert_eq!(wheel.slot_count(), 512);
926 assert_eq!(wheel.current_tick(), 0);
927 assert!(wheel.is_empty());
928 }
929
930 #[test]
931 fn test_hierarchical_wheel_creation() {
932 let config = WheelConfig::default();
933
934 let wheel = Wheel::new(config, BatchConfig::default());
935 assert_eq!(wheel.slot_count(), 512); assert_eq!(wheel.current_tick(), 0);
937 assert!(wheel.is_empty());
938 assert_eq!(wheel.l1.slot_count, 64);
940 assert_eq!(wheel.l1_tick_ratio, 100); }
942
943 #[test]
944 fn test_hierarchical_config_validation() {
945 let result = WheelConfig::builder()
947 .l0_tick_duration(Duration::from_millis(10))
948 .l0_slot_count(512)
949 .l1_tick_duration(Duration::from_millis(15)) .l1_slot_count(64)
951 .build();
952
953 assert!(result.is_err());
954
955 let result = WheelConfig::builder()
957 .l0_tick_duration(Duration::from_millis(10))
958 .l0_slot_count(512)
959 .l1_tick_duration(Duration::from_secs(1)) .l1_slot_count(64)
961 .build();
962
963 assert!(result.is_ok());
964 }
965
966 #[test]
967 fn test_layer_determination() {
968 let config = WheelConfig::default();
969
970 let wheel = Wheel::new(config, BatchConfig::default());
971
972 let (level, _, rounds) = wheel.determine_layer(Duration::from_millis(100));
975 assert_eq!(level, 0);
976 assert_eq!(rounds, 0);
977
978 let (level, _, rounds) = wheel.determine_layer(Duration::from_secs(10));
981 assert_eq!(level, 1);
982 assert_eq!(rounds, 0);
983
984 let (level, _, rounds) = wheel.determine_layer(Duration::from_secs(120));
987 assert_eq!(level, 1);
988 assert!(rounds > 0);
989 }
990
991 #[test]
992 fn test_hierarchical_insert_and_advance() {
993 use crate::task::{TimerTask, CompletionNotifier};
994
995 let config = WheelConfig::default();
996
997 let mut wheel = Wheel::new(config, BatchConfig::default());
998
999 let callback = CallbackWrapper::new(|| async {});
1001 let (tx, _rx) = tokio::sync::oneshot::channel();
1002 let task = TimerTask::new(Duration::from_millis(100), Some(callback));
1003 let task_id = wheel.insert(task, CompletionNotifier(tx));
1004
1005 let location = wheel.task_index.get(&task_id).unwrap();
1007 assert_eq!(location.level, 0);
1008
1009 for _ in 0..10 {
1011 let expired = wheel.advance();
1012 if !expired.is_empty() {
1013 assert_eq!(expired.len(), 1);
1014 assert_eq!(expired[0].id, task_id);
1015 return;
1016 }
1017 }
1018 panic!("Task should have expired");
1019 }
1020
1021 #[test]
1022 fn test_hierarchical_l1_to_l0_demotion() {
1023 use crate::task::{TimerTask, CompletionNotifier};
1024
1025 let config = WheelConfig::builder()
1026 .l0_tick_duration(Duration::from_millis(10))
1027 .l0_slot_count(512)
1028 .l1_tick_duration(Duration::from_millis(100)) .l1_slot_count(64)
1030 .build()
1031 .unwrap();
1032
1033 let mut wheel = Wheel::new(config, BatchConfig::default());
1034 let l1_tick_ratio = wheel.l1_tick_ratio;
1035 assert_eq!(l1_tick_ratio, 10); let callback = CallbackWrapper::new(|| async {});
1039 let (tx, _rx) = tokio::sync::oneshot::channel();
1040 let task = TimerTask::new(Duration::from_millis(6000), Some(callback));
1041 let task_id = wheel.insert(task, CompletionNotifier(tx));
1042
1043 let location = wheel.task_index.get(&task_id).unwrap();
1045 assert_eq!(location.level, 1);
1046
1047 let mut demoted = false;
1050 for i in 0..610 {
1051 wheel.advance();
1052
1053 if let Some(location) = wheel.task_index.get(&task_id) {
1055 if location.level == 0 && !demoted {
1056 demoted = true;
1057 println!("Task demoted to L0 at L0 tick {}", i); }
1059 }
1060 }
1061
1062 assert!(demoted, "Task should have been demoted from L1 to L0"); }
1064
1065 #[test]
1066 fn test_cross_layer_cancel() {
1067 use crate::task::{TimerTask, CompletionNotifier};
1068
1069 let config = WheelConfig::default();
1070
1071 let mut wheel = Wheel::new(config, BatchConfig::default());
1072
1073 let callback1 = CallbackWrapper::new(|| async {});
1075 let (tx1, _rx1) = tokio::sync::oneshot::channel();
1076 let task1 = TimerTask::new(Duration::from_millis(100), Some(callback1));
1077 let task_id1 = wheel.insert(task1, CompletionNotifier(tx1));
1078
1079 let callback2 = CallbackWrapper::new(|| async {});
1081 let (tx2, _rx2) = tokio::sync::oneshot::channel();
1082 let task2 = TimerTask::new(Duration::from_secs(10), Some(callback2));
1083 let task_id2 = wheel.insert(task2, CompletionNotifier(tx2));
1084
1085 assert_eq!(wheel.task_index.get(&task_id1).unwrap().level, 0);
1087 assert_eq!(wheel.task_index.get(&task_id2).unwrap().level, 1);
1088
1089 assert!(wheel.cancel(task_id1));
1091 assert!(wheel.task_index.get(&task_id1).is_none());
1092
1093 assert!(wheel.cancel(task_id2));
1095 assert!(wheel.task_index.get(&task_id2).is_none());
1096
1097 assert!(wheel.is_empty()); }
1099
1100 #[test]
1101 fn test_cross_layer_postpone() {
1102 use crate::task::{TimerTask, CompletionNotifier};
1103
1104 let config = WheelConfig::default();
1105
1106 let mut wheel = Wheel::new(config, BatchConfig::default());
1107
1108 let callback = CallbackWrapper::new(|| async {});
1110 let (tx, _rx) = tokio::sync::oneshot::channel();
1111 let task = TimerTask::new(Duration::from_millis(100), Some(callback));
1112 let task_id = wheel.insert(task, CompletionNotifier(tx));
1113
1114 assert_eq!(wheel.task_index.get(&task_id).unwrap().level, 0);
1116
1117 assert!(wheel.postpone(task_id, Duration::from_secs(10), None));
1119
1120 assert_eq!(wheel.task_index.get(&task_id).unwrap().level, 1);
1122
1123 assert!(wheel.postpone(task_id, Duration::from_millis(200), None));
1125
1126 assert_eq!(wheel.task_index.get(&task_id).unwrap().level, 0);
1128 }
1129
1130 #[test]
1131 fn test_delay_to_ticks() {
1132 let wheel = Wheel::new(WheelConfig::default(), BatchConfig::default());
1133 assert_eq!(wheel.delay_to_ticks(Duration::from_millis(100)), 10);
1134 assert_eq!(wheel.delay_to_ticks(Duration::from_millis(50)), 5);
1135 assert_eq!(wheel.delay_to_ticks(Duration::from_millis(1)), 1); }
1137
1138 #[test]
1139 fn test_wheel_invalid_slot_count() {
1140 let result = WheelConfig::builder()
1141 .l0_slot_count(100)
1142 .build();
1143 assert!(result.is_err());
1144 if let Err(crate::error::TimerError::InvalidSlotCount { slot_count, reason }) = result {
1145 assert_eq!(slot_count, 100);
1146 assert_eq!(reason, "L0 layer slot count must be power of 2"); } else {
1148 panic!("Expected InvalidSlotCount error"); }
1150 }
1151
1152 #[test]
1153 fn test_insert_batch() {
1154 use crate::task::{TimerTask, CompletionNotifier};
1155
1156 let mut wheel = Wheel::new(WheelConfig::default(), BatchConfig::default());
1157
1158 let tasks: Vec<(TimerTask, CompletionNotifier)> = (0..10)
1160 .map(|i| {
1161 let callback = CallbackWrapper::new(|| async {});
1162 let (completion_tx, _completion_rx) = tokio::sync::oneshot::channel();
1163 let notifier = CompletionNotifier(completion_tx);
1164 let task = TimerTask::new(Duration::from_millis(100 + i * 10), Some(callback));
1165 (task, notifier)
1166 })
1167 .collect();
1168
1169 let task_ids = wheel.insert_batch(tasks);
1170
1171 assert_eq!(task_ids.len(), 10);
1172 assert!(!wheel.is_empty());
1173 }
1174
1175 #[test]
1176 fn test_cancel_batch() {
1177 use crate::task::{TimerTask, CompletionNotifier};
1178
1179 let mut wheel = Wheel::new(WheelConfig::default(), BatchConfig::default());
1180
1181 let mut task_ids = Vec::new();
1183 for i in 0..10 {
1184 let callback = CallbackWrapper::new(|| async {});
1185 let (completion_tx, _completion_rx) = tokio::sync::oneshot::channel();
1186 let notifier = CompletionNotifier(completion_tx);
1187 let task = TimerTask::new(Duration::from_millis(100 + i * 10), Some(callback));
1188 let task_id = wheel.insert(task, notifier);
1189 task_ids.push(task_id);
1190 }
1191
1192 assert_eq!(task_ids.len(), 10);
1193
1194 let to_cancel = &task_ids[0..5];
1196 let cancelled_count = wheel.cancel_batch(to_cancel);
1197
1198 assert_eq!(cancelled_count, 5);
1199
1200 let cancelled_again = wheel.cancel_batch(to_cancel);
1202 assert_eq!(cancelled_again, 0);
1203
1204 let remaining = &task_ids[5..10];
1206 let cancelled_remaining = wheel.cancel_batch(remaining);
1207 assert_eq!(cancelled_remaining, 5);
1208
1209 assert!(wheel.is_empty());
1210 }
1211
1212 #[test]
1213 fn test_batch_operations_same_slot() {
1214 use crate::task::{TimerTask, CompletionNotifier};
1215
1216 let mut wheel = Wheel::new(WheelConfig::default(), BatchConfig::default());
1217
1218 let mut task_ids = Vec::new();
1220 for _ in 0..20 {
1221 let callback = CallbackWrapper::new(|| async {});
1222 let (completion_tx, _completion_rx) = tokio::sync::oneshot::channel();
1223 let notifier = CompletionNotifier(completion_tx);
1224 let task = TimerTask::new(Duration::from_millis(100), Some(callback));
1225 let task_id = wheel.insert(task, notifier);
1226 task_ids.push(task_id);
1227 }
1228
1229 let cancelled_count = wheel.cancel_batch(&task_ids);
1231 assert_eq!(cancelled_count, 20);
1232 assert!(wheel.is_empty());
1233 }
1234
1235 #[test]
1236 fn test_postpone_single_task() {
1237 use crate::task::{TimerTask, CompletionNotifier};
1238
1239 let mut wheel = Wheel::new(WheelConfig::default(), BatchConfig::default());
1240
1241 let callback = CallbackWrapper::new(|| async {});
1243 let (completion_tx, _completion_rx) = tokio::sync::oneshot::channel();
1244 let notifier = CompletionNotifier(completion_tx);
1245 let task = TimerTask::new(Duration::from_millis(100), Some(callback));
1246 let task_id = wheel.insert(task, notifier);
1247
1248 let postponed = wheel.postpone(task_id, Duration::from_millis(200), None);
1250 assert!(postponed);
1251
1252 assert!(!wheel.is_empty());
1254
1255 for _ in 0..10 {
1257 let expired = wheel.advance();
1258 assert!(expired.is_empty());
1259 }
1260
1261 let mut triggered = false;
1263 for _ in 0..10 {
1264 let expired = wheel.advance();
1265 if !expired.is_empty() {
1266 assert_eq!(expired.len(), 1);
1267 assert_eq!(expired[0].id, task_id);
1268 triggered = true;
1269 break;
1270 }
1271 }
1272 assert!(triggered);
1273 }
1274
1275 #[test]
1276 fn test_postpone_with_new_callback() {
1277 use crate::task::{TimerTask, CompletionNotifier};
1278
1279 let mut wheel = Wheel::new(WheelConfig::default(), BatchConfig::default());
1280
1281 let old_callback = CallbackWrapper::new(|| async {});
1283 let (completion_tx, _completion_rx) = tokio::sync::oneshot::channel();
1284 let notifier = CompletionNotifier(completion_tx);
1285 let task = TimerTask::new(Duration::from_millis(100), Some(old_callback.clone()));
1286 let task_id = wheel.insert(task, notifier);
1287
1288 let new_callback = CallbackWrapper::new(|| async {});
1290 let postponed = wheel.postpone(task_id, Duration::from_millis(50), Some(new_callback));
1291 assert!(postponed);
1292
1293 let mut triggered = false;
1296 for i in 0..5 {
1297 let expired = wheel.advance();
1298 if !expired.is_empty() {
1299 assert_eq!(expired.len(), 1, "On the {}th advance, there should be 1 task triggered", i + 1);
1300 assert_eq!(expired[0].id, task_id);
1301 triggered = true;
1302 break;
1303 }
1304 }
1305 assert!(triggered, "Task should be triggered within 5 ticks"); }
1307
1308 #[test]
1309 fn test_postpone_nonexistent_task() {
1310 let mut wheel = Wheel::new(WheelConfig::default(), BatchConfig::default());
1311
1312 let fake_task_id = TaskId::new();
1314 let postponed = wheel.postpone(fake_task_id, Duration::from_millis(100), None);
1315 assert!(!postponed);
1316 }
1317
1318 #[test]
1319 fn test_postpone_batch() {
1320 use crate::task::{TimerTask, CompletionNotifier};
1321
1322 let mut wheel = Wheel::new(WheelConfig::default(), BatchConfig::default());
1323
1324 let mut task_ids = Vec::new();
1326 for _ in 0..5 {
1327 let callback = CallbackWrapper::new(|| async {});
1328 let (completion_tx, _completion_rx) = tokio::sync::oneshot::channel();
1329 let notifier = CompletionNotifier(completion_tx);
1330 let task = TimerTask::new(Duration::from_millis(50), Some(callback));
1331 let task_id = wheel.insert(task, notifier);
1332 task_ids.push(task_id);
1333 }
1334
1335 let updates: Vec<_> = task_ids
1337 .iter()
1338 .map(|&id| (id, Duration::from_millis(150)))
1339 .collect();
1340 let postponed_count = wheel.postpone_batch(updates);
1341 assert_eq!(postponed_count, 5);
1342
1343 for _ in 0..5 {
1345 let expired = wheel.advance();
1346 assert!(expired.is_empty(), "The first 5 ticks should not have tasks triggered");
1347 }
1348
1349 let mut total_triggered = 0;
1351 for _ in 0..10 {
1352 let expired = wheel.advance();
1353 total_triggered += expired.len();
1354 }
1355 assert_eq!(total_triggered, 5, "There should be 5 tasks triggered on the 15th tick"); }
1357
1358 #[test]
1359 fn test_postpone_batch_partial() {
1360 use crate::task::{TimerTask, CompletionNotifier};
1361
1362 let mut wheel = Wheel::new(WheelConfig::default(), BatchConfig::default());
1363
1364 let mut task_ids = Vec::new();
1366 for _ in 0..10 {
1367 let callback = CallbackWrapper::new(|| async {});
1368 let (completion_tx, _completion_rx) = tokio::sync::oneshot::channel();
1369 let notifier = CompletionNotifier(completion_tx);
1370 let task = TimerTask::new(Duration::from_millis(50), Some(callback));
1371 let task_id = wheel.insert(task, notifier);
1372 task_ids.push(task_id);
1373 }
1374
1375 let fake_task_id = TaskId::new();
1377 let mut updates: Vec<_> = task_ids[0..5]
1378 .iter()
1379 .map(|&id| (id, Duration::from_millis(150)))
1380 .collect();
1381 updates.push((fake_task_id, Duration::from_millis(150)));
1382
1383 let postponed_count = wheel.postpone_batch(updates);
1384 assert_eq!(postponed_count, 5, "There should be 5 tasks successfully postponed (fake_task_id failed)"); let mut triggered_at_50ms = 0;
1388 for _ in 0..5 {
1389 let expired = wheel.advance();
1390 triggered_at_50ms += expired.len();
1391 }
1392 assert_eq!(triggered_at_50ms, 5, "There should be 5 tasks that were not postponed triggered on the 5th tick"); let mut triggered_at_150ms = 0;
1396 for _ in 0..10 {
1397 let expired = wheel.advance();
1398 triggered_at_150ms += expired.len();
1399 }
1400 assert_eq!(triggered_at_150ms, 5, "There should be 5 tasks that were postponed triggered on the 15th tick"); }
1402
1403 #[test]
1404 fn test_multi_round_tasks() {
1405 use crate::task::{TimerTask, CompletionNotifier};
1406
1407 let mut wheel = Wheel::new(WheelConfig::default(), BatchConfig::default());
1408
1409 let callback = CallbackWrapper::new(|| async {});
1413 let (completion_tx, _completion_rx) = tokio::sync::oneshot::channel();
1414 let notifier = CompletionNotifier(completion_tx);
1415 let task = TimerTask::new(Duration::from_secs(120), Some(callback));
1416 let task_id = wheel.insert(task, notifier);
1417
1418 let location = wheel.task_index.get(&task_id).unwrap();
1424 assert_eq!(location.level, 1);
1425
1426 for _ in 0..6400 {
1429 let _expired = wheel.advance();
1430 }
1432
1433 let location = wheel.task_index.get(&task_id);
1435 if let Some(loc) = location {
1436 assert_eq!(loc.level, 1);
1437 }
1438
1439 let mut triggered = false;
1441 for _ in 0..6000 {
1442 let expired = wheel.advance();
1443 if expired.iter().any(|t| t.id == task_id) {
1444 triggered = true;
1445 break;
1446 }
1447 }
1448 assert!(triggered, "Task should be triggered in the second round of L1"); }
1450
1451 #[test]
1452 fn test_minimum_delay() {
1453 use crate::task::{TimerTask, CompletionNotifier};
1454
1455 let mut wheel = Wheel::new(WheelConfig::default(), BatchConfig::default());
1456
1457 let callback = CallbackWrapper::new(|| async {});
1459 let (completion_tx, _completion_rx) = tokio::sync::oneshot::channel();
1460 let notifier = CompletionNotifier(completion_tx);
1461 let task = TimerTask::new(Duration::from_millis(1), Some(callback));
1462 let task_id: TaskId = wheel.insert(task, notifier);
1463
1464 let expired = wheel.advance();
1466 assert_eq!(expired.len(), 1, "Minimum delay task should be triggered after 1 tick"); assert_eq!(expired[0].id, task_id);
1468 }
1469
1470 #[test]
1471 fn test_empty_batch_operations() {
1472 let mut wheel = Wheel::new(WheelConfig::default(), BatchConfig::default());
1473
1474 let task_ids = wheel.insert_batch(vec![]);
1476 assert_eq!(task_ids.len(), 0);
1477
1478 let cancelled = wheel.cancel_batch(&[]);
1480 assert_eq!(cancelled, 0);
1481
1482 let postponed = wheel.postpone_batch(vec![]);
1484 assert_eq!(postponed, 0);
1485 }
1486
1487 #[test]
1488 fn test_postpone_same_task_multiple_times() {
1489 use crate::task::{TimerTask, CompletionNotifier};
1490
1491 let mut wheel = Wheel::new(WheelConfig::default(), BatchConfig::default());
1492
1493 let callback = CallbackWrapper::new(|| async {});
1495 let (completion_tx, _completion_rx) = tokio::sync::oneshot::channel();
1496 let notifier = CompletionNotifier(completion_tx);
1497 let task = TimerTask::new(Duration::from_millis(100), Some(callback));
1498 let task_id = wheel.insert(task, notifier);
1499
1500 let postponed = wheel.postpone(task_id, Duration::from_millis(200), None);
1502 assert!(postponed, "First postpone should succeed");
1503
1504 let postponed = wheel.postpone(task_id, Duration::from_millis(300), None);
1506 assert!(postponed, "Second postpone should succeed");
1507
1508 let postponed = wheel.postpone(task_id, Duration::from_millis(50), None);
1510 assert!(postponed, "Third postpone should succeed");
1511
1512 let mut triggered = false;
1514 for _ in 0..5 {
1515 let expired = wheel.advance();
1516 if !expired.is_empty() {
1517 assert_eq!(expired.len(), 1);
1518 assert_eq!(expired[0].id, task_id);
1519 triggered = true;
1520 break;
1521 }
1522 }
1523 assert!(triggered, "Task should be triggered at the last postpone time"); }
1525
1526 #[test]
1527 fn test_advance_empty_slots() {
1528 let mut wheel = Wheel::new(WheelConfig::default(), BatchConfig::default());
1529
1530 for _ in 0..100 {
1532 let expired = wheel.advance();
1533 assert!(expired.is_empty(), "Empty slots should not return any tasks");
1534 }
1535
1536 assert_eq!(wheel.current_tick(), 100, "current_tick should correctly increment"); }
1538
1539 #[test]
1540 fn test_cancel_after_postpone() {
1541 use crate::task::{TimerTask, CompletionNotifier};
1542
1543 let mut wheel = Wheel::new(WheelConfig::default(), BatchConfig::default());
1544
1545 let callback = CallbackWrapper::new(|| async {});
1547 let (completion_tx, _completion_rx) = tokio::sync::oneshot::channel();
1548 let notifier = CompletionNotifier(completion_tx);
1549 let task = TimerTask::new(Duration::from_millis(100), Some(callback));
1550 let task_id = wheel.insert(task, notifier);
1551
1552 let postponed = wheel.postpone(task_id, Duration::from_millis(200), None);
1554 assert!(postponed, "Postpone should succeed");
1555
1556 let cancelled = wheel.cancel(task_id);
1558 assert!(cancelled, "Cancel should succeed");
1559
1560 for _ in 0..20 {
1562 let expired = wheel.advance();
1563 assert!(expired.is_empty(), "Cancelled task should not trigger"); }
1565
1566 assert!(wheel.is_empty(), "Wheel should be empty"); }
1568
1569 #[test]
1570 fn test_slot_boundary() {
1571 use crate::task::{TimerTask, CompletionNotifier};
1572
1573 let mut wheel = Wheel::new(WheelConfig::default(), BatchConfig::default());
1574
1575 let callback1 = CallbackWrapper::new(|| async {});
1578 let (tx1, _rx1) = tokio::sync::oneshot::channel();
1579 let task1 = TimerTask::new(Duration::from_millis(10), Some(callback1));
1580 let task_id_1 = wheel.insert(task1, CompletionNotifier(tx1));
1581
1582 let callback2 = CallbackWrapper::new(|| async {});
1584 let (tx2, _rx2) = tokio::sync::oneshot::channel();
1585 let task2 = TimerTask::new(Duration::from_millis(5110), Some(callback2));
1586 let task_id_2 = wheel.insert(task2, CompletionNotifier(tx2));
1587
1588 let expired = wheel.advance();
1590 assert_eq!(expired.len(), 1, "First task should trigger on tick 1"); assert_eq!(expired[0].id, task_id_1);
1592
1593 let mut triggered = false;
1595 for i in 0..510 {
1596 let expired = wheel.advance();
1597 if !expired.is_empty() {
1598 assert_eq!(expired.len(), 1, "The {}th advance should trigger the second task", i + 2); assert_eq!(expired[0].id, task_id_2);
1600 triggered = true;
1601 break;
1602 }
1603 }
1604 assert!(triggered, "Second task should trigger on tick 511"); assert!(wheel.is_empty(), "All tasks should have been triggered"); }
1608
1609 #[test]
1610 fn test_batch_cancel_small_threshold() {
1611 use crate::task::{TimerTask, CompletionNotifier};
1612
1613 let batch_config = BatchConfig {
1615 small_batch_threshold: 5,
1616 };
1617 let mut wheel = Wheel::new(WheelConfig::default(), batch_config);
1618
1619 let mut task_ids = Vec::new();
1621 for _ in 0..10 {
1622 let callback = CallbackWrapper::new(|| async {});
1623 let (tx, _rx) = tokio::sync::oneshot::channel();
1624 let task = TimerTask::new(Duration::from_millis(100), Some(callback));
1625 let task_id = wheel.insert(task, CompletionNotifier(tx));
1626 task_ids.push(task_id);
1627 }
1628
1629 let cancelled = wheel.cancel_batch(&task_ids[0..3]);
1631 assert_eq!(cancelled, 3);
1632
1633 let cancelled = wheel.cancel_batch(&task_ids[3..10]);
1635 assert_eq!(cancelled, 7);
1636
1637 assert!(wheel.is_empty()); }
1639
1640 #[test]
1641 fn test_task_id_uniqueness() {
1642 use crate::task::{TimerTask, CompletionNotifier};
1643
1644 let mut wheel = Wheel::new(WheelConfig::default(), BatchConfig::default());
1645
1646 let mut task_ids = std::collections::HashSet::new();
1648 for _ in 0..100 {
1649 let callback = CallbackWrapper::new(|| async {});
1650 let (tx, _rx) = tokio::sync::oneshot::channel();
1651 let task = TimerTask::new(Duration::from_millis(100), Some(callback));
1652 let task_id = wheel.insert(task, CompletionNotifier(tx));
1653
1654 assert!(task_ids.insert(task_id), "TaskId should be unique"); }
1656
1657 assert_eq!(task_ids.len(), 100);
1658 }
1659}
1660