kestrel_protocol_timer/
service.rs

1use crate::config::ServiceConfig;
2use crate::task::{TaskCompletionReason, TaskId, TimerCallback};
3use crate::timer::{BatchHandle, TimerHandle};
4use crate::wheel::Wheel;
5use futures::stream::{FuturesUnordered, StreamExt};
6use futures::future::BoxFuture;
7use parking_lot::Mutex;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::sync::mpsc;
11use tokio::task::JoinHandle;
12
13/// TimerService 命令类型
14enum ServiceCommand {
15    /// 添加批量定时器句柄
16    AddBatchHandle(BatchHandle),
17    /// 添加单个定时器句柄
18    AddTimerHandle(TimerHandle),
19    /// 关闭 Service
20    Shutdown,
21}
22
23/// TimerService - 基于 Actor 模式的定时器服务
24///
25/// 管理多个定时器句柄,监听所有超时事件,并将 TaskId 聚合转发给用户。
26///
27/// # 特性
28/// - 自动监听所有添加的定时器句柄的超时事件
29/// - 超时后自动从内部管理中移除该任务
30/// - 将超时的 TaskId 转发到统一的通道供用户接收
31/// - 支持动态添加 BatchHandle 和 TimerHandle
32///
33/// # 示例
34/// ```no_run
35/// use kestrel_protocol_timer::{TimerWheel, TimerService};
36/// use std::time::Duration;
37///
38/// #[tokio::main]
39/// async fn main() {
40///     let timer = TimerWheel::with_defaults();
41///     let mut service = timer.create_service();
42///     
43///     // 使用两步式 API 通过 service 批量调度定时器
44///     let callbacks: Vec<_> = (0..5)
45///         .map(|_| (Duration::from_millis(100), || async {}))
46///         .collect();
47///     let tasks = TimerService::create_batch(callbacks);
48///     service.register_batch(tasks).await;
49///     
50///     // 接收超时通知
51///     let mut rx = service.take_receiver().unwrap();
52///     while let Some(task_id) = rx.recv().await {
53///         println!("Task {:?} completed", task_id);
54///     }
55/// }
56/// ```
57pub struct TimerService {
58    /// 命令发送端
59    command_tx: mpsc::Sender<ServiceCommand>,
60    /// 超时接收端
61    timeout_rx: Option<mpsc::Receiver<TaskId>>,
62    /// Actor 任务句柄
63    actor_handle: Option<JoinHandle<()>>,
64    /// 时间轮引用(用于直接调度定时器)
65    wheel: Arc<Mutex<Wheel>>,
66}
67
68impl TimerService {
69    /// 创建新的 TimerService
70    ///
71    /// # 参数
72    /// - `wheel`: 时间轮引用
73    /// - `config`: 服务配置
74    ///
75    /// # 注意
76    /// 通常不直接调用此方法,而是使用 `TimerWheel::create_service()` 来创建。
77    ///
78    /// # 示例
79    /// ```no_run
80    /// use kestrel_protocol_timer::TimerWheel;
81    ///
82    /// #[tokio::main]
83    /// async fn main() {
84    ///     let timer = TimerWheel::with_defaults();
85    ///     let mut service = timer.create_service();
86    /// }
87    /// ```
88    pub(crate) fn new(wheel: Arc<Mutex<Wheel>>, config: ServiceConfig) -> Self {
89        let (command_tx, command_rx) = mpsc::channel(config.command_channel_capacity);
90        let (timeout_tx, timeout_rx) = mpsc::channel(config.timeout_channel_capacity);
91
92        let actor = ServiceActor::new(command_rx, timeout_tx);
93        let actor_handle = tokio::spawn(async move {
94            actor.run().await;
95        });
96
97        Self {
98            command_tx,
99            timeout_rx: Some(timeout_rx),
100            actor_handle: Some(actor_handle),
101            wheel,
102        }
103    }
104
105    /// 添加批量定时器句柄(内部方法)
106    async fn add_batch_handle(&self, batch: BatchHandle) {
107        let _ = self.command_tx
108            .send(ServiceCommand::AddBatchHandle(batch))
109            .await;
110    }
111
112    /// 添加单个定时器句柄(内部方法)
113    async fn add_timer_handle(&self, handle: TimerHandle) {
114        let _ = self.command_tx
115            .send(ServiceCommand::AddTimerHandle(handle))
116            .await;
117    }
118
119    /// 获取超时接收器(转移所有权)
120    ///
121    /// # 返回
122    /// 超时通知接收器,如果已经被取走则返回 None
123    ///
124    /// # 注意
125    /// 此方法只能调用一次,因为它会转移接收器的所有权
126    ///
127    /// # 示例
128    /// ```no_run
129    /// # use kestrel_protocol_timer::TimerWheel;
130    /// # use std::time::Duration;
131    /// # #[tokio::main]
132    /// # async fn main() {
133    /// let timer = TimerWheel::with_defaults();
134    /// let mut service = timer.create_service();
135    /// 
136    /// let mut rx = service.take_receiver().unwrap();
137    /// while let Some(task_id) = rx.recv().await {
138    ///     println!("Task {:?} timed out", task_id);
139    /// }
140    /// # }
141    /// ```
142    pub fn take_receiver(&mut self) -> Option<mpsc::Receiver<TaskId>> {
143        self.timeout_rx.take()
144    }
145
146    /// 取消指定的任务
147    ///
148    /// # 参数
149    /// - `task_id`: 要取消的任务 ID
150    ///
151    /// # 返回
152    /// - `true`: 任务存在且成功取消
153    /// - `false`: 任务不存在或取消失败
154    ///
155    /// # 性能说明
156    /// 此方法使用直接取消优化,不需要等待 Actor 处理,大幅降低延迟
157    ///
158    /// # 示例
159    /// ```no_run
160    /// # use kestrel_protocol_timer::{TimerWheel, TimerService};
161    /// # use std::time::Duration;
162    /// # #[tokio::main]
163    /// # async fn main() {
164    /// let timer = TimerWheel::with_defaults();
165    /// let service = timer.create_service();
166    /// 
167    /// // 使用两步式 API 调度定时器
168    /// let task = TimerService::create_task(Duration::from_secs(10), || async {});
169    /// let task_id = task.get_id();
170    /// service.register(task).await;
171    /// 
172    /// // 取消任务
173    /// let cancelled = service.cancel_task(task_id);
174    /// println!("Task cancelled: {}", cancelled);
175    /// # }
176    /// ```
177    #[inline]
178    pub fn cancel_task(&self, task_id: TaskId) -> bool {
179        // 优化:直接取消任务,无需通知 Actor
180        // FuturesUnordered 会在任务被取消时自动清理
181        let mut wheel = self.wheel.lock();
182        wheel.cancel(task_id)
183    }
184
185    /// 批量取消任务
186    ///
187    /// 使用底层的批量取消操作一次性取消多个任务,性能优于循环调用 cancel_task。
188    ///
189    /// # 参数
190    /// - `task_ids`: 要取消的任务 ID 列表
191    ///
192    /// # 返回
193    /// 成功取消的任务数量
194    ///
195    /// # 示例
196    /// ```no_run
197    /// # use kestrel_protocol_timer::{TimerWheel, TimerService};
198    /// # use std::time::Duration;
199    /// # #[tokio::main]
200    /// # async fn main() {
201    /// let timer = TimerWheel::with_defaults();
202    /// let service = timer.create_service();
203    /// 
204    /// let callbacks: Vec<_> = (0..10)
205    ///     .map(|_| (Duration::from_secs(10), || async {}))
206    ///     .collect();
207    /// let tasks = TimerService::create_batch(callbacks);
208    /// let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
209    /// service.register_batch(tasks).await;
210    /// 
211    /// // 批量取消
212    /// let cancelled = service.cancel_batch(&task_ids);
213    /// println!("成功取消 {} 个任务", cancelled);
214    /// # }
215    /// ```
216    #[inline]
217    pub fn cancel_batch(&self, task_ids: &[TaskId]) -> usize {
218        if task_ids.is_empty() {
219            return 0;
220        }
221
222        // 优化:直接使用底层的批量取消,无需通知 Actor
223        // FuturesUnordered 会在任务被取消时自动清理
224        let mut wheel = self.wheel.lock();
225        wheel.cancel_batch(task_ids)
226    }
227
228    /// 推迟任务(保持原回调)
229    ///
230    /// # 参数
231    /// - `task_id`: 要推迟的任务 ID
232    /// - `new_delay`: 新的延迟时间(从当前时间点重新计算)
233    ///
234    /// # 返回
235    /// - `true`: 任务存在且成功推迟
236    /// - `false`: 任务不存在或推迟失败
237    ///
238    /// # 性能说明
239    /// 此方法使用直接推迟优化,不需要等待 Actor 处理,大幅降低延迟
240    ///
241    /// # 注意
242    /// - 推迟后任务 ID 保持不变
243    /// - 原有的超时通知仍然有效
244    /// - 保持原回调函数不变
245    ///
246    /// # 示例
247    /// ```no_run
248    /// # use kestrel_protocol_timer::{TimerWheel, TimerService};
249    /// # use std::time::Duration;
250    /// # #[tokio::main]
251    /// # async fn main() {
252    /// let timer = TimerWheel::with_defaults();
253    /// let service = timer.create_service();
254    /// 
255    /// let task = TimerService::create_task(Duration::from_secs(5), || async {});
256    /// let task_id = task.get_id();
257    /// service.register(task).await;
258    /// 
259    /// // 推迟到 10 秒后触发
260    /// let success = service.postpone_task(task_id, Duration::from_secs(10));
261    /// println!("推迟成功: {}", success);
262    /// # }
263    /// ```
264    #[inline]
265    pub fn postpone_task(&self, task_id: TaskId, new_delay: Duration) -> bool {
266        // 优化:直接推迟任务,无需通知 Actor
267        // FuturesUnordered 会继续监听原有的 completion_receiver
268        let mut wheel = self.wheel.lock();
269        wheel.postpone(task_id, new_delay, None)
270    }
271
272    /// 推迟任务(替换回调)
273    ///
274    /// # 参数
275    /// - `task_id`: 要推迟的任务 ID
276    /// - `new_delay`: 新的延迟时间(从当前时间点重新计算)
277    /// - `callback`: 新的回调函数
278    ///
279    /// # 返回
280    /// - `true`: 任务存在且成功推迟
281    /// - `false`: 任务不存在或推迟失败
282    ///
283    /// # 注意
284    /// - 推迟后任务 ID 保持不变
285    /// - 原有的超时通知仍然有效
286    /// - 回调函数会被替换为新的回调
287    ///
288    /// # 示例
289    /// ```no_run
290    /// # use kestrel_protocol_timer::{TimerWheel, TimerService};
291    /// # use std::time::Duration;
292    /// # #[tokio::main]
293    /// # async fn main() {
294    /// let timer = TimerWheel::with_defaults();
295    /// let service = timer.create_service();
296    /// 
297    /// let task = TimerService::create_task(Duration::from_secs(5), || async {
298    ///     println!("Original callback");
299    /// });
300    /// let task_id = task.get_id();
301    /// service.register(task).await;
302    /// 
303    /// // 推迟并替换回调
304    /// let success = service.postpone_task_with_callback(
305    ///     task_id,
306    ///     Duration::from_secs(10),
307    ///     || async { println!("New callback!"); }
308    /// );
309    /// println!("推迟成功: {}", success);
310    /// # }
311    /// ```
312    #[inline]
313    pub fn postpone_task_with_callback<C>(
314        &self,
315        task_id: TaskId,
316        new_delay: Duration,
317        callback: C,
318    ) -> bool
319    where
320        C: TimerCallback,
321    {
322        use std::sync::Arc;
323        let callback_wrapper = Arc::new(callback) as Arc<dyn TimerCallback>;
324        let mut wheel = self.wheel.lock();
325        wheel.postpone(task_id, new_delay, Some(callback_wrapper))
326    }
327
328    /// 批量推迟任务(保持原回调)
329    ///
330    /// # 参数
331    /// - `updates`: (任务ID, 新延迟) 的元组列表
332    ///
333    /// # 返回
334    /// 成功推迟的任务数量
335    ///
336    /// # 示例
337    /// ```no_run
338    /// # use kestrel_protocol_timer::{TimerWheel, TimerService};
339    /// # use std::time::Duration;
340    /// # #[tokio::main]
341    /// # async fn main() {
342    /// let timer = TimerWheel::with_defaults();
343    /// let service = timer.create_service();
344    /// 
345    /// let callbacks: Vec<_> = (0..3)
346    ///     .map(|_| (Duration::from_secs(5), || async {}))
347    ///     .collect();
348    /// let tasks = TimerService::create_batch(callbacks);
349    /// let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
350    /// service.register_batch(tasks).await;
351    /// 
352    /// // 批量推迟
353    /// let updates: Vec<_> = task_ids
354    ///     .into_iter()
355    ///     .map(|id| (id, Duration::from_secs(10)))
356    ///     .collect();
357    /// let postponed = service.postpone_batch(&updates);
358    /// println!("成功推迟 {} 个任务", postponed);
359    /// # }
360    /// ```
361    #[inline]
362    pub fn postpone_batch(&self, updates: &[(TaskId, Duration)]) -> usize {
363        if updates.is_empty() {
364            return 0;
365        }
366
367        let updates_vec: Vec<_> = updates
368            .iter()
369            .map(|(task_id, delay)| (*task_id, *delay, None))
370            .collect();
371        let mut wheel = self.wheel.lock();
372        wheel.postpone_batch(updates_vec)
373    }
374
375    /// 批量推迟任务(替换回调)
376    ///
377    /// # 参数
378    /// - `updates`: (任务ID, 新延迟, 新回调) 的元组列表
379    ///
380    /// # 返回
381    /// 成功推迟的任务数量
382    ///
383    /// # 示例
384    /// ```no_run
385    /// # use kestrel_protocol_timer::{TimerWheel, TimerService};
386    /// # use std::time::Duration;
387    /// # #[tokio::main]
388    /// # async fn main() {
389    /// let timer = TimerWheel::with_defaults();
390    /// let service = timer.create_service();
391    /// 
392    /// let callbacks: Vec<_> = (0..3)
393    ///     .map(|_| (Duration::from_secs(5), || async {}))
394    ///     .collect();
395    /// let tasks = TimerService::create_batch(callbacks);
396    /// let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
397    /// service.register_batch(tasks).await;
398    /// 
399    /// // 批量推迟并替换回调
400    /// let updates: Vec<_> = task_ids
401    ///     .into_iter()
402    ///     .enumerate()
403    ///     .map(|(i, id)| {
404    ///         (id, Duration::from_secs(10), move || async move {
405    ///             println!("New callback {}", i);
406    ///         })
407    ///     })
408    ///     .collect();
409    /// let postponed = service.postpone_batch_with_callbacks(updates);
410    /// println!("成功推迟 {} 个任务", postponed);
411    /// # }
412    /// ```
413    #[inline]
414    pub fn postpone_batch_with_callbacks<C>(
415        &self,
416        updates: Vec<(TaskId, Duration, C)>,
417    ) -> usize
418    where
419        C: TimerCallback,
420    {
421        if updates.is_empty() {
422            return 0;
423        }
424
425        use std::sync::Arc;
426        let updates_vec: Vec<_> = updates
427            .into_iter()
428            .map(|(task_id, delay, callback)| {
429                let callback_wrapper = Arc::new(callback) as Arc<dyn TimerCallback>;
430                (task_id, delay, Some(callback_wrapper))
431            })
432            .collect();
433        let mut wheel = self.wheel.lock();
434        wheel.postpone_batch(updates_vec)
435    }
436
437    /// 创建定时器任务(静态方法,申请阶段)
438    /// 
439    /// # 参数
440    /// - `delay`: 延迟时间
441    /// - `callback`: 实现了 TimerCallback trait 的回调对象
442    /// 
443    /// # 返回
444    /// 返回 TimerTask,需要通过 `register()` 注册
445    /// 
446    /// # 示例
447    /// ```no_run
448    /// # use kestrel_protocol_timer::{TimerWheel, TimerService};
449    /// # use std::time::Duration;
450    /// # #[tokio::main]
451    /// # async fn main() {
452    /// let timer = TimerWheel::with_defaults();
453    /// let service = timer.create_service();
454    /// 
455    /// // 步骤 1: 创建任务
456    /// let task = TimerService::create_task(Duration::from_millis(100), || async {
457    ///     println!("Timer fired!");
458    /// });
459    /// 
460    /// let task_id = task.get_id();
461    /// println!("Created task: {:?}", task_id);
462    /// 
463    /// // 步骤 2: 注册任务
464    /// service.register(task).await;
465    /// # }
466    /// ```
467    pub fn create_task<C>(delay: Duration, callback: C) -> crate::task::TimerTask
468    where
469        C: TimerCallback,
470    {
471        crate::timer::TimerWheel::create_task(delay, callback)
472    }
473    
474    /// 批量创建定时器任务(静态方法,申请阶段)
475    /// 
476    /// # 参数
477    /// - `callbacks`: (延迟时间, 回调) 的元组列表
478    /// 
479    /// # 返回
480    /// 返回 TimerTask 列表,需要通过 `register_batch()` 注册
481    /// 
482    /// # 示例
483    /// ```no_run
484    /// # use kestrel_protocol_timer::{TimerWheel, TimerService};
485    /// # use std::time::Duration;
486    /// # #[tokio::main]
487    /// # async fn main() {
488    /// let timer = TimerWheel::with_defaults();
489    /// let service = timer.create_service();
490    /// 
491    /// // 步骤 1: 批量创建任务
492    /// let callbacks: Vec<_> = (0..3)
493    ///     .map(|i| (Duration::from_millis(100 * (i + 1)), move || async move {
494    ///         println!("Timer {} fired!", i);
495    ///     }))
496    ///     .collect();
497    /// 
498    /// let tasks = TimerService::create_batch(callbacks);
499    /// println!("Created {} tasks", tasks.len());
500    /// 
501    /// // 步骤 2: 批量注册任务
502    /// service.register_batch(tasks).await;
503    /// # }
504    /// ```
505    pub fn create_batch<C>(callbacks: Vec<(Duration, C)>) -> Vec<crate::task::TimerTask>
506    where
507        C: TimerCallback,
508    {
509        crate::timer::TimerWheel::create_batch(callbacks)
510    }
511    
512    /// 注册定时器任务到服务(注册阶段)
513    /// 
514    /// # 参数
515    /// - `task`: 通过 `create_task()` 创建的任务
516    /// 
517    /// # 示例
518    /// ```no_run
519    /// # use kestrel_protocol_timer::{TimerWheel, TimerService};
520    /// # use std::time::Duration;
521    /// # #[tokio::main]
522    /// # async fn main() {
523    /// let timer = TimerWheel::with_defaults();
524    /// let service = timer.create_service();
525    /// 
526    /// let task = TimerService::create_task(Duration::from_millis(100), || async {
527    ///     println!("Timer fired!");
528    /// });
529    /// let task_id = task.get_id();
530    /// 
531    /// service.register(task).await;
532    /// # }
533    /// ```
534    #[inline]
535    pub async fn register(&self, task: crate::task::TimerTask) {
536        let (completion_tx, completion_rx) = tokio::sync::oneshot::channel();
537        let notifier = crate::task::CompletionNotifier(completion_tx);
538        
539        let delay = task.delay;
540        let task_id = task.id;
541        
542        // 单次加锁完成所有操作
543        {
544            let mut wheel_guard = self.wheel.lock();
545            wheel_guard.insert(delay, task, notifier);
546        }
547        
548        // 创建句柄并添加到服务管理
549        let handle = TimerHandle::new(task_id, self.wheel.clone(), completion_rx);
550        self.add_timer_handle(handle).await;
551    }
552    
553    /// 批量注册定时器任务到服务(注册阶段)
554    /// 
555    /// # 参数
556    /// - `tasks`: 通过 `create_batch()` 创建的任务列表
557    /// 
558    /// # 示例
559    /// ```no_run
560    /// # use kestrel_protocol_timer::{TimerWheel, TimerService};
561    /// # use std::time::Duration;
562    /// # #[tokio::main]
563    /// # async fn main() {
564    /// let timer = TimerWheel::with_defaults();
565    /// let service = timer.create_service();
566    /// 
567    /// let callbacks: Vec<_> = (0..3)
568    ///     .map(|_| (Duration::from_secs(1), || async {}))
569    ///     .collect();
570    /// let tasks = TimerService::create_batch(callbacks);
571    /// 
572    /// service.register_batch(tasks).await;
573    /// # }
574    /// ```
575    #[inline]
576    pub async fn register_batch(&self, tasks: Vec<crate::task::TimerTask>) {
577        let task_count = tasks.len();
578        let mut completion_rxs = Vec::with_capacity(task_count);
579        let mut task_ids = Vec::with_capacity(task_count);
580        let mut prepared_tasks = Vec::with_capacity(task_count);
581        
582        // 步骤1: 准备所有 channels 和 notifiers(无锁)
583        // 优化:使用 for 循环代替 map + collect,避免闭包捕获开销
584        for task in tasks {
585            let (completion_tx, completion_rx) = tokio::sync::oneshot::channel();
586            let notifier = crate::task::CompletionNotifier(completion_tx);
587            
588            task_ids.push(task.id);
589            completion_rxs.push(completion_rx);
590            prepared_tasks.push((task.delay, task, notifier));
591        }
592        
593        // 步骤2: 单次加锁,批量插入
594        {
595            let mut wheel_guard = self.wheel.lock();
596            wheel_guard.insert_batch(prepared_tasks);
597        }
598        
599        // 创建批量句柄并添加到服务管理
600        let batch_handle = BatchHandle::new(task_ids, self.wheel.clone(), completion_rxs);
601        self.add_batch_handle(batch_handle).await;
602    }
603
604    /// 优雅关闭 TimerService
605    ///
606    /// # 示例
607    /// ```no_run
608    /// # use kestrel_protocol_timer::TimerWheel;
609    /// # #[tokio::main]
610    /// # async fn main() {
611    /// let timer = TimerWheel::with_defaults();
612    /// let mut service = timer.create_service();
613    /// 
614    /// // 使用 service...
615    /// 
616    /// service.shutdown().await;
617    /// # }
618    /// ```
619    pub async fn shutdown(mut self) {
620        let _ = self.command_tx.send(ServiceCommand::Shutdown).await;
621        if let Some(handle) = self.actor_handle.take() {
622            let _ = handle.await;
623        }
624    }
625}
626
627
628impl Drop for TimerService {
629    fn drop(&mut self) {
630        if let Some(handle) = self.actor_handle.take() {
631            handle.abort();
632        }
633    }
634}
635
636/// ServiceActor - 内部 Actor 实现
637struct ServiceActor {
638    /// 命令接收端
639    command_rx: mpsc::Receiver<ServiceCommand>,
640    /// 超时发送端
641    timeout_tx: mpsc::Sender<TaskId>,
642}
643
644impl ServiceActor {
645    fn new(command_rx: mpsc::Receiver<ServiceCommand>, timeout_tx: mpsc::Sender<TaskId>) -> Self {
646        Self {
647            command_rx,
648            timeout_tx,
649        }
650    }
651
652    async fn run(mut self) {
653        // 使用 FuturesUnordered 来监听所有的 completion_rxs
654        // 每个 future 返回 (TaskId, Result)
655        let mut futures: FuturesUnordered<BoxFuture<'static, (TaskId, Result<TaskCompletionReason, tokio::sync::oneshot::error::RecvError>)>> = FuturesUnordered::new();
656
657        loop {
658            tokio::select! {
659                // 监听超时事件
660                Some((task_id, result)) = futures.next() => {
661                    // 检查完成原因,只转发超时(Expired)事件,不转发取消(Cancelled)事件
662                    if let Ok(TaskCompletionReason::Expired) = result {
663                        let _ = self.timeout_tx.send(task_id).await;
664                    }
665                    // 任务会自动从 FuturesUnordered 中移除
666                }
667                
668                // 监听命令
669                Some(cmd) = self.command_rx.recv() => {
670                    match cmd {
671                        ServiceCommand::AddBatchHandle(batch) => {
672                            let BatchHandle {
673                                task_ids,
674                                completion_rxs,
675                                ..
676                            } = batch;
677                            
678                            // 将所有任务添加到 futures 中
679                            for (task_id, rx) in task_ids.into_iter().zip(completion_rxs.into_iter()) {
680                                let future: BoxFuture<'static, (TaskId, Result<TaskCompletionReason, tokio::sync::oneshot::error::RecvError>)> = Box::pin(async move {
681                                    (task_id, rx.await)
682                                });
683                                futures.push(future);
684                            }
685                        }
686                        ServiceCommand::AddTimerHandle(handle) => {
687                            let TimerHandle{
688                                task_id,
689                                completion_rx,
690                                ..
691                            } = handle;
692                            
693                            // 添加到 futures 中
694                            let future: BoxFuture<'static, (TaskId, Result<TaskCompletionReason, tokio::sync::oneshot::error::RecvError>)> = Box::pin(async move {
695                                (task_id, completion_rx.0.await)
696                            });
697                            futures.push(future);
698                        }
699                        ServiceCommand::Shutdown => {
700                            break;
701                        }
702                    }
703                }
704                
705                // 如果没有任何 future 且命令通道已关闭,退出循环
706                else => {
707                    break;
708                }
709            }
710        }
711    }
712}
713
714#[cfg(test)]
715mod tests {
716    use super::*;
717    use crate::TimerWheel;
718    use std::sync::atomic::{AtomicU32, Ordering};
719    use std::sync::Arc;
720    use std::time::Duration;
721
722    #[tokio::test]
723    async fn test_service_creation() {
724        let timer = TimerWheel::with_defaults();
725        let _service = timer.create_service();
726    }
727
728
729    #[tokio::test]
730    async fn test_add_timer_handle_and_receive_timeout() {
731        let timer = TimerWheel::with_defaults();
732        let mut service = timer.create_service();
733
734        // 创建单个定时器
735        let task = TimerWheel::create_task(Duration::from_millis(50), || async {});
736        let task_id = task.get_id();
737        let handle = timer.register(task);
738
739        // 添加到 service
740        service.add_timer_handle(handle).await;
741
742        // 接收超时通知
743        let mut rx = service.take_receiver().unwrap();
744        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
745            .await
746            .expect("Should receive timeout notification")
747            .expect("Should receive Some value");
748
749        assert_eq!(received_task_id, task_id);
750    }
751
752
753    #[tokio::test]
754    async fn test_shutdown() {
755        let timer = TimerWheel::with_defaults();
756        let service = timer.create_service();
757
758        // 添加一些定时器
759        let task1 = TimerService::create_task(Duration::from_secs(10), || async {});
760        let task2 = TimerService::create_task(Duration::from_secs(10), || async {});
761        service.register(task1).await;
762        service.register(task2).await;
763
764        // 立即关闭(不等待定时器触发)
765        service.shutdown().await;
766    }
767
768
769
770    #[tokio::test]
771    async fn test_cancel_task() {
772        let timer = TimerWheel::with_defaults();
773        let service = timer.create_service();
774
775        // 添加一个长时间的定时器
776        let task = TimerWheel::create_task(Duration::from_secs(10), || async {});
777        let task_id = task.get_id();
778        let handle = timer.register(task);
779        
780        service.add_timer_handle(handle).await;
781
782        // 取消任务
783        let cancelled = service.cancel_task(task_id);
784        assert!(cancelled, "Task should be cancelled successfully");
785
786        // 尝试再次取消同一个任务,应该返回 false
787        let cancelled_again = service.cancel_task(task_id);
788        assert!(!cancelled_again, "Task should not exist anymore");
789    }
790
791    #[tokio::test]
792    async fn test_cancel_nonexistent_task() {
793        let timer = TimerWheel::with_defaults();
794        let service = timer.create_service();
795
796        // 添加一个定时器以初始化 service
797        let task = TimerWheel::create_task(Duration::from_millis(50), || async {});
798        let handle = timer.register(task);
799        service.add_timer_handle(handle).await;
800
801        // 尝试取消一个不存在的任务(创建一个不会实际注册的任务ID)
802        let fake_task = TimerWheel::create_task(Duration::from_millis(50), || async {});
803        let fake_task_id = fake_task.get_id();
804        // 不注册 fake_task
805        let cancelled = service.cancel_task(fake_task_id);
806        assert!(!cancelled, "Nonexistent task should not be cancelled");
807    }
808
809
810    #[tokio::test]
811    async fn test_task_timeout_cleans_up_task_sender() {
812        let timer = TimerWheel::with_defaults();
813        let mut service = timer.create_service();
814
815        // 添加一个短时间的定时器
816        let task = TimerWheel::create_task(Duration::from_millis(50), || async {});
817        let task_id = task.get_id();
818        let handle = timer.register(task);
819        
820        service.add_timer_handle(handle).await;
821
822        // 等待任务超时
823        let mut rx = service.take_receiver().unwrap();
824        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
825            .await
826            .expect("Should receive timeout notification")
827            .expect("Should receive Some value");
828        
829        assert_eq!(received_task_id, task_id);
830
831        // 等待一下确保内部清理完成
832        tokio::time::sleep(Duration::from_millis(10)).await;
833
834        // 尝试取消已经超时的任务,应该返回 false
835        let cancelled = service.cancel_task(task_id);
836        assert!(!cancelled, "Timed out task should not exist anymore");
837    }
838
839    #[tokio::test]
840    async fn test_cancel_task_spawns_background_task() {
841        let timer = TimerWheel::with_defaults();
842        let service = timer.create_service();
843        let counter = Arc::new(AtomicU32::new(0));
844
845        // 创建一个定时器
846        let counter_clone = Arc::clone(&counter);
847        let task = TimerWheel::create_task(
848            Duration::from_secs(10),
849            move || {
850                let counter = Arc::clone(&counter_clone);
851                async move {
852                    counter.fetch_add(1, Ordering::SeqCst);
853                }
854            },
855        );
856        let task_id = task.get_id();
857        let handle = timer.register(task);
858        
859        service.add_timer_handle(handle).await;
860
861        // 使用 cancel_task(会等待结果,但在后台协程中处理)
862        let cancelled = service.cancel_task(task_id);
863        assert!(cancelled, "Task should be cancelled successfully");
864
865        // 等待足够长时间确保回调不会被执行
866        tokio::time::sleep(Duration::from_millis(100)).await;
867        assert_eq!(counter.load(Ordering::SeqCst), 0, "Callback should not have been executed");
868
869        // 验证任务已从 active_tasks 中移除
870        let cancelled_again = service.cancel_task(task_id);
871        assert!(!cancelled_again, "Task should have been removed from active_tasks");
872    }
873
874    #[tokio::test]
875    async fn test_schedule_once_direct() {
876        let timer = TimerWheel::with_defaults();
877        let mut service = timer.create_service();
878        let counter = Arc::new(AtomicU32::new(0));
879
880        // 直接通过 service 调度定时器
881        let counter_clone = Arc::clone(&counter);
882        let task = TimerService::create_task(
883            Duration::from_millis(50),
884            move || {
885                let counter = Arc::clone(&counter_clone);
886                async move {
887                    counter.fetch_add(1, Ordering::SeqCst);
888                }
889            },
890        );
891        let task_id = task.get_id();
892        service.register(task).await;
893
894        // 等待定时器触发
895        let mut rx = service.take_receiver().unwrap();
896        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
897            .await
898            .expect("Should receive timeout notification")
899            .expect("Should receive Some value");
900
901        assert_eq!(received_task_id, task_id);
902        
903        // 等待回调执行
904        tokio::time::sleep(Duration::from_millis(50)).await;
905        assert_eq!(counter.load(Ordering::SeqCst), 1);
906    }
907
908    #[tokio::test]
909    async fn test_schedule_once_batch_direct() {
910        let timer = TimerWheel::with_defaults();
911        let mut service = timer.create_service();
912        let counter = Arc::new(AtomicU32::new(0));
913
914        // 直接通过 service 批量调度定时器
915        let callbacks: Vec<_> = (0..3)
916            .map(|_| {
917                let counter = Arc::clone(&counter);
918                (Duration::from_millis(50), move || {
919                    let counter = Arc::clone(&counter);
920                    async move {
921                        counter.fetch_add(1, Ordering::SeqCst);
922                    }
923                })
924            })
925            .collect();
926
927        let tasks = TimerService::create_batch(callbacks);
928        assert_eq!(tasks.len(), 3);
929        service.register_batch(tasks).await;
930
931        // 接收所有超时通知
932        let mut received_count = 0;
933        let mut rx = service.take_receiver().unwrap();
934        
935        while received_count < 3 {
936            match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
937                Ok(Some(_task_id)) => {
938                    received_count += 1;
939                }
940                Ok(None) => break,
941                Err(_) => break,
942            }
943        }
944
945        assert_eq!(received_count, 3);
946        
947        // 等待回调执行
948        tokio::time::sleep(Duration::from_millis(50)).await;
949        assert_eq!(counter.load(Ordering::SeqCst), 3);
950    }
951
952    #[tokio::test]
953    async fn test_schedule_once_notify_direct() {
954        let timer = TimerWheel::with_defaults();
955        let mut service = timer.create_service();
956
957        // 直接通过 service 调度仅通知的定时器(无回调)
958        let task = crate::task::TimerTask::new(Duration::from_millis(50), None);
959        let task_id = task.get_id();
960        service.register(task).await;
961
962        // 接收超时通知
963        let mut rx = service.take_receiver().unwrap();
964        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
965            .await
966            .expect("Should receive timeout notification")
967            .expect("Should receive Some value");
968
969        assert_eq!(received_task_id, task_id);
970    }
971
972    #[tokio::test]
973    async fn test_schedule_and_cancel_direct() {
974        let timer = TimerWheel::with_defaults();
975        let service = timer.create_service();
976        let counter = Arc::new(AtomicU32::new(0));
977
978        // 直接调度定时器
979        let counter_clone = Arc::clone(&counter);
980        let task = TimerService::create_task(
981            Duration::from_secs(10),
982            move || {
983                let counter = Arc::clone(&counter_clone);
984                async move {
985                    counter.fetch_add(1, Ordering::SeqCst);
986                }
987            },
988        );
989        let task_id = task.get_id();
990        service.register(task).await;
991
992        // 立即取消
993        let cancelled = service.cancel_task(task_id);
994        assert!(cancelled, "Task should be cancelled successfully");
995
996        // 等待确保回调不会执行
997        tokio::time::sleep(Duration::from_millis(100)).await;
998        assert_eq!(counter.load(Ordering::SeqCst), 0, "Callback should not have been executed");
999    }
1000
1001    #[tokio::test]
1002    async fn test_cancel_batch_direct() {
1003        let timer = TimerWheel::with_defaults();
1004        let service = timer.create_service();
1005        let counter = Arc::new(AtomicU32::new(0));
1006
1007        // 批量调度定时器
1008        let callbacks: Vec<_> = (0..10)
1009            .map(|_| {
1010                let counter = Arc::clone(&counter);
1011                (Duration::from_secs(10), move || {
1012                    let counter = Arc::clone(&counter);
1013                    async move {
1014                        counter.fetch_add(1, Ordering::SeqCst);
1015                    }
1016                })
1017            })
1018            .collect();
1019
1020        let tasks = TimerService::create_batch(callbacks);
1021        let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
1022        assert_eq!(task_ids.len(), 10);
1023        service.register_batch(tasks).await;
1024
1025        // 批量取消所有任务
1026        let cancelled = service.cancel_batch(&task_ids);
1027        assert_eq!(cancelled, 10, "All 10 tasks should be cancelled");
1028
1029        // 等待确保回调不会执行
1030        tokio::time::sleep(Duration::from_millis(100)).await;
1031        assert_eq!(counter.load(Ordering::SeqCst), 0, "No callbacks should have been executed");
1032    }
1033
1034    #[tokio::test]
1035    async fn test_cancel_batch_partial() {
1036        let timer = TimerWheel::with_defaults();
1037        let service = timer.create_service();
1038        let counter = Arc::new(AtomicU32::new(0));
1039
1040        // 批量调度定时器
1041        let callbacks: Vec<_> = (0..10)
1042            .map(|_| {
1043                let counter = Arc::clone(&counter);
1044                (Duration::from_secs(10), move || {
1045                    let counter = Arc::clone(&counter);
1046                    async move {
1047                        counter.fetch_add(1, Ordering::SeqCst);
1048                    }
1049                })
1050            })
1051            .collect();
1052
1053        let tasks = TimerService::create_batch(callbacks);
1054        let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
1055        service.register_batch(tasks).await;
1056
1057        // 只取消前5个任务
1058        let to_cancel: Vec<_> = task_ids.iter().take(5).copied().collect();
1059        let cancelled = service.cancel_batch(&to_cancel);
1060        assert_eq!(cancelled, 5, "5 tasks should be cancelled");
1061
1062        // 等待确保前5个回调不会执行
1063        tokio::time::sleep(Duration::from_millis(100)).await;
1064        assert_eq!(counter.load(Ordering::SeqCst), 0, "Cancelled tasks should not execute");
1065    }
1066
1067    #[tokio::test]
1068    async fn test_cancel_batch_empty() {
1069        let timer = TimerWheel::with_defaults();
1070        let service = timer.create_service();
1071
1072        // 取消空列表
1073        let empty: Vec<TaskId> = vec![];
1074        let cancelled = service.cancel_batch(&empty);
1075        assert_eq!(cancelled, 0, "No tasks should be cancelled");
1076    }
1077
1078    #[tokio::test]
1079    async fn test_postpone_task() {
1080        let timer = TimerWheel::with_defaults();
1081        let mut service = timer.create_service();
1082        let counter = Arc::new(AtomicU32::new(0));
1083
1084        // 注册一个任务,延迟 50ms
1085        let counter_clone = Arc::clone(&counter);
1086        let task = TimerService::create_task(
1087            Duration::from_millis(50),
1088            move || {
1089                let counter = Arc::clone(&counter_clone);
1090                async move {
1091                    counter.fetch_add(1, Ordering::SeqCst);
1092                }
1093            },
1094        );
1095        let task_id = task.get_id();
1096        service.register(task).await;
1097
1098        // 推迟任务到 150ms
1099        let postponed = service.postpone_task(task_id, Duration::from_millis(150));
1100        assert!(postponed, "Task should be postponed successfully");
1101
1102        // 等待原定时间 50ms,任务不应该触发
1103        tokio::time::sleep(Duration::from_millis(70)).await;
1104        assert_eq!(counter.load(Ordering::SeqCst), 0);
1105
1106        // 接收超时通知(从推迟开始算,还需要等待约 150ms)
1107        let mut rx = service.take_receiver().unwrap();
1108        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1109            .await
1110            .expect("Should receive timeout notification")
1111            .expect("Should receive Some value");
1112
1113        assert_eq!(received_task_id, task_id);
1114        
1115        // 等待回调执行
1116        tokio::time::sleep(Duration::from_millis(20)).await;
1117        assert_eq!(counter.load(Ordering::SeqCst), 1);
1118    }
1119
1120    #[tokio::test]
1121    async fn test_postpone_task_with_callback() {
1122        let timer = TimerWheel::with_defaults();
1123        let mut service = timer.create_service();
1124        let counter = Arc::new(AtomicU32::new(0));
1125
1126        // 注册一个任务,原始回调增加 1
1127        let counter_clone1 = Arc::clone(&counter);
1128        let task = TimerService::create_task(
1129            Duration::from_millis(50),
1130            move || {
1131                let counter = Arc::clone(&counter_clone1);
1132                async move {
1133                    counter.fetch_add(1, Ordering::SeqCst);
1134                }
1135            },
1136        );
1137        let task_id = task.get_id();
1138        service.register(task).await;
1139
1140        // 推迟任务并替换回调,新回调增加 10
1141        let counter_clone2 = Arc::clone(&counter);
1142        let postponed = service.postpone_task_with_callback(
1143            task_id,
1144            Duration::from_millis(100),
1145            move || {
1146                let counter = Arc::clone(&counter_clone2);
1147                async move {
1148                    counter.fetch_add(10, Ordering::SeqCst);
1149                }
1150            }
1151        );
1152        assert!(postponed, "Task should be postponed successfully");
1153
1154        // 接收超时通知(推迟后需要等待100ms,加上余量)
1155        let mut rx = service.take_receiver().unwrap();
1156        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1157            .await
1158            .expect("Should receive timeout notification")
1159            .expect("Should receive Some value");
1160
1161        assert_eq!(received_task_id, task_id);
1162        
1163        // 等待回调执行
1164        tokio::time::sleep(Duration::from_millis(20)).await;
1165        
1166        // 验证新回调被执行(增加了 10 而不是 1)
1167        assert_eq!(counter.load(Ordering::SeqCst), 10);
1168    }
1169
1170    #[tokio::test]
1171    async fn test_postpone_nonexistent_task() {
1172        let timer = TimerWheel::with_defaults();
1173        let service = timer.create_service();
1174
1175        // 尝试推迟一个不存在的任务
1176        let fake_task = TimerService::create_task(Duration::from_millis(50), || async {});
1177        let fake_task_id = fake_task.get_id();
1178        // 不注册这个任务
1179        
1180        let postponed = service.postpone_task(fake_task_id, Duration::from_millis(100));
1181        assert!(!postponed, "Nonexistent task should not be postponed");
1182    }
1183
1184    #[tokio::test]
1185    async fn test_postpone_batch() {
1186        let timer = TimerWheel::with_defaults();
1187        let mut service = timer.create_service();
1188        let counter = Arc::new(AtomicU32::new(0));
1189
1190        // 注册 3 个任务
1191        let mut task_ids = Vec::new();
1192        for _ in 0..3 {
1193            let counter_clone = Arc::clone(&counter);
1194            let task = TimerService::create_task(
1195                Duration::from_millis(50),
1196                move || {
1197                    let counter = Arc::clone(&counter_clone);
1198                    async move {
1199                        counter.fetch_add(1, Ordering::SeqCst);
1200                    }
1201                },
1202            );
1203            task_ids.push((task.get_id(), Duration::from_millis(150)));
1204            service.register(task).await;
1205        }
1206
1207        // 批量推迟
1208        let postponed = service.postpone_batch(&task_ids);
1209        assert_eq!(postponed, 3, "All 3 tasks should be postponed");
1210
1211        // 等待原定时间 50ms,任务不应该触发
1212        tokio::time::sleep(Duration::from_millis(70)).await;
1213        assert_eq!(counter.load(Ordering::SeqCst), 0);
1214
1215        // 接收所有超时通知
1216        let mut received_count = 0;
1217        let mut rx = service.take_receiver().unwrap();
1218        
1219        while received_count < 3 {
1220            match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
1221                Ok(Some(_task_id)) => {
1222                    received_count += 1;
1223                }
1224                Ok(None) => break,
1225                Err(_) => break,
1226            }
1227        }
1228
1229        assert_eq!(received_count, 3);
1230        
1231        // 等待回调执行
1232        tokio::time::sleep(Duration::from_millis(20)).await;
1233        assert_eq!(counter.load(Ordering::SeqCst), 3);
1234    }
1235
1236    #[tokio::test]
1237    async fn test_postpone_batch_with_callbacks() {
1238        let timer = TimerWheel::with_defaults();
1239        let mut service = timer.create_service();
1240        let counter = Arc::new(AtomicU32::new(0));
1241
1242        // 注册 3 个任务
1243        let mut task_ids = Vec::new();
1244        for _ in 0..3 {
1245            let task = TimerService::create_task(
1246                Duration::from_millis(50),
1247                || async {},
1248            );
1249            task_ids.push(task.get_id());
1250            service.register(task).await;
1251        }
1252
1253        // 批量推迟并替换回调
1254        let updates: Vec<_> = task_ids
1255            .into_iter()
1256            .map(|id| {
1257                let counter_clone = Arc::clone(&counter);
1258                (id, Duration::from_millis(150), move || {
1259                    let counter = Arc::clone(&counter_clone);
1260                    async move {
1261                        counter.fetch_add(1, Ordering::SeqCst);
1262                    }
1263                })
1264            })
1265            .collect();
1266
1267        let postponed = service.postpone_batch_with_callbacks(updates);
1268        assert_eq!(postponed, 3, "All 3 tasks should be postponed");
1269
1270        // 等待原定时间 50ms,任务不应该触发
1271        tokio::time::sleep(Duration::from_millis(70)).await;
1272        assert_eq!(counter.load(Ordering::SeqCst), 0);
1273
1274        // 接收所有超时通知
1275        let mut received_count = 0;
1276        let mut rx = service.take_receiver().unwrap();
1277        
1278        while received_count < 3 {
1279            match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
1280                Ok(Some(_task_id)) => {
1281                    received_count += 1;
1282                }
1283                Ok(None) => break,
1284                Err(_) => break,
1285            }
1286        }
1287
1288        assert_eq!(received_count, 3);
1289        
1290        // 等待回调执行
1291        tokio::time::sleep(Duration::from_millis(20)).await;
1292        assert_eq!(counter.load(Ordering::SeqCst), 3);
1293    }
1294
1295    #[tokio::test]
1296    async fn test_postpone_batch_empty() {
1297        let timer = TimerWheel::with_defaults();
1298        let service = timer.create_service();
1299
1300        // 推迟空列表
1301        let empty: Vec<(TaskId, Duration)> = vec![];
1302        let postponed = service.postpone_batch(&empty);
1303        assert_eq!(postponed, 0, "No tasks should be postponed");
1304    }
1305
1306    #[tokio::test]
1307    async fn test_postpone_keeps_timeout_notification_valid() {
1308        let timer = TimerWheel::with_defaults();
1309        let mut service = timer.create_service();
1310        let counter = Arc::new(AtomicU32::new(0));
1311
1312        // 注册一个任务
1313        let counter_clone = Arc::clone(&counter);
1314        let task = TimerService::create_task(
1315            Duration::from_millis(50),
1316            move || {
1317                let counter = Arc::clone(&counter_clone);
1318                async move {
1319                    counter.fetch_add(1, Ordering::SeqCst);
1320                }
1321            },
1322        );
1323        let task_id = task.get_id();
1324        service.register(task).await;
1325
1326        // 推迟任务
1327        service.postpone_task(task_id, Duration::from_millis(100));
1328
1329        // 验证超时通知仍然有效(推迟后需要等待100ms,加上余量)
1330        let mut rx = service.take_receiver().unwrap();
1331        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1332            .await
1333            .expect("Should receive timeout notification")
1334            .expect("Should receive Some value");
1335
1336        assert_eq!(received_task_id, task_id, "Timeout notification should still work after postpone");
1337        
1338        // 等待回调执行
1339        tokio::time::sleep(Duration::from_millis(20)).await;
1340        assert_eq!(counter.load(Ordering::SeqCst), 1);
1341    }
1342
1343    #[tokio::test]
1344    async fn test_cancelled_task_not_forwarded_to_timeout_rx() {
1345        let timer = TimerWheel::with_defaults();
1346        let mut service = timer.create_service();
1347
1348        // 注册两个任务:一个会被取消,一个会正常到期
1349        let task1 = TimerService::create_task(Duration::from_secs(10), || async {});
1350        let task1_id = task1.get_id();
1351        service.register(task1).await;
1352
1353        let task2 = TimerService::create_task(Duration::from_millis(50), || async {});
1354        let task2_id = task2.get_id();
1355        service.register(task2).await;
1356
1357        // 取消第一个任务
1358        let cancelled = service.cancel_task(task1_id);
1359        assert!(cancelled, "Task should be cancelled");
1360
1361        // 等待第二个任务到期
1362        let mut rx = service.take_receiver().unwrap();
1363        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1364            .await
1365            .expect("Should receive timeout notification")
1366            .expect("Should receive Some value");
1367
1368        // 应该只收到第二个任务(到期的)的通知,不应该收到第一个任务(取消的)的通知
1369        assert_eq!(received_task_id, task2_id, "Should only receive expired task notification");
1370
1371        // 验证没有其他通知(特别是被取消的任务不应该有通知)
1372        let no_more = tokio::time::timeout(Duration::from_millis(50), rx.recv()).await;
1373        assert!(no_more.is_err(), "Should not receive any more notifications");
1374    }
1375}
1376