kestrel_protocol_timer/
service.rs

1use crate::config::ServiceConfig;
2use crate::error::TimerError;
3use crate::task::{CallbackWrapper, TaskCompletionReason, TaskId};
4use crate::wheel::Wheel;
5use futures::stream::{FuturesUnordered, StreamExt};
6use futures::future::BoxFuture;
7use parking_lot::Mutex;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::sync::mpsc;
11use tokio::task::JoinHandle;
12
13/// TimerService 命令类型
14enum ServiceCommand {
15    /// 添加批量定时器句柄(仅包含必要数据:task_ids 和 completion_rxs)
16    AddBatchHandle {
17        task_ids: Vec<TaskId>,
18        completion_rxs: Vec<tokio::sync::oneshot::Receiver<TaskCompletionReason>>,
19    },
20    /// 添加单个定时器句柄(仅包含必要数据:task_id 和 completion_rx)
21    AddTimerHandle {
22        task_id: TaskId,
23        completion_rx: tokio::sync::oneshot::Receiver<TaskCompletionReason>,
24    },
25}
26
27/// TimerService - 基于 Actor 模式的定时器服务
28///
29/// 管理多个定时器句柄,监听所有超时事件,并将 TaskId 聚合转发给用户。
30///
31/// # 特性
32/// - 自动监听所有添加的定时器句柄的超时事件
33/// - 超时后自动从内部管理中移除该任务
34/// - 将超时的 TaskId 转发到统一的通道供用户接收
35/// - 支持动态添加 BatchHandle 和 TimerHandle
36///
37/// # 示例
38/// ```no_run
39/// use kestrel_protocol_timer::{TimerWheel, TimerService, CallbackWrapper, ServiceConfig};
40/// use std::time::Duration;
41///
42/// #[tokio::main]
43/// async fn main() {
44///     let timer = TimerWheel::with_defaults();
45///     let mut service = timer.create_service(ServiceConfig::default());
46///     
47///     // 使用两步式 API 通过 service 批量调度定时器
48///     let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = (0..5)
49///         .map(|i| {
50///             let callback = Some(CallbackWrapper::new(move || async move {
51///                 println!("Timer {} fired!", i);
52///             }));
53///             (Duration::from_millis(100), callback)
54///         })
55///         .collect();
56///     let tasks = TimerService::create_batch_with_callbacks(callbacks);
57///     service.register_batch(tasks).unwrap();
58///     
59///     // 接收超时通知
60///     let mut rx = service.take_receiver().unwrap();
61///     while let Some(task_id) = rx.recv().await {
62///         println!("Task {:?} completed", task_id);
63///     }
64/// }
65/// ```
66pub struct TimerService {
67    /// 命令发送端
68    command_tx: mpsc::Sender<ServiceCommand>,
69    /// 超时接收端
70    timeout_rx: Option<mpsc::Receiver<TaskId>>,
71    /// Actor 任务句柄
72    actor_handle: Option<JoinHandle<()>>,
73    /// 时间轮引用(用于直接调度定时器)
74    wheel: Arc<Mutex<Wheel>>,
75    /// Actor 关闭信号发送端
76    shutdown_tx: Option<tokio::sync::oneshot::Sender<()>>,
77}
78
79impl TimerService {
80    /// 创建新的 TimerService
81    ///
82    /// # 参数
83    /// - `wheel`: 时间轮引用
84    /// - `config`: 服务配置
85    ///
86    /// # 注意
87    /// 通常不直接调用此方法,而是使用 `TimerWheel::create_service()` 来创建。
88    ///
89    /// # 示例
90    /// ```no_run
91    /// use kestrel_protocol_timer::{TimerWheel, ServiceConfig};
92    ///
93    /// #[tokio::main]
94    /// async fn main() {
95    ///     let timer = TimerWheel::with_defaults();
96    ///     let mut service = timer.create_service(ServiceConfig::default());
97    /// }
98    /// ```
99    pub(crate) fn new(wheel: Arc<Mutex<Wheel>>, config: ServiceConfig) -> Self {
100        let (command_tx, command_rx) = mpsc::channel(config.command_channel_capacity);
101        let (timeout_tx, timeout_rx) = mpsc::channel(config.timeout_channel_capacity);
102
103        let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
104        let actor = ServiceActor::new(command_rx, timeout_tx, shutdown_rx);
105        let actor_handle = tokio::spawn(async move {
106            actor.run().await;
107        });
108
109        Self {
110            command_tx,
111            timeout_rx: Some(timeout_rx),
112            actor_handle: Some(actor_handle),
113            wheel,
114            shutdown_tx: Some(shutdown_tx),
115        }
116    }
117
118    /// 获取超时接收器(转移所有权)
119    ///
120    /// # 返回
121    /// 超时通知接收器,如果已经被取走则返回 None
122    ///
123    /// # 注意
124    /// 此方法只能调用一次,因为它会转移接收器的所有权
125    ///
126    /// # 示例
127    /// ```no_run
128    /// # use kestrel_protocol_timer::{TimerWheel, ServiceConfig};
129    /// # use std::time::Duration;
130    /// # #[tokio::main]
131    /// # async fn main() {
132    /// let timer = TimerWheel::with_defaults();
133    /// let mut service = timer.create_service(ServiceConfig::default());
134    /// 
135    /// let mut rx = service.take_receiver().unwrap();
136    /// while let Some(task_id) = rx.recv().await {
137    ///     println!("Task {:?} timed out", task_id);
138    /// }
139    /// # }
140    /// ```
141    pub fn take_receiver(&mut self) -> Option<mpsc::Receiver<TaskId>> {
142        self.timeout_rx.take()
143    }
144
145    /// 取消指定的任务
146    ///
147    /// # 参数
148    /// - `task_id`: 要取消的任务 ID
149    ///
150    /// # 返回
151    /// - `true`: 任务存在且成功取消
152    /// - `false`: 任务不存在或取消失败
153    ///
154    /// # 性能说明
155    /// 此方法使用直接取消优化,不需要等待 Actor 处理,大幅降低延迟
156    ///
157    /// # 示例
158    /// ```no_run
159    /// # use kestrel_protocol_timer::{TimerWheel, TimerService, CallbackWrapper, ServiceConfig};
160    /// # use std::time::Duration;
161    /// # 
162    /// # #[tokio::main]
163    /// # async fn main() {
164    /// let timer = TimerWheel::with_defaults();
165    /// let service = timer.create_service(ServiceConfig::default());
166    /// 
167    /// // 使用两步式 API 调度定时器
168    /// let callback = Some(CallbackWrapper::new(|| async move {
169    ///     println!("Timer fired!");
170    /// }));
171    /// let task = TimerService::create_task(Duration::from_secs(10), callback);
172    /// let task_id = task.get_id();
173    /// service.register(task).unwrap();
174    /// 
175    /// // 取消任务
176    /// let cancelled = service.cancel_task(task_id);
177    /// println!("Task cancelled: {}", cancelled);
178    /// # }
179    /// ```
180    #[inline]
181    pub fn cancel_task(&self, task_id: TaskId) -> bool {
182        // 优化:直接取消任务,无需通知 Actor
183        // FuturesUnordered 会在任务被取消时自动清理
184        let mut wheel = self.wheel.lock();
185        wheel.cancel(task_id)
186    }
187
188    /// 批量取消任务
189    ///
190    /// 使用底层的批量取消操作一次性取消多个任务,性能优于循环调用 cancel_task。
191    ///
192    /// # 参数
193    /// - `task_ids`: 要取消的任务 ID 列表
194    ///
195    /// # 返回
196    /// 成功取消的任务数量
197    ///
198    /// # 示例
199    /// ```no_run
200    /// # use kestrel_protocol_timer::{TimerWheel, TimerService, CallbackWrapper, ServiceConfig};
201    /// # use std::time::Duration;
202    /// # 
203    /// # #[tokio::main]
204    /// # async fn main() {
205    /// let timer = TimerWheel::with_defaults();
206    /// let service = timer.create_service(ServiceConfig::default());
207    /// 
208    /// let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = (0..10)
209    ///     .map(|i| {
210    ///         let callback = Some(CallbackWrapper::new(move || async move {
211    ///             println!("Timer {} fired!", i);
212    ///         }));
213    ///         (Duration::from_secs(10), callback)
214    ///     })
215    ///     .collect();
216    /// let tasks = TimerService::create_batch_with_callbacks(callbacks);
217    /// let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
218    /// service.register_batch(tasks).unwrap();
219    /// 
220    /// // 批量取消
221    /// let cancelled = service.cancel_batch(&task_ids);
222    /// println!("成功取消 {} 个任务", cancelled);
223    /// # }
224    /// ```
225    #[inline]
226    pub fn cancel_batch(&self, task_ids: &[TaskId]) -> usize {
227        if task_ids.is_empty() {
228            return 0;
229        }
230
231        // 优化:直接使用底层的批量取消,无需通知 Actor
232        // FuturesUnordered 会在任务被取消时自动清理
233        let mut wheel = self.wheel.lock();
234        wheel.cancel_batch(task_ids)
235    }
236
237    /// 推迟任务(替换回调)
238    ///
239    /// # 参数
240    /// - `task_id`: 要推迟的任务 ID
241    /// - `new_delay`: 新的延迟时间(从当前时间点重新计算)
242    /// - `callback`: 新的回调函数
243    ///
244    /// # 返回
245    /// - `true`: 任务存在且成功推迟
246    /// - `false`: 任务不存在或推迟失败
247    ///
248    /// # 注意
249    /// - 推迟后任务 ID 保持不变
250    /// - 原有的超时通知仍然有效
251    /// - 回调函数会被替换为新的回调
252    ///
253    /// # 示例
254    /// ```no_run
255    /// # use kestrel_protocol_timer::{TimerWheel, TimerService, CallbackWrapper, ServiceConfig};
256    /// # use std::time::Duration;
257    /// # 
258    /// # #[tokio::main]
259    /// # async fn main() {
260    /// let timer = TimerWheel::with_defaults();
261    /// let service = timer.create_service(ServiceConfig::default());
262    /// 
263    /// let callback = Some(CallbackWrapper::new(|| async {
264    ///     println!("Original callback");
265    /// }));
266    /// let task = TimerService::create_task(Duration::from_secs(5), callback);
267    /// let task_id = task.get_id();
268    /// service.register(task).unwrap();
269    /// 
270    /// // 推迟并替换回调
271    /// let new_callback = Some(CallbackWrapper::new(|| async { println!("New callback!"); }));
272    /// let success = service.postpone(
273    ///     task_id,
274    ///     Duration::from_secs(10),
275    ///     new_callback
276    /// );
277    /// println!("推迟成功: {}", success);
278    /// # }
279    /// ```
280    #[inline]
281    pub fn postpone(&self, task_id: TaskId, new_delay: Duration, callback: Option<CallbackWrapper>) -> bool {
282        let mut wheel = self.wheel.lock();
283        wheel.postpone(task_id, new_delay, callback)
284    }
285
286    /// 批量推迟任务(保持原回调)
287    ///
288    /// # 参数
289    /// - `updates`: (任务ID, 新延迟) 的元组列表
290    ///
291    /// # 返回
292    /// 成功推迟的任务数量
293    ///
294    /// # 示例
295    /// ```no_run
296    /// # use kestrel_protocol_timer::{TimerWheel, TimerService, CallbackWrapper, ServiceConfig};
297    /// # use std::time::Duration;
298    /// # 
299    /// # #[tokio::main]
300    /// # async fn main() {
301    /// let timer = TimerWheel::with_defaults();
302    /// let service = timer.create_service(ServiceConfig::default());
303    /// 
304    /// let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = (0..3)
305    ///     .map(|i| {
306    ///         let callback = Some(CallbackWrapper::new(move || async move {
307    ///             println!("Timer {} fired!", i);
308    ///         }));
309    ///         (Duration::from_secs(5), callback)
310    ///     })
311    ///     .collect();
312    /// let tasks = TimerService::create_batch_with_callbacks(callbacks);
313    /// let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
314    /// service.register_batch(tasks).unwrap();
315    /// 
316    /// // 批量推迟(保持原回调)
317    /// let updates: Vec<_> = task_ids
318    ///     .into_iter()
319    ///     .map(|id| (id, Duration::from_secs(10)))
320    ///     .collect();
321    /// let postponed = service.postpone_batch(updates);
322    /// println!("成功推迟 {} 个任务", postponed);
323    /// # }
324    /// ```
325    #[inline]
326    pub fn postpone_batch(&self, updates: Vec<(TaskId, Duration)>) -> usize {
327        if updates.is_empty() {
328            return 0;
329        }
330
331        let mut wheel = self.wheel.lock();
332        wheel.postpone_batch(updates)
333    }
334
335    /// 批量推迟任务(替换回调)
336    ///
337    /// # 参数
338    /// - `updates`: (任务ID, 新延迟, 新回调) 的元组列表
339    ///
340    /// # 返回
341    /// 成功推迟的任务数量
342    ///
343    /// # 示例
344    /// ```no_run
345    /// # use kestrel_protocol_timer::{TimerWheel, TimerService, CallbackWrapper, ServiceConfig};
346    /// # use std::time::Duration;
347    /// # 
348    /// # #[tokio::main]
349    /// # async fn main() {
350    /// let timer = TimerWheel::with_defaults();
351    /// let service = timer.create_service(ServiceConfig::default());
352    /// 
353    /// // 创建 3 个任务,初始没有回调
354    /// let delays: Vec<Duration> = (0..3)
355    ///     .map(|_| Duration::from_secs(5))
356    ///     .collect();
357    /// let tasks = TimerService::create_batch(delays);
358    /// let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
359    /// service.register_batch(tasks).unwrap();
360    /// 
361    /// // 批量推迟并添加新回调
362    /// let updates: Vec<_> = task_ids
363    ///     .into_iter()
364    ///     .enumerate()
365    ///     .map(|(i, id)| {
366    ///         let callback = Some(CallbackWrapper::new(move || async move {
367    ///             println!("New callback {}", i);
368    ///         }));
369    ///         (id, Duration::from_secs(10), callback)
370    ///     })
371    ///     .collect();
372    /// let postponed = service.postpone_batch_with_callbacks(updates);
373    /// println!("成功推迟 {} 个任务", postponed);
374    /// # }
375    /// ```
376    #[inline]
377    pub fn postpone_batch_with_callbacks(
378        &self,
379        updates: Vec<(TaskId, Duration, Option<CallbackWrapper>)>,
380    ) -> usize {
381        if updates.is_empty() {
382            return 0;
383        }
384
385        let mut wheel = self.wheel.lock();
386        wheel.postpone_batch_with_callbacks(updates)
387    }
388
389    /// 创建定时器任务(静态方法,申请阶段)
390    /// 
391    /// # 参数
392    /// - `delay`: 延迟时间
393    /// - `callback`: 实现了 TimerCallback trait 的回调对象
394    /// 
395    /// # 返回
396    /// 返回 TimerTask,需要通过 `register()` 注册
397    /// 
398    /// # 示例
399    /// ```no_run
400    /// # use kestrel_protocol_timer::{TimerWheel, TimerService, CallbackWrapper, ServiceConfig};
401    /// # use std::time::Duration;
402    /// # 
403    /// # #[tokio::main]
404    /// # async fn main() {
405    /// let timer = TimerWheel::with_defaults();
406    /// let service = timer.create_service(ServiceConfig::default());
407    /// 
408    /// // 步骤 1: 创建任务
409    /// let callback = Some(CallbackWrapper::new(|| async {
410    ///     println!("Timer fired!");
411    /// }));
412    /// let task = TimerService::create_task(Duration::from_millis(100), callback);
413    /// 
414    /// let task_id = task.get_id();
415    /// println!("Created task: {:?}", task_id);
416    /// 
417    /// // 步骤 2: 注册任务
418    /// service.register(task).unwrap();
419    /// # }
420    /// ```
421    #[inline]
422    pub fn create_task(delay: Duration, callback: Option<CallbackWrapper>) -> crate::task::TimerTask {
423        crate::timer::TimerWheel::create_task(delay, callback)
424    }
425
426    /// 批量创建定时器任务(静态方法,申请阶段,不带回调)
427    /// 
428    /// # 参数
429    /// - `delays`: 延迟时间列表
430    /// 
431    /// # 返回
432    /// 返回 TimerTask 列表,需要通过 `register_batch()` 注册
433    /// 
434    /// # 示例
435    /// ```no_run
436    /// # use kestrel_protocol_timer::{TimerWheel, TimerService, CallbackWrapper, ServiceConfig};
437    /// # use std::time::Duration;
438    /// # 
439    /// # #[tokio::main]
440    /// # async fn main() {
441    /// let timer = TimerWheel::with_defaults();
442    /// let service = timer.create_service(ServiceConfig::default());
443    /// 
444    /// // 步骤 1: 批量创建任务
445    /// let delays: Vec<Duration> = (0..3)
446    ///     .map(|i| Duration::from_millis(100 * (i + 1)))
447    ///     .collect();
448    /// 
449    /// let tasks = TimerService::create_batch(delays);
450    /// println!("Created {} tasks", tasks.len());
451    /// 
452    /// // 步骤 2: 批量注册任务
453    /// service.register_batch(tasks).unwrap();
454    /// # }
455    /// ```
456    #[inline]
457    pub fn create_batch(delays: Vec<Duration>) -> Vec<crate::task::TimerTask> {
458        crate::timer::TimerWheel::create_batch(delays)
459    }
460    
461    /// 批量创建定时器任务(静态方法,申请阶段,带回调)
462    /// 
463    /// # 参数
464    /// - `callbacks`: (延迟时间, 回调) 的元组列表
465    /// 
466    /// # 返回
467    /// 返回 TimerTask 列表,需要通过 `register_batch()` 注册
468    /// 
469    /// # 示例
470    /// ```no_run
471    /// # use kestrel_protocol_timer::{TimerWheel, TimerService, CallbackWrapper, ServiceConfig};
472    /// # use std::time::Duration;
473    /// # 
474    /// # #[tokio::main]
475    /// # async fn main() {
476    /// let timer = TimerWheel::with_defaults();
477    /// let service = timer.create_service(ServiceConfig::default());
478    /// 
479    /// // 步骤 1: 批量创建任务
480    /// let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = (0..3)
481    ///     .map(|i| {
482    ///         let callback = Some(CallbackWrapper::new(move || async move {
483    ///             println!("Timer {} fired!", i);
484    ///         }));
485    ///         (Duration::from_millis(100 * (i + 1)), callback)
486    ///     })
487    ///     .collect();
488    /// 
489    /// let tasks = TimerService::create_batch_with_callbacks(callbacks);
490    /// println!("Created {} tasks", tasks.len());
491    /// 
492    /// // 步骤 2: 批量注册任务
493    /// service.register_batch(tasks).unwrap();
494    /// # }
495    /// ```
496    #[inline]
497    pub fn create_batch_with_callbacks(callbacks: Vec<(Duration, Option<CallbackWrapper>)>) -> Vec<crate::task::TimerTask> {
498        crate::timer::TimerWheel::create_batch_with_callbacks(callbacks)
499    }
500    
501    /// 注册定时器任务到服务(注册阶段)
502    /// 
503    /// # 参数
504    /// - `task`: 通过 `create_task()` 创建的任务
505    /// 
506    /// # 返回
507    /// - `Ok(())`: 注册成功
508    /// - `Err(TimerError::RegisterFailed)`: 注册失败(内部通道已满或已关闭)
509    /// 
510    /// # 示例
511    /// ```no_run
512    /// # use kestrel_protocol_timer::{TimerWheel, TimerService, CallbackWrapper, ServiceConfig};
513    /// # use std::time::Duration;
514    /// # 
515    /// # #[tokio::main]
516    /// # async fn main() {
517    /// let timer = TimerWheel::with_defaults();
518    /// let service = timer.create_service(ServiceConfig::default());
519    /// 
520    /// let callback = Some(CallbackWrapper::new(|| async move {
521    ///     println!("Timer fired!");
522    /// }));
523    /// let task = TimerService::create_task(Duration::from_millis(100), callback);
524    /// let task_id = task.get_id();
525    /// 
526    /// service.register(task).unwrap();
527    /// # }
528    /// ```
529    #[inline]
530    pub fn register(&self, task: crate::task::TimerTask) -> Result<(), TimerError> {
531        let (completion_tx, completion_rx) = tokio::sync::oneshot::channel();
532        let notifier = crate::task::CompletionNotifier(completion_tx);
533        
534        let task_id = task.id;
535        
536        // 单次加锁完成所有操作
537        {
538            let mut wheel_guard = self.wheel.lock();
539            wheel_guard.insert(task, notifier);
540        }
541        
542        // 添加到服务管理(只发送必要数据)
543        self.command_tx
544            .try_send(ServiceCommand::AddTimerHandle {
545                task_id,
546                completion_rx,
547            })
548            .map_err(|_| TimerError::RegisterFailed)?;
549        
550        Ok(())
551    }
552    
553    /// 批量注册定时器任务到服务(注册阶段)
554    /// 
555    /// # 参数
556    /// - `tasks`: 通过 `create_batch()` 创建的任务列表
557    /// 
558    /// # 返回
559    /// - `Ok(())`: 注册成功
560    /// - `Err(TimerError::RegisterFailed)`: 注册失败(内部通道已满或已关闭)
561    /// 
562    /// # 示例
563    /// ```no_run
564    /// # use kestrel_protocol_timer::{TimerWheel, TimerService, CallbackWrapper, ServiceConfig};
565    /// # use std::time::Duration;
566    /// # 
567    /// # #[tokio::main]
568    /// # async fn main() {
569    /// let timer = TimerWheel::with_defaults();
570    /// let service = timer.create_service(ServiceConfig::default());
571    /// 
572    /// let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = (0..3)
573    ///     .map(|i| {
574    ///         let callback = Some(CallbackWrapper::new(move || async move {
575    ///             println!("Timer {} fired!", i);
576    ///         }));
577    ///         (Duration::from_secs(1), callback)
578    ///     })
579    ///     .collect();
580    /// let tasks = TimerService::create_batch_with_callbacks(callbacks);
581    /// 
582    /// service.register_batch(tasks).unwrap();
583    /// # }
584    /// ```
585    #[inline]
586    pub fn register_batch(&self, tasks: Vec<crate::task::TimerTask>) -> Result<(), TimerError> {
587        let task_count = tasks.len();
588        let mut completion_rxs = Vec::with_capacity(task_count);
589        let mut task_ids = Vec::with_capacity(task_count);
590        let mut prepared_tasks = Vec::with_capacity(task_count);
591        
592        // 步骤1: 准备所有 channels 和 notifiers(无锁)
593        // 优化:使用 for 循环代替 map + collect,避免闭包捕获开销
594        for task in tasks {
595            let (completion_tx, completion_rx) = tokio::sync::oneshot::channel();
596            let notifier = crate::task::CompletionNotifier(completion_tx);
597            
598            task_ids.push(task.id);
599            completion_rxs.push(completion_rx);
600            prepared_tasks.push((task, notifier));
601        }
602        
603        // 步骤2: 单次加锁,批量插入
604        {
605            let mut wheel_guard = self.wheel.lock();
606            wheel_guard.insert_batch(prepared_tasks);
607        }
608        
609        // 添加到服务管理(只发送必要数据)
610        self.command_tx
611            .try_send(ServiceCommand::AddBatchHandle {
612                task_ids,
613                completion_rxs,
614            })
615            .map_err(|_| TimerError::RegisterFailed)?;
616        
617        Ok(())
618    }
619
620    /// 优雅关闭 TimerService
621    ///
622    /// # 示例
623    /// ```no_run
624    /// # use kestrel_protocol_timer::{TimerWheel, ServiceConfig};
625    /// # #[tokio::main]
626    /// # async fn main() {
627    /// let timer = TimerWheel::with_defaults();
628    /// let mut service = timer.create_service(ServiceConfig::default());
629    /// 
630    /// // 使用 service...
631    /// 
632    /// service.shutdown().await;
633    /// # }
634    /// ```
635    pub async fn shutdown(mut self) {
636        if let Some(shutdown_tx) = self.shutdown_tx.take() {
637            let _ = shutdown_tx.send(());
638        }
639        if let Some(handle) = self.actor_handle.take() {
640            let _ = handle.await;
641        }
642    }
643}
644
645
646impl Drop for TimerService {
647    fn drop(&mut self) {
648        if let Some(handle) = self.actor_handle.take() {
649            handle.abort();
650        }
651    }
652}
653
654/// ServiceActor - 内部 Actor 实现
655struct ServiceActor {
656    /// 命令接收端
657    command_rx: mpsc::Receiver<ServiceCommand>,
658    /// 超时发送端
659    timeout_tx: mpsc::Sender<TaskId>,
660    /// Actor 关闭信号接收端
661    shutdown_rx: tokio::sync::oneshot::Receiver<()>,
662}
663
664impl ServiceActor {
665    fn new(command_rx: mpsc::Receiver<ServiceCommand>, timeout_tx: mpsc::Sender<TaskId>, shutdown_rx: tokio::sync::oneshot::Receiver<()>) -> Self {
666        Self {
667            command_rx,
668            timeout_tx,
669            shutdown_rx,
670        }
671    }
672
673    async fn run(mut self) {
674        // 使用 FuturesUnordered 来监听所有的 completion_rxs
675        // 每个 future 返回 (TaskId, Result)
676        let mut futures: FuturesUnordered<BoxFuture<'static, (TaskId, Result<TaskCompletionReason, tokio::sync::oneshot::error::RecvError>)>> = FuturesUnordered::new();
677
678        // 将 shutdown_rx 移出 self,以便在 select! 中使用 &mut
679        let mut shutdown_rx = self.shutdown_rx;
680
681        loop {
682            tokio::select! {
683                // 监听高优先级的关闭信号
684                _ = &mut shutdown_rx => {
685                    // 收到关闭信号,立即退出循环
686                    break;
687                }
688
689                // 监听超时事件
690                Some((task_id, result)) = futures.next() => {
691                    // 检查完成原因,只转发超时(Expired)事件,不转发取消(Cancelled)事件
692                    if let Ok(TaskCompletionReason::Expired) = result {
693                        let _ = self.timeout_tx.send(task_id).await;
694                    }
695                    // 任务会自动从 FuturesUnordered 中移除
696                }
697                
698                // 监听命令
699                Some(cmd) = self.command_rx.recv() => {
700                    match cmd {
701                        ServiceCommand::AddBatchHandle { task_ids, completion_rxs } => {
702                            // 将所有任务添加到 futures 中
703                            for (task_id, rx) in task_ids.into_iter().zip(completion_rxs.into_iter()) {
704                                let future: BoxFuture<'static, (TaskId, Result<TaskCompletionReason, tokio::sync::oneshot::error::RecvError>)> = Box::pin(async move {
705                                    (task_id, rx.await)
706                                });
707                                futures.push(future);
708                            }
709                        }
710                        ServiceCommand::AddTimerHandle { task_id, completion_rx } => {
711                            
712                            // 添加到 futures 中
713                            let future: BoxFuture<'static, (TaskId, Result<TaskCompletionReason, tokio::sync::oneshot::error::RecvError>)> = Box::pin(async move {
714                                (task_id, completion_rx.await)
715                            });
716                            futures.push(future);
717                        }
718                    }
719                }
720                
721                // 如果没有任何 future 且命令通道已关闭,退出循环
722                else => {
723                    break;
724                }
725            }
726        }
727    }
728}
729
730#[cfg(test)]
731mod tests {
732    use super::*;
733    use crate::TimerWheel;
734    use std::sync::atomic::{AtomicU32, Ordering};
735    use std::sync::Arc;
736    use std::time::Duration;
737
738    #[tokio::test]
739    async fn test_service_creation() {
740        let timer = TimerWheel::with_defaults();
741        let _service = timer.create_service(ServiceConfig::default());
742    }
743
744
745    #[tokio::test]
746    async fn test_add_timer_handle_and_receive_timeout() {
747        let timer = TimerWheel::with_defaults();
748        let mut service = timer.create_service(ServiceConfig::default());
749
750        // 创建单个定时器
751        let task = TimerService::create_task(Duration::from_millis(50), Some(CallbackWrapper::new(|| async {})));
752        let task_id = task.get_id();
753        
754        // 注册到 service
755        service.register(task).unwrap();
756
757        // 接收超时通知
758        let mut rx = service.take_receiver().unwrap();
759        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
760            .await
761            .expect("Should receive timeout notification")
762            .expect("Should receive Some value");
763
764        assert_eq!(received_task_id, task_id);
765    }
766
767
768    #[tokio::test]
769    async fn test_shutdown() {
770        let timer = TimerWheel::with_defaults();
771        let service = timer.create_service(ServiceConfig::default());
772
773        // 添加一些定时器
774        let task1 = TimerService::create_task(Duration::from_secs(10), None);
775        let task2 = TimerService::create_task(Duration::from_secs(10), None);
776        service.register(task1).unwrap();
777        service.register(task2).unwrap();
778
779        // 立即关闭(不等待定时器触发)
780        service.shutdown().await;
781    }
782
783
784
785    #[tokio::test]
786    async fn test_cancel_task() {
787        let timer = TimerWheel::with_defaults();
788        let service = timer.create_service(ServiceConfig::default());
789
790        // 添加一个长时间的定时器
791        let task = TimerService::create_task(Duration::from_secs(10), None);
792        let task_id = task.get_id();
793        
794        service.register(task).unwrap();
795
796        // 取消任务
797        let cancelled = service.cancel_task(task_id);
798        assert!(cancelled, "Task should be cancelled successfully");
799
800        // 尝试再次取消同一个任务,应该返回 false
801        let cancelled_again = service.cancel_task(task_id);
802        assert!(!cancelled_again, "Task should not exist anymore");
803    }
804
805    #[tokio::test]
806    async fn test_cancel_nonexistent_task() {
807        let timer = TimerWheel::with_defaults();
808        let service = timer.create_service(ServiceConfig::default());
809
810        // 添加一个定时器以初始化 service
811        let task = TimerService::create_task(Duration::from_millis(50), None);
812        service.register(task).unwrap();
813
814        // 尝试取消一个不存在的任务(创建一个不会实际注册的任务ID)
815        let fake_task = TimerService::create_task(Duration::from_millis(50), None);
816        let fake_task_id = fake_task.get_id();
817        // 不注册 fake_task
818        let cancelled = service.cancel_task(fake_task_id);
819        assert!(!cancelled, "Nonexistent task should not be cancelled");
820    }
821
822
823    #[tokio::test]
824    async fn test_task_timeout_cleans_up_task_sender() {
825        let timer = TimerWheel::with_defaults();
826        let mut service = timer.create_service(ServiceConfig::default());
827
828        // 添加一个短时间的定时器
829        let task = TimerService::create_task(Duration::from_millis(50), None);
830        let task_id = task.get_id();
831        
832        service.register(task).unwrap();
833
834        // 等待任务超时
835        let mut rx = service.take_receiver().unwrap();
836        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
837            .await
838            .expect("Should receive timeout notification")
839            .expect("Should receive Some value");
840        
841        assert_eq!(received_task_id, task_id);
842
843        // 等待一下确保内部清理完成
844        tokio::time::sleep(Duration::from_millis(10)).await;
845
846        // 尝试取消已经超时的任务,应该返回 false
847        let cancelled = service.cancel_task(task_id);
848        assert!(!cancelled, "Timed out task should not exist anymore");
849    }
850
851    #[tokio::test]
852    async fn test_cancel_task_spawns_background_task() {
853        let timer = TimerWheel::with_defaults();
854        let service = timer.create_service(ServiceConfig::default());
855        let counter = Arc::new(AtomicU32::new(0));
856
857        // 创建一个定时器
858        let counter_clone = Arc::clone(&counter);
859        let task = TimerService::create_task(
860            Duration::from_secs(10),
861            Some(CallbackWrapper::new(move || {
862                let counter = Arc::clone(&counter_clone);
863                async move {
864                    counter.fetch_add(1, Ordering::SeqCst);
865                }
866            })),
867        );
868        let task_id = task.get_id();
869        
870        service.register(task).unwrap();
871
872        // 使用 cancel_task(会等待结果,但在后台协程中处理)
873        let cancelled = service.cancel_task(task_id);
874        assert!(cancelled, "Task should be cancelled successfully");
875
876        // 等待足够长时间确保回调不会被执行
877        tokio::time::sleep(Duration::from_millis(100)).await;
878        assert_eq!(counter.load(Ordering::SeqCst), 0, "Callback should not have been executed");
879
880        // 验证任务已从 active_tasks 中移除
881        let cancelled_again = service.cancel_task(task_id);
882        assert!(!cancelled_again, "Task should have been removed from active_tasks");
883    }
884
885    #[tokio::test]
886    async fn test_schedule_once_direct() {
887        let timer = TimerWheel::with_defaults();
888        let mut service = timer.create_service(ServiceConfig::default());
889        let counter = Arc::new(AtomicU32::new(0));
890
891        // 直接通过 service 调度定时器
892        let counter_clone = Arc::clone(&counter);
893        let task = TimerService::create_task(
894            Duration::from_millis(50),
895            Some(CallbackWrapper::new(move || {
896                let counter = Arc::clone(&counter_clone);
897                async move {
898                    counter.fetch_add(1, Ordering::SeqCst);
899                }
900            })),
901        );
902        let task_id = task.get_id();
903        service.register(task).unwrap();
904
905        // 等待定时器触发
906        let mut rx = service.take_receiver().unwrap();
907        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
908            .await
909            .expect("Should receive timeout notification")
910            .expect("Should receive Some value");
911
912        assert_eq!(received_task_id, task_id);
913        
914        // 等待回调执行
915        tokio::time::sleep(Duration::from_millis(50)).await;
916        assert_eq!(counter.load(Ordering::SeqCst), 1);
917    }
918
919    #[tokio::test]
920    async fn test_schedule_once_batch_direct() {
921        let timer = TimerWheel::with_defaults();
922        let mut service = timer.create_service(ServiceConfig::default());
923        let counter = Arc::new(AtomicU32::new(0));
924
925        // 直接通过 service 批量调度定时器
926        let callbacks: Vec<_> = (0..3)
927            .map(|_| {
928                let counter = Arc::clone(&counter);
929                (Duration::from_millis(50), Some(CallbackWrapper::new(move || {
930                    let counter = Arc::clone(&counter);
931                    async move {
932                        counter.fetch_add(1, Ordering::SeqCst);
933                    }
934                    })))
935            })
936            .collect();
937
938        let tasks = TimerService::create_batch_with_callbacks(callbacks);
939        assert_eq!(tasks.len(), 3);
940        service.register_batch(tasks).unwrap();
941
942        // 接收所有超时通知
943        let mut received_count = 0;
944        let mut rx = service.take_receiver().unwrap();
945        
946        while received_count < 3 {
947            match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
948                Ok(Some(_task_id)) => {
949                    received_count += 1;
950                }
951                Ok(None) => break,
952                Err(_) => break,
953            }
954        }
955
956        assert_eq!(received_count, 3);
957        
958        // 等待回调执行
959        tokio::time::sleep(Duration::from_millis(50)).await;
960        assert_eq!(counter.load(Ordering::SeqCst), 3);
961    }
962
963    #[tokio::test]
964    async fn test_schedule_once_notify_direct() {
965        let timer = TimerWheel::with_defaults();
966        let mut service = timer.create_service(ServiceConfig::default());
967
968        // 直接通过 service 调度仅通知的定时器(无回调)
969        let task = TimerService::create_task(Duration::from_millis(50), None);
970        let task_id = task.get_id();
971        service.register(task).unwrap();
972
973        // 接收超时通知
974        let mut rx = service.take_receiver().unwrap();
975        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
976            .await
977            .expect("Should receive timeout notification")
978            .expect("Should receive Some value");
979
980        assert_eq!(received_task_id, task_id);
981    }
982
983    #[tokio::test]
984    async fn test_schedule_and_cancel_direct() {
985        let timer = TimerWheel::with_defaults();
986        let service = timer.create_service(ServiceConfig::default());
987        let counter = Arc::new(AtomicU32::new(0));
988
989        // 直接调度定时器
990        let counter_clone = Arc::clone(&counter);
991        let task = TimerService::create_task(
992            Duration::from_secs(10),
993            Some(CallbackWrapper::new(move || {
994                let counter = Arc::clone(&counter_clone);
995                async move {
996                    counter.fetch_add(1, Ordering::SeqCst);
997                }
998            })),
999        );
1000        let task_id = task.get_id();
1001        service.register(task).unwrap();
1002
1003        // 立即取消
1004        let cancelled = service.cancel_task(task_id);
1005        assert!(cancelled, "Task should be cancelled successfully");
1006
1007        // 等待确保回调不会执行
1008        tokio::time::sleep(Duration::from_millis(100)).await;
1009        assert_eq!(counter.load(Ordering::SeqCst), 0, "Callback should not have been executed");
1010    }
1011
1012    #[tokio::test]
1013    async fn test_cancel_batch_direct() {
1014        let timer = TimerWheel::with_defaults();
1015        let service = timer.create_service(ServiceConfig::default());
1016        let counter = Arc::new(AtomicU32::new(0));
1017
1018        // 批量调度定时器
1019        let callbacks: Vec<_> = (0..10)
1020            .map(|_| {
1021                let counter = Arc::clone(&counter);
1022                (Duration::from_secs(10), Some(CallbackWrapper::new(move || {
1023                    let counter = Arc::clone(&counter);
1024                    async move {
1025                        counter.fetch_add(1, Ordering::SeqCst);
1026                    }
1027                    })))
1028            })
1029            .collect();
1030
1031        let tasks = TimerService::create_batch_with_callbacks(callbacks);
1032        let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
1033        assert_eq!(task_ids.len(), 10);
1034        service.register_batch(tasks).unwrap();
1035
1036        // 批量取消所有任务
1037        let cancelled = service.cancel_batch(&task_ids);
1038        assert_eq!(cancelled, 10, "All 10 tasks should be cancelled");
1039
1040        // 等待确保回调不会执行
1041        tokio::time::sleep(Duration::from_millis(100)).await;
1042        assert_eq!(counter.load(Ordering::SeqCst), 0, "No callbacks should have been executed");
1043    }
1044
1045    #[tokio::test]
1046    async fn test_cancel_batch_partial() {
1047        let timer = TimerWheel::with_defaults();
1048        let service = timer.create_service(ServiceConfig::default());
1049        let counter = Arc::new(AtomicU32::new(0));
1050
1051        // 批量调度定时器
1052        let callbacks: Vec<_> = (0..10)
1053            .map(|_| {
1054                let counter = Arc::clone(&counter);
1055                (Duration::from_secs(10), Some(CallbackWrapper::new(move || {
1056                    let counter = Arc::clone(&counter);
1057                    async move {
1058                        counter.fetch_add(1, Ordering::SeqCst);
1059                    }
1060                    })))
1061            })
1062            .collect();
1063
1064        let tasks = TimerService::create_batch_with_callbacks(callbacks);
1065        let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
1066        service.register_batch(tasks).unwrap();
1067
1068        // 只取消前5个任务
1069        let to_cancel: Vec<_> = task_ids.iter().take(5).copied().collect();
1070        let cancelled = service.cancel_batch(&to_cancel);
1071        assert_eq!(cancelled, 5, "5 tasks should be cancelled");
1072
1073        // 等待确保前5个回调不会执行
1074        tokio::time::sleep(Duration::from_millis(100)).await;
1075        assert_eq!(counter.load(Ordering::SeqCst), 0, "Cancelled tasks should not execute");
1076    }
1077
1078    #[tokio::test]
1079    async fn test_cancel_batch_empty() {
1080        let timer = TimerWheel::with_defaults();
1081        let service = timer.create_service(ServiceConfig::default());
1082
1083        // 取消空列表
1084        let empty: Vec<TaskId> = vec![];
1085        let cancelled = service.cancel_batch(&empty);
1086        assert_eq!(cancelled, 0, "No tasks should be cancelled");
1087    }
1088
1089    #[tokio::test]
1090    async fn test_postpone() {
1091        let timer = TimerWheel::with_defaults();
1092        let mut service = timer.create_service(ServiceConfig::default());
1093        let counter = Arc::new(AtomicU32::new(0));
1094
1095        // 注册一个任务,原始回调增加 1
1096        let counter_clone1 = Arc::clone(&counter);
1097        let task = TimerService::create_task(
1098            Duration::from_millis(50),
1099            Some(CallbackWrapper::new(move || {
1100                let counter = Arc::clone(&counter_clone1);
1101                async move {
1102                    counter.fetch_add(1, Ordering::SeqCst);
1103                }
1104                })),
1105        );
1106        let task_id = task.get_id();
1107        service.register(task).unwrap();
1108
1109        // 推迟任务并替换回调,新回调增加 10
1110        let counter_clone2 = Arc::clone(&counter);
1111        let postponed = service.postpone(
1112            task_id,
1113            Duration::from_millis(100),
1114            Some(CallbackWrapper::new(move || {
1115                let counter = Arc::clone(&counter_clone2);
1116                async move {
1117                    counter.fetch_add(10, Ordering::SeqCst);
1118                }
1119            }))
1120        );
1121        assert!(postponed, "Task should be postponed successfully");
1122
1123        // 接收超时通知(推迟后需要等待100ms,加上余量)
1124        let mut rx = service.take_receiver().unwrap();
1125        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1126            .await
1127            .expect("Should receive timeout notification")
1128            .expect("Should receive Some value");
1129
1130        assert_eq!(received_task_id, task_id);
1131        
1132        // 等待回调执行
1133        tokio::time::sleep(Duration::from_millis(20)).await;
1134        
1135        // 验证新回调被执行(增加了 10 而不是 1)
1136        assert_eq!(counter.load(Ordering::SeqCst), 10);
1137    }
1138
1139    #[tokio::test]
1140    async fn test_postpone_nonexistent_task() {
1141        let timer = TimerWheel::with_defaults();
1142        let service = timer.create_service(ServiceConfig::default());
1143
1144        // 尝试推迟一个不存在的任务
1145        let fake_task = TimerService::create_task(Duration::from_millis(50), None);
1146        let fake_task_id = fake_task.get_id();
1147        // 不注册这个任务
1148        
1149        let postponed = service.postpone(fake_task_id, Duration::from_millis(100), None);
1150        assert!(!postponed, "Nonexistent task should not be postponed");
1151    }
1152
1153    #[tokio::test]
1154    async fn test_postpone_batch() {
1155        let timer = TimerWheel::with_defaults();
1156        let mut service = timer.create_service(ServiceConfig::default());
1157        let counter = Arc::new(AtomicU32::new(0));
1158
1159        // 注册 3 个任务
1160        let mut task_ids = Vec::new();
1161        for _ in 0..3 {
1162            let counter_clone = Arc::clone(&counter);
1163            let task = TimerService::create_task(
1164                Duration::from_millis(50),
1165            Some(CallbackWrapper::new(move || {
1166                let counter = Arc::clone(&counter_clone);
1167                async move {
1168                    counter.fetch_add(1, Ordering::SeqCst);
1169                }
1170            })),
1171        );
1172            task_ids.push((task.get_id(), Duration::from_millis(150), None));
1173            service.register(task).unwrap();
1174        }
1175
1176        // 批量推迟
1177        let postponed = service.postpone_batch_with_callbacks(task_ids);
1178        assert_eq!(postponed, 3, "All 3 tasks should be postponed");
1179
1180        // 等待原定时间 50ms,任务不应该触发
1181        tokio::time::sleep(Duration::from_millis(70)).await;
1182        assert_eq!(counter.load(Ordering::SeqCst), 0);
1183
1184        // 接收所有超时通知
1185        let mut received_count = 0;
1186        let mut rx = service.take_receiver().unwrap();
1187        
1188        while received_count < 3 {
1189            match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
1190                Ok(Some(_task_id)) => {
1191                    received_count += 1;
1192                }
1193                Ok(None) => break,
1194                Err(_) => break,
1195            }
1196        }
1197
1198        assert_eq!(received_count, 3);
1199        
1200        // 等待回调执行
1201        tokio::time::sleep(Duration::from_millis(20)).await;
1202        assert_eq!(counter.load(Ordering::SeqCst), 3);
1203    }
1204
1205    #[tokio::test]
1206    async fn test_postpone_batch_with_callbacks() {
1207        let timer = TimerWheel::with_defaults();
1208        let mut service = timer.create_service(ServiceConfig::default());
1209        let counter = Arc::new(AtomicU32::new(0));
1210
1211        // 注册 3 个任务
1212        let mut task_ids = Vec::new();
1213        for _ in 0..3 {
1214            let task = TimerService::create_task(
1215                Duration::from_millis(50),
1216                None,
1217            );
1218            task_ids.push(task.get_id());
1219            service.register(task).unwrap();
1220        }
1221
1222        // 批量推迟并替换回调
1223        let updates: Vec<_> = task_ids
1224            .into_iter()
1225            .map(|id| {
1226                let counter_clone = Arc::clone(&counter);
1227                (id, Duration::from_millis(150), Some(CallbackWrapper::new(move || {
1228                    let counter = Arc::clone(&counter_clone);
1229                    async move {
1230                        counter.fetch_add(1, Ordering::SeqCst);
1231                    }
1232                    })))
1233            })
1234            .collect();
1235
1236        let postponed = service.postpone_batch_with_callbacks(updates);
1237        assert_eq!(postponed, 3, "All 3 tasks should be postponed");
1238
1239        // 等待原定时间 50ms,任务不应该触发
1240        tokio::time::sleep(Duration::from_millis(70)).await;
1241        assert_eq!(counter.load(Ordering::SeqCst), 0);
1242
1243        // 接收所有超时通知
1244        let mut received_count = 0;
1245        let mut rx = service.take_receiver().unwrap();
1246        
1247        while received_count < 3 {
1248            match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
1249                Ok(Some(_task_id)) => {
1250                    received_count += 1;
1251                }
1252                Ok(None) => break,
1253                Err(_) => break,
1254            }
1255        }
1256
1257        assert_eq!(received_count, 3);
1258        
1259        // 等待回调执行
1260        tokio::time::sleep(Duration::from_millis(20)).await;
1261        assert_eq!(counter.load(Ordering::SeqCst), 3);
1262    }
1263
1264    #[tokio::test]
1265    async fn test_postpone_batch_empty() {
1266        let timer = TimerWheel::with_defaults();
1267        let service = timer.create_service(ServiceConfig::default());
1268
1269        // 推迟空列表
1270        let empty: Vec<(TaskId, Duration, Option<CallbackWrapper>)> = vec![];
1271        let postponed = service.postpone_batch_with_callbacks(empty);
1272        assert_eq!(postponed, 0, "No tasks should be postponed");
1273    }
1274
1275    #[tokio::test]
1276    async fn test_postpone_keeps_timeout_notification_valid() {
1277        let timer = TimerWheel::with_defaults();
1278        let mut service = timer.create_service(ServiceConfig::default());
1279        let counter = Arc::new(AtomicU32::new(0));
1280
1281        // 注册一个任务
1282        let counter_clone = Arc::clone(&counter);
1283        let task = TimerService::create_task(
1284            Duration::from_millis(50),
1285            Some(CallbackWrapper::new(move || {
1286                let counter = Arc::clone(&counter_clone);
1287                async move {
1288                    counter.fetch_add(1, Ordering::SeqCst);
1289                }
1290            })),
1291        );
1292        let task_id = task.get_id();
1293        service.register(task).unwrap();
1294
1295        // 推迟任务
1296        service.postpone(task_id, Duration::from_millis(100), None);
1297
1298        // 验证超时通知仍然有效(推迟后需要等待100ms,加上余量)
1299        let mut rx = service.take_receiver().unwrap();
1300        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1301            .await
1302            .expect("Should receive timeout notification")
1303            .expect("Should receive Some value");
1304
1305        assert_eq!(received_task_id, task_id, "Timeout notification should still work after postpone");
1306        
1307        // 等待回调执行
1308        tokio::time::sleep(Duration::from_millis(20)).await;
1309        assert_eq!(counter.load(Ordering::SeqCst), 1);
1310    }
1311
1312    #[tokio::test]
1313    async fn test_cancelled_task_not_forwarded_to_timeout_rx() {
1314        let timer = TimerWheel::with_defaults();
1315        let mut service = timer.create_service(ServiceConfig::default());
1316
1317        // 注册两个任务:一个会被取消,一个会正常到期
1318        let task1 = TimerService::create_task(Duration::from_secs(10), None);
1319        let task1_id = task1.get_id();
1320        service.register(task1).unwrap();
1321
1322        let task2 = TimerService::create_task(Duration::from_millis(50), None);
1323        let task2_id = task2.get_id();
1324        service.register(task2).unwrap();
1325
1326        // 取消第一个任务
1327        let cancelled = service.cancel_task(task1_id);
1328        assert!(cancelled, "Task should be cancelled");
1329
1330        // 等待第二个任务到期
1331        let mut rx = service.take_receiver().unwrap();
1332        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1333            .await
1334            .expect("Should receive timeout notification")
1335            .expect("Should receive Some value");
1336
1337        // 应该只收到第二个任务(到期的)的通知,不应该收到第一个任务(取消的)的通知
1338        assert_eq!(received_task_id, task2_id, "Should only receive expired task notification");
1339
1340        // 验证没有其他通知(特别是被取消的任务不应该有通知)
1341        let no_more = tokio::time::timeout(Duration::from_millis(50), rx.recv()).await;
1342        assert!(no_more.is_err(), "Should not receive any more notifications");
1343    }
1344
1345    #[tokio::test]
1346    async fn test_take_receiver_twice() {
1347        let timer = TimerWheel::with_defaults();
1348        let mut service = timer.create_service(ServiceConfig::default());
1349
1350        // 第一次调用应该返回 Some
1351        let rx1 = service.take_receiver();
1352        assert!(rx1.is_some(), "First take_receiver should return Some");
1353
1354        // 第二次调用应该返回 None
1355        let rx2 = service.take_receiver();
1356        assert!(rx2.is_none(), "Second take_receiver should return None");
1357    }
1358
1359    #[tokio::test]
1360    async fn test_postpone_batch_without_callbacks() {
1361        let timer = TimerWheel::with_defaults();
1362        let mut service = timer.create_service(ServiceConfig::default());
1363        let counter = Arc::new(AtomicU32::new(0));
1364
1365        // 注册 3 个任务,有原始回调
1366        let mut task_ids = Vec::new();
1367        for _ in 0..3 {
1368            let counter_clone = Arc::clone(&counter);
1369            let task = TimerService::create_task(
1370                Duration::from_millis(50),
1371                Some(CallbackWrapper::new(move || {
1372                    let counter = Arc::clone(&counter_clone);
1373                    async move {
1374                        counter.fetch_add(1, Ordering::SeqCst);
1375                    }
1376                })),
1377            );
1378            task_ids.push(task.get_id());
1379            service.register(task).unwrap();
1380        }
1381
1382        // 批量推迟,不替换回调
1383        let updates: Vec<_> = task_ids
1384            .iter()
1385            .map(|&id| (id, Duration::from_millis(150)))
1386            .collect();
1387        let postponed = service.postpone_batch(updates);
1388        assert_eq!(postponed, 3, "All 3 tasks should be postponed");
1389
1390        // 等待原定时间 50ms,任务不应该触发
1391        tokio::time::sleep(Duration::from_millis(70)).await;
1392        assert_eq!(counter.load(Ordering::SeqCst), 0, "Callbacks should not fire yet");
1393
1394        // 接收所有超时通知
1395        let mut received_count = 0;
1396        let mut rx = service.take_receiver().unwrap();
1397        
1398        while received_count < 3 {
1399            match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
1400                Ok(Some(_task_id)) => {
1401                    received_count += 1;
1402                }
1403                Ok(None) => break,
1404                Err(_) => break,
1405            }
1406        }
1407
1408        assert_eq!(received_count, 3, "Should receive 3 timeout notifications");
1409        
1410        // 等待回调执行
1411        tokio::time::sleep(Duration::from_millis(20)).await;
1412        assert_eq!(counter.load(Ordering::SeqCst), 3, "All callbacks should execute");
1413    }
1414}
1415