kestrel_protocol_timer/
service.rs

1use crate::config::ServiceConfig;
2use crate::task::{TaskId, TimerCallback};
3use crate::timer::{BatchHandle, TimerHandle};
4use crate::wheel::Wheel;
5use futures::stream::{FuturesUnordered, StreamExt};
6use futures::future::BoxFuture;
7use parking_lot::Mutex;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::sync::mpsc;
11use tokio::task::JoinHandle;
12
13/// TimerService 命令类型
14enum ServiceCommand {
15    /// 添加批量定时器句柄
16    AddBatchHandle(BatchHandle),
17    /// 添加单个定时器句柄
18    AddTimerHandle(TimerHandle),
19    /// 关闭 Service
20    Shutdown,
21}
22
23/// TimerService - 基于 Actor 模式的定时器服务
24///
25/// 管理多个定时器句柄,监听所有超时事件,并将 TaskId 聚合转发给用户。
26///
27/// # 特性
28/// - 自动监听所有添加的定时器句柄的超时事件
29/// - 超时后自动从内部管理中移除该任务
30/// - 将超时的 TaskId 转发到统一的通道供用户接收
31/// - 支持动态添加 BatchHandle 和 TimerHandle
32///
33/// # 示例
34/// ```no_run
35/// use kestrel_protocol_timer::{TimerWheel, TimerService};
36/// use std::time::Duration;
37///
38/// #[tokio::main]
39/// async fn main() {
40///     let timer = TimerWheel::with_defaults();
41///     let mut service = timer.create_service();
42///     
43///     // 使用两步式 API 通过 service 批量调度定时器
44///     let callbacks: Vec<_> = (0..5)
45///         .map(|_| (Duration::from_millis(100), || async {}))
46///         .collect();
47///     let tasks = TimerService::create_batch(callbacks);
48///     service.register_batch(tasks).await;
49///     
50///     // 接收超时通知
51///     let mut rx = service.take_receiver().unwrap();
52///     while let Some(task_id) = rx.recv().await {
53///         println!("Task {:?} completed", task_id);
54///     }
55/// }
56/// ```
57pub struct TimerService {
58    /// 命令发送端
59    command_tx: mpsc::Sender<ServiceCommand>,
60    /// 超时接收端
61    timeout_rx: Option<mpsc::Receiver<TaskId>>,
62    /// Actor 任务句柄
63    actor_handle: Option<JoinHandle<()>>,
64    /// 时间轮引用(用于直接调度定时器)
65    wheel: Arc<Mutex<Wheel>>,
66}
67
68impl TimerService {
69    /// 创建新的 TimerService
70    ///
71    /// # 参数
72    /// - `wheel`: 时间轮引用
73    /// - `config`: 服务配置
74    ///
75    /// # 注意
76    /// 通常不直接调用此方法,而是使用 `TimerWheel::create_service()` 来创建。
77    ///
78    /// # 示例
79    /// ```no_run
80    /// use kestrel_protocol_timer::TimerWheel;
81    ///
82    /// #[tokio::main]
83    /// async fn main() {
84    ///     let timer = TimerWheel::with_defaults();
85    ///     let mut service = timer.create_service();
86    /// }
87    /// ```
88    pub(crate) fn new(wheel: Arc<Mutex<Wheel>>, config: ServiceConfig) -> Self {
89        let (command_tx, command_rx) = mpsc::channel(config.command_channel_capacity);
90        let (timeout_tx, timeout_rx) = mpsc::channel(config.timeout_channel_capacity);
91
92        let actor = ServiceActor::new(command_rx, timeout_tx);
93        let actor_handle = tokio::spawn(async move {
94            actor.run().await;
95        });
96
97        Self {
98            command_tx,
99            timeout_rx: Some(timeout_rx),
100            actor_handle: Some(actor_handle),
101            wheel,
102        }
103    }
104
105    /// 添加批量定时器句柄(内部方法)
106    async fn add_batch_handle(&self, batch: BatchHandle) {
107        let _ = self.command_tx
108            .send(ServiceCommand::AddBatchHandle(batch))
109            .await;
110    }
111
112    /// 添加单个定时器句柄(内部方法)
113    async fn add_timer_handle(&self, handle: TimerHandle) {
114        let _ = self.command_tx
115            .send(ServiceCommand::AddTimerHandle(handle))
116            .await;
117    }
118
119    /// 获取超时接收器(转移所有权)
120    ///
121    /// # 返回
122    /// 超时通知接收器,如果已经被取走则返回 None
123    ///
124    /// # 注意
125    /// 此方法只能调用一次,因为它会转移接收器的所有权
126    ///
127    /// # 示例
128    /// ```no_run
129    /// # use kestrel_protocol_timer::TimerWheel;
130    /// # use std::time::Duration;
131    /// # #[tokio::main]
132    /// # async fn main() {
133    /// let timer = TimerWheel::with_defaults();
134    /// let mut service = timer.create_service();
135    /// 
136    /// let mut rx = service.take_receiver().unwrap();
137    /// while let Some(task_id) = rx.recv().await {
138    ///     println!("Task {:?} timed out", task_id);
139    /// }
140    /// # }
141    /// ```
142    pub fn take_receiver(&mut self) -> Option<mpsc::Receiver<TaskId>> {
143        self.timeout_rx.take()
144    }
145
146    /// 取消指定的任务
147    ///
148    /// # 参数
149    /// - `task_id`: 要取消的任务 ID
150    ///
151    /// # 返回
152    /// - `Ok(true)`: 任务存在且成功取消
153    /// - `Ok(false)`: 任务不存在或取消失败
154    /// - `Err(String)`: 发送命令失败
155    ///
156    /// # 性能说明
157    /// 此方法使用直接取消优化,不需要等待 Actor 处理,大幅降低延迟
158    ///
159    /// # 示例
160    /// ```no_run
161    /// # use kestrel_protocol_timer::{TimerWheel, TimerService};
162    /// # use std::time::Duration;
163    /// # #[tokio::main]
164    /// # async fn main() {
165    /// let timer = TimerWheel::with_defaults();
166    /// let service = timer.create_service();
167    /// 
168    /// // 使用两步式 API 调度定时器
169    /// let task = TimerService::create_task(Duration::from_secs(10), || async {});
170    /// let task_id = task.get_id();
171    /// service.register(task).await;
172    /// 
173    /// // 取消任务
174    /// let cancelled = service.cancel_task(task_id).await;
175    /// println!("Task cancelled: {}", cancelled);
176    /// # }
177    /// ```
178    #[inline]
179    pub async fn cancel_task(&self, task_id: TaskId) -> bool {
180        // 优化:直接取消任务,无需通知 Actor
181        // FuturesUnordered 会在任务被取消时自动清理
182        let mut wheel = self.wheel.lock();
183        wheel.cancel(task_id)
184    }
185
186    /// 批量取消任务
187    ///
188    /// 使用底层的批量取消操作一次性取消多个任务,性能优于循环调用 cancel_task。
189    ///
190    /// # 参数
191    /// - `task_ids`: 要取消的任务 ID 列表
192    ///
193    /// # 返回
194    /// 成功取消的任务数量
195    ///
196    /// # 示例
197    /// ```no_run
198    /// # use kestrel_protocol_timer::{TimerWheel, TimerService};
199    /// # use std::time::Duration;
200    /// # #[tokio::main]
201    /// # async fn main() {
202    /// let timer = TimerWheel::with_defaults();
203    /// let service = timer.create_service();
204    /// 
205    /// let callbacks: Vec<_> = (0..10)
206    ///     .map(|_| (Duration::from_secs(10), || async {}))
207    ///     .collect();
208    /// let tasks = TimerService::create_batch(callbacks);
209    /// let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
210    /// service.register_batch(tasks).await;
211    /// 
212    /// // 批量取消
213    /// let cancelled = service.cancel_batch(&task_ids).await;
214    /// println!("成功取消 {} 个任务", cancelled);
215    /// # }
216    /// ```
217    #[inline]
218    pub async fn cancel_batch(&self, task_ids: &[TaskId]) -> usize {
219        if task_ids.is_empty() {
220            return 0;
221        }
222
223        // 优化:直接使用底层的批量取消,无需通知 Actor
224        // FuturesUnordered 会在任务被取消时自动清理
225        let mut wheel = self.wheel.lock();
226        wheel.cancel_batch(task_ids)
227    }
228
229    /// 推迟任务(保持原回调)
230    ///
231    /// # 参数
232    /// - `task_id`: 要推迟的任务 ID
233    /// - `new_delay`: 新的延迟时间(从当前时间点重新计算)
234    ///
235    /// # 返回
236    /// - `true`: 任务存在且成功推迟
237    /// - `false`: 任务不存在或推迟失败
238    ///
239    /// # 性能说明
240    /// 此方法使用直接推迟优化,不需要等待 Actor 处理,大幅降低延迟
241    ///
242    /// # 注意
243    /// - 推迟后任务 ID 保持不变
244    /// - 原有的超时通知仍然有效
245    /// - 保持原回调函数不变
246    ///
247    /// # 示例
248    /// ```no_run
249    /// # use kestrel_protocol_timer::{TimerWheel, TimerService};
250    /// # use std::time::Duration;
251    /// # #[tokio::main]
252    /// # async fn main() {
253    /// let timer = TimerWheel::with_defaults();
254    /// let service = timer.create_service();
255    /// 
256    /// let task = TimerService::create_task(Duration::from_secs(5), || async {});
257    /// let task_id = task.get_id();
258    /// service.register(task).await;
259    /// 
260    /// // 推迟到 10 秒后触发
261    /// let success = service.postpone_task(task_id, Duration::from_secs(10)).await;
262    /// println!("推迟成功: {}", success);
263    /// # }
264    /// ```
265    #[inline]
266    pub async fn postpone_task(&self, task_id: TaskId, new_delay: Duration) -> bool {
267        // 优化:直接推迟任务,无需通知 Actor
268        // FuturesUnordered 会继续监听原有的 completion_receiver
269        let mut wheel = self.wheel.lock();
270        wheel.postpone(task_id, new_delay, None)
271    }
272
273    /// 推迟任务(替换回调)
274    ///
275    /// # 参数
276    /// - `task_id`: 要推迟的任务 ID
277    /// - `new_delay`: 新的延迟时间(从当前时间点重新计算)
278    /// - `callback`: 新的回调函数
279    ///
280    /// # 返回
281    /// - `true`: 任务存在且成功推迟
282    /// - `false`: 任务不存在或推迟失败
283    ///
284    /// # 注意
285    /// - 推迟后任务 ID 保持不变
286    /// - 原有的超时通知仍然有效
287    /// - 回调函数会被替换为新的回调
288    ///
289    /// # 示例
290    /// ```no_run
291    /// # use kestrel_protocol_timer::{TimerWheel, TimerService};
292    /// # use std::time::Duration;
293    /// # #[tokio::main]
294    /// # async fn main() {
295    /// let timer = TimerWheel::with_defaults();
296    /// let service = timer.create_service();
297    /// 
298    /// let task = TimerService::create_task(Duration::from_secs(5), || async {
299    ///     println!("Original callback");
300    /// });
301    /// let task_id = task.get_id();
302    /// service.register(task).await;
303    /// 
304    /// // 推迟并替换回调
305    /// let success = service.postpone_task_with_callback(
306    ///     task_id,
307    ///     Duration::from_secs(10),
308    ///     || async { println!("New callback!"); }
309    /// ).await;
310    /// println!("推迟成功: {}", success);
311    /// # }
312    /// ```
313    #[inline]
314    pub async fn postpone_task_with_callback<C>(
315        &self,
316        task_id: TaskId,
317        new_delay: Duration,
318        callback: C,
319    ) -> bool
320    where
321        C: TimerCallback,
322    {
323        use std::sync::Arc;
324        let callback_wrapper = Arc::new(callback) as Arc<dyn TimerCallback>;
325        let mut wheel = self.wheel.lock();
326        wheel.postpone(task_id, new_delay, Some(callback_wrapper))
327    }
328
329    /// 批量推迟任务(保持原回调)
330    ///
331    /// # 参数
332    /// - `updates`: (任务ID, 新延迟) 的元组列表
333    ///
334    /// # 返回
335    /// 成功推迟的任务数量
336    ///
337    /// # 示例
338    /// ```no_run
339    /// # use kestrel_protocol_timer::{TimerWheel, TimerService};
340    /// # use std::time::Duration;
341    /// # #[tokio::main]
342    /// # async fn main() {
343    /// let timer = TimerWheel::with_defaults();
344    /// let service = timer.create_service();
345    /// 
346    /// let callbacks: Vec<_> = (0..3)
347    ///     .map(|_| (Duration::from_secs(5), || async {}))
348    ///     .collect();
349    /// let tasks = TimerService::create_batch(callbacks);
350    /// let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
351    /// service.register_batch(tasks).await;
352    /// 
353    /// // 批量推迟
354    /// let updates: Vec<_> = task_ids
355    ///     .into_iter()
356    ///     .map(|id| (id, Duration::from_secs(10)))
357    ///     .collect();
358    /// let postponed = service.postpone_batch(&updates).await;
359    /// println!("成功推迟 {} 个任务", postponed);
360    /// # }
361    /// ```
362    #[inline]
363    pub async fn postpone_batch(&self, updates: &[(TaskId, Duration)]) -> usize {
364        if updates.is_empty() {
365            return 0;
366        }
367
368        let updates_vec: Vec<_> = updates
369            .iter()
370            .map(|(task_id, delay)| (*task_id, *delay, None))
371            .collect();
372        let mut wheel = self.wheel.lock();
373        wheel.postpone_batch(updates_vec)
374    }
375
376    /// 批量推迟任务(替换回调)
377    ///
378    /// # 参数
379    /// - `updates`: (任务ID, 新延迟, 新回调) 的元组列表
380    ///
381    /// # 返回
382    /// 成功推迟的任务数量
383    ///
384    /// # 示例
385    /// ```no_run
386    /// # use kestrel_protocol_timer::{TimerWheel, TimerService};
387    /// # use std::time::Duration;
388    /// # #[tokio::main]
389    /// # async fn main() {
390    /// let timer = TimerWheel::with_defaults();
391    /// let service = timer.create_service();
392    /// 
393    /// let callbacks: Vec<_> = (0..3)
394    ///     .map(|_| (Duration::from_secs(5), || async {}))
395    ///     .collect();
396    /// let tasks = TimerService::create_batch(callbacks);
397    /// let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
398    /// service.register_batch(tasks).await;
399    /// 
400    /// // 批量推迟并替换回调
401    /// let updates: Vec<_> = task_ids
402    ///     .into_iter()
403    ///     .enumerate()
404    ///     .map(|(i, id)| {
405    ///         (id, Duration::from_secs(10), move || async move {
406    ///             println!("New callback {}", i);
407    ///         })
408    ///     })
409    ///     .collect();
410    /// let postponed = service.postpone_batch_with_callbacks(updates).await;
411    /// println!("成功推迟 {} 个任务", postponed);
412    /// # }
413    /// ```
414    #[inline]
415    pub async fn postpone_batch_with_callbacks<C>(
416        &self,
417        updates: Vec<(TaskId, Duration, C)>,
418    ) -> usize
419    where
420        C: TimerCallback,
421    {
422        if updates.is_empty() {
423            return 0;
424        }
425
426        use std::sync::Arc;
427        let updates_vec: Vec<_> = updates
428            .into_iter()
429            .map(|(task_id, delay, callback)| {
430                let callback_wrapper = Arc::new(callback) as Arc<dyn TimerCallback>;
431                (task_id, delay, Some(callback_wrapper))
432            })
433            .collect();
434        let mut wheel = self.wheel.lock();
435        wheel.postpone_batch(updates_vec)
436    }
437
438    /// 创建定时器任务(静态方法,申请阶段)
439    /// 
440    /// # 参数
441    /// - `delay`: 延迟时间
442    /// - `callback`: 实现了 TimerCallback trait 的回调对象
443    /// 
444    /// # 返回
445    /// 返回 TimerTask,需要通过 `register()` 注册
446    /// 
447    /// # 示例
448    /// ```no_run
449    /// # use kestrel_protocol_timer::{TimerWheel, TimerService};
450    /// # use std::time::Duration;
451    /// # #[tokio::main]
452    /// # async fn main() {
453    /// let timer = TimerWheel::with_defaults();
454    /// let service = timer.create_service();
455    /// 
456    /// // 步骤 1: 创建任务
457    /// let task = TimerService::create_task(Duration::from_millis(100), || async {
458    ///     println!("Timer fired!");
459    /// });
460    /// 
461    /// let task_id = task.get_id();
462    /// println!("Created task: {:?}", task_id);
463    /// 
464    /// // 步骤 2: 注册任务
465    /// service.register(task).await;
466    /// # }
467    /// ```
468    pub fn create_task<C>(delay: Duration, callback: C) -> crate::task::TimerTask
469    where
470        C: TimerCallback,
471    {
472        crate::timer::TimerWheel::create_task(delay, callback)
473    }
474    
475    /// 批量创建定时器任务(静态方法,申请阶段)
476    /// 
477    /// # 参数
478    /// - `callbacks`: (延迟时间, 回调) 的元组列表
479    /// 
480    /// # 返回
481    /// 返回 TimerTask 列表,需要通过 `register_batch()` 注册
482    /// 
483    /// # 示例
484    /// ```no_run
485    /// # use kestrel_protocol_timer::{TimerWheel, TimerService};
486    /// # use std::time::Duration;
487    /// # #[tokio::main]
488    /// # async fn main() {
489    /// let timer = TimerWheel::with_defaults();
490    /// let service = timer.create_service();
491    /// 
492    /// // 步骤 1: 批量创建任务
493    /// let callbacks: Vec<_> = (0..3)
494    ///     .map(|i| (Duration::from_millis(100 * (i + 1)), move || async move {
495    ///         println!("Timer {} fired!", i);
496    ///     }))
497    ///     .collect();
498    /// 
499    /// let tasks = TimerService::create_batch(callbacks);
500    /// println!("Created {} tasks", tasks.len());
501    /// 
502    /// // 步骤 2: 批量注册任务
503    /// service.register_batch(tasks).await;
504    /// # }
505    /// ```
506    pub fn create_batch<C>(callbacks: Vec<(Duration, C)>) -> Vec<crate::task::TimerTask>
507    where
508        C: TimerCallback,
509    {
510        crate::timer::TimerWheel::create_batch(callbacks)
511    }
512    
513    /// 注册定时器任务到服务(注册阶段)
514    /// 
515    /// # 参数
516    /// - `task`: 通过 `create_task()` 创建的任务
517    /// 
518    /// # 示例
519    /// ```no_run
520    /// # use kestrel_protocol_timer::{TimerWheel, TimerService};
521    /// # use std::time::Duration;
522    /// # #[tokio::main]
523    /// # async fn main() {
524    /// let timer = TimerWheel::with_defaults();
525    /// let service = timer.create_service();
526    /// 
527    /// let task = TimerService::create_task(Duration::from_millis(100), || async {
528    ///     println!("Timer fired!");
529    /// });
530    /// let task_id = task.get_id();
531    /// 
532    /// service.register(task).await;
533    /// # }
534    /// ```
535    #[inline]
536    pub async fn register(&self, task: crate::task::TimerTask) {
537        let (completion_tx, completion_rx) = tokio::sync::oneshot::channel();
538        let notifier = crate::task::CompletionNotifier(completion_tx);
539        
540        let delay = task.delay;
541        let task_id = task.id;
542        
543        // 单次加锁完成所有操作
544        {
545            let mut wheel_guard = self.wheel.lock();
546            wheel_guard.insert(delay, task, notifier);
547        }
548        
549        // 创建句柄并添加到服务管理
550        let handle = TimerHandle::new(task_id, self.wheel.clone(), completion_rx);
551        self.add_timer_handle(handle).await;
552    }
553    
554    /// 批量注册定时器任务到服务(注册阶段)
555    /// 
556    /// # 参数
557    /// - `tasks`: 通过 `create_batch()` 创建的任务列表
558    /// 
559    /// # 示例
560    /// ```no_run
561    /// # use kestrel_protocol_timer::{TimerWheel, TimerService};
562    /// # use std::time::Duration;
563    /// # #[tokio::main]
564    /// # async fn main() {
565    /// let timer = TimerWheel::with_defaults();
566    /// let service = timer.create_service();
567    /// 
568    /// let callbacks: Vec<_> = (0..3)
569    ///     .map(|_| (Duration::from_secs(1), || async {}))
570    ///     .collect();
571    /// let tasks = TimerService::create_batch(callbacks);
572    /// 
573    /// service.register_batch(tasks).await;
574    /// # }
575    /// ```
576    #[inline]
577    pub async fn register_batch(&self, tasks: Vec<crate::task::TimerTask>) {
578        let task_count = tasks.len();
579        let mut completion_rxs = Vec::with_capacity(task_count);
580        let mut task_ids = Vec::with_capacity(task_count);
581        let mut prepared_tasks = Vec::with_capacity(task_count);
582        
583        // 步骤1: 准备所有 channels 和 notifiers(无锁)
584        // 优化:使用 for 循环代替 map + collect,避免闭包捕获开销
585        for task in tasks {
586            let (completion_tx, completion_rx) = tokio::sync::oneshot::channel();
587            let notifier = crate::task::CompletionNotifier(completion_tx);
588            
589            task_ids.push(task.id);
590            completion_rxs.push(completion_rx);
591            prepared_tasks.push((task.delay, task, notifier));
592        }
593        
594        // 步骤2: 单次加锁,批量插入
595        {
596            let mut wheel_guard = self.wheel.lock();
597            wheel_guard.insert_batch(prepared_tasks);
598        }
599        
600        // 创建批量句柄并添加到服务管理
601        let batch_handle = BatchHandle::new(task_ids, self.wheel.clone(), completion_rxs);
602        self.add_batch_handle(batch_handle).await;
603    }
604
605    /// 优雅关闭 TimerService
606    ///
607    /// # 示例
608    /// ```no_run
609    /// # use kestrel_protocol_timer::TimerWheel;
610    /// # #[tokio::main]
611    /// # async fn main() {
612    /// let timer = TimerWheel::with_defaults();
613    /// let mut service = timer.create_service();
614    /// 
615    /// // 使用 service...
616    /// 
617    /// service.shutdown().await;
618    /// # }
619    /// ```
620    pub async fn shutdown(mut self) {
621        let _ = self.command_tx.send(ServiceCommand::Shutdown).await;
622        if let Some(handle) = self.actor_handle.take() {
623            let _ = handle.await;
624        }
625    }
626}
627
628
629impl Drop for TimerService {
630    fn drop(&mut self) {
631        if let Some(handle) = self.actor_handle.take() {
632            handle.abort();
633        }
634    }
635}
636
637/// ServiceActor - 内部 Actor 实现
638struct ServiceActor {
639    /// 命令接收端
640    command_rx: mpsc::Receiver<ServiceCommand>,
641    /// 超时发送端
642    timeout_tx: mpsc::Sender<TaskId>,
643}
644
645impl ServiceActor {
646    fn new(command_rx: mpsc::Receiver<ServiceCommand>, timeout_tx: mpsc::Sender<TaskId>) -> Self {
647        Self {
648            command_rx,
649            timeout_tx,
650        }
651    }
652
653    async fn run(mut self) {
654        // 使用 FuturesUnordered 来监听所有的 completion_rxs
655        // 每个 future 返回 (TaskId, Result)
656        let mut futures: FuturesUnordered<BoxFuture<'static, (TaskId, Result<(), tokio::sync::oneshot::error::RecvError>)>> = FuturesUnordered::new();
657
658        loop {
659            tokio::select! {
660                // 监听超时事件
661                Some((task_id, _result)) = futures.next() => {
662                    // 任务超时,转发 TaskId
663                    let _ = self.timeout_tx.send(task_id).await;
664                    // 任务会自动从 FuturesUnordered 中移除
665                }
666                
667                // 监听命令
668                Some(cmd) = self.command_rx.recv() => {
669                    match cmd {
670                        ServiceCommand::AddBatchHandle(batch) => {
671                            let BatchHandle {
672                                task_ids,
673                                completion_rxs,
674                                ..
675                            } = batch;
676                            
677                            // 将所有任务添加到 futures 中
678                            for (task_id, rx) in task_ids.into_iter().zip(completion_rxs.into_iter()) {
679                                let future: BoxFuture<'static, (TaskId, Result<(), tokio::sync::oneshot::error::RecvError>)> = Box::pin(async move {
680                                    (task_id, rx.await)
681                                });
682                                futures.push(future);
683                            }
684                        }
685                        ServiceCommand::AddTimerHandle(handle) => {
686                            let TimerHandle{
687                                task_id,
688                                completion_rx,
689                                ..
690                            } = handle;
691                            
692                            // 添加到 futures 中
693                            let future: BoxFuture<'static, (TaskId, Result<(), tokio::sync::oneshot::error::RecvError>)> = Box::pin(async move {
694                                (task_id, completion_rx.0.await)
695                            });
696                            futures.push(future);
697                        }
698                        ServiceCommand::Shutdown => {
699                            break;
700                        }
701                    }
702                }
703                
704                // 如果没有任何 future 且命令通道已关闭,退出循环
705                else => {
706                    break;
707                }
708            }
709        }
710    }
711}
712
713#[cfg(test)]
714mod tests {
715    use super::*;
716    use crate::TimerWheel;
717    use std::sync::atomic::{AtomicU32, Ordering};
718    use std::sync::Arc;
719    use std::time::Duration;
720
721    #[tokio::test]
722    async fn test_service_creation() {
723        let timer = TimerWheel::with_defaults();
724        let _service = timer.create_service();
725    }
726
727
728    #[tokio::test]
729    async fn test_add_timer_handle_and_receive_timeout() {
730        let timer = TimerWheel::with_defaults();
731        let mut service = timer.create_service();
732
733        // 创建单个定时器
734        let task = TimerWheel::create_task(Duration::from_millis(50), || async {});
735        let task_id = task.get_id();
736        let handle = timer.register(task);
737
738        // 添加到 service
739        service.add_timer_handle(handle).await;
740
741        // 接收超时通知
742        let mut rx = service.take_receiver().unwrap();
743        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
744            .await
745            .expect("Should receive timeout notification")
746            .expect("Should receive Some value");
747
748        assert_eq!(received_task_id, task_id);
749    }
750
751
752    #[tokio::test]
753    async fn test_shutdown() {
754        let timer = TimerWheel::with_defaults();
755        let service = timer.create_service();
756
757        // 添加一些定时器
758        let task1 = TimerService::create_task(Duration::from_secs(10), || async {});
759        let task2 = TimerService::create_task(Duration::from_secs(10), || async {});
760        service.register(task1).await;
761        service.register(task2).await;
762
763        // 立即关闭(不等待定时器触发)
764        service.shutdown().await;
765    }
766
767
768
769    #[tokio::test]
770    async fn test_cancel_task() {
771        let timer = TimerWheel::with_defaults();
772        let service = timer.create_service();
773
774        // 添加一个长时间的定时器
775        let task = TimerWheel::create_task(Duration::from_secs(10), || async {});
776        let task_id = task.get_id();
777        let handle = timer.register(task);
778        
779        service.add_timer_handle(handle).await;
780
781        // 取消任务
782        let cancelled = service.cancel_task(task_id).await;
783        assert!(cancelled, "Task should be cancelled successfully");
784
785        // 尝试再次取消同一个任务,应该返回 false
786        let cancelled_again = service.cancel_task(task_id).await;
787        assert!(!cancelled_again, "Task should not exist anymore");
788    }
789
790    #[tokio::test]
791    async fn test_cancel_nonexistent_task() {
792        let timer = TimerWheel::with_defaults();
793        let service = timer.create_service();
794
795        // 添加一个定时器以初始化 service
796        let task = TimerWheel::create_task(Duration::from_millis(50), || async {});
797        let handle = timer.register(task);
798        service.add_timer_handle(handle).await;
799
800        // 尝试取消一个不存在的任务(创建一个不会实际注册的任务ID)
801        let fake_task = TimerWheel::create_task(Duration::from_millis(50), || async {});
802        let fake_task_id = fake_task.get_id();
803        // 不注册 fake_task
804        let cancelled = service.cancel_task(fake_task_id).await;
805        assert!(!cancelled, "Nonexistent task should not be cancelled");
806    }
807
808
809    #[tokio::test]
810    async fn test_task_timeout_cleans_up_task_sender() {
811        let timer = TimerWheel::with_defaults();
812        let mut service = timer.create_service();
813
814        // 添加一个短时间的定时器
815        let task = TimerWheel::create_task(Duration::from_millis(50), || async {});
816        let task_id = task.get_id();
817        let handle = timer.register(task);
818        
819        service.add_timer_handle(handle).await;
820
821        // 等待任务超时
822        let mut rx = service.take_receiver().unwrap();
823        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
824            .await
825            .expect("Should receive timeout notification")
826            .expect("Should receive Some value");
827        
828        assert_eq!(received_task_id, task_id);
829
830        // 等待一下确保内部清理完成
831        tokio::time::sleep(Duration::from_millis(10)).await;
832
833        // 尝试取消已经超时的任务,应该返回 false
834        let cancelled = service.cancel_task(task_id).await;
835        assert!(!cancelled, "Timed out task should not exist anymore");
836    }
837
838    #[tokio::test]
839    async fn test_cancel_task_spawns_background_task() {
840        let timer = TimerWheel::with_defaults();
841        let service = timer.create_service();
842        let counter = Arc::new(AtomicU32::new(0));
843
844        // 创建一个定时器
845        let counter_clone = Arc::clone(&counter);
846        let task = TimerWheel::create_task(
847            Duration::from_secs(10),
848            move || {
849                let counter = Arc::clone(&counter_clone);
850                async move {
851                    counter.fetch_add(1, Ordering::SeqCst);
852                }
853            },
854        );
855        let task_id = task.get_id();
856        let handle = timer.register(task);
857        
858        service.add_timer_handle(handle).await;
859
860        // 使用 cancel_task(会等待结果,但在后台协程中处理)
861        let cancelled = service.cancel_task(task_id).await;
862        assert!(cancelled, "Task should be cancelled successfully");
863
864        // 等待足够长时间确保回调不会被执行
865        tokio::time::sleep(Duration::from_millis(100)).await;
866        assert_eq!(counter.load(Ordering::SeqCst), 0, "Callback should not have been executed");
867
868        // 验证任务已从 active_tasks 中移除
869        let cancelled_again = service.cancel_task(task_id).await;
870        assert!(!cancelled_again, "Task should have been removed from active_tasks");
871    }
872
873    #[tokio::test]
874    async fn test_schedule_once_direct() {
875        let timer = TimerWheel::with_defaults();
876        let mut service = timer.create_service();
877        let counter = Arc::new(AtomicU32::new(0));
878
879        // 直接通过 service 调度定时器
880        let counter_clone = Arc::clone(&counter);
881        let task = TimerService::create_task(
882            Duration::from_millis(50),
883            move || {
884                let counter = Arc::clone(&counter_clone);
885                async move {
886                    counter.fetch_add(1, Ordering::SeqCst);
887                }
888            },
889        );
890        let task_id = task.get_id();
891        service.register(task).await;
892
893        // 等待定时器触发
894        let mut rx = service.take_receiver().unwrap();
895        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
896            .await
897            .expect("Should receive timeout notification")
898            .expect("Should receive Some value");
899
900        assert_eq!(received_task_id, task_id);
901        
902        // 等待回调执行
903        tokio::time::sleep(Duration::from_millis(50)).await;
904        assert_eq!(counter.load(Ordering::SeqCst), 1);
905    }
906
907    #[tokio::test]
908    async fn test_schedule_once_batch_direct() {
909        let timer = TimerWheel::with_defaults();
910        let mut service = timer.create_service();
911        let counter = Arc::new(AtomicU32::new(0));
912
913        // 直接通过 service 批量调度定时器
914        let callbacks: Vec<_> = (0..3)
915            .map(|_| {
916                let counter = Arc::clone(&counter);
917                (Duration::from_millis(50), move || {
918                    let counter = Arc::clone(&counter);
919                    async move {
920                        counter.fetch_add(1, Ordering::SeqCst);
921                    }
922                })
923            })
924            .collect();
925
926        let tasks = TimerService::create_batch(callbacks);
927        assert_eq!(tasks.len(), 3);
928        service.register_batch(tasks).await;
929
930        // 接收所有超时通知
931        let mut received_count = 0;
932        let mut rx = service.take_receiver().unwrap();
933        
934        while received_count < 3 {
935            match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
936                Ok(Some(_task_id)) => {
937                    received_count += 1;
938                }
939                Ok(None) => break,
940                Err(_) => break,
941            }
942        }
943
944        assert_eq!(received_count, 3);
945        
946        // 等待回调执行
947        tokio::time::sleep(Duration::from_millis(50)).await;
948        assert_eq!(counter.load(Ordering::SeqCst), 3);
949    }
950
951    #[tokio::test]
952    async fn test_schedule_once_notify_direct() {
953        let timer = TimerWheel::with_defaults();
954        let mut service = timer.create_service();
955
956        // 直接通过 service 调度仅通知的定时器(无回调)
957        let task = crate::task::TimerTask::new(Duration::from_millis(50), None);
958        let task_id = task.get_id();
959        service.register(task).await;
960
961        // 接收超时通知
962        let mut rx = service.take_receiver().unwrap();
963        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
964            .await
965            .expect("Should receive timeout notification")
966            .expect("Should receive Some value");
967
968        assert_eq!(received_task_id, task_id);
969    }
970
971    #[tokio::test]
972    async fn test_schedule_and_cancel_direct() {
973        let timer = TimerWheel::with_defaults();
974        let service = timer.create_service();
975        let counter = Arc::new(AtomicU32::new(0));
976
977        // 直接调度定时器
978        let counter_clone = Arc::clone(&counter);
979        let task = TimerService::create_task(
980            Duration::from_secs(10),
981            move || {
982                let counter = Arc::clone(&counter_clone);
983                async move {
984                    counter.fetch_add(1, Ordering::SeqCst);
985                }
986            },
987        );
988        let task_id = task.get_id();
989        service.register(task).await;
990
991        // 立即取消
992        let cancelled = service.cancel_task(task_id).await;
993        assert!(cancelled, "Task should be cancelled successfully");
994
995        // 等待确保回调不会执行
996        tokio::time::sleep(Duration::from_millis(100)).await;
997        assert_eq!(counter.load(Ordering::SeqCst), 0, "Callback should not have been executed");
998    }
999
1000    #[tokio::test]
1001    async fn test_cancel_batch_direct() {
1002        let timer = TimerWheel::with_defaults();
1003        let service = timer.create_service();
1004        let counter = Arc::new(AtomicU32::new(0));
1005
1006        // 批量调度定时器
1007        let callbacks: Vec<_> = (0..10)
1008            .map(|_| {
1009                let counter = Arc::clone(&counter);
1010                (Duration::from_secs(10), move || {
1011                    let counter = Arc::clone(&counter);
1012                    async move {
1013                        counter.fetch_add(1, Ordering::SeqCst);
1014                    }
1015                })
1016            })
1017            .collect();
1018
1019        let tasks = TimerService::create_batch(callbacks);
1020        let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
1021        assert_eq!(task_ids.len(), 10);
1022        service.register_batch(tasks).await;
1023
1024        // 批量取消所有任务
1025        let cancelled = service.cancel_batch(&task_ids).await;
1026        assert_eq!(cancelled, 10, "All 10 tasks should be cancelled");
1027
1028        // 等待确保回调不会执行
1029        tokio::time::sleep(Duration::from_millis(100)).await;
1030        assert_eq!(counter.load(Ordering::SeqCst), 0, "No callbacks should have been executed");
1031    }
1032
1033    #[tokio::test]
1034    async fn test_cancel_batch_partial() {
1035        let timer = TimerWheel::with_defaults();
1036        let service = timer.create_service();
1037        let counter = Arc::new(AtomicU32::new(0));
1038
1039        // 批量调度定时器
1040        let callbacks: Vec<_> = (0..10)
1041            .map(|_| {
1042                let counter = Arc::clone(&counter);
1043                (Duration::from_secs(10), move || {
1044                    let counter = Arc::clone(&counter);
1045                    async move {
1046                        counter.fetch_add(1, Ordering::SeqCst);
1047                    }
1048                })
1049            })
1050            .collect();
1051
1052        let tasks = TimerService::create_batch(callbacks);
1053        let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
1054        service.register_batch(tasks).await;
1055
1056        // 只取消前5个任务
1057        let to_cancel: Vec<_> = task_ids.iter().take(5).copied().collect();
1058        let cancelled = service.cancel_batch(&to_cancel).await;
1059        assert_eq!(cancelled, 5, "5 tasks should be cancelled");
1060
1061        // 等待确保前5个回调不会执行
1062        tokio::time::sleep(Duration::from_millis(100)).await;
1063        assert_eq!(counter.load(Ordering::SeqCst), 0, "Cancelled tasks should not execute");
1064    }
1065
1066    #[tokio::test]
1067    async fn test_cancel_batch_empty() {
1068        let timer = TimerWheel::with_defaults();
1069        let service = timer.create_service();
1070
1071        // 取消空列表
1072        let empty: Vec<TaskId> = vec![];
1073        let cancelled = service.cancel_batch(&empty).await;
1074        assert_eq!(cancelled, 0, "No tasks should be cancelled");
1075    }
1076
1077    #[tokio::test]
1078    async fn test_postpone_task() {
1079        let timer = TimerWheel::with_defaults();
1080        let mut service = timer.create_service();
1081        let counter = Arc::new(AtomicU32::new(0));
1082
1083        // 注册一个任务,延迟 50ms
1084        let counter_clone = Arc::clone(&counter);
1085        let task = TimerService::create_task(
1086            Duration::from_millis(50),
1087            move || {
1088                let counter = Arc::clone(&counter_clone);
1089                async move {
1090                    counter.fetch_add(1, Ordering::SeqCst);
1091                }
1092            },
1093        );
1094        let task_id = task.get_id();
1095        service.register(task).await;
1096
1097        // 推迟任务到 150ms
1098        let postponed = service.postpone_task(task_id, Duration::from_millis(150)).await;
1099        assert!(postponed, "Task should be postponed successfully");
1100
1101        // 等待原定时间 50ms,任务不应该触发
1102        tokio::time::sleep(Duration::from_millis(70)).await;
1103        assert_eq!(counter.load(Ordering::SeqCst), 0);
1104
1105        // 接收超时通知(从推迟开始算,还需要等待约 150ms)
1106        let mut rx = service.take_receiver().unwrap();
1107        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1108            .await
1109            .expect("Should receive timeout notification")
1110            .expect("Should receive Some value");
1111
1112        assert_eq!(received_task_id, task_id);
1113        
1114        // 等待回调执行
1115        tokio::time::sleep(Duration::from_millis(20)).await;
1116        assert_eq!(counter.load(Ordering::SeqCst), 1);
1117    }
1118
1119    #[tokio::test]
1120    async fn test_postpone_task_with_callback() {
1121        let timer = TimerWheel::with_defaults();
1122        let mut service = timer.create_service();
1123        let counter = Arc::new(AtomicU32::new(0));
1124
1125        // 注册一个任务,原始回调增加 1
1126        let counter_clone1 = Arc::clone(&counter);
1127        let task = TimerService::create_task(
1128            Duration::from_millis(50),
1129            move || {
1130                let counter = Arc::clone(&counter_clone1);
1131                async move {
1132                    counter.fetch_add(1, Ordering::SeqCst);
1133                }
1134            },
1135        );
1136        let task_id = task.get_id();
1137        service.register(task).await;
1138
1139        // 推迟任务并替换回调,新回调增加 10
1140        let counter_clone2 = Arc::clone(&counter);
1141        let postponed = service.postpone_task_with_callback(
1142            task_id,
1143            Duration::from_millis(100),
1144            move || {
1145                let counter = Arc::clone(&counter_clone2);
1146                async move {
1147                    counter.fetch_add(10, Ordering::SeqCst);
1148                }
1149            }
1150        ).await;
1151        assert!(postponed, "Task should be postponed successfully");
1152
1153        // 接收超时通知(推迟后需要等待100ms,加上余量)
1154        let mut rx = service.take_receiver().unwrap();
1155        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1156            .await
1157            .expect("Should receive timeout notification")
1158            .expect("Should receive Some value");
1159
1160        assert_eq!(received_task_id, task_id);
1161        
1162        // 等待回调执行
1163        tokio::time::sleep(Duration::from_millis(20)).await;
1164        
1165        // 验证新回调被执行(增加了 10 而不是 1)
1166        assert_eq!(counter.load(Ordering::SeqCst), 10);
1167    }
1168
1169    #[tokio::test]
1170    async fn test_postpone_nonexistent_task() {
1171        let timer = TimerWheel::with_defaults();
1172        let service = timer.create_service();
1173
1174        // 尝试推迟一个不存在的任务
1175        let fake_task = TimerService::create_task(Duration::from_millis(50), || async {});
1176        let fake_task_id = fake_task.get_id();
1177        // 不注册这个任务
1178        
1179        let postponed = service.postpone_task(fake_task_id, Duration::from_millis(100)).await;
1180        assert!(!postponed, "Nonexistent task should not be postponed");
1181    }
1182
1183    #[tokio::test]
1184    async fn test_postpone_batch() {
1185        let timer = TimerWheel::with_defaults();
1186        let mut service = timer.create_service();
1187        let counter = Arc::new(AtomicU32::new(0));
1188
1189        // 注册 3 个任务
1190        let mut task_ids = Vec::new();
1191        for _ in 0..3 {
1192            let counter_clone = Arc::clone(&counter);
1193            let task = TimerService::create_task(
1194                Duration::from_millis(50),
1195                move || {
1196                    let counter = Arc::clone(&counter_clone);
1197                    async move {
1198                        counter.fetch_add(1, Ordering::SeqCst);
1199                    }
1200                },
1201            );
1202            task_ids.push((task.get_id(), Duration::from_millis(150)));
1203            service.register(task).await;
1204        }
1205
1206        // 批量推迟
1207        let postponed = service.postpone_batch(&task_ids).await;
1208        assert_eq!(postponed, 3, "All 3 tasks should be postponed");
1209
1210        // 等待原定时间 50ms,任务不应该触发
1211        tokio::time::sleep(Duration::from_millis(70)).await;
1212        assert_eq!(counter.load(Ordering::SeqCst), 0);
1213
1214        // 接收所有超时通知
1215        let mut received_count = 0;
1216        let mut rx = service.take_receiver().unwrap();
1217        
1218        while received_count < 3 {
1219            match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
1220                Ok(Some(_task_id)) => {
1221                    received_count += 1;
1222                }
1223                Ok(None) => break,
1224                Err(_) => break,
1225            }
1226        }
1227
1228        assert_eq!(received_count, 3);
1229        
1230        // 等待回调执行
1231        tokio::time::sleep(Duration::from_millis(20)).await;
1232        assert_eq!(counter.load(Ordering::SeqCst), 3);
1233    }
1234
1235    #[tokio::test]
1236    async fn test_postpone_batch_with_callbacks() {
1237        let timer = TimerWheel::with_defaults();
1238        let mut service = timer.create_service();
1239        let counter = Arc::new(AtomicU32::new(0));
1240
1241        // 注册 3 个任务
1242        let mut task_ids = Vec::new();
1243        for _ in 0..3 {
1244            let task = TimerService::create_task(
1245                Duration::from_millis(50),
1246                || async {},
1247            );
1248            task_ids.push(task.get_id());
1249            service.register(task).await;
1250        }
1251
1252        // 批量推迟并替换回调
1253        let updates: Vec<_> = task_ids
1254            .into_iter()
1255            .map(|id| {
1256                let counter_clone = Arc::clone(&counter);
1257                (id, Duration::from_millis(150), move || {
1258                    let counter = Arc::clone(&counter_clone);
1259                    async move {
1260                        counter.fetch_add(1, Ordering::SeqCst);
1261                    }
1262                })
1263            })
1264            .collect();
1265
1266        let postponed = service.postpone_batch_with_callbacks(updates).await;
1267        assert_eq!(postponed, 3, "All 3 tasks should be postponed");
1268
1269        // 等待原定时间 50ms,任务不应该触发
1270        tokio::time::sleep(Duration::from_millis(70)).await;
1271        assert_eq!(counter.load(Ordering::SeqCst), 0);
1272
1273        // 接收所有超时通知
1274        let mut received_count = 0;
1275        let mut rx = service.take_receiver().unwrap();
1276        
1277        while received_count < 3 {
1278            match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
1279                Ok(Some(_task_id)) => {
1280                    received_count += 1;
1281                }
1282                Ok(None) => break,
1283                Err(_) => break,
1284            }
1285        }
1286
1287        assert_eq!(received_count, 3);
1288        
1289        // 等待回调执行
1290        tokio::time::sleep(Duration::from_millis(20)).await;
1291        assert_eq!(counter.load(Ordering::SeqCst), 3);
1292    }
1293
1294    #[tokio::test]
1295    async fn test_postpone_batch_empty() {
1296        let timer = TimerWheel::with_defaults();
1297        let service = timer.create_service();
1298
1299        // 推迟空列表
1300        let empty: Vec<(TaskId, Duration)> = vec![];
1301        let postponed = service.postpone_batch(&empty).await;
1302        assert_eq!(postponed, 0, "No tasks should be postponed");
1303    }
1304
1305    #[tokio::test]
1306    async fn test_postpone_keeps_timeout_notification_valid() {
1307        let timer = TimerWheel::with_defaults();
1308        let mut service = timer.create_service();
1309        let counter = Arc::new(AtomicU32::new(0));
1310
1311        // 注册一个任务
1312        let counter_clone = Arc::clone(&counter);
1313        let task = TimerService::create_task(
1314            Duration::from_millis(50),
1315            move || {
1316                let counter = Arc::clone(&counter_clone);
1317                async move {
1318                    counter.fetch_add(1, Ordering::SeqCst);
1319                }
1320            },
1321        );
1322        let task_id = task.get_id();
1323        service.register(task).await;
1324
1325        // 推迟任务
1326        service.postpone_task(task_id, Duration::from_millis(100)).await;
1327
1328        // 验证超时通知仍然有效(推迟后需要等待100ms,加上余量)
1329        let mut rx = service.take_receiver().unwrap();
1330        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1331            .await
1332            .expect("Should receive timeout notification")
1333            .expect("Should receive Some value");
1334
1335        assert_eq!(received_task_id, task_id, "Timeout notification should still work after postpone");
1336        
1337        // 等待回调执行
1338        tokio::time::sleep(Duration::from_millis(20)).await;
1339        assert_eq!(counter.load(Ordering::SeqCst), 1);
1340    }
1341}
1342