kestrel_protocol_timer/
service.rs

1use crate::config::ServiceConfig;
2use crate::error::TimerError;
3use crate::task::{CallbackWrapper, TaskCompletionReason, TaskId};
4use crate::timer::{BatchHandle, TimerHandle};
5use crate::wheel::Wheel;
6use futures::stream::{FuturesUnordered, StreamExt};
7use futures::future::BoxFuture;
8use parking_lot::Mutex;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::sync::mpsc;
12use tokio::task::JoinHandle;
13
14/// TimerService 命令类型
15enum ServiceCommand {
16    /// 添加批量定时器句柄
17    AddBatchHandle(BatchHandle),
18    /// 添加单个定时器句柄
19    AddTimerHandle(TimerHandle),
20    /// 关闭 Service
21    Shutdown,
22}
23
24/// TimerService - 基于 Actor 模式的定时器服务
25///
26/// 管理多个定时器句柄,监听所有超时事件,并将 TaskId 聚合转发给用户。
27///
28/// # 特性
29/// - 自动监听所有添加的定时器句柄的超时事件
30/// - 超时后自动从内部管理中移除该任务
31/// - 将超时的 TaskId 转发到统一的通道供用户接收
32/// - 支持动态添加 BatchHandle 和 TimerHandle
33///
34/// # 示例
35/// ```no_run
36/// use kestrel_protocol_timer::{TimerWheel, TimerService, CallbackWrapper};
37/// use std::time::Duration;
38/// 
39///
40/// #[tokio::main]
41/// async fn main() {
42///     let timer = TimerWheel::with_defaults();
43///     let mut service = timer.create_service();
44///     
45///     // 使用两步式 API 通过 service 批量调度定时器
46///     let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = (0..5)
47///         .map(|_| (Duration::from_millis(100), Some(CallbackWrapper::new(|| async {}))))
48///         .collect();
49///     let tasks = TimerService::create_batch(callbacks);
50///     service.register_batch(tasks).unwrap();
51///     
52///     // 接收超时通知
53///     let mut rx = service.take_receiver().unwrap();
54///     while let Some(task_id) = rx.recv().await {
55///         println!("Task {:?} completed", task_id);
56///     }
57/// }
58/// ```
59pub struct TimerService {
60    /// 命令发送端
61    command_tx: mpsc::Sender<ServiceCommand>,
62    /// 超时接收端
63    timeout_rx: Option<mpsc::Receiver<TaskId>>,
64    /// Actor 任务句柄
65    actor_handle: Option<JoinHandle<()>>,
66    /// 时间轮引用(用于直接调度定时器)
67    wheel: Arc<Mutex<Wheel>>,
68}
69
70impl TimerService {
71    /// 创建新的 TimerService
72    ///
73    /// # 参数
74    /// - `wheel`: 时间轮引用
75    /// - `config`: 服务配置
76    ///
77    /// # 注意
78    /// 通常不直接调用此方法,而是使用 `TimerWheel::create_service()` 来创建。
79    ///
80    /// # 示例
81    /// ```no_run
82    /// use kestrel_protocol_timer::TimerWheel;
83    ///
84    /// #[tokio::main]
85    /// async fn main() {
86    ///     let timer = TimerWheel::with_defaults();
87    ///     let mut service = timer.create_service();
88    /// }
89    /// ```
90    pub(crate) fn new(wheel: Arc<Mutex<Wheel>>, config: ServiceConfig) -> Self {
91        let (command_tx, command_rx) = mpsc::channel(config.command_channel_capacity);
92        let (timeout_tx, timeout_rx) = mpsc::channel(config.timeout_channel_capacity);
93
94        let actor = ServiceActor::new(command_rx, timeout_tx);
95        let actor_handle = tokio::spawn(async move {
96            actor.run().await;
97        });
98
99        Self {
100            command_tx,
101            timeout_rx: Some(timeout_rx),
102            actor_handle: Some(actor_handle),
103            wheel,
104        }
105    }
106
107    /// 获取超时接收器(转移所有权)
108    ///
109    /// # 返回
110    /// 超时通知接收器,如果已经被取走则返回 None
111    ///
112    /// # 注意
113    /// 此方法只能调用一次,因为它会转移接收器的所有权
114    ///
115    /// # 示例
116    /// ```no_run
117    /// # use kestrel_protocol_timer::TimerWheel;
118    /// # use std::time::Duration;
119    /// # #[tokio::main]
120    /// # async fn main() {
121    /// let timer = TimerWheel::with_defaults();
122    /// let mut service = timer.create_service();
123    /// 
124    /// let mut rx = service.take_receiver().unwrap();
125    /// while let Some(task_id) = rx.recv().await {
126    ///     println!("Task {:?} timed out", task_id);
127    /// }
128    /// # }
129    /// ```
130    pub fn take_receiver(&mut self) -> Option<mpsc::Receiver<TaskId>> {
131        self.timeout_rx.take()
132    }
133
134    /// 取消指定的任务
135    ///
136    /// # 参数
137    /// - `task_id`: 要取消的任务 ID
138    ///
139    /// # 返回
140    /// - `true`: 任务存在且成功取消
141    /// - `false`: 任务不存在或取消失败
142    ///
143    /// # 性能说明
144    /// 此方法使用直接取消优化,不需要等待 Actor 处理,大幅降低延迟
145    ///
146    /// # 示例
147    /// ```no_run
148    /// # use kestrel_protocol_timer::{TimerWheel, TimerService, CallbackWrapper};
149    /// # use std::time::Duration;
150    /// # 
151    /// # #[tokio::main]
152    /// # async fn main() {
153    /// let timer = TimerWheel::with_defaults();
154    /// let service = timer.create_service();
155    /// 
156    /// // 使用两步式 API 调度定时器
157    /// let callback = Some(CallbackWrapper::new(|| async {}));
158    /// let task = TimerService::create_task(Duration::from_secs(10), callback);
159    /// let task_id = task.get_id();
160    /// service.register(task).unwrap();
161    /// 
162    /// // 取消任务
163    /// let cancelled = service.cancel_task(task_id);
164    /// println!("Task cancelled: {}", cancelled);
165    /// # }
166    /// ```
167    #[inline]
168    pub fn cancel_task(&self, task_id: TaskId) -> bool {
169        // 优化:直接取消任务,无需通知 Actor
170        // FuturesUnordered 会在任务被取消时自动清理
171        let mut wheel = self.wheel.lock();
172        wheel.cancel(task_id)
173    }
174
175    /// 批量取消任务
176    ///
177    /// 使用底层的批量取消操作一次性取消多个任务,性能优于循环调用 cancel_task。
178    ///
179    /// # 参数
180    /// - `task_ids`: 要取消的任务 ID 列表
181    ///
182    /// # 返回
183    /// 成功取消的任务数量
184    ///
185    /// # 示例
186    /// ```no_run
187    /// # use kestrel_protocol_timer::{TimerWheel, TimerService, CallbackWrapper};
188    /// # use std::time::Duration;
189    /// # 
190    /// # #[tokio::main]
191    /// # async fn main() {
192    /// let timer = TimerWheel::with_defaults();
193    /// let service = timer.create_service();
194    /// 
195    /// let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = (0..10)
196    ///     .map(|_| (Duration::from_secs(10), Some(CallbackWrapper::new(|| async {}))))
197    ///     .collect();
198    /// let tasks = TimerService::create_batch(callbacks);
199    /// let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
200    /// service.register_batch(tasks).unwrap();
201    /// 
202    /// // 批量取消
203    /// let cancelled = service.cancel_batch(&task_ids);
204    /// println!("成功取消 {} 个任务", cancelled);
205    /// # }
206    /// ```
207    #[inline]
208    pub fn cancel_batch(&self, task_ids: &[TaskId]) -> usize {
209        if task_ids.is_empty() {
210            return 0;
211        }
212
213        // 优化:直接使用底层的批量取消,无需通知 Actor
214        // FuturesUnordered 会在任务被取消时自动清理
215        let mut wheel = self.wheel.lock();
216        wheel.cancel_batch(task_ids)
217    }
218
219    /// 推迟任务(替换回调)
220    ///
221    /// # 参数
222    /// - `task_id`: 要推迟的任务 ID
223    /// - `new_delay`: 新的延迟时间(从当前时间点重新计算)
224    /// - `callback`: 新的回调函数
225    ///
226    /// # 返回
227    /// - `true`: 任务存在且成功推迟
228    /// - `false`: 任务不存在或推迟失败
229    ///
230    /// # 注意
231    /// - 推迟后任务 ID 保持不变
232    /// - 原有的超时通知仍然有效
233    /// - 回调函数会被替换为新的回调
234    ///
235    /// # 示例
236    /// ```no_run
237    /// # use kestrel_protocol_timer::{TimerWheel, TimerService, CallbackWrapper};
238    /// # use std::time::Duration;
239    /// # 
240    /// # #[tokio::main]
241    /// # async fn main() {
242    /// let timer = TimerWheel::with_defaults();
243    /// let service = timer.create_service();
244    /// 
245    /// let callback = Some(CallbackWrapper::new(|| async {
246    ///     println!("Original callback");
247    /// }));
248    /// let task = TimerService::create_task(Duration::from_secs(5), callback);
249    /// let task_id = task.get_id();
250    /// service.register(task).unwrap();
251    /// 
252    /// // 推迟并替换回调
253    /// let new_callback = Some(CallbackWrapper::new(|| async { println!("New callback!"); }));
254    /// let success = service.postpone(
255    ///     task_id,
256    ///     Duration::from_secs(10),
257    ///     new_callback
258    /// );
259    /// println!("推迟成功: {}", success);
260    /// # }
261    /// ```
262    #[inline]
263    pub fn postpone(&self, task_id: TaskId, new_delay: Duration, callback: Option<CallbackWrapper>) -> bool {
264        let mut wheel = self.wheel.lock();
265        wheel.postpone(task_id, new_delay, callback)
266    }
267
268    /// 批量推迟任务(保持原回调)
269    ///
270    /// # 参数
271    /// - `updates`: (任务ID, 新延迟) 的元组列表
272    ///
273    /// # 返回
274    /// 成功推迟的任务数量
275    ///
276    /// # 示例
277    /// ```no_run
278    /// # use kestrel_protocol_timer::{TimerWheel, TimerService, CallbackWrapper};
279    /// # use std::time::Duration;
280    /// # 
281    /// # #[tokio::main]
282    /// # async fn main() {
283    /// let timer = TimerWheel::with_defaults();
284    /// let service = timer.create_service();
285    /// 
286    /// let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = (0..3)
287    ///     .map(|_| (Duration::from_secs(5), Some(CallbackWrapper::new(|| async {}))))
288    ///     .collect();
289    /// let tasks = TimerService::create_batch(callbacks);
290    /// let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
291    /// service.register_batch(tasks).unwrap();
292    /// 
293    /// // 批量推迟
294    /// let updates: Vec<_> = task_ids
295    ///     .into_iter()
296    ///     .map(|id| (id, Duration::from_secs(10)))
297    ///     .collect();
298    /// let postponed = service.postpone_batch(updates);
299    /// println!("成功推迟 {} 个任务", postponed);
300    /// # }
301    /// ```
302    #[inline]
303    pub fn postpone_batch(&self, updates: Vec<(TaskId, Duration)>) -> usize {
304        if updates.is_empty() {
305            return 0;
306        }
307
308        let mut wheel = self.wheel.lock();
309        wheel.postpone_batch(updates)
310    }
311
312    /// 批量推迟任务(替换回调)
313    ///
314    /// # 参数
315    /// - `updates`: (任务ID, 新延迟, 新回调) 的元组列表
316    ///
317    /// # 返回
318    /// 成功推迟的任务数量
319    ///
320    /// # 示例
321    /// ```no_run
322    /// # use kestrel_protocol_timer::{TimerWheel, TimerService, CallbackWrapper};
323    /// # use std::time::Duration;
324    /// # 
325    /// # #[tokio::main]
326    /// # async fn main() {
327    /// let timer = TimerWheel::with_defaults();
328    /// let service = timer.create_service();
329    /// 
330    /// let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = (0..3)
331    ///     .map(|_| (Duration::from_secs(5), None))
332    ///     .collect();
333    /// let tasks = TimerService::create_batch(callbacks);
334    /// let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
335    /// service.register_batch(tasks).unwrap();
336    /// 
337    /// // 批量推迟并替换回调
338    /// let updates: Vec<_> = task_ids
339    ///     .into_iter()
340    ///     .enumerate()
341    ///     .map(|(i, id)| {
342    ///         let callback = Some(CallbackWrapper::new(move || async move {
343    ///             println!("New callback {}", i);
344    ///         }));
345    ///         (id, Duration::from_secs(10), callback)
346    ///     })
347    ///     .collect();
348    /// let postponed = service.postpone_batch_with_callbacks(updates);
349    /// println!("成功推迟 {} 个任务", postponed);
350    /// # }
351    /// ```
352    #[inline]
353    pub fn postpone_batch_with_callbacks(
354        &self,
355        updates: Vec<(TaskId, Duration, Option<CallbackWrapper>)>,
356    ) -> usize {
357        if updates.is_empty() {
358            return 0;
359        }
360
361        let mut wheel = self.wheel.lock();
362        wheel.postpone_batch_with_callbacks(updates)
363    }
364
365    /// 创建定时器任务(静态方法,申请阶段)
366    /// 
367    /// # 参数
368    /// - `delay`: 延迟时间
369    /// - `callback`: 实现了 TimerCallback trait 的回调对象
370    /// 
371    /// # 返回
372    /// 返回 TimerTask,需要通过 `register()` 注册
373    /// 
374    /// # 示例
375    /// ```no_run
376    /// # use kestrel_protocol_timer::{TimerWheel, TimerService, CallbackWrapper};
377    /// # use std::time::Duration;
378    /// # 
379    /// # #[tokio::main]
380    /// # async fn main() {
381    /// let timer = TimerWheel::with_defaults();
382    /// let service = timer.create_service();
383    /// 
384    /// // 步骤 1: 创建任务
385    /// let callback = Some(CallbackWrapper::new(|| async {
386    ///     println!("Timer fired!");
387    /// }));
388    /// let task = TimerService::create_task(Duration::from_millis(100), callback);
389    /// 
390    /// let task_id = task.get_id();
391    /// println!("Created task: {:?}", task_id);
392    /// 
393    /// // 步骤 2: 注册任务
394    /// service.register(task).unwrap();
395    /// # }
396    /// ```
397    #[inline]
398    pub fn create_task(delay: Duration, callback: Option<CallbackWrapper>) -> crate::task::TimerTask {
399        crate::timer::TimerWheel::create_task(delay, callback)
400    }
401    
402    /// 批量创建定时器任务(静态方法,申请阶段)
403    /// 
404    /// # 参数
405    /// - `callbacks`: (延迟时间, 回调) 的元组列表
406    /// 
407    /// # 返回
408    /// 返回 TimerTask 列表,需要通过 `register_batch()` 注册
409    /// 
410    /// # 示例
411    /// ```no_run
412    /// # use kestrel_protocol_timer::{TimerWheel, TimerService, CallbackWrapper};
413    /// # use std::time::Duration;
414    /// # 
415    /// # #[tokio::main]
416    /// # async fn main() {
417    /// let timer = TimerWheel::with_defaults();
418    /// let service = timer.create_service();
419    /// 
420    /// // 步骤 1: 批量创建任务
421    /// let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = (0..3)
422    ///     .map(|i| {
423    ///         let callback = Some(CallbackWrapper::new(move || async move {
424    ///             println!("Timer {} fired!", i);
425    ///         }));
426    ///         (Duration::from_millis(100 * (i + 1)), callback)
427    ///     })
428    ///     .collect();
429    /// 
430    /// let tasks = TimerService::create_batch(callbacks);
431    /// println!("Created {} tasks", tasks.len());
432    /// 
433    /// // 步骤 2: 批量注册任务
434    /// service.register_batch(tasks).unwrap();
435    /// # }
436    /// ```
437    #[inline]
438    pub fn create_batch(callbacks: Vec<(Duration, Option<CallbackWrapper>)>) -> Vec<crate::task::TimerTask> {
439        crate::timer::TimerWheel::create_batch(callbacks)
440    }   
441    
442    /// 注册定时器任务到服务(注册阶段)
443    /// 
444    /// # 参数
445    /// - `task`: 通过 `create_task()` 创建的任务
446    /// 
447    /// # 返回
448    /// - `Ok(())`: 注册成功
449    /// - `Err(TimerError::RegisterFailed)`: 注册失败(内部通道已满或已关闭)
450    /// 
451    /// # 示例
452    /// ```no_run
453    /// # use kestrel_protocol_timer::{TimerWheel, TimerService, CallbackWrapper};
454    /// # use std::time::Duration;
455    /// # 
456    /// # #[tokio::main]
457    /// # async fn main() {
458    /// let timer = TimerWheel::with_defaults();
459    /// let service = timer.create_service();
460    /// 
461    /// let callback = Some(CallbackWrapper::new(|| async {
462    ///     println!("Timer fired!");
463    /// }));
464    /// let task = TimerService::create_task(Duration::from_millis(100), callback);
465    /// let task_id = task.get_id();
466    /// 
467    /// service.register(task).unwrap();
468    /// # }
469    /// ```
470    #[inline]
471    pub fn register(&self, task: crate::task::TimerTask) -> Result<(), TimerError> {
472        let (completion_tx, completion_rx) = tokio::sync::oneshot::channel();
473        let notifier = crate::task::CompletionNotifier(completion_tx);
474        
475        let delay = task.delay;
476        let task_id = task.id;
477        
478        // 单次加锁完成所有操作
479        {
480            let mut wheel_guard = self.wheel.lock();
481            wheel_guard.insert(delay, task, notifier);
482        }
483        
484        // 创建句柄并添加到服务管理
485        let handle = TimerHandle::new(task_id, self.wheel.clone(), completion_rx);
486        self.command_tx
487            .try_send(ServiceCommand::AddTimerHandle(handle))
488            .map_err(|_| TimerError::RegisterFailed)?;
489        
490        Ok(())
491    }
492    
493    /// 批量注册定时器任务到服务(注册阶段)
494    /// 
495    /// # 参数
496    /// - `tasks`: 通过 `create_batch()` 创建的任务列表
497    /// 
498    /// # 返回
499    /// - `Ok(())`: 注册成功
500    /// - `Err(TimerError::RegisterFailed)`: 注册失败(内部通道已满或已关闭)
501    /// 
502    /// # 示例
503    /// ```no_run
504    /// # use kestrel_protocol_timer::{TimerWheel, TimerService, CallbackWrapper};
505    /// # use std::time::Duration;
506    /// # 
507    /// # #[tokio::main]
508    /// # async fn main() {
509    /// let timer = TimerWheel::with_defaults();
510    /// let service = timer.create_service();
511    /// 
512    /// let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = (0..3)
513    ///     .map(|_| (Duration::from_secs(1), Some(CallbackWrapper::new(|| async {}))))
514    ///     .collect();
515    /// let tasks = TimerService::create_batch(callbacks);
516    /// 
517    /// service.register_batch(tasks).unwrap();
518    /// # }
519    /// ```
520    #[inline]
521    pub fn register_batch(&self, tasks: Vec<crate::task::TimerTask>) -> Result<(), TimerError> {
522        let task_count = tasks.len();
523        let mut completion_rxs = Vec::with_capacity(task_count);
524        let mut task_ids = Vec::with_capacity(task_count);
525        let mut prepared_tasks = Vec::with_capacity(task_count);
526        
527        // 步骤1: 准备所有 channels 和 notifiers(无锁)
528        // 优化:使用 for 循环代替 map + collect,避免闭包捕获开销
529        for task in tasks {
530            let (completion_tx, completion_rx) = tokio::sync::oneshot::channel();
531            let notifier = crate::task::CompletionNotifier(completion_tx);
532            
533            task_ids.push(task.id);
534            completion_rxs.push(completion_rx);
535            prepared_tasks.push((task.delay, task, notifier));
536        }
537        
538        // 步骤2: 单次加锁,批量插入
539        {
540            let mut wheel_guard = self.wheel.lock();
541            wheel_guard.insert_batch(prepared_tasks);
542        }
543        
544        // 创建批量句柄并添加到服务管理
545        let batch_handle = BatchHandle::new(task_ids, self.wheel.clone(), completion_rxs);
546        self.command_tx
547            .try_send(ServiceCommand::AddBatchHandle(batch_handle))
548            .map_err(|_| TimerError::RegisterFailed)?;
549        
550        Ok(())
551    }
552
553    /// 优雅关闭 TimerService
554    ///
555    /// # 示例
556    /// ```no_run
557    /// # use kestrel_protocol_timer::TimerWheel;
558    /// # #[tokio::main]
559    /// # async fn main() {
560    /// let timer = TimerWheel::with_defaults();
561    /// let mut service = timer.create_service();
562    /// 
563    /// // 使用 service...
564    /// 
565    /// service.shutdown().await;
566    /// # }
567    /// ```
568    pub async fn shutdown(mut self) {
569        let _ = self.command_tx.send(ServiceCommand::Shutdown).await;
570        if let Some(handle) = self.actor_handle.take() {
571            let _ = handle.await;
572        }
573    }
574}
575
576
577impl Drop for TimerService {
578    fn drop(&mut self) {
579        if let Some(handle) = self.actor_handle.take() {
580            handle.abort();
581        }
582    }
583}
584
585/// ServiceActor - 内部 Actor 实现
586struct ServiceActor {
587    /// 命令接收端
588    command_rx: mpsc::Receiver<ServiceCommand>,
589    /// 超时发送端
590    timeout_tx: mpsc::Sender<TaskId>,
591}
592
593impl ServiceActor {
594    fn new(command_rx: mpsc::Receiver<ServiceCommand>, timeout_tx: mpsc::Sender<TaskId>) -> Self {
595        Self {
596            command_rx,
597            timeout_tx,
598        }
599    }
600
601    async fn run(mut self) {
602        // 使用 FuturesUnordered 来监听所有的 completion_rxs
603        // 每个 future 返回 (TaskId, Result)
604        let mut futures: FuturesUnordered<BoxFuture<'static, (TaskId, Result<TaskCompletionReason, tokio::sync::oneshot::error::RecvError>)>> = FuturesUnordered::new();
605
606        loop {
607            tokio::select! {
608                // 监听超时事件
609                Some((task_id, result)) = futures.next() => {
610                    // 检查完成原因,只转发超时(Expired)事件,不转发取消(Cancelled)事件
611                    if let Ok(TaskCompletionReason::Expired) = result {
612                        let _ = self.timeout_tx.send(task_id).await;
613                    }
614                    // 任务会自动从 FuturesUnordered 中移除
615                }
616                
617                // 监听命令
618                Some(cmd) = self.command_rx.recv() => {
619                    match cmd {
620                        ServiceCommand::AddBatchHandle(batch) => {
621                            let BatchHandle {
622                                task_ids,
623                                completion_rxs,
624                                ..
625                            } = batch;
626                            
627                            // 将所有任务添加到 futures 中
628                            for (task_id, rx) in task_ids.into_iter().zip(completion_rxs.into_iter()) {
629                                let future: BoxFuture<'static, (TaskId, Result<TaskCompletionReason, tokio::sync::oneshot::error::RecvError>)> = Box::pin(async move {
630                                    (task_id, rx.await)
631                                });
632                                futures.push(future);
633                            }
634                        }
635                        ServiceCommand::AddTimerHandle(handle) => {
636                            let TimerHandle{
637                                task_id,
638                                completion_rx,
639                                ..
640                            } = handle;
641                            
642                            // 添加到 futures 中
643                            let future: BoxFuture<'static, (TaskId, Result<TaskCompletionReason, tokio::sync::oneshot::error::RecvError>)> = Box::pin(async move {
644                                (task_id, completion_rx.0.await)
645                            });
646                            futures.push(future);
647                        }
648                        ServiceCommand::Shutdown => {
649                            break;
650                        }
651                    }
652                }
653                
654                // 如果没有任何 future 且命令通道已关闭,退出循环
655                else => {
656                    break;
657                }
658            }
659        }
660    }
661}
662
663#[cfg(test)]
664mod tests {
665    use super::*;
666    use crate::TimerWheel;
667    use std::sync::atomic::{AtomicU32, Ordering};
668    use std::sync::Arc;
669    use std::time::Duration;
670
671    #[tokio::test]
672    async fn test_service_creation() {
673        let timer = TimerWheel::with_defaults();
674        let _service = timer.create_service();
675    }
676
677
678    #[tokio::test]
679    async fn test_add_timer_handle_and_receive_timeout() {
680        let timer = TimerWheel::with_defaults();
681        let mut service = timer.create_service();
682
683        // 创建单个定时器
684        let task = TimerService::create_task(Duration::from_millis(50), Some(CallbackWrapper::new(|| async {})));
685        let task_id = task.get_id();
686        
687        // 注册到 service
688        service.register(task).unwrap();
689
690        // 接收超时通知
691        let mut rx = service.take_receiver().unwrap();
692        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
693            .await
694            .expect("Should receive timeout notification")
695            .expect("Should receive Some value");
696
697        assert_eq!(received_task_id, task_id);
698    }
699
700
701    #[tokio::test]
702    async fn test_shutdown() {
703        let timer = TimerWheel::with_defaults();
704        let service = timer.create_service();
705
706        // 添加一些定时器
707        let task1 = TimerService::create_task(Duration::from_secs(10), None);
708        let task2 = TimerService::create_task(Duration::from_secs(10), None);
709        service.register(task1).unwrap();
710        service.register(task2).unwrap();
711
712        // 立即关闭(不等待定时器触发)
713        service.shutdown().await;
714    }
715
716
717
718    #[tokio::test]
719    async fn test_cancel_task() {
720        let timer = TimerWheel::with_defaults();
721        let service = timer.create_service();
722
723        // 添加一个长时间的定时器
724        let task = TimerService::create_task(Duration::from_secs(10), None);
725        let task_id = task.get_id();
726        
727        service.register(task).unwrap();
728
729        // 取消任务
730        let cancelled = service.cancel_task(task_id);
731        assert!(cancelled, "Task should be cancelled successfully");
732
733        // 尝试再次取消同一个任务,应该返回 false
734        let cancelled_again = service.cancel_task(task_id);
735        assert!(!cancelled_again, "Task should not exist anymore");
736    }
737
738    #[tokio::test]
739    async fn test_cancel_nonexistent_task() {
740        let timer = TimerWheel::with_defaults();
741        let service = timer.create_service();
742
743        // 添加一个定时器以初始化 service
744        let task = TimerService::create_task(Duration::from_millis(50), None);
745        service.register(task).unwrap();
746
747        // 尝试取消一个不存在的任务(创建一个不会实际注册的任务ID)
748        let fake_task = TimerService::create_task(Duration::from_millis(50), None);
749        let fake_task_id = fake_task.get_id();
750        // 不注册 fake_task
751        let cancelled = service.cancel_task(fake_task_id);
752        assert!(!cancelled, "Nonexistent task should not be cancelled");
753    }
754
755
756    #[tokio::test]
757    async fn test_task_timeout_cleans_up_task_sender() {
758        let timer = TimerWheel::with_defaults();
759        let mut service = timer.create_service();
760
761        // 添加一个短时间的定时器
762        let task = TimerService::create_task(Duration::from_millis(50), None);
763        let task_id = task.get_id();
764        
765        service.register(task).unwrap();
766
767        // 等待任务超时
768        let mut rx = service.take_receiver().unwrap();
769        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
770            .await
771            .expect("Should receive timeout notification")
772            .expect("Should receive Some value");
773        
774        assert_eq!(received_task_id, task_id);
775
776        // 等待一下确保内部清理完成
777        tokio::time::sleep(Duration::from_millis(10)).await;
778
779        // 尝试取消已经超时的任务,应该返回 false
780        let cancelled = service.cancel_task(task_id);
781        assert!(!cancelled, "Timed out task should not exist anymore");
782    }
783
784    #[tokio::test]
785    async fn test_cancel_task_spawns_background_task() {
786        let timer = TimerWheel::with_defaults();
787        let service = timer.create_service();
788        let counter = Arc::new(AtomicU32::new(0));
789
790        // 创建一个定时器
791        let counter_clone = Arc::clone(&counter);
792        let task = TimerService::create_task(
793            Duration::from_secs(10),
794            Some(CallbackWrapper::new(move || {
795                let counter = Arc::clone(&counter_clone);
796                async move {
797                    counter.fetch_add(1, Ordering::SeqCst);
798                }
799            })),
800        );
801        let task_id = task.get_id();
802        
803        service.register(task).unwrap();
804
805        // 使用 cancel_task(会等待结果,但在后台协程中处理)
806        let cancelled = service.cancel_task(task_id);
807        assert!(cancelled, "Task should be cancelled successfully");
808
809        // 等待足够长时间确保回调不会被执行
810        tokio::time::sleep(Duration::from_millis(100)).await;
811        assert_eq!(counter.load(Ordering::SeqCst), 0, "Callback should not have been executed");
812
813        // 验证任务已从 active_tasks 中移除
814        let cancelled_again = service.cancel_task(task_id);
815        assert!(!cancelled_again, "Task should have been removed from active_tasks");
816    }
817
818    #[tokio::test]
819    async fn test_schedule_once_direct() {
820        let timer = TimerWheel::with_defaults();
821        let mut service = timer.create_service();
822        let counter = Arc::new(AtomicU32::new(0));
823
824        // 直接通过 service 调度定时器
825        let counter_clone = Arc::clone(&counter);
826        let task = TimerService::create_task(
827            Duration::from_millis(50),
828            Some(CallbackWrapper::new(move || {
829                let counter = Arc::clone(&counter_clone);
830                async move {
831                    counter.fetch_add(1, Ordering::SeqCst);
832                }
833            })),
834        );
835        let task_id = task.get_id();
836        service.register(task).unwrap();
837
838        // 等待定时器触发
839        let mut rx = service.take_receiver().unwrap();
840        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
841            .await
842            .expect("Should receive timeout notification")
843            .expect("Should receive Some value");
844
845        assert_eq!(received_task_id, task_id);
846        
847        // 等待回调执行
848        tokio::time::sleep(Duration::from_millis(50)).await;
849        assert_eq!(counter.load(Ordering::SeqCst), 1);
850    }
851
852    #[tokio::test]
853    async fn test_schedule_once_batch_direct() {
854        let timer = TimerWheel::with_defaults();
855        let mut service = timer.create_service();
856        let counter = Arc::new(AtomicU32::new(0));
857
858        // 直接通过 service 批量调度定时器
859        let callbacks: Vec<_> = (0..3)
860            .map(|_| {
861                let counter = Arc::clone(&counter);
862                (Duration::from_millis(50), Some(CallbackWrapper::new(move || {
863                    let counter = Arc::clone(&counter);
864                    async move {
865                        counter.fetch_add(1, Ordering::SeqCst);
866                    }
867                    })))
868            })
869            .collect();
870
871        let tasks = TimerService::create_batch(callbacks);
872        assert_eq!(tasks.len(), 3);
873        service.register_batch(tasks).unwrap();
874
875        // 接收所有超时通知
876        let mut received_count = 0;
877        let mut rx = service.take_receiver().unwrap();
878        
879        while received_count < 3 {
880            match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
881                Ok(Some(_task_id)) => {
882                    received_count += 1;
883                }
884                Ok(None) => break,
885                Err(_) => break,
886            }
887        }
888
889        assert_eq!(received_count, 3);
890        
891        // 等待回调执行
892        tokio::time::sleep(Duration::from_millis(50)).await;
893        assert_eq!(counter.load(Ordering::SeqCst), 3);
894    }
895
896    #[tokio::test]
897    async fn test_schedule_once_notify_direct() {
898        let timer = TimerWheel::with_defaults();
899        let mut service = timer.create_service();
900
901        // 直接通过 service 调度仅通知的定时器(无回调)
902        let task = crate::task::TimerTask::new(Duration::from_millis(50), None);
903        let task_id = task.get_id();
904        service.register(task).unwrap();
905
906        // 接收超时通知
907        let mut rx = service.take_receiver().unwrap();
908        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
909            .await
910            .expect("Should receive timeout notification")
911            .expect("Should receive Some value");
912
913        assert_eq!(received_task_id, task_id);
914    }
915
916    #[tokio::test]
917    async fn test_schedule_and_cancel_direct() {
918        let timer = TimerWheel::with_defaults();
919        let service = timer.create_service();
920        let counter = Arc::new(AtomicU32::new(0));
921
922        // 直接调度定时器
923        let counter_clone = Arc::clone(&counter);
924        let task = TimerService::create_task(
925            Duration::from_secs(10),
926            Some(CallbackWrapper::new(move || {
927                let counter = Arc::clone(&counter_clone);
928                async move {
929                    counter.fetch_add(1, Ordering::SeqCst);
930                }
931            })),
932        );
933        let task_id = task.get_id();
934        service.register(task).unwrap();
935
936        // 立即取消
937        let cancelled = service.cancel_task(task_id);
938        assert!(cancelled, "Task should be cancelled successfully");
939
940        // 等待确保回调不会执行
941        tokio::time::sleep(Duration::from_millis(100)).await;
942        assert_eq!(counter.load(Ordering::SeqCst), 0, "Callback should not have been executed");
943    }
944
945    #[tokio::test]
946    async fn test_cancel_batch_direct() {
947        let timer = TimerWheel::with_defaults();
948        let service = timer.create_service();
949        let counter = Arc::new(AtomicU32::new(0));
950
951        // 批量调度定时器
952        let callbacks: Vec<_> = (0..10)
953            .map(|_| {
954                let counter = Arc::clone(&counter);
955                (Duration::from_secs(10), Some(CallbackWrapper::new(move || {
956                    let counter = Arc::clone(&counter);
957                    async move {
958                        counter.fetch_add(1, Ordering::SeqCst);
959                    }
960                    })))
961            })
962            .collect();
963
964        let tasks = TimerService::create_batch(callbacks);
965        let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
966        assert_eq!(task_ids.len(), 10);
967        service.register_batch(tasks).unwrap();
968
969        // 批量取消所有任务
970        let cancelled = service.cancel_batch(&task_ids);
971        assert_eq!(cancelled, 10, "All 10 tasks should be cancelled");
972
973        // 等待确保回调不会执行
974        tokio::time::sleep(Duration::from_millis(100)).await;
975        assert_eq!(counter.load(Ordering::SeqCst), 0, "No callbacks should have been executed");
976    }
977
978    #[tokio::test]
979    async fn test_cancel_batch_partial() {
980        let timer = TimerWheel::with_defaults();
981        let service = timer.create_service();
982        let counter = Arc::new(AtomicU32::new(0));
983
984        // 批量调度定时器
985        let callbacks: Vec<_> = (0..10)
986            .map(|_| {
987                let counter = Arc::clone(&counter);
988                (Duration::from_secs(10), Some(CallbackWrapper::new(move || {
989                    let counter = Arc::clone(&counter);
990                    async move {
991                        counter.fetch_add(1, Ordering::SeqCst);
992                    }
993                    })))
994            })
995            .collect();
996
997        let tasks = TimerService::create_batch(callbacks);
998        let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
999        service.register_batch(tasks).unwrap();
1000
1001        // 只取消前5个任务
1002        let to_cancel: Vec<_> = task_ids.iter().take(5).copied().collect();
1003        let cancelled = service.cancel_batch(&to_cancel);
1004        assert_eq!(cancelled, 5, "5 tasks should be cancelled");
1005
1006        // 等待确保前5个回调不会执行
1007        tokio::time::sleep(Duration::from_millis(100)).await;
1008        assert_eq!(counter.load(Ordering::SeqCst), 0, "Cancelled tasks should not execute");
1009    }
1010
1011    #[tokio::test]
1012    async fn test_cancel_batch_empty() {
1013        let timer = TimerWheel::with_defaults();
1014        let service = timer.create_service();
1015
1016        // 取消空列表
1017        let empty: Vec<TaskId> = vec![];
1018        let cancelled = service.cancel_batch(&empty);
1019        assert_eq!(cancelled, 0, "No tasks should be cancelled");
1020    }
1021
1022    #[tokio::test]
1023    async fn test_postpone() {
1024        let timer = TimerWheel::with_defaults();
1025        let mut service = timer.create_service();
1026        let counter = Arc::new(AtomicU32::new(0));
1027
1028        // 注册一个任务,原始回调增加 1
1029        let counter_clone1 = Arc::clone(&counter);
1030        let task = TimerService::create_task(
1031            Duration::from_millis(50),
1032            Some(CallbackWrapper::new(move || {
1033                let counter = Arc::clone(&counter_clone1);
1034                async move {
1035                    counter.fetch_add(1, Ordering::SeqCst);
1036                }
1037                })),
1038        );
1039        let task_id = task.get_id();
1040        service.register(task).unwrap();
1041
1042        // 推迟任务并替换回调,新回调增加 10
1043        let counter_clone2 = Arc::clone(&counter);
1044        let postponed = service.postpone(
1045            task_id,
1046            Duration::from_millis(100),
1047            Some(CallbackWrapper::new(move || {
1048                let counter = Arc::clone(&counter_clone2);
1049                async move {
1050                    counter.fetch_add(10, Ordering::SeqCst);
1051                }
1052            }))
1053        );
1054        assert!(postponed, "Task should be postponed successfully");
1055
1056        // 接收超时通知(推迟后需要等待100ms,加上余量)
1057        let mut rx = service.take_receiver().unwrap();
1058        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1059            .await
1060            .expect("Should receive timeout notification")
1061            .expect("Should receive Some value");
1062
1063        assert_eq!(received_task_id, task_id);
1064        
1065        // 等待回调执行
1066        tokio::time::sleep(Duration::from_millis(20)).await;
1067        
1068        // 验证新回调被执行(增加了 10 而不是 1)
1069        assert_eq!(counter.load(Ordering::SeqCst), 10);
1070    }
1071
1072    #[tokio::test]
1073    async fn test_postpone_nonexistent_task() {
1074        let timer = TimerWheel::with_defaults();
1075        let service = timer.create_service();
1076
1077        // 尝试推迟一个不存在的任务
1078        let fake_task = TimerService::create_task(Duration::from_millis(50), None);
1079        let fake_task_id = fake_task.get_id();
1080        // 不注册这个任务
1081        
1082        let postponed = service.postpone(fake_task_id, Duration::from_millis(100), None);
1083        assert!(!postponed, "Nonexistent task should not be postponed");
1084    }
1085
1086    #[tokio::test]
1087    async fn test_postpone_batch() {
1088        let timer = TimerWheel::with_defaults();
1089        let mut service = timer.create_service();
1090        let counter = Arc::new(AtomicU32::new(0));
1091
1092        // 注册 3 个任务
1093        let mut task_ids = Vec::new();
1094        for _ in 0..3 {
1095            let counter_clone = Arc::clone(&counter);
1096            let task = TimerService::create_task(
1097                Duration::from_millis(50),
1098            Some(CallbackWrapper::new(move || {
1099                let counter = Arc::clone(&counter_clone);
1100                async move {
1101                    counter.fetch_add(1, Ordering::SeqCst);
1102                }
1103            })),
1104        );
1105            task_ids.push((task.get_id(), Duration::from_millis(150), None));
1106            service.register(task).unwrap();
1107        }
1108
1109        // 批量推迟
1110        let postponed = service.postpone_batch_with_callbacks(task_ids);
1111        assert_eq!(postponed, 3, "All 3 tasks should be postponed");
1112
1113        // 等待原定时间 50ms,任务不应该触发
1114        tokio::time::sleep(Duration::from_millis(70)).await;
1115        assert_eq!(counter.load(Ordering::SeqCst), 0);
1116
1117        // 接收所有超时通知
1118        let mut received_count = 0;
1119        let mut rx = service.take_receiver().unwrap();
1120        
1121        while received_count < 3 {
1122            match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
1123                Ok(Some(_task_id)) => {
1124                    received_count += 1;
1125                }
1126                Ok(None) => break,
1127                Err(_) => break,
1128            }
1129        }
1130
1131        assert_eq!(received_count, 3);
1132        
1133        // 等待回调执行
1134        tokio::time::sleep(Duration::from_millis(20)).await;
1135        assert_eq!(counter.load(Ordering::SeqCst), 3);
1136    }
1137
1138    #[tokio::test]
1139    async fn test_postpone_batch_with_callbacks() {
1140        let timer = TimerWheel::with_defaults();
1141        let mut service = timer.create_service();
1142        let counter = Arc::new(AtomicU32::new(0));
1143
1144        // 注册 3 个任务
1145        let mut task_ids = Vec::new();
1146        for _ in 0..3 {
1147            let task = TimerService::create_task(
1148                Duration::from_millis(50),
1149                None,
1150            );
1151            task_ids.push(task.get_id());
1152            service.register(task).unwrap();
1153        }
1154
1155        // 批量推迟并替换回调
1156        let updates: Vec<_> = task_ids
1157            .into_iter()
1158            .map(|id| {
1159                let counter_clone = Arc::clone(&counter);
1160                (id, Duration::from_millis(150), Some(CallbackWrapper::new(move || {
1161                    let counter = Arc::clone(&counter_clone);
1162                    async move {
1163                        counter.fetch_add(1, Ordering::SeqCst);
1164                    }
1165                    })))
1166            })
1167            .collect();
1168
1169        let postponed = service.postpone_batch_with_callbacks(updates);
1170        assert_eq!(postponed, 3, "All 3 tasks should be postponed");
1171
1172        // 等待原定时间 50ms,任务不应该触发
1173        tokio::time::sleep(Duration::from_millis(70)).await;
1174        assert_eq!(counter.load(Ordering::SeqCst), 0);
1175
1176        // 接收所有超时通知
1177        let mut received_count = 0;
1178        let mut rx = service.take_receiver().unwrap();
1179        
1180        while received_count < 3 {
1181            match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
1182                Ok(Some(_task_id)) => {
1183                    received_count += 1;
1184                }
1185                Ok(None) => break,
1186                Err(_) => break,
1187            }
1188        }
1189
1190        assert_eq!(received_count, 3);
1191        
1192        // 等待回调执行
1193        tokio::time::sleep(Duration::from_millis(20)).await;
1194        assert_eq!(counter.load(Ordering::SeqCst), 3);
1195    }
1196
1197    #[tokio::test]
1198    async fn test_postpone_batch_empty() {
1199        let timer = TimerWheel::with_defaults();
1200        let service = timer.create_service();
1201
1202        // 推迟空列表
1203        let empty: Vec<(TaskId, Duration, Option<CallbackWrapper>)> = vec![];
1204        let postponed = service.postpone_batch_with_callbacks(empty);
1205        assert_eq!(postponed, 0, "No tasks should be postponed");
1206    }
1207
1208    #[tokio::test]
1209    async fn test_postpone_keeps_timeout_notification_valid() {
1210        let timer = TimerWheel::with_defaults();
1211        let mut service = timer.create_service();
1212        let counter = Arc::new(AtomicU32::new(0));
1213
1214        // 注册一个任务
1215        let counter_clone = Arc::clone(&counter);
1216        let task = TimerService::create_task(
1217            Duration::from_millis(50),
1218            Some(CallbackWrapper::new(move || {
1219                let counter = Arc::clone(&counter_clone);
1220                async move {
1221                    counter.fetch_add(1, Ordering::SeqCst);
1222                }
1223            })),
1224        );
1225        let task_id = task.get_id();
1226        service.register(task).unwrap();
1227
1228        // 推迟任务
1229        service.postpone(task_id, Duration::from_millis(100), None);
1230
1231        // 验证超时通知仍然有效(推迟后需要等待100ms,加上余量)
1232        let mut rx = service.take_receiver().unwrap();
1233        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1234            .await
1235            .expect("Should receive timeout notification")
1236            .expect("Should receive Some value");
1237
1238        assert_eq!(received_task_id, task_id, "Timeout notification should still work after postpone");
1239        
1240        // 等待回调执行
1241        tokio::time::sleep(Duration::from_millis(20)).await;
1242        assert_eq!(counter.load(Ordering::SeqCst), 1);
1243    }
1244
1245    #[tokio::test]
1246    async fn test_cancelled_task_not_forwarded_to_timeout_rx() {
1247        let timer = TimerWheel::with_defaults();
1248        let mut service = timer.create_service();
1249
1250        // 注册两个任务:一个会被取消,一个会正常到期
1251        let task1 = TimerService::create_task(Duration::from_secs(10), None);
1252        let task1_id = task1.get_id();
1253        service.register(task1).unwrap();
1254
1255        let task2 = TimerService::create_task(Duration::from_millis(50), None);
1256        let task2_id = task2.get_id();
1257        service.register(task2).unwrap();
1258
1259        // 取消第一个任务
1260        let cancelled = service.cancel_task(task1_id);
1261        assert!(cancelled, "Task should be cancelled");
1262
1263        // 等待第二个任务到期
1264        let mut rx = service.take_receiver().unwrap();
1265        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1266            .await
1267            .expect("Should receive timeout notification")
1268            .expect("Should receive Some value");
1269
1270        // 应该只收到第二个任务(到期的)的通知,不应该收到第一个任务(取消的)的通知
1271        assert_eq!(received_task_id, task2_id, "Should only receive expired task notification");
1272
1273        // 验证没有其他通知(特别是被取消的任务不应该有通知)
1274        let no_more = tokio::time::timeout(Duration::from_millis(50), rx.recv()).await;
1275        assert!(no_more.is_err(), "Should not receive any more notifications");
1276    }
1277}
1278