kestrel_protocol_timer/
timer.rs

1use crate::config::{BatchConfig, ServiceConfig, WheelConfig};
2use crate::task::{CallbackWrapper, TaskId, TaskCompletionReason};
3use crate::wheel::Wheel;
4use parking_lot::Mutex;
5use std::sync::Arc;
6use std::time::Duration;
7use tokio::sync::oneshot;
8use tokio::task::JoinHandle;
9
10/// 完成通知接收器,用于接收定时器完成通知
11pub struct CompletionReceiver(pub oneshot::Receiver<TaskCompletionReason>);
12
13/// 定时器句柄,用于管理定时器的生命周期
14/// 
15/// 注意:此类型不实现 Clone,以防止重复取消同一个定时器。
16/// 每个定时器只应有一个所有者。
17pub struct TimerHandle {
18    pub(crate) task_id: TaskId,
19    pub(crate) wheel: Arc<Mutex<Wheel>>,
20    pub(crate) completion_rx: CompletionReceiver,
21}
22
23impl TimerHandle {
24    pub(crate) fn new(task_id: TaskId, wheel: Arc<Mutex<Wheel>>, completion_rx: oneshot::Receiver<TaskCompletionReason>) -> Self {
25        Self { task_id, wheel, completion_rx: CompletionReceiver(completion_rx) }
26    }
27
28    /// 取消定时器
29    ///
30    /// # 返回
31    /// 如果任务存在且成功取消返回 true,否则返回 false
32    ///
33    /// # 示例
34    /// ```no_run
35    /// # use kestrel_protocol_timer::{TimerWheel, CallbackWrapper};
36    /// # use std::time::Duration;
37    /// # 
38    /// # #[tokio::main]
39    /// # async fn main() {
40    /// let timer = TimerWheel::with_defaults();
41    /// let callback = Some(CallbackWrapper::new(|| async {}));
42    /// let task = TimerWheel::create_task(Duration::from_secs(1), callback);
43    /// let handle = timer.register(task);
44    /// 
45    /// // 取消定时器
46    /// let success = handle.cancel();
47    /// println!("取消成功: {}", success);
48    /// # }
49    /// ```
50    pub fn cancel(&self) -> bool {
51        let mut wheel = self.wheel.lock();
52        wheel.cancel(self.task_id)
53    }
54
55    /// 获取完成通知接收器的可变引用
56    ///
57    /// # 示例
58    /// ```no_run
59    /// # use kestrel_protocol_timer::{TimerWheel, CallbackWrapper};
60    /// # use std::time::Duration;
61    /// # 
62    /// # #[tokio::main]
63    /// # async fn main() {
64    /// let timer = TimerWheel::with_defaults();
65    /// let callback = Some(CallbackWrapper::new(|| async {
66    ///     println!("Timer fired!");
67    /// }));
68    /// let task = TimerWheel::create_task(Duration::from_secs(1), callback);
69    /// let handle = timer.register(task);
70    /// 
71    /// // 等待定时器完成(使用 into_completion_receiver 消耗句柄)
72    /// handle.into_completion_receiver().0.await.ok();
73    /// println!("Timer completed!");
74    /// # }
75    /// ```
76    pub fn completion_receiver(&mut self) -> &mut CompletionReceiver {
77        &mut self.completion_rx
78    }
79
80    /// 消耗句柄,返回完成通知接收器
81    ///
82    /// # 示例
83    /// ```no_run
84    /// # use kestrel_protocol_timer::{TimerWheel, CallbackWrapper};
85    /// # use std::time::Duration;
86    /// # 
87    /// # #[tokio::main]
88    /// # async fn main() {
89    /// let timer = TimerWheel::with_defaults();
90    /// let callback = Some(CallbackWrapper::new(|| async {
91    ///     println!("Timer fired!");
92    /// }));
93    /// let task = TimerWheel::create_task(Duration::from_secs(1), callback);
94    /// let handle = timer.register(task);
95    /// 
96    /// // 等待定时器完成
97    /// handle.into_completion_receiver().0.await.ok();
98    /// println!("Timer completed!");
99    /// # }
100    /// ```
101    pub fn into_completion_receiver(self) -> CompletionReceiver {
102        self.completion_rx
103    }
104}
105
106/// 批量定时器句柄,用于管理批量调度的定时器
107/// 
108/// 通过共享 Wheel 引用减少内存开销,同时提供批量操作和迭代器访问能力。
109/// 
110/// 注意:此类型不实现 Clone,以防止重复取消同一批定时器。
111/// 如需访问单个定时器句柄,请使用 `into_iter()` 或 `into_handles()` 进行转换。
112pub struct BatchHandle {
113    pub(crate) task_ids: Vec<TaskId>,
114    pub(crate) wheel: Arc<Mutex<Wheel>>,
115    pub(crate) completion_rxs: Vec<oneshot::Receiver<TaskCompletionReason>>,
116}
117
118impl BatchHandle {
119    pub(crate) fn new(task_ids: Vec<TaskId>, wheel: Arc<Mutex<Wheel>>, completion_rxs: Vec<oneshot::Receiver<TaskCompletionReason>>) -> Self {
120        Self { task_ids, wheel, completion_rxs }
121    }
122
123    /// 批量取消所有定时器
124    ///
125    /// # 返回
126    /// 成功取消的任务数量
127    ///
128    /// # 示例
129    /// ```no_run
130    /// # use kestrel_protocol_timer::{TimerWheel, CallbackWrapper};
131    /// # use std::time::Duration;
132    /// # 
133    /// # #[tokio::main]
134    /// # async fn main() {
135    /// let timer = TimerWheel::with_defaults();
136    /// let delays: Vec<Duration> = (0..10)
137    ///     .map(|_| Duration::from_secs(1))
138    ///     .collect();
139    /// let tasks = TimerWheel::create_batch(delays);
140    /// let batch = timer.register_batch(tasks);
141    /// 
142    /// let cancelled = batch.cancel_all();
143    /// println!("取消了 {} 个定时器", cancelled);
144    /// # }
145    /// ```
146    pub fn cancel_all(self) -> usize {
147        let mut wheel = self.wheel.lock();
148        wheel.cancel_batch(&self.task_ids)
149    }
150
151    /// 将批量句柄转换为单个定时器句柄的 Vec
152    ///
153    /// 消耗 BatchHandle,为每个任务创建独立的 TimerHandle。
154    ///
155    /// # 示例
156    /// ```no_run
157    /// # use kestrel_protocol_timer::{TimerWheel, CallbackWrapper};
158    /// # use std::time::Duration;
159    /// # 
160    /// # #[tokio::main]
161    /// # async fn main() {
162    /// let timer = TimerWheel::with_defaults();
163    /// let delays: Vec<Duration> = (0..3)
164    ///     .map(|_| Duration::from_secs(1))
165    ///     .collect();
166    /// let tasks = TimerWheel::create_batch(delays);
167    /// let batch = timer.register_batch(tasks);
168    /// 
169    /// // 转换为独立的句柄
170    /// let handles = batch.into_handles();
171    /// for handle in handles {
172    ///     // 可以单独操作每个句柄
173    /// }
174    /// # }
175    /// ```
176    pub fn into_handles(self) -> Vec<TimerHandle> {
177        self.task_ids
178            .into_iter()
179            .zip(self.completion_rxs.into_iter())
180            .map(|(task_id, rx)| {
181                TimerHandle::new(task_id, self.wheel.clone(), rx)
182            })
183            .collect()
184    }
185
186    /// 获取批量任务的数量
187    pub fn len(&self) -> usize {
188        self.task_ids.len()
189    }
190
191    /// 检查批量任务是否为空
192    pub fn is_empty(&self) -> bool {
193        self.task_ids.is_empty()
194    }
195
196    /// 获取所有任务 ID 的引用
197    pub fn task_ids(&self) -> &[TaskId] {
198        &self.task_ids
199    }
200
201    /// 获取所有完成通知接收器的引用
202    ///
203    /// # 返回
204    /// 所有任务的完成通知接收器列表引用
205    pub fn completion_receivers(&mut self) -> &mut Vec<oneshot::Receiver<TaskCompletionReason>> {
206        &mut self.completion_rxs
207    }
208
209    /// 消耗句柄,返回所有完成通知接收器
210    ///
211    /// # 返回
212    /// 所有任务的完成通知接收器列表
213    ///
214    /// # 示例
215    /// ```no_run
216    /// # use kestrel_protocol_timer::{TimerWheel, CallbackWrapper};
217    /// # use std::time::Duration;
218    /// # 
219    /// # #[tokio::main]
220    /// # async fn main() {
221    /// let timer = TimerWheel::with_defaults();
222    /// let delays: Vec<Duration> = (0..3)
223    ///     .map(|_| Duration::from_secs(1))
224    ///     .collect();
225    /// let tasks = TimerWheel::create_batch(delays);
226    /// let batch = timer.register_batch(tasks);
227    /// 
228    /// // 获取所有完成通知接收器
229    /// let receivers = batch.into_completion_receivers();
230    /// for rx in receivers {
231    ///     tokio::spawn(async move {
232    ///         if rx.await.is_ok() {
233    ///             println!("A timer completed!");
234    ///         }
235    ///     });
236    /// }
237    /// # }
238    /// ```
239    pub fn into_completion_receivers(self) -> Vec<oneshot::Receiver<TaskCompletionReason>> {
240        self.completion_rxs
241    }
242}
243
244/// 实现 IntoIterator,允许直接迭代 BatchHandle
245/// 
246/// # 示例
247/// ```no_run
248/// # use kestrel_protocol_timer::{TimerWheel, CallbackWrapper};
249/// # use std::time::Duration;
250/// # 
251/// # #[tokio::main]
252/// # async fn main() {
253/// let timer = TimerWheel::with_defaults();
254/// let delays: Vec<Duration> = (0..3)
255///     .map(|_| Duration::from_secs(1))
256///     .collect();
257/// let tasks = TimerWheel::create_batch(delays);
258/// let batch = timer.register_batch(tasks);
259/// 
260/// // 直接迭代,每个元素都是独立的 TimerHandle
261/// for handle in batch {
262///     // 可以单独操作每个句柄
263/// }
264/// # }
265/// ```
266impl IntoIterator for BatchHandle {
267    type Item = TimerHandle;
268    type IntoIter = BatchHandleIter;
269
270    fn into_iter(self) -> Self::IntoIter {
271        BatchHandleIter {
272            task_ids: self.task_ids.into_iter(),
273            completion_rxs: self.completion_rxs.into_iter(),
274            wheel: self.wheel,
275        }
276    }
277}
278
279/// BatchHandle 的迭代器
280pub struct BatchHandleIter {
281    task_ids: std::vec::IntoIter<TaskId>,
282    completion_rxs: std::vec::IntoIter<oneshot::Receiver<TaskCompletionReason>>,
283    wheel: Arc<Mutex<Wheel>>,
284}
285
286impl Iterator for BatchHandleIter {
287    type Item = TimerHandle;
288
289    fn next(&mut self) -> Option<Self::Item> {
290        match (self.task_ids.next(), self.completion_rxs.next()) {
291            (Some(task_id), Some(rx)) => {
292                Some(TimerHandle::new(task_id, self.wheel.clone(), rx))
293            }
294            _ => None,
295        }
296    }
297
298    fn size_hint(&self) -> (usize, Option<usize>) {
299        self.task_ids.size_hint()
300    }
301}
302
303impl ExactSizeIterator for BatchHandleIter {
304    fn len(&self) -> usize {
305        self.task_ids.len()
306    }
307}
308
309/// 时间轮定时器管理器
310pub struct TimerWheel {
311    /// 时间轮唯一标识符
312    
313    /// 时间轮实例(使用 Arc<Mutex> 包装以支持多线程访问)
314    wheel: Arc<Mutex<Wheel>>,
315    
316    /// 后台 tick 循环任务句柄
317    tick_handle: Option<JoinHandle<()>>,
318}
319
320impl TimerWheel {
321    /// 创建新的定时器管理器
322    ///
323    /// # 参数
324    /// - `config`: 时间轮配置(已经过验证)
325    ///
326    /// # 示例
327    /// ```no_run
328    /// use kestrel_protocol_timer::{TimerWheel, WheelConfig, TimerTask, BatchConfig};
329    /// use std::time::Duration;
330    ///
331    /// #[tokio::main]
332    /// async fn main() {
333    ///     let config = WheelConfig::builder()
334    ///         .tick_duration(Duration::from_millis(10))
335    ///         .slot_count(512)
336    ///         .build()
337    ///         .unwrap();
338    ///     let timer = TimerWheel::new(config, BatchConfig::default());
339    ///     
340    ///     // 使用两步式 API
341    ///     let task = TimerWheel::create_task(Duration::from_secs(1), None);
342    ///     let handle = timer.register(task);
343    /// }
344    /// ```
345    pub fn new(config: WheelConfig, batch_config: BatchConfig) -> Self {
346        let tick_duration = config.tick_duration;
347        let wheel = Wheel::new(config, batch_config);
348        let wheel = Arc::new(Mutex::new(wheel));
349        let wheel_clone = wheel.clone();
350
351        // 启动后台 tick 循环
352        let tick_handle = tokio::spawn(async move {
353            Self::tick_loop(wheel_clone, tick_duration).await;
354        });
355
356        Self {
357            wheel,
358            tick_handle: Some(tick_handle),
359        }
360    }
361
362    /// 创建带默认配置的定时器管理器
363    /// - tick 时长: 10ms
364    /// - 槽位数量: 512
365    ///
366    /// # 示例
367    /// ```no_run
368    /// use kestrel_protocol_timer::TimerWheel;
369    ///
370    /// #[tokio::main]
371    /// async fn main() {
372    ///     let timer = TimerWheel::with_defaults();
373    /// }
374    /// ```
375    pub fn with_defaults() -> Self {
376        Self::new(WheelConfig::default(), BatchConfig::default())
377    }
378
379    /// 创建与此时间轮绑定的 TimerService(使用默认配置)
380    ///
381    /// # 返回
382    /// 绑定到此时间轮的 TimerService 实例
383    ///
384    /// # 参数
385    /// - `service_config`: 服务配置
386    ///
387    /// # 示例
388    /// ```no_run
389    /// use kestrel_protocol_timer::{TimerWheel, TimerService, CallbackWrapper, ServiceConfig};
390    /// use std::time::Duration;
391    /// 
392    ///
393    /// #[tokio::main]
394    /// async fn main() {
395    ///     let timer = TimerWheel::with_defaults();
396    ///     let mut service = timer.create_service(ServiceConfig::default());
397    ///     
398    ///     // 使用两步式 API 通过 service 批量调度定时器
399    ///     let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = (0..5)
400    ///         .map(|_| (Duration::from_millis(100), Some(CallbackWrapper::new(|| async {}))))
401    ///         .collect();
402    ///     let tasks = TimerService::create_batch_with_callbacks(callbacks);
403    ///     service.register_batch(tasks).unwrap();
404    ///     
405    ///     // 接收超时通知
406    ///     let mut rx = service.take_receiver().unwrap();
407    ///     while let Some(task_id) = rx.recv().await {
408    ///         println!("Task {:?} completed", task_id);
409    ///     }
410    /// }
411    /// ```
412    pub fn create_service(&self, service_config: ServiceConfig) -> crate::service::TimerService {
413        crate::service::TimerService::new(self.wheel.clone(), service_config)
414    }
415    
416    /// 创建与此时间轮绑定的 TimerService(使用自定义配置)
417    ///
418    /// # 参数
419    /// - `config`: 服务配置
420    ///
421    /// # 返回
422    /// 绑定到此时间轮的 TimerService 实例
423    ///
424    /// # 示例
425    /// ```no_run
426    /// use kestrel_protocol_timer::{TimerWheel, ServiceConfig};
427    ///
428    /// #[tokio::main]
429    /// async fn main() {
430    ///     let timer = TimerWheel::with_defaults();
431    ///     let config = ServiceConfig::builder()
432    ///         .command_channel_capacity(1024)
433    ///         .timeout_channel_capacity(2000)
434    ///         .build()
435    ///         .unwrap();
436    ///     let service = timer.create_service_with_config(config);
437    /// }
438    /// ```
439    pub fn create_service_with_config(&self, config: ServiceConfig) -> crate::service::TimerService {
440        crate::service::TimerService::new(self.wheel.clone(), config)
441    }
442
443    /// 创建定时器任务(申请阶段)
444    /// 
445    /// # 参数
446    /// - `delay`: 延迟时间
447    /// - `callback`: 实现了 TimerCallback trait 的回调对象
448    /// 
449    /// # 返回
450    /// 返回 TimerTask,需要通过 `register()` 注册到时间轮
451    /// 
452    /// # 示例
453    /// ```no_run
454    /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
455    /// use std::time::Duration;
456    /// 
457    /// 
458    /// #[tokio::main]
459    /// async fn main() {
460    ///     let timer = TimerWheel::with_defaults();
461    ///     
462    ///     // 步骤 1: 创建任务
463    ///     let task = TimerWheel::create_task(Duration::from_secs(1), Some(CallbackWrapper::new(|| async {
464    ///         println!("Timer fired!");
465    ///     })));
466    ///     
467    ///     // 获取任务 ID
468    ///     let task_id = task.get_id();
469    ///     println!("Created task: {:?}", task_id);
470    ///     
471    ///     // 步骤 2: 注册任务
472    ///     let handle = timer.register(task);
473    /// }
474    /// ```
475    #[inline]
476    pub fn create_task(delay: Duration, callback: Option<CallbackWrapper>) -> crate::task::TimerTask {
477        crate::task::TimerTask::new(delay, callback)
478    }
479    
480    /// 批量创建定时器任务(申请阶段)
481    /// 
482    /// # 参数
483    /// - `callbacks`: (延迟时间, 回调) 的元组列表
484    /// 
485    /// # 返回
486    /// 返回 TimerTask 列表,需要通过 `register_batch()` 注册到时间轮
487    /// 
488    /// # 示例
489    /// ```no_run
490    /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
491    /// use std::time::Duration;
492    /// use std::sync::Arc;
493    /// use std::sync::atomic::{AtomicU32, Ordering};
494    /// 
495    /// #[tokio::main]
496    /// async fn main() {
497    ///     let timer = TimerWheel::with_defaults();
498    ///     let counter = Arc::new(AtomicU32::new(0));
499    ///     
500    ///     // 步骤 1: 批量创建任务
501    ///     let delays: Vec<Duration> = (0..3)
502    ///         .map(|_| Duration::from_millis(100))
503    ///         .collect();
504    ///     
505    ///     let tasks = TimerWheel::create_batch(delays);
506    ///     println!("Created {} tasks", tasks.len());
507    ///     
508    ///     // 步骤 2: 批量注册任务
509    ///     let batch = timer.register_batch(tasks);
510    /// }
511    /// ```
512    #[inline]
513    pub fn create_batch(delays: Vec<Duration>) -> Vec<crate::task::TimerTask>
514    {
515        delays
516            .into_iter()
517            .map(|delay| crate::task::TimerTask::new(delay, None))
518            .collect()
519    }
520
521    /// 批量创建定时器任务(申请阶段)
522    /// 
523    /// # 参数
524    /// - `callbacks`: (延迟时间, 回调) 的元组列表
525    /// 
526    /// # 返回
527    /// 返回 TimerTask 列表,需要通过 `register_batch()` 注册到时间轮
528    /// 
529    /// # 示例
530    /// ```no_run
531    /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
532    /// use std::time::Duration;
533    /// use std::sync::Arc;
534    /// use std::sync::atomic::{AtomicU32, Ordering};
535    /// 
536    /// #[tokio::main]
537    /// async fn main() {
538    ///     let timer = TimerWheel::with_defaults();
539    ///     let counter = Arc::new(AtomicU32::new(0));
540    ///     
541    ///     // 步骤 1: 批量创建任务
542    ///     let delays: Vec<Duration> = (0..3)
543    ///         .map(|_| Duration::from_millis(100))
544    ///         .collect();
545    ///     let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = delays
546    ///         .into_iter()
547    ///         .map(|delay| {
548    ///             let counter = Arc::clone(&counter);
549    ///             let callback = Some(CallbackWrapper::new(move || {
550    ///                 let counter = Arc::clone(&counter);
551    ///                 async move {
552    ///                     counter.fetch_add(1, Ordering::SeqCst);
553    ///                 }
554    ///             }));
555    ///             (delay, callback)
556    ///         })
557    ///         .collect();
558    ///     
559    ///     let tasks = TimerWheel::create_batch_with_callbacks(callbacks);
560    ///     println!("Created {} tasks", tasks.len());
561    ///     
562    ///     // 步骤 2: 批量注册任务
563    ///     let batch = timer.register_batch(tasks);
564    /// }
565    /// ```
566    #[inline]
567    pub fn create_batch_with_callbacks(callbacks: Vec<(Duration, Option<CallbackWrapper>)>) -> Vec<crate::task::TimerTask>
568    {
569        callbacks
570            .into_iter()
571            .map(|(delay, callback)| crate::task::TimerTask::new(delay, callback))
572            .collect()
573    }
574    
575    /// 注册定时器任务到时间轮(注册阶段)
576    /// 
577    /// # 参数
578    /// - `task`: 通过 `create_task()` 创建的任务
579    /// 
580    /// # 返回
581    /// 返回定时器句柄,可用于取消定时器和接收完成通知
582    /// 
583    /// # 示例
584    /// ```no_run
585    /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
586    /// 
587    /// use std::time::Duration;
588    /// 
589    /// #[tokio::main]
590    /// async fn main() {
591    ///     let timer = TimerWheel::with_defaults();
592    ///     
593    ///     let task = TimerWheel::create_task(Duration::from_secs(1), Some(CallbackWrapper::new(|| async {
594    ///         println!("Timer fired!");
595    ///     })));
596    ///     let task_id = task.get_id();
597    ///     
598    ///     let handle = timer.register(task);
599    ///     
600    ///     // 等待定时器完成
601    ///     handle.into_completion_receiver().0.await.ok();
602    /// }
603    /// ```
604    #[inline]
605    pub fn register(&self, task: crate::task::TimerTask) -> TimerHandle {
606        let (completion_tx, completion_rx) = oneshot::channel();
607        let notifier = crate::task::CompletionNotifier(completion_tx);
608        
609        let task_id = task.id;
610        
611        // 单次加锁完成所有操作
612        {
613            let mut wheel_guard = self.wheel.lock();
614            wheel_guard.insert(task, notifier);
615        }
616        
617        TimerHandle::new(task_id, self.wheel.clone(), completion_rx)
618    }
619    
620    /// 批量注册定时器任务到时间轮(注册阶段)
621    /// 
622    /// # 参数
623    /// - `tasks`: 通过 `create_batch()` 创建的任务列表
624    /// 
625    /// # 返回
626    /// 返回批量定时器句柄
627    /// 
628    /// # 示例
629    /// ```no_run
630    /// use kestrel_protocol_timer::{TimerWheel, TimerTask};
631    /// use std::time::Duration;
632    /// 
633    /// #[tokio::main]
634    /// async fn main() {
635    ///     let timer = TimerWheel::with_defaults();
636    ///     
637    ///     let delays: Vec<Duration> = (0..3)
638    ///         .map(|_| Duration::from_secs(1))
639    ///         .collect();
640    ///     let tasks = TimerWheel::create_batch(delays);
641    ///     
642    ///     let batch = timer.register_batch(tasks);
643    ///     println!("Registered {} timers", batch.len());
644    /// }
645    /// ```
646    #[inline]
647    pub fn register_batch(&self, tasks: Vec<crate::task::TimerTask>) -> BatchHandle {
648        let task_count = tasks.len();
649        let mut completion_rxs = Vec::with_capacity(task_count);
650        let mut task_ids = Vec::with_capacity(task_count);
651        let mut prepared_tasks = Vec::with_capacity(task_count);
652        
653        // 步骤1: 准备所有 channels 和 notifiers(无锁)
654        // 优化:使用 for 循环代替 map + collect,避免闭包捕获开销
655        for task in tasks {
656            let (completion_tx, completion_rx) = oneshot::channel();
657            let notifier = crate::task::CompletionNotifier(completion_tx);
658            
659            task_ids.push(task.id);
660            completion_rxs.push(completion_rx);
661            prepared_tasks.push((task, notifier));
662        }
663        
664        // 步骤2: 单次加锁,批量插入
665        {
666            let mut wheel_guard = self.wheel.lock();
667            wheel_guard.insert_batch(prepared_tasks);
668        }
669        
670        BatchHandle::new(task_ids, self.wheel.clone(), completion_rxs)
671    }
672
673    /// 取消定时器
674    ///
675    /// # 参数
676    /// - `task_id`: 任务 ID
677    ///
678    /// # 返回
679    /// 如果任务存在且成功取消返回 true,否则返回 false
680    /// 
681    /// # 示例
682    /// ```no_run
683    /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
684    /// 
685    /// use std::time::Duration;
686    ///
687    /// #[tokio::main]
688    /// async fn main() {
689    ///     let timer = TimerWheel::with_defaults();
690    ///     
691    ///     let task = TimerWheel::create_task(Duration::from_secs(10), Some(CallbackWrapper::new(|| async {
692    ///         println!("Timer fired!");
693    ///     })));
694    ///     let task_id = task.get_id();
695    ///     let _handle = timer.register(task);
696    ///     
697    ///     // 使用任务 ID 取消
698    ///     let cancelled = timer.cancel(task_id);
699    ///     println!("取消成功: {}", cancelled);
700    /// }
701    /// ```
702    #[inline]
703    pub fn cancel(&self, task_id: TaskId) -> bool {
704        let mut wheel = self.wheel.lock();
705        wheel.cancel(task_id)
706    }
707
708    /// 批量取消定时器
709    ///
710    /// # 参数
711    /// - `task_ids`: 要取消的任务 ID 列表
712    ///
713    /// # 返回
714    /// 成功取消的任务数量
715    ///
716    /// # 性能优势
717    /// - 批量处理减少锁竞争
718    /// - 内部优化批量取消操作
719    ///
720    /// # 示例
721    /// ```no_run
722    /// use kestrel_protocol_timer::{TimerWheel, TimerTask};
723    /// use std::time::Duration;
724    ///
725    /// #[tokio::main]
726    /// async fn main() {
727    ///     let timer = TimerWheel::with_defaults();
728    ///     
729    ///     // 创建多个定时器
730    ///     let task1 = TimerWheel::create_task(Duration::from_secs(10), None);
731    ///     let task2 = TimerWheel::create_task(Duration::from_secs(10), None);
732    ///     let task3 = TimerWheel::create_task(Duration::from_secs(10), None);
733    ///     
734    ///     let task_ids = vec![task1.get_id(), task2.get_id(), task3.get_id()];
735    ///     
736    ///     let _h1 = timer.register(task1);
737    ///     let _h2 = timer.register(task2);
738    ///     let _h3 = timer.register(task3);
739    ///     
740    ///     // 批量取消
741    ///     let cancelled = timer.cancel_batch(&task_ids);
742    ///     println!("已取消 {} 个定时器", cancelled);
743    /// }
744    /// ```
745    #[inline]
746    pub fn cancel_batch(&self, task_ids: &[TaskId]) -> usize {
747        let mut wheel = self.wheel.lock();
748        wheel.cancel_batch(task_ids)
749    }
750
751    /// 推迟定时器
752    ///
753    /// # 参数
754    /// - `task_id`: 要推迟的任务 ID
755    /// - `new_delay`: 新的延迟时间(从当前时间点重新计算)
756    /// - `callback`: 新的回调函数,传入 `None` 保持原回调不变,传入 `Some` 替换为新回调
757    ///
758    /// # 返回
759    /// 如果任务存在且成功推迟返回 true,否则返回 false
760    ///
761    /// # 注意
762    /// - 推迟后任务 ID 保持不变
763    /// - 原有的 completion_receiver 仍然有效
764    ///
765    /// # 示例
766    ///
767    /// ## 保持原回调
768    /// ```no_run
769    /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
770    /// use std::time::Duration;
771    /// 
772    ///
773    /// #[tokio::main]
774    /// async fn main() {
775    ///     let timer = TimerWheel::with_defaults();
776    ///     
777    ///     let task = TimerWheel::create_task(Duration::from_secs(5), Some(CallbackWrapper::new(|| async {
778    ///         println!("Timer fired!");
779    ///     })));
780    ///     let task_id = task.get_id();
781    ///     let _handle = timer.register(task);
782    ///     
783    ///     // 推迟到 10 秒后触发(保持原回调)
784    ///     let success = timer.postpone(task_id, Duration::from_secs(10), None);
785    ///     println!("推迟成功: {}", success);
786    /// }
787    /// ```
788    ///
789    /// ## 替换为新回调
790    /// ```no_run
791    /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
792    /// use std::time::Duration;
793    ///
794    /// #[tokio::main]
795    /// async fn main() {
796    ///     let timer = TimerWheel::with_defaults();
797    ///     
798    ///     let task = TimerWheel::create_task(Duration::from_secs(5), Some(CallbackWrapper::new(|| async {
799    ///         println!("Original callback!");
800    ///     })));
801    ///     let task_id = task.get_id();
802    ///     let _handle = timer.register(task);
803    ///     
804    ///     // 推迟到 10 秒后触发(并替换为新回调)
805    ///     let success = timer.postpone(task_id, Duration::from_secs(10), Some(CallbackWrapper::new(|| async {
806    ///         println!("New callback!");
807    ///     })));
808    ///     println!("推迟成功: {}", success);
809    /// }
810    /// ```
811    #[inline]
812    pub fn postpone(
813        &self,
814        task_id: TaskId,
815        new_delay: Duration,
816        callback: Option<CallbackWrapper>,
817    ) -> bool {
818        let mut wheel = self.wheel.lock();
819        wheel.postpone(task_id, new_delay, callback)
820    }
821
822    /// 批量推迟定时器(保持原回调)
823    ///
824    /// # 参数
825    /// - `updates`: (任务ID, 新延迟) 的元组列表
826    ///
827    /// # 返回
828    /// 成功推迟的任务数量
829    ///
830    /// # 注意
831    /// - 此方法会保持所有任务的原回调不变
832    /// - 如需替换回调,请使用 `postpone_batch_with_callbacks`
833    ///
834    /// # 性能优势
835    /// - 批量处理减少锁竞争
836    /// - 内部优化批量推迟操作
837    ///
838    /// # 示例
839    /// ```no_run
840    /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
841    /// use std::time::Duration;
842    ///
843    /// #[tokio::main]
844    /// async fn main() {
845    ///     let timer = TimerWheel::with_defaults();
846    ///     
847    ///     // 创建多个带回调的定时器
848    ///     let task1 = TimerWheel::create_task(Duration::from_secs(5), Some(CallbackWrapper::new(|| async {
849    ///         println!("Task 1 fired!");
850    ///     })));
851    ///     let task2 = TimerWheel::create_task(Duration::from_secs(5), Some(CallbackWrapper::new(|| async {
852    ///         println!("Task 2 fired!");
853    ///     })));
854    ///     let task3 = TimerWheel::create_task(Duration::from_secs(5), Some(CallbackWrapper::new(|| async {
855    ///         println!("Task 3 fired!");
856    ///     })));
857    ///     
858    ///     let task_ids = vec![
859    ///         (task1.get_id(), Duration::from_secs(10)),
860    ///         (task2.get_id(), Duration::from_secs(15)),
861    ///         (task3.get_id(), Duration::from_secs(20)),
862    ///     ];
863    ///     
864    ///     timer.register(task1);
865    ///     timer.register(task2);
866    ///     timer.register(task3);
867    ///     
868    ///     // 批量推迟(保持原回调)
869    ///     let postponed = timer.postpone_batch(task_ids);
870    ///     println!("已推迟 {} 个定时器", postponed);
871    /// }
872    /// ```
873    #[inline]
874    pub fn postpone_batch(&self, updates: Vec<(TaskId, Duration)>) -> usize {
875        let mut wheel = self.wheel.lock();
876        wheel.postpone_batch(updates)
877    }
878
879    /// 批量推迟定时器(替换回调)
880    ///
881    /// # 参数
882    /// - `updates`: (任务ID, 新延迟, 新回调) 的元组列表
883    ///
884    /// # 返回
885    /// 成功推迟的任务数量
886    ///
887    /// # 性能优势
888    /// - 批量处理减少锁竞争
889    /// - 内部优化批量推迟操作
890    ///
891    /// # 示例
892    /// ```no_run
893    /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
894    /// use std::time::Duration;
895    /// use std::sync::Arc;
896    /// use std::sync::atomic::{AtomicU32, Ordering};
897    ///
898    /// #[tokio::main]
899    /// async fn main() {
900    ///     let timer = TimerWheel::with_defaults();
901    ///     let counter = Arc::new(AtomicU32::new(0));
902    ///     
903    ///     // 创建多个定时器
904    ///     let task1 = TimerWheel::create_task(Duration::from_secs(5), None);
905    ///     let task2 = TimerWheel::create_task(Duration::from_secs(5), None);
906    ///     
907    ///     let id1 = task1.get_id();
908    ///     let id2 = task2.get_id();
909    ///     
910    ///     timer.register(task1);
911    ///     timer.register(task2);
912    ///     
913    ///     // 批量推迟并替换回调
914    ///     let updates: Vec<_> = vec![id1, id2]
915    ///         .into_iter()
916    ///         .map(|id| {
917    ///             let counter = Arc::clone(&counter);
918    ///             (id, Duration::from_secs(10), Some(CallbackWrapper::new(move || {
919    ///                 let counter = Arc::clone(&counter);
920    ///                 async move { counter.fetch_add(1, Ordering::SeqCst); }
921    ///             })))
922    ///         })
923    ///         .collect();
924    ///     let postponed = timer.postpone_batch_with_callbacks(updates);
925    ///     println!("已推迟 {} 个定时器", postponed);
926    /// }
927    /// ```
928    #[inline]
929    pub fn postpone_batch_with_callbacks(
930        &self,
931        updates: Vec<(TaskId, Duration, Option<CallbackWrapper>)>,
932    ) -> usize {
933        let mut wheel = self.wheel.lock();
934        wheel.postpone_batch_with_callbacks(updates.to_vec())
935    }
936    
937    /// 核心 tick 循环
938    async fn tick_loop(wheel: Arc<Mutex<Wheel>>, tick_duration: Duration) {
939        let mut interval = tokio::time::interval(tick_duration);
940        interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
941
942        loop {
943            interval.tick().await;
944
945            // 推进时间轮并获取到期任务
946            let expired_tasks = {
947                let mut wheel_guard = wheel.lock();
948                wheel_guard.advance()
949            };
950
951            // 执行到期任务
952            for task in expired_tasks {
953                let callback = task.get_callback();
954                
955                // 移动task的所有权来获取completion_notifier
956                let notifier = task.completion_notifier;
957                
958                // 只有注册过的任务才有 notifier
959                if let Some(notifier) = notifier {
960                    // 在独立的 tokio 任务中执行回调,并在回调完成后发送通知
961                    if let Some(callback) = callback {
962                        tokio::spawn(async move {
963                            // 执行回调
964                            let future = callback.call();
965                            future.await;
966                            
967                            // 回调执行完成后发送通知
968                            let _ = notifier.0.send(TaskCompletionReason::Expired);
969                        });
970                    } else {
971                        // 如果没有回调,立即发送完成通知
972                        let _ = notifier.0.send(TaskCompletionReason::Expired);
973                    }
974                }
975            }
976        }
977    }
978
979    /// 停止定时器管理器
980    pub async fn shutdown(mut self) {
981        if let Some(handle) = self.tick_handle.take() {
982            handle.abort();
983            let _ = handle.await;
984        }
985    }
986}
987
988impl Drop for TimerWheel {
989    fn drop(&mut self) {
990        if let Some(handle) = self.tick_handle.take() {
991            handle.abort();
992        }
993    }
994}
995
996#[cfg(test)]
997mod tests {
998    use super::*;
999    use std::sync::atomic::{AtomicU32, Ordering};
1000
1001    #[tokio::test]
1002    async fn test_timer_creation() {
1003        let _timer = TimerWheel::with_defaults();
1004    }
1005
1006    #[tokio::test]
1007    async fn test_schedule_once() {
1008        use std::sync::Arc;
1009        let timer = TimerWheel::with_defaults();
1010        let counter = Arc::new(AtomicU32::new(0));
1011        let counter_clone = Arc::clone(&counter);
1012
1013        let task = TimerWheel::create_task(
1014            Duration::from_millis(50),
1015            Some(CallbackWrapper::new(move || {
1016                let counter = Arc::clone(&counter_clone);
1017                async move {
1018                    counter.fetch_add(1, Ordering::SeqCst);
1019                }
1020            })),
1021        );
1022        let _handle = timer.register(task);
1023
1024        // 等待定时器触发
1025        tokio::time::sleep(Duration::from_millis(100)).await;
1026        assert_eq!(counter.load(Ordering::SeqCst), 1);
1027    }
1028
1029    #[tokio::test]
1030    async fn test_cancel_timer() {
1031        use std::sync::Arc;
1032        let timer = TimerWheel::with_defaults();
1033        let counter = Arc::new(AtomicU32::new(0));
1034        let counter_clone = Arc::clone(&counter);
1035
1036        let task = TimerWheel::create_task(
1037            Duration::from_millis(100),
1038            Some(CallbackWrapper::new(move || {
1039                let counter = Arc::clone(&counter_clone);
1040                async move {
1041                    counter.fetch_add(1, Ordering::SeqCst);
1042                }
1043            })),
1044        );
1045        let handle = timer.register(task);
1046
1047        // 立即取消
1048        let cancel_result = handle.cancel();
1049        assert!(cancel_result);
1050
1051        // 等待足够长时间确保定时器不会触发
1052        tokio::time::sleep(Duration::from_millis(200)).await;
1053        assert_eq!(counter.load(Ordering::SeqCst), 0);
1054    }
1055
1056    #[tokio::test]
1057    async fn test_cancel_immediate() {
1058        use std::sync::Arc;
1059        let timer = TimerWheel::with_defaults();
1060        let counter = Arc::new(AtomicU32::new(0));
1061        let counter_clone = Arc::clone(&counter);
1062
1063        let task = TimerWheel::create_task(
1064            Duration::from_millis(100),
1065            Some(CallbackWrapper::new(move || {
1066                let counter = Arc::clone(&counter_clone);
1067                async move {
1068                    counter.fetch_add(1, Ordering::SeqCst);
1069                }
1070            })),
1071        );
1072        let handle = timer.register(task);
1073
1074        // 立即取消
1075        let cancel_result = handle.cancel();
1076        assert!(cancel_result);
1077
1078        // 等待足够长时间确保定时器不会触发
1079        tokio::time::sleep(Duration::from_millis(200)).await;
1080        assert_eq!(counter.load(Ordering::SeqCst), 0);
1081    }
1082
1083    #[tokio::test]
1084    async fn test_postpone_timer() {
1085        use std::sync::Arc;
1086        let timer = TimerWheel::with_defaults();
1087        let counter = Arc::new(AtomicU32::new(0));
1088        let counter_clone = Arc::clone(&counter);
1089
1090        let task = TimerWheel::create_task(
1091            Duration::from_millis(50),
1092            Some(CallbackWrapper::new(move || {
1093                let counter = Arc::clone(&counter_clone);
1094                async move {
1095                    counter.fetch_add(1, Ordering::SeqCst);
1096                }
1097            })),
1098        );
1099        let task_id = task.get_id();
1100        let handle = timer.register(task);
1101
1102        // 推迟任务到 150ms
1103        let postponed = timer.postpone(task_id, Duration::from_millis(150), None);
1104        assert!(postponed);
1105
1106        // 等待原定时间 50ms,任务不应该触发
1107        tokio::time::sleep(Duration::from_millis(70)).await;
1108        assert_eq!(counter.load(Ordering::SeqCst), 0);
1109
1110        // 等待新的触发时间(从推迟开始算,还需要等待约 150ms)
1111        let result = tokio::time::timeout(
1112            Duration::from_millis(200),
1113            handle.into_completion_receiver().0
1114        ).await;
1115        assert!(result.is_ok());
1116        
1117        // 等待回调执行
1118        tokio::time::sleep(Duration::from_millis(20)).await;
1119        assert_eq!(counter.load(Ordering::SeqCst), 1);
1120    }
1121
1122    #[tokio::test]
1123    async fn test_postpone_with_callback() {
1124        use std::sync::Arc;
1125        let timer = TimerWheel::with_defaults();
1126        let counter = Arc::new(AtomicU32::new(0));
1127        let counter_clone1 = Arc::clone(&counter);
1128        let counter_clone2 = Arc::clone(&counter);
1129
1130        // 创建任务,原始回调增加 1
1131        let task = TimerWheel::create_task(
1132            Duration::from_millis(50),
1133            Some(CallbackWrapper::new(move || {
1134                let counter = Arc::clone(&counter_clone1);
1135                async move {
1136                    counter.fetch_add(1, Ordering::SeqCst);
1137                }
1138            })),
1139        );
1140        let task_id = task.get_id();
1141        let handle = timer.register(task);
1142
1143        // 推迟任务并替换回调,新回调增加 10
1144        let postponed = timer.postpone(
1145            task_id,
1146            Duration::from_millis(100),
1147            Some(CallbackWrapper::new(move || {
1148                let counter = Arc::clone(&counter_clone2);
1149                async move {
1150                    counter.fetch_add(10, Ordering::SeqCst);
1151                }
1152            })),
1153        );
1154        assert!(postponed);
1155
1156        // 等待任务触发(推迟后需要等待100ms,加上余量)
1157        let result = tokio::time::timeout(
1158            Duration::from_millis(200),
1159            handle.into_completion_receiver().0
1160        ).await;
1161        assert!(result.is_ok());
1162        
1163        // 等待回调执行
1164        tokio::time::sleep(Duration::from_millis(20)).await;
1165        
1166        // 验证新回调被执行(增加了 10 而不是 1)
1167        assert_eq!(counter.load(Ordering::SeqCst), 10);
1168    }
1169
1170    #[tokio::test]
1171    async fn test_postpone_nonexistent_timer() {
1172        let timer = TimerWheel::with_defaults();
1173        
1174        // 尝试推迟不存在的任务
1175        let fake_task = TimerWheel::create_task(Duration::from_millis(50), None);
1176        let fake_task_id = fake_task.get_id();
1177        // 不注册这个任务
1178        
1179        let postponed = timer.postpone(fake_task_id, Duration::from_millis(100), None);
1180        assert!(!postponed);
1181    }
1182
1183    #[tokio::test]
1184    async fn test_postpone_batch() {
1185        use std::sync::Arc;
1186        let timer = TimerWheel::with_defaults();
1187        let counter = Arc::new(AtomicU32::new(0));
1188
1189        // 创建 3 个任务
1190        let mut task_ids = Vec::new();
1191        for _ in 0..3 {
1192            let counter_clone = Arc::clone(&counter);
1193            let task = TimerWheel::create_task(
1194                Duration::from_millis(50),
1195                Some(CallbackWrapper::new(move || {
1196                    let counter = Arc::clone(&counter_clone);
1197                    async move {
1198                        counter.fetch_add(1, Ordering::SeqCst);
1199                    }
1200                })),
1201            );
1202            task_ids.push((task.get_id(), Duration::from_millis(150)));
1203            timer.register(task);
1204        }
1205
1206        // 批量推迟
1207        let postponed = timer.postpone_batch(task_ids);
1208        assert_eq!(postponed, 3);
1209
1210        // 等待原定时间 50ms,任务不应该触发
1211        tokio::time::sleep(Duration::from_millis(70)).await;
1212        assert_eq!(counter.load(Ordering::SeqCst), 0);
1213
1214        // 等待新的触发时间(从推迟开始算,还需要等待约 150ms)
1215        tokio::time::sleep(Duration::from_millis(200)).await;
1216        
1217        // 等待回调执行
1218        tokio::time::sleep(Duration::from_millis(20)).await;
1219        assert_eq!(counter.load(Ordering::SeqCst), 3);
1220    }
1221
1222    #[tokio::test]
1223    async fn test_postpone_batch_with_callbacks() {
1224        use std::sync::Arc;
1225        let timer = TimerWheel::with_defaults();
1226        let counter = Arc::new(AtomicU32::new(0));
1227
1228        // 创建 3 个任务
1229        let mut task_ids = Vec::new();
1230        for _ in 0..3 {
1231            let task = TimerWheel::create_task(
1232                Duration::from_millis(50),
1233                None
1234            );
1235            task_ids.push(task.get_id());
1236            timer.register(task);
1237        }
1238
1239        // 批量推迟并替换回调
1240        let updates: Vec<_> = task_ids
1241            .into_iter()
1242            .map(|id| {
1243                let counter_clone = Arc::clone(&counter);
1244                (id, Duration::from_millis(150), Some(CallbackWrapper::new(move || {
1245                    let counter = Arc::clone(&counter_clone);
1246                    async move {
1247                        counter.fetch_add(1, Ordering::SeqCst);
1248                    }
1249                })))
1250            })
1251            .collect();
1252
1253        let postponed = timer.postpone_batch_with_callbacks(updates);
1254        assert_eq!(postponed, 3);
1255
1256        // 等待原定时间 50ms,任务不应该触发
1257        tokio::time::sleep(Duration::from_millis(70)).await;
1258        assert_eq!(counter.load(Ordering::SeqCst), 0);
1259
1260        // 等待新的触发时间(从推迟开始算,还需要等待约 150ms)
1261        tokio::time::sleep(Duration::from_millis(200)).await;
1262        
1263        // 等待回调执行
1264        tokio::time::sleep(Duration::from_millis(20)).await;
1265        assert_eq!(counter.load(Ordering::SeqCst), 3);
1266    }
1267
1268    #[tokio::test]
1269    async fn test_postpone_keeps_completion_receiver_valid() {
1270        use std::sync::Arc;
1271        let timer = TimerWheel::with_defaults();
1272        let counter = Arc::new(AtomicU32::new(0));
1273        let counter_clone = Arc::clone(&counter);
1274
1275        let task = TimerWheel::create_task(
1276            Duration::from_millis(50),
1277            Some(CallbackWrapper::new(move || {
1278                let counter = Arc::clone(&counter_clone);
1279                async move {
1280                    counter.fetch_add(1, Ordering::SeqCst);
1281                }
1282            })),
1283        );
1284        let task_id = task.get_id();
1285        let handle = timer.register(task);
1286
1287        // 推迟任务
1288        timer.postpone(task_id, Duration::from_millis(100), None);
1289
1290        // 验证原 completion_receiver 仍然有效(推迟后需要等待100ms,加上余量)
1291        let result = tokio::time::timeout(
1292            Duration::from_millis(200),
1293            handle.into_completion_receiver().0
1294        ).await;
1295        assert!(result.is_ok(), "Completion receiver should still work after postpone");
1296        
1297        // 等待回调执行
1298        tokio::time::sleep(Duration::from_millis(20)).await;
1299        assert_eq!(counter.load(Ordering::SeqCst), 1);
1300    }
1301}
1302