kestrel_protocol_timer/
service.rs

1use crate::task::{CallbackWrapper, TaskId, TimerCallback};
2use crate::timer::{BatchHandle, TimerHandle};
3use crate::error::TimerError;
4use crate::wheel::Wheel;
5use futures::stream::{FuturesUnordered, StreamExt};
6use futures::future::BoxFuture;
7use parking_lot::Mutex;
8use rustc_hash::FxHashSet;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::sync::mpsc;
12use tokio::task::JoinHandle;
13
14/// TimerService 命令类型
15enum ServiceCommand {
16    /// 添加批量定时器句柄
17    AddBatchHandle(BatchHandle),
18    /// 添加单个定时器句柄
19    AddTimerHandle(TimerHandle),
20    /// 批量从活跃任务集合中移除任务(用于直接取消后的清理)
21    RemoveTasks {
22        task_ids: Vec<TaskId>,
23    },
24    /// 关闭 Service
25    Shutdown,
26}
27
28/// TimerService - 基于 Actor 模式的定时器服务
29///
30/// 管理多个定时器句柄,监听所有超时事件,并将 TaskId 聚合转发给用户。
31///
32/// # 特性
33/// - 自动监听所有添加的定时器句柄的超时事件
34/// - 超时后自动从内部管理中移除该任务
35/// - 将超时的 TaskId 转发到统一的通道供用户接收
36/// - 支持动态添加 BatchHandle 和 TimerHandle
37///
38/// # 示例
39/// ```no_run
40/// use kestrel_protocol_timer::{TimerWheel, TimerService};
41/// use std::time::Duration;
42///
43/// #[tokio::main]
44/// async fn main() {
45///     let timer = TimerWheel::with_defaults().unwrap();
46///     let mut service = timer.create_service();
47///     
48///     // 直接通过 service 批量调度定时器
49///     let callbacks: Vec<_> = (0..5)
50///         .map(|_| (Duration::from_millis(100), || async {}))
51///         .collect();
52///     service.schedule_once_batch(callbacks).await.unwrap();
53///     
54///     // 接收超时通知
55///     let mut rx = service.take_receiver().unwrap();
56///     while let Some(task_id) = rx.recv().await {
57///         println!("Task {:?} completed", task_id);
58///     }
59/// }
60/// ```
61pub struct TimerService {
62    /// 命令发送端
63    command_tx: mpsc::Sender<ServiceCommand>,
64    /// 超时接收端
65    timeout_rx: Option<mpsc::Receiver<TaskId>>,
66    /// Actor 任务句柄
67    actor_handle: Option<JoinHandle<()>>,
68    /// 时间轮引用(用于直接调度定时器)
69    wheel: Arc<Mutex<Wheel>>,
70}
71
72impl TimerService {
73    /// 创建新的 TimerService
74    ///
75    /// # 参数
76    /// - `wheel`: 时间轮引用
77    ///
78    /// # 注意
79    /// 通常不直接调用此方法,而是使用 `TimerWheel::create_service()` 来创建。
80    ///
81    /// # 示例
82    /// ```no_run
83    /// use kestrel_protocol_timer::TimerWheel;
84    ///
85    /// #[tokio::main]
86    /// async fn main() {
87    ///     let timer = TimerWheel::with_defaults().unwrap();
88    ///     let mut service = timer.create_service();
89    /// }
90    /// ```
91    pub(crate) fn new(wheel: Arc<Mutex<Wheel>>) -> Self {
92        // 优化:增加命令通道容量以减少背压,提升添加操作的吞吐量
93        let (command_tx, command_rx) = mpsc::channel(512);
94        let (timeout_tx, timeout_rx) = mpsc::channel(1000);
95
96        let actor = ServiceActor::new(command_rx, timeout_tx);
97        let actor_handle = tokio::spawn(async move {
98            actor.run().await;
99        });
100
101        Self {
102            command_tx,
103            timeout_rx: Some(timeout_rx),
104            actor_handle: Some(actor_handle),
105            wheel,
106        }
107    }
108
109    /// 添加批量定时器句柄(内部方法)
110    async fn add_batch_handle(&self, batch: BatchHandle) -> Result<(), TimerError> {
111        self.command_tx
112            .send(ServiceCommand::AddBatchHandle(batch))
113            .await
114            .map_err(|_| TimerError::ChannelClosed)
115    }
116
117    /// 添加单个定时器句柄(内部方法)
118    async fn add_timer_handle(&self, handle: TimerHandle) -> Result<(), TimerError> {
119        self.command_tx
120            .send(ServiceCommand::AddTimerHandle(handle))
121            .await
122            .map_err(|_| TimerError::ChannelClosed)
123    }
124
125    /// 获取超时接收器(转移所有权)
126    ///
127    /// # 返回
128    /// 超时通知接收器,如果已经被取走则返回 None
129    ///
130    /// # 注意
131    /// 此方法只能调用一次,因为它会转移接收器的所有权
132    ///
133    /// # 示例
134    /// ```no_run
135    /// # use kestrel_protocol_timer::TimerWheel;
136    /// # use std::time::Duration;
137    /// # #[tokio::main]
138    /// # async fn main() {
139    /// let timer = TimerWheel::with_defaults().unwrap();
140    /// let mut service = timer.create_service();
141    /// 
142    /// let mut rx = service.take_receiver().unwrap();
143    /// while let Some(task_id) = rx.recv().await {
144    ///     println!("Task {:?} timed out", task_id);
145    /// }
146    /// # }
147    /// ```
148    pub fn take_receiver(&mut self) -> Option<mpsc::Receiver<TaskId>> {
149        self.timeout_rx.take()
150    }
151
152    /// 取消指定的任务
153    ///
154    /// # 参数
155    /// - `task_id`: 要取消的任务 ID
156    ///
157    /// # 返回
158    /// - `Ok(true)`: 任务存在且成功取消
159    /// - `Ok(false)`: 任务不存在或取消失败
160    /// - `Err(String)`: 发送命令失败
161    ///
162    /// # 性能说明
163    /// 此方法使用直接取消优化,不需要等待 Actor 处理,大幅降低延迟
164    ///
165    /// # 示例
166    /// ```no_run
167    /// # use kestrel_protocol_timer::{TimerWheel, TimerService};
168    /// # use std::time::Duration;
169    /// # #[tokio::main]
170    /// # async fn main() {
171    /// let timer = TimerWheel::with_defaults().unwrap();
172    /// let service = timer.create_service();
173    /// 
174    /// // 直接通过 service 调度定时器
175    /// let task_id = service.schedule_once(Duration::from_secs(10), || async {}).await.unwrap();
176    /// 
177    /// // 取消任务
178    /// let cancelled = service.cancel_task(task_id).await.unwrap();
179    /// println!("Task cancelled: {}", cancelled);
180    /// # }
181    /// ```
182    pub async fn cancel_task(&self, task_id: TaskId) -> Result<bool, String> {
183        // 优化:直接取消任务,避免通过 Actor 的异步往返
184        // 这将延迟从 "2次异步通信" 减少到 "0次等待"
185        let success = {
186            let mut wheel = self.wheel.lock();
187            wheel.cancel(task_id)
188        };
189        
190        // 异步通知 Actor 清理 active_tasks(无需等待结果)
191        if success {
192            let _ = self.command_tx
193                .send(ServiceCommand::RemoveTasks { 
194                    task_ids: vec![task_id] 
195                })
196                .await;
197        }
198        
199        Ok(success)
200    }
201
202    /// 批量取消任务
203    ///
204    /// 使用底层的批量取消操作一次性取消多个任务,性能优于循环调用 cancel_task。
205    ///
206    /// # 参数
207    /// - `task_ids`: 要取消的任务 ID 列表
208    ///
209    /// # 返回
210    /// 成功取消的任务数量
211    ///
212    /// # 示例
213    /// ```no_run
214    /// # use kestrel_protocol_timer::{TimerWheel, TimerService};
215    /// # use std::time::Duration;
216    /// # #[tokio::main]
217    /// # async fn main() {
218    /// let timer = TimerWheel::with_defaults().unwrap();
219    /// let service = timer.create_service();
220    /// 
221    /// let callbacks: Vec<_> = (0..10)
222    ///     .map(|_| (Duration::from_secs(10), || async {}))
223    ///     .collect();
224    /// let task_ids = service.schedule_once_batch(callbacks).await.unwrap();
225    /// 
226    /// // 批量取消
227    /// let cancelled = service.cancel_batch(&task_ids).await;
228    /// println!("成功取消 {} 个任务", cancelled);
229    /// # }
230    /// ```
231    pub async fn cancel_batch(&self, task_ids: &[TaskId]) -> usize {
232        if task_ids.is_empty() {
233            return 0;
234        }
235
236        // 直接使用底层的批量取消
237        let cancelled_count = {
238            let mut wheel = self.wheel.lock();
239            wheel.cancel_batch(task_ids)
240        };
241
242        // 使用批量移除命令,一次性发送所有需要移除的任务ID
243        let _ = self.command_tx
244            .send(ServiceCommand::RemoveTasks { 
245                task_ids: task_ids.to_vec() 
246            })
247            .await;
248
249        cancelled_count
250    }
251
252    /// 调度一次性定时器
253    ///
254    /// 创建定时器并自动添加到服务管理中,无需手动调用 add_timer_handle
255    ///
256    /// # 参数
257    /// - `delay`: 延迟时间
258    /// - `callback`: 实现了 TimerCallback trait 的回调对象
259    ///
260    /// # 返回
261    /// - `Ok(TaskId)`: 成功调度,返回任务ID
262    /// - `Err(TimerError)`: 调度失败
263    ///
264    /// # 示例
265    /// ```no_run
266    /// # use kestrel_protocol_timer::TimerWheel;
267    /// # use std::time::Duration;
268    /// # #[tokio::main]
269    /// # async fn main() {
270    /// let timer = TimerWheel::with_defaults().unwrap();
271    /// let mut service = timer.create_service();
272    /// 
273    /// let task_id = service.schedule_once(Duration::from_millis(100), || async {
274    ///     println!("Timer fired!");
275    /// }).await.unwrap();
276    /// 
277    /// println!("Scheduled task: {:?}", task_id);
278    /// # }
279    /// ```
280    pub async fn schedule_once<C>(&self, delay: Duration, callback: C) -> Result<TaskId, TimerError>
281    where
282        C: TimerCallback,
283    {
284        // 创建任务并获取句柄
285        let handle = self.create_timer_handle(delay, Some(Arc::new(callback)))?;
286        let task_id = handle.task_id();
287        
288        // 自动添加到服务管理
289        self.add_timer_handle(handle).await?;
290        
291        Ok(task_id)
292    }
293
294    /// 批量调度一次性定时器
295    ///
296    /// 批量创建定时器并自动添加到服务管理中
297    ///
298    /// # 参数
299    /// - `callbacks`: (延迟时间, 回调) 的元组列表
300    ///
301    /// # 返回
302    /// - `Ok(Vec<TaskId>)`: 成功调度,返回所有任务ID
303    /// - `Err(TimerError)`: 调度失败
304    ///
305    /// # 示例
306    /// ```no_run
307    /// # use kestrel_protocol_timer::TimerWheel;
308    /// # use std::time::Duration;
309    /// # #[tokio::main]
310    /// # async fn main() {
311    /// let timer = TimerWheel::with_defaults().unwrap();
312    /// let mut service = timer.create_service();
313    /// 
314    /// let callbacks: Vec<_> = (0..3)
315    ///     .map(|i| (Duration::from_millis(100 * (i + 1)), move || async move {
316    ///         println!("Timer {} fired!", i);
317    ///     }))
318    ///     .collect();
319    /// 
320    /// let task_ids = service.schedule_once_batch(callbacks).await.unwrap();
321    /// println!("Scheduled {} tasks", task_ids.len());
322    /// # }
323    /// ```
324    pub async fn schedule_once_batch<C>(&self, callbacks: Vec<(Duration, C)>) -> Result<Vec<TaskId>, TimerError>
325    where
326        C: TimerCallback,
327    {
328        // 创建批量任务并获取句柄
329        let batch_handle = self.create_batch_handle(callbacks)?;
330        let task_ids = batch_handle.task_ids().to_vec();
331        
332        // 自动添加到服务管理
333        self.add_batch_handle(batch_handle).await?;
334        
335        Ok(task_ids)
336    }
337
338    /// 调度一次性通知定时器(无回调,仅通知)
339    ///
340    /// 创建仅通知的定时器并自动添加到服务管理中
341    ///
342    /// # 参数
343    /// - `delay`: 延迟时间
344    ///
345    /// # 返回
346    /// - `Ok(TaskId)`: 成功调度,返回任务ID
347    /// - `Err(TimerError)`: 调度失败
348    ///
349    /// # 示例
350    /// ```no_run
351    /// # use kestrel_protocol_timer::TimerWheel;
352    /// # use std::time::Duration;
353    /// # #[tokio::main]
354    /// # async fn main() {
355    /// let timer = TimerWheel::with_defaults().unwrap();
356    /// let mut service = timer.create_service();
357    /// 
358    /// let task_id = service.schedule_once_notify(Duration::from_millis(100)).await.unwrap();
359    /// println!("Scheduled notify task: {:?}", task_id);
360    /// 
361    /// // 可以通过 timeout_receiver 接收超时通知
362    /// # }
363    /// ```
364    pub async fn schedule_once_notify(&self, delay: Duration) -> Result<TaskId, TimerError> {
365        // 创建无回调任务并获取句柄
366        let handle = self.create_timer_handle(delay, None)?;
367        let task_id = handle.task_id();
368        
369        // 自动添加到服务管理
370        self.add_timer_handle(handle).await?;
371        
372        Ok(task_id)
373    }
374
375    /// 内部方法:创建定时器句柄
376    fn create_timer_handle(
377        &self,
378        delay: Duration,
379        callback: Option<CallbackWrapper>,
380    ) -> Result<TimerHandle, TimerError> {
381        crate::timer::TimerWheel::create_timer_handle_internal(
382            &self.wheel,
383            delay,
384            callback
385        )
386    }
387
388    /// 内部方法:创建批量定时器句柄
389    fn create_batch_handle<C>(
390        &self,
391        callbacks: Vec<(Duration, C)>,
392    ) -> Result<BatchHandle, TimerError>
393    where
394        C: TimerCallback,
395    {
396        crate::timer::TimerWheel::create_batch_handle_internal(
397            &self.wheel,
398            callbacks
399        )
400    }
401
402    /// 优雅关闭 TimerService
403    ///
404    /// # 示例
405    /// ```no_run
406    /// # use kestrel_protocol_timer::TimerWheel;
407    /// # #[tokio::main]
408    /// # async fn main() {
409    /// let timer = TimerWheel::with_defaults().unwrap();
410    /// let mut service = timer.create_service();
411    /// 
412    /// // 使用 service...
413    /// 
414    /// service.shutdown().await;
415    /// # }
416    /// ```
417    pub async fn shutdown(mut self) {
418        let _ = self.command_tx.send(ServiceCommand::Shutdown).await;
419        if let Some(handle) = self.actor_handle.take() {
420            let _ = handle.await;
421        }
422    }
423}
424
425
426impl Drop for TimerService {
427    fn drop(&mut self) {
428        if let Some(handle) = self.actor_handle.take() {
429            handle.abort();
430        }
431    }
432}
433
434/// ServiceActor - 内部 Actor 实现
435struct ServiceActor {
436    /// 命令接收端
437    command_rx: mpsc::Receiver<ServiceCommand>,
438    /// 超时发送端
439    timeout_tx: mpsc::Sender<TaskId>,
440    /// 活跃任务ID集合(使用 FxHashSet 提升性能)
441    active_tasks: FxHashSet<TaskId>,
442}
443
444impl ServiceActor {
445    fn new(command_rx: mpsc::Receiver<ServiceCommand>, timeout_tx: mpsc::Sender<TaskId>) -> Self {
446        Self {
447            command_rx,
448            timeout_tx,
449            active_tasks: FxHashSet::default(),
450        }
451    }
452
453    async fn run(mut self) {
454        // 使用 FuturesUnordered 来监听所有的 completion_rxs
455        // 每个 future 返回 (TaskId, Result)
456        let mut futures: FuturesUnordered<BoxFuture<'static, (TaskId, Result<(), tokio::sync::oneshot::error::RecvError>)>> = FuturesUnordered::new();
457
458        loop {
459            tokio::select! {
460                // 监听超时事件
461                Some((task_id, _result)) = futures.next() => {
462                    // 任务超时,转发 TaskId
463                    let _ = self.timeout_tx.send(task_id).await;
464                    // 从活跃任务集合中移除该任务
465                    self.active_tasks.remove(&task_id);
466                    // 任务会自动从 FuturesUnordered 中移除
467                }
468                
469                // 监听命令
470                Some(cmd) = self.command_rx.recv() => {
471                    match cmd {
472                        ServiceCommand::AddBatchHandle(batch) => {
473                            let BatchHandle {
474                                task_ids,
475                                completion_rxs,
476                                ..
477                            } = batch;
478                            
479                            // 将所有任务添加到 futures 和 active_tasks 中
480                            for (task_id, rx) in task_ids.into_iter().zip(completion_rxs.into_iter()) {
481                                // 记录到活跃任务集合
482                                self.active_tasks.insert(task_id);
483                                
484                                let future: BoxFuture<'static, (TaskId, Result<(), tokio::sync::oneshot::error::RecvError>)> = Box::pin(async move {
485                                    (task_id, rx.await)
486                                });
487                                futures.push(future);
488                            }
489                        }
490                        ServiceCommand::AddTimerHandle(handle) => {
491                            let TimerHandle{
492                                task_id,
493                                completion_rx,
494                                ..
495                            } = handle;
496                            
497                            // 记录到活跃任务集合
498                            self.active_tasks.insert(task_id);
499                            
500                            // 添加到 futures 中
501                            let future: BoxFuture<'static, (TaskId, Result<(), tokio::sync::oneshot::error::RecvError>)> = Box::pin(async move {
502                                (task_id, completion_rx.0.await)
503                            });
504                            futures.push(future);
505                        }
506                        ServiceCommand::RemoveTasks { task_ids } => {
507                            // 批量从活跃任务集合中移除任务
508                            // 用于直接取消后的清理工作
509                            for task_id in task_ids {
510                                self.active_tasks.remove(&task_id);
511                            }
512                        }
513                        ServiceCommand::Shutdown => {
514                            break;
515                        }
516                    }
517                }
518                
519                // 如果没有任何 future 且命令通道已关闭,退出循环
520                else => {
521                    break;
522                }
523            }
524        }
525    }
526}
527
528#[cfg(test)]
529mod tests {
530    use super::*;
531    use crate::TimerWheel;
532    use std::sync::atomic::{AtomicU32, Ordering};
533    use std::sync::Arc;
534    use std::time::Duration;
535
536    #[tokio::test]
537    async fn test_service_creation() {
538        let timer = TimerWheel::with_defaults().unwrap();
539        let _service = timer.create_service();
540    }
541
542
543    #[tokio::test]
544    async fn test_add_timer_handle_and_receive_timeout() {
545        let timer = TimerWheel::with_defaults().unwrap();
546        let mut service = timer.create_service();
547
548        // 创建单个定时器
549        let handle = timer.schedule_once(Duration::from_millis(50), || async {}).await.unwrap();
550        let task_id = handle.task_id();
551
552        // 添加到 service
553        service.add_timer_handle(handle).await.unwrap();
554
555        // 接收超时通知
556        let mut rx = service.take_receiver().unwrap();
557        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
558            .await
559            .expect("Should receive timeout notification")
560            .expect("Should receive Some value");
561
562        assert_eq!(received_task_id, task_id);
563    }
564
565
566    #[tokio::test]
567    async fn test_shutdown() {
568        let timer = TimerWheel::with_defaults().unwrap();
569        let service = timer.create_service();
570
571        // 添加一些定时器
572        let _task_id1 = service.schedule_once(Duration::from_secs(10), || async {}).await.unwrap();
573        let _task_id2 = service.schedule_once(Duration::from_secs(10), || async {}).await.unwrap();
574
575        // 立即关闭(不等待定时器触发)
576        service.shutdown().await;
577    }
578
579
580
581    #[tokio::test]
582    async fn test_cancel_task() {
583        let timer = TimerWheel::with_defaults().unwrap();
584        let service = timer.create_service();
585
586        // 添加一个长时间的定时器
587        let handle = timer.schedule_once(Duration::from_secs(10), || async {}).await.unwrap();
588        let task_id = handle.task_id();
589        
590        service.add_timer_handle(handle).await.unwrap();
591
592        // 取消任务
593        let cancelled = service.cancel_task(task_id).await.unwrap();
594        assert!(cancelled, "Task should be cancelled successfully");
595
596        // 尝试再次取消同一个任务,应该返回 false
597        let cancelled_again = service.cancel_task(task_id).await.unwrap();
598        assert!(!cancelled_again, "Task should not exist anymore");
599    }
600
601    #[tokio::test]
602    async fn test_cancel_nonexistent_task() {
603        let timer = TimerWheel::with_defaults().unwrap();
604        let service = timer.create_service();
605
606        // 添加一个定时器以初始化 service
607        let handle = timer.schedule_once(Duration::from_millis(50), || async {}).await.unwrap();
608        service.add_timer_handle(handle).await.unwrap();
609
610        // 尝试取消一个不存在的任务
611        let fake_task_id = TaskId::new();
612        let cancelled = service.cancel_task(fake_task_id).await.unwrap();
613        assert!(!cancelled, "Nonexistent task should not be cancelled");
614    }
615
616
617    #[tokio::test]
618    async fn test_task_timeout_cleans_up_task_sender() {
619        let timer = TimerWheel::with_defaults().unwrap();
620        let mut service = timer.create_service();
621
622        // 添加一个短时间的定时器
623        let handle = timer.schedule_once(Duration::from_millis(50), || async {}).await.unwrap();
624        let task_id = handle.task_id();
625        
626        service.add_timer_handle(handle).await.unwrap();
627
628        // 等待任务超时
629        let mut rx = service.take_receiver().unwrap();
630        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
631            .await
632            .expect("Should receive timeout notification")
633            .expect("Should receive Some value");
634        
635        assert_eq!(received_task_id, task_id);
636
637        // 等待一下确保内部清理完成
638        tokio::time::sleep(Duration::from_millis(10)).await;
639
640        // 尝试取消已经超时的任务,应该返回 false
641        let cancelled = service.cancel_task(task_id).await.unwrap();
642        assert!(!cancelled, "Timed out task should not exist anymore");
643    }
644
645    #[tokio::test]
646    async fn test_cancel_task_spawns_background_task() {
647        let timer = TimerWheel::with_defaults().unwrap();
648        let service = timer.create_service();
649        let counter = Arc::new(AtomicU32::new(0));
650
651        // 创建一个定时器
652        let counter_clone = Arc::clone(&counter);
653        let handle = timer.schedule_once(
654            Duration::from_secs(10),
655            move || {
656                let counter = Arc::clone(&counter_clone);
657                async move {
658                    counter.fetch_add(1, Ordering::SeqCst);
659                }
660            },
661        ).await.unwrap();
662        let task_id = handle.task_id();
663        
664        service.add_timer_handle(handle).await.unwrap();
665
666        // 使用 cancel_task(会等待结果,但在后台协程中处理)
667        let cancelled = service.cancel_task(task_id).await.unwrap();
668        assert!(cancelled, "Task should be cancelled successfully");
669
670        // 等待足够长时间确保回调不会被执行
671        tokio::time::sleep(Duration::from_millis(100)).await;
672        assert_eq!(counter.load(Ordering::SeqCst), 0, "Callback should not have been executed");
673
674        // 验证任务已从 active_tasks 中移除
675        let cancelled_again = service.cancel_task(task_id).await.unwrap();
676        assert!(!cancelled_again, "Task should have been removed from active_tasks");
677    }
678
679    #[tokio::test]
680    async fn test_schedule_once_direct() {
681        let timer = TimerWheel::with_defaults().unwrap();
682        let mut service = timer.create_service();
683        let counter = Arc::new(AtomicU32::new(0));
684
685        // 直接通过 service 调度定时器
686        let counter_clone = Arc::clone(&counter);
687        let task_id = service.schedule_once(
688            Duration::from_millis(50),
689            move || {
690                let counter = Arc::clone(&counter_clone);
691                async move {
692                    counter.fetch_add(1, Ordering::SeqCst);
693                }
694            },
695        ).await.unwrap();
696
697        // 等待定时器触发
698        let mut rx = service.take_receiver().unwrap();
699        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
700            .await
701            .expect("Should receive timeout notification")
702            .expect("Should receive Some value");
703
704        assert_eq!(received_task_id, task_id);
705        
706        // 等待回调执行
707        tokio::time::sleep(Duration::from_millis(50)).await;
708        assert_eq!(counter.load(Ordering::SeqCst), 1);
709    }
710
711    #[tokio::test]
712    async fn test_schedule_once_batch_direct() {
713        let timer = TimerWheel::with_defaults().unwrap();
714        let mut service = timer.create_service();
715        let counter = Arc::new(AtomicU32::new(0));
716
717        // 直接通过 service 批量调度定时器
718        let callbacks: Vec<_> = (0..3)
719            .map(|_| {
720                let counter = Arc::clone(&counter);
721                (Duration::from_millis(50), move || {
722                    let counter = Arc::clone(&counter);
723                    async move {
724                        counter.fetch_add(1, Ordering::SeqCst);
725                    }
726                })
727            })
728            .collect();
729
730        let task_ids = service.schedule_once_batch(callbacks).await.unwrap();
731        assert_eq!(task_ids.len(), 3);
732
733        // 接收所有超时通知
734        let mut received_count = 0;
735        let mut rx = service.take_receiver().unwrap();
736        
737        while received_count < 3 {
738            match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
739                Ok(Some(_task_id)) => {
740                    received_count += 1;
741                }
742                Ok(None) => break,
743                Err(_) => break,
744            }
745        }
746
747        assert_eq!(received_count, 3);
748        
749        // 等待回调执行
750        tokio::time::sleep(Duration::from_millis(50)).await;
751        assert_eq!(counter.load(Ordering::SeqCst), 3);
752    }
753
754    #[tokio::test]
755    async fn test_schedule_once_notify_direct() {
756        let timer = TimerWheel::with_defaults().unwrap();
757        let mut service = timer.create_service();
758
759        // 直接通过 service 调度仅通知的定时器
760        let task_id = service.schedule_once_notify(Duration::from_millis(50)).await.unwrap();
761
762        // 接收超时通知
763        let mut rx = service.take_receiver().unwrap();
764        let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
765            .await
766            .expect("Should receive timeout notification")
767            .expect("Should receive Some value");
768
769        assert_eq!(received_task_id, task_id);
770    }
771
772    #[tokio::test]
773    async fn test_schedule_and_cancel_direct() {
774        let timer = TimerWheel::with_defaults().unwrap();
775        let service = timer.create_service();
776        let counter = Arc::new(AtomicU32::new(0));
777
778        // 直接调度定时器
779        let counter_clone = Arc::clone(&counter);
780        let task_id = service.schedule_once(
781            Duration::from_secs(10),
782            move || {
783                let counter = Arc::clone(&counter_clone);
784                async move {
785                    counter.fetch_add(1, Ordering::SeqCst);
786                }
787            },
788        ).await.unwrap();
789
790        // 立即取消
791        let cancelled = service.cancel_task(task_id).await.unwrap();
792        assert!(cancelled, "Task should be cancelled successfully");
793
794        // 等待确保回调不会执行
795        tokio::time::sleep(Duration::from_millis(100)).await;
796        assert_eq!(counter.load(Ordering::SeqCst), 0, "Callback should not have been executed");
797    }
798
799    #[tokio::test]
800    async fn test_cancel_batch_direct() {
801        let timer = TimerWheel::with_defaults().unwrap();
802        let service = timer.create_service();
803        let counter = Arc::new(AtomicU32::new(0));
804
805        // 批量调度定时器
806        let callbacks: Vec<_> = (0..10)
807            .map(|_| {
808                let counter = Arc::clone(&counter);
809                (Duration::from_secs(10), move || {
810                    let counter = Arc::clone(&counter);
811                    async move {
812                        counter.fetch_add(1, Ordering::SeqCst);
813                    }
814                })
815            })
816            .collect();
817
818        let task_ids = service.schedule_once_batch(callbacks).await.unwrap();
819        assert_eq!(task_ids.len(), 10);
820
821        // 批量取消所有任务
822        let cancelled = service.cancel_batch(&task_ids).await;
823        assert_eq!(cancelled, 10, "All 10 tasks should be cancelled");
824
825        // 等待确保回调不会执行
826        tokio::time::sleep(Duration::from_millis(100)).await;
827        assert_eq!(counter.load(Ordering::SeqCst), 0, "No callbacks should have been executed");
828    }
829
830    #[tokio::test]
831    async fn test_cancel_batch_partial() {
832        let timer = TimerWheel::with_defaults().unwrap();
833        let service = timer.create_service();
834        let counter = Arc::new(AtomicU32::new(0));
835
836        // 批量调度定时器
837        let callbacks: Vec<_> = (0..10)
838            .map(|_| {
839                let counter = Arc::clone(&counter);
840                (Duration::from_secs(10), move || {
841                    let counter = Arc::clone(&counter);
842                    async move {
843                        counter.fetch_add(1, Ordering::SeqCst);
844                    }
845                })
846            })
847            .collect();
848
849        let task_ids = service.schedule_once_batch(callbacks).await.unwrap();
850
851        // 只取消前5个任务
852        let to_cancel: Vec<_> = task_ids.iter().take(5).copied().collect();
853        let cancelled = service.cancel_batch(&to_cancel).await;
854        assert_eq!(cancelled, 5, "5 tasks should be cancelled");
855
856        // 等待确保前5个回调不会执行
857        tokio::time::sleep(Duration::from_millis(100)).await;
858        assert_eq!(counter.load(Ordering::SeqCst), 0, "Cancelled tasks should not execute");
859    }
860
861    #[tokio::test]
862    async fn test_cancel_batch_empty() {
863        let timer = TimerWheel::with_defaults().unwrap();
864        let service = timer.create_service();
865
866        // 取消空列表
867        let empty: Vec<TaskId> = vec![];
868        let cancelled = service.cancel_batch(&empty).await;
869        assert_eq!(cancelled, 0, "No tasks should be cancelled");
870    }
871}
872