kestrel_protocol_timer/
timer.rs

1use crate::config::{ServiceConfig, WheelConfig};
2use crate::task::{CallbackWrapper, TaskId, TimerCallback};
3use crate::wheel::Wheel;
4use parking_lot::Mutex;
5use std::sync::Arc;
6use std::time::Duration;
7use tokio::sync::oneshot;
8use tokio::task::JoinHandle;
9
10/// 完成通知接收器,用于接收定时器完成通知
11pub struct CompletionReceiver(pub oneshot::Receiver<()>);
12
13/// 定时器句柄,用于管理定时器的生命周期
14/// 
15/// 注意:此类型不实现 Clone,以防止重复取消同一个定时器。
16/// 每个定时器只应有一个所有者。
17pub struct TimerHandle {
18    pub(crate) task_id: TaskId,
19    pub(crate) wheel: Arc<Mutex<Wheel>>,
20    pub(crate) completion_rx: CompletionReceiver,
21}
22
23impl TimerHandle {
24    pub(crate) fn new(task_id: TaskId, wheel: Arc<Mutex<Wheel>>, completion_rx: oneshot::Receiver<()>) -> Self {
25        Self { task_id, wheel, completion_rx: CompletionReceiver(completion_rx) }
26    }
27
28    /// 取消定时器
29    ///
30    /// # 返回
31    /// 如果任务存在且成功取消返回 true,否则返回 false
32    ///
33    /// # 示例
34    /// ```no_run
35    /// # use kestrel_protocol_timer::{TimerWheel, TimerTask};
36    /// # use std::time::Duration;
37    /// # #[tokio::main]
38    /// # async fn main() {
39    /// let timer = TimerWheel::with_defaults();
40    /// let task = TimerWheel::create_task(Duration::from_secs(1), || async {});
41    /// let handle = timer.register(task);
42    /// 
43    /// // 取消定时器
44    /// let success = handle.cancel();
45    /// println!("取消成功: {}", success);
46    /// # }
47    /// ```
48    pub fn cancel(&self) -> bool {
49        let mut wheel = self.wheel.lock();
50        wheel.cancel(self.task_id)
51    }
52
53    /// 获取完成通知接收器的可变引用
54    ///
55    /// # 示例
56    /// ```no_run
57    /// # use kestrel_protocol_timer::{TimerWheel, TimerTask};
58    /// # use std::time::Duration;
59    /// # #[tokio::main]
60    /// # async fn main() {
61    /// let timer = TimerWheel::with_defaults();
62    /// let task = TimerWheel::create_task(Duration::from_secs(1), || async {
63    ///     println!("Timer fired!");
64    /// });
65    /// let handle = timer.register(task);
66    /// 
67    /// // 等待定时器完成(使用 into_completion_receiver 消耗句柄)
68    /// handle.into_completion_receiver().0.await.ok();
69    /// println!("Timer completed!");
70    /// # }
71    /// ```
72    pub fn completion_receiver(&mut self) -> &mut CompletionReceiver {
73        &mut self.completion_rx
74    }
75
76    /// 消耗句柄,返回完成通知接收器
77    ///
78    /// # 示例
79    /// ```no_run
80    /// # use kestrel_protocol_timer::{TimerWheel, TimerTask};
81    /// # use std::time::Duration;
82    /// # #[tokio::main]
83    /// # async fn main() {
84    /// let timer = TimerWheel::with_defaults();
85    /// let task = TimerWheel::create_task(Duration::from_secs(1), || async {
86    ///     println!("Timer fired!");
87    /// });
88    /// let handle = timer.register(task);
89    /// 
90    /// // 等待定时器完成
91    /// handle.into_completion_receiver().0.await.ok();
92    /// println!("Timer completed!");
93    /// # }
94    /// ```
95    pub fn into_completion_receiver(self) -> CompletionReceiver {
96        self.completion_rx
97    }
98}
99
100/// 批量定时器句柄,用于管理批量调度的定时器
101/// 
102/// 通过共享 Wheel 引用减少内存开销,同时提供批量操作和迭代器访问能力。
103/// 
104/// 注意:此类型不实现 Clone,以防止重复取消同一批定时器。
105/// 如需访问单个定时器句柄,请使用 `into_iter()` 或 `into_handles()` 进行转换。
106pub struct BatchHandle {
107    pub(crate) task_ids: Vec<TaskId>,
108    pub(crate) wheel: Arc<Mutex<Wheel>>,
109    pub(crate) completion_rxs: Vec<oneshot::Receiver<()>>,
110}
111
112impl BatchHandle {
113    pub(crate) fn new(task_ids: Vec<TaskId>, wheel: Arc<Mutex<Wheel>>, completion_rxs: Vec<oneshot::Receiver<()>>) -> Self {
114        Self { task_ids, wheel, completion_rxs }
115    }
116
117    /// 批量取消所有定时器
118    ///
119    /// # 返回
120    /// 成功取消的任务数量
121    ///
122    /// # 示例
123    /// ```no_run
124    /// # use kestrel_protocol_timer::{TimerWheel, TimerTask};
125    /// # use std::time::Duration;
126    /// # #[tokio::main]
127    /// # async fn main() {
128    /// let timer = TimerWheel::with_defaults();
129    /// let tasks = TimerWheel::create_batch(
130    ///     (0..10).map(|_| (Duration::from_secs(1), || async {})).collect()
131    /// );
132    /// let batch = timer.register_batch(tasks);
133    /// 
134    /// let cancelled = batch.cancel_all();
135    /// println!("取消了 {} 个定时器", cancelled);
136    /// # }
137    /// ```
138    pub fn cancel_all(self) -> usize {
139        let mut wheel = self.wheel.lock();
140        wheel.cancel_batch(&self.task_ids)
141    }
142
143    /// 将批量句柄转换为单个定时器句柄的 Vec
144    ///
145    /// 消耗 BatchHandle,为每个任务创建独立的 TimerHandle。
146    ///
147    /// # 示例
148    /// ```no_run
149    /// # use kestrel_protocol_timer::{TimerWheel, TimerTask};
150    /// # use std::time::Duration;
151    /// # #[tokio::main]
152    /// # async fn main() {
153    /// let timer = TimerWheel::with_defaults();
154    /// let tasks = TimerWheel::create_batch(
155    ///     (0..3).map(|_| (Duration::from_secs(1), || async {})).collect()
156    /// );
157    /// let batch = timer.register_batch(tasks);
158    /// 
159    /// // 转换为独立的句柄
160    /// let handles = batch.into_handles();
161    /// for handle in handles {
162    ///     // 可以单独操作每个句柄
163    /// }
164    /// # }
165    /// ```
166    pub fn into_handles(self) -> Vec<TimerHandle> {
167        self.task_ids
168            .into_iter()
169            .zip(self.completion_rxs.into_iter())
170            .map(|(task_id, rx)| {
171                TimerHandle::new(task_id, self.wheel.clone(), rx)
172            })
173            .collect()
174    }
175
176    /// 获取批量任务的数量
177    pub fn len(&self) -> usize {
178        self.task_ids.len()
179    }
180
181    /// 检查批量任务是否为空
182    pub fn is_empty(&self) -> bool {
183        self.task_ids.is_empty()
184    }
185
186    /// 获取所有任务 ID 的引用
187    pub fn task_ids(&self) -> &[TaskId] {
188        &self.task_ids
189    }
190
191    /// 获取所有完成通知接收器的引用
192    ///
193    /// # 返回
194    /// 所有任务的完成通知接收器列表引用
195    pub fn completion_receivers(&mut self) -> &mut Vec<oneshot::Receiver<()>> {
196        &mut self.completion_rxs
197    }
198
199    /// 消耗句柄,返回所有完成通知接收器
200    ///
201    /// # 返回
202    /// 所有任务的完成通知接收器列表
203    ///
204    /// # 示例
205    /// ```no_run
206    /// # use kestrel_protocol_timer::{TimerWheel, TimerTask};
207    /// # use std::time::Duration;
208    /// # #[tokio::main]
209    /// # async fn main() {
210    /// let timer = TimerWheel::with_defaults();
211    /// let tasks = TimerWheel::create_batch(
212    ///     (0..3).map(|_| (Duration::from_secs(1), || async {})).collect()
213    /// );
214    /// let batch = timer.register_batch(tasks);
215    /// 
216    /// // 获取所有完成通知接收器
217    /// let receivers = batch.into_completion_receivers();
218    /// for rx in receivers {
219    ///     tokio::spawn(async move {
220    ///         if rx.await.is_ok() {
221    ///             println!("A timer completed!");
222    ///         }
223    ///     });
224    /// }
225    /// # }
226    /// ```
227    pub fn into_completion_receivers(self) -> Vec<oneshot::Receiver<()>> {
228        self.completion_rxs
229    }
230}
231
232/// 实现 IntoIterator,允许直接迭代 BatchHandle
233/// 
234/// # 示例
235/// ```no_run
236/// # use kestrel_protocol_timer::{TimerWheel, TimerTask};
237/// # use std::time::Duration;
238/// # #[tokio::main]
239/// # async fn main() {
240/// let timer = TimerWheel::with_defaults();
241/// let tasks = TimerWheel::create_batch(
242///     (0..3).map(|_| (Duration::from_secs(1), || async {})).collect()
243/// );
244/// let batch = timer.register_batch(tasks);
245/// 
246/// // 直接迭代,每个元素都是独立的 TimerHandle
247/// for handle in batch {
248///     // 可以单独操作每个句柄
249/// }
250/// # }
251/// ```
252impl IntoIterator for BatchHandle {
253    type Item = TimerHandle;
254    type IntoIter = BatchHandleIter;
255
256    fn into_iter(self) -> Self::IntoIter {
257        BatchHandleIter {
258            task_ids: self.task_ids.into_iter(),
259            completion_rxs: self.completion_rxs.into_iter(),
260            wheel: self.wheel,
261        }
262    }
263}
264
265/// BatchHandle 的迭代器
266pub struct BatchHandleIter {
267    task_ids: std::vec::IntoIter<TaskId>,
268    completion_rxs: std::vec::IntoIter<oneshot::Receiver<()>>,
269    wheel: Arc<Mutex<Wheel>>,
270}
271
272impl Iterator for BatchHandleIter {
273    type Item = TimerHandle;
274
275    fn next(&mut self) -> Option<Self::Item> {
276        match (self.task_ids.next(), self.completion_rxs.next()) {
277            (Some(task_id), Some(rx)) => {
278                Some(TimerHandle::new(task_id, self.wheel.clone(), rx))
279            }
280            _ => None,
281        }
282    }
283
284    fn size_hint(&self) -> (usize, Option<usize>) {
285        self.task_ids.size_hint()
286    }
287}
288
289impl ExactSizeIterator for BatchHandleIter {
290    fn len(&self) -> usize {
291        self.task_ids.len()
292    }
293}
294
295/// 时间轮定时器管理器
296pub struct TimerWheel {
297    /// 时间轮唯一标识符
298    
299    /// 时间轮实例(使用 Arc<Mutex> 包装以支持多线程访问)
300    wheel: Arc<Mutex<Wheel>>,
301    
302    /// 后台 tick 循环任务句柄
303    tick_handle: Option<JoinHandle<()>>,
304}
305
306impl TimerWheel {
307    /// 创建新的定时器管理器
308    ///
309    /// # 参数
310    /// - `config`: 时间轮配置(已经过验证)
311    ///
312    /// # 示例
313    /// ```no_run
314    /// use kestrel_protocol_timer::{TimerWheel, WheelConfig, TimerTask};
315    /// use std::time::Duration;
316    ///
317    /// #[tokio::main]
318    /// async fn main() {
319    ///     let config = WheelConfig::builder()
320    ///         .tick_duration(Duration::from_millis(10))
321    ///         .slot_count(512)
322    ///         .build()
323    ///         .unwrap();
324    ///     let timer = TimerWheel::new(config);
325    ///     
326    ///     // 使用两步式 API
327    ///     let task = TimerWheel::create_task(Duration::from_secs(1), || async {});
328    ///     let handle = timer.register(task);
329    /// }
330    /// ```
331    pub fn new(config: WheelConfig) -> Self {
332        let tick_duration = config.tick_duration;
333        let wheel = Wheel::new(config);
334        let wheel = Arc::new(Mutex::new(wheel));
335        let wheel_clone = wheel.clone();
336
337        // 启动后台 tick 循环
338        let tick_handle = tokio::spawn(async move {
339            Self::tick_loop(wheel_clone, tick_duration).await;
340        });
341
342        Self {
343            wheel,
344            tick_handle: Some(tick_handle),
345        }
346    }
347
348    /// 创建带默认配置的定时器管理器
349    /// - tick 时长: 10ms
350    /// - 槽位数量: 512
351    ///
352    /// # 示例
353    /// ```no_run
354    /// use kestrel_protocol_timer::TimerWheel;
355    ///
356    /// #[tokio::main]
357    /// async fn main() {
358    ///     let timer = TimerWheel::with_defaults();
359    /// }
360    /// ```
361    pub fn with_defaults() -> Self {
362        Self::new(WheelConfig::default())
363    }
364
365    /// 创建与此时间轮绑定的 TimerService(使用默认配置)
366    ///
367    /// # 返回
368    /// 绑定到此时间轮的 TimerService 实例
369    ///
370    /// # 示例
371    /// ```no_run
372    /// use kestrel_protocol_timer::{TimerWheel, TimerService};
373    /// use std::time::Duration;
374    ///
375    /// #[tokio::main]
376    /// async fn main() {
377    ///     let timer = TimerWheel::with_defaults();
378    ///     let mut service = timer.create_service();
379    ///     
380    ///     // 使用两步式 API 通过 service 批量调度定时器
381    ///     let tasks = TimerService::create_batch(
382    ///         (0..5).map(|_| (Duration::from_millis(100), || async {})).collect()
383    ///     );
384    ///     service.register_batch(tasks).await;
385    ///     
386    ///     // 接收超时通知
387    ///     let mut rx = service.take_receiver().unwrap();
388    ///     while let Some(task_id) = rx.recv().await {
389    ///         println!("Task {:?} completed", task_id);
390    ///     }
391    /// }
392    /// ```
393    pub fn create_service(&self) -> crate::service::TimerService {
394        crate::service::TimerService::new(self.wheel.clone(), ServiceConfig::default())
395    }
396    
397    /// 创建与此时间轮绑定的 TimerService(使用自定义配置)
398    ///
399    /// # 参数
400    /// - `config`: 服务配置
401    ///
402    /// # 返回
403    /// 绑定到此时间轮的 TimerService 实例
404    ///
405    /// # 示例
406    /// ```no_run
407    /// use kestrel_protocol_timer::{TimerWheel, ServiceConfig};
408    ///
409    /// #[tokio::main]
410    /// async fn main() {
411    ///     let timer = TimerWheel::with_defaults();
412    ///     let config = ServiceConfig::builder()
413    ///         .command_channel_capacity(1024)
414    ///         .timeout_channel_capacity(2000)
415    ///         .build()
416    ///         .unwrap();
417    ///     let service = timer.create_service_with_config(config);
418    /// }
419    /// ```
420    pub fn create_service_with_config(&self, config: ServiceConfig) -> crate::service::TimerService {
421        crate::service::TimerService::new(self.wheel.clone(), config)
422    }
423
424    /// 创建定时器任务(申请阶段)
425    /// 
426    /// # 参数
427    /// - `delay`: 延迟时间
428    /// - `callback`: 实现了 TimerCallback trait 的回调对象
429    /// 
430    /// # 返回
431    /// 返回 TimerTask,需要通过 `register()` 注册到时间轮
432    /// 
433    /// # 示例
434    /// ```no_run
435    /// use kestrel_protocol_timer::{TimerWheel, TimerTask};
436    /// use std::time::Duration;
437    /// 
438    /// #[tokio::main]
439    /// async fn main() {
440    ///     let timer = TimerWheel::with_defaults();
441    ///     
442    ///     // 步骤 1: 创建任务
443    ///     let task = TimerWheel::create_task(Duration::from_secs(1), || async {
444    ///         println!("Timer fired!");
445    ///     });
446    ///     
447    ///     // 获取任务 ID
448    ///     let task_id = task.get_id();
449    ///     println!("Created task: {:?}", task_id);
450    ///     
451    ///     // 步骤 2: 注册任务
452    ///     let handle = timer.register(task);
453    /// }
454    /// ```
455    pub fn create_task<C>(delay: Duration, callback: C) -> crate::task::TimerTask
456    where
457        C: TimerCallback,
458    {
459        use std::sync::Arc;
460        let callback_wrapper = Arc::new(callback) as CallbackWrapper;
461        crate::task::TimerTask::new(delay, Some(callback_wrapper))
462    }
463    
464    /// 批量创建定时器任务(申请阶段)
465    /// 
466    /// # 参数
467    /// - `callbacks`: (延迟时间, 回调) 的元组列表
468    /// 
469    /// # 返回
470    /// 返回 TimerTask 列表,需要通过 `register_batch()` 注册到时间轮
471    /// 
472    /// # 示例
473    /// ```no_run
474    /// use kestrel_protocol_timer::{TimerWheel, TimerTask};
475    /// use std::time::Duration;
476    /// use std::sync::Arc;
477    /// use std::sync::atomic::{AtomicU32, Ordering};
478    /// 
479    /// #[tokio::main]
480    /// async fn main() {
481    ///     let timer = TimerWheel::with_defaults();
482    ///     let counter = Arc::new(AtomicU32::new(0));
483    ///     
484    ///     // 步骤 1: 批量创建任务
485    ///     let callbacks: Vec<_> = (0..3)
486    ///         .map(|i| {
487    ///             let counter = Arc::clone(&counter);
488    ///             let delay = Duration::from_millis(100 + i * 100);
489    ///             let callback = move || {
490    ///                 let counter = Arc::clone(&counter);
491    ///                 async move {
492    ///                     counter.fetch_add(1, Ordering::SeqCst);
493    ///                 }
494    ///             };
495    ///             (delay, callback)
496    ///         })
497    ///         .collect();
498    ///     
499    ///     let tasks = TimerWheel::create_batch(callbacks);
500    ///     println!("Created {} tasks", tasks.len());
501    ///     
502    ///     // 步骤 2: 批量注册任务
503    ///     let batch = timer.register_batch(tasks);
504    /// }
505    /// ```
506    pub fn create_batch<C>(callbacks: Vec<(Duration, C)>) -> Vec<crate::task::TimerTask>
507    where
508        C: TimerCallback,
509    {
510        use std::sync::Arc;
511        callbacks
512            .into_iter()
513            .map(|(delay, callback)| {
514                let callback_wrapper = Arc::new(callback) as CallbackWrapper;
515                crate::task::TimerTask::new(delay, Some(callback_wrapper))
516            })
517            .collect()
518    }
519    
520    /// 注册定时器任务到时间轮(注册阶段)
521    /// 
522    /// # 参数
523    /// - `task`: 通过 `create_task()` 创建的任务
524    /// 
525    /// # 返回
526    /// 返回定时器句柄,可用于取消定时器和接收完成通知
527    /// 
528    /// # 示例
529    /// ```no_run
530    /// use kestrel_protocol_timer::{TimerWheel, TimerTask};
531    /// use std::time::Duration;
532    /// 
533    /// #[tokio::main]
534    /// async fn main() {
535    ///     let timer = TimerWheel::with_defaults();
536    ///     
537    ///     let task = TimerWheel::create_task(Duration::from_secs(1), || async {
538    ///         println!("Timer fired!");
539    ///     });
540    ///     let task_id = task.get_id();
541    ///     
542    ///     let handle = timer.register(task);
543    ///     
544    ///     // 等待定时器完成
545    ///     handle.into_completion_receiver().0.await.ok();
546    /// }
547    /// ```
548    #[inline]
549    pub fn register(&self, task: crate::task::TimerTask) -> TimerHandle {
550        let (completion_tx, completion_rx) = oneshot::channel();
551        let notifier = crate::task::CompletionNotifier(completion_tx);
552        
553        let delay = task.delay;
554        let task_id = task.id;
555        
556        // 单次加锁完成所有操作
557        {
558            let mut wheel_guard = self.wheel.lock();
559            wheel_guard.insert(delay, task, notifier);
560        }
561        
562        TimerHandle::new(task_id, self.wheel.clone(), completion_rx)
563    }
564    
565    /// 批量注册定时器任务到时间轮(注册阶段)
566    /// 
567    /// # 参数
568    /// - `tasks`: 通过 `create_batch()` 创建的任务列表
569    /// 
570    /// # 返回
571    /// 返回批量定时器句柄
572    /// 
573    /// # 示例
574    /// ```no_run
575    /// use kestrel_protocol_timer::{TimerWheel, TimerTask};
576    /// use std::time::Duration;
577    /// 
578    /// #[tokio::main]
579    /// async fn main() {
580    ///     let timer = TimerWheel::with_defaults();
581    ///     
582    ///     let callbacks: Vec<_> = (0..3)
583    ///         .map(|_| (Duration::from_secs(1), || async {}))
584    ///         .collect();
585    ///     let tasks = TimerWheel::create_batch(callbacks);
586    ///     
587    ///     let batch = timer.register_batch(tasks);
588    ///     println!("Registered {} timers", batch.len());
589    /// }
590    /// ```
591    #[inline]
592    pub fn register_batch(&self, tasks: Vec<crate::task::TimerTask>) -> BatchHandle {
593        let task_count = tasks.len();
594        let mut completion_rxs = Vec::with_capacity(task_count);
595        let mut task_ids = Vec::with_capacity(task_count);
596        let mut prepared_tasks = Vec::with_capacity(task_count);
597        
598        // 步骤1: 准备所有 channels 和 notifiers(无锁)
599        // 优化:使用 for 循环代替 map + collect,避免闭包捕获开销
600        for task in tasks {
601            let (completion_tx, completion_rx) = oneshot::channel();
602            let notifier = crate::task::CompletionNotifier(completion_tx);
603            
604            task_ids.push(task.id);
605            completion_rxs.push(completion_rx);
606            prepared_tasks.push((task.delay, task, notifier));
607        }
608        
609        // 步骤2: 单次加锁,批量插入
610        {
611            let mut wheel_guard = self.wheel.lock();
612            wheel_guard.insert_batch(prepared_tasks);
613        }
614        
615        BatchHandle::new(task_ids, self.wheel.clone(), completion_rxs)
616    }
617
618    /// 取消定时器
619    ///
620    /// # 参数
621    /// - `task_id`: 任务 ID
622    ///
623    /// # 返回
624    /// 如果任务存在且成功取消返回 true,否则返回 false
625    /// 
626    /// # 示例
627    /// ```no_run
628    /// use kestrel_protocol_timer::{TimerWheel, TimerTask};
629    /// use std::time::Duration;
630    ///
631    /// #[tokio::main]
632    /// async fn main() {
633    ///     let timer = TimerWheel::with_defaults();
634    ///     
635    ///     let task = TimerWheel::create_task(Duration::from_secs(10), || async {});
636    ///     let task_id = task.get_id();
637    ///     let _handle = timer.register(task);
638    ///     
639    ///     // 使用任务 ID 取消
640    ///     let cancelled = timer.cancel(task_id);
641    ///     println!("取消成功: {}", cancelled);
642    /// }
643    /// ```
644    #[inline]
645    pub fn cancel(&self, task_id: TaskId) -> bool {
646        let mut wheel = self.wheel.lock();
647        wheel.cancel(task_id)
648    }
649
650    /// 批量取消定时器
651    ///
652    /// # 参数
653    /// - `task_ids`: 要取消的任务 ID 列表
654    ///
655    /// # 返回
656    /// 成功取消的任务数量
657    ///
658    /// # 性能优势
659    /// - 批量处理减少锁竞争
660    /// - 内部优化批量取消操作
661    ///
662    /// # 示例
663    /// ```no_run
664    /// use kestrel_protocol_timer::{TimerWheel, TimerTask};
665    /// use std::time::Duration;
666    ///
667    /// #[tokio::main]
668    /// async fn main() {
669    ///     let timer = TimerWheel::with_defaults();
670    ///     
671    ///     // 创建多个定时器
672    ///     let task1 = TimerWheel::create_task(Duration::from_secs(10), || async {});
673    ///     let task2 = TimerWheel::create_task(Duration::from_secs(10), || async {});
674    ///     let task3 = TimerWheel::create_task(Duration::from_secs(10), || async {});
675    ///     
676    ///     let task_ids = vec![task1.get_id(), task2.get_id(), task3.get_id()];
677    ///     
678    ///     let _h1 = timer.register(task1);
679    ///     let _h2 = timer.register(task2);
680    ///     let _h3 = timer.register(task3);
681    ///     
682    ///     // 批量取消
683    ///     let cancelled = timer.cancel_batch(&task_ids);
684    ///     println!("已取消 {} 个定时器", cancelled);
685    /// }
686    /// ```
687    #[inline]
688    pub fn cancel_batch(&self, task_ids: &[TaskId]) -> usize {
689        let mut wheel = self.wheel.lock();
690        wheel.cancel_batch(task_ids)
691    }
692
693    /// 推迟定时器(保持原回调)
694    ///
695    /// # 参数
696    /// - `task_id`: 要推迟的任务 ID
697    /// - `new_delay`: 新的延迟时间(从当前时间点重新计算)
698    ///
699    /// # 返回
700    /// 如果任务存在且成功推迟返回 true,否则返回 false
701    ///
702    /// # 注意
703    /// - 推迟后任务 ID 保持不变
704    /// - 原有的 completion_receiver 仍然有效
705    /// - 保持原回调函数不变
706    ///
707    /// # 示例
708    /// ```no_run
709    /// use kestrel_protocol_timer::{TimerWheel, TimerTask};
710    /// use std::time::Duration;
711    ///
712    /// #[tokio::main]
713    /// async fn main() {
714    ///     let timer = TimerWheel::with_defaults();
715    ///     
716    ///     let task = TimerWheel::create_task(Duration::from_secs(5), || async {
717    ///         println!("Timer fired!");
718    ///     });
719    ///     let task_id = task.get_id();
720    ///     let _handle = timer.register(task);
721    ///     
722    ///     // 推迟到 10 秒后触发
723    ///     let success = timer.postpone(task_id, Duration::from_secs(10));
724    ///     println!("推迟成功: {}", success);
725    /// }
726    /// ```
727    #[inline]
728    pub fn postpone(&self, task_id: TaskId, new_delay: Duration) -> bool {
729        let mut wheel = self.wheel.lock();
730        wheel.postpone(task_id, new_delay, None)
731    }
732
733    /// 推迟定时器(替换回调)
734    ///
735    /// # 参数
736    /// - `task_id`: 要推迟的任务 ID
737    /// - `new_delay`: 新的延迟时间(从当前时间点重新计算)
738    /// - `callback`: 新的回调函数
739    ///
740    /// # 返回
741    /// 如果任务存在且成功推迟返回 true,否则返回 false
742    ///
743    /// # 注意
744    /// - 推迟后任务 ID 保持不变
745    /// - 原有的 completion_receiver 仍然有效
746    /// - 回调函数会被替换为新的回调
747    ///
748    /// # 示例
749    /// ```no_run
750    /// use kestrel_protocol_timer::{TimerWheel, TimerTask};
751    /// use std::time::Duration;
752    ///
753    /// #[tokio::main]
754    /// async fn main() {
755    ///     let timer = TimerWheel::with_defaults();
756    ///     
757    ///     let task = TimerWheel::create_task(Duration::from_secs(5), || async {
758    ///         println!("Original callback");
759    ///     });
760    ///     let task_id = task.get_id();
761    ///     let _handle = timer.register(task);
762    ///     
763    ///     // 推迟并替换回调
764    ///     let success = timer.postpone_with_callback(
765    ///         task_id,
766    ///         Duration::from_secs(10),
767    ///         || async {
768    ///             println!("New callback!");
769    ///         }
770    ///     );
771    ///     println!("推迟成功: {}", success);
772    /// }
773    /// ```
774    #[inline]
775    pub fn postpone_with_callback<C>(
776        &self,
777        task_id: TaskId,
778        new_delay: Duration,
779        callback: C,
780    ) -> bool
781    where
782        C: TimerCallback,
783    {
784        use std::sync::Arc;
785        let callback_wrapper = Arc::new(callback) as CallbackWrapper;
786        let mut wheel = self.wheel.lock();
787        wheel.postpone(task_id, new_delay, Some(callback_wrapper))
788    }
789
790    /// 批量推迟定时器(保持原回调)
791    ///
792    /// # 参数
793    /// - `updates`: (任务ID, 新延迟) 的元组列表
794    ///
795    /// # 返回
796    /// 成功推迟的任务数量
797    ///
798    /// # 性能优势
799    /// - 批量处理减少锁竞争
800    /// - 内部优化批量推迟操作
801    ///
802    /// # 示例
803    /// ```no_run
804    /// use kestrel_protocol_timer::{TimerWheel, TimerTask};
805    /// use std::time::Duration;
806    ///
807    /// #[tokio::main]
808    /// async fn main() {
809    ///     let timer = TimerWheel::with_defaults();
810    ///     
811    ///     // 创建多个定时器
812    ///     let task1 = TimerWheel::create_task(Duration::from_secs(5), || async {});
813    ///     let task2 = TimerWheel::create_task(Duration::from_secs(5), || async {});
814    ///     let task3 = TimerWheel::create_task(Duration::from_secs(5), || async {});
815    ///     
816    ///     let task_ids = vec![
817    ///         (task1.get_id(), Duration::from_secs(10)),
818    ///         (task2.get_id(), Duration::from_secs(15)),
819    ///         (task3.get_id(), Duration::from_secs(20)),
820    ///     ];
821    ///     
822    ///     timer.register(task1);
823    ///     timer.register(task2);
824    ///     timer.register(task3);
825    ///     
826    ///     // 批量推迟
827    ///     let postponed = timer.postpone_batch(&task_ids);
828    ///     println!("已推迟 {} 个定时器", postponed);
829    /// }
830    /// ```
831    #[inline]
832    pub fn postpone_batch(&self, updates: &[(TaskId, Duration)]) -> usize {
833        let updates_vec: Vec<_> = updates
834            .iter()
835            .map(|(task_id, delay)| (*task_id, *delay, None))
836            .collect();
837        let mut wheel = self.wheel.lock();
838        wheel.postpone_batch(updates_vec)
839    }
840
841    /// 批量推迟定时器(替换回调)
842    ///
843    /// # 参数
844    /// - `updates`: (任务ID, 新延迟, 新回调) 的元组列表
845    ///
846    /// # 返回
847    /// 成功推迟的任务数量
848    ///
849    /// # 性能优势
850    /// - 批量处理减少锁竞争
851    /// - 内部优化批量推迟操作
852    ///
853    /// # 示例
854    /// ```no_run
855    /// use kestrel_protocol_timer::{TimerWheel, TimerTask};
856    /// use std::time::Duration;
857    /// use std::sync::Arc;
858    /// use std::sync::atomic::{AtomicU32, Ordering};
859    ///
860    /// #[tokio::main]
861    /// async fn main() {
862    ///     let timer = TimerWheel::with_defaults();
863    ///     let counter = Arc::new(AtomicU32::new(0));
864    ///     
865    ///     // 创建多个定时器
866    ///     let task1 = TimerWheel::create_task(Duration::from_secs(5), || async {});
867    ///     let task2 = TimerWheel::create_task(Duration::from_secs(5), || async {});
868    ///     
869    ///     let id1 = task1.get_id();
870    ///     let id2 = task2.get_id();
871    ///     
872    ///     timer.register(task1);
873    ///     timer.register(task2);
874    ///     
875    ///     // 批量推迟并替换回调
876    ///     let updates: Vec<_> = vec![id1, id2]
877    ///         .into_iter()
878    ///         .map(|id| {
879    ///             let counter = Arc::clone(&counter);
880    ///             (id, Duration::from_secs(10), move || {
881    ///                 let counter = Arc::clone(&counter);
882    ///                 async move { counter.fetch_add(1, Ordering::SeqCst); }
883    ///             })
884    ///         })
885    ///         .collect();
886    ///     let postponed = timer.postpone_batch_with_callbacks(updates);
887    ///     println!("已推迟 {} 个定时器", postponed);
888    /// }
889    /// ```
890    #[inline]
891    pub fn postpone_batch_with_callbacks<C>(
892        &self,
893        updates: Vec<(TaskId, Duration, C)>,
894    ) -> usize
895    where
896        C: TimerCallback,
897    {
898        use std::sync::Arc;
899        let updates_vec: Vec<_> = updates
900            .into_iter()
901            .map(|(task_id, delay, callback)| {
902                let callback_wrapper = Arc::new(callback) as CallbackWrapper;
903                (task_id, delay, Some(callback_wrapper))
904            })
905            .collect();
906        let mut wheel = self.wheel.lock();
907        wheel.postpone_batch(updates_vec)
908    }
909    
910    /// 核心 tick 循环
911    async fn tick_loop(wheel: Arc<Mutex<Wheel>>, tick_duration: Duration) {
912        let mut interval = tokio::time::interval(tick_duration);
913        interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
914
915        loop {
916            interval.tick().await;
917
918            // 推进时间轮并获取到期任务
919            let expired_tasks = {
920                let mut wheel_guard = wheel.lock();
921                wheel_guard.advance()
922            };
923
924            // 执行到期任务
925            for task in expired_tasks {
926                let callback = task.get_callback();
927                
928                // 移动task的所有权来获取completion_notifier
929                let notifier = task.completion_notifier;
930                
931                // 只有注册过的任务才有 notifier
932                if let Some(notifier) = notifier {
933                    // 在独立的 tokio 任务中执行回调,并在回调完成后发送通知
934                    if let Some(callback) = callback {
935                        tokio::spawn(async move {
936                            // 执行回调
937                            let future = callback.call();
938                            future.await;
939                            
940                            // 回调执行完成后发送通知
941                            let _ = notifier.0.send(());
942                        });
943                    } else {
944                        // 如果没有回调,立即发送完成通知
945                        let _ = notifier.0.send(());
946                    }
947                }
948            }
949        }
950    }
951
952    /// 停止定时器管理器
953    pub async fn shutdown(mut self) {
954        if let Some(handle) = self.tick_handle.take() {
955            handle.abort();
956            let _ = handle.await;
957        }
958    }
959}
960
961impl Drop for TimerWheel {
962    fn drop(&mut self) {
963        if let Some(handle) = self.tick_handle.take() {
964            handle.abort();
965        }
966    }
967}
968
969#[cfg(test)]
970mod tests {
971    use super::*;
972    use std::sync::atomic::{AtomicU32, Ordering};
973
974    #[tokio::test]
975    async fn test_timer_creation() {
976        let _timer = TimerWheel::with_defaults();
977    }
978
979    #[tokio::test]
980    async fn test_schedule_once() {
981        use std::sync::Arc;
982        let timer = TimerWheel::with_defaults();
983        let counter = Arc::new(AtomicU32::new(0));
984        let counter_clone = Arc::clone(&counter);
985
986        let task = TimerWheel::create_task(
987            Duration::from_millis(50),
988            move || {
989                let counter = Arc::clone(&counter_clone);
990                async move {
991                    counter.fetch_add(1, Ordering::SeqCst);
992                }
993            },
994        );
995        let _handle = timer.register(task);
996
997        // 等待定时器触发
998        tokio::time::sleep(Duration::from_millis(100)).await;
999        assert_eq!(counter.load(Ordering::SeqCst), 1);
1000    }
1001
1002    #[tokio::test]
1003    async fn test_cancel_timer() {
1004        use std::sync::Arc;
1005        let timer = TimerWheel::with_defaults();
1006        let counter = Arc::new(AtomicU32::new(0));
1007        let counter_clone = Arc::clone(&counter);
1008
1009        let task = TimerWheel::create_task(
1010            Duration::from_millis(100),
1011            move || {
1012                let counter = Arc::clone(&counter_clone);
1013                async move {
1014                    counter.fetch_add(1, Ordering::SeqCst);
1015                }
1016            },
1017        );
1018        let handle = timer.register(task);
1019
1020        // 立即取消
1021        let cancel_result = handle.cancel();
1022        assert!(cancel_result);
1023
1024        // 等待足够长时间确保定时器不会触发
1025        tokio::time::sleep(Duration::from_millis(200)).await;
1026        assert_eq!(counter.load(Ordering::SeqCst), 0);
1027    }
1028
1029    #[tokio::test]
1030    async fn test_cancel_immediate() {
1031        use std::sync::Arc;
1032        let timer = TimerWheel::with_defaults();
1033        let counter = Arc::new(AtomicU32::new(0));
1034        let counter_clone = Arc::clone(&counter);
1035
1036        let task = TimerWheel::create_task(
1037            Duration::from_millis(100),
1038            move || {
1039                let counter = Arc::clone(&counter_clone);
1040                async move {
1041                    counter.fetch_add(1, Ordering::SeqCst);
1042                }
1043            },
1044        );
1045        let handle = timer.register(task);
1046
1047        // 立即取消
1048        let cancel_result = handle.cancel();
1049        assert!(cancel_result);
1050
1051        // 等待足够长时间确保定时器不会触发
1052        tokio::time::sleep(Duration::from_millis(200)).await;
1053        assert_eq!(counter.load(Ordering::SeqCst), 0);
1054    }
1055
1056    #[tokio::test]
1057    async fn test_postpone_timer() {
1058        use std::sync::Arc;
1059        let timer = TimerWheel::with_defaults();
1060        let counter = Arc::new(AtomicU32::new(0));
1061        let counter_clone = Arc::clone(&counter);
1062
1063        let task = TimerWheel::create_task(
1064            Duration::from_millis(50),
1065            move || {
1066                let counter = Arc::clone(&counter_clone);
1067                async move {
1068                    counter.fetch_add(1, Ordering::SeqCst);
1069                }
1070            },
1071        );
1072        let task_id = task.get_id();
1073        let handle = timer.register(task);
1074
1075        // 推迟任务到 150ms
1076        let postponed = timer.postpone(task_id, Duration::from_millis(150));
1077        assert!(postponed);
1078
1079        // 等待原定时间 50ms,任务不应该触发
1080        tokio::time::sleep(Duration::from_millis(70)).await;
1081        assert_eq!(counter.load(Ordering::SeqCst), 0);
1082
1083        // 等待新的触发时间(从推迟开始算,还需要等待约 150ms)
1084        let result = tokio::time::timeout(
1085            Duration::from_millis(200),
1086            handle.into_completion_receiver().0
1087        ).await;
1088        assert!(result.is_ok());
1089        
1090        // 等待回调执行
1091        tokio::time::sleep(Duration::from_millis(20)).await;
1092        assert_eq!(counter.load(Ordering::SeqCst), 1);
1093    }
1094
1095    #[tokio::test]
1096    async fn test_postpone_with_callback() {
1097        use std::sync::Arc;
1098        let timer = TimerWheel::with_defaults();
1099        let counter = Arc::new(AtomicU32::new(0));
1100        let counter_clone1 = Arc::clone(&counter);
1101        let counter_clone2 = Arc::clone(&counter);
1102
1103        // 创建任务,原始回调增加 1
1104        let task = TimerWheel::create_task(
1105            Duration::from_millis(50),
1106            move || {
1107                let counter = Arc::clone(&counter_clone1);
1108                async move {
1109                    counter.fetch_add(1, Ordering::SeqCst);
1110                }
1111            },
1112        );
1113        let task_id = task.get_id();
1114        let handle = timer.register(task);
1115
1116        // 推迟任务并替换回调,新回调增加 10
1117        let postponed = timer.postpone_with_callback(
1118            task_id,
1119            Duration::from_millis(100),
1120            move || {
1121                let counter = Arc::clone(&counter_clone2);
1122                async move {
1123                    counter.fetch_add(10, Ordering::SeqCst);
1124                }
1125            }
1126        );
1127        assert!(postponed);
1128
1129        // 等待任务触发(推迟后需要等待100ms,加上余量)
1130        let result = tokio::time::timeout(
1131            Duration::from_millis(200),
1132            handle.into_completion_receiver().0
1133        ).await;
1134        assert!(result.is_ok());
1135        
1136        // 等待回调执行
1137        tokio::time::sleep(Duration::from_millis(20)).await;
1138        
1139        // 验证新回调被执行(增加了 10 而不是 1)
1140        assert_eq!(counter.load(Ordering::SeqCst), 10);
1141    }
1142
1143    #[tokio::test]
1144    async fn test_postpone_nonexistent_timer() {
1145        let timer = TimerWheel::with_defaults();
1146        
1147        // 尝试推迟不存在的任务
1148        let fake_task = TimerWheel::create_task(Duration::from_millis(50), || async {});
1149        let fake_task_id = fake_task.get_id();
1150        // 不注册这个任务
1151        
1152        let postponed = timer.postpone(fake_task_id, Duration::from_millis(100));
1153        assert!(!postponed);
1154    }
1155
1156    #[tokio::test]
1157    async fn test_postpone_batch() {
1158        use std::sync::Arc;
1159        let timer = TimerWheel::with_defaults();
1160        let counter = Arc::new(AtomicU32::new(0));
1161
1162        // 创建 3 个任务
1163        let mut task_ids = Vec::new();
1164        for _ in 0..3 {
1165            let counter_clone = Arc::clone(&counter);
1166            let task = TimerWheel::create_task(
1167                Duration::from_millis(50),
1168                move || {
1169                    let counter = Arc::clone(&counter_clone);
1170                    async move {
1171                        counter.fetch_add(1, Ordering::SeqCst);
1172                    }
1173                },
1174            );
1175            task_ids.push((task.get_id(), Duration::from_millis(150)));
1176            timer.register(task);
1177        }
1178
1179        // 批量推迟
1180        let postponed = timer.postpone_batch(&task_ids);
1181        assert_eq!(postponed, 3);
1182
1183        // 等待原定时间 50ms,任务不应该触发
1184        tokio::time::sleep(Duration::from_millis(70)).await;
1185        assert_eq!(counter.load(Ordering::SeqCst), 0);
1186
1187        // 等待新的触发时间(从推迟开始算,还需要等待约 150ms)
1188        tokio::time::sleep(Duration::from_millis(200)).await;
1189        
1190        // 等待回调执行
1191        tokio::time::sleep(Duration::from_millis(20)).await;
1192        assert_eq!(counter.load(Ordering::SeqCst), 3);
1193    }
1194
1195    #[tokio::test]
1196    async fn test_postpone_batch_with_callbacks() {
1197        use std::sync::Arc;
1198        let timer = TimerWheel::with_defaults();
1199        let counter = Arc::new(AtomicU32::new(0));
1200
1201        // 创建 3 个任务
1202        let mut task_ids = Vec::new();
1203        for _ in 0..3 {
1204            let task = TimerWheel::create_task(
1205                Duration::from_millis(50),
1206                || async {},
1207            );
1208            task_ids.push(task.get_id());
1209            timer.register(task);
1210        }
1211
1212        // 批量推迟并替换回调
1213        let updates: Vec<_> = task_ids
1214            .into_iter()
1215            .map(|id| {
1216                let counter_clone = Arc::clone(&counter);
1217                (id, Duration::from_millis(150), move || {
1218                    let counter = Arc::clone(&counter_clone);
1219                    async move {
1220                        counter.fetch_add(1, Ordering::SeqCst);
1221                    }
1222                })
1223            })
1224            .collect();
1225
1226        let postponed = timer.postpone_batch_with_callbacks(updates);
1227        assert_eq!(postponed, 3);
1228
1229        // 等待原定时间 50ms,任务不应该触发
1230        tokio::time::sleep(Duration::from_millis(70)).await;
1231        assert_eq!(counter.load(Ordering::SeqCst), 0);
1232
1233        // 等待新的触发时间(从推迟开始算,还需要等待约 150ms)
1234        tokio::time::sleep(Duration::from_millis(200)).await;
1235        
1236        // 等待回调执行
1237        tokio::time::sleep(Duration::from_millis(20)).await;
1238        assert_eq!(counter.load(Ordering::SeqCst), 3);
1239    }
1240
1241    #[tokio::test]
1242    async fn test_postpone_keeps_completion_receiver_valid() {
1243        use std::sync::Arc;
1244        let timer = TimerWheel::with_defaults();
1245        let counter = Arc::new(AtomicU32::new(0));
1246        let counter_clone = Arc::clone(&counter);
1247
1248        let task = TimerWheel::create_task(
1249            Duration::from_millis(50),
1250            move || {
1251                let counter = Arc::clone(&counter_clone);
1252                async move {
1253                    counter.fetch_add(1, Ordering::SeqCst);
1254                }
1255            },
1256        );
1257        let task_id = task.get_id();
1258        let handle = timer.register(task);
1259
1260        // 推迟任务
1261        timer.postpone(task_id, Duration::from_millis(100));
1262
1263        // 验证原 completion_receiver 仍然有效(推迟后需要等待100ms,加上余量)
1264        let result = tokio::time::timeout(
1265            Duration::from_millis(200),
1266            handle.into_completion_receiver().0
1267        ).await;
1268        assert!(result.is_ok(), "Completion receiver should still work after postpone");
1269        
1270        // 等待回调执行
1271        tokio::time::sleep(Duration::from_millis(20)).await;
1272        assert_eq!(counter.load(Ordering::SeqCst), 1);
1273    }
1274}
1275