kestrel_protocol_timer/
timer.rs

1use crate::config::{ServiceConfig, WheelConfig};
2use crate::task::{CallbackWrapper, TaskId, TimerCallback};
3use crate::wheel::Wheel;
4use parking_lot::Mutex;
5use std::sync::Arc;
6use std::time::Duration;
7use tokio::sync::oneshot;
8use tokio::task::JoinHandle;
9
10/// 完成通知接收器,用于接收定时器完成通知
11pub struct CompletionReceiver(pub oneshot::Receiver<()>);
12
13/// 定时器句柄,用于管理定时器的生命周期
14/// 
15/// 注意:此类型不实现 Clone,以防止重复取消同一个定时器。
16/// 每个定时器只应有一个所有者。
17pub struct TimerHandle {
18    pub(crate) task_id: TaskId,
19    pub(crate) wheel: Arc<Mutex<Wheel>>,
20    pub(crate) completion_rx: CompletionReceiver,
21}
22
23impl TimerHandle {
24    pub(crate) fn new(task_id: TaskId, wheel: Arc<Mutex<Wheel>>, completion_rx: oneshot::Receiver<()>) -> Self {
25        Self { task_id, wheel, completion_rx: CompletionReceiver(completion_rx) }
26    }
27
28    /// 取消定时器
29    ///
30    /// # 返回
31    /// 如果任务存在且成功取消返回 true,否则返回 false
32    ///
33    /// # 示例
34    /// ```no_run
35    /// # use kestrel_protocol_timer::{TimerWheel, TimerTask};
36    /// # use std::time::Duration;
37    /// # #[tokio::main]
38    /// # async fn main() {
39    /// let timer = TimerWheel::with_defaults();
40    /// let task = TimerWheel::create_task(Duration::from_secs(1), || async {});
41    /// let handle = timer.register(task);
42    /// 
43    /// // 取消定时器
44    /// let success = handle.cancel();
45    /// println!("取消成功: {}", success);
46    /// # }
47    /// ```
48    pub fn cancel(&self) -> bool {
49        let mut wheel = self.wheel.lock();
50        wheel.cancel(self.task_id)
51    }
52
53    /// 获取完成通知接收器的可变引用
54    ///
55    /// # 示例
56    /// ```no_run
57    /// # use kestrel_protocol_timer::{TimerWheel, TimerTask};
58    /// # use std::time::Duration;
59    /// # #[tokio::main]
60    /// # async fn main() {
61    /// let timer = TimerWheel::with_defaults();
62    /// let task = TimerWheel::create_task(Duration::from_secs(1), || async {
63    ///     println!("Timer fired!");
64    /// });
65    /// let handle = timer.register(task);
66    /// 
67    /// // 等待定时器完成(使用 into_completion_receiver 消耗句柄)
68    /// handle.into_completion_receiver().0.await.ok();
69    /// println!("Timer completed!");
70    /// # }
71    /// ```
72    pub fn completion_receiver(&mut self) -> &mut CompletionReceiver {
73        &mut self.completion_rx
74    }
75
76    /// 消耗句柄,返回完成通知接收器
77    ///
78    /// # 示例
79    /// ```no_run
80    /// # use kestrel_protocol_timer::{TimerWheel, TimerTask};
81    /// # use std::time::Duration;
82    /// # #[tokio::main]
83    /// # async fn main() {
84    /// let timer = TimerWheel::with_defaults();
85    /// let task = TimerWheel::create_task(Duration::from_secs(1), || async {
86    ///     println!("Timer fired!");
87    /// });
88    /// let handle = timer.register(task);
89    /// 
90    /// // 等待定时器完成
91    /// handle.into_completion_receiver().0.await.ok();
92    /// println!("Timer completed!");
93    /// # }
94    /// ```
95    pub fn into_completion_receiver(self) -> CompletionReceiver {
96        self.completion_rx
97    }
98}
99
100/// 批量定时器句柄,用于管理批量调度的定时器
101/// 
102/// 通过共享 Wheel 引用减少内存开销,同时提供批量操作和迭代器访问能力。
103/// 
104/// 注意:此类型不实现 Clone,以防止重复取消同一批定时器。
105/// 如需访问单个定时器句柄,请使用 `into_iter()` 或 `into_handles()` 进行转换。
106pub struct BatchHandle {
107    pub(crate) task_ids: Vec<TaskId>,
108    pub(crate) wheel: Arc<Mutex<Wheel>>,
109    pub(crate) completion_rxs: Vec<oneshot::Receiver<()>>,
110}
111
112impl BatchHandle {
113    pub(crate) fn new(task_ids: Vec<TaskId>, wheel: Arc<Mutex<Wheel>>, completion_rxs: Vec<oneshot::Receiver<()>>) -> Self {
114        Self { task_ids, wheel, completion_rxs }
115    }
116
117    /// 批量取消所有定时器
118    ///
119    /// # 返回
120    /// 成功取消的任务数量
121    ///
122    /// # 示例
123    /// ```no_run
124    /// # use kestrel_protocol_timer::{TimerWheel, TimerTask};
125    /// # use std::time::Duration;
126    /// # #[tokio::main]
127    /// # async fn main() {
128    /// let timer = TimerWheel::with_defaults();
129    /// let tasks = TimerWheel::create_batch(
130    ///     (0..10).map(|_| (Duration::from_secs(1), || async {})).collect()
131    /// );
132    /// let batch = timer.register_batch(tasks);
133    /// 
134    /// let cancelled = batch.cancel_all();
135    /// println!("取消了 {} 个定时器", cancelled);
136    /// # }
137    /// ```
138    pub fn cancel_all(self) -> usize {
139        let mut wheel = self.wheel.lock();
140        wheel.cancel_batch(&self.task_ids)
141    }
142
143    /// 将批量句柄转换为单个定时器句柄的 Vec
144    ///
145    /// 消耗 BatchHandle,为每个任务创建独立的 TimerHandle。
146    ///
147    /// # 示例
148    /// ```no_run
149    /// # use kestrel_protocol_timer::{TimerWheel, TimerTask};
150    /// # use std::time::Duration;
151    /// # #[tokio::main]
152    /// # async fn main() {
153    /// let timer = TimerWheel::with_defaults();
154    /// let tasks = TimerWheel::create_batch(
155    ///     (0..3).map(|_| (Duration::from_secs(1), || async {})).collect()
156    /// );
157    /// let batch = timer.register_batch(tasks);
158    /// 
159    /// // 转换为独立的句柄
160    /// let handles = batch.into_handles();
161    /// for handle in handles {
162    ///     // 可以单独操作每个句柄
163    /// }
164    /// # }
165    /// ```
166    pub fn into_handles(self) -> Vec<TimerHandle> {
167        self.task_ids
168            .into_iter()
169            .zip(self.completion_rxs.into_iter())
170            .map(|(task_id, rx)| {
171                TimerHandle::new(task_id, self.wheel.clone(), rx)
172            })
173            .collect()
174    }
175
176    /// 获取批量任务的数量
177    pub fn len(&self) -> usize {
178        self.task_ids.len()
179    }
180
181    /// 检查批量任务是否为空
182    pub fn is_empty(&self) -> bool {
183        self.task_ids.is_empty()
184    }
185
186    /// 获取所有任务 ID 的引用
187    pub fn task_ids(&self) -> &[TaskId] {
188        &self.task_ids
189    }
190
191    /// 获取所有完成通知接收器的引用
192    ///
193    /// # 返回
194    /// 所有任务的完成通知接收器列表引用
195    pub fn completion_receivers(&mut self) -> &mut Vec<oneshot::Receiver<()>> {
196        &mut self.completion_rxs
197    }
198
199    /// 消耗句柄,返回所有完成通知接收器
200    ///
201    /// # 返回
202    /// 所有任务的完成通知接收器列表
203    ///
204    /// # 示例
205    /// ```no_run
206    /// # use kestrel_protocol_timer::{TimerWheel, TimerTask};
207    /// # use std::time::Duration;
208    /// # #[tokio::main]
209    /// # async fn main() {
210    /// let timer = TimerWheel::with_defaults();
211    /// let tasks = TimerWheel::create_batch(
212    ///     (0..3).map(|_| (Duration::from_secs(1), || async {})).collect()
213    /// );
214    /// let batch = timer.register_batch(tasks);
215    /// 
216    /// // 获取所有完成通知接收器
217    /// let receivers = batch.into_completion_receivers();
218    /// for rx in receivers {
219    ///     tokio::spawn(async move {
220    ///         if rx.await.is_ok() {
221    ///             println!("A timer completed!");
222    ///         }
223    ///     });
224    /// }
225    /// # }
226    /// ```
227    pub fn into_completion_receivers(self) -> Vec<oneshot::Receiver<()>> {
228        self.completion_rxs
229    }
230}
231
232/// 实现 IntoIterator,允许直接迭代 BatchHandle
233/// 
234/// # 示例
235/// ```no_run
236/// # use kestrel_protocol_timer::{TimerWheel, TimerTask};
237/// # use std::time::Duration;
238/// # #[tokio::main]
239/// # async fn main() {
240/// let timer = TimerWheel::with_defaults();
241/// let tasks = TimerWheel::create_batch(
242///     (0..3).map(|_| (Duration::from_secs(1), || async {})).collect()
243/// );
244/// let batch = timer.register_batch(tasks);
245/// 
246/// // 直接迭代,每个元素都是独立的 TimerHandle
247/// for handle in batch {
248///     // 可以单独操作每个句柄
249/// }
250/// # }
251/// ```
252impl IntoIterator for BatchHandle {
253    type Item = TimerHandle;
254    type IntoIter = BatchHandleIter;
255
256    fn into_iter(self) -> Self::IntoIter {
257        BatchHandleIter {
258            task_ids: self.task_ids.into_iter(),
259            completion_rxs: self.completion_rxs.into_iter(),
260            wheel: self.wheel,
261        }
262    }
263}
264
265/// BatchHandle 的迭代器
266pub struct BatchHandleIter {
267    task_ids: std::vec::IntoIter<TaskId>,
268    completion_rxs: std::vec::IntoIter<oneshot::Receiver<()>>,
269    wheel: Arc<Mutex<Wheel>>,
270}
271
272impl Iterator for BatchHandleIter {
273    type Item = TimerHandle;
274
275    fn next(&mut self) -> Option<Self::Item> {
276        match (self.task_ids.next(), self.completion_rxs.next()) {
277            (Some(task_id), Some(rx)) => {
278                Some(TimerHandle::new(task_id, self.wheel.clone(), rx))
279            }
280            _ => None,
281        }
282    }
283
284    fn size_hint(&self) -> (usize, Option<usize>) {
285        self.task_ids.size_hint()
286    }
287}
288
289impl ExactSizeIterator for BatchHandleIter {
290    fn len(&self) -> usize {
291        self.task_ids.len()
292    }
293}
294
295/// 时间轮定时器管理器
296pub struct TimerWheel {
297    /// 时间轮唯一标识符
298    
299    /// 时间轮实例(使用 Arc<Mutex> 包装以支持多线程访问)
300    wheel: Arc<Mutex<Wheel>>,
301    
302    /// 后台 tick 循环任务句柄
303    tick_handle: Option<JoinHandle<()>>,
304}
305
306impl TimerWheel {
307    /// 创建新的定时器管理器
308    ///
309    /// # 参数
310    /// - `config`: 时间轮配置(已经过验证)
311    ///
312    /// # 示例
313    /// ```no_run
314    /// use kestrel_protocol_timer::{TimerWheel, WheelConfig, TimerTask};
315    /// use std::time::Duration;
316    ///
317    /// #[tokio::main]
318    /// async fn main() {
319    ///     let config = WheelConfig::builder()
320    ///         .tick_duration(Duration::from_millis(10))
321    ///         .slot_count(512)
322    ///         .build()
323    ///         .unwrap();
324    ///     let timer = TimerWheel::new(config);
325    ///     
326    ///     // 使用两步式 API
327    ///     let task = TimerWheel::create_task(Duration::from_secs(1), || async {});
328    ///     let handle = timer.register(task);
329    /// }
330    /// ```
331    pub fn new(config: WheelConfig) -> Self {
332        let tick_duration = config.tick_duration;
333        let wheel = Wheel::new(config);
334        let wheel = Arc::new(Mutex::new(wheel));
335        let wheel_clone = wheel.clone();
336
337        // 启动后台 tick 循环
338        let tick_handle = tokio::spawn(async move {
339            Self::tick_loop(wheel_clone, tick_duration).await;
340        });
341
342        Self {
343            wheel,
344            tick_handle: Some(tick_handle),
345        }
346    }
347
348    /// 创建带默认配置的定时器管理器
349    /// - tick 时长: 10ms
350    /// - 槽位数量: 512
351    ///
352    /// # 示例
353    /// ```no_run
354    /// use kestrel_protocol_timer::TimerWheel;
355    ///
356    /// #[tokio::main]
357    /// async fn main() {
358    ///     let timer = TimerWheel::with_defaults();
359    /// }
360    /// ```
361    pub fn with_defaults() -> Self {
362        Self::new(WheelConfig::default())
363    }
364
365    /// 创建与此时间轮绑定的 TimerService(使用默认配置)
366    ///
367    /// # 返回
368    /// 绑定到此时间轮的 TimerService 实例
369    ///
370    /// # 示例
371    /// ```no_run
372    /// use kestrel_protocol_timer::{TimerWheel, TimerService};
373    /// use std::time::Duration;
374    ///
375    /// #[tokio::main]
376    /// async fn main() {
377    ///     let timer = TimerWheel::with_defaults();
378    ///     let mut service = timer.create_service();
379    ///     
380    ///     // 使用两步式 API 通过 service 批量调度定时器
381    ///     let tasks = TimerService::create_batch(
382    ///         (0..5).map(|_| (Duration::from_millis(100), || async {})).collect()
383    ///     );
384    ///     service.register_batch(tasks).await;
385    ///     
386    ///     // 接收超时通知
387    ///     let mut rx = service.take_receiver().unwrap();
388    ///     while let Some(task_id) = rx.recv().await {
389    ///         println!("Task {:?} completed", task_id);
390    ///     }
391    /// }
392    /// ```
393    pub fn create_service(&self) -> crate::service::TimerService {
394        crate::service::TimerService::new(self.wheel.clone(), ServiceConfig::default())
395    }
396    
397    /// 创建与此时间轮绑定的 TimerService(使用自定义配置)
398    ///
399    /// # 参数
400    /// - `config`: 服务配置
401    ///
402    /// # 返回
403    /// 绑定到此时间轮的 TimerService 实例
404    ///
405    /// # 示例
406    /// ```no_run
407    /// use kestrel_protocol_timer::{TimerWheel, ServiceConfig};
408    ///
409    /// #[tokio::main]
410    /// async fn main() {
411    ///     let timer = TimerWheel::with_defaults();
412    ///     let config = ServiceConfig::builder()
413    ///         .command_channel_capacity(1024)
414    ///         .timeout_channel_capacity(2000)
415    ///         .build()
416    ///         .unwrap();
417    ///     let service = timer.create_service_with_config(config);
418    /// }
419    /// ```
420    pub fn create_service_with_config(&self, config: ServiceConfig) -> crate::service::TimerService {
421        crate::service::TimerService::new(self.wheel.clone(), config)
422    }
423
424    /// 创建定时器任务(申请阶段)
425    /// 
426    /// # 参数
427    /// - `delay`: 延迟时间
428    /// - `callback`: 实现了 TimerCallback trait 的回调对象
429    /// 
430    /// # 返回
431    /// 返回 TimerTask,需要通过 `register()` 注册到时间轮
432    /// 
433    /// # 示例
434    /// ```no_run
435    /// use kestrel_protocol_timer::{TimerWheel, TimerTask};
436    /// use std::time::Duration;
437    /// 
438    /// #[tokio::main]
439    /// async fn main() {
440    ///     let timer = TimerWheel::with_defaults();
441    ///     
442    ///     // 步骤 1: 创建任务
443    ///     let task = TimerWheel::create_task(Duration::from_secs(1), || async {
444    ///         println!("Timer fired!");
445    ///     });
446    ///     
447    ///     // 获取任务 ID
448    ///     let task_id = task.get_id();
449    ///     println!("Created task: {:?}", task_id);
450    ///     
451    ///     // 步骤 2: 注册任务
452    ///     let handle = timer.register(task);
453    /// }
454    /// ```
455    pub fn create_task<C>(delay: Duration, callback: C) -> crate::task::TimerTask
456    where
457        C: TimerCallback,
458    {
459        use std::sync::Arc;
460        let callback_wrapper = Arc::new(callback) as CallbackWrapper;
461        crate::task::TimerTask::new(delay, Some(callback_wrapper))
462    }
463    
464    /// 批量创建定时器任务(申请阶段)
465    /// 
466    /// # 参数
467    /// - `callbacks`: (延迟时间, 回调) 的元组列表
468    /// 
469    /// # 返回
470    /// 返回 TimerTask 列表,需要通过 `register_batch()` 注册到时间轮
471    /// 
472    /// # 示例
473    /// ```no_run
474    /// use kestrel_protocol_timer::{TimerWheel, TimerTask};
475    /// use std::time::Duration;
476    /// use std::sync::Arc;
477    /// use std::sync::atomic::{AtomicU32, Ordering};
478    /// 
479    /// #[tokio::main]
480    /// async fn main() {
481    ///     let timer = TimerWheel::with_defaults();
482    ///     let counter = Arc::new(AtomicU32::new(0));
483    ///     
484    ///     // 步骤 1: 批量创建任务
485    ///     let callbacks: Vec<_> = (0..3)
486    ///         .map(|i| {
487    ///             let counter = Arc::clone(&counter);
488    ///             let delay = Duration::from_millis(100 + i * 100);
489    ///             let callback = move || {
490    ///                 let counter = Arc::clone(&counter);
491    ///                 async move {
492    ///                     counter.fetch_add(1, Ordering::SeqCst);
493    ///                 }
494    ///             };
495    ///             (delay, callback)
496    ///         })
497    ///         .collect();
498    ///     
499    ///     let tasks = TimerWheel::create_batch(callbacks);
500    ///     println!("Created {} tasks", tasks.len());
501    ///     
502    ///     // 步骤 2: 批量注册任务
503    ///     let batch = timer.register_batch(tasks);
504    /// }
505    /// ```
506    pub fn create_batch<C>(callbacks: Vec<(Duration, C)>) -> Vec<crate::task::TimerTask>
507    where
508        C: TimerCallback,
509    {
510        use std::sync::Arc;
511        callbacks
512            .into_iter()
513            .map(|(delay, callback)| {
514                let callback_wrapper = Arc::new(callback) as CallbackWrapper;
515                crate::task::TimerTask::new(delay, Some(callback_wrapper))
516            })
517            .collect()
518    }
519    
520    /// 注册定时器任务到时间轮(注册阶段)
521    /// 
522    /// # 参数
523    /// - `task`: 通过 `create_task()` 创建的任务
524    /// 
525    /// # 返回
526    /// 返回定时器句柄,可用于取消定时器和接收完成通知
527    /// 
528    /// # 示例
529    /// ```no_run
530    /// use kestrel_protocol_timer::{TimerWheel, TimerTask};
531    /// use std::time::Duration;
532    /// 
533    /// #[tokio::main]
534    /// async fn main() {
535    ///     let timer = TimerWheel::with_defaults();
536    ///     
537    ///     let task = TimerWheel::create_task(Duration::from_secs(1), || async {
538    ///         println!("Timer fired!");
539    ///     });
540    ///     let task_id = task.get_id();
541    ///     
542    ///     let handle = timer.register(task);
543    ///     
544    ///     // 等待定时器完成
545    ///     handle.into_completion_receiver().0.await.ok();
546    /// }
547    /// ```
548    #[inline]
549    pub fn register(&self, task: crate::task::TimerTask) -> TimerHandle {
550        let (completion_tx, completion_rx) = oneshot::channel();
551        let notifier = crate::task::CompletionNotifier(completion_tx);
552        
553        let delay = task.delay;
554        let task_id = task.id;
555        
556        // 单次加锁完成所有操作
557        {
558            let mut wheel_guard = self.wheel.lock();
559            wheel_guard.insert(delay, task, notifier);
560        }
561        
562        TimerHandle::new(task_id, self.wheel.clone(), completion_rx)
563    }
564    
565    /// 批量注册定时器任务到时间轮(注册阶段)
566    /// 
567    /// # 参数
568    /// - `tasks`: 通过 `create_batch()` 创建的任务列表
569    /// 
570    /// # 返回
571    /// 返回批量定时器句柄
572    /// 
573    /// # 示例
574    /// ```no_run
575    /// use kestrel_protocol_timer::{TimerWheel, TimerTask};
576    /// use std::time::Duration;
577    /// 
578    /// #[tokio::main]
579    /// async fn main() {
580    ///     let timer = TimerWheel::with_defaults();
581    ///     
582    ///     let callbacks: Vec<_> = (0..3)
583    ///         .map(|_| (Duration::from_secs(1), || async {}))
584    ///         .collect();
585    ///     let tasks = TimerWheel::create_batch(callbacks);
586    ///     
587    ///     let batch = timer.register_batch(tasks);
588    ///     println!("Registered {} timers", batch.len());
589    /// }
590    /// ```
591    #[inline]
592    pub fn register_batch(&self, tasks: Vec<crate::task::TimerTask>) -> BatchHandle {
593        let task_count = tasks.len();
594        let mut completion_rxs = Vec::with_capacity(task_count);
595        let mut task_ids = Vec::with_capacity(task_count);
596        let mut prepared_tasks = Vec::with_capacity(task_count);
597        
598        // 步骤1: 准备所有 channels 和 notifiers(无锁)
599        // 优化:使用 for 循环代替 map + collect,避免闭包捕获开销
600        for task in tasks {
601            let (completion_tx, completion_rx) = oneshot::channel();
602            let notifier = crate::task::CompletionNotifier(completion_tx);
603            
604            task_ids.push(task.id);
605            completion_rxs.push(completion_rx);
606            prepared_tasks.push((task.delay, task, notifier));
607        }
608        
609        // 步骤2: 单次加锁,批量插入
610        {
611            let mut wheel_guard = self.wheel.lock();
612            wheel_guard.insert_batch(prepared_tasks);
613        }
614        
615        BatchHandle::new(task_ids, self.wheel.clone(), completion_rxs)
616    }
617
618    /// 取消定时器
619    ///
620    /// # 参数
621    /// - `task_id`: 任务 ID
622    ///
623    /// # 返回
624    /// 如果任务存在且成功取消返回 true,否则返回 false
625    /// 
626    /// # 示例
627    /// ```no_run
628    /// use kestrel_protocol_timer::{TimerWheel, TimerTask};
629    /// use std::time::Duration;
630    ///
631    /// #[tokio::main]
632    /// async fn main() {
633    ///     let timer = TimerWheel::with_defaults();
634    ///     
635    ///     let task = TimerWheel::create_task(Duration::from_secs(10), || async {});
636    ///     let task_id = task.get_id();
637    ///     let _handle = timer.register(task);
638    ///     
639    ///     // 使用任务 ID 取消
640    ///     let cancelled = timer.cancel(task_id);
641    ///     println!("取消成功: {}", cancelled);
642    /// }
643    /// ```
644    #[inline]
645    pub fn cancel(&self, task_id: TaskId) -> bool {
646        let mut wheel = self.wheel.lock();
647        wheel.cancel(task_id)
648    }
649
650    /// 批量取消定时器
651    ///
652    /// # 参数
653    /// - `task_ids`: 要取消的任务 ID 列表
654    ///
655    /// # 返回
656    /// 成功取消的任务数量
657    ///
658    /// # 性能优势
659    /// - 批量处理减少锁竞争
660    /// - 内部优化批量取消操作
661    ///
662    /// # 示例
663    /// ```no_run
664    /// use kestrel_protocol_timer::{TimerWheel, TimerTask};
665    /// use std::time::Duration;
666    ///
667    /// #[tokio::main]
668    /// async fn main() {
669    ///     let timer = TimerWheel::with_defaults();
670    ///     
671    ///     // 创建多个定时器
672    ///     let task1 = TimerWheel::create_task(Duration::from_secs(10), || async {});
673    ///     let task2 = TimerWheel::create_task(Duration::from_secs(10), || async {});
674    ///     let task3 = TimerWheel::create_task(Duration::from_secs(10), || async {});
675    ///     
676    ///     let task_ids = vec![task1.get_id(), task2.get_id(), task3.get_id()];
677    ///     
678    ///     let _h1 = timer.register(task1);
679    ///     let _h2 = timer.register(task2);
680    ///     let _h3 = timer.register(task3);
681    ///     
682    ///     // 批量取消
683    ///     let cancelled = timer.cancel_batch(&task_ids);
684    ///     println!("已取消 {} 个定时器", cancelled);
685    /// }
686    /// ```
687    #[inline]
688    pub fn cancel_batch(&self, task_ids: &[TaskId]) -> usize {
689        let mut wheel = self.wheel.lock();
690        wheel.cancel_batch(task_ids)
691    }
692    
693    /// 核心 tick 循环
694    async fn tick_loop(wheel: Arc<Mutex<Wheel>>, tick_duration: Duration) {
695        let mut interval = tokio::time::interval(tick_duration);
696        interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
697
698        loop {
699            interval.tick().await;
700
701            // 推进时间轮并获取到期任务
702            let expired_tasks = {
703                let mut wheel_guard = wheel.lock();
704                wheel_guard.advance()
705            };
706
707            // 执行到期任务
708            for task in expired_tasks {
709                let callback = task.get_callback();
710                
711                // 移动task的所有权来获取completion_notifier
712                let notifier = task.completion_notifier;
713                
714                // 只有注册过的任务才有 notifier
715                if let Some(notifier) = notifier {
716                    // 在独立的 tokio 任务中执行回调,并在回调完成后发送通知
717                    if let Some(callback) = callback {
718                        tokio::spawn(async move {
719                            // 执行回调
720                            let future = callback.call();
721                            future.await;
722                            
723                            // 回调执行完成后发送通知
724                            let _ = notifier.0.send(());
725                        });
726                    } else {
727                        // 如果没有回调,立即发送完成通知
728                        let _ = notifier.0.send(());
729                    }
730                }
731            }
732        }
733    }
734
735    /// 停止定时器管理器
736    pub async fn shutdown(mut self) {
737        if let Some(handle) = self.tick_handle.take() {
738            handle.abort();
739            let _ = handle.await;
740        }
741    }
742}
743
744impl Drop for TimerWheel {
745    fn drop(&mut self) {
746        if let Some(handle) = self.tick_handle.take() {
747            handle.abort();
748        }
749    }
750}
751
752#[cfg(test)]
753mod tests {
754    use super::*;
755    use std::sync::atomic::{AtomicU32, Ordering};
756
757    #[tokio::test]
758    async fn test_timer_creation() {
759        let _timer = TimerWheel::with_defaults();
760    }
761
762    #[tokio::test]
763    async fn test_schedule_once() {
764        use std::sync::Arc;
765        let timer = TimerWheel::with_defaults();
766        let counter = Arc::new(AtomicU32::new(0));
767        let counter_clone = Arc::clone(&counter);
768
769        let task = TimerWheel::create_task(
770            Duration::from_millis(50),
771            move || {
772                let counter = Arc::clone(&counter_clone);
773                async move {
774                    counter.fetch_add(1, Ordering::SeqCst);
775                }
776            },
777        );
778        let _handle = timer.register(task);
779
780        // 等待定时器触发
781        tokio::time::sleep(Duration::from_millis(100)).await;
782        assert_eq!(counter.load(Ordering::SeqCst), 1);
783    }
784
785    #[tokio::test]
786    async fn test_cancel_timer() {
787        use std::sync::Arc;
788        let timer = TimerWheel::with_defaults();
789        let counter = Arc::new(AtomicU32::new(0));
790        let counter_clone = Arc::clone(&counter);
791
792        let task = TimerWheel::create_task(
793            Duration::from_millis(100),
794            move || {
795                let counter = Arc::clone(&counter_clone);
796                async move {
797                    counter.fetch_add(1, Ordering::SeqCst);
798                }
799            },
800        );
801        let handle = timer.register(task);
802
803        // 立即取消
804        let cancel_result = handle.cancel();
805        assert!(cancel_result);
806
807        // 等待足够长时间确保定时器不会触发
808        tokio::time::sleep(Duration::from_millis(200)).await;
809        assert_eq!(counter.load(Ordering::SeqCst), 0);
810    }
811
812    #[tokio::test]
813    async fn test_cancel_immediate() {
814        use std::sync::Arc;
815        let timer = TimerWheel::with_defaults();
816        let counter = Arc::new(AtomicU32::new(0));
817        let counter_clone = Arc::clone(&counter);
818
819        let task = TimerWheel::create_task(
820            Duration::from_millis(100),
821            move || {
822                let counter = Arc::clone(&counter_clone);
823                async move {
824                    counter.fetch_add(1, Ordering::SeqCst);
825                }
826            },
827        );
828        let handle = timer.register(task);
829
830        // 立即取消
831        let cancel_result = handle.cancel();
832        assert!(cancel_result);
833
834        // 等待足够长时间确保定时器不会触发
835        tokio::time::sleep(Duration::from_millis(200)).await;
836        assert_eq!(counter.load(Ordering::SeqCst), 0);
837    }
838}
839