kestrel_protocol_timer/
timer.rs

1use crate::config::{BatchConfig, ServiceConfig, WheelConfig};
2use crate::task::{CallbackWrapper, TaskId, TaskCompletionReason};
3use crate::wheel::Wheel;
4use parking_lot::Mutex;
5use std::sync::Arc;
6use std::time::Duration;
7use tokio::sync::oneshot;
8use tokio::task::JoinHandle;
9
10/// 完成通知接收器,用于接收定时器完成通知
11pub struct CompletionReceiver(pub oneshot::Receiver<TaskCompletionReason>);
12
13/// 定时器句柄,用于管理定时器的生命周期
14/// 
15/// 注意:此类型不实现 Clone,以防止重复取消同一个定时器。
16/// 每个定时器只应有一个所有者。
17pub struct TimerHandle {
18    pub(crate) task_id: TaskId,
19    pub(crate) wheel: Arc<Mutex<Wheel>>,
20    pub(crate) completion_rx: CompletionReceiver,
21}
22
23impl TimerHandle {
24    pub(crate) fn new(task_id: TaskId, wheel: Arc<Mutex<Wheel>>, completion_rx: oneshot::Receiver<TaskCompletionReason>) -> Self {
25        Self { task_id, wheel, completion_rx: CompletionReceiver(completion_rx) }
26    }
27
28    /// 取消定时器
29    ///
30    /// # 返回
31    /// 如果任务存在且成功取消返回 true,否则返回 false
32    ///
33    /// # 示例
34    /// ```no_run
35    /// # use kestrel_protocol_timer::{TimerWheel, CallbackWrapper};
36    /// # use std::time::Duration;
37    /// # 
38    /// # #[tokio::main]
39    /// # async fn main() {
40    /// let timer = TimerWheel::with_defaults();
41    /// let callback = Some(CallbackWrapper::new(|| async {}));
42    /// let task = TimerWheel::create_task(Duration::from_secs(1), callback);
43    /// let handle = timer.register(task);
44    /// 
45    /// // 取消定时器
46    /// let success = handle.cancel();
47    /// println!("取消成功: {}", success);
48    /// # }
49    /// ```
50    pub fn cancel(&self) -> bool {
51        let mut wheel = self.wheel.lock();
52        wheel.cancel(self.task_id)
53    }
54
55    /// 获取完成通知接收器的可变引用
56    ///
57    /// # 示例
58    /// ```no_run
59    /// # use kestrel_protocol_timer::{TimerWheel, CallbackWrapper};
60    /// # use std::time::Duration;
61    /// # 
62    /// # #[tokio::main]
63    /// # async fn main() {
64    /// let timer = TimerWheel::with_defaults();
65    /// let callback = Some(CallbackWrapper::new(|| async {
66    ///     println!("Timer fired!");
67    /// }));
68    /// let task = TimerWheel::create_task(Duration::from_secs(1), callback);
69    /// let handle = timer.register(task);
70    /// 
71    /// // 等待定时器完成(使用 into_completion_receiver 消耗句柄)
72    /// handle.into_completion_receiver().0.await.ok();
73    /// println!("Timer completed!");
74    /// # }
75    /// ```
76    pub fn completion_receiver(&mut self) -> &mut CompletionReceiver {
77        &mut self.completion_rx
78    }
79
80    /// 消耗句柄,返回完成通知接收器
81    ///
82    /// # 示例
83    /// ```no_run
84    /// # use kestrel_protocol_timer::{TimerWheel, CallbackWrapper};
85    /// # use std::time::Duration;
86    /// # 
87    /// # #[tokio::main]
88    /// # async fn main() {
89    /// let timer = TimerWheel::with_defaults();
90    /// let callback = Some(CallbackWrapper::new(|| async {
91    ///     println!("Timer fired!");
92    /// }));
93    /// let task = TimerWheel::create_task(Duration::from_secs(1), callback);
94    /// let handle = timer.register(task);
95    /// 
96    /// // 等待定时器完成
97    /// handle.into_completion_receiver().0.await.ok();
98    /// println!("Timer completed!");
99    /// # }
100    /// ```
101    pub fn into_completion_receiver(self) -> CompletionReceiver {
102        self.completion_rx
103    }
104}
105
106/// 批量定时器句柄,用于管理批量调度的定时器
107/// 
108/// 通过共享 Wheel 引用减少内存开销,同时提供批量操作和迭代器访问能力。
109/// 
110/// 注意:此类型不实现 Clone,以防止重复取消同一批定时器。
111/// 如需访问单个定时器句柄,请使用 `into_iter()` 或 `into_handles()` 进行转换。
112pub struct BatchHandle {
113    pub(crate) task_ids: Vec<TaskId>,
114    pub(crate) wheel: Arc<Mutex<Wheel>>,
115    pub(crate) completion_rxs: Vec<oneshot::Receiver<TaskCompletionReason>>,
116}
117
118impl BatchHandle {
119    pub(crate) fn new(task_ids: Vec<TaskId>, wheel: Arc<Mutex<Wheel>>, completion_rxs: Vec<oneshot::Receiver<TaskCompletionReason>>) -> Self {
120        Self { task_ids, wheel, completion_rxs }
121    }
122
123    /// 批量取消所有定时器
124    ///
125    /// # 返回
126    /// 成功取消的任务数量
127    ///
128    /// # 示例
129    /// ```no_run
130    /// # use kestrel_protocol_timer::{TimerWheel, CallbackWrapper};
131    /// # use std::time::Duration;
132    /// # 
133    /// # #[tokio::main]
134    /// # async fn main() {
135    /// let timer = TimerWheel::with_defaults();
136    /// let delays: Vec<Duration> = (0..10)
137    ///     .map(|_| Duration::from_secs(1))
138    ///     .collect();
139    /// let tasks = TimerWheel::create_batch(delays);
140    /// let batch = timer.register_batch(tasks);
141    /// 
142    /// let cancelled = batch.cancel_all();
143    /// println!("取消了 {} 个定时器", cancelled);
144    /// # }
145    /// ```
146    pub fn cancel_all(self) -> usize {
147        let mut wheel = self.wheel.lock();
148        wheel.cancel_batch(&self.task_ids)
149    }
150
151    /// 将批量句柄转换为单个定时器句柄的 Vec
152    ///
153    /// 消耗 BatchHandle,为每个任务创建独立的 TimerHandle。
154    ///
155    /// # 示例
156    /// ```no_run
157    /// # use kestrel_protocol_timer::{TimerWheel, CallbackWrapper};
158    /// # use std::time::Duration;
159    /// # 
160    /// # #[tokio::main]
161    /// # async fn main() {
162    /// let timer = TimerWheel::with_defaults();
163    /// let delays: Vec<Duration> = (0..3)
164    ///     .map(|_| Duration::from_secs(1))
165    ///     .collect();
166    /// let tasks = TimerWheel::create_batch(delays);
167    /// let batch = timer.register_batch(tasks);
168    /// 
169    /// // 转换为独立的句柄
170    /// let handles = batch.into_handles();
171    /// for handle in handles {
172    ///     // 可以单独操作每个句柄
173    /// }
174    /// # }
175    /// ```
176    pub fn into_handles(self) -> Vec<TimerHandle> {
177        self.task_ids
178            .into_iter()
179            .zip(self.completion_rxs.into_iter())
180            .map(|(task_id, rx)| {
181                TimerHandle::new(task_id, self.wheel.clone(), rx)
182            })
183            .collect()
184    }
185
186    /// 获取批量任务的数量
187    pub fn len(&self) -> usize {
188        self.task_ids.len()
189    }
190
191    /// 检查批量任务是否为空
192    pub fn is_empty(&self) -> bool {
193        self.task_ids.is_empty()
194    }
195
196    /// 获取所有任务 ID 的引用
197    pub fn task_ids(&self) -> &[TaskId] {
198        &self.task_ids
199    }
200
201    /// 获取所有完成通知接收器的引用
202    ///
203    /// # 返回
204    /// 所有任务的完成通知接收器列表引用
205    pub fn completion_receivers(&mut self) -> &mut Vec<oneshot::Receiver<TaskCompletionReason>> {
206        &mut self.completion_rxs
207    }
208
209    /// 消耗句柄,返回所有完成通知接收器
210    ///
211    /// # 返回
212    /// 所有任务的完成通知接收器列表
213    ///
214    /// # 示例
215    /// ```no_run
216    /// # use kestrel_protocol_timer::{TimerWheel, CallbackWrapper};
217    /// # use std::time::Duration;
218    /// # 
219    /// # #[tokio::main]
220    /// # async fn main() {
221    /// let timer = TimerWheel::with_defaults();
222    /// let delays: Vec<Duration> = (0..3)
223    ///     .map(|_| Duration::from_secs(1))
224    ///     .collect();
225    /// let tasks = TimerWheel::create_batch(delays);
226    /// let batch = timer.register_batch(tasks);
227    /// 
228    /// // 获取所有完成通知接收器
229    /// let receivers = batch.into_completion_receivers();
230    /// for rx in receivers {
231    ///     tokio::spawn(async move {
232    ///         if rx.await.is_ok() {
233    ///             println!("A timer completed!");
234    ///         }
235    ///     });
236    /// }
237    /// # }
238    /// ```
239    pub fn into_completion_receivers(self) -> Vec<oneshot::Receiver<TaskCompletionReason>> {
240        self.completion_rxs
241    }
242}
243
244/// 实现 IntoIterator,允许直接迭代 BatchHandle
245/// 
246/// # 示例
247/// ```no_run
248/// # use kestrel_protocol_timer::{TimerWheel, CallbackWrapper};
249/// # use std::time::Duration;
250/// # 
251/// # #[tokio::main]
252/// # async fn main() {
253/// let timer = TimerWheel::with_defaults();
254/// let delays: Vec<Duration> = (0..3)
255///     .map(|_| Duration::from_secs(1))
256///     .collect();
257/// let tasks = TimerWheel::create_batch(delays);
258/// let batch = timer.register_batch(tasks);
259/// 
260/// // 直接迭代,每个元素都是独立的 TimerHandle
261/// for handle in batch {
262///     // 可以单独操作每个句柄
263/// }
264/// # }
265/// ```
266impl IntoIterator for BatchHandle {
267    type Item = TimerHandle;
268    type IntoIter = BatchHandleIter;
269
270    fn into_iter(self) -> Self::IntoIter {
271        BatchHandleIter {
272            task_ids: self.task_ids.into_iter(),
273            completion_rxs: self.completion_rxs.into_iter(),
274            wheel: self.wheel,
275        }
276    }
277}
278
279/// BatchHandle 的迭代器
280pub struct BatchHandleIter {
281    task_ids: std::vec::IntoIter<TaskId>,
282    completion_rxs: std::vec::IntoIter<oneshot::Receiver<TaskCompletionReason>>,
283    wheel: Arc<Mutex<Wheel>>,
284}
285
286impl Iterator for BatchHandleIter {
287    type Item = TimerHandle;
288
289    fn next(&mut self) -> Option<Self::Item> {
290        match (self.task_ids.next(), self.completion_rxs.next()) {
291            (Some(task_id), Some(rx)) => {
292                Some(TimerHandle::new(task_id, self.wheel.clone(), rx))
293            }
294            _ => None,
295        }
296    }
297
298    fn size_hint(&self) -> (usize, Option<usize>) {
299        self.task_ids.size_hint()
300    }
301}
302
303impl ExactSizeIterator for BatchHandleIter {
304    fn len(&self) -> usize {
305        self.task_ids.len()
306    }
307}
308
309/// 时间轮定时器管理器
310pub struct TimerWheel {
311    /// 时间轮唯一标识符
312    
313    /// 时间轮实例(使用 Arc<Mutex> 包装以支持多线程访问)
314    wheel: Arc<Mutex<Wheel>>,
315    
316    /// 后台 tick 循环任务句柄
317    tick_handle: Option<JoinHandle<()>>,
318}
319
320impl TimerWheel {
321    /// 创建新的定时器管理器
322    ///
323    /// # 参数
324    /// - `config`: 时间轮配置(已经过验证)
325    ///
326    /// # 示例
327    /// ```no_run
328    /// use kestrel_protocol_timer::{TimerWheel, WheelConfig, TimerTask, BatchConfig};
329    /// use std::time::Duration;
330    ///
331    /// #[tokio::main]
332    /// async fn main() {
333    ///     let config = WheelConfig::builder()
334    ///         .l0_tick_duration(Duration::from_millis(10))
335    ///         .l0_slot_count(512)
336    ///         .l1_tick_duration(Duration::from_secs(1))
337    ///         .l1_slot_count(64)
338    ///         .build()
339    ///         .unwrap();
340    ///     let timer = TimerWheel::new(config, BatchConfig::default());
341    ///     
342    ///     // 使用两步式 API
343    ///     let task = TimerWheel::create_task(Duration::from_secs(1), None);
344    ///     let handle = timer.register(task);
345    /// }
346    /// ```
347    pub fn new(config: WheelConfig, batch_config: BatchConfig) -> Self {
348        let tick_duration = config.hierarchical.l0_tick_duration;
349        let wheel = Wheel::new(config, batch_config);
350        let wheel = Arc::new(Mutex::new(wheel));
351        let wheel_clone = wheel.clone();
352
353        // 启动后台 tick 循环
354        let tick_handle = tokio::spawn(async move {
355            Self::tick_loop(wheel_clone, tick_duration).await;
356        });
357
358        Self {
359            wheel,
360            tick_handle: Some(tick_handle),
361        }
362    }
363
364    /// 创建带默认配置的定时器管理器(分层模式)
365    /// - L0 层 tick 时长: 10ms, 槽位数量: 512
366    /// - L1 层 tick 时长: 1s, 槽位数量: 64
367    ///
368    /// # 示例
369    /// ```no_run
370    /// use kestrel_protocol_timer::TimerWheel;
371    ///
372    /// #[tokio::main]
373    /// async fn main() {
374    ///     let timer = TimerWheel::with_defaults();
375    /// }
376    /// ```
377    pub fn with_defaults() -> Self {
378        Self::new(WheelConfig::default(), BatchConfig::default())
379    }
380
381    /// 创建与此时间轮绑定的 TimerService(使用默认配置)
382    ///
383    /// # 返回
384    /// 绑定到此时间轮的 TimerService 实例
385    ///
386    /// # 参数
387    /// - `service_config`: 服务配置
388    ///
389    /// # 示例
390    /// ```no_run
391    /// use kestrel_protocol_timer::{TimerWheel, TimerService, CallbackWrapper, ServiceConfig};
392    /// use std::time::Duration;
393    /// 
394    ///
395    /// #[tokio::main]
396    /// async fn main() {
397    ///     let timer = TimerWheel::with_defaults();
398    ///     let mut service = timer.create_service(ServiceConfig::default());
399    ///     
400    ///     // 使用两步式 API 通过 service 批量调度定时器
401    ///     let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = (0..5)
402    ///         .map(|_| (Duration::from_millis(100), Some(CallbackWrapper::new(|| async {}))))
403    ///         .collect();
404    ///     let tasks = TimerService::create_batch_with_callbacks(callbacks);
405    ///     service.register_batch(tasks).unwrap();
406    ///     
407    ///     // 接收超时通知
408    ///     let mut rx = service.take_receiver().unwrap();
409    ///     while let Some(task_id) = rx.recv().await {
410    ///         println!("Task {:?} completed", task_id);
411    ///     }
412    /// }
413    /// ```
414    pub fn create_service(&self, service_config: ServiceConfig) -> crate::service::TimerService {
415        crate::service::TimerService::new(self.wheel.clone(), service_config)
416    }
417    
418    /// 创建与此时间轮绑定的 TimerService(使用自定义配置)
419    ///
420    /// # 参数
421    /// - `config`: 服务配置
422    ///
423    /// # 返回
424    /// 绑定到此时间轮的 TimerService 实例
425    ///
426    /// # 示例
427    /// ```no_run
428    /// use kestrel_protocol_timer::{TimerWheel, ServiceConfig};
429    ///
430    /// #[tokio::main]
431    /// async fn main() {
432    ///     let timer = TimerWheel::with_defaults();
433    ///     let config = ServiceConfig::builder()
434    ///         .command_channel_capacity(1024)
435    ///         .timeout_channel_capacity(2000)
436    ///         .build()
437    ///         .unwrap();
438    ///     let service = timer.create_service_with_config(config);
439    /// }
440    /// ```
441    pub fn create_service_with_config(&self, config: ServiceConfig) -> crate::service::TimerService {
442        crate::service::TimerService::new(self.wheel.clone(), config)
443    }
444
445    /// 创建定时器任务(申请阶段)
446    /// 
447    /// # 参数
448    /// - `delay`: 延迟时间
449    /// - `callback`: 实现了 TimerCallback trait 的回调对象
450    /// 
451    /// # 返回
452    /// 返回 TimerTask,需要通过 `register()` 注册到时间轮
453    /// 
454    /// # 示例
455    /// ```no_run
456    /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
457    /// use std::time::Duration;
458    /// 
459    /// 
460    /// #[tokio::main]
461    /// async fn main() {
462    ///     let timer = TimerWheel::with_defaults();
463    ///     
464    ///     // 步骤 1: 创建任务
465    ///     let task = TimerWheel::create_task(Duration::from_secs(1), Some(CallbackWrapper::new(|| async {
466    ///         println!("Timer fired!");
467    ///     })));
468    ///     
469    ///     // 获取任务 ID
470    ///     let task_id = task.get_id();
471    ///     println!("Created task: {:?}", task_id);
472    ///     
473    ///     // 步骤 2: 注册任务
474    ///     let handle = timer.register(task);
475    /// }
476    /// ```
477    #[inline]
478    pub fn create_task(delay: Duration, callback: Option<CallbackWrapper>) -> crate::task::TimerTask {
479        crate::task::TimerTask::new(delay, callback)
480    }
481    
482    /// 批量创建定时器任务(申请阶段)
483    /// 
484    /// # 参数
485    /// - `callbacks`: (延迟时间, 回调) 的元组列表
486    /// 
487    /// # 返回
488    /// 返回 TimerTask 列表,需要通过 `register_batch()` 注册到时间轮
489    /// 
490    /// # 示例
491    /// ```no_run
492    /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
493    /// use std::time::Duration;
494    /// use std::sync::Arc;
495    /// use std::sync::atomic::{AtomicU32, Ordering};
496    /// 
497    /// #[tokio::main]
498    /// async fn main() {
499    ///     let timer = TimerWheel::with_defaults();
500    ///     let counter = Arc::new(AtomicU32::new(0));
501    ///     
502    ///     // 步骤 1: 批量创建任务
503    ///     let delays: Vec<Duration> = (0..3)
504    ///         .map(|_| Duration::from_millis(100))
505    ///         .collect();
506    ///     
507    ///     let tasks = TimerWheel::create_batch(delays);
508    ///     println!("Created {} tasks", tasks.len());
509    ///     
510    ///     // 步骤 2: 批量注册任务
511    ///     let batch = timer.register_batch(tasks);
512    /// }
513    /// ```
514    #[inline]
515    pub fn create_batch(delays: Vec<Duration>) -> Vec<crate::task::TimerTask>
516    {
517        delays
518            .into_iter()
519            .map(|delay| crate::task::TimerTask::new(delay, None))
520            .collect()
521    }
522
523    /// 批量创建定时器任务(申请阶段)
524    /// 
525    /// # 参数
526    /// - `callbacks`: (延迟时间, 回调) 的元组列表
527    /// 
528    /// # 返回
529    /// 返回 TimerTask 列表,需要通过 `register_batch()` 注册到时间轮
530    /// 
531    /// # 示例
532    /// ```no_run
533    /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
534    /// use std::time::Duration;
535    /// use std::sync::Arc;
536    /// use std::sync::atomic::{AtomicU32, Ordering};
537    /// 
538    /// #[tokio::main]
539    /// async fn main() {
540    ///     let timer = TimerWheel::with_defaults();
541    ///     let counter = Arc::new(AtomicU32::new(0));
542    ///     
543    ///     // 步骤 1: 批量创建任务
544    ///     let delays: Vec<Duration> = (0..3)
545    ///         .map(|_| Duration::from_millis(100))
546    ///         .collect();
547    ///     let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = delays
548    ///         .into_iter()
549    ///         .map(|delay| {
550    ///             let counter = Arc::clone(&counter);
551    ///             let callback = Some(CallbackWrapper::new(move || {
552    ///                 let counter = Arc::clone(&counter);
553    ///                 async move {
554    ///                     counter.fetch_add(1, Ordering::SeqCst);
555    ///                 }
556    ///             }));
557    ///             (delay, callback)
558    ///         })
559    ///         .collect();
560    ///     
561    ///     let tasks = TimerWheel::create_batch_with_callbacks(callbacks);
562    ///     println!("Created {} tasks", tasks.len());
563    ///     
564    ///     // 步骤 2: 批量注册任务
565    ///     let batch = timer.register_batch(tasks);
566    /// }
567    /// ```
568    #[inline]
569    pub fn create_batch_with_callbacks(callbacks: Vec<(Duration, Option<CallbackWrapper>)>) -> Vec<crate::task::TimerTask>
570    {
571        callbacks
572            .into_iter()
573            .map(|(delay, callback)| crate::task::TimerTask::new(delay, callback))
574            .collect()
575    }
576    
577    /// 注册定时器任务到时间轮(注册阶段)
578    /// 
579    /// # 参数
580    /// - `task`: 通过 `create_task()` 创建的任务
581    /// 
582    /// # 返回
583    /// 返回定时器句柄,可用于取消定时器和接收完成通知
584    /// 
585    /// # 示例
586    /// ```no_run
587    /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
588    /// 
589    /// use std::time::Duration;
590    /// 
591    /// #[tokio::main]
592    /// async fn main() {
593    ///     let timer = TimerWheel::with_defaults();
594    ///     
595    ///     let task = TimerWheel::create_task(Duration::from_secs(1), Some(CallbackWrapper::new(|| async {
596    ///         println!("Timer fired!");
597    ///     })));
598    ///     let task_id = task.get_id();
599    ///     
600    ///     let handle = timer.register(task);
601    ///     
602    ///     // 等待定时器完成
603    ///     handle.into_completion_receiver().0.await.ok();
604    /// }
605    /// ```
606    #[inline]
607    pub fn register(&self, task: crate::task::TimerTask) -> TimerHandle {
608        let (completion_tx, completion_rx) = oneshot::channel();
609        let notifier = crate::task::CompletionNotifier(completion_tx);
610        
611        let task_id = task.id;
612        
613        // 单次加锁完成所有操作
614        {
615            let mut wheel_guard = self.wheel.lock();
616            wheel_guard.insert(task, notifier);
617        }
618        
619        TimerHandle::new(task_id, self.wheel.clone(), completion_rx)
620    }
621    
622    /// 批量注册定时器任务到时间轮(注册阶段)
623    /// 
624    /// # 参数
625    /// - `tasks`: 通过 `create_batch()` 创建的任务列表
626    /// 
627    /// # 返回
628    /// 返回批量定时器句柄
629    /// 
630    /// # 示例
631    /// ```no_run
632    /// use kestrel_protocol_timer::{TimerWheel, TimerTask};
633    /// use std::time::Duration;
634    /// 
635    /// #[tokio::main]
636    /// async fn main() {
637    ///     let timer = TimerWheel::with_defaults();
638    ///     
639    ///     let delays: Vec<Duration> = (0..3)
640    ///         .map(|_| Duration::from_secs(1))
641    ///         .collect();
642    ///     let tasks = TimerWheel::create_batch(delays);
643    ///     
644    ///     let batch = timer.register_batch(tasks);
645    ///     println!("Registered {} timers", batch.len());
646    /// }
647    /// ```
648    #[inline]
649    pub fn register_batch(&self, tasks: Vec<crate::task::TimerTask>) -> BatchHandle {
650        let task_count = tasks.len();
651        let mut completion_rxs = Vec::with_capacity(task_count);
652        let mut task_ids = Vec::with_capacity(task_count);
653        let mut prepared_tasks = Vec::with_capacity(task_count);
654        
655        // 步骤1: 准备所有 channels 和 notifiers(无锁)
656        // 优化:使用 for 循环代替 map + collect,避免闭包捕获开销
657        for task in tasks {
658            let (completion_tx, completion_rx) = oneshot::channel();
659            let notifier = crate::task::CompletionNotifier(completion_tx);
660            
661            task_ids.push(task.id);
662            completion_rxs.push(completion_rx);
663            prepared_tasks.push((task, notifier));
664        }
665        
666        // 步骤2: 单次加锁,批量插入
667        {
668            let mut wheel_guard = self.wheel.lock();
669            wheel_guard.insert_batch(prepared_tasks);
670        }
671        
672        BatchHandle::new(task_ids, self.wheel.clone(), completion_rxs)
673    }
674
675    /// 取消定时器
676    ///
677    /// # 参数
678    /// - `task_id`: 任务 ID
679    ///
680    /// # 返回
681    /// 如果任务存在且成功取消返回 true,否则返回 false
682    /// 
683    /// # 示例
684    /// ```no_run
685    /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
686    /// 
687    /// use std::time::Duration;
688    ///
689    /// #[tokio::main]
690    /// async fn main() {
691    ///     let timer = TimerWheel::with_defaults();
692    ///     
693    ///     let task = TimerWheel::create_task(Duration::from_secs(10), Some(CallbackWrapper::new(|| async {
694    ///         println!("Timer fired!");
695    ///     })));
696    ///     let task_id = task.get_id();
697    ///     let _handle = timer.register(task);
698    ///     
699    ///     // 使用任务 ID 取消
700    ///     let cancelled = timer.cancel(task_id);
701    ///     println!("取消成功: {}", cancelled);
702    /// }
703    /// ```
704    #[inline]
705    pub fn cancel(&self, task_id: TaskId) -> bool {
706        let mut wheel = self.wheel.lock();
707        wheel.cancel(task_id)
708    }
709
710    /// 批量取消定时器
711    ///
712    /// # 参数
713    /// - `task_ids`: 要取消的任务 ID 列表
714    ///
715    /// # 返回
716    /// 成功取消的任务数量
717    ///
718    /// # 性能优势
719    /// - 批量处理减少锁竞争
720    /// - 内部优化批量取消操作
721    ///
722    /// # 示例
723    /// ```no_run
724    /// use kestrel_protocol_timer::{TimerWheel, TimerTask};
725    /// use std::time::Duration;
726    ///
727    /// #[tokio::main]
728    /// async fn main() {
729    ///     let timer = TimerWheel::with_defaults();
730    ///     
731    ///     // 创建多个定时器
732    ///     let task1 = TimerWheel::create_task(Duration::from_secs(10), None);
733    ///     let task2 = TimerWheel::create_task(Duration::from_secs(10), None);
734    ///     let task3 = TimerWheel::create_task(Duration::from_secs(10), None);
735    ///     
736    ///     let task_ids = vec![task1.get_id(), task2.get_id(), task3.get_id()];
737    ///     
738    ///     let _h1 = timer.register(task1);
739    ///     let _h2 = timer.register(task2);
740    ///     let _h3 = timer.register(task3);
741    ///     
742    ///     // 批量取消
743    ///     let cancelled = timer.cancel_batch(&task_ids);
744    ///     println!("已取消 {} 个定时器", cancelled);
745    /// }
746    /// ```
747    #[inline]
748    pub fn cancel_batch(&self, task_ids: &[TaskId]) -> usize {
749        let mut wheel = self.wheel.lock();
750        wheel.cancel_batch(task_ids)
751    }
752
753    /// 推迟定时器
754    ///
755    /// # 参数
756    /// - `task_id`: 要推迟的任务 ID
757    /// - `new_delay`: 新的延迟时间(从当前时间点重新计算)
758    /// - `callback`: 新的回调函数,传入 `None` 保持原回调不变,传入 `Some` 替换为新回调
759    ///
760    /// # 返回
761    /// 如果任务存在且成功推迟返回 true,否则返回 false
762    ///
763    /// # 注意
764    /// - 推迟后任务 ID 保持不变
765    /// - 原有的 completion_receiver 仍然有效
766    ///
767    /// # 示例
768    ///
769    /// ## 保持原回调
770    /// ```no_run
771    /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
772    /// use std::time::Duration;
773    /// 
774    ///
775    /// #[tokio::main]
776    /// async fn main() {
777    ///     let timer = TimerWheel::with_defaults();
778    ///     
779    ///     let task = TimerWheel::create_task(Duration::from_secs(5), Some(CallbackWrapper::new(|| async {
780    ///         println!("Timer fired!");
781    ///     })));
782    ///     let task_id = task.get_id();
783    ///     let _handle = timer.register(task);
784    ///     
785    ///     // 推迟到 10 秒后触发(保持原回调)
786    ///     let success = timer.postpone(task_id, Duration::from_secs(10), None);
787    ///     println!("推迟成功: {}", success);
788    /// }
789    /// ```
790    ///
791    /// ## 替换为新回调
792    /// ```no_run
793    /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
794    /// use std::time::Duration;
795    ///
796    /// #[tokio::main]
797    /// async fn main() {
798    ///     let timer = TimerWheel::with_defaults();
799    ///     
800    ///     let task = TimerWheel::create_task(Duration::from_secs(5), Some(CallbackWrapper::new(|| async {
801    ///         println!("Original callback!");
802    ///     })));
803    ///     let task_id = task.get_id();
804    ///     let _handle = timer.register(task);
805    ///     
806    ///     // 推迟到 10 秒后触发(并替换为新回调)
807    ///     let success = timer.postpone(task_id, Duration::from_secs(10), Some(CallbackWrapper::new(|| async {
808    ///         println!("New callback!");
809    ///     })));
810    ///     println!("推迟成功: {}", success);
811    /// }
812    /// ```
813    #[inline]
814    pub fn postpone(
815        &self,
816        task_id: TaskId,
817        new_delay: Duration,
818        callback: Option<CallbackWrapper>,
819    ) -> bool {
820        let mut wheel = self.wheel.lock();
821        wheel.postpone(task_id, new_delay, callback)
822    }
823
824    /// 批量推迟定时器(保持原回调)
825    ///
826    /// # 参数
827    /// - `updates`: (任务ID, 新延迟) 的元组列表
828    ///
829    /// # 返回
830    /// 成功推迟的任务数量
831    ///
832    /// # 注意
833    /// - 此方法会保持所有任务的原回调不变
834    /// - 如需替换回调,请使用 `postpone_batch_with_callbacks`
835    ///
836    /// # 性能优势
837    /// - 批量处理减少锁竞争
838    /// - 内部优化批量推迟操作
839    ///
840    /// # 示例
841    /// ```no_run
842    /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
843    /// use std::time::Duration;
844    ///
845    /// #[tokio::main]
846    /// async fn main() {
847    ///     let timer = TimerWheel::with_defaults();
848    ///     
849    ///     // 创建多个带回调的定时器
850    ///     let task1 = TimerWheel::create_task(Duration::from_secs(5), Some(CallbackWrapper::new(|| async {
851    ///         println!("Task 1 fired!");
852    ///     })));
853    ///     let task2 = TimerWheel::create_task(Duration::from_secs(5), Some(CallbackWrapper::new(|| async {
854    ///         println!("Task 2 fired!");
855    ///     })));
856    ///     let task3 = TimerWheel::create_task(Duration::from_secs(5), Some(CallbackWrapper::new(|| async {
857    ///         println!("Task 3 fired!");
858    ///     })));
859    ///     
860    ///     let task_ids = vec![
861    ///         (task1.get_id(), Duration::from_secs(10)),
862    ///         (task2.get_id(), Duration::from_secs(15)),
863    ///         (task3.get_id(), Duration::from_secs(20)),
864    ///     ];
865    ///     
866    ///     timer.register(task1);
867    ///     timer.register(task2);
868    ///     timer.register(task3);
869    ///     
870    ///     // 批量推迟(保持原回调)
871    ///     let postponed = timer.postpone_batch(task_ids);
872    ///     println!("已推迟 {} 个定时器", postponed);
873    /// }
874    /// ```
875    #[inline]
876    pub fn postpone_batch(&self, updates: Vec<(TaskId, Duration)>) -> usize {
877        let mut wheel = self.wheel.lock();
878        wheel.postpone_batch(updates)
879    }
880
881    /// 批量推迟定时器(替换回调)
882    ///
883    /// # 参数
884    /// - `updates`: (任务ID, 新延迟, 新回调) 的元组列表
885    ///
886    /// # 返回
887    /// 成功推迟的任务数量
888    ///
889    /// # 性能优势
890    /// - 批量处理减少锁竞争
891    /// - 内部优化批量推迟操作
892    ///
893    /// # 示例
894    /// ```no_run
895    /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
896    /// use std::time::Duration;
897    /// use std::sync::Arc;
898    /// use std::sync::atomic::{AtomicU32, Ordering};
899    ///
900    /// #[tokio::main]
901    /// async fn main() {
902    ///     let timer = TimerWheel::with_defaults();
903    ///     let counter = Arc::new(AtomicU32::new(0));
904    ///     
905    ///     // 创建多个定时器
906    ///     let task1 = TimerWheel::create_task(Duration::from_secs(5), None);
907    ///     let task2 = TimerWheel::create_task(Duration::from_secs(5), None);
908    ///     
909    ///     let id1 = task1.get_id();
910    ///     let id2 = task2.get_id();
911    ///     
912    ///     timer.register(task1);
913    ///     timer.register(task2);
914    ///     
915    ///     // 批量推迟并替换回调
916    ///     let updates: Vec<_> = vec![id1, id2]
917    ///         .into_iter()
918    ///         .map(|id| {
919    ///             let counter = Arc::clone(&counter);
920    ///             (id, Duration::from_secs(10), Some(CallbackWrapper::new(move || {
921    ///                 let counter = Arc::clone(&counter);
922    ///                 async move { counter.fetch_add(1, Ordering::SeqCst); }
923    ///             })))
924    ///         })
925    ///         .collect();
926    ///     let postponed = timer.postpone_batch_with_callbacks(updates);
927    ///     println!("已推迟 {} 个定时器", postponed);
928    /// }
929    /// ```
930    #[inline]
931    pub fn postpone_batch_with_callbacks(
932        &self,
933        updates: Vec<(TaskId, Duration, Option<CallbackWrapper>)>,
934    ) -> usize {
935        let mut wheel = self.wheel.lock();
936        wheel.postpone_batch_with_callbacks(updates.to_vec())
937    }
938    
939    /// 核心 tick 循环
940    async fn tick_loop(wheel: Arc<Mutex<Wheel>>, tick_duration: Duration) {
941        let mut interval = tokio::time::interval(tick_duration);
942        interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
943
944        loop {
945            interval.tick().await;
946
947            // 推进时间轮并获取到期任务
948            let expired_tasks = {
949                let mut wheel_guard = wheel.lock();
950                wheel_guard.advance()
951            };
952
953            // 执行到期任务
954            for task in expired_tasks {
955                let callback = task.get_callback();
956                
957                // 移动task的所有权来获取completion_notifier
958                let notifier = task.completion_notifier;
959                
960                // 只有注册过的任务才有 notifier
961                if let Some(notifier) = notifier {
962                    // 在独立的 tokio 任务中执行回调,并在回调完成后发送通知
963                    if let Some(callback) = callback {
964                        tokio::spawn(async move {
965                            // 执行回调
966                            let future = callback.call();
967                            future.await;
968                            
969                            // 回调执行完成后发送通知
970                            let _ = notifier.0.send(TaskCompletionReason::Expired);
971                        });
972                    } else {
973                        // 如果没有回调,立即发送完成通知
974                        let _ = notifier.0.send(TaskCompletionReason::Expired);
975                    }
976                }
977            }
978        }
979    }
980
981    /// 停止定时器管理器
982    pub async fn shutdown(mut self) {
983        if let Some(handle) = self.tick_handle.take() {
984            handle.abort();
985            let _ = handle.await;
986        }
987    }
988}
989
990impl Drop for TimerWheel {
991    fn drop(&mut self) {
992        if let Some(handle) = self.tick_handle.take() {
993            handle.abort();
994        }
995    }
996}
997
998#[cfg(test)]
999mod tests {
1000    use super::*;
1001    use std::sync::atomic::{AtomicU32, Ordering};
1002
1003    #[tokio::test]
1004    async fn test_timer_creation() {
1005        let _timer = TimerWheel::with_defaults();
1006    }
1007
1008    #[tokio::test]
1009    async fn test_schedule_once() {
1010        use std::sync::Arc;
1011        let timer = TimerWheel::with_defaults();
1012        let counter = Arc::new(AtomicU32::new(0));
1013        let counter_clone = Arc::clone(&counter);
1014
1015        let task = TimerWheel::create_task(
1016            Duration::from_millis(50),
1017            Some(CallbackWrapper::new(move || {
1018                let counter = Arc::clone(&counter_clone);
1019                async move {
1020                    counter.fetch_add(1, Ordering::SeqCst);
1021                }
1022            })),
1023        );
1024        let _handle = timer.register(task);
1025
1026        // 等待定时器触发
1027        tokio::time::sleep(Duration::from_millis(100)).await;
1028        assert_eq!(counter.load(Ordering::SeqCst), 1);
1029    }
1030
1031    #[tokio::test]
1032    async fn test_cancel_timer() {
1033        use std::sync::Arc;
1034        let timer = TimerWheel::with_defaults();
1035        let counter = Arc::new(AtomicU32::new(0));
1036        let counter_clone = Arc::clone(&counter);
1037
1038        let task = TimerWheel::create_task(
1039            Duration::from_millis(100),
1040            Some(CallbackWrapper::new(move || {
1041                let counter = Arc::clone(&counter_clone);
1042                async move {
1043                    counter.fetch_add(1, Ordering::SeqCst);
1044                }
1045            })),
1046        );
1047        let handle = timer.register(task);
1048
1049        // 立即取消
1050        let cancel_result = handle.cancel();
1051        assert!(cancel_result);
1052
1053        // 等待足够长时间确保定时器不会触发
1054        tokio::time::sleep(Duration::from_millis(200)).await;
1055        assert_eq!(counter.load(Ordering::SeqCst), 0);
1056    }
1057
1058    #[tokio::test]
1059    async fn test_cancel_immediate() {
1060        use std::sync::Arc;
1061        let timer = TimerWheel::with_defaults();
1062        let counter = Arc::new(AtomicU32::new(0));
1063        let counter_clone = Arc::clone(&counter);
1064
1065        let task = TimerWheel::create_task(
1066            Duration::from_millis(100),
1067            Some(CallbackWrapper::new(move || {
1068                let counter = Arc::clone(&counter_clone);
1069                async move {
1070                    counter.fetch_add(1, Ordering::SeqCst);
1071                }
1072            })),
1073        );
1074        let handle = timer.register(task);
1075
1076        // 立即取消
1077        let cancel_result = handle.cancel();
1078        assert!(cancel_result);
1079
1080        // 等待足够长时间确保定时器不会触发
1081        tokio::time::sleep(Duration::from_millis(200)).await;
1082        assert_eq!(counter.load(Ordering::SeqCst), 0);
1083    }
1084
1085    #[tokio::test]
1086    async fn test_postpone_timer() {
1087        use std::sync::Arc;
1088        let timer = TimerWheel::with_defaults();
1089        let counter = Arc::new(AtomicU32::new(0));
1090        let counter_clone = Arc::clone(&counter);
1091
1092        let task = TimerWheel::create_task(
1093            Duration::from_millis(50),
1094            Some(CallbackWrapper::new(move || {
1095                let counter = Arc::clone(&counter_clone);
1096                async move {
1097                    counter.fetch_add(1, Ordering::SeqCst);
1098                }
1099            })),
1100        );
1101        let task_id = task.get_id();
1102        let handle = timer.register(task);
1103
1104        // 推迟任务到 150ms
1105        let postponed = timer.postpone(task_id, Duration::from_millis(150), None);
1106        assert!(postponed);
1107
1108        // 等待原定时间 50ms,任务不应该触发
1109        tokio::time::sleep(Duration::from_millis(70)).await;
1110        assert_eq!(counter.load(Ordering::SeqCst), 0);
1111
1112        // 等待新的触发时间(从推迟开始算,还需要等待约 150ms)
1113        let result = tokio::time::timeout(
1114            Duration::from_millis(200),
1115            handle.into_completion_receiver().0
1116        ).await;
1117        assert!(result.is_ok());
1118        
1119        // 等待回调执行
1120        tokio::time::sleep(Duration::from_millis(20)).await;
1121        assert_eq!(counter.load(Ordering::SeqCst), 1);
1122    }
1123
1124    #[tokio::test]
1125    async fn test_postpone_with_callback() {
1126        use std::sync::Arc;
1127        let timer = TimerWheel::with_defaults();
1128        let counter = Arc::new(AtomicU32::new(0));
1129        let counter_clone1 = Arc::clone(&counter);
1130        let counter_clone2 = Arc::clone(&counter);
1131
1132        // 创建任务,原始回调增加 1
1133        let task = TimerWheel::create_task(
1134            Duration::from_millis(50),
1135            Some(CallbackWrapper::new(move || {
1136                let counter = Arc::clone(&counter_clone1);
1137                async move {
1138                    counter.fetch_add(1, Ordering::SeqCst);
1139                }
1140            })),
1141        );
1142        let task_id = task.get_id();
1143        let handle = timer.register(task);
1144
1145        // 推迟任务并替换回调,新回调增加 10
1146        let postponed = timer.postpone(
1147            task_id,
1148            Duration::from_millis(100),
1149            Some(CallbackWrapper::new(move || {
1150                let counter = Arc::clone(&counter_clone2);
1151                async move {
1152                    counter.fetch_add(10, Ordering::SeqCst);
1153                }
1154            })),
1155        );
1156        assert!(postponed);
1157
1158        // 等待任务触发(推迟后需要等待100ms,加上余量)
1159        let result = tokio::time::timeout(
1160            Duration::from_millis(200),
1161            handle.into_completion_receiver().0
1162        ).await;
1163        assert!(result.is_ok());
1164        
1165        // 等待回调执行
1166        tokio::time::sleep(Duration::from_millis(20)).await;
1167        
1168        // 验证新回调被执行(增加了 10 而不是 1)
1169        assert_eq!(counter.load(Ordering::SeqCst), 10);
1170    }
1171
1172    #[tokio::test]
1173    async fn test_postpone_nonexistent_timer() {
1174        let timer = TimerWheel::with_defaults();
1175        
1176        // 尝试推迟不存在的任务
1177        let fake_task = TimerWheel::create_task(Duration::from_millis(50), None);
1178        let fake_task_id = fake_task.get_id();
1179        // 不注册这个任务
1180        
1181        let postponed = timer.postpone(fake_task_id, Duration::from_millis(100), None);
1182        assert!(!postponed);
1183    }
1184
1185    #[tokio::test]
1186    async fn test_postpone_batch() {
1187        use std::sync::Arc;
1188        let timer = TimerWheel::with_defaults();
1189        let counter = Arc::new(AtomicU32::new(0));
1190
1191        // 创建 3 个任务
1192        let mut task_ids = Vec::new();
1193        for _ in 0..3 {
1194            let counter_clone = Arc::clone(&counter);
1195            let task = TimerWheel::create_task(
1196                Duration::from_millis(50),
1197                Some(CallbackWrapper::new(move || {
1198                    let counter = Arc::clone(&counter_clone);
1199                    async move {
1200                        counter.fetch_add(1, Ordering::SeqCst);
1201                    }
1202                })),
1203            );
1204            task_ids.push((task.get_id(), Duration::from_millis(150)));
1205            timer.register(task);
1206        }
1207
1208        // 批量推迟
1209        let postponed = timer.postpone_batch(task_ids);
1210        assert_eq!(postponed, 3);
1211
1212        // 等待原定时间 50ms,任务不应该触发
1213        tokio::time::sleep(Duration::from_millis(70)).await;
1214        assert_eq!(counter.load(Ordering::SeqCst), 0);
1215
1216        // 等待新的触发时间(从推迟开始算,还需要等待约 150ms)
1217        tokio::time::sleep(Duration::from_millis(200)).await;
1218        
1219        // 等待回调执行
1220        tokio::time::sleep(Duration::from_millis(20)).await;
1221        assert_eq!(counter.load(Ordering::SeqCst), 3);
1222    }
1223
1224    #[tokio::test]
1225    async fn test_postpone_batch_with_callbacks() {
1226        use std::sync::Arc;
1227        let timer = TimerWheel::with_defaults();
1228        let counter = Arc::new(AtomicU32::new(0));
1229
1230        // 创建 3 个任务
1231        let mut task_ids = Vec::new();
1232        for _ in 0..3 {
1233            let task = TimerWheel::create_task(
1234                Duration::from_millis(50),
1235                None
1236            );
1237            task_ids.push(task.get_id());
1238            timer.register(task);
1239        }
1240
1241        // 批量推迟并替换回调
1242        let updates: Vec<_> = task_ids
1243            .into_iter()
1244            .map(|id| {
1245                let counter_clone = Arc::clone(&counter);
1246                (id, Duration::from_millis(150), Some(CallbackWrapper::new(move || {
1247                    let counter = Arc::clone(&counter_clone);
1248                    async move {
1249                        counter.fetch_add(1, Ordering::SeqCst);
1250                    }
1251                })))
1252            })
1253            .collect();
1254
1255        let postponed = timer.postpone_batch_with_callbacks(updates);
1256        assert_eq!(postponed, 3);
1257
1258        // 等待原定时间 50ms,任务不应该触发
1259        tokio::time::sleep(Duration::from_millis(70)).await;
1260        assert_eq!(counter.load(Ordering::SeqCst), 0);
1261
1262        // 等待新的触发时间(从推迟开始算,还需要等待约 150ms)
1263        tokio::time::sleep(Duration::from_millis(200)).await;
1264        
1265        // 等待回调执行
1266        tokio::time::sleep(Duration::from_millis(20)).await;
1267        assert_eq!(counter.load(Ordering::SeqCst), 3);
1268    }
1269
1270    #[tokio::test]
1271    async fn test_postpone_keeps_completion_receiver_valid() {
1272        use std::sync::Arc;
1273        let timer = TimerWheel::with_defaults();
1274        let counter = Arc::new(AtomicU32::new(0));
1275        let counter_clone = Arc::clone(&counter);
1276
1277        let task = TimerWheel::create_task(
1278            Duration::from_millis(50),
1279            Some(CallbackWrapper::new(move || {
1280                let counter = Arc::clone(&counter_clone);
1281                async move {
1282                    counter.fetch_add(1, Ordering::SeqCst);
1283                }
1284            })),
1285        );
1286        let task_id = task.get_id();
1287        let handle = timer.register(task);
1288
1289        // 推迟任务
1290        timer.postpone(task_id, Duration::from_millis(100), None);
1291
1292        // 验证原 completion_receiver 仍然有效(推迟后需要等待100ms,加上余量)
1293        let result = tokio::time::timeout(
1294            Duration::from_millis(200),
1295            handle.into_completion_receiver().0
1296        ).await;
1297        assert!(result.is_ok(), "Completion receiver should still work after postpone");
1298        
1299        // 等待回调执行
1300        tokio::time::sleep(Duration::from_millis(20)).await;
1301        assert_eq!(counter.load(Ordering::SeqCst), 1);
1302    }
1303}
1304