kestrel_protocol_timer/
service.rs

1use crate::config::ServiceConfig;
2use crate::task::{TaskCompletionReason, TaskId, TimerCallback};
3use crate::timer::{BatchHandle, TimerHandle};
4use crate::wheel::Wheel;
5use futures::stream::{FuturesUnordered, StreamExt};
6use futures::future::BoxFuture;
7use parking_lot::Mutex;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::sync::mpsc;
11use tokio::task::JoinHandle;
12
13/// TimerService 命令类型
14enum ServiceCommand {
15    /// 添加批量定时器句柄
16    AddBatchHandle(BatchHandle),
17    /// 添加单个定时器句柄
18    AddTimerHandle(TimerHandle),
19    /// 关闭 Service
20    Shutdown,
21}
22
23/// TimerService - 基于 Actor 模式的定时器服务
24///
25/// 管理多个定时器句柄,监听所有超时事件,并将 TaskId 聚合转发给用户。
26///
27/// # 特性
28/// - 自动监听所有添加的定时器句柄的超时事件
29/// - 超时后自动从内部管理中移除该任务
30/// - 将超时的 TaskId 转发到统一的通道供用户接收
31/// - 支持动态添加 BatchHandle 和 TimerHandle
32///
33/// # 示例
34/// ```no_run
35/// use kestrel_protocol_timer::{TimerWheel, TimerService};
36/// use std::time::Duration;
37///
38/// #[tokio::main]
39/// async fn main() {
40///     let timer = TimerWheel::with_defaults();
41///     let mut service = timer.create_service();
42///     
43///     // 使用两步式 API 通过 service 批量调度定时器
44///     let callbacks: Vec<_> = (0..5)
45///         .map(|_| (Duration::from_millis(100), || async {}))
46///         .collect();
47///     let tasks = TimerService::create_batch(callbacks);
48///     service.register_batch(tasks).await;
49///     
50///     // 接收超时通知
51///     let mut rx = service.take_receiver().unwrap();
52///     while let Some(task_id) = rx.recv().await {
53///         println!("Task {:?} completed", task_id);
54///     }
55/// }
56/// ```
57pub struct TimerService {
58    /// 命令发送端
59    command_tx: mpsc::Sender<ServiceCommand>,
60    /// 超时接收端
61    timeout_rx: Option<mpsc::Receiver<TaskId>>,
62    /// Actor 任务句柄
63    actor_handle: Option<JoinHandle<()>>,
64    /// 时间轮引用(用于直接调度定时器)
65    wheel: Arc<Mutex<Wheel>>,
66}
67
68impl TimerService {
69    /// 创建新的 TimerService
70    ///
71    /// # 参数
72    /// - `wheel`: 时间轮引用
73    /// - `config`: 服务配置
74    ///
75    /// # 注意
76    /// 通常不直接调用此方法,而是使用 `TimerWheel::create_service()` 来创建。
77    ///
78    /// # 示例
79    /// ```no_run
80    /// use kestrel_protocol_timer::TimerWheel;
81    ///
82    /// #[tokio::main]
83    /// async fn main() {
84    ///     let timer = TimerWheel::with_defaults();
85    ///     let mut service = timer.create_service();
86    /// }
87    /// ```
88    pub(crate) fn new(wheel: Arc<Mutex<Wheel>>, config: ServiceConfig) -> Self {
89        let (command_tx, command_rx) = mpsc::channel(config.command_channel_capacity);
90        let (timeout_tx, timeout_rx) = mpsc::channel(config.timeout_channel_capacity);
91
92        let actor = ServiceActor::new(command_rx, timeout_tx);
93        let actor_handle = tokio::spawn(async move {
94            actor.run().await;
95        });
96
97        Self {
98            command_tx,
99            timeout_rx: Some(timeout_rx),
100            actor_handle: Some(actor_handle),
101            wheel,
102        }
103    }
104
105    /// 添加批量定时器句柄(内部方法)
106    async fn add_batch_handle(&self, batch: BatchHandle) {
107        let _ = self.command_tx
108            .send(ServiceCommand::AddBatchHandle(batch))
109            .await;
110    }
111
112    /// 添加单个定时器句柄(内部方法)
113    async fn add_timer_handle(&self, handle: TimerHandle) {
114        let _ = self.command_tx
115            .send(ServiceCommand::AddTimerHandle(handle))
116            .await;
117    }
118
119    /// 获取超时接收器(转移所有权)
120    ///
121    /// # 返回
122    /// 超时通知接收器,如果已经被取走则返回 None
123    ///
124    /// # 注意
125    /// 此方法只能调用一次,因为它会转移接收器的所有权
126    ///
127    /// # 示例
128    /// ```no_run
129    /// # use kestrel_protocol_timer::TimerWheel;
130    /// # use std::time::Duration;
131    /// # #[tokio::main]
132    /// # async fn main() {
133    /// let timer = TimerWheel::with_defaults();
134    /// let mut service = timer.create_service();
135    /// 
136    /// let mut rx = service.take_receiver().unwrap();
137    /// while let Some(task_id) = rx.recv().await {
138    ///     println!("Task {:?} timed out", task_id);
139    /// }
140    /// # }
141    /// ```
142    pub fn take_receiver(&mut self) -> Option<mpsc::Receiver<TaskId>> {
143        self.timeout_rx.take()
144    }
145
146    /// 取消指定的任务
147    ///
148    /// # 参数
149    /// - `task_id`: 要取消的任务 ID
150    ///
151    /// # 返回
152    /// - `Ok(true)`: 任务存在且成功取消
153    /// - `Ok(false)`: 任务不存在或取消失败
154    /// - `Err(String)`: 发送命令失败
155    ///
156    /// # 性能说明
157    /// 此方法使用直接取消优化,不需要等待 Actor 处理,大幅降低延迟
158    ///
159    /// # 示例
160    /// ```no_run
161    /// # use kestrel_protocol_timer::{TimerWheel, TimerService};
162    /// # use std::time::Duration;
163    /// # #[tokio::main]
164    /// # async fn main() {
165    /// let timer = TimerWheel::with_defaults();
166    /// let service = timer.create_service();
167    /// 
168    /// // 使用两步式 API 调度定时器
169    /// let task = TimerService::create_task(Duration::from_secs(10), || async {});
170    /// let task_id = task.get_id();
171    /// service.register(task).await;
172    /// 
173    /// // 取消任务
174    /// let cancelled = service.cancel_task(task_id).await;
175    /// println!("Task cancelled: {}", cancelled);
176    /// # }
177    /// ```
178    #[inline]
179    pub async fn cancel_task(&self, task_id: TaskId) -> bool {
180        // 优化:直接取消任务,无需通知 Actor
181        // FuturesUnordered 会在任务被取消时自动清理
182        let mut wheel = self.wheel.lock();
183        wheel.cancel(task_id)
184    }
185
186    /// 批量取消任务
187    ///
188    /// 使用底层的批量取消操作一次性取消多个任务,性能优于循环调用 cancel_task。
189    ///
190    /// # 参数
191    /// - `task_ids`: 要取消的任务 ID 列表
192    ///
193    /// # 返回
194    /// 成功取消的任务数量
195    ///
196    /// # 示例
197    /// ```no_run
198    /// # use kestrel_protocol_timer::{TimerWheel, TimerService};
199    /// # use std::time::Duration;
200    /// # #[tokio::main]
201    /// # async fn main() {
202    /// let timer = TimerWheel::with_defaults();
203    /// let service = timer.create_service();
204    /// 
205    /// let callbacks: Vec<_> = (0..10)
206    ///     .map(|_| (Duration::from_secs(10), || async {}))
207    ///     .collect();
208    /// let tasks = TimerService::create_batch(callbacks);
209    /// let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
210    /// service.register_batch(tasks).await;
211    /// 
212    /// // 批量取消
213    /// let cancelled = service.cancel_batch(&task_ids).await;
214    /// println!("成功取消 {} 个任务", cancelled);
215    /// # }
216    /// ```
217    #[inline]
218    pub async fn cancel_batch(&self, task_ids: &[TaskId]) -> usize {
219        if task_ids.is_empty() {
220            return 0;
221        }
222
223        // 优化:直接使用底层的批量取消,无需通知 Actor
224        // FuturesUnordered 会在任务被取消时自动清理
225        let mut wheel = self.wheel.lock();
226        wheel.cancel_batch(task_ids)
227    }
228
229    /// 推迟任务(保持原回调)
230    ///
231    /// # 参数
232    /// - `task_id`: 要推迟的任务 ID
233    /// - `new_delay`: 新的延迟时间(从当前时间点重新计算)
234    ///
235    /// # 返回
236    /// - `true`: 任务存在且成功推迟
237    /// - `false`: 任务不存在或推迟失败
238    ///
239    /// # 性能说明
240    /// 此方法使用直接推迟优化,不需要等待 Actor 处理,大幅降低延迟
241    ///
242    /// # 注意
243    /// - 推迟后任务 ID 保持不变
244    /// - 原有的超时通知仍然有效
245    /// - 保持原回调函数不变
246    ///
247    /// # 示例
248    /// ```no_run
249    /// # use kestrel_protocol_timer::{TimerWheel, TimerService};
250    /// # use std::time::Duration;
251    /// # #[tokio::main]
252    /// # async fn main() {
253    /// let timer = TimerWheel::with_defaults();
254    /// let service = timer.create_service();
255    /// 
256    /// let task = TimerService::create_task(Duration::from_secs(5), || async {});
257    /// let task_id = task.get_id();
258    /// service.register(task).await;
259    /// 
260    /// // 推迟到 10 秒后触发
261    /// let success = service.postpone_task(task_id, Duration::from_secs(10)).await;
262    /// println!("推迟成功: {}", success);
263    /// # }
264    /// ```
265    #[inline]
266    pub async fn postpone_task(&self, task_id: TaskId, new_delay: Duration) -> bool {
267        // 优化:直接推迟任务,无需通知 Actor
268        // FuturesUnordered 会继续监听原有的 completion_receiver
269        let mut wheel = self.wheel.lock();
270        wheel.postpone(task_id, new_delay, None)
271    }
272
273    /// 推迟任务(替换回调)
274    ///
275    /// # 参数
276    /// - `task_id`: 要推迟的任务 ID
277    /// - `new_delay`: 新的延迟时间(从当前时间点重新计算)
278    /// - `callback`: 新的回调函数
279    ///
280    /// # 返回
281    /// - `true`: 任务存在且成功推迟
282    /// - `false`: 任务不存在或推迟失败
283    ///
284    /// # 注意
285    /// - 推迟后任务 ID 保持不变
286    /// - 原有的超时通知仍然有效
287    /// - 回调函数会被替换为新的回调
288    ///
289    /// # 示例
290    /// ```no_run
291    /// # use kestrel_protocol_timer::{TimerWheel, TimerService};
292    /// # use std::time::Duration;
293    /// # #[tokio::main]
294    /// # async fn main() {
295    /// let timer = TimerWheel::with_defaults();
296    /// let service = timer.create_service();
297    /// 
298    /// let task = TimerService::create_task(Duration::from_secs(5), || async {
299    ///     println!("Original callback");
300    /// });
301    /// let task_id = task.get_id();
302    /// service.register(task).await;
303    /// 
304    /// // 推迟并替换回调
305    /// let success = service.postpone_task_with_callback(
306    ///     task_id,
307    ///     Duration::from_secs(10),
308    ///     || async { println!("New callback!"); }
309    /// ).await;
310    /// println!("推迟成功: {}", success);
311    /// # }
312    /// ```
313    #[inline]
314    pub async fn postpone_task_with_callback<C>(
315        &self,
316        task_id: TaskId,
317        new_delay: Duration,
318        callback: C,
319    ) -> bool
320    where
321        C: TimerCallback,
322    {
323        use std::sync::Arc;
324        let callback_wrapper = Arc::new(callback) as Arc<dyn TimerCallback>;
325        let mut wheel = self.wheel.lock();
326        wheel.postpone(task_id, new_delay, Some(callback_wrapper))
327    }
328
329    /// 批量推迟任务(保持原回调)
330    ///
331    /// # 参数
332    /// - `updates`: (任务ID, 新延迟) 的元组列表
333    ///
334    /// # 返回
335    /// 成功推迟的任务数量
336    ///
337    /// # 示例
338    /// ```no_run
339    /// # use kestrel_protocol_timer::{TimerWheel, TimerService};
340    /// # use std::time::Duration;
341    /// # #[tokio::main]
342    /// # async fn main() {
343    /// let timer = TimerWheel::with_defaults();
344    /// let service = timer.create_service();
345    /// 
346    /// let callbacks: Vec<_> = (0..3)
347    ///     .map(|_| (Duration::from_secs(5), || async {}))
348    ///     .collect();
349    /// let tasks = TimerService::create_batch(callbacks);
350    /// let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
351    /// service.register_batch(tasks).await;
352    /// 
353    /// // 批量推迟
354    /// let updates: Vec<_> = task_ids
355    ///     .into_iter()
356    ///     .map(|id| (id, Duration::from_secs(10)))
357    ///     .collect();
358    /// let postponed = service.postpone_batch(&updates).await;
359    /// println!("成功推迟 {} 个任务", postponed);
360    /// # }
361    /// ```
362    #[inline]
363    pub async fn postpone_batch(&self, updates: &[(TaskId, Duration)]) -> usize {
364        if updates.is_empty() {
365            return 0;
366        }
367
368        let updates_vec: Vec<_> = updates
369            .iter()
370            .map(|(task_id, delay)| (*task_id, *delay, None))
371            .collect();
372        let mut wheel = self.wheel.lock();
373        wheel.postpone_batch(updates_vec)
374    }
375
376    /// 批量推迟任务(替换回调)
377    ///
378    /// # 参数
379    /// - `updates`: (任务ID, 新延迟, 新回调) 的元组列表
380    ///
381    /// # 返回
382    /// 成功推迟的任务数量
383    ///
384    /// # 示例
385    /// ```no_run
386    /// # use kestrel_protocol_timer::{TimerWheel, TimerService};
387    /// # use std::time::Duration;
388    /// # #[tokio::main]
389    /// # async fn main() {
390    /// let timer = TimerWheel::with_defaults();
391    /// let service = timer.create_service();
392    /// 
393    /// let callbacks: Vec<_> = (0..3)
394    ///     .map(|_| (Duration::from_secs(5), || async {}))
395    ///     .collect();
396    /// let tasks = TimerService::create_batch(callbacks);
397    /// let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
398    /// service.register_batch(tasks).await;
399    /// 
400    /// // 批量推迟并替换回调
401    /// let updates: Vec<_> = task_ids
402    ///     .into_iter()
403    ///     .enumerate()
404    ///     .map(|(i, id)| {
405    ///         (id, Duration::from_secs(10), move || async move {
406    ///             println!("New callback {}", i);
407    ///         })
408    ///     })
409    ///     .collect();
410    /// let postponed = service.postpone_batch_with_callbacks(updates).await;
411    /// println!("成功推迟 {} 个任务", postponed);
412    /// # }
413    /// ```
414    #[inline]
415    pub async fn postpone_batch_with_callbacks<C>(
416        &self,
417        updates: Vec<(TaskId, Duration, C)>,
418    ) -> usize
419    where
420        C: TimerCallback,
421    {
422        if updates.is_empty() {
423            return 0;
424        }
425
426        use std::sync::Arc;
427        let updates_vec: Vec<_> = updates
428            .into_iter()
429            .map(|(task_id, delay, callback)| {
430                let callback_wrapper = Arc::new(callback) as Arc<dyn TimerCallback>;
431                (task_id, delay, Some(callback_wrapper))
432            })
433            .collect();
434        let mut wheel = self.wheel.lock();
435        wheel.postpone_batch(updates_vec)
436    }
437
438    /// 创建定时器任务(静态方法,申请阶段)
439    /// 
440    /// # 参数
441    /// - `delay`: 延迟时间
442    /// - `callback`: 实现了 TimerCallback trait 的回调对象
443    /// 
444    /// # 返回
445    /// 返回 TimerTask,需要通过 `register()` 注册
446    /// 
447    /// # 示例
448    /// ```no_run
449    /// # use kestrel_protocol_timer::{TimerWheel, TimerService};
450    /// # use std::time::Duration;
451    /// # #[tokio::main]
452    /// # async fn main() {
453    /// let timer = TimerWheel::with_defaults();
454    /// let service = timer.create_service();
455    /// 
456    /// // 步骤 1: 创建任务
457    /// let task = TimerService::create_task(Duration::from_millis(100), || async {
458    ///     println!("Timer fired!");
459    /// });
460    /// 
461    /// let task_id = task.get_id();
462    /// println!("Created task: {:?}", task_id);
463    /// 
464    /// // 步骤 2: 注册任务
465    /// service.register(task).await;
466    /// # }
467    /// ```
468    pub fn create_task<C>(delay: Duration, callback: C) -> crate::task::TimerTask
469    where
470        C: TimerCallback,
471    {
472        crate::timer::TimerWheel::create_task(delay, callback)
473    }
474    
475    /// 批量创建定时器任务(静态方法,申请阶段)
476    /// 
477    /// # 参数
478    /// - `callbacks`: (延迟时间, 回调) 的元组列表
479    /// 
480    /// # 返回
481    /// 返回 TimerTask 列表,需要通过 `register_batch()` 注册
482    /// 
483    /// # 示例
484    /// ```no_run
485    /// # use kestrel_protocol_timer::{TimerWheel, TimerService};
486    /// # use std::time::Duration;
487    /// # #[tokio::main]
488    /// # async fn main() {
489    /// let timer = TimerWheel::with_defaults();
490    /// let service = timer.create_service();
491    /// 
492    /// // 步骤 1: 批量创建任务
493    /// let callbacks: Vec<_> = (0..3)
494    ///     .map(|i| (Duration::from_millis(100 * (i + 1)), move || async move {
495    ///         println!("Timer {} fired!", i);
496    ///     }))
497    ///     .collect();
498    /// 
499    /// let tasks = TimerService::create_batch(callbacks);
500    /// println!("Created {} tasks", tasks.len());
501    /// 
502    /// // 步骤 2: 批量注册任务
503    /// service.register_batch(tasks).await;
504    /// # }
505    /// ```
506    pub fn create_batch<C>(callbacks: Vec<(Duration, C)>) -> Vec<crate::task::TimerTask>
507    where
508        C: TimerCallback,
509    {
510        crate::timer::TimerWheel::create_batch(callbacks)
511    }
512    
513    /// 注册定时器任务到服务(注册阶段)
514    /// 
515    /// # 参数
516    /// - `task`: 通过 `create_task()` 创建的任务
517    /// 
518    /// # 示例
519    /// ```no_run
520    /// # use kestrel_protocol_timer::{TimerWheel, TimerService};
521    /// # use std::time::Duration;
522    /// # #[tokio::main]
523    /// # async fn main() {
524    /// let timer = TimerWheel::with_defaults();
525    /// let service = timer.create_service();
526    /// 
527    /// let task = TimerService::create_task(Duration::from_millis(100), || async {
528    ///     println!("Timer fired!");
529    /// });
530    /// let task_id = task.get_id();
531    /// 
532    /// service.register(task).await;
533    /// # }
534    /// ```
535    #[inline]
536    pub async fn register(&self, task: crate::task::TimerTask) {
537        let (completion_tx, completion_rx) = tokio::sync::oneshot::channel();
538        let notifier = crate::task::CompletionNotifier(completion_tx);
539        
540        let delay = task.delay;
541        let task_id = task.id;
542        
543        // 单次加锁完成所有操作
544        {
545            let mut wheel_guard = self.wheel.lock();
546            wheel_guard.insert(delay, task, notifier);
547        }
548        
549        // 创建句柄并添加到服务管理
550        let handle = TimerHandle::new(task_id, self.wheel.clone(), completion_rx);
551        self.add_timer_handle(handle).await;
552    }
553    
554    /// 批量注册定时器任务到服务(注册阶段)
555    /// 
556    /// # 参数
557    /// - `tasks`: 通过 `create_batch()` 创建的任务列表
558    /// 
559    /// # 示例
560    /// ```no_run
561    /// # use kestrel_protocol_timer::{TimerWheel, TimerService};
562    /// # use std::time::Duration;
563    /// # #[tokio::main]
564    /// # async fn main() {
565    /// let timer = TimerWheel::with_defaults();
566    /// let service = timer.create_service();
567    /// 
568    /// let callbacks: Vec<_> = (0..3)
569    ///     .map(|_| (Duration::from_secs(1), || async {}))
570    ///     .collect();
571    /// let tasks = TimerService::create_batch(callbacks);
572    /// 
573    /// service.register_batch(tasks).await;
574    /// # }
575    /// ```
576    #[inline]
577    pub async fn register_batch(&self, tasks: Vec<crate::task::TimerTask>) {
578        let task_count = tasks.len();
579        let mut completion_rxs = Vec::with_capacity(task_count);
580        let mut task_ids = Vec::with_capacity(task_count);
581        let mut prepared_tasks = Vec::with_capacity(task_count);
582        
583        // 步骤1: 准备所有 channels 和 notifiers(无锁)
584        // 优化:使用 for 循环代替 map + collect,避免闭包捕获开销
585        for task in tasks {
586            let (completion_tx, completion_rx) = tokio::sync::oneshot::channel();
587            let notifier = crate::task::CompletionNotifier(completion_tx);
588            
589            task_ids.push(task.id);
590            completion_rxs.push(completion_rx);
591            prepared_tasks.push((task.delay, task, notifier));
592        }
593        
594        // 步骤2: 单次加锁,批量插入
595        {
596            let mut wheel_guard = self.wheel.lock();
597            wheel_guard.insert_batch(prepared_tasks);
598        }
599        
600        // 创建批量句柄并添加到服务管理
601        let batch_handle = BatchHandle::new(task_ids, self.wheel.clone(), completion_rxs);
602        self.add_batch_handle(batch_handle).await;
603    }
604
605    /// 优雅关闭 TimerService
606    ///
607    /// # 示例
608    /// ```no_run
609    /// # use kestrel_protocol_timer::TimerWheel;
610    /// # #[tokio::main]
611    /// # async fn main() {
612    /// let timer = TimerWheel::with_defaults();
613    /// let mut service = timer.create_service();
614    /// 
615    /// // 使用 service...
616    /// 
617    /// service.shutdown().await;
618    /// # }
619    /// ```
620    pub async fn shutdown(mut self) {
621        let _ = self.command_tx.send(ServiceCommand::Shutdown).await;
622        if let Some(handle) = self.actor_handle.take() {
623            let _ = handle.await;
624        }
625    }
626}
627
628
629impl Drop for TimerService {
630    fn drop(&mut self) {
631        if let Some(handle) = self.actor_handle.take() {
632            handle.abort();
633        }
634    }
635}
636
637/// ServiceActor - 内部 Actor 实现
638struct ServiceActor {
639    /// 命令接收端
640    command_rx: mpsc::Receiver<ServiceCommand>,
641    /// 超时发送端
642    timeout_tx: mpsc::Sender<TaskId>,
643}
644
645impl ServiceActor {
646    fn new(command_rx: mpsc::Receiver<ServiceCommand>, timeout_tx: mpsc::Sender<TaskId>) -> Self {
647        Self {
648            command_rx,
649            timeout_tx,
650        }
651    }
652
653    async fn run(mut self) {
654        // 使用 FuturesUnordered 来监听所有的 completion_rxs
655        // 每个 future 返回 (TaskId, Result)
656        let mut futures: FuturesUnordered<BoxFuture<'static, (TaskId, Result<TaskCompletionReason, tokio::sync::oneshot::error::RecvError>)>> = FuturesUnordered::new();
657
658        loop {
659            tokio::select! {
660                // 监听超时事件
661                Some((task_id, result)) = futures.next() => {
662                    // 检查完成原因,只转发超时(Expired)事件,不转发取消(Cancelled)事件
663                    if let Ok(TaskCompletionReason::Expired) = result {
664                        let _ = self.timeout_tx.send(task_id).await;
665                    }
666                    // 任务会自动从 FuturesUnordered 中移除
667                }
668                
669                // 监听命令
670                Some(cmd) = self.command_rx.recv() => {
671                    match cmd {
672                        ServiceCommand::AddBatchHandle(batch) => {
673                            let BatchHandle {
674                                task_ids,
675                                completion_rxs,
676                                ..
677                            } = batch;
678                            
679                            // 将所有任务添加到 futures 中
680                            for (task_id, rx) in task_ids.into_iter().zip(completion_rxs.into_iter()) {
681                                let future: BoxFuture<'static, (TaskId, Result<TaskCompletionReason, tokio::sync::oneshot::error::RecvError>)> = Box::pin(async move {
682                                    (task_id, rx.await)
683                                });
684                                futures.push(future);
685                            }
686                        }
687                        ServiceCommand::AddTimerHandle(handle) => {
688                            let TimerHandle{
689                                task_id,
690                                completion_rx,
691                                ..
692                            } = handle;
693                            
694                            // 添加到 futures 中
695                            let future: BoxFuture<'static, (TaskId, Result<TaskCompletionReason, tokio::sync::oneshot::error::RecvError>)> = Box::pin(async move {
696                                (task_id, completion_rx.0.await)
697                            });
698                            futures.push(future);
699                        }
700                        ServiceCommand::Shutdown => {
701                            break;
702                        }
703                    }
704                }
705                
706                // 如果没有任何 future 且命令通道已关闭,退出循环
707                else => {
708                    break;
709                }
710            }
711        }
712    }
713}
714
715#[cfg(test)]
716mod tests {
717    use super::*;
718    use crate::TimerWheel;
719    use std::sync::atomic::{AtomicU32, Ordering};
720    use std::sync::Arc;
721    use std::time::Duration;
722
723    #[tokio::test]
724    async fn test_service_creation() {
725        let timer = TimerWheel::with_defaults();
726        let _service = timer.create_service();
727    }
728
729
730    #[tokio::test]
731    async fn test_add_timer_handle_and_receive_timeout() {
732        let timer = TimerWheel::with_defaults();
733        let mut service = timer.create_service();
734
735        // 创建单个定时器
736        let task = TimerWheel::create_task(Duration::from_millis(50), || async {});
737        let task_id = task.get_id();
738        let handle = timer.register(task);
739
740        // 添加到 service
741        service.add_timer_handle(handle).await;
742
743        // 接收超时通知
744        let mut rx = service.take_receiver().unwrap();
745        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
746            .await
747            .expect("Should receive timeout notification")
748            .expect("Should receive Some value");
749
750        assert_eq!(received_task_id, task_id);
751    }
752
753
754    #[tokio::test]
755    async fn test_shutdown() {
756        let timer = TimerWheel::with_defaults();
757        let service = timer.create_service();
758
759        // 添加一些定时器
760        let task1 = TimerService::create_task(Duration::from_secs(10), || async {});
761        let task2 = TimerService::create_task(Duration::from_secs(10), || async {});
762        service.register(task1).await;
763        service.register(task2).await;
764
765        // 立即关闭(不等待定时器触发)
766        service.shutdown().await;
767    }
768
769
770
771    #[tokio::test]
772    async fn test_cancel_task() {
773        let timer = TimerWheel::with_defaults();
774        let service = timer.create_service();
775
776        // 添加一个长时间的定时器
777        let task = TimerWheel::create_task(Duration::from_secs(10), || async {});
778        let task_id = task.get_id();
779        let handle = timer.register(task);
780        
781        service.add_timer_handle(handle).await;
782
783        // 取消任务
784        let cancelled = service.cancel_task(task_id).await;
785        assert!(cancelled, "Task should be cancelled successfully");
786
787        // 尝试再次取消同一个任务,应该返回 false
788        let cancelled_again = service.cancel_task(task_id).await;
789        assert!(!cancelled_again, "Task should not exist anymore");
790    }
791
792    #[tokio::test]
793    async fn test_cancel_nonexistent_task() {
794        let timer = TimerWheel::with_defaults();
795        let service = timer.create_service();
796
797        // 添加一个定时器以初始化 service
798        let task = TimerWheel::create_task(Duration::from_millis(50), || async {});
799        let handle = timer.register(task);
800        service.add_timer_handle(handle).await;
801
802        // 尝试取消一个不存在的任务(创建一个不会实际注册的任务ID)
803        let fake_task = TimerWheel::create_task(Duration::from_millis(50), || async {});
804        let fake_task_id = fake_task.get_id();
805        // 不注册 fake_task
806        let cancelled = service.cancel_task(fake_task_id).await;
807        assert!(!cancelled, "Nonexistent task should not be cancelled");
808    }
809
810
811    #[tokio::test]
812    async fn test_task_timeout_cleans_up_task_sender() {
813        let timer = TimerWheel::with_defaults();
814        let mut service = timer.create_service();
815
816        // 添加一个短时间的定时器
817        let task = TimerWheel::create_task(Duration::from_millis(50), || async {});
818        let task_id = task.get_id();
819        let handle = timer.register(task);
820        
821        service.add_timer_handle(handle).await;
822
823        // 等待任务超时
824        let mut rx = service.take_receiver().unwrap();
825        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
826            .await
827            .expect("Should receive timeout notification")
828            .expect("Should receive Some value");
829        
830        assert_eq!(received_task_id, task_id);
831
832        // 等待一下确保内部清理完成
833        tokio::time::sleep(Duration::from_millis(10)).await;
834
835        // 尝试取消已经超时的任务,应该返回 false
836        let cancelled = service.cancel_task(task_id).await;
837        assert!(!cancelled, "Timed out task should not exist anymore");
838    }
839
840    #[tokio::test]
841    async fn test_cancel_task_spawns_background_task() {
842        let timer = TimerWheel::with_defaults();
843        let service = timer.create_service();
844        let counter = Arc::new(AtomicU32::new(0));
845
846        // 创建一个定时器
847        let counter_clone = Arc::clone(&counter);
848        let task = TimerWheel::create_task(
849            Duration::from_secs(10),
850            move || {
851                let counter = Arc::clone(&counter_clone);
852                async move {
853                    counter.fetch_add(1, Ordering::SeqCst);
854                }
855            },
856        );
857        let task_id = task.get_id();
858        let handle = timer.register(task);
859        
860        service.add_timer_handle(handle).await;
861
862        // 使用 cancel_task(会等待结果,但在后台协程中处理)
863        let cancelled = service.cancel_task(task_id).await;
864        assert!(cancelled, "Task should be cancelled successfully");
865
866        // 等待足够长时间确保回调不会被执行
867        tokio::time::sleep(Duration::from_millis(100)).await;
868        assert_eq!(counter.load(Ordering::SeqCst), 0, "Callback should not have been executed");
869
870        // 验证任务已从 active_tasks 中移除
871        let cancelled_again = service.cancel_task(task_id).await;
872        assert!(!cancelled_again, "Task should have been removed from active_tasks");
873    }
874
875    #[tokio::test]
876    async fn test_schedule_once_direct() {
877        let timer = TimerWheel::with_defaults();
878        let mut service = timer.create_service();
879        let counter = Arc::new(AtomicU32::new(0));
880
881        // 直接通过 service 调度定时器
882        let counter_clone = Arc::clone(&counter);
883        let task = TimerService::create_task(
884            Duration::from_millis(50),
885            move || {
886                let counter = Arc::clone(&counter_clone);
887                async move {
888                    counter.fetch_add(1, Ordering::SeqCst);
889                }
890            },
891        );
892        let task_id = task.get_id();
893        service.register(task).await;
894
895        // 等待定时器触发
896        let mut rx = service.take_receiver().unwrap();
897        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
898            .await
899            .expect("Should receive timeout notification")
900            .expect("Should receive Some value");
901
902        assert_eq!(received_task_id, task_id);
903        
904        // 等待回调执行
905        tokio::time::sleep(Duration::from_millis(50)).await;
906        assert_eq!(counter.load(Ordering::SeqCst), 1);
907    }
908
909    #[tokio::test]
910    async fn test_schedule_once_batch_direct() {
911        let timer = TimerWheel::with_defaults();
912        let mut service = timer.create_service();
913        let counter = Arc::new(AtomicU32::new(0));
914
915        // 直接通过 service 批量调度定时器
916        let callbacks: Vec<_> = (0..3)
917            .map(|_| {
918                let counter = Arc::clone(&counter);
919                (Duration::from_millis(50), move || {
920                    let counter = Arc::clone(&counter);
921                    async move {
922                        counter.fetch_add(1, Ordering::SeqCst);
923                    }
924                })
925            })
926            .collect();
927
928        let tasks = TimerService::create_batch(callbacks);
929        assert_eq!(tasks.len(), 3);
930        service.register_batch(tasks).await;
931
932        // 接收所有超时通知
933        let mut received_count = 0;
934        let mut rx = service.take_receiver().unwrap();
935        
936        while received_count < 3 {
937            match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
938                Ok(Some(_task_id)) => {
939                    received_count += 1;
940                }
941                Ok(None) => break,
942                Err(_) => break,
943            }
944        }
945
946        assert_eq!(received_count, 3);
947        
948        // 等待回调执行
949        tokio::time::sleep(Duration::from_millis(50)).await;
950        assert_eq!(counter.load(Ordering::SeqCst), 3);
951    }
952
953    #[tokio::test]
954    async fn test_schedule_once_notify_direct() {
955        let timer = TimerWheel::with_defaults();
956        let mut service = timer.create_service();
957
958        // 直接通过 service 调度仅通知的定时器(无回调)
959        let task = crate::task::TimerTask::new(Duration::from_millis(50), None);
960        let task_id = task.get_id();
961        service.register(task).await;
962
963        // 接收超时通知
964        let mut rx = service.take_receiver().unwrap();
965        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
966            .await
967            .expect("Should receive timeout notification")
968            .expect("Should receive Some value");
969
970        assert_eq!(received_task_id, task_id);
971    }
972
973    #[tokio::test]
974    async fn test_schedule_and_cancel_direct() {
975        let timer = TimerWheel::with_defaults();
976        let service = timer.create_service();
977        let counter = Arc::new(AtomicU32::new(0));
978
979        // 直接调度定时器
980        let counter_clone = Arc::clone(&counter);
981        let task = TimerService::create_task(
982            Duration::from_secs(10),
983            move || {
984                let counter = Arc::clone(&counter_clone);
985                async move {
986                    counter.fetch_add(1, Ordering::SeqCst);
987                }
988            },
989        );
990        let task_id = task.get_id();
991        service.register(task).await;
992
993        // 立即取消
994        let cancelled = service.cancel_task(task_id).await;
995        assert!(cancelled, "Task should be cancelled successfully");
996
997        // 等待确保回调不会执行
998        tokio::time::sleep(Duration::from_millis(100)).await;
999        assert_eq!(counter.load(Ordering::SeqCst), 0, "Callback should not have been executed");
1000    }
1001
1002    #[tokio::test]
1003    async fn test_cancel_batch_direct() {
1004        let timer = TimerWheel::with_defaults();
1005        let service = timer.create_service();
1006        let counter = Arc::new(AtomicU32::new(0));
1007
1008        // 批量调度定时器
1009        let callbacks: Vec<_> = (0..10)
1010            .map(|_| {
1011                let counter = Arc::clone(&counter);
1012                (Duration::from_secs(10), move || {
1013                    let counter = Arc::clone(&counter);
1014                    async move {
1015                        counter.fetch_add(1, Ordering::SeqCst);
1016                    }
1017                })
1018            })
1019            .collect();
1020
1021        let tasks = TimerService::create_batch(callbacks);
1022        let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
1023        assert_eq!(task_ids.len(), 10);
1024        service.register_batch(tasks).await;
1025
1026        // 批量取消所有任务
1027        let cancelled = service.cancel_batch(&task_ids).await;
1028        assert_eq!(cancelled, 10, "All 10 tasks should be cancelled");
1029
1030        // 等待确保回调不会执行
1031        tokio::time::sleep(Duration::from_millis(100)).await;
1032        assert_eq!(counter.load(Ordering::SeqCst), 0, "No callbacks should have been executed");
1033    }
1034
1035    #[tokio::test]
1036    async fn test_cancel_batch_partial() {
1037        let timer = TimerWheel::with_defaults();
1038        let service = timer.create_service();
1039        let counter = Arc::new(AtomicU32::new(0));
1040
1041        // 批量调度定时器
1042        let callbacks: Vec<_> = (0..10)
1043            .map(|_| {
1044                let counter = Arc::clone(&counter);
1045                (Duration::from_secs(10), move || {
1046                    let counter = Arc::clone(&counter);
1047                    async move {
1048                        counter.fetch_add(1, Ordering::SeqCst);
1049                    }
1050                })
1051            })
1052            .collect();
1053
1054        let tasks = TimerService::create_batch(callbacks);
1055        let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
1056        service.register_batch(tasks).await;
1057
1058        // 只取消前5个任务
1059        let to_cancel: Vec<_> = task_ids.iter().take(5).copied().collect();
1060        let cancelled = service.cancel_batch(&to_cancel).await;
1061        assert_eq!(cancelled, 5, "5 tasks should be cancelled");
1062
1063        // 等待确保前5个回调不会执行
1064        tokio::time::sleep(Duration::from_millis(100)).await;
1065        assert_eq!(counter.load(Ordering::SeqCst), 0, "Cancelled tasks should not execute");
1066    }
1067
1068    #[tokio::test]
1069    async fn test_cancel_batch_empty() {
1070        let timer = TimerWheel::with_defaults();
1071        let service = timer.create_service();
1072
1073        // 取消空列表
1074        let empty: Vec<TaskId> = vec![];
1075        let cancelled = service.cancel_batch(&empty).await;
1076        assert_eq!(cancelled, 0, "No tasks should be cancelled");
1077    }
1078
1079    #[tokio::test]
1080    async fn test_postpone_task() {
1081        let timer = TimerWheel::with_defaults();
1082        let mut service = timer.create_service();
1083        let counter = Arc::new(AtomicU32::new(0));
1084
1085        // 注册一个任务,延迟 50ms
1086        let counter_clone = Arc::clone(&counter);
1087        let task = TimerService::create_task(
1088            Duration::from_millis(50),
1089            move || {
1090                let counter = Arc::clone(&counter_clone);
1091                async move {
1092                    counter.fetch_add(1, Ordering::SeqCst);
1093                }
1094            },
1095        );
1096        let task_id = task.get_id();
1097        service.register(task).await;
1098
1099        // 推迟任务到 150ms
1100        let postponed = service.postpone_task(task_id, Duration::from_millis(150)).await;
1101        assert!(postponed, "Task should be postponed successfully");
1102
1103        // 等待原定时间 50ms,任务不应该触发
1104        tokio::time::sleep(Duration::from_millis(70)).await;
1105        assert_eq!(counter.load(Ordering::SeqCst), 0);
1106
1107        // 接收超时通知(从推迟开始算,还需要等待约 150ms)
1108        let mut rx = service.take_receiver().unwrap();
1109        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1110            .await
1111            .expect("Should receive timeout notification")
1112            .expect("Should receive Some value");
1113
1114        assert_eq!(received_task_id, task_id);
1115        
1116        // 等待回调执行
1117        tokio::time::sleep(Duration::from_millis(20)).await;
1118        assert_eq!(counter.load(Ordering::SeqCst), 1);
1119    }
1120
1121    #[tokio::test]
1122    async fn test_postpone_task_with_callback() {
1123        let timer = TimerWheel::with_defaults();
1124        let mut service = timer.create_service();
1125        let counter = Arc::new(AtomicU32::new(0));
1126
1127        // 注册一个任务,原始回调增加 1
1128        let counter_clone1 = Arc::clone(&counter);
1129        let task = TimerService::create_task(
1130            Duration::from_millis(50),
1131            move || {
1132                let counter = Arc::clone(&counter_clone1);
1133                async move {
1134                    counter.fetch_add(1, Ordering::SeqCst);
1135                }
1136            },
1137        );
1138        let task_id = task.get_id();
1139        service.register(task).await;
1140
1141        // 推迟任务并替换回调,新回调增加 10
1142        let counter_clone2 = Arc::clone(&counter);
1143        let postponed = service.postpone_task_with_callback(
1144            task_id,
1145            Duration::from_millis(100),
1146            move || {
1147                let counter = Arc::clone(&counter_clone2);
1148                async move {
1149                    counter.fetch_add(10, Ordering::SeqCst);
1150                }
1151            }
1152        ).await;
1153        assert!(postponed, "Task should be postponed successfully");
1154
1155        // 接收超时通知(推迟后需要等待100ms,加上余量)
1156        let mut rx = service.take_receiver().unwrap();
1157        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1158            .await
1159            .expect("Should receive timeout notification")
1160            .expect("Should receive Some value");
1161
1162        assert_eq!(received_task_id, task_id);
1163        
1164        // 等待回调执行
1165        tokio::time::sleep(Duration::from_millis(20)).await;
1166        
1167        // 验证新回调被执行(增加了 10 而不是 1)
1168        assert_eq!(counter.load(Ordering::SeqCst), 10);
1169    }
1170
1171    #[tokio::test]
1172    async fn test_postpone_nonexistent_task() {
1173        let timer = TimerWheel::with_defaults();
1174        let service = timer.create_service();
1175
1176        // 尝试推迟一个不存在的任务
1177        let fake_task = TimerService::create_task(Duration::from_millis(50), || async {});
1178        let fake_task_id = fake_task.get_id();
1179        // 不注册这个任务
1180        
1181        let postponed = service.postpone_task(fake_task_id, Duration::from_millis(100)).await;
1182        assert!(!postponed, "Nonexistent task should not be postponed");
1183    }
1184
1185    #[tokio::test]
1186    async fn test_postpone_batch() {
1187        let timer = TimerWheel::with_defaults();
1188        let mut service = timer.create_service();
1189        let counter = Arc::new(AtomicU32::new(0));
1190
1191        // 注册 3 个任务
1192        let mut task_ids = Vec::new();
1193        for _ in 0..3 {
1194            let counter_clone = Arc::clone(&counter);
1195            let task = TimerService::create_task(
1196                Duration::from_millis(50),
1197                move || {
1198                    let counter = Arc::clone(&counter_clone);
1199                    async move {
1200                        counter.fetch_add(1, Ordering::SeqCst);
1201                    }
1202                },
1203            );
1204            task_ids.push((task.get_id(), Duration::from_millis(150)));
1205            service.register(task).await;
1206        }
1207
1208        // 批量推迟
1209        let postponed = service.postpone_batch(&task_ids).await;
1210        assert_eq!(postponed, 3, "All 3 tasks should be postponed");
1211
1212        // 等待原定时间 50ms,任务不应该触发
1213        tokio::time::sleep(Duration::from_millis(70)).await;
1214        assert_eq!(counter.load(Ordering::SeqCst), 0);
1215
1216        // 接收所有超时通知
1217        let mut received_count = 0;
1218        let mut rx = service.take_receiver().unwrap();
1219        
1220        while received_count < 3 {
1221            match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
1222                Ok(Some(_task_id)) => {
1223                    received_count += 1;
1224                }
1225                Ok(None) => break,
1226                Err(_) => break,
1227            }
1228        }
1229
1230        assert_eq!(received_count, 3);
1231        
1232        // 等待回调执行
1233        tokio::time::sleep(Duration::from_millis(20)).await;
1234        assert_eq!(counter.load(Ordering::SeqCst), 3);
1235    }
1236
1237    #[tokio::test]
1238    async fn test_postpone_batch_with_callbacks() {
1239        let timer = TimerWheel::with_defaults();
1240        let mut service = timer.create_service();
1241        let counter = Arc::new(AtomicU32::new(0));
1242
1243        // 注册 3 个任务
1244        let mut task_ids = Vec::new();
1245        for _ in 0..3 {
1246            let task = TimerService::create_task(
1247                Duration::from_millis(50),
1248                || async {},
1249            );
1250            task_ids.push(task.get_id());
1251            service.register(task).await;
1252        }
1253
1254        // 批量推迟并替换回调
1255        let updates: Vec<_> = task_ids
1256            .into_iter()
1257            .map(|id| {
1258                let counter_clone = Arc::clone(&counter);
1259                (id, Duration::from_millis(150), move || {
1260                    let counter = Arc::clone(&counter_clone);
1261                    async move {
1262                        counter.fetch_add(1, Ordering::SeqCst);
1263                    }
1264                })
1265            })
1266            .collect();
1267
1268        let postponed = service.postpone_batch_with_callbacks(updates).await;
1269        assert_eq!(postponed, 3, "All 3 tasks should be postponed");
1270
1271        // 等待原定时间 50ms,任务不应该触发
1272        tokio::time::sleep(Duration::from_millis(70)).await;
1273        assert_eq!(counter.load(Ordering::SeqCst), 0);
1274
1275        // 接收所有超时通知
1276        let mut received_count = 0;
1277        let mut rx = service.take_receiver().unwrap();
1278        
1279        while received_count < 3 {
1280            match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
1281                Ok(Some(_task_id)) => {
1282                    received_count += 1;
1283                }
1284                Ok(None) => break,
1285                Err(_) => break,
1286            }
1287        }
1288
1289        assert_eq!(received_count, 3);
1290        
1291        // 等待回调执行
1292        tokio::time::sleep(Duration::from_millis(20)).await;
1293        assert_eq!(counter.load(Ordering::SeqCst), 3);
1294    }
1295
1296    #[tokio::test]
1297    async fn test_postpone_batch_empty() {
1298        let timer = TimerWheel::with_defaults();
1299        let service = timer.create_service();
1300
1301        // 推迟空列表
1302        let empty: Vec<(TaskId, Duration)> = vec![];
1303        let postponed = service.postpone_batch(&empty).await;
1304        assert_eq!(postponed, 0, "No tasks should be postponed");
1305    }
1306
1307    #[tokio::test]
1308    async fn test_postpone_keeps_timeout_notification_valid() {
1309        let timer = TimerWheel::with_defaults();
1310        let mut service = timer.create_service();
1311        let counter = Arc::new(AtomicU32::new(0));
1312
1313        // 注册一个任务
1314        let counter_clone = Arc::clone(&counter);
1315        let task = TimerService::create_task(
1316            Duration::from_millis(50),
1317            move || {
1318                let counter = Arc::clone(&counter_clone);
1319                async move {
1320                    counter.fetch_add(1, Ordering::SeqCst);
1321                }
1322            },
1323        );
1324        let task_id = task.get_id();
1325        service.register(task).await;
1326
1327        // 推迟任务
1328        service.postpone_task(task_id, Duration::from_millis(100)).await;
1329
1330        // 验证超时通知仍然有效(推迟后需要等待100ms,加上余量)
1331        let mut rx = service.take_receiver().unwrap();
1332        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1333            .await
1334            .expect("Should receive timeout notification")
1335            .expect("Should receive Some value");
1336
1337        assert_eq!(received_task_id, task_id, "Timeout notification should still work after postpone");
1338        
1339        // 等待回调执行
1340        tokio::time::sleep(Duration::from_millis(20)).await;
1341        assert_eq!(counter.load(Ordering::SeqCst), 1);
1342    }
1343
1344    #[tokio::test]
1345    async fn test_cancelled_task_not_forwarded_to_timeout_rx() {
1346        let timer = TimerWheel::with_defaults();
1347        let mut service = timer.create_service();
1348
1349        // 注册两个任务:一个会被取消,一个会正常到期
1350        let task1 = TimerService::create_task(Duration::from_secs(10), || async {});
1351        let task1_id = task1.get_id();
1352        service.register(task1).await;
1353
1354        let task2 = TimerService::create_task(Duration::from_millis(50), || async {});
1355        let task2_id = task2.get_id();
1356        service.register(task2).await;
1357
1358        // 取消第一个任务
1359        let cancelled = service.cancel_task(task1_id).await;
1360        assert!(cancelled, "Task should be cancelled");
1361
1362        // 等待第二个任务到期
1363        let mut rx = service.take_receiver().unwrap();
1364        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1365            .await
1366            .expect("Should receive timeout notification")
1367            .expect("Should receive Some value");
1368
1369        // 应该只收到第二个任务(到期的)的通知,不应该收到第一个任务(取消的)的通知
1370        assert_eq!(received_task_id, task2_id, "Should only receive expired task notification");
1371
1372        // 验证没有其他通知(特别是被取消的任务不应该有通知)
1373        let no_more = tokio::time::timeout(Duration::from_millis(50), rx.recv()).await;
1374        assert!(no_more.is_err(), "Should not receive any more notifications");
1375    }
1376}
1377