kestrel_protocol_timer/
service.rs

1use crate::config::ServiceConfig;
2use crate::task::{TaskId, TimerCallback};
3use crate::timer::{BatchHandle, TimerHandle};
4use crate::wheel::Wheel;
5use futures::stream::{FuturesUnordered, StreamExt};
6use futures::future::BoxFuture;
7use parking_lot::Mutex;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::sync::mpsc;
11use tokio::task::JoinHandle;
12
13/// TimerService 命令类型
14enum ServiceCommand {
15    /// 添加批量定时器句柄
16    AddBatchHandle(BatchHandle),
17    /// 添加单个定时器句柄
18    AddTimerHandle(TimerHandle),
19    /// 关闭 Service
20    Shutdown,
21}
22
23/// TimerService - 基于 Actor 模式的定时器服务
24///
25/// 管理多个定时器句柄,监听所有超时事件,并将 TaskId 聚合转发给用户。
26///
27/// # 特性
28/// - 自动监听所有添加的定时器句柄的超时事件
29/// - 超时后自动从内部管理中移除该任务
30/// - 将超时的 TaskId 转发到统一的通道供用户接收
31/// - 支持动态添加 BatchHandle 和 TimerHandle
32///
33/// # 示例
34/// ```no_run
35/// use kestrel_protocol_timer::{TimerWheel, TimerService};
36/// use std::time::Duration;
37///
38/// #[tokio::main]
39/// async fn main() {
40///     let timer = TimerWheel::with_defaults();
41///     let mut service = timer.create_service();
42///     
43///     // 使用两步式 API 通过 service 批量调度定时器
44///     let callbacks: Vec<_> = (0..5)
45///         .map(|_| (Duration::from_millis(100), || async {}))
46///         .collect();
47///     let tasks = TimerService::create_batch(callbacks);
48///     service.register_batch(tasks).await;
49///     
50///     // 接收超时通知
51///     let mut rx = service.take_receiver().unwrap();
52///     while let Some(task_id) = rx.recv().await {
53///         println!("Task {:?} completed", task_id);
54///     }
55/// }
56/// ```
57pub struct TimerService {
58    /// 命令发送端
59    command_tx: mpsc::Sender<ServiceCommand>,
60    /// 超时接收端
61    timeout_rx: Option<mpsc::Receiver<TaskId>>,
62    /// Actor 任务句柄
63    actor_handle: Option<JoinHandle<()>>,
64    /// 时间轮引用(用于直接调度定时器)
65    wheel: Arc<Mutex<Wheel>>,
66}
67
68impl TimerService {
69    /// 创建新的 TimerService
70    ///
71    /// # 参数
72    /// - `wheel`: 时间轮引用
73    /// - `config`: 服务配置
74    ///
75    /// # 注意
76    /// 通常不直接调用此方法,而是使用 `TimerWheel::create_service()` 来创建。
77    ///
78    /// # 示例
79    /// ```no_run
80    /// use kestrel_protocol_timer::TimerWheel;
81    ///
82    /// #[tokio::main]
83    /// async fn main() {
84    ///     let timer = TimerWheel::with_defaults();
85    ///     let mut service = timer.create_service();
86    /// }
87    /// ```
88    pub(crate) fn new(wheel: Arc<Mutex<Wheel>>, config: ServiceConfig) -> Self {
89        let (command_tx, command_rx) = mpsc::channel(config.command_channel_capacity);
90        let (timeout_tx, timeout_rx) = mpsc::channel(config.timeout_channel_capacity);
91
92        let actor = ServiceActor::new(command_rx, timeout_tx);
93        let actor_handle = tokio::spawn(async move {
94            actor.run().await;
95        });
96
97        Self {
98            command_tx,
99            timeout_rx: Some(timeout_rx),
100            actor_handle: Some(actor_handle),
101            wheel,
102        }
103    }
104
105    /// 添加批量定时器句柄(内部方法)
106    async fn add_batch_handle(&self, batch: BatchHandle) {
107        let _ = self.command_tx
108            .send(ServiceCommand::AddBatchHandle(batch))
109            .await;
110    }
111
112    /// 添加单个定时器句柄(内部方法)
113    async fn add_timer_handle(&self, handle: TimerHandle) {
114        let _ = self.command_tx
115            .send(ServiceCommand::AddTimerHandle(handle))
116            .await;
117    }
118
119    /// 获取超时接收器(转移所有权)
120    ///
121    /// # 返回
122    /// 超时通知接收器,如果已经被取走则返回 None
123    ///
124    /// # 注意
125    /// 此方法只能调用一次,因为它会转移接收器的所有权
126    ///
127    /// # 示例
128    /// ```no_run
129    /// # use kestrel_protocol_timer::TimerWheel;
130    /// # use std::time::Duration;
131    /// # #[tokio::main]
132    /// # async fn main() {
133    /// let timer = TimerWheel::with_defaults();
134    /// let mut service = timer.create_service();
135    /// 
136    /// let mut rx = service.take_receiver().unwrap();
137    /// while let Some(task_id) = rx.recv().await {
138    ///     println!("Task {:?} timed out", task_id);
139    /// }
140    /// # }
141    /// ```
142    pub fn take_receiver(&mut self) -> Option<mpsc::Receiver<TaskId>> {
143        self.timeout_rx.take()
144    }
145
146    /// 取消指定的任务
147    ///
148    /// # 参数
149    /// - `task_id`: 要取消的任务 ID
150    ///
151    /// # 返回
152    /// - `Ok(true)`: 任务存在且成功取消
153    /// - `Ok(false)`: 任务不存在或取消失败
154    /// - `Err(String)`: 发送命令失败
155    ///
156    /// # 性能说明
157    /// 此方法使用直接取消优化,不需要等待 Actor 处理,大幅降低延迟
158    ///
159    /// # 示例
160    /// ```no_run
161    /// # use kestrel_protocol_timer::{TimerWheel, TimerService};
162    /// # use std::time::Duration;
163    /// # #[tokio::main]
164    /// # async fn main() {
165    /// let timer = TimerWheel::with_defaults();
166    /// let service = timer.create_service();
167    /// 
168    /// // 使用两步式 API 调度定时器
169    /// let task = TimerService::create_task(Duration::from_secs(10), || async {});
170    /// let task_id = task.get_id();
171    /// service.register(task).await;
172    /// 
173    /// // 取消任务
174    /// let cancelled = service.cancel_task(task_id).await;
175    /// println!("Task cancelled: {}", cancelled);
176    /// # }
177    /// ```
178    #[inline]
179    pub async fn cancel_task(&self, task_id: TaskId) -> bool {
180        // 优化:直接取消任务,无需通知 Actor
181        // FuturesUnordered 会在任务被取消时自动清理
182        let mut wheel = self.wheel.lock();
183        wheel.cancel(task_id)
184    }
185
186    /// 批量取消任务
187    ///
188    /// 使用底层的批量取消操作一次性取消多个任务,性能优于循环调用 cancel_task。
189    ///
190    /// # 参数
191    /// - `task_ids`: 要取消的任务 ID 列表
192    ///
193    /// # 返回
194    /// 成功取消的任务数量
195    ///
196    /// # 示例
197    /// ```no_run
198    /// # use kestrel_protocol_timer::{TimerWheel, TimerService};
199    /// # use std::time::Duration;
200    /// # #[tokio::main]
201    /// # async fn main() {
202    /// let timer = TimerWheel::with_defaults();
203    /// let service = timer.create_service();
204    /// 
205    /// let callbacks: Vec<_> = (0..10)
206    ///     .map(|_| (Duration::from_secs(10), || async {}))
207    ///     .collect();
208    /// let tasks = TimerService::create_batch(callbacks);
209    /// let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
210    /// service.register_batch(tasks).await;
211    /// 
212    /// // 批量取消
213    /// let cancelled = service.cancel_batch(&task_ids).await;
214    /// println!("成功取消 {} 个任务", cancelled);
215    /// # }
216    /// ```
217    #[inline]
218    pub async fn cancel_batch(&self, task_ids: &[TaskId]) -> usize {
219        if task_ids.is_empty() {
220            return 0;
221        }
222
223        // 优化:直接使用底层的批量取消,无需通知 Actor
224        // FuturesUnordered 会在任务被取消时自动清理
225        let mut wheel = self.wheel.lock();
226        wheel.cancel_batch(task_ids)
227    }
228
229    /// 创建定时器任务(静态方法,申请阶段)
230    /// 
231    /// # 参数
232    /// - `delay`: 延迟时间
233    /// - `callback`: 实现了 TimerCallback trait 的回调对象
234    /// 
235    /// # 返回
236    /// 返回 TimerTask,需要通过 `register()` 注册
237    /// 
238    /// # 示例
239    /// ```no_run
240    /// # use kestrel_protocol_timer::{TimerWheel, TimerService};
241    /// # use std::time::Duration;
242    /// # #[tokio::main]
243    /// # async fn main() {
244    /// let timer = TimerWheel::with_defaults();
245    /// let service = timer.create_service();
246    /// 
247    /// // 步骤 1: 创建任务
248    /// let task = TimerService::create_task(Duration::from_millis(100), || async {
249    ///     println!("Timer fired!");
250    /// });
251    /// 
252    /// let task_id = task.get_id();
253    /// println!("Created task: {:?}", task_id);
254    /// 
255    /// // 步骤 2: 注册任务
256    /// service.register(task).await;
257    /// # }
258    /// ```
259    pub fn create_task<C>(delay: Duration, callback: C) -> crate::task::TimerTask
260    where
261        C: TimerCallback,
262    {
263        crate::timer::TimerWheel::create_task(delay, callback)
264    }
265    
266    /// 批量创建定时器任务(静态方法,申请阶段)
267    /// 
268    /// # 参数
269    /// - `callbacks`: (延迟时间, 回调) 的元组列表
270    /// 
271    /// # 返回
272    /// 返回 TimerTask 列表,需要通过 `register_batch()` 注册
273    /// 
274    /// # 示例
275    /// ```no_run
276    /// # use kestrel_protocol_timer::{TimerWheel, TimerService};
277    /// # use std::time::Duration;
278    /// # #[tokio::main]
279    /// # async fn main() {
280    /// let timer = TimerWheel::with_defaults();
281    /// let service = timer.create_service();
282    /// 
283    /// // 步骤 1: 批量创建任务
284    /// let callbacks: Vec<_> = (0..3)
285    ///     .map(|i| (Duration::from_millis(100 * (i + 1)), move || async move {
286    ///         println!("Timer {} fired!", i);
287    ///     }))
288    ///     .collect();
289    /// 
290    /// let tasks = TimerService::create_batch(callbacks);
291    /// println!("Created {} tasks", tasks.len());
292    /// 
293    /// // 步骤 2: 批量注册任务
294    /// service.register_batch(tasks).await;
295    /// # }
296    /// ```
297    pub fn create_batch<C>(callbacks: Vec<(Duration, C)>) -> Vec<crate::task::TimerTask>
298    where
299        C: TimerCallback,
300    {
301        crate::timer::TimerWheel::create_batch(callbacks)
302    }
303    
304    /// 注册定时器任务到服务(注册阶段)
305    /// 
306    /// # 参数
307    /// - `task`: 通过 `create_task()` 创建的任务
308    /// 
309    /// # 示例
310    /// ```no_run
311    /// # use kestrel_protocol_timer::{TimerWheel, TimerService};
312    /// # use std::time::Duration;
313    /// # #[tokio::main]
314    /// # async fn main() {
315    /// let timer = TimerWheel::with_defaults();
316    /// let service = timer.create_service();
317    /// 
318    /// let task = TimerService::create_task(Duration::from_millis(100), || async {
319    ///     println!("Timer fired!");
320    /// });
321    /// let task_id = task.get_id();
322    /// 
323    /// service.register(task).await;
324    /// # }
325    /// ```
326    #[inline]
327    pub async fn register(&self, task: crate::task::TimerTask) {
328        let (completion_tx, completion_rx) = tokio::sync::oneshot::channel();
329        let notifier = crate::task::CompletionNotifier(completion_tx);
330        
331        let delay = task.delay;
332        let task_id = task.id;
333        
334        // 单次加锁完成所有操作
335        {
336            let mut wheel_guard = self.wheel.lock();
337            wheel_guard.insert(delay, task, notifier);
338        }
339        
340        // 创建句柄并添加到服务管理
341        let handle = TimerHandle::new(task_id, self.wheel.clone(), completion_rx);
342        self.add_timer_handle(handle).await;
343    }
344    
345    /// 批量注册定时器任务到服务(注册阶段)
346    /// 
347    /// # 参数
348    /// - `tasks`: 通过 `create_batch()` 创建的任务列表
349    /// 
350    /// # 示例
351    /// ```no_run
352    /// # use kestrel_protocol_timer::{TimerWheel, TimerService};
353    /// # use std::time::Duration;
354    /// # #[tokio::main]
355    /// # async fn main() {
356    /// let timer = TimerWheel::with_defaults();
357    /// let service = timer.create_service();
358    /// 
359    /// let callbacks: Vec<_> = (0..3)
360    ///     .map(|_| (Duration::from_secs(1), || async {}))
361    ///     .collect();
362    /// let tasks = TimerService::create_batch(callbacks);
363    /// 
364    /// service.register_batch(tasks).await;
365    /// # }
366    /// ```
367    #[inline]
368    pub async fn register_batch(&self, tasks: Vec<crate::task::TimerTask>) {
369        let task_count = tasks.len();
370        let mut completion_rxs = Vec::with_capacity(task_count);
371        let mut task_ids = Vec::with_capacity(task_count);
372        let mut prepared_tasks = Vec::with_capacity(task_count);
373        
374        // 步骤1: 准备所有 channels 和 notifiers(无锁)
375        // 优化:使用 for 循环代替 map + collect,避免闭包捕获开销
376        for task in tasks {
377            let (completion_tx, completion_rx) = tokio::sync::oneshot::channel();
378            let notifier = crate::task::CompletionNotifier(completion_tx);
379            
380            task_ids.push(task.id);
381            completion_rxs.push(completion_rx);
382            prepared_tasks.push((task.delay, task, notifier));
383        }
384        
385        // 步骤2: 单次加锁,批量插入
386        {
387            let mut wheel_guard = self.wheel.lock();
388            wheel_guard.insert_batch(prepared_tasks);
389        }
390        
391        // 创建批量句柄并添加到服务管理
392        let batch_handle = BatchHandle::new(task_ids, self.wheel.clone(), completion_rxs);
393        self.add_batch_handle(batch_handle).await;
394    }
395
396    /// 优雅关闭 TimerService
397    ///
398    /// # 示例
399    /// ```no_run
400    /// # use kestrel_protocol_timer::TimerWheel;
401    /// # #[tokio::main]
402    /// # async fn main() {
403    /// let timer = TimerWheel::with_defaults();
404    /// let mut service = timer.create_service();
405    /// 
406    /// // 使用 service...
407    /// 
408    /// service.shutdown().await;
409    /// # }
410    /// ```
411    pub async fn shutdown(mut self) {
412        let _ = self.command_tx.send(ServiceCommand::Shutdown).await;
413        if let Some(handle) = self.actor_handle.take() {
414            let _ = handle.await;
415        }
416    }
417}
418
419
420impl Drop for TimerService {
421    fn drop(&mut self) {
422        if let Some(handle) = self.actor_handle.take() {
423            handle.abort();
424        }
425    }
426}
427
428/// ServiceActor - 内部 Actor 实现
429struct ServiceActor {
430    /// 命令接收端
431    command_rx: mpsc::Receiver<ServiceCommand>,
432    /// 超时发送端
433    timeout_tx: mpsc::Sender<TaskId>,
434}
435
436impl ServiceActor {
437    fn new(command_rx: mpsc::Receiver<ServiceCommand>, timeout_tx: mpsc::Sender<TaskId>) -> Self {
438        Self {
439            command_rx,
440            timeout_tx,
441        }
442    }
443
444    async fn run(mut self) {
445        // 使用 FuturesUnordered 来监听所有的 completion_rxs
446        // 每个 future 返回 (TaskId, Result)
447        let mut futures: FuturesUnordered<BoxFuture<'static, (TaskId, Result<(), tokio::sync::oneshot::error::RecvError>)>> = FuturesUnordered::new();
448
449        loop {
450            tokio::select! {
451                // 监听超时事件
452                Some((task_id, _result)) = futures.next() => {
453                    // 任务超时,转发 TaskId
454                    let _ = self.timeout_tx.send(task_id).await;
455                    // 任务会自动从 FuturesUnordered 中移除
456                }
457                
458                // 监听命令
459                Some(cmd) = self.command_rx.recv() => {
460                    match cmd {
461                        ServiceCommand::AddBatchHandle(batch) => {
462                            let BatchHandle {
463                                task_ids,
464                                completion_rxs,
465                                ..
466                            } = batch;
467                            
468                            // 将所有任务添加到 futures 中
469                            for (task_id, rx) in task_ids.into_iter().zip(completion_rxs.into_iter()) {
470                                let future: BoxFuture<'static, (TaskId, Result<(), tokio::sync::oneshot::error::RecvError>)> = Box::pin(async move {
471                                    (task_id, rx.await)
472                                });
473                                futures.push(future);
474                            }
475                        }
476                        ServiceCommand::AddTimerHandle(handle) => {
477                            let TimerHandle{
478                                task_id,
479                                completion_rx,
480                                ..
481                            } = handle;
482                            
483                            // 添加到 futures 中
484                            let future: BoxFuture<'static, (TaskId, Result<(), tokio::sync::oneshot::error::RecvError>)> = Box::pin(async move {
485                                (task_id, completion_rx.0.await)
486                            });
487                            futures.push(future);
488                        }
489                        ServiceCommand::Shutdown => {
490                            break;
491                        }
492                    }
493                }
494                
495                // 如果没有任何 future 且命令通道已关闭,退出循环
496                else => {
497                    break;
498                }
499            }
500        }
501    }
502}
503
504#[cfg(test)]
505mod tests {
506    use super::*;
507    use crate::TimerWheel;
508    use std::sync::atomic::{AtomicU32, Ordering};
509    use std::sync::Arc;
510    use std::time::Duration;
511
512    #[tokio::test]
513    async fn test_service_creation() {
514        let timer = TimerWheel::with_defaults();
515        let _service = timer.create_service();
516    }
517
518
519    #[tokio::test]
520    async fn test_add_timer_handle_and_receive_timeout() {
521        let timer = TimerWheel::with_defaults();
522        let mut service = timer.create_service();
523
524        // 创建单个定时器
525        let task = TimerWheel::create_task(Duration::from_millis(50), || async {});
526        let task_id = task.get_id();
527        let handle = timer.register(task);
528
529        // 添加到 service
530        service.add_timer_handle(handle).await;
531
532        // 接收超时通知
533        let mut rx = service.take_receiver().unwrap();
534        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
535            .await
536            .expect("Should receive timeout notification")
537            .expect("Should receive Some value");
538
539        assert_eq!(received_task_id, task_id);
540    }
541
542
543    #[tokio::test]
544    async fn test_shutdown() {
545        let timer = TimerWheel::with_defaults();
546        let service = timer.create_service();
547
548        // 添加一些定时器
549        let task1 = TimerService::create_task(Duration::from_secs(10), || async {});
550        let task2 = TimerService::create_task(Duration::from_secs(10), || async {});
551        service.register(task1).await;
552        service.register(task2).await;
553
554        // 立即关闭(不等待定时器触发)
555        service.shutdown().await;
556    }
557
558
559
560    #[tokio::test]
561    async fn test_cancel_task() {
562        let timer = TimerWheel::with_defaults();
563        let service = timer.create_service();
564
565        // 添加一个长时间的定时器
566        let task = TimerWheel::create_task(Duration::from_secs(10), || async {});
567        let task_id = task.get_id();
568        let handle = timer.register(task);
569        
570        service.add_timer_handle(handle).await;
571
572        // 取消任务
573        let cancelled = service.cancel_task(task_id).await;
574        assert!(cancelled, "Task should be cancelled successfully");
575
576        // 尝试再次取消同一个任务,应该返回 false
577        let cancelled_again = service.cancel_task(task_id).await;
578        assert!(!cancelled_again, "Task should not exist anymore");
579    }
580
581    #[tokio::test]
582    async fn test_cancel_nonexistent_task() {
583        let timer = TimerWheel::with_defaults();
584        let service = timer.create_service();
585
586        // 添加一个定时器以初始化 service
587        let task = TimerWheel::create_task(Duration::from_millis(50), || async {});
588        let handle = timer.register(task);
589        service.add_timer_handle(handle).await;
590
591        // 尝试取消一个不存在的任务(创建一个不会实际注册的任务ID)
592        let fake_task = TimerWheel::create_task(Duration::from_millis(50), || async {});
593        let fake_task_id = fake_task.get_id();
594        // 不注册 fake_task
595        let cancelled = service.cancel_task(fake_task_id).await;
596        assert!(!cancelled, "Nonexistent task should not be cancelled");
597    }
598
599
600    #[tokio::test]
601    async fn test_task_timeout_cleans_up_task_sender() {
602        let timer = TimerWheel::with_defaults();
603        let mut service = timer.create_service();
604
605        // 添加一个短时间的定时器
606        let task = TimerWheel::create_task(Duration::from_millis(50), || async {});
607        let task_id = task.get_id();
608        let handle = timer.register(task);
609        
610        service.add_timer_handle(handle).await;
611
612        // 等待任务超时
613        let mut rx = service.take_receiver().unwrap();
614        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
615            .await
616            .expect("Should receive timeout notification")
617            .expect("Should receive Some value");
618        
619        assert_eq!(received_task_id, task_id);
620
621        // 等待一下确保内部清理完成
622        tokio::time::sleep(Duration::from_millis(10)).await;
623
624        // 尝试取消已经超时的任务,应该返回 false
625        let cancelled = service.cancel_task(task_id).await;
626        assert!(!cancelled, "Timed out task should not exist anymore");
627    }
628
629    #[tokio::test]
630    async fn test_cancel_task_spawns_background_task() {
631        let timer = TimerWheel::with_defaults();
632        let service = timer.create_service();
633        let counter = Arc::new(AtomicU32::new(0));
634
635        // 创建一个定时器
636        let counter_clone = Arc::clone(&counter);
637        let task = TimerWheel::create_task(
638            Duration::from_secs(10),
639            move || {
640                let counter = Arc::clone(&counter_clone);
641                async move {
642                    counter.fetch_add(1, Ordering::SeqCst);
643                }
644            },
645        );
646        let task_id = task.get_id();
647        let handle = timer.register(task);
648        
649        service.add_timer_handle(handle).await;
650
651        // 使用 cancel_task(会等待结果,但在后台协程中处理)
652        let cancelled = service.cancel_task(task_id).await;
653        assert!(cancelled, "Task should be cancelled successfully");
654
655        // 等待足够长时间确保回调不会被执行
656        tokio::time::sleep(Duration::from_millis(100)).await;
657        assert_eq!(counter.load(Ordering::SeqCst), 0, "Callback should not have been executed");
658
659        // 验证任务已从 active_tasks 中移除
660        let cancelled_again = service.cancel_task(task_id).await;
661        assert!(!cancelled_again, "Task should have been removed from active_tasks");
662    }
663
664    #[tokio::test]
665    async fn test_schedule_once_direct() {
666        let timer = TimerWheel::with_defaults();
667        let mut service = timer.create_service();
668        let counter = Arc::new(AtomicU32::new(0));
669
670        // 直接通过 service 调度定时器
671        let counter_clone = Arc::clone(&counter);
672        let task = TimerService::create_task(
673            Duration::from_millis(50),
674            move || {
675                let counter = Arc::clone(&counter_clone);
676                async move {
677                    counter.fetch_add(1, Ordering::SeqCst);
678                }
679            },
680        );
681        let task_id = task.get_id();
682        service.register(task).await;
683
684        // 等待定时器触发
685        let mut rx = service.take_receiver().unwrap();
686        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
687            .await
688            .expect("Should receive timeout notification")
689            .expect("Should receive Some value");
690
691        assert_eq!(received_task_id, task_id);
692        
693        // 等待回调执行
694        tokio::time::sleep(Duration::from_millis(50)).await;
695        assert_eq!(counter.load(Ordering::SeqCst), 1);
696    }
697
698    #[tokio::test]
699    async fn test_schedule_once_batch_direct() {
700        let timer = TimerWheel::with_defaults();
701        let mut service = timer.create_service();
702        let counter = Arc::new(AtomicU32::new(0));
703
704        // 直接通过 service 批量调度定时器
705        let callbacks: Vec<_> = (0..3)
706            .map(|_| {
707                let counter = Arc::clone(&counter);
708                (Duration::from_millis(50), move || {
709                    let counter = Arc::clone(&counter);
710                    async move {
711                        counter.fetch_add(1, Ordering::SeqCst);
712                    }
713                })
714            })
715            .collect();
716
717        let tasks = TimerService::create_batch(callbacks);
718        assert_eq!(tasks.len(), 3);
719        service.register_batch(tasks).await;
720
721        // 接收所有超时通知
722        let mut received_count = 0;
723        let mut rx = service.take_receiver().unwrap();
724        
725        while received_count < 3 {
726            match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
727                Ok(Some(_task_id)) => {
728                    received_count += 1;
729                }
730                Ok(None) => break,
731                Err(_) => break,
732            }
733        }
734
735        assert_eq!(received_count, 3);
736        
737        // 等待回调执行
738        tokio::time::sleep(Duration::from_millis(50)).await;
739        assert_eq!(counter.load(Ordering::SeqCst), 3);
740    }
741
742    #[tokio::test]
743    async fn test_schedule_once_notify_direct() {
744        let timer = TimerWheel::with_defaults();
745        let mut service = timer.create_service();
746
747        // 直接通过 service 调度仅通知的定时器(无回调)
748        let task = crate::task::TimerTask::new(Duration::from_millis(50), None);
749        let task_id = task.get_id();
750        service.register(task).await;
751
752        // 接收超时通知
753        let mut rx = service.take_receiver().unwrap();
754        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
755            .await
756            .expect("Should receive timeout notification")
757            .expect("Should receive Some value");
758
759        assert_eq!(received_task_id, task_id);
760    }
761
762    #[tokio::test]
763    async fn test_schedule_and_cancel_direct() {
764        let timer = TimerWheel::with_defaults();
765        let service = timer.create_service();
766        let counter = Arc::new(AtomicU32::new(0));
767
768        // 直接调度定时器
769        let counter_clone = Arc::clone(&counter);
770        let task = TimerService::create_task(
771            Duration::from_secs(10),
772            move || {
773                let counter = Arc::clone(&counter_clone);
774                async move {
775                    counter.fetch_add(1, Ordering::SeqCst);
776                }
777            },
778        );
779        let task_id = task.get_id();
780        service.register(task).await;
781
782        // 立即取消
783        let cancelled = service.cancel_task(task_id).await;
784        assert!(cancelled, "Task should be cancelled successfully");
785
786        // 等待确保回调不会执行
787        tokio::time::sleep(Duration::from_millis(100)).await;
788        assert_eq!(counter.load(Ordering::SeqCst), 0, "Callback should not have been executed");
789    }
790
791    #[tokio::test]
792    async fn test_cancel_batch_direct() {
793        let timer = TimerWheel::with_defaults();
794        let service = timer.create_service();
795        let counter = Arc::new(AtomicU32::new(0));
796
797        // 批量调度定时器
798        let callbacks: Vec<_> = (0..10)
799            .map(|_| {
800                let counter = Arc::clone(&counter);
801                (Duration::from_secs(10), move || {
802                    let counter = Arc::clone(&counter);
803                    async move {
804                        counter.fetch_add(1, Ordering::SeqCst);
805                    }
806                })
807            })
808            .collect();
809
810        let tasks = TimerService::create_batch(callbacks);
811        let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
812        assert_eq!(task_ids.len(), 10);
813        service.register_batch(tasks).await;
814
815        // 批量取消所有任务
816        let cancelled = service.cancel_batch(&task_ids).await;
817        assert_eq!(cancelled, 10, "All 10 tasks should be cancelled");
818
819        // 等待确保回调不会执行
820        tokio::time::sleep(Duration::from_millis(100)).await;
821        assert_eq!(counter.load(Ordering::SeqCst), 0, "No callbacks should have been executed");
822    }
823
824    #[tokio::test]
825    async fn test_cancel_batch_partial() {
826        let timer = TimerWheel::with_defaults();
827        let service = timer.create_service();
828        let counter = Arc::new(AtomicU32::new(0));
829
830        // 批量调度定时器
831        let callbacks: Vec<_> = (0..10)
832            .map(|_| {
833                let counter = Arc::clone(&counter);
834                (Duration::from_secs(10), move || {
835                    let counter = Arc::clone(&counter);
836                    async move {
837                        counter.fetch_add(1, Ordering::SeqCst);
838                    }
839                })
840            })
841            .collect();
842
843        let tasks = TimerService::create_batch(callbacks);
844        let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
845        service.register_batch(tasks).await;
846
847        // 只取消前5个任务
848        let to_cancel: Vec<_> = task_ids.iter().take(5).copied().collect();
849        let cancelled = service.cancel_batch(&to_cancel).await;
850        assert_eq!(cancelled, 5, "5 tasks should be cancelled");
851
852        // 等待确保前5个回调不会执行
853        tokio::time::sleep(Duration::from_millis(100)).await;
854        assert_eq!(counter.load(Ordering::SeqCst), 0, "Cancelled tasks should not execute");
855    }
856
857    #[tokio::test]
858    async fn test_cancel_batch_empty() {
859        let timer = TimerWheel::with_defaults();
860        let service = timer.create_service();
861
862        // 取消空列表
863        let empty: Vec<TaskId> = vec![];
864        let cancelled = service.cancel_batch(&empty).await;
865        assert_eq!(cancelled, 0, "No tasks should be cancelled");
866    }
867}
868