kestrel_protocol_timer/
service.rs

1use crate::config::ServiceConfig;
2use crate::error::TimerError;
3use crate::task::{CallbackWrapper, TaskCompletionReason, TaskId};
4use crate::wheel::Wheel;
5use futures::stream::{FuturesUnordered, StreamExt};
6use futures::future::BoxFuture;
7use parking_lot::Mutex;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::sync::mpsc;
11use tokio::task::JoinHandle;
12
13/// TimerService 命令类型
14enum ServiceCommand {
15    /// 添加批量定时器句柄(仅包含必要数据:task_ids 和 completion_rxs)
16    AddBatchHandle {
17        task_ids: Vec<TaskId>,
18        completion_rxs: Vec<tokio::sync::oneshot::Receiver<TaskCompletionReason>>,
19    },
20    /// 添加单个定时器句柄(仅包含必要数据:task_id 和 completion_rx)
21    AddTimerHandle {
22        task_id: TaskId,
23        completion_rx: tokio::sync::oneshot::Receiver<TaskCompletionReason>,
24    },
25    /// 关闭 Service
26    Shutdown,
27}
28
29/// TimerService - 基于 Actor 模式的定时器服务
30///
31/// 管理多个定时器句柄,监听所有超时事件,并将 TaskId 聚合转发给用户。
32///
33/// # 特性
34/// - 自动监听所有添加的定时器句柄的超时事件
35/// - 超时后自动从内部管理中移除该任务
36/// - 将超时的 TaskId 转发到统一的通道供用户接收
37/// - 支持动态添加 BatchHandle 和 TimerHandle
38///
39/// # 示例
40/// ```no_run
41/// use kestrel_protocol_timer::{TimerWheel, TimerService, CallbackWrapper, ServiceConfig};
42/// use std::time::Duration;
43///
44/// #[tokio::main]
45/// async fn main() {
46///     let timer = TimerWheel::with_defaults();
47///     let mut service = timer.create_service(ServiceConfig::default());
48///     
49///     // 使用两步式 API 通过 service 批量调度定时器
50///     let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = (0..5)
51///         .map(|i| {
52///             let callback = Some(CallbackWrapper::new(move || async move {
53///                 println!("Timer {} fired!", i);
54///             }));
55///             (Duration::from_millis(100), callback)
56///         })
57///         .collect();
58///     let tasks = TimerService::create_batch_with_callbacks(callbacks);
59///     service.register_batch(tasks).unwrap();
60///     
61///     // 接收超时通知
62///     let mut rx = service.take_receiver().unwrap();
63///     while let Some(task_id) = rx.recv().await {
64///         println!("Task {:?} completed", task_id);
65///     }
66/// }
67/// ```
68pub struct TimerService {
69    /// 命令发送端
70    command_tx: mpsc::Sender<ServiceCommand>,
71    /// 超时接收端
72    timeout_rx: Option<mpsc::Receiver<TaskId>>,
73    /// Actor 任务句柄
74    actor_handle: Option<JoinHandle<()>>,
75    /// 时间轮引用(用于直接调度定时器)
76    wheel: Arc<Mutex<Wheel>>,
77}
78
79impl TimerService {
80    /// 创建新的 TimerService
81    ///
82    /// # 参数
83    /// - `wheel`: 时间轮引用
84    /// - `config`: 服务配置
85    ///
86    /// # 注意
87    /// 通常不直接调用此方法,而是使用 `TimerWheel::create_service()` 来创建。
88    ///
89    /// # 示例
90    /// ```no_run
91    /// use kestrel_protocol_timer::{TimerWheel, ServiceConfig};
92    ///
93    /// #[tokio::main]
94    /// async fn main() {
95    ///     let timer = TimerWheel::with_defaults();
96    ///     let mut service = timer.create_service(ServiceConfig::default());
97    /// }
98    /// ```
99    pub(crate) fn new(wheel: Arc<Mutex<Wheel>>, config: ServiceConfig) -> Self {
100        let (command_tx, command_rx) = mpsc::channel(config.command_channel_capacity);
101        let (timeout_tx, timeout_rx) = mpsc::channel(config.timeout_channel_capacity);
102
103        let actor = ServiceActor::new(command_rx, timeout_tx);
104        let actor_handle = tokio::spawn(async move {
105            actor.run().await;
106        });
107
108        Self {
109            command_tx,
110            timeout_rx: Some(timeout_rx),
111            actor_handle: Some(actor_handle),
112            wheel,
113        }
114    }
115
116    /// 获取超时接收器(转移所有权)
117    ///
118    /// # 返回
119    /// 超时通知接收器,如果已经被取走则返回 None
120    ///
121    /// # 注意
122    /// 此方法只能调用一次,因为它会转移接收器的所有权
123    ///
124    /// # 示例
125    /// ```no_run
126    /// # use kestrel_protocol_timer::{TimerWheel, ServiceConfig};
127    /// # use std::time::Duration;
128    /// # #[tokio::main]
129    /// # async fn main() {
130    /// let timer = TimerWheel::with_defaults();
131    /// let mut service = timer.create_service(ServiceConfig::default());
132    /// 
133    /// let mut rx = service.take_receiver().unwrap();
134    /// while let Some(task_id) = rx.recv().await {
135    ///     println!("Task {:?} timed out", task_id);
136    /// }
137    /// # }
138    /// ```
139    pub fn take_receiver(&mut self) -> Option<mpsc::Receiver<TaskId>> {
140        self.timeout_rx.take()
141    }
142
143    /// 取消指定的任务
144    ///
145    /// # 参数
146    /// - `task_id`: 要取消的任务 ID
147    ///
148    /// # 返回
149    /// - `true`: 任务存在且成功取消
150    /// - `false`: 任务不存在或取消失败
151    ///
152    /// # 性能说明
153    /// 此方法使用直接取消优化,不需要等待 Actor 处理,大幅降低延迟
154    ///
155    /// # 示例
156    /// ```no_run
157    /// # use kestrel_protocol_timer::{TimerWheel, TimerService, CallbackWrapper, ServiceConfig};
158    /// # use std::time::Duration;
159    /// # 
160    /// # #[tokio::main]
161    /// # async fn main() {
162    /// let timer = TimerWheel::with_defaults();
163    /// let service = timer.create_service(ServiceConfig::default());
164    /// 
165    /// // 使用两步式 API 调度定时器
166    /// let callback = Some(CallbackWrapper::new(|| async move {
167    ///     println!("Timer fired!");
168    /// }));
169    /// let task = TimerService::create_task(Duration::from_secs(10), callback);
170    /// let task_id = task.get_id();
171    /// service.register(task).unwrap();
172    /// 
173    /// // 取消任务
174    /// let cancelled = service.cancel_task(task_id);
175    /// println!("Task cancelled: {}", cancelled);
176    /// # }
177    /// ```
178    #[inline]
179    pub fn cancel_task(&self, task_id: TaskId) -> bool {
180        // 优化:直接取消任务,无需通知 Actor
181        // FuturesUnordered 会在任务被取消时自动清理
182        let mut wheel = self.wheel.lock();
183        wheel.cancel(task_id)
184    }
185
186    /// 批量取消任务
187    ///
188    /// 使用底层的批量取消操作一次性取消多个任务,性能优于循环调用 cancel_task。
189    ///
190    /// # 参数
191    /// - `task_ids`: 要取消的任务 ID 列表
192    ///
193    /// # 返回
194    /// 成功取消的任务数量
195    ///
196    /// # 示例
197    /// ```no_run
198    /// # use kestrel_protocol_timer::{TimerWheel, TimerService, CallbackWrapper, ServiceConfig};
199    /// # use std::time::Duration;
200    /// # 
201    /// # #[tokio::main]
202    /// # async fn main() {
203    /// let timer = TimerWheel::with_defaults();
204    /// let service = timer.create_service(ServiceConfig::default());
205    /// 
206    /// let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = (0..10)
207    ///     .map(|i| {
208    ///         let callback = Some(CallbackWrapper::new(move || async move {
209    ///             println!("Timer {} fired!", i);
210    ///         }));
211    ///         (Duration::from_secs(10), callback)
212    ///     })
213    ///     .collect();
214    /// let tasks = TimerService::create_batch_with_callbacks(callbacks);
215    /// let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
216    /// service.register_batch(tasks).unwrap();
217    /// 
218    /// // 批量取消
219    /// let cancelled = service.cancel_batch(&task_ids);
220    /// println!("成功取消 {} 个任务", cancelled);
221    /// # }
222    /// ```
223    #[inline]
224    pub fn cancel_batch(&self, task_ids: &[TaskId]) -> usize {
225        if task_ids.is_empty() {
226            return 0;
227        }
228
229        // 优化:直接使用底层的批量取消,无需通知 Actor
230        // FuturesUnordered 会在任务被取消时自动清理
231        let mut wheel = self.wheel.lock();
232        wheel.cancel_batch(task_ids)
233    }
234
235    /// 推迟任务(替换回调)
236    ///
237    /// # 参数
238    /// - `task_id`: 要推迟的任务 ID
239    /// - `new_delay`: 新的延迟时间(从当前时间点重新计算)
240    /// - `callback`: 新的回调函数
241    ///
242    /// # 返回
243    /// - `true`: 任务存在且成功推迟
244    /// - `false`: 任务不存在或推迟失败
245    ///
246    /// # 注意
247    /// - 推迟后任务 ID 保持不变
248    /// - 原有的超时通知仍然有效
249    /// - 回调函数会被替换为新的回调
250    ///
251    /// # 示例
252    /// ```no_run
253    /// # use kestrel_protocol_timer::{TimerWheel, TimerService, CallbackWrapper, ServiceConfig};
254    /// # use std::time::Duration;
255    /// # 
256    /// # #[tokio::main]
257    /// # async fn main() {
258    /// let timer = TimerWheel::with_defaults();
259    /// let service = timer.create_service(ServiceConfig::default());
260    /// 
261    /// let callback = Some(CallbackWrapper::new(|| async {
262    ///     println!("Original callback");
263    /// }));
264    /// let task = TimerService::create_task(Duration::from_secs(5), callback);
265    /// let task_id = task.get_id();
266    /// service.register(task).unwrap();
267    /// 
268    /// // 推迟并替换回调
269    /// let new_callback = Some(CallbackWrapper::new(|| async { println!("New callback!"); }));
270    /// let success = service.postpone(
271    ///     task_id,
272    ///     Duration::from_secs(10),
273    ///     new_callback
274    /// );
275    /// println!("推迟成功: {}", success);
276    /// # }
277    /// ```
278    #[inline]
279    pub fn postpone(&self, task_id: TaskId, new_delay: Duration, callback: Option<CallbackWrapper>) -> bool {
280        let mut wheel = self.wheel.lock();
281        wheel.postpone(task_id, new_delay, callback)
282    }
283
284    /// 批量推迟任务(保持原回调)
285    ///
286    /// # 参数
287    /// - `updates`: (任务ID, 新延迟) 的元组列表
288    ///
289    /// # 返回
290    /// 成功推迟的任务数量
291    ///
292    /// # 示例
293    /// ```no_run
294    /// # use kestrel_protocol_timer::{TimerWheel, TimerService, CallbackWrapper, ServiceConfig};
295    /// # use std::time::Duration;
296    /// # 
297    /// # #[tokio::main]
298    /// # async fn main() {
299    /// let timer = TimerWheel::with_defaults();
300    /// let service = timer.create_service(ServiceConfig::default());
301    /// 
302    /// let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = (0..3)
303    ///     .map(|i| {
304    ///         let callback = Some(CallbackWrapper::new(move || async move {
305    ///             println!("Timer {} fired!", i);
306    ///         }));
307    ///         (Duration::from_secs(5), callback)
308    ///     })
309    ///     .collect();
310    /// let tasks = TimerService::create_batch_with_callbacks(callbacks);
311    /// let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
312    /// service.register_batch(tasks).unwrap();
313    /// 
314    /// // 批量推迟(保持原回调)
315    /// let updates: Vec<_> = task_ids
316    ///     .into_iter()
317    ///     .map(|id| (id, Duration::from_secs(10)))
318    ///     .collect();
319    /// let postponed = service.postpone_batch(updates);
320    /// println!("成功推迟 {} 个任务", postponed);
321    /// # }
322    /// ```
323    #[inline]
324    pub fn postpone_batch(&self, updates: Vec<(TaskId, Duration)>) -> usize {
325        if updates.is_empty() {
326            return 0;
327        }
328
329        let mut wheel = self.wheel.lock();
330        wheel.postpone_batch(updates)
331    }
332
333    /// 批量推迟任务(替换回调)
334    ///
335    /// # 参数
336    /// - `updates`: (任务ID, 新延迟, 新回调) 的元组列表
337    ///
338    /// # 返回
339    /// 成功推迟的任务数量
340    ///
341    /// # 示例
342    /// ```no_run
343    /// # use kestrel_protocol_timer::{TimerWheel, TimerService, CallbackWrapper, ServiceConfig};
344    /// # use std::time::Duration;
345    /// # 
346    /// # #[tokio::main]
347    /// # async fn main() {
348    /// let timer = TimerWheel::with_defaults();
349    /// let service = timer.create_service(ServiceConfig::default());
350    /// 
351    /// // 创建 3 个任务,初始没有回调
352    /// let delays: Vec<Duration> = (0..3)
353    ///     .map(|_| Duration::from_secs(5))
354    ///     .collect();
355    /// let tasks = TimerService::create_batch(delays);
356    /// let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
357    /// service.register_batch(tasks).unwrap();
358    /// 
359    /// // 批量推迟并添加新回调
360    /// let updates: Vec<_> = task_ids
361    ///     .into_iter()
362    ///     .enumerate()
363    ///     .map(|(i, id)| {
364    ///         let callback = Some(CallbackWrapper::new(move || async move {
365    ///             println!("New callback {}", i);
366    ///         }));
367    ///         (id, Duration::from_secs(10), callback)
368    ///     })
369    ///     .collect();
370    /// let postponed = service.postpone_batch_with_callbacks(updates);
371    /// println!("成功推迟 {} 个任务", postponed);
372    /// # }
373    /// ```
374    #[inline]
375    pub fn postpone_batch_with_callbacks(
376        &self,
377        updates: Vec<(TaskId, Duration, Option<CallbackWrapper>)>,
378    ) -> usize {
379        if updates.is_empty() {
380            return 0;
381        }
382
383        let mut wheel = self.wheel.lock();
384        wheel.postpone_batch_with_callbacks(updates)
385    }
386
387    /// 创建定时器任务(静态方法,申请阶段)
388    /// 
389    /// # 参数
390    /// - `delay`: 延迟时间
391    /// - `callback`: 实现了 TimerCallback trait 的回调对象
392    /// 
393    /// # 返回
394    /// 返回 TimerTask,需要通过 `register()` 注册
395    /// 
396    /// # 示例
397    /// ```no_run
398    /// # use kestrel_protocol_timer::{TimerWheel, TimerService, CallbackWrapper, ServiceConfig};
399    /// # use std::time::Duration;
400    /// # 
401    /// # #[tokio::main]
402    /// # async fn main() {
403    /// let timer = TimerWheel::with_defaults();
404    /// let service = timer.create_service(ServiceConfig::default());
405    /// 
406    /// // 步骤 1: 创建任务
407    /// let callback = Some(CallbackWrapper::new(|| async {
408    ///     println!("Timer fired!");
409    /// }));
410    /// let task = TimerService::create_task(Duration::from_millis(100), callback);
411    /// 
412    /// let task_id = task.get_id();
413    /// println!("Created task: {:?}", task_id);
414    /// 
415    /// // 步骤 2: 注册任务
416    /// service.register(task).unwrap();
417    /// # }
418    /// ```
419    #[inline]
420    pub fn create_task(delay: Duration, callback: Option<CallbackWrapper>) -> crate::task::TimerTask {
421        crate::timer::TimerWheel::create_task(delay, callback)
422    }
423
424    /// 批量创建定时器任务(静态方法,申请阶段,不带回调)
425    /// 
426    /// # 参数
427    /// - `delays`: 延迟时间列表
428    /// 
429    /// # 返回
430    /// 返回 TimerTask 列表,需要通过 `register_batch()` 注册
431    /// 
432    /// # 示例
433    /// ```no_run
434    /// # use kestrel_protocol_timer::{TimerWheel, TimerService, CallbackWrapper, ServiceConfig};
435    /// # use std::time::Duration;
436    /// # 
437    /// # #[tokio::main]
438    /// # async fn main() {
439    /// let timer = TimerWheel::with_defaults();
440    /// let service = timer.create_service(ServiceConfig::default());
441    /// 
442    /// // 步骤 1: 批量创建任务
443    /// let delays: Vec<Duration> = (0..3)
444    ///     .map(|i| Duration::from_millis(100 * (i + 1)))
445    ///     .collect();
446    /// 
447    /// let tasks = TimerService::create_batch(delays);
448    /// println!("Created {} tasks", tasks.len());
449    /// 
450    /// // 步骤 2: 批量注册任务
451    /// service.register_batch(tasks).unwrap();
452    /// # }
453    /// ```
454    #[inline]
455    pub fn create_batch(delays: Vec<Duration>) -> Vec<crate::task::TimerTask> {
456        crate::timer::TimerWheel::create_batch(delays)
457    }
458    
459    /// 批量创建定时器任务(静态方法,申请阶段,带回调)
460    /// 
461    /// # 参数
462    /// - `callbacks`: (延迟时间, 回调) 的元组列表
463    /// 
464    /// # 返回
465    /// 返回 TimerTask 列表,需要通过 `register_batch()` 注册
466    /// 
467    /// # 示例
468    /// ```no_run
469    /// # use kestrel_protocol_timer::{TimerWheel, TimerService, CallbackWrapper, ServiceConfig};
470    /// # use std::time::Duration;
471    /// # 
472    /// # #[tokio::main]
473    /// # async fn main() {
474    /// let timer = TimerWheel::with_defaults();
475    /// let service = timer.create_service(ServiceConfig::default());
476    /// 
477    /// // 步骤 1: 批量创建任务
478    /// let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = (0..3)
479    ///     .map(|i| {
480    ///         let callback = Some(CallbackWrapper::new(move || async move {
481    ///             println!("Timer {} fired!", i);
482    ///         }));
483    ///         (Duration::from_millis(100 * (i + 1)), callback)
484    ///     })
485    ///     .collect();
486    /// 
487    /// let tasks = TimerService::create_batch_with_callbacks(callbacks);
488    /// println!("Created {} tasks", tasks.len());
489    /// 
490    /// // 步骤 2: 批量注册任务
491    /// service.register_batch(tasks).unwrap();
492    /// # }
493    /// ```
494    #[inline]
495    pub fn create_batch_with_callbacks(callbacks: Vec<(Duration, Option<CallbackWrapper>)>) -> Vec<crate::task::TimerTask> {
496        crate::timer::TimerWheel::create_batch_with_callbacks(callbacks)
497    }
498    
499    /// 注册定时器任务到服务(注册阶段)
500    /// 
501    /// # 参数
502    /// - `task`: 通过 `create_task()` 创建的任务
503    /// 
504    /// # 返回
505    /// - `Ok(())`: 注册成功
506    /// - `Err(TimerError::RegisterFailed)`: 注册失败(内部通道已满或已关闭)
507    /// 
508    /// # 示例
509    /// ```no_run
510    /// # use kestrel_protocol_timer::{TimerWheel, TimerService, CallbackWrapper, ServiceConfig};
511    /// # use std::time::Duration;
512    /// # 
513    /// # #[tokio::main]
514    /// # async fn main() {
515    /// let timer = TimerWheel::with_defaults();
516    /// let service = timer.create_service(ServiceConfig::default());
517    /// 
518    /// let callback = Some(CallbackWrapper::new(|| async move {
519    ///     println!("Timer fired!");
520    /// }));
521    /// let task = TimerService::create_task(Duration::from_millis(100), callback);
522    /// let task_id = task.get_id();
523    /// 
524    /// service.register(task).unwrap();
525    /// # }
526    /// ```
527    #[inline]
528    pub fn register(&self, task: crate::task::TimerTask) -> Result<(), TimerError> {
529        let (completion_tx, completion_rx) = tokio::sync::oneshot::channel();
530        let notifier = crate::task::CompletionNotifier(completion_tx);
531        
532        let task_id = task.id;
533        
534        // 单次加锁完成所有操作
535        {
536            let mut wheel_guard = self.wheel.lock();
537            wheel_guard.insert(task, notifier);
538        }
539        
540        // 添加到服务管理(只发送必要数据)
541        self.command_tx
542            .try_send(ServiceCommand::AddTimerHandle {
543                task_id,
544                completion_rx,
545            })
546            .map_err(|_| TimerError::RegisterFailed)?;
547        
548        Ok(())
549    }
550    
551    /// 批量注册定时器任务到服务(注册阶段)
552    /// 
553    /// # 参数
554    /// - `tasks`: 通过 `create_batch()` 创建的任务列表
555    /// 
556    /// # 返回
557    /// - `Ok(())`: 注册成功
558    /// - `Err(TimerError::RegisterFailed)`: 注册失败(内部通道已满或已关闭)
559    /// 
560    /// # 示例
561    /// ```no_run
562    /// # use kestrel_protocol_timer::{TimerWheel, TimerService, CallbackWrapper, ServiceConfig};
563    /// # use std::time::Duration;
564    /// # 
565    /// # #[tokio::main]
566    /// # async fn main() {
567    /// let timer = TimerWheel::with_defaults();
568    /// let service = timer.create_service(ServiceConfig::default());
569    /// 
570    /// let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = (0..3)
571    ///     .map(|i| {
572    ///         let callback = Some(CallbackWrapper::new(move || async move {
573    ///             println!("Timer {} fired!", i);
574    ///         }));
575    ///         (Duration::from_secs(1), callback)
576    ///     })
577    ///     .collect();
578    /// let tasks = TimerService::create_batch_with_callbacks(callbacks);
579    /// 
580    /// service.register_batch(tasks).unwrap();
581    /// # }
582    /// ```
583    #[inline]
584    pub fn register_batch(&self, tasks: Vec<crate::task::TimerTask>) -> Result<(), TimerError> {
585        let task_count = tasks.len();
586        let mut completion_rxs = Vec::with_capacity(task_count);
587        let mut task_ids = Vec::with_capacity(task_count);
588        let mut prepared_tasks = Vec::with_capacity(task_count);
589        
590        // 步骤1: 准备所有 channels 和 notifiers(无锁)
591        // 优化:使用 for 循环代替 map + collect,避免闭包捕获开销
592        for task in tasks {
593            let (completion_tx, completion_rx) = tokio::sync::oneshot::channel();
594            let notifier = crate::task::CompletionNotifier(completion_tx);
595            
596            task_ids.push(task.id);
597            completion_rxs.push(completion_rx);
598            prepared_tasks.push((task, notifier));
599        }
600        
601        // 步骤2: 单次加锁,批量插入
602        {
603            let mut wheel_guard = self.wheel.lock();
604            wheel_guard.insert_batch(prepared_tasks);
605        }
606        
607        // 添加到服务管理(只发送必要数据)
608        self.command_tx
609            .try_send(ServiceCommand::AddBatchHandle {
610                task_ids,
611                completion_rxs,
612            })
613            .map_err(|_| TimerError::RegisterFailed)?;
614        
615        Ok(())
616    }
617
618    /// 优雅关闭 TimerService
619    ///
620    /// # 示例
621    /// ```no_run
622    /// # use kestrel_protocol_timer::{TimerWheel, ServiceConfig};
623    /// # #[tokio::main]
624    /// # async fn main() {
625    /// let timer = TimerWheel::with_defaults();
626    /// let mut service = timer.create_service(ServiceConfig::default());
627    /// 
628    /// // 使用 service...
629    /// 
630    /// service.shutdown().await;
631    /// # }
632    /// ```
633    pub async fn shutdown(mut self) {
634        let _ = self.command_tx.send(ServiceCommand::Shutdown).await;
635        if let Some(handle) = self.actor_handle.take() {
636            let _ = handle.await;
637        }
638    }
639}
640
641
642impl Drop for TimerService {
643    fn drop(&mut self) {
644        if let Some(handle) = self.actor_handle.take() {
645            handle.abort();
646        }
647    }
648}
649
650/// ServiceActor - 内部 Actor 实现
651struct ServiceActor {
652    /// 命令接收端
653    command_rx: mpsc::Receiver<ServiceCommand>,
654    /// 超时发送端
655    timeout_tx: mpsc::Sender<TaskId>,
656}
657
658impl ServiceActor {
659    fn new(command_rx: mpsc::Receiver<ServiceCommand>, timeout_tx: mpsc::Sender<TaskId>) -> Self {
660        Self {
661            command_rx,
662            timeout_tx,
663        }
664    }
665
666    async fn run(mut self) {
667        // 使用 FuturesUnordered 来监听所有的 completion_rxs
668        // 每个 future 返回 (TaskId, Result)
669        let mut futures: FuturesUnordered<BoxFuture<'static, (TaskId, Result<TaskCompletionReason, tokio::sync::oneshot::error::RecvError>)>> = FuturesUnordered::new();
670
671        loop {
672            tokio::select! {
673                // 监听超时事件
674                Some((task_id, result)) = futures.next() => {
675                    // 检查完成原因,只转发超时(Expired)事件,不转发取消(Cancelled)事件
676                    if let Ok(TaskCompletionReason::Expired) = result {
677                        let _ = self.timeout_tx.send(task_id).await;
678                    }
679                    // 任务会自动从 FuturesUnordered 中移除
680                }
681                
682                // 监听命令
683                Some(cmd) = self.command_rx.recv() => {
684                    match cmd {
685                        ServiceCommand::AddBatchHandle { task_ids, completion_rxs } => {
686                            // 将所有任务添加到 futures 中
687                            for (task_id, rx) in task_ids.into_iter().zip(completion_rxs.into_iter()) {
688                                let future: BoxFuture<'static, (TaskId, Result<TaskCompletionReason, tokio::sync::oneshot::error::RecvError>)> = Box::pin(async move {
689                                    (task_id, rx.await)
690                                });
691                                futures.push(future);
692                            }
693                        }
694                        ServiceCommand::AddTimerHandle { task_id, completion_rx } => {
695                            
696                            // 添加到 futures 中
697                            let future: BoxFuture<'static, (TaskId, Result<TaskCompletionReason, tokio::sync::oneshot::error::RecvError>)> = Box::pin(async move {
698                                (task_id, completion_rx.await)
699                            });
700                            futures.push(future);
701                        }
702                        ServiceCommand::Shutdown => {
703                            break;
704                        }
705                    }
706                }
707                
708                // 如果没有任何 future 且命令通道已关闭,退出循环
709                else => {
710                    break;
711                }
712            }
713        }
714    }
715}
716
717#[cfg(test)]
718mod tests {
719    use super::*;
720    use crate::TimerWheel;
721    use std::sync::atomic::{AtomicU32, Ordering};
722    use std::sync::Arc;
723    use std::time::Duration;
724
725    #[tokio::test]
726    async fn test_service_creation() {
727        let timer = TimerWheel::with_defaults();
728        let _service = timer.create_service(ServiceConfig::default());
729    }
730
731
732    #[tokio::test]
733    async fn test_add_timer_handle_and_receive_timeout() {
734        let timer = TimerWheel::with_defaults();
735        let mut service = timer.create_service(ServiceConfig::default());
736
737        // 创建单个定时器
738        let task = TimerService::create_task(Duration::from_millis(50), Some(CallbackWrapper::new(|| async {})));
739        let task_id = task.get_id();
740        
741        // 注册到 service
742        service.register(task).unwrap();
743
744        // 接收超时通知
745        let mut rx = service.take_receiver().unwrap();
746        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
747            .await
748            .expect("Should receive timeout notification")
749            .expect("Should receive Some value");
750
751        assert_eq!(received_task_id, task_id);
752    }
753
754
755    #[tokio::test]
756    async fn test_shutdown() {
757        let timer = TimerWheel::with_defaults();
758        let service = timer.create_service(ServiceConfig::default());
759
760        // 添加一些定时器
761        let task1 = TimerService::create_task(Duration::from_secs(10), None);
762        let task2 = TimerService::create_task(Duration::from_secs(10), None);
763        service.register(task1).unwrap();
764        service.register(task2).unwrap();
765
766        // 立即关闭(不等待定时器触发)
767        service.shutdown().await;
768    }
769
770
771
772    #[tokio::test]
773    async fn test_cancel_task() {
774        let timer = TimerWheel::with_defaults();
775        let service = timer.create_service(ServiceConfig::default());
776
777        // 添加一个长时间的定时器
778        let task = TimerService::create_task(Duration::from_secs(10), None);
779        let task_id = task.get_id();
780        
781        service.register(task).unwrap();
782
783        // 取消任务
784        let cancelled = service.cancel_task(task_id);
785        assert!(cancelled, "Task should be cancelled successfully");
786
787        // 尝试再次取消同一个任务,应该返回 false
788        let cancelled_again = service.cancel_task(task_id);
789        assert!(!cancelled_again, "Task should not exist anymore");
790    }
791
792    #[tokio::test]
793    async fn test_cancel_nonexistent_task() {
794        let timer = TimerWheel::with_defaults();
795        let service = timer.create_service(ServiceConfig::default());
796
797        // 添加一个定时器以初始化 service
798        let task = TimerService::create_task(Duration::from_millis(50), None);
799        service.register(task).unwrap();
800
801        // 尝试取消一个不存在的任务(创建一个不会实际注册的任务ID)
802        let fake_task = TimerService::create_task(Duration::from_millis(50), None);
803        let fake_task_id = fake_task.get_id();
804        // 不注册 fake_task
805        let cancelled = service.cancel_task(fake_task_id);
806        assert!(!cancelled, "Nonexistent task should not be cancelled");
807    }
808
809
810    #[tokio::test]
811    async fn test_task_timeout_cleans_up_task_sender() {
812        let timer = TimerWheel::with_defaults();
813        let mut service = timer.create_service(ServiceConfig::default());
814
815        // 添加一个短时间的定时器
816        let task = TimerService::create_task(Duration::from_millis(50), None);
817        let task_id = task.get_id();
818        
819        service.register(task).unwrap();
820
821        // 等待任务超时
822        let mut rx = service.take_receiver().unwrap();
823        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
824            .await
825            .expect("Should receive timeout notification")
826            .expect("Should receive Some value");
827        
828        assert_eq!(received_task_id, task_id);
829
830        // 等待一下确保内部清理完成
831        tokio::time::sleep(Duration::from_millis(10)).await;
832
833        // 尝试取消已经超时的任务,应该返回 false
834        let cancelled = service.cancel_task(task_id);
835        assert!(!cancelled, "Timed out task should not exist anymore");
836    }
837
838    #[tokio::test]
839    async fn test_cancel_task_spawns_background_task() {
840        let timer = TimerWheel::with_defaults();
841        let service = timer.create_service(ServiceConfig::default());
842        let counter = Arc::new(AtomicU32::new(0));
843
844        // 创建一个定时器
845        let counter_clone = Arc::clone(&counter);
846        let task = TimerService::create_task(
847            Duration::from_secs(10),
848            Some(CallbackWrapper::new(move || {
849                let counter = Arc::clone(&counter_clone);
850                async move {
851                    counter.fetch_add(1, Ordering::SeqCst);
852                }
853            })),
854        );
855        let task_id = task.get_id();
856        
857        service.register(task).unwrap();
858
859        // 使用 cancel_task(会等待结果,但在后台协程中处理)
860        let cancelled = service.cancel_task(task_id);
861        assert!(cancelled, "Task should be cancelled successfully");
862
863        // 等待足够长时间确保回调不会被执行
864        tokio::time::sleep(Duration::from_millis(100)).await;
865        assert_eq!(counter.load(Ordering::SeqCst), 0, "Callback should not have been executed");
866
867        // 验证任务已从 active_tasks 中移除
868        let cancelled_again = service.cancel_task(task_id);
869        assert!(!cancelled_again, "Task should have been removed from active_tasks");
870    }
871
872    #[tokio::test]
873    async fn test_schedule_once_direct() {
874        let timer = TimerWheel::with_defaults();
875        let mut service = timer.create_service(ServiceConfig::default());
876        let counter = Arc::new(AtomicU32::new(0));
877
878        // 直接通过 service 调度定时器
879        let counter_clone = Arc::clone(&counter);
880        let task = TimerService::create_task(
881            Duration::from_millis(50),
882            Some(CallbackWrapper::new(move || {
883                let counter = Arc::clone(&counter_clone);
884                async move {
885                    counter.fetch_add(1, Ordering::SeqCst);
886                }
887            })),
888        );
889        let task_id = task.get_id();
890        service.register(task).unwrap();
891
892        // 等待定时器触发
893        let mut rx = service.take_receiver().unwrap();
894        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
895            .await
896            .expect("Should receive timeout notification")
897            .expect("Should receive Some value");
898
899        assert_eq!(received_task_id, task_id);
900        
901        // 等待回调执行
902        tokio::time::sleep(Duration::from_millis(50)).await;
903        assert_eq!(counter.load(Ordering::SeqCst), 1);
904    }
905
906    #[tokio::test]
907    async fn test_schedule_once_batch_direct() {
908        let timer = TimerWheel::with_defaults();
909        let mut service = timer.create_service(ServiceConfig::default());
910        let counter = Arc::new(AtomicU32::new(0));
911
912        // 直接通过 service 批量调度定时器
913        let callbacks: Vec<_> = (0..3)
914            .map(|_| {
915                let counter = Arc::clone(&counter);
916                (Duration::from_millis(50), Some(CallbackWrapper::new(move || {
917                    let counter = Arc::clone(&counter);
918                    async move {
919                        counter.fetch_add(1, Ordering::SeqCst);
920                    }
921                    })))
922            })
923            .collect();
924
925        let tasks = TimerService::create_batch_with_callbacks(callbacks);
926        assert_eq!(tasks.len(), 3);
927        service.register_batch(tasks).unwrap();
928
929        // 接收所有超时通知
930        let mut received_count = 0;
931        let mut rx = service.take_receiver().unwrap();
932        
933        while received_count < 3 {
934            match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
935                Ok(Some(_task_id)) => {
936                    received_count += 1;
937                }
938                Ok(None) => break,
939                Err(_) => break,
940            }
941        }
942
943        assert_eq!(received_count, 3);
944        
945        // 等待回调执行
946        tokio::time::sleep(Duration::from_millis(50)).await;
947        assert_eq!(counter.load(Ordering::SeqCst), 3);
948    }
949
950    #[tokio::test]
951    async fn test_schedule_once_notify_direct() {
952        let timer = TimerWheel::with_defaults();
953        let mut service = timer.create_service(ServiceConfig::default());
954
955        // 直接通过 service 调度仅通知的定时器(无回调)
956        let task = TimerService::create_task(Duration::from_millis(50), None);
957        let task_id = task.get_id();
958        service.register(task).unwrap();
959
960        // 接收超时通知
961        let mut rx = service.take_receiver().unwrap();
962        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
963            .await
964            .expect("Should receive timeout notification")
965            .expect("Should receive Some value");
966
967        assert_eq!(received_task_id, task_id);
968    }
969
970    #[tokio::test]
971    async fn test_schedule_and_cancel_direct() {
972        let timer = TimerWheel::with_defaults();
973        let service = timer.create_service(ServiceConfig::default());
974        let counter = Arc::new(AtomicU32::new(0));
975
976        // 直接调度定时器
977        let counter_clone = Arc::clone(&counter);
978        let task = TimerService::create_task(
979            Duration::from_secs(10),
980            Some(CallbackWrapper::new(move || {
981                let counter = Arc::clone(&counter_clone);
982                async move {
983                    counter.fetch_add(1, Ordering::SeqCst);
984                }
985            })),
986        );
987        let task_id = task.get_id();
988        service.register(task).unwrap();
989
990        // 立即取消
991        let cancelled = service.cancel_task(task_id);
992        assert!(cancelled, "Task should be cancelled successfully");
993
994        // 等待确保回调不会执行
995        tokio::time::sleep(Duration::from_millis(100)).await;
996        assert_eq!(counter.load(Ordering::SeqCst), 0, "Callback should not have been executed");
997    }
998
999    #[tokio::test]
1000    async fn test_cancel_batch_direct() {
1001        let timer = TimerWheel::with_defaults();
1002        let service = timer.create_service(ServiceConfig::default());
1003        let counter = Arc::new(AtomicU32::new(0));
1004
1005        // 批量调度定时器
1006        let callbacks: Vec<_> = (0..10)
1007            .map(|_| {
1008                let counter = Arc::clone(&counter);
1009                (Duration::from_secs(10), Some(CallbackWrapper::new(move || {
1010                    let counter = Arc::clone(&counter);
1011                    async move {
1012                        counter.fetch_add(1, Ordering::SeqCst);
1013                    }
1014                    })))
1015            })
1016            .collect();
1017
1018        let tasks = TimerService::create_batch_with_callbacks(callbacks);
1019        let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
1020        assert_eq!(task_ids.len(), 10);
1021        service.register_batch(tasks).unwrap();
1022
1023        // 批量取消所有任务
1024        let cancelled = service.cancel_batch(&task_ids);
1025        assert_eq!(cancelled, 10, "All 10 tasks should be cancelled");
1026
1027        // 等待确保回调不会执行
1028        tokio::time::sleep(Duration::from_millis(100)).await;
1029        assert_eq!(counter.load(Ordering::SeqCst), 0, "No callbacks should have been executed");
1030    }
1031
1032    #[tokio::test]
1033    async fn test_cancel_batch_partial() {
1034        let timer = TimerWheel::with_defaults();
1035        let service = timer.create_service(ServiceConfig::default());
1036        let counter = Arc::new(AtomicU32::new(0));
1037
1038        // 批量调度定时器
1039        let callbacks: Vec<_> = (0..10)
1040            .map(|_| {
1041                let counter = Arc::clone(&counter);
1042                (Duration::from_secs(10), Some(CallbackWrapper::new(move || {
1043                    let counter = Arc::clone(&counter);
1044                    async move {
1045                        counter.fetch_add(1, Ordering::SeqCst);
1046                    }
1047                    })))
1048            })
1049            .collect();
1050
1051        let tasks = TimerService::create_batch_with_callbacks(callbacks);
1052        let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
1053        service.register_batch(tasks).unwrap();
1054
1055        // 只取消前5个任务
1056        let to_cancel: Vec<_> = task_ids.iter().take(5).copied().collect();
1057        let cancelled = service.cancel_batch(&to_cancel);
1058        assert_eq!(cancelled, 5, "5 tasks should be cancelled");
1059
1060        // 等待确保前5个回调不会执行
1061        tokio::time::sleep(Duration::from_millis(100)).await;
1062        assert_eq!(counter.load(Ordering::SeqCst), 0, "Cancelled tasks should not execute");
1063    }
1064
1065    #[tokio::test]
1066    async fn test_cancel_batch_empty() {
1067        let timer = TimerWheel::with_defaults();
1068        let service = timer.create_service(ServiceConfig::default());
1069
1070        // 取消空列表
1071        let empty: Vec<TaskId> = vec![];
1072        let cancelled = service.cancel_batch(&empty);
1073        assert_eq!(cancelled, 0, "No tasks should be cancelled");
1074    }
1075
1076    #[tokio::test]
1077    async fn test_postpone() {
1078        let timer = TimerWheel::with_defaults();
1079        let mut service = timer.create_service(ServiceConfig::default());
1080        let counter = Arc::new(AtomicU32::new(0));
1081
1082        // 注册一个任务,原始回调增加 1
1083        let counter_clone1 = Arc::clone(&counter);
1084        let task = TimerService::create_task(
1085            Duration::from_millis(50),
1086            Some(CallbackWrapper::new(move || {
1087                let counter = Arc::clone(&counter_clone1);
1088                async move {
1089                    counter.fetch_add(1, Ordering::SeqCst);
1090                }
1091                })),
1092        );
1093        let task_id = task.get_id();
1094        service.register(task).unwrap();
1095
1096        // 推迟任务并替换回调,新回调增加 10
1097        let counter_clone2 = Arc::clone(&counter);
1098        let postponed = service.postpone(
1099            task_id,
1100            Duration::from_millis(100),
1101            Some(CallbackWrapper::new(move || {
1102                let counter = Arc::clone(&counter_clone2);
1103                async move {
1104                    counter.fetch_add(10, Ordering::SeqCst);
1105                }
1106            }))
1107        );
1108        assert!(postponed, "Task should be postponed successfully");
1109
1110        // 接收超时通知(推迟后需要等待100ms,加上余量)
1111        let mut rx = service.take_receiver().unwrap();
1112        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1113            .await
1114            .expect("Should receive timeout notification")
1115            .expect("Should receive Some value");
1116
1117        assert_eq!(received_task_id, task_id);
1118        
1119        // 等待回调执行
1120        tokio::time::sleep(Duration::from_millis(20)).await;
1121        
1122        // 验证新回调被执行(增加了 10 而不是 1)
1123        assert_eq!(counter.load(Ordering::SeqCst), 10);
1124    }
1125
1126    #[tokio::test]
1127    async fn test_postpone_nonexistent_task() {
1128        let timer = TimerWheel::with_defaults();
1129        let service = timer.create_service(ServiceConfig::default());
1130
1131        // 尝试推迟一个不存在的任务
1132        let fake_task = TimerService::create_task(Duration::from_millis(50), None);
1133        let fake_task_id = fake_task.get_id();
1134        // 不注册这个任务
1135        
1136        let postponed = service.postpone(fake_task_id, Duration::from_millis(100), None);
1137        assert!(!postponed, "Nonexistent task should not be postponed");
1138    }
1139
1140    #[tokio::test]
1141    async fn test_postpone_batch() {
1142        let timer = TimerWheel::with_defaults();
1143        let mut service = timer.create_service(ServiceConfig::default());
1144        let counter = Arc::new(AtomicU32::new(0));
1145
1146        // 注册 3 个任务
1147        let mut task_ids = Vec::new();
1148        for _ in 0..3 {
1149            let counter_clone = Arc::clone(&counter);
1150            let task = TimerService::create_task(
1151                Duration::from_millis(50),
1152            Some(CallbackWrapper::new(move || {
1153                let counter = Arc::clone(&counter_clone);
1154                async move {
1155                    counter.fetch_add(1, Ordering::SeqCst);
1156                }
1157            })),
1158        );
1159            task_ids.push((task.get_id(), Duration::from_millis(150), None));
1160            service.register(task).unwrap();
1161        }
1162
1163        // 批量推迟
1164        let postponed = service.postpone_batch_with_callbacks(task_ids);
1165        assert_eq!(postponed, 3, "All 3 tasks should be postponed");
1166
1167        // 等待原定时间 50ms,任务不应该触发
1168        tokio::time::sleep(Duration::from_millis(70)).await;
1169        assert_eq!(counter.load(Ordering::SeqCst), 0);
1170
1171        // 接收所有超时通知
1172        let mut received_count = 0;
1173        let mut rx = service.take_receiver().unwrap();
1174        
1175        while received_count < 3 {
1176            match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
1177                Ok(Some(_task_id)) => {
1178                    received_count += 1;
1179                }
1180                Ok(None) => break,
1181                Err(_) => break,
1182            }
1183        }
1184
1185        assert_eq!(received_count, 3);
1186        
1187        // 等待回调执行
1188        tokio::time::sleep(Duration::from_millis(20)).await;
1189        assert_eq!(counter.load(Ordering::SeqCst), 3);
1190    }
1191
1192    #[tokio::test]
1193    async fn test_postpone_batch_with_callbacks() {
1194        let timer = TimerWheel::with_defaults();
1195        let mut service = timer.create_service(ServiceConfig::default());
1196        let counter = Arc::new(AtomicU32::new(0));
1197
1198        // 注册 3 个任务
1199        let mut task_ids = Vec::new();
1200        for _ in 0..3 {
1201            let task = TimerService::create_task(
1202                Duration::from_millis(50),
1203                None,
1204            );
1205            task_ids.push(task.get_id());
1206            service.register(task).unwrap();
1207        }
1208
1209        // 批量推迟并替换回调
1210        let updates: Vec<_> = task_ids
1211            .into_iter()
1212            .map(|id| {
1213                let counter_clone = Arc::clone(&counter);
1214                (id, Duration::from_millis(150), Some(CallbackWrapper::new(move || {
1215                    let counter = Arc::clone(&counter_clone);
1216                    async move {
1217                        counter.fetch_add(1, Ordering::SeqCst);
1218                    }
1219                    })))
1220            })
1221            .collect();
1222
1223        let postponed = service.postpone_batch_with_callbacks(updates);
1224        assert_eq!(postponed, 3, "All 3 tasks should be postponed");
1225
1226        // 等待原定时间 50ms,任务不应该触发
1227        tokio::time::sleep(Duration::from_millis(70)).await;
1228        assert_eq!(counter.load(Ordering::SeqCst), 0);
1229
1230        // 接收所有超时通知
1231        let mut received_count = 0;
1232        let mut rx = service.take_receiver().unwrap();
1233        
1234        while received_count < 3 {
1235            match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
1236                Ok(Some(_task_id)) => {
1237                    received_count += 1;
1238                }
1239                Ok(None) => break,
1240                Err(_) => break,
1241            }
1242        }
1243
1244        assert_eq!(received_count, 3);
1245        
1246        // 等待回调执行
1247        tokio::time::sleep(Duration::from_millis(20)).await;
1248        assert_eq!(counter.load(Ordering::SeqCst), 3);
1249    }
1250
1251    #[tokio::test]
1252    async fn test_postpone_batch_empty() {
1253        let timer = TimerWheel::with_defaults();
1254        let service = timer.create_service(ServiceConfig::default());
1255
1256        // 推迟空列表
1257        let empty: Vec<(TaskId, Duration, Option<CallbackWrapper>)> = vec![];
1258        let postponed = service.postpone_batch_with_callbacks(empty);
1259        assert_eq!(postponed, 0, "No tasks should be postponed");
1260    }
1261
1262    #[tokio::test]
1263    async fn test_postpone_keeps_timeout_notification_valid() {
1264        let timer = TimerWheel::with_defaults();
1265        let mut service = timer.create_service(ServiceConfig::default());
1266        let counter = Arc::new(AtomicU32::new(0));
1267
1268        // 注册一个任务
1269        let counter_clone = Arc::clone(&counter);
1270        let task = TimerService::create_task(
1271            Duration::from_millis(50),
1272            Some(CallbackWrapper::new(move || {
1273                let counter = Arc::clone(&counter_clone);
1274                async move {
1275                    counter.fetch_add(1, Ordering::SeqCst);
1276                }
1277            })),
1278        );
1279        let task_id = task.get_id();
1280        service.register(task).unwrap();
1281
1282        // 推迟任务
1283        service.postpone(task_id, Duration::from_millis(100), None);
1284
1285        // 验证超时通知仍然有效(推迟后需要等待100ms,加上余量)
1286        let mut rx = service.take_receiver().unwrap();
1287        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1288            .await
1289            .expect("Should receive timeout notification")
1290            .expect("Should receive Some value");
1291
1292        assert_eq!(received_task_id, task_id, "Timeout notification should still work after postpone");
1293        
1294        // 等待回调执行
1295        tokio::time::sleep(Duration::from_millis(20)).await;
1296        assert_eq!(counter.load(Ordering::SeqCst), 1);
1297    }
1298
1299    #[tokio::test]
1300    async fn test_cancelled_task_not_forwarded_to_timeout_rx() {
1301        let timer = TimerWheel::with_defaults();
1302        let mut service = timer.create_service(ServiceConfig::default());
1303
1304        // 注册两个任务:一个会被取消,一个会正常到期
1305        let task1 = TimerService::create_task(Duration::from_secs(10), None);
1306        let task1_id = task1.get_id();
1307        service.register(task1).unwrap();
1308
1309        let task2 = TimerService::create_task(Duration::from_millis(50), None);
1310        let task2_id = task2.get_id();
1311        service.register(task2).unwrap();
1312
1313        // 取消第一个任务
1314        let cancelled = service.cancel_task(task1_id);
1315        assert!(cancelled, "Task should be cancelled");
1316
1317        // 等待第二个任务到期
1318        let mut rx = service.take_receiver().unwrap();
1319        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1320            .await
1321            .expect("Should receive timeout notification")
1322            .expect("Should receive Some value");
1323
1324        // 应该只收到第二个任务(到期的)的通知,不应该收到第一个任务(取消的)的通知
1325        assert_eq!(received_task_id, task2_id, "Should only receive expired task notification");
1326
1327        // 验证没有其他通知(特别是被取消的任务不应该有通知)
1328        let no_more = tokio::time::timeout(Duration::from_millis(50), rx.recv()).await;
1329        assert!(no_more.is_err(), "Should not receive any more notifications");
1330    }
1331
1332    #[tokio::test]
1333    async fn test_take_receiver_twice() {
1334        let timer = TimerWheel::with_defaults();
1335        let mut service = timer.create_service(ServiceConfig::default());
1336
1337        // 第一次调用应该返回 Some
1338        let rx1 = service.take_receiver();
1339        assert!(rx1.is_some(), "First take_receiver should return Some");
1340
1341        // 第二次调用应该返回 None
1342        let rx2 = service.take_receiver();
1343        assert!(rx2.is_none(), "Second take_receiver should return None");
1344    }
1345
1346    #[tokio::test]
1347    async fn test_postpone_batch_without_callbacks() {
1348        let timer = TimerWheel::with_defaults();
1349        let mut service = timer.create_service(ServiceConfig::default());
1350        let counter = Arc::new(AtomicU32::new(0));
1351
1352        // 注册 3 个任务,有原始回调
1353        let mut task_ids = Vec::new();
1354        for _ in 0..3 {
1355            let counter_clone = Arc::clone(&counter);
1356            let task = TimerService::create_task(
1357                Duration::from_millis(50),
1358                Some(CallbackWrapper::new(move || {
1359                    let counter = Arc::clone(&counter_clone);
1360                    async move {
1361                        counter.fetch_add(1, Ordering::SeqCst);
1362                    }
1363                })),
1364            );
1365            task_ids.push(task.get_id());
1366            service.register(task).unwrap();
1367        }
1368
1369        // 批量推迟,不替换回调
1370        let updates: Vec<_> = task_ids
1371            .iter()
1372            .map(|&id| (id, Duration::from_millis(150)))
1373            .collect();
1374        let postponed = service.postpone_batch(updates);
1375        assert_eq!(postponed, 3, "All 3 tasks should be postponed");
1376
1377        // 等待原定时间 50ms,任务不应该触发
1378        tokio::time::sleep(Duration::from_millis(70)).await;
1379        assert_eq!(counter.load(Ordering::SeqCst), 0, "Callbacks should not fire yet");
1380
1381        // 接收所有超时通知
1382        let mut received_count = 0;
1383        let mut rx = service.take_receiver().unwrap();
1384        
1385        while received_count < 3 {
1386            match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
1387                Ok(Some(_task_id)) => {
1388                    received_count += 1;
1389                }
1390                Ok(None) => break,
1391                Err(_) => break,
1392            }
1393        }
1394
1395        assert_eq!(received_count, 3, "Should receive 3 timeout notifications");
1396        
1397        // 等待回调执行
1398        tokio::time::sleep(Duration::from_millis(20)).await;
1399        assert_eq!(counter.load(Ordering::SeqCst), 3, "All callbacks should execute");
1400    }
1401}
1402