kestrel_protocol_timer/
service.rs

1use crate::config::ServiceConfig;
2use crate::error::TimerError;
3use crate::task::{TaskCompletionReason, TaskId, TimerCallback};
4use crate::timer::{BatchHandle, TimerHandle};
5use crate::wheel::Wheel;
6use futures::stream::{FuturesUnordered, StreamExt};
7use futures::future::BoxFuture;
8use parking_lot::Mutex;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::sync::mpsc;
12use tokio::task::JoinHandle;
13
14/// TimerService 命令类型
15enum ServiceCommand {
16    /// 添加批量定时器句柄
17    AddBatchHandle(BatchHandle),
18    /// 添加单个定时器句柄
19    AddTimerHandle(TimerHandle),
20    /// 关闭 Service
21    Shutdown,
22}
23
24/// TimerService - 基于 Actor 模式的定时器服务
25///
26/// 管理多个定时器句柄,监听所有超时事件,并将 TaskId 聚合转发给用户。
27///
28/// # 特性
29/// - 自动监听所有添加的定时器句柄的超时事件
30/// - 超时后自动从内部管理中移除该任务
31/// - 将超时的 TaskId 转发到统一的通道供用户接收
32/// - 支持动态添加 BatchHandle 和 TimerHandle
33///
34/// # 示例
35/// ```no_run
36/// use kestrel_protocol_timer::{TimerWheel, TimerService};
37/// use std::time::Duration;
38///
39/// #[tokio::main]
40/// async fn main() {
41///     let timer = TimerWheel::with_defaults();
42///     let mut service = timer.create_service();
43///     
44///     // 使用两步式 API 通过 service 批量调度定时器
45///     let callbacks: Vec<_> = (0..5)
46///         .map(|_| (Duration::from_millis(100), || async {}))
47///         .collect();
48///     let tasks = TimerService::create_batch(callbacks);
49///     service.register_batch(tasks).unwrap();
50///     
51///     // 接收超时通知
52///     let mut rx = service.take_receiver().unwrap();
53///     while let Some(task_id) = rx.recv().await {
54///         println!("Task {:?} completed", task_id);
55///     }
56/// }
57/// ```
58pub struct TimerService {
59    /// 命令发送端
60    command_tx: mpsc::Sender<ServiceCommand>,
61    /// 超时接收端
62    timeout_rx: Option<mpsc::Receiver<TaskId>>,
63    /// Actor 任务句柄
64    actor_handle: Option<JoinHandle<()>>,
65    /// 时间轮引用(用于直接调度定时器)
66    wheel: Arc<Mutex<Wheel>>,
67}
68
69impl TimerService {
70    /// 创建新的 TimerService
71    ///
72    /// # 参数
73    /// - `wheel`: 时间轮引用
74    /// - `config`: 服务配置
75    ///
76    /// # 注意
77    /// 通常不直接调用此方法,而是使用 `TimerWheel::create_service()` 来创建。
78    ///
79    /// # 示例
80    /// ```no_run
81    /// use kestrel_protocol_timer::TimerWheel;
82    ///
83    /// #[tokio::main]
84    /// async fn main() {
85    ///     let timer = TimerWheel::with_defaults();
86    ///     let mut service = timer.create_service();
87    /// }
88    /// ```
89    pub(crate) fn new(wheel: Arc<Mutex<Wheel>>, config: ServiceConfig) -> Self {
90        let (command_tx, command_rx) = mpsc::channel(config.command_channel_capacity);
91        let (timeout_tx, timeout_rx) = mpsc::channel(config.timeout_channel_capacity);
92
93        let actor = ServiceActor::new(command_rx, timeout_tx);
94        let actor_handle = tokio::spawn(async move {
95            actor.run().await;
96        });
97
98        Self {
99            command_tx,
100            timeout_rx: Some(timeout_rx),
101            actor_handle: Some(actor_handle),
102            wheel,
103        }
104    }
105
106    /// 获取超时接收器(转移所有权)
107    ///
108    /// # 返回
109    /// 超时通知接收器,如果已经被取走则返回 None
110    ///
111    /// # 注意
112    /// 此方法只能调用一次,因为它会转移接收器的所有权
113    ///
114    /// # 示例
115    /// ```no_run
116    /// # use kestrel_protocol_timer::TimerWheel;
117    /// # use std::time::Duration;
118    /// # #[tokio::main]
119    /// # async fn main() {
120    /// let timer = TimerWheel::with_defaults();
121    /// let mut service = timer.create_service();
122    /// 
123    /// let mut rx = service.take_receiver().unwrap();
124    /// while let Some(task_id) = rx.recv().await {
125    ///     println!("Task {:?} timed out", task_id);
126    /// }
127    /// # }
128    /// ```
129    pub fn take_receiver(&mut self) -> Option<mpsc::Receiver<TaskId>> {
130        self.timeout_rx.take()
131    }
132
133    /// 取消指定的任务
134    ///
135    /// # 参数
136    /// - `task_id`: 要取消的任务 ID
137    ///
138    /// # 返回
139    /// - `true`: 任务存在且成功取消
140    /// - `false`: 任务不存在或取消失败
141    ///
142    /// # 性能说明
143    /// 此方法使用直接取消优化,不需要等待 Actor 处理,大幅降低延迟
144    ///
145    /// # 示例
146    /// ```no_run
147    /// # use kestrel_protocol_timer::{TimerWheel, TimerService};
148    /// # use std::time::Duration;
149    /// # #[tokio::main]
150    /// # async fn main() {
151    /// let timer = TimerWheel::with_defaults();
152    /// let service = timer.create_service();
153    /// 
154    /// // 使用两步式 API 调度定时器
155    /// let task = TimerService::create_task(Duration::from_secs(10), || async {});
156    /// let task_id = task.get_id();
157    /// service.register(task).unwrap();
158    /// 
159    /// // 取消任务
160    /// let cancelled = service.cancel_task(task_id);
161    /// println!("Task cancelled: {}", cancelled);
162    /// # }
163    /// ```
164    #[inline]
165    pub fn cancel_task(&self, task_id: TaskId) -> bool {
166        // 优化:直接取消任务,无需通知 Actor
167        // FuturesUnordered 会在任务被取消时自动清理
168        let mut wheel = self.wheel.lock();
169        wheel.cancel(task_id)
170    }
171
172    /// 批量取消任务
173    ///
174    /// 使用底层的批量取消操作一次性取消多个任务,性能优于循环调用 cancel_task。
175    ///
176    /// # 参数
177    /// - `task_ids`: 要取消的任务 ID 列表
178    ///
179    /// # 返回
180    /// 成功取消的任务数量
181    ///
182    /// # 示例
183    /// ```no_run
184    /// # use kestrel_protocol_timer::{TimerWheel, TimerService};
185    /// # use std::time::Duration;
186    /// # #[tokio::main]
187    /// # async fn main() {
188    /// let timer = TimerWheel::with_defaults();
189    /// let service = timer.create_service();
190    /// 
191    /// let callbacks: Vec<_> = (0..10)
192    ///     .map(|_| (Duration::from_secs(10), || async {}))
193    ///     .collect();
194    /// let tasks = TimerService::create_batch(callbacks);
195    /// let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
196    /// service.register_batch(tasks).unwrap();
197    /// 
198    /// // 批量取消
199    /// let cancelled = service.cancel_batch(&task_ids);
200    /// println!("成功取消 {} 个任务", cancelled);
201    /// # }
202    /// ```
203    #[inline]
204    pub fn cancel_batch(&self, task_ids: &[TaskId]) -> usize {
205        if task_ids.is_empty() {
206            return 0;
207        }
208
209        // 优化:直接使用底层的批量取消,无需通知 Actor
210        // FuturesUnordered 会在任务被取消时自动清理
211        let mut wheel = self.wheel.lock();
212        wheel.cancel_batch(task_ids)
213    }
214
215    /// 推迟任务(保持原回调)
216    ///
217    /// # 参数
218    /// - `task_id`: 要推迟的任务 ID
219    /// - `new_delay`: 新的延迟时间(从当前时间点重新计算)
220    ///
221    /// # 返回
222    /// - `true`: 任务存在且成功推迟
223    /// - `false`: 任务不存在或推迟失败
224    ///
225    /// # 性能说明
226    /// 此方法使用直接推迟优化,不需要等待 Actor 处理,大幅降低延迟
227    ///
228    /// # 注意
229    /// - 推迟后任务 ID 保持不变
230    /// - 原有的超时通知仍然有效
231    /// - 保持原回调函数不变
232    ///
233    /// # 示例
234    /// ```no_run
235    /// # use kestrel_protocol_timer::{TimerWheel, TimerService};
236    /// # use std::time::Duration;
237    /// # #[tokio::main]
238    /// # async fn main() {
239    /// let timer = TimerWheel::with_defaults();
240    /// let service = timer.create_service();
241    /// 
242    /// let task = TimerService::create_task(Duration::from_secs(5), || async {});
243    /// let task_id = task.get_id();
244    /// service.register(task).unwrap();
245    /// 
246    /// // 推迟到 10 秒后触发
247    /// let success = service.postpone_task(task_id, Duration::from_secs(10));
248    /// println!("推迟成功: {}", success);
249    /// # }
250    /// ```
251    #[inline]
252    pub fn postpone_task(&self, task_id: TaskId, new_delay: Duration) -> bool {
253        // 优化:直接推迟任务,无需通知 Actor
254        // FuturesUnordered 会继续监听原有的 completion_receiver
255        let mut wheel = self.wheel.lock();
256        wheel.postpone(task_id, new_delay, None)
257    }
258
259    /// 推迟任务(替换回调)
260    ///
261    /// # 参数
262    /// - `task_id`: 要推迟的任务 ID
263    /// - `new_delay`: 新的延迟时间(从当前时间点重新计算)
264    /// - `callback`: 新的回调函数
265    ///
266    /// # 返回
267    /// - `true`: 任务存在且成功推迟
268    /// - `false`: 任务不存在或推迟失败
269    ///
270    /// # 注意
271    /// - 推迟后任务 ID 保持不变
272    /// - 原有的超时通知仍然有效
273    /// - 回调函数会被替换为新的回调
274    ///
275    /// # 示例
276    /// ```no_run
277    /// # use kestrel_protocol_timer::{TimerWheel, TimerService};
278    /// # use std::time::Duration;
279    /// # #[tokio::main]
280    /// # async fn main() {
281    /// let timer = TimerWheel::with_defaults();
282    /// let service = timer.create_service();
283    /// 
284    /// let task = TimerService::create_task(Duration::from_secs(5), || async {
285    ///     println!("Original callback");
286    /// });
287    /// let task_id = task.get_id();
288    /// service.register(task).unwrap();
289    /// 
290    /// // 推迟并替换回调
291    /// let success = service.postpone_task_with_callback(
292    ///     task_id,
293    ///     Duration::from_secs(10),
294    ///     || async { println!("New callback!"); }
295    /// );
296    /// println!("推迟成功: {}", success);
297    /// # }
298    /// ```
299    #[inline]
300    pub fn postpone_task_with_callback<C>(
301        &self,
302        task_id: TaskId,
303        new_delay: Duration,
304        callback: C,
305    ) -> bool
306    where
307        C: TimerCallback,
308    {
309        use std::sync::Arc;
310        let callback_wrapper = Arc::new(callback) as Arc<dyn TimerCallback>;
311        let mut wheel = self.wheel.lock();
312        wheel.postpone(task_id, new_delay, Some(callback_wrapper))
313    }
314
315    /// 批量推迟任务(保持原回调)
316    ///
317    /// # 参数
318    /// - `updates`: (任务ID, 新延迟) 的元组列表
319    ///
320    /// # 返回
321    /// 成功推迟的任务数量
322    ///
323    /// # 示例
324    /// ```no_run
325    /// # use kestrel_protocol_timer::{TimerWheel, TimerService};
326    /// # use std::time::Duration;
327    /// # #[tokio::main]
328    /// # async fn main() {
329    /// let timer = TimerWheel::with_defaults();
330    /// let service = timer.create_service();
331    /// 
332    /// let callbacks: Vec<_> = (0..3)
333    ///     .map(|_| (Duration::from_secs(5), || async {}))
334    ///     .collect();
335    /// let tasks = TimerService::create_batch(callbacks);
336    /// let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
337    /// service.register_batch(tasks).unwrap();
338    /// 
339    /// // 批量推迟
340    /// let updates: Vec<_> = task_ids
341    ///     .into_iter()
342    ///     .map(|id| (id, Duration::from_secs(10)))
343    ///     .collect();
344    /// let postponed = service.postpone_batch(&updates);
345    /// println!("成功推迟 {} 个任务", postponed);
346    /// # }
347    /// ```
348    #[inline]
349    pub fn postpone_batch(&self, updates: &[(TaskId, Duration)]) -> usize {
350        if updates.is_empty() {
351            return 0;
352        }
353
354        let updates_vec: Vec<_> = updates
355            .iter()
356            .map(|(task_id, delay)| (*task_id, *delay, None))
357            .collect();
358        let mut wheel = self.wheel.lock();
359        wheel.postpone_batch(updates_vec)
360    }
361
362    /// 批量推迟任务(替换回调)
363    ///
364    /// # 参数
365    /// - `updates`: (任务ID, 新延迟, 新回调) 的元组列表
366    ///
367    /// # 返回
368    /// 成功推迟的任务数量
369    ///
370    /// # 示例
371    /// ```no_run
372    /// # use kestrel_protocol_timer::{TimerWheel, TimerService};
373    /// # use std::time::Duration;
374    /// # #[tokio::main]
375    /// # async fn main() {
376    /// let timer = TimerWheel::with_defaults();
377    /// let service = timer.create_service();
378    /// 
379    /// let callbacks: Vec<_> = (0..3)
380    ///     .map(|_| (Duration::from_secs(5), || async {}))
381    ///     .collect();
382    /// let tasks = TimerService::create_batch(callbacks);
383    /// let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
384    /// service.register_batch(tasks).unwrap();
385    /// 
386    /// // 批量推迟并替换回调
387    /// let updates: Vec<_> = task_ids
388    ///     .into_iter()
389    ///     .enumerate()
390    ///     .map(|(i, id)| {
391    ///         (id, Duration::from_secs(10), move || async move {
392    ///             println!("New callback {}", i);
393    ///         })
394    ///     })
395    ///     .collect();
396    /// let postponed = service.postpone_batch_with_callbacks(updates);
397    /// println!("成功推迟 {} 个任务", postponed);
398    /// # }
399    /// ```
400    #[inline]
401    pub fn postpone_batch_with_callbacks<C>(
402        &self,
403        updates: Vec<(TaskId, Duration, C)>,
404    ) -> usize
405    where
406        C: TimerCallback,
407    {
408        if updates.is_empty() {
409            return 0;
410        }
411
412        use std::sync::Arc;
413        let updates_vec: Vec<_> = updates
414            .into_iter()
415            .map(|(task_id, delay, callback)| {
416                let callback_wrapper = Arc::new(callback) as Arc<dyn TimerCallback>;
417                (task_id, delay, Some(callback_wrapper))
418            })
419            .collect();
420        let mut wheel = self.wheel.lock();
421        wheel.postpone_batch(updates_vec)
422    }
423
424    /// 创建定时器任务(静态方法,申请阶段)
425    /// 
426    /// # 参数
427    /// - `delay`: 延迟时间
428    /// - `callback`: 实现了 TimerCallback trait 的回调对象
429    /// 
430    /// # 返回
431    /// 返回 TimerTask,需要通过 `register()` 注册
432    /// 
433    /// # 示例
434    /// ```no_run
435    /// # use kestrel_protocol_timer::{TimerWheel, TimerService};
436    /// # use std::time::Duration;
437    /// # #[tokio::main]
438    /// # async fn main() {
439    /// let timer = TimerWheel::with_defaults();
440    /// let service = timer.create_service();
441    /// 
442    /// // 步骤 1: 创建任务
443    /// let task = TimerService::create_task(Duration::from_millis(100), || async {
444    ///     println!("Timer fired!");
445    /// });
446    /// 
447    /// let task_id = task.get_id();
448    /// println!("Created task: {:?}", task_id);
449    /// 
450    /// // 步骤 2: 注册任务
451    /// service.register(task).unwrap();
452    /// # }
453    /// ```
454    pub fn create_task<C>(delay: Duration, callback: C) -> crate::task::TimerTask
455    where
456        C: TimerCallback,
457    {
458        crate::timer::TimerWheel::create_task(delay, callback)
459    }
460    
461    /// 批量创建定时器任务(静态方法,申请阶段)
462    /// 
463    /// # 参数
464    /// - `callbacks`: (延迟时间, 回调) 的元组列表
465    /// 
466    /// # 返回
467    /// 返回 TimerTask 列表,需要通过 `register_batch()` 注册
468    /// 
469    /// # 示例
470    /// ```no_run
471    /// # use kestrel_protocol_timer::{TimerWheel, TimerService};
472    /// # use std::time::Duration;
473    /// # #[tokio::main]
474    /// # async fn main() {
475    /// let timer = TimerWheel::with_defaults();
476    /// let service = timer.create_service();
477    /// 
478    /// // 步骤 1: 批量创建任务
479    /// let callbacks: Vec<_> = (0..3)
480    ///     .map(|i| (Duration::from_millis(100 * (i + 1)), move || async move {
481    ///         println!("Timer {} fired!", i);
482    ///     }))
483    ///     .collect();
484    /// 
485    /// let tasks = TimerService::create_batch(callbacks);
486    /// println!("Created {} tasks", tasks.len());
487    /// 
488    /// // 步骤 2: 批量注册任务
489    /// service.register_batch(tasks).unwrap();
490    /// # }
491    /// ```
492    pub fn create_batch<C>(callbacks: Vec<(Duration, C)>) -> Vec<crate::task::TimerTask>
493    where
494        C: TimerCallback,
495    {
496        crate::timer::TimerWheel::create_batch(callbacks)
497    }
498    
499    /// 注册定时器任务到服务(注册阶段)
500    /// 
501    /// # 参数
502    /// - `task`: 通过 `create_task()` 创建的任务
503    /// 
504    /// # 返回
505    /// - `Ok(())`: 注册成功
506    /// - `Err(TimerError::RegisterFailed)`: 注册失败(内部通道已满或已关闭)
507    /// 
508    /// # 示例
509    /// ```no_run
510    /// # use kestrel_protocol_timer::{TimerWheel, TimerService};
511    /// # use std::time::Duration;
512    /// # #[tokio::main]
513    /// # async fn main() {
514    /// let timer = TimerWheel::with_defaults();
515    /// let service = timer.create_service();
516    /// 
517    /// let task = TimerService::create_task(Duration::from_millis(100), || async {
518    ///     println!("Timer fired!");
519    /// });
520    /// let task_id = task.get_id();
521    /// 
522    /// service.register(task).unwrap();
523    /// # }
524    /// ```
525    #[inline]
526    pub fn register(&self, task: crate::task::TimerTask) -> Result<(), TimerError> {
527        let (completion_tx, completion_rx) = tokio::sync::oneshot::channel();
528        let notifier = crate::task::CompletionNotifier(completion_tx);
529        
530        let delay = task.delay;
531        let task_id = task.id;
532        
533        // 单次加锁完成所有操作
534        {
535            let mut wheel_guard = self.wheel.lock();
536            wheel_guard.insert(delay, task, notifier);
537        }
538        
539        // 创建句柄并添加到服务管理
540        let handle = TimerHandle::new(task_id, self.wheel.clone(), completion_rx);
541        self.command_tx
542            .try_send(ServiceCommand::AddTimerHandle(handle))
543            .map_err(|_| TimerError::RegisterFailed)?;
544        
545        Ok(())
546    }
547    
548    /// 批量注册定时器任务到服务(注册阶段)
549    /// 
550    /// # 参数
551    /// - `tasks`: 通过 `create_batch()` 创建的任务列表
552    /// 
553    /// # 返回
554    /// - `Ok(())`: 注册成功
555    /// - `Err(TimerError::RegisterFailed)`: 注册失败(内部通道已满或已关闭)
556    /// 
557    /// # 示例
558    /// ```no_run
559    /// # use kestrel_protocol_timer::{TimerWheel, TimerService};
560    /// # use std::time::Duration;
561    /// # #[tokio::main]
562    /// # async fn main() {
563    /// let timer = TimerWheel::with_defaults();
564    /// let service = timer.create_service();
565    /// 
566    /// let callbacks: Vec<_> = (0..3)
567    ///     .map(|_| (Duration::from_secs(1), || async {}))
568    ///     .collect();
569    /// let tasks = TimerService::create_batch(callbacks);
570    /// 
571    /// service.register_batch(tasks).unwrap();
572    /// # }
573    /// ```
574    #[inline]
575    pub fn register_batch(&self, tasks: Vec<crate::task::TimerTask>) -> Result<(), TimerError> {
576        let task_count = tasks.len();
577        let mut completion_rxs = Vec::with_capacity(task_count);
578        let mut task_ids = Vec::with_capacity(task_count);
579        let mut prepared_tasks = Vec::with_capacity(task_count);
580        
581        // 步骤1: 准备所有 channels 和 notifiers(无锁)
582        // 优化:使用 for 循环代替 map + collect,避免闭包捕获开销
583        for task in tasks {
584            let (completion_tx, completion_rx) = tokio::sync::oneshot::channel();
585            let notifier = crate::task::CompletionNotifier(completion_tx);
586            
587            task_ids.push(task.id);
588            completion_rxs.push(completion_rx);
589            prepared_tasks.push((task.delay, task, notifier));
590        }
591        
592        // 步骤2: 单次加锁,批量插入
593        {
594            let mut wheel_guard = self.wheel.lock();
595            wheel_guard.insert_batch(prepared_tasks);
596        }
597        
598        // 创建批量句柄并添加到服务管理
599        let batch_handle = BatchHandle::new(task_ids, self.wheel.clone(), completion_rxs);
600        self.command_tx
601            .try_send(ServiceCommand::AddBatchHandle(batch_handle))
602            .map_err(|_| TimerError::RegisterFailed)?;
603        
604        Ok(())
605    }
606
607    /// 优雅关闭 TimerService
608    ///
609    /// # 示例
610    /// ```no_run
611    /// # use kestrel_protocol_timer::TimerWheel;
612    /// # #[tokio::main]
613    /// # async fn main() {
614    /// let timer = TimerWheel::with_defaults();
615    /// let mut service = timer.create_service();
616    /// 
617    /// // 使用 service...
618    /// 
619    /// service.shutdown().await;
620    /// # }
621    /// ```
622    pub async fn shutdown(mut self) {
623        let _ = self.command_tx.send(ServiceCommand::Shutdown).await;
624        if let Some(handle) = self.actor_handle.take() {
625            let _ = handle.await;
626        }
627    }
628}
629
630
631impl Drop for TimerService {
632    fn drop(&mut self) {
633        if let Some(handle) = self.actor_handle.take() {
634            handle.abort();
635        }
636    }
637}
638
639/// ServiceActor - 内部 Actor 实现
640struct ServiceActor {
641    /// 命令接收端
642    command_rx: mpsc::Receiver<ServiceCommand>,
643    /// 超时发送端
644    timeout_tx: mpsc::Sender<TaskId>,
645}
646
647impl ServiceActor {
648    fn new(command_rx: mpsc::Receiver<ServiceCommand>, timeout_tx: mpsc::Sender<TaskId>) -> Self {
649        Self {
650            command_rx,
651            timeout_tx,
652        }
653    }
654
655    async fn run(mut self) {
656        // 使用 FuturesUnordered 来监听所有的 completion_rxs
657        // 每个 future 返回 (TaskId, Result)
658        let mut futures: FuturesUnordered<BoxFuture<'static, (TaskId, Result<TaskCompletionReason, tokio::sync::oneshot::error::RecvError>)>> = FuturesUnordered::new();
659
660        loop {
661            tokio::select! {
662                // 监听超时事件
663                Some((task_id, result)) = futures.next() => {
664                    // 检查完成原因,只转发超时(Expired)事件,不转发取消(Cancelled)事件
665                    if let Ok(TaskCompletionReason::Expired) = result {
666                        let _ = self.timeout_tx.send(task_id).await;
667                    }
668                    // 任务会自动从 FuturesUnordered 中移除
669                }
670                
671                // 监听命令
672                Some(cmd) = self.command_rx.recv() => {
673                    match cmd {
674                        ServiceCommand::AddBatchHandle(batch) => {
675                            let BatchHandle {
676                                task_ids,
677                                completion_rxs,
678                                ..
679                            } = batch;
680                            
681                            // 将所有任务添加到 futures 中
682                            for (task_id, rx) in task_ids.into_iter().zip(completion_rxs.into_iter()) {
683                                let future: BoxFuture<'static, (TaskId, Result<TaskCompletionReason, tokio::sync::oneshot::error::RecvError>)> = Box::pin(async move {
684                                    (task_id, rx.await)
685                                });
686                                futures.push(future);
687                            }
688                        }
689                        ServiceCommand::AddTimerHandle(handle) => {
690                            let TimerHandle{
691                                task_id,
692                                completion_rx,
693                                ..
694                            } = handle;
695                            
696                            // 添加到 futures 中
697                            let future: BoxFuture<'static, (TaskId, Result<TaskCompletionReason, tokio::sync::oneshot::error::RecvError>)> = Box::pin(async move {
698                                (task_id, completion_rx.0.await)
699                            });
700                            futures.push(future);
701                        }
702                        ServiceCommand::Shutdown => {
703                            break;
704                        }
705                    }
706                }
707                
708                // 如果没有任何 future 且命令通道已关闭,退出循环
709                else => {
710                    break;
711                }
712            }
713        }
714    }
715}
716
717#[cfg(test)]
718mod tests {
719    use super::*;
720    use crate::TimerWheel;
721    use std::sync::atomic::{AtomicU32, Ordering};
722    use std::sync::Arc;
723    use std::time::Duration;
724
725    #[tokio::test]
726    async fn test_service_creation() {
727        let timer = TimerWheel::with_defaults();
728        let _service = timer.create_service();
729    }
730
731
732    #[tokio::test]
733    async fn test_add_timer_handle_and_receive_timeout() {
734        let timer = TimerWheel::with_defaults();
735        let mut service = timer.create_service();
736
737        // 创建单个定时器
738        let task = TimerService::create_task(Duration::from_millis(50), || async {});
739        let task_id = task.get_id();
740        
741        // 注册到 service
742        service.register(task).unwrap();
743
744        // 接收超时通知
745        let mut rx = service.take_receiver().unwrap();
746        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
747            .await
748            .expect("Should receive timeout notification")
749            .expect("Should receive Some value");
750
751        assert_eq!(received_task_id, task_id);
752    }
753
754
755    #[tokio::test]
756    async fn test_shutdown() {
757        let timer = TimerWheel::with_defaults();
758        let service = timer.create_service();
759
760        // 添加一些定时器
761        let task1 = TimerService::create_task(Duration::from_secs(10), || async {});
762        let task2 = TimerService::create_task(Duration::from_secs(10), || async {});
763        service.register(task1).unwrap();
764        service.register(task2).unwrap();
765
766        // 立即关闭(不等待定时器触发)
767        service.shutdown().await;
768    }
769
770
771
772    #[tokio::test]
773    async fn test_cancel_task() {
774        let timer = TimerWheel::with_defaults();
775        let service = timer.create_service();
776
777        // 添加一个长时间的定时器
778        let task = TimerService::create_task(Duration::from_secs(10), || async {});
779        let task_id = task.get_id();
780        
781        service.register(task).unwrap();
782
783        // 取消任务
784        let cancelled = service.cancel_task(task_id);
785        assert!(cancelled, "Task should be cancelled successfully");
786
787        // 尝试再次取消同一个任务,应该返回 false
788        let cancelled_again = service.cancel_task(task_id);
789        assert!(!cancelled_again, "Task should not exist anymore");
790    }
791
792    #[tokio::test]
793    async fn test_cancel_nonexistent_task() {
794        let timer = TimerWheel::with_defaults();
795        let service = timer.create_service();
796
797        // 添加一个定时器以初始化 service
798        let task = TimerService::create_task(Duration::from_millis(50), || async {});
799        service.register(task).unwrap();
800
801        // 尝试取消一个不存在的任务(创建一个不会实际注册的任务ID)
802        let fake_task = TimerService::create_task(Duration::from_millis(50), || async {});
803        let fake_task_id = fake_task.get_id();
804        // 不注册 fake_task
805        let cancelled = service.cancel_task(fake_task_id);
806        assert!(!cancelled, "Nonexistent task should not be cancelled");
807    }
808
809
810    #[tokio::test]
811    async fn test_task_timeout_cleans_up_task_sender() {
812        let timer = TimerWheel::with_defaults();
813        let mut service = timer.create_service();
814
815        // 添加一个短时间的定时器
816        let task = TimerService::create_task(Duration::from_millis(50), || async {});
817        let task_id = task.get_id();
818        
819        service.register(task).unwrap();
820
821        // 等待任务超时
822        let mut rx = service.take_receiver().unwrap();
823        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
824            .await
825            .expect("Should receive timeout notification")
826            .expect("Should receive Some value");
827        
828        assert_eq!(received_task_id, task_id);
829
830        // 等待一下确保内部清理完成
831        tokio::time::sleep(Duration::from_millis(10)).await;
832
833        // 尝试取消已经超时的任务,应该返回 false
834        let cancelled = service.cancel_task(task_id);
835        assert!(!cancelled, "Timed out task should not exist anymore");
836    }
837
838    #[tokio::test]
839    async fn test_cancel_task_spawns_background_task() {
840        let timer = TimerWheel::with_defaults();
841        let service = timer.create_service();
842        let counter = Arc::new(AtomicU32::new(0));
843
844        // 创建一个定时器
845        let counter_clone = Arc::clone(&counter);
846        let task = TimerService::create_task(
847            Duration::from_secs(10),
848            move || {
849                let counter = Arc::clone(&counter_clone);
850                async move {
851                    counter.fetch_add(1, Ordering::SeqCst);
852                }
853            },
854        );
855        let task_id = task.get_id();
856        
857        service.register(task).unwrap();
858
859        // 使用 cancel_task(会等待结果,但在后台协程中处理)
860        let cancelled = service.cancel_task(task_id);
861        assert!(cancelled, "Task should be cancelled successfully");
862
863        // 等待足够长时间确保回调不会被执行
864        tokio::time::sleep(Duration::from_millis(100)).await;
865        assert_eq!(counter.load(Ordering::SeqCst), 0, "Callback should not have been executed");
866
867        // 验证任务已从 active_tasks 中移除
868        let cancelled_again = service.cancel_task(task_id);
869        assert!(!cancelled_again, "Task should have been removed from active_tasks");
870    }
871
872    #[tokio::test]
873    async fn test_schedule_once_direct() {
874        let timer = TimerWheel::with_defaults();
875        let mut service = timer.create_service();
876        let counter = Arc::new(AtomicU32::new(0));
877
878        // 直接通过 service 调度定时器
879        let counter_clone = Arc::clone(&counter);
880        let task = TimerService::create_task(
881            Duration::from_millis(50),
882            move || {
883                let counter = Arc::clone(&counter_clone);
884                async move {
885                    counter.fetch_add(1, Ordering::SeqCst);
886                }
887            },
888        );
889        let task_id = task.get_id();
890        service.register(task).unwrap();
891
892        // 等待定时器触发
893        let mut rx = service.take_receiver().unwrap();
894        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
895            .await
896            .expect("Should receive timeout notification")
897            .expect("Should receive Some value");
898
899        assert_eq!(received_task_id, task_id);
900        
901        // 等待回调执行
902        tokio::time::sleep(Duration::from_millis(50)).await;
903        assert_eq!(counter.load(Ordering::SeqCst), 1);
904    }
905
906    #[tokio::test]
907    async fn test_schedule_once_batch_direct() {
908        let timer = TimerWheel::with_defaults();
909        let mut service = timer.create_service();
910        let counter = Arc::new(AtomicU32::new(0));
911
912        // 直接通过 service 批量调度定时器
913        let callbacks: Vec<_> = (0..3)
914            .map(|_| {
915                let counter = Arc::clone(&counter);
916                (Duration::from_millis(50), move || {
917                    let counter = Arc::clone(&counter);
918                    async move {
919                        counter.fetch_add(1, Ordering::SeqCst);
920                    }
921                })
922            })
923            .collect();
924
925        let tasks = TimerService::create_batch(callbacks);
926        assert_eq!(tasks.len(), 3);
927        service.register_batch(tasks).unwrap();
928
929        // 接收所有超时通知
930        let mut received_count = 0;
931        let mut rx = service.take_receiver().unwrap();
932        
933        while received_count < 3 {
934            match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
935                Ok(Some(_task_id)) => {
936                    received_count += 1;
937                }
938                Ok(None) => break,
939                Err(_) => break,
940            }
941        }
942
943        assert_eq!(received_count, 3);
944        
945        // 等待回调执行
946        tokio::time::sleep(Duration::from_millis(50)).await;
947        assert_eq!(counter.load(Ordering::SeqCst), 3);
948    }
949
950    #[tokio::test]
951    async fn test_schedule_once_notify_direct() {
952        let timer = TimerWheel::with_defaults();
953        let mut service = timer.create_service();
954
955        // 直接通过 service 调度仅通知的定时器(无回调)
956        let task = crate::task::TimerTask::new(Duration::from_millis(50), None);
957        let task_id = task.get_id();
958        service.register(task).unwrap();
959
960        // 接收超时通知
961        let mut rx = service.take_receiver().unwrap();
962        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
963            .await
964            .expect("Should receive timeout notification")
965            .expect("Should receive Some value");
966
967        assert_eq!(received_task_id, task_id);
968    }
969
970    #[tokio::test]
971    async fn test_schedule_and_cancel_direct() {
972        let timer = TimerWheel::with_defaults();
973        let service = timer.create_service();
974        let counter = Arc::new(AtomicU32::new(0));
975
976        // 直接调度定时器
977        let counter_clone = Arc::clone(&counter);
978        let task = TimerService::create_task(
979            Duration::from_secs(10),
980            move || {
981                let counter = Arc::clone(&counter_clone);
982                async move {
983                    counter.fetch_add(1, Ordering::SeqCst);
984                }
985            },
986        );
987        let task_id = task.get_id();
988        service.register(task).unwrap();
989
990        // 立即取消
991        let cancelled = service.cancel_task(task_id);
992        assert!(cancelled, "Task should be cancelled successfully");
993
994        // 等待确保回调不会执行
995        tokio::time::sleep(Duration::from_millis(100)).await;
996        assert_eq!(counter.load(Ordering::SeqCst), 0, "Callback should not have been executed");
997    }
998
999    #[tokio::test]
1000    async fn test_cancel_batch_direct() {
1001        let timer = TimerWheel::with_defaults();
1002        let service = timer.create_service();
1003        let counter = Arc::new(AtomicU32::new(0));
1004
1005        // 批量调度定时器
1006        let callbacks: Vec<_> = (0..10)
1007            .map(|_| {
1008                let counter = Arc::clone(&counter);
1009                (Duration::from_secs(10), move || {
1010                    let counter = Arc::clone(&counter);
1011                    async move {
1012                        counter.fetch_add(1, Ordering::SeqCst);
1013                    }
1014                })
1015            })
1016            .collect();
1017
1018        let tasks = TimerService::create_batch(callbacks);
1019        let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
1020        assert_eq!(task_ids.len(), 10);
1021        service.register_batch(tasks).unwrap();
1022
1023        // 批量取消所有任务
1024        let cancelled = service.cancel_batch(&task_ids);
1025        assert_eq!(cancelled, 10, "All 10 tasks should be cancelled");
1026
1027        // 等待确保回调不会执行
1028        tokio::time::sleep(Duration::from_millis(100)).await;
1029        assert_eq!(counter.load(Ordering::SeqCst), 0, "No callbacks should have been executed");
1030    }
1031
1032    #[tokio::test]
1033    async fn test_cancel_batch_partial() {
1034        let timer = TimerWheel::with_defaults();
1035        let service = timer.create_service();
1036        let counter = Arc::new(AtomicU32::new(0));
1037
1038        // 批量调度定时器
1039        let callbacks: Vec<_> = (0..10)
1040            .map(|_| {
1041                let counter = Arc::clone(&counter);
1042                (Duration::from_secs(10), move || {
1043                    let counter = Arc::clone(&counter);
1044                    async move {
1045                        counter.fetch_add(1, Ordering::SeqCst);
1046                    }
1047                })
1048            })
1049            .collect();
1050
1051        let tasks = TimerService::create_batch(callbacks);
1052        let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
1053        service.register_batch(tasks).unwrap();
1054
1055        // 只取消前5个任务
1056        let to_cancel: Vec<_> = task_ids.iter().take(5).copied().collect();
1057        let cancelled = service.cancel_batch(&to_cancel);
1058        assert_eq!(cancelled, 5, "5 tasks should be cancelled");
1059
1060        // 等待确保前5个回调不会执行
1061        tokio::time::sleep(Duration::from_millis(100)).await;
1062        assert_eq!(counter.load(Ordering::SeqCst), 0, "Cancelled tasks should not execute");
1063    }
1064
1065    #[tokio::test]
1066    async fn test_cancel_batch_empty() {
1067        let timer = TimerWheel::with_defaults();
1068        let service = timer.create_service();
1069
1070        // 取消空列表
1071        let empty: Vec<TaskId> = vec![];
1072        let cancelled = service.cancel_batch(&empty);
1073        assert_eq!(cancelled, 0, "No tasks should be cancelled");
1074    }
1075
1076    #[tokio::test]
1077    async fn test_postpone_task() {
1078        let timer = TimerWheel::with_defaults();
1079        let mut service = timer.create_service();
1080        let counter = Arc::new(AtomicU32::new(0));
1081
1082        // 注册一个任务,延迟 50ms
1083        let counter_clone = Arc::clone(&counter);
1084        let task = TimerService::create_task(
1085            Duration::from_millis(50),
1086            move || {
1087                let counter = Arc::clone(&counter_clone);
1088                async move {
1089                    counter.fetch_add(1, Ordering::SeqCst);
1090                }
1091            },
1092        );
1093        let task_id = task.get_id();
1094        service.register(task).unwrap();
1095
1096        // 推迟任务到 150ms
1097        let postponed = service.postpone_task(task_id, Duration::from_millis(150));
1098        assert!(postponed, "Task should be postponed successfully");
1099
1100        // 等待原定时间 50ms,任务不应该触发
1101        tokio::time::sleep(Duration::from_millis(70)).await;
1102        assert_eq!(counter.load(Ordering::SeqCst), 0);
1103
1104        // 接收超时通知(从推迟开始算,还需要等待约 150ms)
1105        let mut rx = service.take_receiver().unwrap();
1106        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1107            .await
1108            .expect("Should receive timeout notification")
1109            .expect("Should receive Some value");
1110
1111        assert_eq!(received_task_id, task_id);
1112        
1113        // 等待回调执行
1114        tokio::time::sleep(Duration::from_millis(20)).await;
1115        assert_eq!(counter.load(Ordering::SeqCst), 1);
1116    }
1117
1118    #[tokio::test]
1119    async fn test_postpone_task_with_callback() {
1120        let timer = TimerWheel::with_defaults();
1121        let mut service = timer.create_service();
1122        let counter = Arc::new(AtomicU32::new(0));
1123
1124        // 注册一个任务,原始回调增加 1
1125        let counter_clone1 = Arc::clone(&counter);
1126        let task = TimerService::create_task(
1127            Duration::from_millis(50),
1128            move || {
1129                let counter = Arc::clone(&counter_clone1);
1130                async move {
1131                    counter.fetch_add(1, Ordering::SeqCst);
1132                }
1133            },
1134        );
1135        let task_id = task.get_id();
1136        service.register(task).unwrap();
1137
1138        // 推迟任务并替换回调,新回调增加 10
1139        let counter_clone2 = Arc::clone(&counter);
1140        let postponed = service.postpone_task_with_callback(
1141            task_id,
1142            Duration::from_millis(100),
1143            move || {
1144                let counter = Arc::clone(&counter_clone2);
1145                async move {
1146                    counter.fetch_add(10, Ordering::SeqCst);
1147                }
1148            }
1149        );
1150        assert!(postponed, "Task should be postponed successfully");
1151
1152        // 接收超时通知(推迟后需要等待100ms,加上余量)
1153        let mut rx = service.take_receiver().unwrap();
1154        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1155            .await
1156            .expect("Should receive timeout notification")
1157            .expect("Should receive Some value");
1158
1159        assert_eq!(received_task_id, task_id);
1160        
1161        // 等待回调执行
1162        tokio::time::sleep(Duration::from_millis(20)).await;
1163        
1164        // 验证新回调被执行(增加了 10 而不是 1)
1165        assert_eq!(counter.load(Ordering::SeqCst), 10);
1166    }
1167
1168    #[tokio::test]
1169    async fn test_postpone_nonexistent_task() {
1170        let timer = TimerWheel::with_defaults();
1171        let service = timer.create_service();
1172
1173        // 尝试推迟一个不存在的任务
1174        let fake_task = TimerService::create_task(Duration::from_millis(50), || async {});
1175        let fake_task_id = fake_task.get_id();
1176        // 不注册这个任务
1177        
1178        let postponed = service.postpone_task(fake_task_id, Duration::from_millis(100));
1179        assert!(!postponed, "Nonexistent task should not be postponed");
1180    }
1181
1182    #[tokio::test]
1183    async fn test_postpone_batch() {
1184        let timer = TimerWheel::with_defaults();
1185        let mut service = timer.create_service();
1186        let counter = Arc::new(AtomicU32::new(0));
1187
1188        // 注册 3 个任务
1189        let mut task_ids = Vec::new();
1190        for _ in 0..3 {
1191            let counter_clone = Arc::clone(&counter);
1192            let task = TimerService::create_task(
1193                Duration::from_millis(50),
1194                move || {
1195                    let counter = Arc::clone(&counter_clone);
1196                    async move {
1197                        counter.fetch_add(1, Ordering::SeqCst);
1198                    }
1199                },
1200            );
1201            task_ids.push((task.get_id(), Duration::from_millis(150)));
1202            service.register(task).unwrap();
1203        }
1204
1205        // 批量推迟
1206        let postponed = service.postpone_batch(&task_ids);
1207        assert_eq!(postponed, 3, "All 3 tasks should be postponed");
1208
1209        // 等待原定时间 50ms,任务不应该触发
1210        tokio::time::sleep(Duration::from_millis(70)).await;
1211        assert_eq!(counter.load(Ordering::SeqCst), 0);
1212
1213        // 接收所有超时通知
1214        let mut received_count = 0;
1215        let mut rx = service.take_receiver().unwrap();
1216        
1217        while received_count < 3 {
1218            match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
1219                Ok(Some(_task_id)) => {
1220                    received_count += 1;
1221                }
1222                Ok(None) => break,
1223                Err(_) => break,
1224            }
1225        }
1226
1227        assert_eq!(received_count, 3);
1228        
1229        // 等待回调执行
1230        tokio::time::sleep(Duration::from_millis(20)).await;
1231        assert_eq!(counter.load(Ordering::SeqCst), 3);
1232    }
1233
1234    #[tokio::test]
1235    async fn test_postpone_batch_with_callbacks() {
1236        let timer = TimerWheel::with_defaults();
1237        let mut service = timer.create_service();
1238        let counter = Arc::new(AtomicU32::new(0));
1239
1240        // 注册 3 个任务
1241        let mut task_ids = Vec::new();
1242        for _ in 0..3 {
1243            let task = TimerService::create_task(
1244                Duration::from_millis(50),
1245                || async {},
1246            );
1247            task_ids.push(task.get_id());
1248            service.register(task).unwrap();
1249        }
1250
1251        // 批量推迟并替换回调
1252        let updates: Vec<_> = task_ids
1253            .into_iter()
1254            .map(|id| {
1255                let counter_clone = Arc::clone(&counter);
1256                (id, Duration::from_millis(150), move || {
1257                    let counter = Arc::clone(&counter_clone);
1258                    async move {
1259                        counter.fetch_add(1, Ordering::SeqCst);
1260                    }
1261                })
1262            })
1263            .collect();
1264
1265        let postponed = service.postpone_batch_with_callbacks(updates);
1266        assert_eq!(postponed, 3, "All 3 tasks should be postponed");
1267
1268        // 等待原定时间 50ms,任务不应该触发
1269        tokio::time::sleep(Duration::from_millis(70)).await;
1270        assert_eq!(counter.load(Ordering::SeqCst), 0);
1271
1272        // 接收所有超时通知
1273        let mut received_count = 0;
1274        let mut rx = service.take_receiver().unwrap();
1275        
1276        while received_count < 3 {
1277            match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
1278                Ok(Some(_task_id)) => {
1279                    received_count += 1;
1280                }
1281                Ok(None) => break,
1282                Err(_) => break,
1283            }
1284        }
1285
1286        assert_eq!(received_count, 3);
1287        
1288        // 等待回调执行
1289        tokio::time::sleep(Duration::from_millis(20)).await;
1290        assert_eq!(counter.load(Ordering::SeqCst), 3);
1291    }
1292
1293    #[tokio::test]
1294    async fn test_postpone_batch_empty() {
1295        let timer = TimerWheel::with_defaults();
1296        let service = timer.create_service();
1297
1298        // 推迟空列表
1299        let empty: Vec<(TaskId, Duration)> = vec![];
1300        let postponed = service.postpone_batch(&empty);
1301        assert_eq!(postponed, 0, "No tasks should be postponed");
1302    }
1303
1304    #[tokio::test]
1305    async fn test_postpone_keeps_timeout_notification_valid() {
1306        let timer = TimerWheel::with_defaults();
1307        let mut service = timer.create_service();
1308        let counter = Arc::new(AtomicU32::new(0));
1309
1310        // 注册一个任务
1311        let counter_clone = Arc::clone(&counter);
1312        let task = TimerService::create_task(
1313            Duration::from_millis(50),
1314            move || {
1315                let counter = Arc::clone(&counter_clone);
1316                async move {
1317                    counter.fetch_add(1, Ordering::SeqCst);
1318                }
1319            },
1320        );
1321        let task_id = task.get_id();
1322        service.register(task).unwrap();
1323
1324        // 推迟任务
1325        service.postpone_task(task_id, Duration::from_millis(100));
1326
1327        // 验证超时通知仍然有效(推迟后需要等待100ms,加上余量)
1328        let mut rx = service.take_receiver().unwrap();
1329        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1330            .await
1331            .expect("Should receive timeout notification")
1332            .expect("Should receive Some value");
1333
1334        assert_eq!(received_task_id, task_id, "Timeout notification should still work after postpone");
1335        
1336        // 等待回调执行
1337        tokio::time::sleep(Duration::from_millis(20)).await;
1338        assert_eq!(counter.load(Ordering::SeqCst), 1);
1339    }
1340
1341    #[tokio::test]
1342    async fn test_cancelled_task_not_forwarded_to_timeout_rx() {
1343        let timer = TimerWheel::with_defaults();
1344        let mut service = timer.create_service();
1345
1346        // 注册两个任务:一个会被取消,一个会正常到期
1347        let task1 = TimerService::create_task(Duration::from_secs(10), || async {});
1348        let task1_id = task1.get_id();
1349        service.register(task1).unwrap();
1350
1351        let task2 = TimerService::create_task(Duration::from_millis(50), || async {});
1352        let task2_id = task2.get_id();
1353        service.register(task2).unwrap();
1354
1355        // 取消第一个任务
1356        let cancelled = service.cancel_task(task1_id);
1357        assert!(cancelled, "Task should be cancelled");
1358
1359        // 等待第二个任务到期
1360        let mut rx = service.take_receiver().unwrap();
1361        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1362            .await
1363            .expect("Should receive timeout notification")
1364            .expect("Should receive Some value");
1365
1366        // 应该只收到第二个任务(到期的)的通知,不应该收到第一个任务(取消的)的通知
1367        assert_eq!(received_task_id, task2_id, "Should only receive expired task notification");
1368
1369        // 验证没有其他通知(特别是被取消的任务不应该有通知)
1370        let no_more = tokio::time::timeout(Duration::from_millis(50), rx.recv()).await;
1371        assert!(no_more.is_err(), "Should not receive any more notifications");
1372    }
1373}
1374