kestrel_protocol_timer/
timer.rs

1use crate::config::{ServiceConfig, WheelConfig};
2use crate::task::{CallbackWrapper, TaskId, TaskCompletionReason};
3use crate::wheel::Wheel;
4use parking_lot::Mutex;
5use std::sync::Arc;
6use std::time::Duration;
7use tokio::sync::oneshot;
8use tokio::task::JoinHandle;
9
10/// 完成通知接收器,用于接收定时器完成通知
11pub struct CompletionReceiver(pub oneshot::Receiver<TaskCompletionReason>);
12
13/// 定时器句柄,用于管理定时器的生命周期
14/// 
15/// 注意:此类型不实现 Clone,以防止重复取消同一个定时器。
16/// 每个定时器只应有一个所有者。
17pub struct TimerHandle {
18    pub(crate) task_id: TaskId,
19    pub(crate) wheel: Arc<Mutex<Wheel>>,
20    pub(crate) completion_rx: CompletionReceiver,
21}
22
23impl TimerHandle {
24    pub(crate) fn new(task_id: TaskId, wheel: Arc<Mutex<Wheel>>, completion_rx: oneshot::Receiver<TaskCompletionReason>) -> Self {
25        Self { task_id, wheel, completion_rx: CompletionReceiver(completion_rx) }
26    }
27
28    /// 取消定时器
29    ///
30    /// # 返回
31    /// 如果任务存在且成功取消返回 true,否则返回 false
32    ///
33    /// # 示例
34    /// ```no_run
35    /// # use kestrel_protocol_timer::{TimerWheel, CallbackWrapper};
36    /// # use std::time::Duration;
37    /// # 
38    /// # #[tokio::main]
39    /// # async fn main() {
40    /// let timer = TimerWheel::with_defaults();
41    /// let callback = Some(CallbackWrapper::new(|| async {}));
42    /// let task = TimerWheel::create_task(Duration::from_secs(1), callback);
43    /// let handle = timer.register(task);
44    /// 
45    /// // 取消定时器
46    /// let success = handle.cancel();
47    /// println!("取消成功: {}", success);
48    /// # }
49    /// ```
50    pub fn cancel(&self) -> bool {
51        let mut wheel = self.wheel.lock();
52        wheel.cancel(self.task_id)
53    }
54
55    /// 获取完成通知接收器的可变引用
56    ///
57    /// # 示例
58    /// ```no_run
59    /// # use kestrel_protocol_timer::{TimerWheel, CallbackWrapper};
60    /// # use std::time::Duration;
61    /// # 
62    /// # #[tokio::main]
63    /// # async fn main() {
64    /// let timer = TimerWheel::with_defaults();
65    /// let callback = Some(CallbackWrapper::new(|| async {
66    ///     println!("Timer fired!");
67    /// }));
68    /// let task = TimerWheel::create_task(Duration::from_secs(1), callback);
69    /// let handle = timer.register(task);
70    /// 
71    /// // 等待定时器完成(使用 into_completion_receiver 消耗句柄)
72    /// handle.into_completion_receiver().0.await.ok();
73    /// println!("Timer completed!");
74    /// # }
75    /// ```
76    pub fn completion_receiver(&mut self) -> &mut CompletionReceiver {
77        &mut self.completion_rx
78    }
79
80    /// 消耗句柄,返回完成通知接收器
81    ///
82    /// # 示例
83    /// ```no_run
84    /// # use kestrel_protocol_timer::{TimerWheel, CallbackWrapper};
85    /// # use std::time::Duration;
86    /// # 
87    /// # #[tokio::main]
88    /// # async fn main() {
89    /// let timer = TimerWheel::with_defaults();
90    /// let callback = Some(CallbackWrapper::new(|| async {
91    ///     println!("Timer fired!");
92    /// }));
93    /// let task = TimerWheel::create_task(Duration::from_secs(1), callback);
94    /// let handle = timer.register(task);
95    /// 
96    /// // 等待定时器完成
97    /// handle.into_completion_receiver().0.await.ok();
98    /// println!("Timer completed!");
99    /// # }
100    /// ```
101    pub fn into_completion_receiver(self) -> CompletionReceiver {
102        self.completion_rx
103    }
104}
105
106/// 批量定时器句柄,用于管理批量调度的定时器
107/// 
108/// 通过共享 Wheel 引用减少内存开销,同时提供批量操作和迭代器访问能力。
109/// 
110/// 注意:此类型不实现 Clone,以防止重复取消同一批定时器。
111/// 如需访问单个定时器句柄,请使用 `into_iter()` 或 `into_handles()` 进行转换。
112pub struct BatchHandle {
113    pub(crate) task_ids: Vec<TaskId>,
114    pub(crate) wheel: Arc<Mutex<Wheel>>,
115    pub(crate) completion_rxs: Vec<oneshot::Receiver<TaskCompletionReason>>,
116}
117
118impl BatchHandle {
119    pub(crate) fn new(task_ids: Vec<TaskId>, wheel: Arc<Mutex<Wheel>>, completion_rxs: Vec<oneshot::Receiver<TaskCompletionReason>>) -> Self {
120        Self { task_ids, wheel, completion_rxs }
121    }
122
123    /// 批量取消所有定时器
124    ///
125    /// # 返回
126    /// 成功取消的任务数量
127    ///
128    /// # 示例
129    /// ```no_run
130    /// # use kestrel_protocol_timer::{TimerWheel, CallbackWrapper};
131    /// # use std::time::Duration;
132    /// # 
133    /// # #[tokio::main]
134    /// # async fn main() {
135    /// let timer = TimerWheel::with_defaults();
136    /// let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = (0..10)
137    ///     .map(|_| (Duration::from_secs(1), Some(CallbackWrapper::new(|| async {}))))
138    ///     .collect();
139    /// let tasks = TimerWheel::create_batch(callbacks);
140    /// let batch = timer.register_batch(tasks);
141    /// 
142    /// let cancelled = batch.cancel_all();
143    /// println!("取消了 {} 个定时器", cancelled);
144    /// # }
145    /// ```
146    pub fn cancel_all(self) -> usize {
147        let mut wheel = self.wheel.lock();
148        wheel.cancel_batch(&self.task_ids)
149    }
150
151    /// 将批量句柄转换为单个定时器句柄的 Vec
152    ///
153    /// 消耗 BatchHandle,为每个任务创建独立的 TimerHandle。
154    ///
155    /// # 示例
156    /// ```no_run
157    /// # use kestrel_protocol_timer::{TimerWheel, CallbackWrapper};
158    /// # use std::time::Duration;
159    /// # 
160    /// # #[tokio::main]
161    /// # async fn main() {
162    /// let timer = TimerWheel::with_defaults();
163    /// let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = (0..3)
164    ///     .map(|_| (Duration::from_secs(1), Some(CallbackWrapper::new(|| async {}))))
165    ///     .collect();
166    /// let tasks = TimerWheel::create_batch(callbacks);
167    /// let batch = timer.register_batch(tasks);
168    /// 
169    /// // 转换为独立的句柄
170    /// let handles = batch.into_handles();
171    /// for handle in handles {
172    ///     // 可以单独操作每个句柄
173    /// }
174    /// # }
175    /// ```
176    pub fn into_handles(self) -> Vec<TimerHandle> {
177        self.task_ids
178            .into_iter()
179            .zip(self.completion_rxs.into_iter())
180            .map(|(task_id, rx)| {
181                TimerHandle::new(task_id, self.wheel.clone(), rx)
182            })
183            .collect()
184    }
185
186    /// 获取批量任务的数量
187    pub fn len(&self) -> usize {
188        self.task_ids.len()
189    }
190
191    /// 检查批量任务是否为空
192    pub fn is_empty(&self) -> bool {
193        self.task_ids.is_empty()
194    }
195
196    /// 获取所有任务 ID 的引用
197    pub fn task_ids(&self) -> &[TaskId] {
198        &self.task_ids
199    }
200
201    /// 获取所有完成通知接收器的引用
202    ///
203    /// # 返回
204    /// 所有任务的完成通知接收器列表引用
205    pub fn completion_receivers(&mut self) -> &mut Vec<oneshot::Receiver<TaskCompletionReason>> {
206        &mut self.completion_rxs
207    }
208
209    /// 消耗句柄,返回所有完成通知接收器
210    ///
211    /// # 返回
212    /// 所有任务的完成通知接收器列表
213    ///
214    /// # 示例
215    /// ```no_run
216    /// # use kestrel_protocol_timer::{TimerWheel, CallbackWrapper};
217    /// # use std::time::Duration;
218    /// # 
219    /// # #[tokio::main]
220    /// # async fn main() {
221    /// let timer = TimerWheel::with_defaults();
222    /// let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = (0..3)
223    ///     .map(|_| (Duration::from_secs(1), Some(CallbackWrapper::new(|| async {}))))
224    ///     .collect();
225    /// let tasks = TimerWheel::create_batch(callbacks);
226    /// let batch = timer.register_batch(tasks);
227    /// 
228    /// // 获取所有完成通知接收器
229    /// let receivers = batch.into_completion_receivers();
230    /// for rx in receivers {
231    ///     tokio::spawn(async move {
232    ///         if rx.await.is_ok() {
233    ///             println!("A timer completed!");
234    ///         }
235    ///     });
236    /// }
237    /// # }
238    /// ```
239    pub fn into_completion_receivers(self) -> Vec<oneshot::Receiver<TaskCompletionReason>> {
240        self.completion_rxs
241    }
242}
243
244/// 实现 IntoIterator,允许直接迭代 BatchHandle
245/// 
246/// # 示例
247/// ```no_run
248/// # use kestrel_protocol_timer::{TimerWheel, CallbackWrapper};
249/// # use std::time::Duration;
250/// # 
251/// # #[tokio::main]
252/// # async fn main() {
253/// let timer = TimerWheel::with_defaults();
254/// let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = (0..3)
255///     .map(|_| (Duration::from_secs(1), Some(CallbackWrapper::new(|| async {}))))
256///     .collect();
257/// let tasks = TimerWheel::create_batch(callbacks);
258/// let batch = timer.register_batch(tasks);
259/// 
260/// // 直接迭代,每个元素都是独立的 TimerHandle
261/// for handle in batch {
262///     // 可以单独操作每个句柄
263/// }
264/// # }
265/// ```
266impl IntoIterator for BatchHandle {
267    type Item = TimerHandle;
268    type IntoIter = BatchHandleIter;
269
270    fn into_iter(self) -> Self::IntoIter {
271        BatchHandleIter {
272            task_ids: self.task_ids.into_iter(),
273            completion_rxs: self.completion_rxs.into_iter(),
274            wheel: self.wheel,
275        }
276    }
277}
278
279/// BatchHandle 的迭代器
280pub struct BatchHandleIter {
281    task_ids: std::vec::IntoIter<TaskId>,
282    completion_rxs: std::vec::IntoIter<oneshot::Receiver<TaskCompletionReason>>,
283    wheel: Arc<Mutex<Wheel>>,
284}
285
286impl Iterator for BatchHandleIter {
287    type Item = TimerHandle;
288
289    fn next(&mut self) -> Option<Self::Item> {
290        match (self.task_ids.next(), self.completion_rxs.next()) {
291            (Some(task_id), Some(rx)) => {
292                Some(TimerHandle::new(task_id, self.wheel.clone(), rx))
293            }
294            _ => None,
295        }
296    }
297
298    fn size_hint(&self) -> (usize, Option<usize>) {
299        self.task_ids.size_hint()
300    }
301}
302
303impl ExactSizeIterator for BatchHandleIter {
304    fn len(&self) -> usize {
305        self.task_ids.len()
306    }
307}
308
309/// 时间轮定时器管理器
310pub struct TimerWheel {
311    /// 时间轮唯一标识符
312    
313    /// 时间轮实例(使用 Arc<Mutex> 包装以支持多线程访问)
314    wheel: Arc<Mutex<Wheel>>,
315    
316    /// 后台 tick 循环任务句柄
317    tick_handle: Option<JoinHandle<()>>,
318}
319
320impl TimerWheel {
321    /// 创建新的定时器管理器
322    ///
323    /// # 参数
324    /// - `config`: 时间轮配置(已经过验证)
325    ///
326    /// # 示例
327    /// ```no_run
328    /// use kestrel_protocol_timer::{TimerWheel, WheelConfig, TimerTask};
329    /// use std::time::Duration;
330    ///
331    /// #[tokio::main]
332    /// async fn main() {
333    ///     let config = WheelConfig::builder()
334    ///         .tick_duration(Duration::from_millis(10))
335    ///         .slot_count(512)
336    ///         .build()
337    ///         .unwrap();
338    ///     let timer = TimerWheel::new(config);
339    ///     
340    ///     // 使用两步式 API
341    ///     let task = TimerWheel::create_task(Duration::from_secs(1), None);
342    ///     let handle = timer.register(task);
343    /// }
344    /// ```
345    pub fn new(config: WheelConfig) -> Self {
346        let tick_duration = config.tick_duration;
347        let wheel = Wheel::new(config);
348        let wheel = Arc::new(Mutex::new(wheel));
349        let wheel_clone = wheel.clone();
350
351        // 启动后台 tick 循环
352        let tick_handle = tokio::spawn(async move {
353            Self::tick_loop(wheel_clone, tick_duration).await;
354        });
355
356        Self {
357            wheel,
358            tick_handle: Some(tick_handle),
359        }
360    }
361
362    /// 创建带默认配置的定时器管理器
363    /// - tick 时长: 10ms
364    /// - 槽位数量: 512
365    ///
366    /// # 示例
367    /// ```no_run
368    /// use kestrel_protocol_timer::TimerWheel;
369    ///
370    /// #[tokio::main]
371    /// async fn main() {
372    ///     let timer = TimerWheel::with_defaults();
373    /// }
374    /// ```
375    pub fn with_defaults() -> Self {
376        Self::new(WheelConfig::default())
377    }
378
379    /// 创建与此时间轮绑定的 TimerService(使用默认配置)
380    ///
381    /// # 返回
382    /// 绑定到此时间轮的 TimerService 实例
383    ///
384    /// # 示例
385    /// ```no_run
386    /// use kestrel_protocol_timer::{TimerWheel, TimerService, CallbackWrapper};
387    /// use std::time::Duration;
388    /// 
389    ///
390    /// #[tokio::main]
391    /// async fn main() {
392    ///     let timer = TimerWheel::with_defaults();
393    ///     let mut service = timer.create_service();
394    ///     
395    ///     // 使用两步式 API 通过 service 批量调度定时器
396    ///     let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = (0..5)
397    ///         .map(|_| (Duration::from_millis(100), Some(CallbackWrapper::new(|| async {}))))
398    ///         .collect();
399    ///     let tasks = TimerService::create_batch(callbacks);
400    ///     service.register_batch(tasks).unwrap();
401    ///     
402    ///     // 接收超时通知
403    ///     let mut rx = service.take_receiver().unwrap();
404    ///     while let Some(task_id) = rx.recv().await {
405    ///         println!("Task {:?} completed", task_id);
406    ///     }
407    /// }
408    /// ```
409    pub fn create_service(&self) -> crate::service::TimerService {
410        crate::service::TimerService::new(self.wheel.clone(), ServiceConfig::default())
411    }
412    
413    /// 创建与此时间轮绑定的 TimerService(使用自定义配置)
414    ///
415    /// # 参数
416    /// - `config`: 服务配置
417    ///
418    /// # 返回
419    /// 绑定到此时间轮的 TimerService 实例
420    ///
421    /// # 示例
422    /// ```no_run
423    /// use kestrel_protocol_timer::{TimerWheel, ServiceConfig};
424    ///
425    /// #[tokio::main]
426    /// async fn main() {
427    ///     let timer = TimerWheel::with_defaults();
428    ///     let config = ServiceConfig::builder()
429    ///         .command_channel_capacity(1024)
430    ///         .timeout_channel_capacity(2000)
431    ///         .build()
432    ///         .unwrap();
433    ///     let service = timer.create_service_with_config(config);
434    /// }
435    /// ```
436    pub fn create_service_with_config(&self, config: ServiceConfig) -> crate::service::TimerService {
437        crate::service::TimerService::new(self.wheel.clone(), config)
438    }
439
440    /// 创建定时器任务(申请阶段)
441    /// 
442    /// # 参数
443    /// - `delay`: 延迟时间
444    /// - `callback`: 实现了 TimerCallback trait 的回调对象
445    /// 
446    /// # 返回
447    /// 返回 TimerTask,需要通过 `register()` 注册到时间轮
448    /// 
449    /// # 示例
450    /// ```no_run
451    /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
452    /// use std::time::Duration;
453    /// 
454    /// 
455    /// #[tokio::main]
456    /// async fn main() {
457    ///     let timer = TimerWheel::with_defaults();
458    ///     
459    ///     // 步骤 1: 创建任务
460    ///     let task = TimerWheel::create_task(Duration::from_secs(1), Some(CallbackWrapper::new(|| async {
461    ///         println!("Timer fired!");
462    ///     })));
463    ///     
464    ///     // 获取任务 ID
465    ///     let task_id = task.get_id();
466    ///     println!("Created task: {:?}", task_id);
467    ///     
468    ///     // 步骤 2: 注册任务
469    ///     let handle = timer.register(task);
470    /// }
471    /// ```
472    #[inline]
473    pub fn create_task(delay: Duration, callback: Option<CallbackWrapper>) -> crate::task::TimerTask {
474        crate::task::TimerTask::new(delay, callback)
475    }
476    
477    /// 批量创建定时器任务(申请阶段)
478    /// 
479    /// # 参数
480    /// - `callbacks`: (延迟时间, 回调) 的元组列表
481    /// 
482    /// # 返回
483    /// 返回 TimerTask 列表,需要通过 `register_batch()` 注册到时间轮
484    /// 
485    /// # 示例
486    /// ```no_run
487    /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
488    /// use std::time::Duration;
489    /// use std::sync::Arc;
490    /// use std::sync::atomic::{AtomicU32, Ordering};
491    /// 
492    /// #[tokio::main]
493    /// async fn main() {
494    ///     let timer = TimerWheel::with_defaults();
495    ///     let counter = Arc::new(AtomicU32::new(0));
496    ///     
497    ///     // 步骤 1: 批量创建任务
498    ///     let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = (0..3)
499    ///         .map(|i| {
500    ///             let counter = Arc::clone(&counter);
501    ///             let delay = Duration::from_millis(100 + i * 100);
502    ///             let callback = Some(CallbackWrapper::new(move || {
503    ///                 let counter = Arc::clone(&counter);
504    ///                 async move {
505    ///                     counter.fetch_add(1, Ordering::SeqCst);
506    ///                 }
507    ///             }));
508    ///             (delay, callback)
509    ///         })
510    ///         .collect();
511    ///     
512    ///     let tasks = TimerWheel::create_batch(callbacks);
513    ///     println!("Created {} tasks", tasks.len());
514    ///     
515    ///     // 步骤 2: 批量注册任务
516    ///     let batch = timer.register_batch(tasks);
517    /// }
518    /// ```
519    #[inline]
520    pub fn create_batch(callbacks: Vec<(Duration, Option<CallbackWrapper>)>) -> Vec<crate::task::TimerTask>
521    {
522        callbacks
523            .into_iter()
524            .map(|(delay, callback)| crate::task::TimerTask::new(delay, callback))
525            .collect()
526    }
527    
528    /// 注册定时器任务到时间轮(注册阶段)
529    /// 
530    /// # 参数
531    /// - `task`: 通过 `create_task()` 创建的任务
532    /// 
533    /// # 返回
534    /// 返回定时器句柄,可用于取消定时器和接收完成通知
535    /// 
536    /// # 示例
537    /// ```no_run
538    /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
539    /// 
540    /// use std::time::Duration;
541    /// 
542    /// #[tokio::main]
543    /// async fn main() {
544    ///     let timer = TimerWheel::with_defaults();
545    ///     
546    ///     let task = TimerWheel::create_task(Duration::from_secs(1), Some(CallbackWrapper::new(|| async {
547    ///         println!("Timer fired!");
548    ///     })));
549    ///     let task_id = task.get_id();
550    ///     
551    ///     let handle = timer.register(task);
552    ///     
553    ///     // 等待定时器完成
554    ///     handle.into_completion_receiver().0.await.ok();
555    /// }
556    /// ```
557    #[inline]
558    pub fn register(&self, task: crate::task::TimerTask) -> TimerHandle {
559        let (completion_tx, completion_rx) = oneshot::channel();
560        let notifier = crate::task::CompletionNotifier(completion_tx);
561        
562        let delay = task.delay;
563        let task_id = task.id;
564        
565        // 单次加锁完成所有操作
566        {
567            let mut wheel_guard = self.wheel.lock();
568            wheel_guard.insert(delay, task, notifier);
569        }
570        
571        TimerHandle::new(task_id, self.wheel.clone(), completion_rx)
572    }
573    
574    /// 批量注册定时器任务到时间轮(注册阶段)
575    /// 
576    /// # 参数
577    /// - `tasks`: 通过 `create_batch()` 创建的任务列表
578    /// 
579    /// # 返回
580    /// 返回批量定时器句柄
581    /// 
582    /// # 示例
583    /// ```no_run
584    /// use kestrel_protocol_timer::{TimerWheel, TimerTask};
585    /// use std::time::Duration;
586    /// 
587    /// #[tokio::main]
588    /// async fn main() {
589    ///     let timer = TimerWheel::with_defaults();
590    ///     
591    ///     let callbacks: Vec<_> = (0..3)
592    ///         .map(|_| (Duration::from_secs(1), None))
593    ///         .collect();
594    ///     let tasks = TimerWheel::create_batch(callbacks);
595    ///     
596    ///     let batch = timer.register_batch(tasks);
597    ///     println!("Registered {} timers", batch.len());
598    /// }
599    /// ```
600    #[inline]
601    pub fn register_batch(&self, tasks: Vec<crate::task::TimerTask>) -> BatchHandle {
602        let task_count = tasks.len();
603        let mut completion_rxs = Vec::with_capacity(task_count);
604        let mut task_ids = Vec::with_capacity(task_count);
605        let mut prepared_tasks = Vec::with_capacity(task_count);
606        
607        // 步骤1: 准备所有 channels 和 notifiers(无锁)
608        // 优化:使用 for 循环代替 map + collect,避免闭包捕获开销
609        for task in tasks {
610            let (completion_tx, completion_rx) = oneshot::channel();
611            let notifier = crate::task::CompletionNotifier(completion_tx);
612            
613            task_ids.push(task.id);
614            completion_rxs.push(completion_rx);
615            prepared_tasks.push((task.delay, task, notifier));
616        }
617        
618        // 步骤2: 单次加锁,批量插入
619        {
620            let mut wheel_guard = self.wheel.lock();
621            wheel_guard.insert_batch(prepared_tasks);
622        }
623        
624        BatchHandle::new(task_ids, self.wheel.clone(), completion_rxs)
625    }
626
627    /// 取消定时器
628    ///
629    /// # 参数
630    /// - `task_id`: 任务 ID
631    ///
632    /// # 返回
633    /// 如果任务存在且成功取消返回 true,否则返回 false
634    /// 
635    /// # 示例
636    /// ```no_run
637    /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
638    /// 
639    /// use std::time::Duration;
640    ///
641    /// #[tokio::main]
642    /// async fn main() {
643    ///     let timer = TimerWheel::with_defaults();
644    ///     
645    ///     let task = TimerWheel::create_task(Duration::from_secs(10), Some(CallbackWrapper::new(|| async {
646    ///         println!("Timer fired!");
647    ///     })));
648    ///     let task_id = task.get_id();
649    ///     let _handle = timer.register(task);
650    ///     
651    ///     // 使用任务 ID 取消
652    ///     let cancelled = timer.cancel(task_id);
653    ///     println!("取消成功: {}", cancelled);
654    /// }
655    /// ```
656    #[inline]
657    pub fn cancel(&self, task_id: TaskId) -> bool {
658        let mut wheel = self.wheel.lock();
659        wheel.cancel(task_id)
660    }
661
662    /// 批量取消定时器
663    ///
664    /// # 参数
665    /// - `task_ids`: 要取消的任务 ID 列表
666    ///
667    /// # 返回
668    /// 成功取消的任务数量
669    ///
670    /// # 性能优势
671    /// - 批量处理减少锁竞争
672    /// - 内部优化批量取消操作
673    ///
674    /// # 示例
675    /// ```no_run
676    /// use kestrel_protocol_timer::{TimerWheel, TimerTask};
677    /// use std::time::Duration;
678    ///
679    /// #[tokio::main]
680    /// async fn main() {
681    ///     let timer = TimerWheel::with_defaults();
682    ///     
683    ///     // 创建多个定时器
684    ///     let task1 = TimerWheel::create_task(Duration::from_secs(10), None);
685    ///     let task2 = TimerWheel::create_task(Duration::from_secs(10), None);
686    ///     let task3 = TimerWheel::create_task(Duration::from_secs(10), None);
687    ///     
688    ///     let task_ids = vec![task1.get_id(), task2.get_id(), task3.get_id()];
689    ///     
690    ///     let _h1 = timer.register(task1);
691    ///     let _h2 = timer.register(task2);
692    ///     let _h3 = timer.register(task3);
693    ///     
694    ///     // 批量取消
695    ///     let cancelled = timer.cancel_batch(&task_ids);
696    ///     println!("已取消 {} 个定时器", cancelled);
697    /// }
698    /// ```
699    #[inline]
700    pub fn cancel_batch(&self, task_ids: &[TaskId]) -> usize {
701        let mut wheel = self.wheel.lock();
702        wheel.cancel_batch(task_ids)
703    }
704
705    /// 推迟定时器
706    ///
707    /// # 参数
708    /// - `task_id`: 要推迟的任务 ID
709    /// - `new_delay`: 新的延迟时间(从当前时间点重新计算)
710    /// - `callback`: 新的回调函数,传入 `None` 保持原回调不变,传入 `Some` 替换为新回调
711    ///
712    /// # 返回
713    /// 如果任务存在且成功推迟返回 true,否则返回 false
714    ///
715    /// # 注意
716    /// - 推迟后任务 ID 保持不变
717    /// - 原有的 completion_receiver 仍然有效
718    ///
719    /// # 示例
720    ///
721    /// ## 保持原回调
722    /// ```no_run
723    /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
724    /// use std::time::Duration;
725    /// 
726    ///
727    /// #[tokio::main]
728    /// async fn main() {
729    ///     let timer = TimerWheel::with_defaults();
730    ///     
731    ///     let task = TimerWheel::create_task(Duration::from_secs(5), Some(CallbackWrapper::new(|| async {
732    ///         println!("Timer fired!");
733    ///     })));
734    ///     let task_id = task.get_id();
735    ///     let _handle = timer.register(task);
736    ///     
737    ///     // 推迟到 10 秒后触发(保持原回调)
738    ///     let success = timer.postpone(task_id, Duration::from_secs(10), None);
739    ///     println!("推迟成功: {}", success);
740    /// }
741    /// ```
742    ///
743    /// ## 替换为新回调
744    /// ```no_run
745    /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
746    /// use std::time::Duration;
747    ///
748    /// #[tokio::main]
749    /// async fn main() {
750    ///     let timer = TimerWheel::with_defaults();
751    ///     
752    ///     let task = TimerWheel::create_task(Duration::from_secs(5), Some(CallbackWrapper::new(|| async {
753    ///         println!("Original callback!");
754    ///     })));
755    ///     let task_id = task.get_id();
756    ///     let _handle = timer.register(task);
757    ///     
758    ///     // 推迟到 10 秒后触发(并替换为新回调)
759    ///     let success = timer.postpone(task_id, Duration::from_secs(10), Some(CallbackWrapper::new(|| async {
760    ///         println!("New callback!");
761    ///     })));
762    ///     println!("推迟成功: {}", success);
763    /// }
764    /// ```
765    #[inline]
766    pub fn postpone(
767        &self,
768        task_id: TaskId,
769        new_delay: Duration,
770        callback: Option<CallbackWrapper>,
771    ) -> bool {
772        let mut wheel = self.wheel.lock();
773        wheel.postpone(task_id, new_delay, callback)
774    }
775
776    /// 批量推迟定时器(保持原回调)
777    ///
778    /// # 参数
779    /// - `updates`: (任务ID, 新延迟) 的元组列表
780    ///
781    /// # 返回
782    /// 成功推迟的任务数量
783    ///
784    /// # 注意
785    /// - 此方法会保持所有任务的原回调不变
786    /// - 如需替换回调,请使用 `postpone_batch_with_callbacks`
787    ///
788    /// # 性能优势
789    /// - 批量处理减少锁竞争
790    /// - 内部优化批量推迟操作
791    ///
792    /// # 示例
793    /// ```no_run
794    /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
795    /// use std::time::Duration;
796    ///
797    /// #[tokio::main]
798    /// async fn main() {
799    ///     let timer = TimerWheel::with_defaults();
800    ///     
801    ///     // 创建多个带回调的定时器
802    ///     let task1 = TimerWheel::create_task(Duration::from_secs(5), Some(CallbackWrapper::new(|| async {
803    ///         println!("Task 1 fired!");
804    ///     })));
805    ///     let task2 = TimerWheel::create_task(Duration::from_secs(5), Some(CallbackWrapper::new(|| async {
806    ///         println!("Task 2 fired!");
807    ///     })));
808    ///     let task3 = TimerWheel::create_task(Duration::from_secs(5), Some(CallbackWrapper::new(|| async {
809    ///         println!("Task 3 fired!");
810    ///     })));
811    ///     
812    ///     let task_ids = vec![
813    ///         (task1.get_id(), Duration::from_secs(10)),
814    ///         (task2.get_id(), Duration::from_secs(15)),
815    ///         (task3.get_id(), Duration::from_secs(20)),
816    ///     ];
817    ///     
818    ///     timer.register(task1);
819    ///     timer.register(task2);
820    ///     timer.register(task3);
821    ///     
822    ///     // 批量推迟(保持原回调)
823    ///     let postponed = timer.postpone_batch(&task_ids);
824    ///     println!("已推迟 {} 个定时器", postponed);
825    /// }
826    /// ```
827    #[inline]
828    pub fn postpone_batch(&self, updates: &[(TaskId, Duration)]) -> usize {
829        let mut wheel = self.wheel.lock();
830        wheel.postpone_batch(updates.to_vec())
831    }
832
833    /// 批量推迟定时器(替换回调)
834    ///
835    /// # 参数
836    /// - `updates`: (任务ID, 新延迟, 新回调) 的元组列表
837    ///
838    /// # 返回
839    /// 成功推迟的任务数量
840    ///
841    /// # 性能优势
842    /// - 批量处理减少锁竞争
843    /// - 内部优化批量推迟操作
844    ///
845    /// # 示例
846    /// ```no_run
847    /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
848    /// use std::time::Duration;
849    /// use std::sync::Arc;
850    /// use std::sync::atomic::{AtomicU32, Ordering};
851    ///
852    /// #[tokio::main]
853    /// async fn main() {
854    ///     let timer = TimerWheel::with_defaults();
855    ///     let counter = Arc::new(AtomicU32::new(0));
856    ///     
857    ///     // 创建多个定时器
858    ///     let task1 = TimerWheel::create_task(Duration::from_secs(5), None);
859    ///     let task2 = TimerWheel::create_task(Duration::from_secs(5), None);
860    ///     
861    ///     let id1 = task1.get_id();
862    ///     let id2 = task2.get_id();
863    ///     
864    ///     timer.register(task1);
865    ///     timer.register(task2);
866    ///     
867    ///     // 批量推迟并替换回调
868    ///     let updates: Vec<_> = vec![id1, id2]
869    ///         .into_iter()
870    ///         .map(|id| {
871    ///             let counter = Arc::clone(&counter);
872    ///             (id, Duration::from_secs(10), Some(CallbackWrapper::new(move || {
873    ///                 let counter = Arc::clone(&counter);
874    ///                 async move { counter.fetch_add(1, Ordering::SeqCst); }
875    ///             })))
876    ///         })
877    ///         .collect();
878    ///     let postponed = timer.postpone_batch_with_callbacks(updates);
879    ///     println!("已推迟 {} 个定时器", postponed);
880    /// }
881    /// ```
882    #[inline]
883    pub fn postpone_batch_with_callbacks(
884        &self,
885        updates: Vec<(TaskId, Duration, Option<CallbackWrapper>)>,
886    ) -> usize {
887        let mut wheel = self.wheel.lock();
888        wheel.postpone_batch_with_callbacks(updates.to_vec())
889    }
890    
891    /// 核心 tick 循环
892    async fn tick_loop(wheel: Arc<Mutex<Wheel>>, tick_duration: Duration) {
893        let mut interval = tokio::time::interval(tick_duration);
894        interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
895
896        loop {
897            interval.tick().await;
898
899            // 推进时间轮并获取到期任务
900            let expired_tasks = {
901                let mut wheel_guard = wheel.lock();
902                wheel_guard.advance()
903            };
904
905            // 执行到期任务
906            for task in expired_tasks {
907                let callback = task.get_callback();
908                
909                // 移动task的所有权来获取completion_notifier
910                let notifier = task.completion_notifier;
911                
912                // 只有注册过的任务才有 notifier
913                if let Some(notifier) = notifier {
914                    // 在独立的 tokio 任务中执行回调,并在回调完成后发送通知
915                    if let Some(callback) = callback {
916                        tokio::spawn(async move {
917                            // 执行回调
918                            let future = callback.call();
919                            future.await;
920                            
921                            // 回调执行完成后发送通知
922                            let _ = notifier.0.send(TaskCompletionReason::Expired);
923                        });
924                    } else {
925                        // 如果没有回调,立即发送完成通知
926                        let _ = notifier.0.send(TaskCompletionReason::Expired);
927                    }
928                }
929            }
930        }
931    }
932
933    /// 停止定时器管理器
934    pub async fn shutdown(mut self) {
935        if let Some(handle) = self.tick_handle.take() {
936            handle.abort();
937            let _ = handle.await;
938        }
939    }
940}
941
942impl Drop for TimerWheel {
943    fn drop(&mut self) {
944        if let Some(handle) = self.tick_handle.take() {
945            handle.abort();
946        }
947    }
948}
949
950#[cfg(test)]
951mod tests {
952    use super::*;
953    use std::sync::atomic::{AtomicU32, Ordering};
954
955    #[tokio::test]
956    async fn test_timer_creation() {
957        let _timer = TimerWheel::with_defaults();
958    }
959
960    #[tokio::test]
961    async fn test_schedule_once() {
962        use std::sync::Arc;
963        let timer = TimerWheel::with_defaults();
964        let counter = Arc::new(AtomicU32::new(0));
965        let counter_clone = Arc::clone(&counter);
966
967        let task = TimerWheel::create_task(
968            Duration::from_millis(50),
969            Some(CallbackWrapper::new(move || {
970                let counter = Arc::clone(&counter_clone);
971                async move {
972                    counter.fetch_add(1, Ordering::SeqCst);
973                }
974            })),
975        );
976        let _handle = timer.register(task);
977
978        // 等待定时器触发
979        tokio::time::sleep(Duration::from_millis(100)).await;
980        assert_eq!(counter.load(Ordering::SeqCst), 1);
981    }
982
983    #[tokio::test]
984    async fn test_cancel_timer() {
985        use std::sync::Arc;
986        let timer = TimerWheel::with_defaults();
987        let counter = Arc::new(AtomicU32::new(0));
988        let counter_clone = Arc::clone(&counter);
989
990        let task = TimerWheel::create_task(
991            Duration::from_millis(100),
992            Some(CallbackWrapper::new(move || {
993                let counter = Arc::clone(&counter_clone);
994                async move {
995                    counter.fetch_add(1, Ordering::SeqCst);
996                }
997            })),
998        );
999        let handle = timer.register(task);
1000
1001        // 立即取消
1002        let cancel_result = handle.cancel();
1003        assert!(cancel_result);
1004
1005        // 等待足够长时间确保定时器不会触发
1006        tokio::time::sleep(Duration::from_millis(200)).await;
1007        assert_eq!(counter.load(Ordering::SeqCst), 0);
1008    }
1009
1010    #[tokio::test]
1011    async fn test_cancel_immediate() {
1012        use std::sync::Arc;
1013        let timer = TimerWheel::with_defaults();
1014        let counter = Arc::new(AtomicU32::new(0));
1015        let counter_clone = Arc::clone(&counter);
1016
1017        let task = TimerWheel::create_task(
1018            Duration::from_millis(100),
1019            Some(CallbackWrapper::new(move || {
1020                let counter = Arc::clone(&counter_clone);
1021                async move {
1022                    counter.fetch_add(1, Ordering::SeqCst);
1023                }
1024            })),
1025        );
1026        let handle = timer.register(task);
1027
1028        // 立即取消
1029        let cancel_result = handle.cancel();
1030        assert!(cancel_result);
1031
1032        // 等待足够长时间确保定时器不会触发
1033        tokio::time::sleep(Duration::from_millis(200)).await;
1034        assert_eq!(counter.load(Ordering::SeqCst), 0);
1035    }
1036
1037    #[tokio::test]
1038    async fn test_postpone_timer() {
1039        use std::sync::Arc;
1040        let timer = TimerWheel::with_defaults();
1041        let counter = Arc::new(AtomicU32::new(0));
1042        let counter_clone = Arc::clone(&counter);
1043
1044        let task = TimerWheel::create_task(
1045            Duration::from_millis(50),
1046            Some(CallbackWrapper::new(move || {
1047                let counter = Arc::clone(&counter_clone);
1048                async move {
1049                    counter.fetch_add(1, Ordering::SeqCst);
1050                }
1051            })),
1052        );
1053        let task_id = task.get_id();
1054        let handle = timer.register(task);
1055
1056        // 推迟任务到 150ms
1057        let postponed = timer.postpone(task_id, Duration::from_millis(150), None);
1058        assert!(postponed);
1059
1060        // 等待原定时间 50ms,任务不应该触发
1061        tokio::time::sleep(Duration::from_millis(70)).await;
1062        assert_eq!(counter.load(Ordering::SeqCst), 0);
1063
1064        // 等待新的触发时间(从推迟开始算,还需要等待约 150ms)
1065        let result = tokio::time::timeout(
1066            Duration::from_millis(200),
1067            handle.into_completion_receiver().0
1068        ).await;
1069        assert!(result.is_ok());
1070        
1071        // 等待回调执行
1072        tokio::time::sleep(Duration::from_millis(20)).await;
1073        assert_eq!(counter.load(Ordering::SeqCst), 1);
1074    }
1075
1076    #[tokio::test]
1077    async fn test_postpone_with_callback() {
1078        use std::sync::Arc;
1079        let timer = TimerWheel::with_defaults();
1080        let counter = Arc::new(AtomicU32::new(0));
1081        let counter_clone1 = Arc::clone(&counter);
1082        let counter_clone2 = Arc::clone(&counter);
1083
1084        // 创建任务,原始回调增加 1
1085        let task = TimerWheel::create_task(
1086            Duration::from_millis(50),
1087            Some(CallbackWrapper::new(move || {
1088                let counter = Arc::clone(&counter_clone1);
1089                async move {
1090                    counter.fetch_add(1, Ordering::SeqCst);
1091                }
1092            })),
1093        );
1094        let task_id = task.get_id();
1095        let handle = timer.register(task);
1096
1097        // 推迟任务并替换回调,新回调增加 10
1098        let postponed = timer.postpone(
1099            task_id,
1100            Duration::from_millis(100),
1101            Some(CallbackWrapper::new(move || {
1102                let counter = Arc::clone(&counter_clone2);
1103                async move {
1104                    counter.fetch_add(10, Ordering::SeqCst);
1105                }
1106            })),
1107        );
1108        assert!(postponed);
1109
1110        // 等待任务触发(推迟后需要等待100ms,加上余量)
1111        let result = tokio::time::timeout(
1112            Duration::from_millis(200),
1113            handle.into_completion_receiver().0
1114        ).await;
1115        assert!(result.is_ok());
1116        
1117        // 等待回调执行
1118        tokio::time::sleep(Duration::from_millis(20)).await;
1119        
1120        // 验证新回调被执行(增加了 10 而不是 1)
1121        assert_eq!(counter.load(Ordering::SeqCst), 10);
1122    }
1123
1124    #[tokio::test]
1125    async fn test_postpone_nonexistent_timer() {
1126        let timer = TimerWheel::with_defaults();
1127        
1128        // 尝试推迟不存在的任务
1129        let fake_task = TimerWheel::create_task(Duration::from_millis(50), None);
1130        let fake_task_id = fake_task.get_id();
1131        // 不注册这个任务
1132        
1133        let postponed = timer.postpone(fake_task_id, Duration::from_millis(100), None);
1134        assert!(!postponed);
1135    }
1136
1137    #[tokio::test]
1138    async fn test_postpone_batch() {
1139        use std::sync::Arc;
1140        let timer = TimerWheel::with_defaults();
1141        let counter = Arc::new(AtomicU32::new(0));
1142
1143        // 创建 3 个任务
1144        let mut task_ids = Vec::new();
1145        for _ in 0..3 {
1146            let counter_clone = Arc::clone(&counter);
1147            let task = TimerWheel::create_task(
1148                Duration::from_millis(50),
1149                Some(CallbackWrapper::new(move || {
1150                    let counter = Arc::clone(&counter_clone);
1151                    async move {
1152                        counter.fetch_add(1, Ordering::SeqCst);
1153                    }
1154                })),
1155            );
1156            task_ids.push((task.get_id(), Duration::from_millis(150)));
1157            timer.register(task);
1158        }
1159
1160        // 批量推迟
1161        let postponed = timer.postpone_batch(&task_ids);
1162        assert_eq!(postponed, 3);
1163
1164        // 等待原定时间 50ms,任务不应该触发
1165        tokio::time::sleep(Duration::from_millis(70)).await;
1166        assert_eq!(counter.load(Ordering::SeqCst), 0);
1167
1168        // 等待新的触发时间(从推迟开始算,还需要等待约 150ms)
1169        tokio::time::sleep(Duration::from_millis(200)).await;
1170        
1171        // 等待回调执行
1172        tokio::time::sleep(Duration::from_millis(20)).await;
1173        assert_eq!(counter.load(Ordering::SeqCst), 3);
1174    }
1175
1176    #[tokio::test]
1177    async fn test_postpone_batch_with_callbacks() {
1178        use std::sync::Arc;
1179        let timer = TimerWheel::with_defaults();
1180        let counter = Arc::new(AtomicU32::new(0));
1181
1182        // 创建 3 个任务
1183        let mut task_ids = Vec::new();
1184        for _ in 0..3 {
1185            let task = TimerWheel::create_task(
1186                Duration::from_millis(50),
1187                None
1188            );
1189            task_ids.push(task.get_id());
1190            timer.register(task);
1191        }
1192
1193        // 批量推迟并替换回调
1194        let updates: Vec<_> = task_ids
1195            .into_iter()
1196            .map(|id| {
1197                let counter_clone = Arc::clone(&counter);
1198                (id, Duration::from_millis(150), Some(CallbackWrapper::new(move || {
1199                    let counter = Arc::clone(&counter_clone);
1200                    async move {
1201                        counter.fetch_add(1, Ordering::SeqCst);
1202                    }
1203                })))
1204            })
1205            .collect();
1206
1207        let postponed = timer.postpone_batch_with_callbacks(updates);
1208        assert_eq!(postponed, 3);
1209
1210        // 等待原定时间 50ms,任务不应该触发
1211        tokio::time::sleep(Duration::from_millis(70)).await;
1212        assert_eq!(counter.load(Ordering::SeqCst), 0);
1213
1214        // 等待新的触发时间(从推迟开始算,还需要等待约 150ms)
1215        tokio::time::sleep(Duration::from_millis(200)).await;
1216        
1217        // 等待回调执行
1218        tokio::time::sleep(Duration::from_millis(20)).await;
1219        assert_eq!(counter.load(Ordering::SeqCst), 3);
1220    }
1221
1222    #[tokio::test]
1223    async fn test_postpone_keeps_completion_receiver_valid() {
1224        use std::sync::Arc;
1225        let timer = TimerWheel::with_defaults();
1226        let counter = Arc::new(AtomicU32::new(0));
1227        let counter_clone = Arc::clone(&counter);
1228
1229        let task = TimerWheel::create_task(
1230            Duration::from_millis(50),
1231            Some(CallbackWrapper::new(move || {
1232                let counter = Arc::clone(&counter_clone);
1233                async move {
1234                    counter.fetch_add(1, Ordering::SeqCst);
1235                }
1236            })),
1237        );
1238        let task_id = task.get_id();
1239        let handle = timer.register(task);
1240
1241        // 推迟任务
1242        timer.postpone(task_id, Duration::from_millis(100), None);
1243
1244        // 验证原 completion_receiver 仍然有效(推迟后需要等待100ms,加上余量)
1245        let result = tokio::time::timeout(
1246            Duration::from_millis(200),
1247            handle.into_completion_receiver().0
1248        ).await;
1249        assert!(result.is_ok(), "Completion receiver should still work after postpone");
1250        
1251        // 等待回调执行
1252        tokio::time::sleep(Duration::from_millis(20)).await;
1253        assert_eq!(counter.load(Ordering::SeqCst), 1);
1254    }
1255}
1256