kestrel_protocol_timer/timer.rs
1use crate::config::{ServiceConfig, WheelConfig};
2use crate::task::{CallbackWrapper, CompletionNotifier, TaskId, TimerCallback, TimerTask};
3use crate::wheel::Wheel;
4use parking_lot::Mutex;
5use std::sync::Arc;
6use std::time::Duration;
7use tokio::sync::oneshot;
8use tokio::task::JoinHandle;
9
10/// 完成通知接收器,用于接收定时器完成通知
11pub struct CompletionReceiver(pub oneshot::Receiver<()>);
12
13/// 定时器句柄,用于管理定时器的生命周期
14///
15/// 注意:此类型不实现 Clone,以防止重复取消同一个定时器。
16/// 每个定时器只应有一个所有者。
17pub struct TimerHandle {
18 pub(crate) task_id: TaskId,
19 pub(crate) wheel: Arc<Mutex<Wheel>>,
20 pub(crate) completion_rx: CompletionReceiver,
21}
22
23impl TimerHandle {
24 pub(crate) fn new(task_id: TaskId, wheel: Arc<Mutex<Wheel>>, completion_rx: oneshot::Receiver<()>) -> Self {
25 Self { task_id, wheel, completion_rx: CompletionReceiver(completion_rx) }
26 }
27
28 /// 取消定时器
29 ///
30 /// # 返回
31 /// 如果任务存在且成功取消返回 true,否则返回 false
32 ///
33 /// # 示例
34 /// ```no_run
35 /// # use kestrel_protocol_timer::TimerWheel;
36 /// # use std::time::Duration;
37 /// # #[tokio::main]
38 /// # async fn main() {
39 /// let timer = TimerWheel::with_defaults();
40 /// let handle = timer.schedule_once(Duration::from_secs(1), || async {}).await;
41 ///
42 /// // 取消定时器
43 /// let success = handle.cancel();
44 /// println!("取消成功: {}", success);
45 /// # }
46 /// ```
47 pub fn cancel(&self) -> bool {
48 let mut wheel = self.wheel.lock();
49 wheel.cancel(self.task_id)
50 }
51
52 /// 获取任务 ID
53 pub fn task_id(&self) -> TaskId {
54 self.task_id
55 }
56
57 /// 获取完成通知接收器的可变引用
58 ///
59 /// # 示例
60 /// ```no_run
61 /// # use kestrel_protocol_timer::TimerWheel;
62 /// # use std::time::Duration;
63 /// # #[tokio::main]
64 /// # async fn main() {
65 /// let timer = TimerWheel::with_defaults();
66 /// let handle = timer.schedule_once(Duration::from_secs(1), || async {
67 /// println!("Timer fired!");
68 /// }).await;
69 ///
70 /// // 等待定时器完成(使用 into_completion_receiver 消耗句柄)
71 /// handle.into_completion_receiver().0.await.ok();
72 /// println!("Timer completed!");
73 /// # }
74 /// ```
75 pub fn completion_receiver(&mut self) -> &mut CompletionReceiver {
76 &mut self.completion_rx
77 }
78
79 /// 消耗句柄,返回完成通知接收器
80 ///
81 /// # 示例
82 /// ```no_run
83 /// # use kestrel_protocol_timer::TimerWheel;
84 /// # use std::time::Duration;
85 /// # #[tokio::main]
86 /// # async fn main() {
87 /// let timer = TimerWheel::with_defaults();
88 /// let handle = timer.schedule_once(Duration::from_secs(1), || async {
89 /// println!("Timer fired!");
90 /// }).await;
91 ///
92 /// // 等待定时器完成
93 /// handle.into_completion_receiver().0.await.ok();
94 /// println!("Timer completed!");
95 /// # }
96 /// ```
97 pub fn into_completion_receiver(self) -> CompletionReceiver {
98 self.completion_rx
99 }
100}
101
102/// 批量定时器句柄,用于管理批量调度的定时器
103///
104/// 通过共享 Wheel 引用减少内存开销,同时提供批量操作和迭代器访问能力。
105///
106/// 注意:此类型不实现 Clone,以防止重复取消同一批定时器。
107/// 如需访问单个定时器句柄,请使用 `into_iter()` 或 `into_handles()` 进行转换。
108pub struct BatchHandle {
109 pub(crate) task_ids: Vec<TaskId>,
110 pub(crate) wheel: Arc<Mutex<Wheel>>,
111 pub(crate) completion_rxs: Vec<oneshot::Receiver<()>>,
112}
113
114impl BatchHandle {
115 pub(crate) fn new(task_ids: Vec<TaskId>, wheel: Arc<Mutex<Wheel>>, completion_rxs: Vec<oneshot::Receiver<()>>) -> Self {
116 Self { task_ids, wheel, completion_rxs }
117 }
118
119 /// 批量取消所有定时器
120 ///
121 /// # 返回
122 /// 成功取消的任务数量
123 ///
124 /// # 示例
125 /// ```no_run
126 /// # use kestrel_protocol_timer::TimerWheel;
127 /// # use std::time::Duration;
128 /// # #[tokio::main]
129 /// # async fn main() {
130 /// let timer = TimerWheel::with_defaults();
131 /// let callbacks: Vec<_> = (0..10)
132 /// .map(|_| (Duration::from_secs(1), || async {}))
133 /// .collect();
134 /// let batch = timer.schedule_once_batch(callbacks).await;
135 ///
136 /// let cancelled = batch.cancel_all();
137 /// println!("取消了 {} 个定时器", cancelled);
138 /// # }
139 /// ```
140 pub fn cancel_all(self) -> usize {
141 let mut wheel = self.wheel.lock();
142 wheel.cancel_batch(&self.task_ids)
143 }
144
145 /// 将批量句柄转换为单个定时器句柄的 Vec
146 ///
147 /// 消耗 BatchHandle,为每个任务创建独立的 TimerHandle。
148 ///
149 /// # 示例
150 /// ```no_run
151 /// # use kestrel_protocol_timer::TimerWheel;
152 /// # use std::time::Duration;
153 /// # #[tokio::main]
154 /// # async fn main() {
155 /// let timer = TimerWheel::with_defaults();
156 /// let callbacks: Vec<_> = (0..3)
157 /// .map(|_| (Duration::from_secs(1), || async {}))
158 /// .collect();
159 /// let batch = timer.schedule_once_batch(callbacks).await;
160 ///
161 /// // 转换为独立的句柄
162 /// let handles = batch.into_handles();
163 /// for handle in handles {
164 /// // 可以单独操作每个句柄
165 /// }
166 /// # }
167 /// ```
168 pub fn into_handles(self) -> Vec<TimerHandle> {
169 self.task_ids
170 .into_iter()
171 .zip(self.completion_rxs.into_iter())
172 .map(|(task_id, rx)| {
173 TimerHandle::new(task_id, self.wheel.clone(), rx)
174 })
175 .collect()
176 }
177
178 /// 获取批量任务的数量
179 pub fn len(&self) -> usize {
180 self.task_ids.len()
181 }
182
183 /// 检查批量任务是否为空
184 pub fn is_empty(&self) -> bool {
185 self.task_ids.is_empty()
186 }
187
188 /// 获取所有任务 ID 的引用
189 pub fn task_ids(&self) -> &[TaskId] {
190 &self.task_ids
191 }
192
193 /// 获取所有完成通知接收器的引用
194 ///
195 /// # 返回
196 /// 所有任务的完成通知接收器列表引用
197 pub fn completion_receivers(&mut self) -> &mut Vec<oneshot::Receiver<()>> {
198 &mut self.completion_rxs
199 }
200
201 /// 消耗句柄,返回所有完成通知接收器
202 ///
203 /// # 返回
204 /// 所有任务的完成通知接收器列表
205 ///
206 /// # 示例
207 /// ```no_run
208 /// # use kestrel_protocol_timer::TimerWheel;
209 /// # use std::time::Duration;
210 /// # #[tokio::main]
211 /// # async fn main() {
212 /// let timer = TimerWheel::with_defaults();
213 /// let callbacks: Vec<_> = (0..3)
214 /// .map(|_| (Duration::from_secs(1), || async {}))
215 /// .collect();
216 /// let batch = timer.schedule_once_batch(callbacks).await;
217 ///
218 /// // 获取所有完成通知接收器
219 /// let receivers = batch.into_completion_receivers();
220 /// for rx in receivers {
221 /// tokio::spawn(async move {
222 /// if rx.await.is_ok() {
223 /// println!("A timer completed!");
224 /// }
225 /// });
226 /// }
227 /// # }
228 /// ```
229 pub fn into_completion_receivers(self) -> Vec<oneshot::Receiver<()>> {
230 self.completion_rxs
231 }
232}
233
234/// 实现 IntoIterator,允许直接迭代 BatchHandle
235///
236/// # 示例
237/// ```no_run
238/// # use kestrel_protocol_timer::TimerWheel;
239/// # use std::time::Duration;
240/// # #[tokio::main]
241/// # async fn main() {
242/// let timer = TimerWheel::with_defaults();
243/// let callbacks: Vec<_> = (0..3)
244/// .map(|_| (Duration::from_secs(1), || async {}))
245/// .collect();
246/// let batch = timer.schedule_once_batch(callbacks).await;
247///
248/// // 直接迭代,每个元素都是独立的 TimerHandle
249/// for handle in batch {
250/// // 可以单独操作每个句柄
251/// }
252/// # }
253/// ```
254impl IntoIterator for BatchHandle {
255 type Item = TimerHandle;
256 type IntoIter = BatchHandleIter;
257
258 fn into_iter(self) -> Self::IntoIter {
259 BatchHandleIter {
260 task_ids: self.task_ids.into_iter(),
261 completion_rxs: self.completion_rxs.into_iter(),
262 wheel: self.wheel,
263 }
264 }
265}
266
267/// BatchHandle 的迭代器
268pub struct BatchHandleIter {
269 task_ids: std::vec::IntoIter<TaskId>,
270 completion_rxs: std::vec::IntoIter<oneshot::Receiver<()>>,
271 wheel: Arc<Mutex<Wheel>>,
272}
273
274impl Iterator for BatchHandleIter {
275 type Item = TimerHandle;
276
277 fn next(&mut self) -> Option<Self::Item> {
278 match (self.task_ids.next(), self.completion_rxs.next()) {
279 (Some(task_id), Some(rx)) => {
280 Some(TimerHandle::new(task_id, self.wheel.clone(), rx))
281 }
282 _ => None,
283 }
284 }
285
286 fn size_hint(&self) -> (usize, Option<usize>) {
287 self.task_ids.size_hint()
288 }
289}
290
291impl ExactSizeIterator for BatchHandleIter {
292 fn len(&self) -> usize {
293 self.task_ids.len()
294 }
295}
296
297/// 时间轮定时器管理器
298pub struct TimerWheel {
299 /// 时间轮唯一标识符
300
301 /// 时间轮实例(使用 Arc<Mutex> 包装以支持多线程访问)
302 wheel: Arc<Mutex<Wheel>>,
303
304 /// 后台 tick 循环任务句柄
305 tick_handle: Option<JoinHandle<()>>,
306}
307
308impl TimerWheel {
309 /// 创建新的定时器管理器
310 ///
311 /// # 参数
312 /// - `config`: 时间轮配置(已经过验证)
313 ///
314 /// # 示例
315 /// ```no_run
316 /// use kestrel_protocol_timer::{TimerWheel, WheelConfig};
317 /// use std::time::Duration;
318 ///
319 /// #[tokio::main]
320 /// async fn main() {
321 /// let config = WheelConfig::builder()
322 /// .tick_duration(Duration::from_millis(10))
323 /// .slot_count(512)
324 /// .build()
325 /// .unwrap();
326 /// let timer = TimerWheel::new(config);
327 /// }
328 /// ```
329 pub fn new(config: WheelConfig) -> Self {
330 let tick_duration = config.tick_duration;
331 let wheel = Wheel::new(config);
332 let wheel = Arc::new(Mutex::new(wheel));
333 let wheel_clone = wheel.clone();
334
335 // 启动后台 tick 循环
336 let tick_handle = tokio::spawn(async move {
337 Self::tick_loop(wheel_clone, tick_duration).await;
338 });
339
340 Self {
341 wheel,
342 tick_handle: Some(tick_handle),
343 }
344 }
345
346 /// 创建带默认配置的定时器管理器
347 /// - tick 时长: 10ms
348 /// - 槽位数量: 512
349 ///
350 /// # 示例
351 /// ```no_run
352 /// use kestrel_protocol_timer::TimerWheel;
353 ///
354 /// #[tokio::main]
355 /// async fn main() {
356 /// let timer = TimerWheel::with_defaults();
357 /// }
358 /// ```
359 pub fn with_defaults() -> Self {
360 Self::new(WheelConfig::default())
361 }
362
363 /// 创建与此时间轮绑定的 TimerService(使用默认配置)
364 ///
365 /// # 返回
366 /// 绑定到此时间轮的 TimerService 实例
367 ///
368 /// # 示例
369 /// ```no_run
370 /// use kestrel_protocol_timer::TimerWheel;
371 /// use std::time::Duration;
372 ///
373 /// #[tokio::main]
374 /// async fn main() {
375 /// let timer = TimerWheel::with_defaults();
376 /// let mut service = timer.create_service();
377 ///
378 /// // 直接通过 service 批量调度定时器
379 /// let callbacks: Vec<_> = (0..5)
380 /// .map(|_| (Duration::from_millis(100), || async {}))
381 /// .collect();
382 /// service.schedule_once_batch(callbacks).await;
383 ///
384 /// // 接收超时通知
385 /// let mut rx = service.take_receiver().unwrap();
386 /// while let Some(task_id) = rx.recv().await {
387 /// println!("Task {:?} completed", task_id);
388 /// }
389 /// }
390 /// ```
391 pub fn create_service(&self) -> crate::service::TimerService {
392 crate::service::TimerService::new(self.wheel.clone(), ServiceConfig::default())
393 }
394
395 /// 创建与此时间轮绑定的 TimerService(使用自定义配置)
396 ///
397 /// # 参数
398 /// - `config`: 服务配置
399 ///
400 /// # 返回
401 /// 绑定到此时间轮的 TimerService 实例
402 ///
403 /// # 示例
404 /// ```no_run
405 /// use kestrel_protocol_timer::{TimerWheel, ServiceConfig};
406 ///
407 /// #[tokio::main]
408 /// async fn main() {
409 /// let timer = TimerWheel::with_defaults();
410 /// let config = ServiceConfig::builder()
411 /// .command_channel_capacity(1024)
412 /// .timeout_channel_capacity(2000)
413 /// .build()
414 /// .unwrap();
415 /// let service = timer.create_service_with_config(config);
416 /// }
417 /// ```
418 pub fn create_service_with_config(&self, config: ServiceConfig) -> crate::service::TimerService {
419 crate::service::TimerService::new(self.wheel.clone(), config)
420 }
421
422 /// 内部辅助方法:创建定时器句柄
423 ///
424 /// 由 TimerWheel 和 TimerService 共用
425 pub(crate) fn create_timer_handle_internal(
426 wheel: &Arc<Mutex<Wheel>>,
427 delay: Duration,
428 callback: Option<CallbackWrapper>,
429 ) -> TimerHandle {
430 let (completion_tx, completion_rx) = oneshot::channel();
431 let notifier = CompletionNotifier(completion_tx);
432
433 let task = TimerTask::once(0, 0, callback, notifier);
434
435 let task_id = {
436 let mut wheel_guard = wheel.lock();
437 wheel_guard.insert(delay, task)
438 };
439
440 TimerHandle::new(task_id, wheel.clone(), completion_rx)
441 }
442
443 /// 内部辅助方法:创建批量定时器句柄
444 ///
445 /// 由 TimerWheel 和 TimerService 共用
446 pub(crate) fn create_batch_handle_internal<C>(
447 wheel: &Arc<Mutex<Wheel>>,
448 callbacks: Vec<(Duration, C)>,
449 ) -> BatchHandle
450 where
451 C: TimerCallback,
452 {
453 use std::sync::Arc;
454 let mut completion_rxs = Vec::with_capacity(callbacks.len());
455
456 let tasks: Vec<(Duration, TimerTask)> = callbacks
457 .into_iter()
458 .map(|(delay, callback)| {
459 let callback_wrapper = Arc::new(callback) as CallbackWrapper;
460 let (completion_tx, completion_rx) = oneshot::channel();
461 completion_rxs.push(completion_rx);
462 let notifier = CompletionNotifier(completion_tx);
463 let task = TimerTask::once(0, 0, Some(callback_wrapper), notifier);
464 (delay, task)
465 })
466 .collect();
467
468 let task_ids = {
469 let mut wheel_guard = wheel.lock();
470 wheel_guard.insert_batch(tasks)
471 };
472
473 BatchHandle::new(task_ids, wheel.clone(), completion_rxs)
474 }
475
476 /// 调度一次性定时器
477 ///
478 /// # 参数
479 /// - `delay`: 延迟时间
480 /// - `callback`: 实现了 TimerCallback trait 的回调对象
481 ///
482 /// # 返回
483 /// 返回定时器句柄,可用于取消定时器
484 ///
485 /// # 示例
486 /// ```no_run
487 /// use kestrel_protocol_timer::TimerWheel;
488 /// use std::time::Duration;
489 /// use std::sync::Arc;
490 ///
491 /// #[tokio::main]
492 /// async fn main() {
493 /// let timer = TimerWheel::with_defaults();
494 ///
495 /// let handle = timer.schedule_once(Duration::from_secs(1), || async {
496 /// println!("Timer fired!");
497 /// }).await;
498 ///
499 /// tokio::time::sleep(Duration::from_secs(2)).await;
500 /// }
501 /// ```
502 pub async fn schedule_once<C>(&self, delay: Duration, callback: C) -> TimerHandle
503 where
504 C: TimerCallback,
505 {
506 use std::sync::Arc;
507 let callback_wrapper = Arc::new(callback) as CallbackWrapper;
508 Self::create_timer_handle_internal(&self.wheel, delay, Some(callback_wrapper))
509 }
510
511 /// 批量调度一次性定时器
512 ///
513 /// # 参数
514 /// - `tasks`: (延迟时间, 回调) 的元组列表
515 ///
516 /// # 返回
517 /// 返回批量定时器句柄
518 ///
519 /// # 性能优势
520 /// - 批量处理减少锁竞争
521 /// - 内部优化批量插入操作
522 /// - 共享 Wheel 引用减少内存开销
523 ///
524 /// # 示例
525 /// ```no_run
526 /// use kestrel_protocol_timer::TimerWheel;
527 /// use std::time::Duration;
528 /// use std::sync::Arc;
529 /// use std::sync::atomic::{AtomicU32, Ordering};
530 ///
531 /// #[tokio::main]
532 /// async fn main() {
533 /// let timer = TimerWheel::with_defaults();
534 /// let counter = Arc::new(AtomicU32::new(0));
535 ///
536 /// // 动态生成批量回调
537 /// let callbacks: Vec<(Duration, _)> = (0..3)
538 /// .map(|i| {
539 /// let counter = Arc::clone(&counter);
540 /// let delay = Duration::from_millis(100 + i * 100);
541 /// let callback = move || {
542 /// let counter = Arc::clone(&counter);
543 /// async move {
544 /// counter.fetch_add(1, Ordering::SeqCst);
545 /// }
546 /// };
547 /// (delay, callback)
548 /// })
549 /// .collect();
550 ///
551 /// let batch = timer.schedule_once_batch(callbacks).await;
552 /// println!("Scheduled {} timers", batch.len());
553 ///
554 /// // 批量取消所有定时器
555 /// let cancelled = batch.cancel_all();
556 /// println!("Cancelled {} timers", cancelled);
557 /// }
558 /// ```
559 pub async fn schedule_once_batch<C>(&self, callbacks: Vec<(Duration, C)>) -> BatchHandle
560 where
561 C: TimerCallback,
562 {
563 Self::create_batch_handle_internal(&self.wheel, callbacks)
564 }
565
566
567 /// 调度一次性通知定时器(无回调,仅通知)
568 ///
569 /// # 参数
570 /// - `delay`: 延迟时间
571 ///
572 /// # 返回
573 /// 返回定时器句柄,可通过 `into_completion_receiver()` 获取通知接收器
574 ///
575 /// # 示例
576 /// ```no_run
577 /// use kestrel_protocol_timer::TimerWheel;
578 /// use std::time::Duration;
579 ///
580 /// #[tokio::main]
581 /// async fn main() {
582 /// let timer = TimerWheel::with_defaults();
583 ///
584 /// let handle = timer.schedule_once_notify(Duration::from_secs(1)).await;
585 ///
586 /// // 获取完成通知接收器
587 /// handle.into_completion_receiver().0.await.ok();
588 /// println!("Timer completed!");
589 /// }
590 /// ```
591 pub async fn schedule_once_notify(&self, delay: Duration) -> TimerHandle {
592 Self::create_timer_handle_internal(&self.wheel, delay, None)
593 }
594
595 /// 取消定时器
596 ///
597 /// # 参数
598 /// - `task_id`: 任务 ID
599 ///
600 /// # 返回
601 /// 如果任务存在且成功取消返回 true,否则返回 false
602 pub fn cancel(&self, task_id: TaskId) -> bool {
603 let mut wheel = self.wheel.lock();
604 wheel.cancel(task_id)
605 }
606
607 /// 批量取消定时器
608 ///
609 /// # 参数
610 /// - `task_ids`: 要取消的任务 ID 列表
611 ///
612 /// # 返回
613 /// 成功取消的任务数量
614 ///
615 /// # 性能优势
616 /// - 批量处理减少锁竞争
617 /// - 内部优化批量取消操作
618 ///
619 /// # 示例
620 /// ```no_run
621 /// use kestrel_protocol_timer::TimerWheel;
622 /// use std::time::Duration;
623 ///
624 /// #[tokio::main]
625 /// async fn main() {
626 /// let timer = TimerWheel::with_defaults();
627 ///
628 /// // 创建多个定时器
629 /// let handle1 = timer.schedule_once(Duration::from_secs(10), || async {}).await;
630 /// let handle2 = timer.schedule_once(Duration::from_secs(10), || async {}).await;
631 /// let handle3 = timer.schedule_once(Duration::from_secs(10), || async {}).await;
632 ///
633 /// // 批量取消
634 /// let task_ids = vec![handle1.task_id(), handle2.task_id(), handle3.task_id()];
635 /// let cancelled = timer.cancel_batch(&task_ids);
636 /// println!("已取消 {} 个定时器", cancelled);
637 /// }
638 /// ```
639 pub fn cancel_batch(&self, task_ids: &[TaskId]) -> usize {
640 let mut wheel = self.wheel.lock();
641 wheel.cancel_batch(task_ids)
642 }
643
644 /// 核心 tick 循环
645 async fn tick_loop(wheel: Arc<Mutex<Wheel>>, tick_duration: Duration) {
646 let mut interval = tokio::time::interval(tick_duration);
647 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
648
649 loop {
650 interval.tick().await;
651
652 // 推进时间轮并获取到期任务
653 let expired_tasks = {
654 let mut wheel_guard = wheel.lock();
655 wheel_guard.advance()
656 };
657
658 // 执行到期任务
659 for task in expired_tasks {
660 let callback = task.get_callback();
661
662 // 移动task的所有权来获取completion_notifier
663 let notifier = task.completion_notifier;
664
665 // 在独立的 tokio 任务中执行回调,并在回调完成后发送通知
666 if let Some(callback) = callback {
667 tokio::spawn(async move {
668 // 执行回调
669 let future = callback.call();
670 future.await;
671
672 // 回调执行完成后发送通知
673 let _ = notifier.0.send(());
674 });
675 } else {
676 // 如果没有回调,立即发送完成通知
677 let _ = notifier.0.send(());
678 }
679 }
680 }
681 }
682
683 /// 停止定时器管理器
684 pub async fn shutdown(mut self) {
685 if let Some(handle) = self.tick_handle.take() {
686 handle.abort();
687 let _ = handle.await;
688 }
689 }
690}
691
692impl Drop for TimerWheel {
693 fn drop(&mut self) {
694 if let Some(handle) = self.tick_handle.take() {
695 handle.abort();
696 }
697 }
698}
699
700#[cfg(test)]
701mod tests {
702 use super::*;
703 use std::sync::atomic::{AtomicU32, Ordering};
704
705 #[tokio::test]
706 async fn test_timer_creation() {
707 let _timer = TimerWheel::with_defaults();
708 }
709
710 #[tokio::test]
711 async fn test_schedule_once() {
712 use std::sync::Arc;
713 let timer = TimerWheel::with_defaults();
714 let counter = Arc::new(AtomicU32::new(0));
715 let counter_clone = Arc::clone(&counter);
716
717 let _handle = timer.schedule_once(
718 Duration::from_millis(50),
719 move || {
720 let counter = Arc::clone(&counter_clone);
721 async move {
722 counter.fetch_add(1, Ordering::SeqCst);
723 }
724 },
725 ).await;
726
727 // 等待定时器触发
728 tokio::time::sleep(Duration::from_millis(100)).await;
729 assert_eq!(counter.load(Ordering::SeqCst), 1);
730 }
731
732 #[tokio::test]
733 async fn test_cancel_timer() {
734 use std::sync::Arc;
735 let timer = TimerWheel::with_defaults();
736 let counter = Arc::new(AtomicU32::new(0));
737 let counter_clone = Arc::clone(&counter);
738
739 let handle = timer.schedule_once(
740 Duration::from_millis(100),
741 move || {
742 let counter = Arc::clone(&counter_clone);
743 async move {
744 counter.fetch_add(1, Ordering::SeqCst);
745 }
746 },
747 ).await;
748
749 // 立即取消
750 let cancel_result = handle.cancel();
751 assert!(cancel_result);
752
753 // 等待足够长时间确保定时器不会触发
754 tokio::time::sleep(Duration::from_millis(200)).await;
755 assert_eq!(counter.load(Ordering::SeqCst), 0);
756 }
757
758 #[tokio::test]
759 async fn test_cancel_immediate() {
760 use std::sync::Arc;
761 let timer = TimerWheel::with_defaults();
762 let counter = Arc::new(AtomicU32::new(0));
763 let counter_clone = Arc::clone(&counter);
764
765 let handle = timer.schedule_once(
766 Duration::from_millis(100),
767 move || {
768 let counter = Arc::clone(&counter_clone);
769 async move {
770 counter.fetch_add(1, Ordering::SeqCst);
771 }
772 },
773 ).await;
774
775 // 立即取消
776 let cancel_result = handle.cancel();
777 assert!(cancel_result);
778
779 // 等待足够长时间确保定时器不会触发
780 tokio::time::sleep(Duration::from_millis(200)).await;
781 assert_eq!(counter.load(Ordering::SeqCst), 0);
782 }
783}
784