kestrel_protocol_timer/timer.rs
1use crate::config::{ServiceConfig, WheelConfig};
2use crate::task::{CallbackWrapper, TaskId, TimerCallback};
3use crate::wheel::Wheel;
4use parking_lot::Mutex;
5use std::sync::Arc;
6use std::time::Duration;
7use tokio::sync::oneshot;
8use tokio::task::JoinHandle;
9
10/// 完成通知接收器,用于接收定时器完成通知
11pub struct CompletionReceiver(pub oneshot::Receiver<()>);
12
13/// 定时器句柄,用于管理定时器的生命周期
14///
15/// 注意:此类型不实现 Clone,以防止重复取消同一个定时器。
16/// 每个定时器只应有一个所有者。
17pub struct TimerHandle {
18 pub(crate) task_id: TaskId,
19 pub(crate) wheel: Arc<Mutex<Wheel>>,
20 pub(crate) completion_rx: CompletionReceiver,
21}
22
23impl TimerHandle {
24 pub(crate) fn new(task_id: TaskId, wheel: Arc<Mutex<Wheel>>, completion_rx: oneshot::Receiver<()>) -> Self {
25 Self { task_id, wheel, completion_rx: CompletionReceiver(completion_rx) }
26 }
27
28 /// 取消定时器
29 ///
30 /// # 返回
31 /// 如果任务存在且成功取消返回 true,否则返回 false
32 ///
33 /// # 示例
34 /// ```no_run
35 /// # use kestrel_protocol_timer::{TimerWheel, TimerTask};
36 /// # use std::time::Duration;
37 /// # #[tokio::main]
38 /// # async fn main() {
39 /// let timer = TimerWheel::with_defaults();
40 /// let task = TimerWheel::create_task(Duration::from_secs(1), || async {});
41 /// let handle = timer.register(task);
42 ///
43 /// // 取消定时器
44 /// let success = handle.cancel();
45 /// println!("取消成功: {}", success);
46 /// # }
47 /// ```
48 pub fn cancel(&self) -> bool {
49 let mut wheel = self.wheel.lock();
50 wheel.cancel(self.task_id)
51 }
52
53 /// 获取完成通知接收器的可变引用
54 ///
55 /// # 示例
56 /// ```no_run
57 /// # use kestrel_protocol_timer::{TimerWheel, TimerTask};
58 /// # use std::time::Duration;
59 /// # #[tokio::main]
60 /// # async fn main() {
61 /// let timer = TimerWheel::with_defaults();
62 /// let task = TimerWheel::create_task(Duration::from_secs(1), || async {
63 /// println!("Timer fired!");
64 /// });
65 /// let handle = timer.register(task);
66 ///
67 /// // 等待定时器完成(使用 into_completion_receiver 消耗句柄)
68 /// handle.into_completion_receiver().0.await.ok();
69 /// println!("Timer completed!");
70 /// # }
71 /// ```
72 pub fn completion_receiver(&mut self) -> &mut CompletionReceiver {
73 &mut self.completion_rx
74 }
75
76 /// 消耗句柄,返回完成通知接收器
77 ///
78 /// # 示例
79 /// ```no_run
80 /// # use kestrel_protocol_timer::{TimerWheel, TimerTask};
81 /// # use std::time::Duration;
82 /// # #[tokio::main]
83 /// # async fn main() {
84 /// let timer = TimerWheel::with_defaults();
85 /// let task = TimerWheel::create_task(Duration::from_secs(1), || async {
86 /// println!("Timer fired!");
87 /// });
88 /// let handle = timer.register(task);
89 ///
90 /// // 等待定时器完成
91 /// handle.into_completion_receiver().0.await.ok();
92 /// println!("Timer completed!");
93 /// # }
94 /// ```
95 pub fn into_completion_receiver(self) -> CompletionReceiver {
96 self.completion_rx
97 }
98}
99
100/// 批量定时器句柄,用于管理批量调度的定时器
101///
102/// 通过共享 Wheel 引用减少内存开销,同时提供批量操作和迭代器访问能力。
103///
104/// 注意:此类型不实现 Clone,以防止重复取消同一批定时器。
105/// 如需访问单个定时器句柄,请使用 `into_iter()` 或 `into_handles()` 进行转换。
106pub struct BatchHandle {
107 pub(crate) task_ids: Vec<TaskId>,
108 pub(crate) wheel: Arc<Mutex<Wheel>>,
109 pub(crate) completion_rxs: Vec<oneshot::Receiver<()>>,
110}
111
112impl BatchHandle {
113 pub(crate) fn new(task_ids: Vec<TaskId>, wheel: Arc<Mutex<Wheel>>, completion_rxs: Vec<oneshot::Receiver<()>>) -> Self {
114 Self { task_ids, wheel, completion_rxs }
115 }
116
117 /// 批量取消所有定时器
118 ///
119 /// # 返回
120 /// 成功取消的任务数量
121 ///
122 /// # 示例
123 /// ```no_run
124 /// # use kestrel_protocol_timer::{TimerWheel, TimerTask};
125 /// # use std::time::Duration;
126 /// # #[tokio::main]
127 /// # async fn main() {
128 /// let timer = TimerWheel::with_defaults();
129 /// let tasks = TimerWheel::create_batch(
130 /// (0..10).map(|_| (Duration::from_secs(1), || async {})).collect()
131 /// );
132 /// let batch = timer.register_batch(tasks);
133 ///
134 /// let cancelled = batch.cancel_all();
135 /// println!("取消了 {} 个定时器", cancelled);
136 /// # }
137 /// ```
138 pub fn cancel_all(self) -> usize {
139 let mut wheel = self.wheel.lock();
140 wheel.cancel_batch(&self.task_ids)
141 }
142
143 /// 将批量句柄转换为单个定时器句柄的 Vec
144 ///
145 /// 消耗 BatchHandle,为每个任务创建独立的 TimerHandle。
146 ///
147 /// # 示例
148 /// ```no_run
149 /// # use kestrel_protocol_timer::{TimerWheel, TimerTask};
150 /// # use std::time::Duration;
151 /// # #[tokio::main]
152 /// # async fn main() {
153 /// let timer = TimerWheel::with_defaults();
154 /// let tasks = TimerWheel::create_batch(
155 /// (0..3).map(|_| (Duration::from_secs(1), || async {})).collect()
156 /// );
157 /// let batch = timer.register_batch(tasks);
158 ///
159 /// // 转换为独立的句柄
160 /// let handles = batch.into_handles();
161 /// for handle in handles {
162 /// // 可以单独操作每个句柄
163 /// }
164 /// # }
165 /// ```
166 pub fn into_handles(self) -> Vec<TimerHandle> {
167 self.task_ids
168 .into_iter()
169 .zip(self.completion_rxs.into_iter())
170 .map(|(task_id, rx)| {
171 TimerHandle::new(task_id, self.wheel.clone(), rx)
172 })
173 .collect()
174 }
175
176 /// 获取批量任务的数量
177 pub fn len(&self) -> usize {
178 self.task_ids.len()
179 }
180
181 /// 检查批量任务是否为空
182 pub fn is_empty(&self) -> bool {
183 self.task_ids.is_empty()
184 }
185
186 /// 获取所有任务 ID 的引用
187 pub fn task_ids(&self) -> &[TaskId] {
188 &self.task_ids
189 }
190
191 /// 获取所有完成通知接收器的引用
192 ///
193 /// # 返回
194 /// 所有任务的完成通知接收器列表引用
195 pub fn completion_receivers(&mut self) -> &mut Vec<oneshot::Receiver<()>> {
196 &mut self.completion_rxs
197 }
198
199 /// 消耗句柄,返回所有完成通知接收器
200 ///
201 /// # 返回
202 /// 所有任务的完成通知接收器列表
203 ///
204 /// # 示例
205 /// ```no_run
206 /// # use kestrel_protocol_timer::{TimerWheel, TimerTask};
207 /// # use std::time::Duration;
208 /// # #[tokio::main]
209 /// # async fn main() {
210 /// let timer = TimerWheel::with_defaults();
211 /// let tasks = TimerWheel::create_batch(
212 /// (0..3).map(|_| (Duration::from_secs(1), || async {})).collect()
213 /// );
214 /// let batch = timer.register_batch(tasks);
215 ///
216 /// // 获取所有完成通知接收器
217 /// let receivers = batch.into_completion_receivers();
218 /// for rx in receivers {
219 /// tokio::spawn(async move {
220 /// if rx.await.is_ok() {
221 /// println!("A timer completed!");
222 /// }
223 /// });
224 /// }
225 /// # }
226 /// ```
227 pub fn into_completion_receivers(self) -> Vec<oneshot::Receiver<()>> {
228 self.completion_rxs
229 }
230}
231
232/// 实现 IntoIterator,允许直接迭代 BatchHandle
233///
234/// # 示例
235/// ```no_run
236/// # use kestrel_protocol_timer::{TimerWheel, TimerTask};
237/// # use std::time::Duration;
238/// # #[tokio::main]
239/// # async fn main() {
240/// let timer = TimerWheel::with_defaults();
241/// let tasks = TimerWheel::create_batch(
242/// (0..3).map(|_| (Duration::from_secs(1), || async {})).collect()
243/// );
244/// let batch = timer.register_batch(tasks);
245///
246/// // 直接迭代,每个元素都是独立的 TimerHandle
247/// for handle in batch {
248/// // 可以单独操作每个句柄
249/// }
250/// # }
251/// ```
252impl IntoIterator for BatchHandle {
253 type Item = TimerHandle;
254 type IntoIter = BatchHandleIter;
255
256 fn into_iter(self) -> Self::IntoIter {
257 BatchHandleIter {
258 task_ids: self.task_ids.into_iter(),
259 completion_rxs: self.completion_rxs.into_iter(),
260 wheel: self.wheel,
261 }
262 }
263}
264
265/// BatchHandle 的迭代器
266pub struct BatchHandleIter {
267 task_ids: std::vec::IntoIter<TaskId>,
268 completion_rxs: std::vec::IntoIter<oneshot::Receiver<()>>,
269 wheel: Arc<Mutex<Wheel>>,
270}
271
272impl Iterator for BatchHandleIter {
273 type Item = TimerHandle;
274
275 fn next(&mut self) -> Option<Self::Item> {
276 match (self.task_ids.next(), self.completion_rxs.next()) {
277 (Some(task_id), Some(rx)) => {
278 Some(TimerHandle::new(task_id, self.wheel.clone(), rx))
279 }
280 _ => None,
281 }
282 }
283
284 fn size_hint(&self) -> (usize, Option<usize>) {
285 self.task_ids.size_hint()
286 }
287}
288
289impl ExactSizeIterator for BatchHandleIter {
290 fn len(&self) -> usize {
291 self.task_ids.len()
292 }
293}
294
295/// 时间轮定时器管理器
296pub struct TimerWheel {
297 /// 时间轮唯一标识符
298
299 /// 时间轮实例(使用 Arc<Mutex> 包装以支持多线程访问)
300 wheel: Arc<Mutex<Wheel>>,
301
302 /// 后台 tick 循环任务句柄
303 tick_handle: Option<JoinHandle<()>>,
304}
305
306impl TimerWheel {
307 /// 创建新的定时器管理器
308 ///
309 /// # 参数
310 /// - `config`: 时间轮配置(已经过验证)
311 ///
312 /// # 示例
313 /// ```no_run
314 /// use kestrel_protocol_timer::{TimerWheel, WheelConfig, TimerTask};
315 /// use std::time::Duration;
316 ///
317 /// #[tokio::main]
318 /// async fn main() {
319 /// let config = WheelConfig::builder()
320 /// .tick_duration(Duration::from_millis(10))
321 /// .slot_count(512)
322 /// .build()
323 /// .unwrap();
324 /// let timer = TimerWheel::new(config);
325 ///
326 /// // 使用两步式 API
327 /// let task = TimerWheel::create_task(Duration::from_secs(1), || async {});
328 /// let handle = timer.register(task);
329 /// }
330 /// ```
331 pub fn new(config: WheelConfig) -> Self {
332 let tick_duration = config.tick_duration;
333 let wheel = Wheel::new(config);
334 let wheel = Arc::new(Mutex::new(wheel));
335 let wheel_clone = wheel.clone();
336
337 // 启动后台 tick 循环
338 let tick_handle = tokio::spawn(async move {
339 Self::tick_loop(wheel_clone, tick_duration).await;
340 });
341
342 Self {
343 wheel,
344 tick_handle: Some(tick_handle),
345 }
346 }
347
348 /// 创建带默认配置的定时器管理器
349 /// - tick 时长: 10ms
350 /// - 槽位数量: 512
351 ///
352 /// # 示例
353 /// ```no_run
354 /// use kestrel_protocol_timer::TimerWheel;
355 ///
356 /// #[tokio::main]
357 /// async fn main() {
358 /// let timer = TimerWheel::with_defaults();
359 /// }
360 /// ```
361 pub fn with_defaults() -> Self {
362 Self::new(WheelConfig::default())
363 }
364
365 /// 创建与此时间轮绑定的 TimerService(使用默认配置)
366 ///
367 /// # 返回
368 /// 绑定到此时间轮的 TimerService 实例
369 ///
370 /// # 示例
371 /// ```no_run
372 /// use kestrel_protocol_timer::{TimerWheel, TimerService};
373 /// use std::time::Duration;
374 ///
375 /// #[tokio::main]
376 /// async fn main() {
377 /// let timer = TimerWheel::with_defaults();
378 /// let mut service = timer.create_service();
379 ///
380 /// // 使用两步式 API 通过 service 批量调度定时器
381 /// let tasks = TimerService::create_batch(
382 /// (0..5).map(|_| (Duration::from_millis(100), || async {})).collect()
383 /// );
384 /// service.register_batch(tasks).await;
385 ///
386 /// // 接收超时通知
387 /// let mut rx = service.take_receiver().unwrap();
388 /// while let Some(task_id) = rx.recv().await {
389 /// println!("Task {:?} completed", task_id);
390 /// }
391 /// }
392 /// ```
393 pub fn create_service(&self) -> crate::service::TimerService {
394 crate::service::TimerService::new(self.wheel.clone(), ServiceConfig::default())
395 }
396
397 /// 创建与此时间轮绑定的 TimerService(使用自定义配置)
398 ///
399 /// # 参数
400 /// - `config`: 服务配置
401 ///
402 /// # 返回
403 /// 绑定到此时间轮的 TimerService 实例
404 ///
405 /// # 示例
406 /// ```no_run
407 /// use kestrel_protocol_timer::{TimerWheel, ServiceConfig};
408 ///
409 /// #[tokio::main]
410 /// async fn main() {
411 /// let timer = TimerWheel::with_defaults();
412 /// let config = ServiceConfig::builder()
413 /// .command_channel_capacity(1024)
414 /// .timeout_channel_capacity(2000)
415 /// .build()
416 /// .unwrap();
417 /// let service = timer.create_service_with_config(config);
418 /// }
419 /// ```
420 pub fn create_service_with_config(&self, config: ServiceConfig) -> crate::service::TimerService {
421 crate::service::TimerService::new(self.wheel.clone(), config)
422 }
423
424 /// 创建定时器任务(申请阶段)
425 ///
426 /// # 参数
427 /// - `delay`: 延迟时间
428 /// - `callback`: 实现了 TimerCallback trait 的回调对象
429 ///
430 /// # 返回
431 /// 返回 TimerTask,需要通过 `register()` 注册到时间轮
432 ///
433 /// # 示例
434 /// ```no_run
435 /// use kestrel_protocol_timer::{TimerWheel, TimerTask};
436 /// use std::time::Duration;
437 ///
438 /// #[tokio::main]
439 /// async fn main() {
440 /// let timer = TimerWheel::with_defaults();
441 ///
442 /// // 步骤 1: 创建任务
443 /// let task = TimerWheel::create_task(Duration::from_secs(1), || async {
444 /// println!("Timer fired!");
445 /// });
446 ///
447 /// // 获取任务 ID
448 /// let task_id = task.get_id();
449 /// println!("Created task: {:?}", task_id);
450 ///
451 /// // 步骤 2: 注册任务
452 /// let handle = timer.register(task);
453 /// }
454 /// ```
455 pub fn create_task<C>(delay: Duration, callback: C) -> crate::task::TimerTask
456 where
457 C: TimerCallback,
458 {
459 use std::sync::Arc;
460 let callback_wrapper = Arc::new(callback) as CallbackWrapper;
461 crate::task::TimerTask::new(delay, Some(callback_wrapper))
462 }
463
464 /// 批量创建定时器任务(申请阶段)
465 ///
466 /// # 参数
467 /// - `callbacks`: (延迟时间, 回调) 的元组列表
468 ///
469 /// # 返回
470 /// 返回 TimerTask 列表,需要通过 `register_batch()` 注册到时间轮
471 ///
472 /// # 示例
473 /// ```no_run
474 /// use kestrel_protocol_timer::{TimerWheel, TimerTask};
475 /// use std::time::Duration;
476 /// use std::sync::Arc;
477 /// use std::sync::atomic::{AtomicU32, Ordering};
478 ///
479 /// #[tokio::main]
480 /// async fn main() {
481 /// let timer = TimerWheel::with_defaults();
482 /// let counter = Arc::new(AtomicU32::new(0));
483 ///
484 /// // 步骤 1: 批量创建任务
485 /// let callbacks: Vec<_> = (0..3)
486 /// .map(|i| {
487 /// let counter = Arc::clone(&counter);
488 /// let delay = Duration::from_millis(100 + i * 100);
489 /// let callback = move || {
490 /// let counter = Arc::clone(&counter);
491 /// async move {
492 /// counter.fetch_add(1, Ordering::SeqCst);
493 /// }
494 /// };
495 /// (delay, callback)
496 /// })
497 /// .collect();
498 ///
499 /// let tasks = TimerWheel::create_batch(callbacks);
500 /// println!("Created {} tasks", tasks.len());
501 ///
502 /// // 步骤 2: 批量注册任务
503 /// let batch = timer.register_batch(tasks);
504 /// }
505 /// ```
506 pub fn create_batch<C>(callbacks: Vec<(Duration, C)>) -> Vec<crate::task::TimerTask>
507 where
508 C: TimerCallback,
509 {
510 use std::sync::Arc;
511 callbacks
512 .into_iter()
513 .map(|(delay, callback)| {
514 let callback_wrapper = Arc::new(callback) as CallbackWrapper;
515 crate::task::TimerTask::new(delay, Some(callback_wrapper))
516 })
517 .collect()
518 }
519
520 /// 注册定时器任务到时间轮(注册阶段)
521 ///
522 /// # 参数
523 /// - `task`: 通过 `create_task()` 创建的任务
524 ///
525 /// # 返回
526 /// 返回定时器句柄,可用于取消定时器和接收完成通知
527 ///
528 /// # 示例
529 /// ```no_run
530 /// use kestrel_protocol_timer::{TimerWheel, TimerTask};
531 /// use std::time::Duration;
532 ///
533 /// #[tokio::main]
534 /// async fn main() {
535 /// let timer = TimerWheel::with_defaults();
536 ///
537 /// let task = TimerWheel::create_task(Duration::from_secs(1), || async {
538 /// println!("Timer fired!");
539 /// });
540 /// let task_id = task.get_id();
541 ///
542 /// let handle = timer.register(task);
543 ///
544 /// // 等待定时器完成
545 /// handle.into_completion_receiver().0.await.ok();
546 /// }
547 /// ```
548 #[inline]
549 pub fn register(&self, task: crate::task::TimerTask) -> TimerHandle {
550 let (completion_tx, completion_rx) = oneshot::channel();
551 let notifier = crate::task::CompletionNotifier(completion_tx);
552
553 let delay = task.delay;
554 let task_id = task.id;
555
556 // 单次加锁完成所有操作
557 {
558 let mut wheel_guard = self.wheel.lock();
559 wheel_guard.insert(delay, task, notifier);
560 }
561
562 TimerHandle::new(task_id, self.wheel.clone(), completion_rx)
563 }
564
565 /// 批量注册定时器任务到时间轮(注册阶段)
566 ///
567 /// # 参数
568 /// - `tasks`: 通过 `create_batch()` 创建的任务列表
569 ///
570 /// # 返回
571 /// 返回批量定时器句柄
572 ///
573 /// # 示例
574 /// ```no_run
575 /// use kestrel_protocol_timer::{TimerWheel, TimerTask};
576 /// use std::time::Duration;
577 ///
578 /// #[tokio::main]
579 /// async fn main() {
580 /// let timer = TimerWheel::with_defaults();
581 ///
582 /// let callbacks: Vec<_> = (0..3)
583 /// .map(|_| (Duration::from_secs(1), || async {}))
584 /// .collect();
585 /// let tasks = TimerWheel::create_batch(callbacks);
586 ///
587 /// let batch = timer.register_batch(tasks);
588 /// println!("Registered {} timers", batch.len());
589 /// }
590 /// ```
591 #[inline]
592 pub fn register_batch(&self, tasks: Vec<crate::task::TimerTask>) -> BatchHandle {
593 let task_count = tasks.len();
594 let mut completion_rxs = Vec::with_capacity(task_count);
595 let mut task_ids = Vec::with_capacity(task_count);
596 let mut prepared_tasks = Vec::with_capacity(task_count);
597
598 // 步骤1: 准备所有 channels 和 notifiers(无锁)
599 // 优化:使用 for 循环代替 map + collect,避免闭包捕获开销
600 for task in tasks {
601 let (completion_tx, completion_rx) = oneshot::channel();
602 let notifier = crate::task::CompletionNotifier(completion_tx);
603
604 task_ids.push(task.id);
605 completion_rxs.push(completion_rx);
606 prepared_tasks.push((task.delay, task, notifier));
607 }
608
609 // 步骤2: 单次加锁,批量插入
610 {
611 let mut wheel_guard = self.wheel.lock();
612 wheel_guard.insert_batch(prepared_tasks);
613 }
614
615 BatchHandle::new(task_ids, self.wheel.clone(), completion_rxs)
616 }
617
618 /// 取消定时器
619 ///
620 /// # 参数
621 /// - `task_id`: 任务 ID
622 ///
623 /// # 返回
624 /// 如果任务存在且成功取消返回 true,否则返回 false
625 ///
626 /// # 示例
627 /// ```no_run
628 /// use kestrel_protocol_timer::{TimerWheel, TimerTask};
629 /// use std::time::Duration;
630 ///
631 /// #[tokio::main]
632 /// async fn main() {
633 /// let timer = TimerWheel::with_defaults();
634 ///
635 /// let task = TimerWheel::create_task(Duration::from_secs(10), || async {});
636 /// let task_id = task.get_id();
637 /// let _handle = timer.register(task);
638 ///
639 /// // 使用任务 ID 取消
640 /// let cancelled = timer.cancel(task_id);
641 /// println!("取消成功: {}", cancelled);
642 /// }
643 /// ```
644 #[inline]
645 pub fn cancel(&self, task_id: TaskId) -> bool {
646 let mut wheel = self.wheel.lock();
647 wheel.cancel(task_id)
648 }
649
650 /// 批量取消定时器
651 ///
652 /// # 参数
653 /// - `task_ids`: 要取消的任务 ID 列表
654 ///
655 /// # 返回
656 /// 成功取消的任务数量
657 ///
658 /// # 性能优势
659 /// - 批量处理减少锁竞争
660 /// - 内部优化批量取消操作
661 ///
662 /// # 示例
663 /// ```no_run
664 /// use kestrel_protocol_timer::{TimerWheel, TimerTask};
665 /// use std::time::Duration;
666 ///
667 /// #[tokio::main]
668 /// async fn main() {
669 /// let timer = TimerWheel::with_defaults();
670 ///
671 /// // 创建多个定时器
672 /// let task1 = TimerWheel::create_task(Duration::from_secs(10), || async {});
673 /// let task2 = TimerWheel::create_task(Duration::from_secs(10), || async {});
674 /// let task3 = TimerWheel::create_task(Duration::from_secs(10), || async {});
675 ///
676 /// let task_ids = vec![task1.get_id(), task2.get_id(), task3.get_id()];
677 ///
678 /// let _h1 = timer.register(task1);
679 /// let _h2 = timer.register(task2);
680 /// let _h3 = timer.register(task3);
681 ///
682 /// // 批量取消
683 /// let cancelled = timer.cancel_batch(&task_ids);
684 /// println!("已取消 {} 个定时器", cancelled);
685 /// }
686 /// ```
687 #[inline]
688 pub fn cancel_batch(&self, task_ids: &[TaskId]) -> usize {
689 let mut wheel = self.wheel.lock();
690 wheel.cancel_batch(task_ids)
691 }
692
693 /// 核心 tick 循环
694 async fn tick_loop(wheel: Arc<Mutex<Wheel>>, tick_duration: Duration) {
695 let mut interval = tokio::time::interval(tick_duration);
696 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
697
698 loop {
699 interval.tick().await;
700
701 // 推进时间轮并获取到期任务
702 let expired_tasks = {
703 let mut wheel_guard = wheel.lock();
704 wheel_guard.advance()
705 };
706
707 // 执行到期任务
708 for task in expired_tasks {
709 let callback = task.get_callback();
710
711 // 移动task的所有权来获取completion_notifier
712 let notifier = task.completion_notifier;
713
714 // 只有注册过的任务才有 notifier
715 if let Some(notifier) = notifier {
716 // 在独立的 tokio 任务中执行回调,并在回调完成后发送通知
717 if let Some(callback) = callback {
718 tokio::spawn(async move {
719 // 执行回调
720 let future = callback.call();
721 future.await;
722
723 // 回调执行完成后发送通知
724 let _ = notifier.0.send(());
725 });
726 } else {
727 // 如果没有回调,立即发送完成通知
728 let _ = notifier.0.send(());
729 }
730 }
731 }
732 }
733 }
734
735 /// 停止定时器管理器
736 pub async fn shutdown(mut self) {
737 if let Some(handle) = self.tick_handle.take() {
738 handle.abort();
739 let _ = handle.await;
740 }
741 }
742}
743
744impl Drop for TimerWheel {
745 fn drop(&mut self) {
746 if let Some(handle) = self.tick_handle.take() {
747 handle.abort();
748 }
749 }
750}
751
752#[cfg(test)]
753mod tests {
754 use super::*;
755 use std::sync::atomic::{AtomicU32, Ordering};
756
757 #[tokio::test]
758 async fn test_timer_creation() {
759 let _timer = TimerWheel::with_defaults();
760 }
761
762 #[tokio::test]
763 async fn test_schedule_once() {
764 use std::sync::Arc;
765 let timer = TimerWheel::with_defaults();
766 let counter = Arc::new(AtomicU32::new(0));
767 let counter_clone = Arc::clone(&counter);
768
769 let task = TimerWheel::create_task(
770 Duration::from_millis(50),
771 move || {
772 let counter = Arc::clone(&counter_clone);
773 async move {
774 counter.fetch_add(1, Ordering::SeqCst);
775 }
776 },
777 );
778 let _handle = timer.register(task);
779
780 // 等待定时器触发
781 tokio::time::sleep(Duration::from_millis(100)).await;
782 assert_eq!(counter.load(Ordering::SeqCst), 1);
783 }
784
785 #[tokio::test]
786 async fn test_cancel_timer() {
787 use std::sync::Arc;
788 let timer = TimerWheel::with_defaults();
789 let counter = Arc::new(AtomicU32::new(0));
790 let counter_clone = Arc::clone(&counter);
791
792 let task = TimerWheel::create_task(
793 Duration::from_millis(100),
794 move || {
795 let counter = Arc::clone(&counter_clone);
796 async move {
797 counter.fetch_add(1, Ordering::SeqCst);
798 }
799 },
800 );
801 let handle = timer.register(task);
802
803 // 立即取消
804 let cancel_result = handle.cancel();
805 assert!(cancel_result);
806
807 // 等待足够长时间确保定时器不会触发
808 tokio::time::sleep(Duration::from_millis(200)).await;
809 assert_eq!(counter.load(Ordering::SeqCst), 0);
810 }
811
812 #[tokio::test]
813 async fn test_cancel_immediate() {
814 use std::sync::Arc;
815 let timer = TimerWheel::with_defaults();
816 let counter = Arc::new(AtomicU32::new(0));
817 let counter_clone = Arc::clone(&counter);
818
819 let task = TimerWheel::create_task(
820 Duration::from_millis(100),
821 move || {
822 let counter = Arc::clone(&counter_clone);
823 async move {
824 counter.fetch_add(1, Ordering::SeqCst);
825 }
826 },
827 );
828 let handle = timer.register(task);
829
830 // 立即取消
831 let cancel_result = handle.cancel();
832 assert!(cancel_result);
833
834 // 等待足够长时间确保定时器不会触发
835 tokio::time::sleep(Duration::from_millis(200)).await;
836 assert_eq!(counter.load(Ordering::SeqCst), 0);
837 }
838}
839