kestrel_timer/service.rs
1use crate::config::ServiceConfig;
2use crate::error::TimerError;
3use crate::task::{CallbackWrapper, TaskCompletionReason, TaskId};
4use crate::wheel::Wheel;
5use futures::stream::{FuturesUnordered, StreamExt};
6use futures::future::BoxFuture;
7use parking_lot::Mutex;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::sync::mpsc;
11use tokio::task::JoinHandle;
12
13/// Service command type
14///
15/// 服务命令类型
16enum ServiceCommand {
17 /// Add batch timer handle, only contains necessary data: task_ids and completion_rxs
18 ///
19 /// 添加批量定时器句柄,仅包含必要数据: task_ids 和 completion_rxs
20 AddBatchHandle {
21 task_ids: Vec<TaskId>,
22 completion_rxs: Vec<tokio::sync::oneshot::Receiver<TaskCompletionReason>>,
23 },
24 /// Add single timer handle, only contains necessary data: task_id and completion_rx
25 ///
26 /// 添加单个定时器句柄,仅包含必要数据: task_id 和 completion_rx
27 AddTimerHandle {
28 task_id: TaskId,
29 completion_rx: tokio::sync::oneshot::Receiver<TaskCompletionReason>,
30 },
31}
32
33/// TimerService - timer service based on Actor pattern
34/// Manages multiple timer handles, listens to all timeout events, and aggregates TaskId to be forwarded to the user.
35/// # Features
36/// - Automatically listens to all added timer handles' timeout events
37/// - Automatically removes the task from internal management after timeout
38/// - Aggregates TaskId to be forwarded to the user's unified channel
39/// - Supports dynamic addition of BatchHandle and TimerHandle
40///
41///
42/// # 定时器服务,基于 Actor 模式管理多个定时器句柄,监听所有超时事件,并将 TaskId 聚合转发给用户。
43/// - 自动监听所有添加的定时器句柄的超时事件
44/// - 自动在超时后从内部管理中移除任务
45/// - 将 TaskId 聚合转发给用户
46/// - 支持动态添加 BatchHandle 和 TimerHandle
47///
48/// # Examples (示例)
49/// ```no_run
50/// use kestrel_timer::{TimerWheel, TimerService, CallbackWrapper, ServiceConfig};
51/// use std::time::Duration;
52///
53/// #[tokio::main]
54/// async fn main() {
55/// let timer = TimerWheel::with_defaults();
56/// let mut service = timer.create_service(ServiceConfig::default());
57///
58/// // Use two-step API to batch schedule timers through service
59/// // (使用两步 API 通过服务批量调度定时器)
60/// let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = (0..5)
61/// .map(|i| {
62/// let callback = Some(CallbackWrapper::new(move || async move {
63/// println!("Timer {} fired!", i);
64/// }));
65/// (Duration::from_millis(100), callback)
66/// })
67/// .collect();
68/// let tasks = TimerService::create_batch_with_callbacks(callbacks);
69/// service.register_batch(tasks).unwrap();
70///
71/// // Receive timeout notifications (接收超时通知)
72/// let mut rx = service.take_receiver().unwrap();
73/// while let Some(task_id) = rx.recv().await {
74/// // Receive timeout notification (接收超时通知)
75/// println!("Task {:?} completed", task_id);
76/// }
77/// }
78/// ```
79pub struct TimerService {
80 /// Command sender
81 ///
82 /// 命令发送器
83 command_tx: mpsc::Sender<ServiceCommand>,
84 /// Timeout receiver
85 ///
86 /// 超时接收器
87 timeout_rx: Option<mpsc::Receiver<TaskId>>,
88 /// Actor task handle
89 ///
90 /// Actor 任务句柄
91 actor_handle: Option<JoinHandle<()>>,
92 /// Timing wheel reference (for direct scheduling of timers)
93 ///
94 /// 时间轮引用 (用于直接调度定时器)
95 wheel: Arc<Mutex<Wheel>>,
96 /// Actor shutdown signal sender
97 ///
98 /// Actor 关闭信号发送器
99 shutdown_tx: Option<tokio::sync::oneshot::Sender<()>>,
100}
101
102impl TimerService {
103 /// Create new TimerService
104 ///
105 /// # Parameters
106 /// - `wheel`: Timing wheel reference
107 /// - `config`: Service configuration
108 ///
109 /// # Notes
110 /// Typically not called directly, but used to create through `TimerWheel::create_service()`
111 ///
112 /// 创建新的 TimerService
113 ///
114 /// # 参数
115 /// - `wheel`: 时间轮引用
116 /// - `config`: 服务配置
117 ///
118 /// # 注意
119 /// 通常不直接调用,而是通过 `TimerWheel::create_service()` 创建
120 ///
121 pub(crate) fn new(wheel: Arc<Mutex<Wheel>>, config: ServiceConfig) -> Self {
122 let (command_tx, command_rx) = mpsc::channel(config.command_channel_capacity);
123 let (timeout_tx, timeout_rx) = mpsc::channel(config.timeout_channel_capacity);
124
125 let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
126 let actor = ServiceActor::new(command_rx, timeout_tx, shutdown_rx);
127 let actor_handle = tokio::spawn(async move {
128 actor.run().await;
129 });
130
131 Self {
132 command_tx,
133 timeout_rx: Some(timeout_rx),
134 actor_handle: Some(actor_handle),
135 wheel,
136 shutdown_tx: Some(shutdown_tx),
137 }
138 }
139
140 /// Get timeout receiver (transfer ownership)
141 ///
142 /// # Returns
143 /// Timeout notification receiver, if already taken, returns None
144 ///
145 /// # Notes
146 /// This method can only be called once, because it transfers ownership of the receiver
147 ///
148 /// 获取超时通知接收器 (转移所有权)
149 ///
150 /// # 返回值
151 /// 超时通知接收器,如果已经取走,返回 None
152 ///
153 /// # 注意
154 /// 此方法只能调用一次,因为它转移了接收器的所有权
155 ///
156 /// # Examples (示例)
157 /// ```no_run
158 /// # use kestrel_timer::{TimerWheel, ServiceConfig};
159 /// # #[tokio::main]
160 /// # async fn main() {
161 /// let timer = TimerWheel::with_defaults();
162 /// let mut service = timer.create_service(ServiceConfig::default());
163 ///
164 /// let mut rx = service.take_receiver().unwrap();
165 /// while let Some(task_id) = rx.recv().await {
166 /// println!("Task {:?} timed out", task_id);
167 /// }
168 /// # }
169 /// ```
170 pub fn take_receiver(&mut self) -> Option<mpsc::Receiver<TaskId>> {
171 self.timeout_rx.take()
172 }
173
174 /// Cancel specified task
175 ///
176 /// # Parameters
177 /// - `task_id`: Task ID to cancel
178 ///
179 /// # Returns
180 /// - `true`: Task exists and cancellation is successful
181 /// - `false`: Task does not exist or cancellation fails
182 ///
183 /// 取消指定任务
184 ///
185 /// # 参数
186 /// - `task_id`: 任务 ID
187 ///
188 /// # 返回值
189 /// - `true`: 任务存在且取消成功
190 /// - `false`: 任务不存在或取消失败
191 ///
192 /// # Examples (示例)
193 /// ```no_run
194 /// # use kestrel_timer::{TimerWheel, TimerService, CallbackWrapper, ServiceConfig};
195 /// # use std::time::Duration;
196 /// #
197 /// # #[tokio::main]
198 /// # async fn main() {
199 /// let timer = TimerWheel::with_defaults();
200 /// let service = timer.create_service(ServiceConfig::default());
201 ///
202 /// // Use two-step API to schedule timers
203 /// let callback = Some(CallbackWrapper::new(|| async move {
204 /// println!("Timer fired!"); // 定时器触发
205 /// }));
206 /// let task = TimerService::create_task(Duration::from_secs(10), callback);
207 /// let task_id = task.get_id();
208 /// service.register(task).unwrap(); // 注册定时器
209 ///
210 /// // Cancel task
211 /// let cancelled = service.cancel_task(task_id);
212 /// println!("Task cancelled: {}", cancelled); // 任务取消
213 /// # }
214 /// ```
215 #[inline]
216 pub fn cancel_task(&self, task_id: TaskId) -> bool {
217 // Direct cancellation, no need to notify Actor
218 // FuturesUnordered will automatically clean up when tasks are cancelled
219 // 直接取消,无需通知 Actor
220 // FuturesUnordered 将在任务取消时自动清理
221 let mut wheel = self.wheel.lock();
222 wheel.cancel(task_id)
223 }
224
225 /// Batch cancel tasks
226 ///
227 /// Use underlying batch cancellation operation to cancel multiple tasks at once, performance is better than calling cancel_task repeatedly.
228 ///
229 /// # Parameters
230 /// - `task_ids`: List of task IDs to cancel
231 ///
232 /// # Returns
233 /// Number of successfully cancelled tasks
234 ///
235 /// 批量取消任务
236 ///
237 /// # 参数
238 /// - `task_ids`: 任务 ID 列表
239 ///
240 /// # 返回值
241 /// 成功取消的任务数量
242 ///
243 /// # Examples (示例)
244 /// ```no_run
245 /// # use kestrel_timer::{TimerWheel, TimerService, CallbackWrapper, ServiceConfig};
246 /// # use std::time::Duration;
247 /// #
248 /// # #[tokio::main]
249 /// # async fn main() {
250 /// let timer = TimerWheel::with_defaults();
251 /// let service = timer.create_service(ServiceConfig::default());
252 ///
253 /// let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = (0..10)
254 /// .map(|i| {
255 /// let callback = Some(CallbackWrapper::new(move || async move {
256 /// println!("Timer {} fired!", i); // 定时器触发
257 /// }));
258 /// (Duration::from_secs(10), callback)
259 /// })
260 /// .collect();
261 /// let tasks = TimerService::create_batch_with_callbacks(callbacks);
262 /// let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
263 /// service.register_batch(tasks).unwrap(); // 注册定时器
264 ///
265 /// // Batch cancel
266 /// let cancelled = service.cancel_batch(&task_ids);
267 /// println!("Cancelled {} tasks", cancelled); // 任务取消
268 /// # }
269 /// ```
270 #[inline]
271 pub fn cancel_batch(&self, task_ids: &[TaskId]) -> usize {
272 if task_ids.is_empty() {
273 return 0;
274 }
275
276 // Direct batch cancellation, no need to notify Actor
277 // FuturesUnordered will automatically clean up when tasks are cancelled
278 // 直接批量取消,无需通知 Actor
279 // FuturesUnordered 将在任务取消时自动清理
280 let mut wheel = self.wheel.lock();
281 wheel.cancel_batch(task_ids)
282 }
283
284 /// Postpone task (replace callback)
285 ///
286 /// # Parameters
287 /// - `task_id`: Task ID to postpone
288 /// - `new_delay`: New delay time (recalculated from current time point)
289 /// - `callback`: New callback function
290 ///
291 /// # Returns
292 /// - `true`: Task exists and is successfully postponed
293 /// - `false`: Task does not exist or postponement fails
294 ///
295 /// # Notes
296 /// - Task ID remains unchanged after postponement
297 /// - Original timeout notification remains valid
298 /// - Callback function will be replaced with new callback
299 ///
300 /// 推迟任务 (替换回调)
301 ///
302 /// # 参数
303 /// - `task_id`: 任务 ID
304 /// - `new_delay`: 新的延迟时间 (从当前时间点重新计算)
305 /// - `callback`: 新的回调函数
306 ///
307 /// # 返回值
308 /// - `true`: 任务存在且延期成功
309 /// - `false`: 任务不存在或延期失败
310 ///
311 /// # 注意
312 /// - 任务 ID 在延期后保持不变
313 /// - 原始超时通知保持有效
314 /// - 回调函数将被新的回调函数替换
315 ///
316 /// # Examples (示例)
317 /// ```no_run
318 /// # use kestrel_timer::{TimerWheel, TimerService, CallbackWrapper, ServiceConfig};
319 /// # use std::time::Duration;
320 /// #
321 /// # #[tokio::main]
322 /// # async fn main() {
323 /// let timer = TimerWheel::with_defaults();
324 /// let service = timer.create_service(ServiceConfig::default());
325 ///
326 /// let callback = Some(CallbackWrapper::new(|| async {
327 /// println!("Original callback"); // 原始回调
328 /// }));
329 /// let task = TimerService::create_task(Duration::from_secs(5), callback);
330 /// let task_id = task.get_id();
331 /// service.register(task).unwrap(); // 注册定时器
332 ///
333 /// // Postpone and replace callback (延期并替换回调)
334 /// let new_callback = Some(CallbackWrapper::new(|| async { println!("New callback!"); }));
335 /// let success = service.postpone(
336 /// task_id,
337 /// Duration::from_secs(10),
338 /// new_callback
339 /// );
340 /// println!("Postponed successfully: {}", success);
341 /// # }
342 /// ```
343 #[inline]
344 pub fn postpone(&self, task_id: TaskId, new_delay: Duration, callback: Option<CallbackWrapper>) -> bool {
345 let mut wheel = self.wheel.lock();
346 wheel.postpone(task_id, new_delay, callback)
347 }
348
349 /// Batch postpone tasks (keep original callbacks)
350 ///
351 /// # Parameters
352 /// - `updates`: List of tuples of (task ID, new delay)
353 ///
354 /// # Returns
355 /// Number of successfully postponed tasks
356 ///
357 /// 批量延期任务 (保持原始回调)
358 ///
359 /// # 参数
360 /// - `updates`: (任务 ID, 新延迟) 元组列表
361 ///
362 /// # 返回值
363 /// 成功延期的任务数量
364 ///
365 /// # Examples (示例)
366 /// ```no_run
367 /// # use kestrel_timer::{TimerWheel, TimerService, CallbackWrapper, ServiceConfig};
368 /// # use std::time::Duration;
369 /// #
370 /// # #[tokio::main]
371 /// # async fn main() {
372 /// let timer = TimerWheel::with_defaults();
373 /// let service = timer.create_service(ServiceConfig::default());
374 ///
375 /// let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = (0..3)
376 /// .map(|i| {
377 /// let callback = Some(CallbackWrapper::new(move || async move {
378 /// println!("Timer {} fired!", i);
379 /// }));
380 /// (Duration::from_secs(5), callback)
381 /// })
382 /// .collect();
383 /// let tasks = TimerService::create_batch_with_callbacks(callbacks);
384 /// let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
385 /// service.register_batch(tasks).unwrap();
386 ///
387 /// // Batch postpone (keep original callbacks)
388 /// // 批量延期任务 (保持原始回调)
389 /// let updates: Vec<_> = task_ids
390 /// .into_iter()
391 /// .map(|id| (id, Duration::from_secs(10)))
392 /// .collect();
393 /// let postponed = service.postpone_batch(updates);
394 /// println!("Postponed {} tasks", postponed);
395 /// # }
396 /// ```
397 #[inline]
398 pub fn postpone_batch(&self, updates: Vec<(TaskId, Duration)>) -> usize {
399 if updates.is_empty() {
400 return 0;
401 }
402
403 let mut wheel = self.wheel.lock();
404 wheel.postpone_batch(updates)
405 }
406
407 /// Batch postpone tasks (replace callbacks)
408 ///
409 /// # Parameters
410 /// - `updates`: List of tuples of (task ID, new delay, new callback)
411 ///
412 /// # Returns
413 /// Number of successfully postponed tasks
414 ///
415 /// 批量延期任务 (替换回调)
416 ///
417 /// # 参数
418 /// - `updates`: (任务 ID, 新延迟, 新回调) 元组列表
419 ///
420 /// # 返回值
421 /// 成功延期的任务数量
422 ///
423 /// # Examples (示例)
424 /// ```no_run
425 /// # use kestrel_timer::{TimerWheel, TimerService, CallbackWrapper, ServiceConfig};
426 /// # use std::time::Duration;
427 /// #
428 /// # #[tokio::main]
429 /// # async fn main() {
430 /// let timer = TimerWheel::with_defaults();
431 /// let service = timer.create_service(ServiceConfig::default());
432 ///
433 /// // Create 3 tasks, initially no callbacks
434 /// // 创建 3 个任务,最初没有回调
435 /// let delays: Vec<Duration> = (0..3)
436 /// .map(|_| Duration::from_secs(5))
437 /// .collect();
438 /// let tasks = TimerService::create_batch(delays);
439 /// let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
440 /// service.register_batch(tasks).unwrap();
441 ///
442 /// // Batch postpone and add new callbacks
443 /// // 批量延期并添加新的回调
444 /// let updates: Vec<_> = task_ids
445 /// .into_iter()
446 /// .enumerate()
447 /// .map(|(i, id)| {
448 /// let callback = Some(CallbackWrapper::new(move || async move {
449 /// println!("New callback {}", i);
450 /// }));
451 /// (id, Duration::from_secs(10), callback)
452 /// })
453 /// .collect();
454 /// let postponed = service.postpone_batch_with_callbacks(updates);
455 /// println!("Postponed {} tasks", postponed);
456 /// # }
457 /// ```
458 #[inline]
459 pub fn postpone_batch_with_callbacks(
460 &self,
461 updates: Vec<(TaskId, Duration, Option<CallbackWrapper>)>,
462 ) -> usize {
463 if updates.is_empty() {
464 return 0;
465 }
466
467 let mut wheel = self.wheel.lock();
468 wheel.postpone_batch_with_callbacks(updates)
469 }
470
471 /// Create timer task (static method, apply stage)
472 ///
473 /// # Parameters
474 /// - `delay`: Delay time
475 /// - `callback`: Callback object implementing TimerCallback trait
476 ///
477 /// # Returns
478 /// Return TimerTask, needs to be registered through `register()`
479 ///
480 /// 创建定时器任务 (静态方法,应用阶段)
481 ///
482 /// # 参数
483 /// - `delay`: 延迟时间
484 /// - `callback`: 回调对象,实现 TimerCallback 特质
485 ///
486 /// # 返回值
487 /// 返回 TimerTask,需要通过 `register()` 注册
488 ///
489 /// # Examples (示例)
490 /// ```no_run
491 /// # use kestrel_timer::{TimerWheel, TimerService, CallbackWrapper, ServiceConfig};
492 /// # use std::time::Duration;
493 /// #
494 /// # #[tokio::main]
495 /// # async fn main() {
496 /// let timer = TimerWheel::with_defaults();
497 /// let service = timer.create_service(ServiceConfig::default());
498 ///
499 /// // Step 1: create task
500 /// // 创建任务
501 /// let callback = Some(CallbackWrapper::new(|| async {
502 /// println!("Timer fired!");
503 /// }));
504 /// let task = TimerService::create_task(Duration::from_millis(100), callback);
505 ///
506 /// let task_id = task.get_id();
507 /// println!("Created task: {:?}", task_id);
508 ///
509 /// // Step 2: register task
510 /// // 注册任务
511 /// service.register(task).unwrap();
512 /// # }
513 /// ```
514 #[inline]
515 pub fn create_task(delay: Duration, callback: Option<CallbackWrapper>) -> crate::task::TimerTask {
516 crate::timer::TimerWheel::create_task(delay, callback)
517 }
518
519 /// Create batch of timer tasks (static method, apply stage, no callbacks)
520 ///
521 /// # Parameters
522 /// - `delays`: List of delay times
523 ///
524 /// # Returns
525 /// Return TimerTask list, needs to be registered through `register_batch()`
526 ///
527 /// 创建定时器任务 (静态方法,应用阶段,没有回调)
528 /// # 参数
529 /// - `delays`: 延迟时间列表
530 ///
531 /// # 返回值
532 /// 返回 TimerTask 列表,需要通过 `register_batch()` 注册
533 ///
534 /// # Examples (示例)
535 /// ```no_run
536 /// # use kestrel_timer::{TimerWheel, TimerService, CallbackWrapper, ServiceConfig};
537 /// # use std::time::Duration;
538 /// #
539 /// # #[tokio::main]
540 /// # async fn main() {
541 /// let timer = TimerWheel::with_defaults();
542 /// let service = timer.create_service(ServiceConfig::default());
543 ///
544 /// // Step 1: create batch of tasks
545 /// // 创建批量任务
546 /// let delays: Vec<Duration> = (0..3)
547 /// .map(|i| Duration::from_millis(100 * (i + 1)))
548 /// .collect();
549 ///
550 /// let tasks = TimerService::create_batch(delays);
551 /// println!("Created {} tasks", tasks.len());
552 ///
553 /// // Step 2: register batch of tasks
554 /// // 注册批量任务
555 /// service.register_batch(tasks).unwrap();
556 /// # }
557 /// ```
558 #[inline]
559 pub fn create_batch(delays: Vec<Duration>) -> Vec<crate::task::TimerTask> {
560 crate::timer::TimerWheel::create_batch(delays)
561 }
562
563 /// Create batch of timer tasks (static method, apply stage, with callbacks)
564 ///
565 /// # Parameters
566 /// - `callbacks`: List of tuples of (delay time, callback)
567 ///
568 /// # Returns
569 /// Return TimerTask list, needs to be registered through `register_batch()`
570 ///
571 /// 创建定时器任务 (静态方法,应用阶段,有回调)
572 /// # 参数
573 /// - `callbacks`: (延迟时间, 回调) 元组列表
574 ///
575 /// # 返回值
576 /// 返回 TimerTask 列表,需要通过 `register_batch()` 注册
577 ///
578 /// # Examples (示例)
579 /// ```no_run
580 /// # use kestrel_timer::{TimerWheel, TimerService, CallbackWrapper, ServiceConfig};
581 /// # use std::time::Duration;
582 /// #
583 /// # #[tokio::main]
584 /// # async fn main() {
585 /// let timer = TimerWheel::with_defaults();
586 /// let service = timer.create_service(ServiceConfig::default());
587 ///
588 /// // Step 1: create batch of tasks with callbacks
589 /// // 创建批量任务,带有回调
590 /// let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = (0..3)
591 /// .map(|i| {
592 /// let callback = Some(CallbackWrapper::new(move || async move {
593 /// println!("Timer {} fired!", i);
594 /// }));
595 /// (Duration::from_millis(100 * (i + 1)), callback)
596 /// })
597 /// .collect();
598 ///
599 /// let tasks = TimerService::create_batch_with_callbacks(callbacks);
600 /// println!("Created {} tasks", tasks.len());
601 ///
602 /// // Step 2: register batch of tasks with callbacks
603 /// // 注册批量任务,带有回调
604 /// service.register_batch(tasks).unwrap();
605 /// # }
606 /// ```
607 #[inline]
608 pub fn create_batch_with_callbacks(callbacks: Vec<(Duration, Option<CallbackWrapper>)>) -> Vec<crate::task::TimerTask> {
609 crate::timer::TimerWheel::create_batch_with_callbacks(callbacks)
610 }
611
612 /// Register timer task to service (registration phase)
613 ///
614 /// # Parameters
615 /// - `task`: Task created via `create_task()`
616 ///
617 /// # Returns
618 /// - `Ok(())`: Register successfully
619 /// - `Err(TimerError::RegisterFailed)`: Register failed (internal channel is full or closed)
620 ///
621 /// 注册定时器任务到服务 (注册阶段)
622 /// # 参数
623 /// - `task`: 通过 `create_task()` 创建的任务
624 ///
625 /// # 返回值
626 /// - `Ok(())`: 注册成功
627 /// - `Err(TimerError::RegisterFailed)`: 注册失败 (内部通道已满或关闭)
628 ///
629 /// # Examples (示例)
630 /// ```no_run
631 /// # use kestrel_timer::{TimerWheel, TimerService, CallbackWrapper, ServiceConfig};
632 /// # use std::time::Duration;
633 /// #
634 /// # #[tokio::main]
635 /// # async fn main() {
636 /// let timer = TimerWheel::with_defaults();
637 /// let service = timer.create_service(ServiceConfig::default());
638 ///
639 /// // Step 1: create task
640 /// // 创建任务
641 /// let callback = Some(CallbackWrapper::new(|| async move {
642 /// println!("Timer fired!");
643 /// }));
644 /// let task = TimerService::create_task(Duration::from_millis(100), callback);
645 /// let task_id = task.get_id();
646 ///
647 /// // Step 2: register task
648 /// // 注册任务
649 /// service.register(task).unwrap();
650 /// # }
651 /// ```
652 #[inline]
653 pub fn register(&self, task: crate::task::TimerTask) -> Result<(), TimerError> {
654 let (completion_tx, completion_rx) = tokio::sync::oneshot::channel();
655 let notifier = crate::task::CompletionNotifier(completion_tx);
656
657 let task_id = task.id;
658
659 // Single lock, complete all operations
660 // 单次锁定,完成所有操作
661 {
662 let mut wheel_guard = self.wheel.lock();
663 wheel_guard.insert(task, notifier);
664 }
665
666 // Add to service management (only send necessary data)
667 // 添加到服务管理(只发送必要数据)
668 self.command_tx
669 .try_send(ServiceCommand::AddTimerHandle {
670 task_id,
671 completion_rx,
672 })
673 .map_err(|_| TimerError::RegisterFailed)?;
674
675 Ok(())
676 }
677
678 /// Batch register timer tasks to service (registration phase)
679 ///
680 /// # Parameters
681 /// - `tasks`: List of tasks created via `create_batch()`
682 ///
683 /// # Returns
684 /// - `Ok(())`: Register successfully
685 /// - `Err(TimerError::RegisterFailed)`: Register failed (internal channel is full or closed)
686 ///
687 /// 批量注册定时器任务到服务 (注册阶段)
688 /// # 参数
689 /// - `tasks`: 通过 `create_batch()` 创建的任务列表
690 ///
691 /// # 返回值
692 /// - `Ok(())`: 注册成功
693 /// - `Err(TimerError::RegisterFailed)`: 注册失败 (内部通道已满或关闭)
694 ///
695 /// # Examples (示例)
696 /// ```no_run
697 /// # use kestrel_timer::{TimerWheel, TimerService, CallbackWrapper, ServiceConfig};
698 /// # use std::time::Duration;
699 /// #
700 /// # #[tokio::main]
701 /// # async fn main() {
702 /// let timer = TimerWheel::with_defaults();
703 /// let service = timer.create_service(ServiceConfig::default());
704 ///
705 /// // Step 1: create batch of tasks with callbacks
706 /// // 创建批量任务,带有回调
707 /// let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = (0..3)
708 /// .map(|i| {
709 /// let callback = Some(CallbackWrapper::new(move || async move {
710 /// println!("Timer {} fired!", i);
711 /// }));
712 /// (Duration::from_secs(1), callback)
713 /// })
714 /// .collect();
715 /// let tasks = TimerService::create_batch_with_callbacks(callbacks);
716 ///
717 /// // Step 2: register batch of tasks with callbacks
718 /// // 注册批量任务,带有回调
719 /// service.register_batch(tasks).unwrap();
720 /// # }
721 /// ```
722 #[inline]
723 pub fn register_batch(&self, tasks: Vec<crate::task::TimerTask>) -> Result<(), TimerError> {
724 let task_count = tasks.len();
725 let mut completion_rxs = Vec::with_capacity(task_count);
726 let mut task_ids = Vec::with_capacity(task_count);
727 let mut prepared_tasks = Vec::with_capacity(task_count);
728
729 // Step 1: prepare all channels and notifiers (no lock)
730 // 步骤 1: 准备所有通道和通知器(无锁)
731 for task in tasks {
732 let (completion_tx, completion_rx) = tokio::sync::oneshot::channel();
733 let notifier = crate::task::CompletionNotifier(completion_tx);
734
735 task_ids.push(task.id);
736 completion_rxs.push(completion_rx);
737 prepared_tasks.push((task, notifier));
738 }
739
740 // Step 2: single lock, batch insert
741 // 步骤 2: 单次锁定,批量插入
742 {
743 let mut wheel_guard = self.wheel.lock();
744 wheel_guard.insert_batch(prepared_tasks);
745 }
746
747 // Add to service management (only send necessary data)
748 // 添加到服务管理(只发送必要数据)
749 self.command_tx
750 .try_send(ServiceCommand::AddBatchHandle {
751 task_ids,
752 completion_rxs,
753 })
754 .map_err(|_| TimerError::RegisterFailed)?;
755
756 Ok(())
757 }
758
759 /// Graceful shutdown of TimerService
760 ///
761 /// 优雅关闭 TimerService
762 ///
763 /// # Examples (示例)
764 /// ```no_run
765 /// # use kestrel_timer::{TimerWheel, ServiceConfig};
766 /// # #[tokio::main]
767 /// # async fn main() {
768 /// let timer = TimerWheel::with_defaults();
769 /// let mut service = timer.create_service(ServiceConfig::default());
770 ///
771 /// // Use service... (使用服务...)
772 ///
773 /// service.shutdown().await;
774 /// # }
775 /// ```
776 pub async fn shutdown(mut self) {
777 if let Some(shutdown_tx) = self.shutdown_tx.take() {
778 let _ = shutdown_tx.send(());
779 }
780 if let Some(handle) = self.actor_handle.take() {
781 let _ = handle.await;
782 }
783 }
784}
785
786
787impl Drop for TimerService {
788 fn drop(&mut self) {
789 if let Some(handle) = self.actor_handle.take() {
790 handle.abort();
791 }
792 }
793}
794
795/// ServiceActor - internal Actor implementation
796///
797/// ServiceActor - 内部 Actor 实现
798struct ServiceActor {
799 /// Command receiver
800 ///
801 /// 命令接收器
802 command_rx: mpsc::Receiver<ServiceCommand>,
803 /// Timeout sender
804 ///
805 /// 超时发送器
806 timeout_tx: mpsc::Sender<TaskId>,
807 /// Actor shutdown signal receiver
808 ///
809 /// Actor 关闭信号接收器
810 shutdown_rx: tokio::sync::oneshot::Receiver<()>,
811}
812
813impl ServiceActor {
814 /// Create new ServiceActor
815 ///
816 /// 创建新的 ServiceActor
817 fn new(command_rx: mpsc::Receiver<ServiceCommand>, timeout_tx: mpsc::Sender<TaskId>, shutdown_rx: tokio::sync::oneshot::Receiver<()>) -> Self {
818 Self {
819 command_rx,
820 timeout_tx,
821 shutdown_rx,
822 }
823 }
824
825 /// Run Actor event loop
826 ///
827 /// 运行 Actor 事件循环
828 async fn run(mut self) {
829 // Use FuturesUnordered to listen to all completion_rxs
830 // Each future returns (TaskId, Result)
831 // 使用 FuturesUnordered 监听所有 completion_rxs
832 // 每个 future 返回 (TaskId, Result)
833 let mut futures: FuturesUnordered<BoxFuture<'static, (TaskId, Result<TaskCompletionReason, tokio::sync::oneshot::error::RecvError>)>> = FuturesUnordered::new();
834
835 // Move shutdown_rx out of self, so it can be used in select! with &mut
836 // 将 shutdown_rx 从 self 中移出,以便在 select! 中使用 &mut
837 let mut shutdown_rx = self.shutdown_rx;
838
839 loop {
840 tokio::select! {
841 // Listen to high-priority shutdown signal
842 // 监听高优先级关闭信号
843 _ = &mut shutdown_rx => {
844 // Receive shutdown signal, exit loop immediately
845 // 接收到关闭信号,立即退出循环
846 break;
847 }
848
849 // Listen to timeout events
850 // 监听超时事件
851 Some((task_id, result)) = futures.next() => {
852 // Check completion reason, only forward expired (Expired) events, do not forward cancelled (Cancelled) events
853 // 检查完成原因,只转发过期 (Expired) 事件,不转发取消 (Cancelled) 事件
854 if let Ok(TaskCompletionReason::Expired) = result {
855 let _ = self.timeout_tx.send(task_id).await;
856 }
857 // Task will be automatically removed from FuturesUnordered
858 // 任务将自动从 FuturesUnordered 中移除
859 }
860
861 // Listen to commands
862 // 监听命令
863 Some(cmd) = self.command_rx.recv() => {
864 match cmd {
865 ServiceCommand::AddBatchHandle { task_ids, completion_rxs } => {
866 // Add all tasks to futures
867 // 将所有任务添加到 futures
868 for (task_id, rx) in task_ids.into_iter().zip(completion_rxs.into_iter()) {
869 let future: BoxFuture<'static, (TaskId, Result<TaskCompletionReason, tokio::sync::oneshot::error::RecvError>)> = Box::pin(async move {
870 (task_id, rx.await)
871 });
872 futures.push(future);
873 }
874 }
875 ServiceCommand::AddTimerHandle { task_id, completion_rx } => {
876 // Add to futures
877 // 添加到 futures
878 let future: BoxFuture<'static, (TaskId, Result<TaskCompletionReason, tokio::sync::oneshot::error::RecvError>)> = Box::pin(async move {
879 (task_id, completion_rx.await)
880 });
881 futures.push(future);
882 }
883 }
884 }
885
886 // If no futures and command channel is closed, exit loop
887 // 如果没有 futures 且命令通道关闭,退出循环
888 else => {
889 break;
890 }
891 }
892 }
893 }
894}
895
896#[cfg(test)]
897mod tests {
898 use super::*;
899 use crate::TimerWheel;
900 use std::sync::atomic::{AtomicU32, Ordering};
901 use std::sync::Arc;
902 use std::time::Duration;
903
904 #[tokio::test]
905 async fn test_service_creation() {
906 let timer = TimerWheel::with_defaults();
907 let _service = timer.create_service(ServiceConfig::default());
908 }
909
910
911 #[tokio::test]
912 async fn test_add_timer_handle_and_receive_timeout() {
913 let timer = TimerWheel::with_defaults();
914 let mut service = timer.create_service(ServiceConfig::default());
915
916 // Create single timer (创建单个定时器)
917 let task = TimerService::create_task(Duration::from_millis(50), Some(CallbackWrapper::new(|| async {})));
918 let task_id = task.get_id();
919
920 // Register to service (注册到服务)
921 service.register(task).unwrap();
922
923 // Receive timeout notification (接收超时通知)
924 let mut rx = service.take_receiver().unwrap();
925 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
926 .await
927 .expect("Should receive timeout notification")
928 .expect("Should receive Some value");
929
930 assert_eq!(received_task_id, task_id);
931 }
932
933
934 #[tokio::test]
935 async fn test_shutdown() {
936 let timer = TimerWheel::with_defaults();
937 let service = timer.create_service(ServiceConfig::default());
938
939 // Add some timers (添加一些定时器)
940 let task1 = TimerService::create_task(Duration::from_secs(10), None);
941 let task2 = TimerService::create_task(Duration::from_secs(10), None);
942 service.register(task1).unwrap();
943 service.register(task2).unwrap();
944
945 // Immediately shutdown (without waiting for timers to trigger) (立即关闭(不等待定时器触发))
946 service.shutdown().await;
947 }
948
949
950
951 #[tokio::test]
952 async fn test_cancel_task() {
953 let timer = TimerWheel::with_defaults();
954 let service = timer.create_service(ServiceConfig::default());
955
956 // Add a long-term timer (添加一个长期定时器)
957 let task = TimerService::create_task(Duration::from_secs(10), None);
958 let task_id = task.get_id();
959
960 service.register(task).unwrap();
961
962 // Cancel task (取消任务)
963 let cancelled = service.cancel_task(task_id);
964 assert!(cancelled, "Task should be cancelled successfully");
965
966 // Try to cancel the same task again, should return false (再次尝试取消同一任务,应返回 false)
967 let cancelled_again = service.cancel_task(task_id);
968 assert!(!cancelled_again, "Task should not exist anymore");
969 }
970
971 #[tokio::test]
972 async fn test_cancel_nonexistent_task() {
973 let timer = TimerWheel::with_defaults();
974 let service = timer.create_service(ServiceConfig::default());
975
976 // Add a timer to initialize service (添加定时器以初始化服务)
977 let task = TimerService::create_task(Duration::from_millis(50), None);
978 service.register(task).unwrap();
979
980 // Try to cancel a nonexistent task (create a task ID that will not actually be registered)
981 // 尝试取消不存在的任务(创建一个实际不会注册的任务 ID)
982 let fake_task = TimerService::create_task(Duration::from_millis(50), None);
983 let fake_task_id = fake_task.get_id();
984 // Do not register fake_task (不注册 fake_task)
985 let cancelled = service.cancel_task(fake_task_id);
986 assert!(!cancelled, "Nonexistent task should not be cancelled");
987 }
988
989
990 #[tokio::test]
991 async fn test_task_timeout_cleans_up_task_sender() {
992 let timer = TimerWheel::with_defaults();
993 let mut service = timer.create_service(ServiceConfig::default());
994
995 // Add a short-term timer (添加短期定时器)
996 let task = TimerService::create_task(Duration::from_millis(50), None);
997 let task_id = task.get_id();
998
999 service.register(task).unwrap();
1000
1001 // Wait for task timeout (等待任务超时)
1002 let mut rx = service.take_receiver().unwrap();
1003 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1004 .await
1005 .expect("Should receive timeout notification")
1006 .expect("Should receive Some value");
1007
1008 assert_eq!(received_task_id, task_id);
1009
1010 // Wait a moment to ensure internal cleanup is complete (等待片刻以确保内部清理完成)
1011 tokio::time::sleep(Duration::from_millis(10)).await;
1012
1013 // Try to cancel the timed-out task, should return false (尝试取消超时任务,应返回 false)
1014 let cancelled = service.cancel_task(task_id);
1015 assert!(!cancelled, "Timed out task should not exist anymore");
1016 }
1017
1018 #[tokio::test]
1019 async fn test_cancel_task_spawns_background_task() {
1020 let timer = TimerWheel::with_defaults();
1021 let service = timer.create_service(ServiceConfig::default());
1022 let counter = Arc::new(AtomicU32::new(0));
1023
1024 // Create a timer (创建定时器)
1025 let counter_clone = Arc::clone(&counter);
1026 let task = TimerService::create_task(
1027 Duration::from_secs(10),
1028 Some(CallbackWrapper::new(move || {
1029 let counter = Arc::clone(&counter_clone);
1030 async move {
1031 counter.fetch_add(1, Ordering::SeqCst);
1032 }
1033 })),
1034 );
1035 let task_id = task.get_id();
1036
1037 service.register(task).unwrap();
1038
1039 // Use cancel_task (will wait for result, but processed in background coroutine)
1040 // 使用 cancel_task(将等待结果,但在后台协程中处理)
1041 let cancelled = service.cancel_task(task_id);
1042 assert!(cancelled, "Task should be cancelled successfully");
1043
1044 // Wait long enough to ensure callback is not executed (等待足够长时间以确保回调未执行)
1045 tokio::time::sleep(Duration::from_millis(100)).await;
1046 assert_eq!(counter.load(Ordering::SeqCst), 0, "Callback should not have been executed");
1047
1048 // Verify task has been removed from active_tasks (验证任务已从 active_tasks 中移除)
1049 let cancelled_again = service.cancel_task(task_id);
1050 assert!(!cancelled_again, "Task should have been removed from active_tasks");
1051 }
1052
1053 #[tokio::test]
1054 async fn test_schedule_once_direct() {
1055 let timer = TimerWheel::with_defaults();
1056 let mut service = timer.create_service(ServiceConfig::default());
1057 let counter = Arc::new(AtomicU32::new(0));
1058
1059 // Schedule timer directly through service
1060 // 直接通过服务调度定时器
1061 let counter_clone = Arc::clone(&counter);
1062 let task = TimerService::create_task(
1063 Duration::from_millis(50),
1064 Some(CallbackWrapper::new(move || {
1065 let counter = Arc::clone(&counter_clone);
1066 async move {
1067 counter.fetch_add(1, Ordering::SeqCst);
1068 }
1069 })),
1070 );
1071 let task_id = task.get_id();
1072 service.register(task).unwrap();
1073
1074 // Wait for timer to trigger
1075 // 等待定时器触发
1076 let mut rx = service.take_receiver().unwrap();
1077 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1078 .await
1079 .expect("Should receive timeout notification")
1080 .expect("Should receive Some value");
1081
1082 assert_eq!(received_task_id, task_id);
1083
1084 // Wait for callback to execute
1085 // 等待回调执行
1086 tokio::time::sleep(Duration::from_millis(50)).await;
1087 assert_eq!(counter.load(Ordering::SeqCst), 1);
1088 }
1089
1090 #[tokio::test]
1091 async fn test_schedule_once_batch_direct() {
1092 let timer = TimerWheel::with_defaults();
1093 let mut service = timer.create_service(ServiceConfig::default());
1094 let counter = Arc::new(AtomicU32::new(0));
1095
1096 // Schedule timers directly through service
1097 // 直接通过服务调度定时器
1098 let callbacks: Vec<_> = (0..3)
1099 .map(|_| {
1100 let counter = Arc::clone(&counter);
1101 (Duration::from_millis(50), Some(CallbackWrapper::new(move || {
1102 let counter = Arc::clone(&counter);
1103 async move {
1104 counter.fetch_add(1, Ordering::SeqCst);
1105 }
1106 })))
1107 })
1108 .collect();
1109
1110 let tasks = TimerService::create_batch_with_callbacks(callbacks);
1111 assert_eq!(tasks.len(), 3);
1112 service.register_batch(tasks).unwrap();
1113
1114 // Receive all timeout notifications
1115 // 接收所有超时通知
1116 let mut received_count = 0;
1117 let mut rx = service.take_receiver().unwrap();
1118
1119 while received_count < 3 {
1120 match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
1121 Ok(Some(_task_id)) => {
1122 received_count += 1;
1123 }
1124 Ok(None) => break,
1125 Err(_) => break,
1126 }
1127 }
1128
1129 assert_eq!(received_count, 3);
1130
1131 // Wait for callback to execute
1132 // 等待回调执行
1133 tokio::time::sleep(Duration::from_millis(50)).await;
1134 assert_eq!(counter.load(Ordering::SeqCst), 3);
1135 }
1136
1137 #[tokio::test]
1138 async fn test_schedule_once_notify_direct() {
1139 let timer = TimerWheel::with_defaults();
1140 let mut service = timer.create_service(ServiceConfig::default());
1141
1142 // Schedule only notification timer directly through service (no callback)
1143 // 直接通过服务调度通知定时器(没有回调)
1144 let task = TimerService::create_task(Duration::from_millis(50), None);
1145 let task_id = task.get_id();
1146 service.register(task).unwrap();
1147
1148 // Receive timeout notification
1149 // 接收超时通知
1150 let mut rx = service.take_receiver().unwrap();
1151 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1152 .await
1153 .expect("Should receive timeout notification")
1154 .expect("Should receive Some value");
1155
1156 assert_eq!(received_task_id, task_id);
1157 }
1158
1159 #[tokio::test]
1160 async fn test_schedule_and_cancel_direct() {
1161 let timer = TimerWheel::with_defaults();
1162 let service = timer.create_service(ServiceConfig::default());
1163 let counter = Arc::new(AtomicU32::new(0));
1164
1165 // Schedule timer directly
1166 // 直接调度定时器
1167 let counter_clone = Arc::clone(&counter);
1168 let task = TimerService::create_task(
1169 Duration::from_secs(10),
1170 Some(CallbackWrapper::new(move || {
1171 let counter = Arc::clone(&counter_clone);
1172 async move {
1173 counter.fetch_add(1, Ordering::SeqCst);
1174 }
1175 })),
1176 );
1177 let task_id = task.get_id();
1178 service.register(task).unwrap();
1179
1180 // Immediately cancel
1181 // 立即取消
1182 let cancelled = service.cancel_task(task_id);
1183 assert!(cancelled, "Task should be cancelled successfully");
1184
1185 // Wait to ensure callback is not executed
1186 // 等待确保回调未执行
1187 tokio::time::sleep(Duration::from_millis(100)).await;
1188 assert_eq!(counter.load(Ordering::SeqCst), 0, "Callback should not have been executed");
1189 }
1190
1191 #[tokio::test]
1192 async fn test_cancel_batch_direct() {
1193 let timer = TimerWheel::with_defaults();
1194 let service = timer.create_service(ServiceConfig::default());
1195 let counter = Arc::new(AtomicU32::new(0));
1196
1197 // Batch schedule timers
1198 // 批量调度定时器
1199 let callbacks: Vec<_> = (0..10)
1200 .map(|_| {
1201 let counter = Arc::clone(&counter);
1202 (Duration::from_secs(10), Some(CallbackWrapper::new(move || {
1203 let counter = Arc::clone(&counter);
1204 async move {
1205 counter.fetch_add(1, Ordering::SeqCst);
1206 }
1207 })))
1208 })
1209 .collect();
1210
1211 let tasks = TimerService::create_batch_with_callbacks(callbacks);
1212 let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
1213 assert_eq!(task_ids.len(), 10);
1214 service.register_batch(tasks).unwrap();
1215
1216 // Batch cancel all tasks
1217 // 批量取消所有任务
1218 let cancelled = service.cancel_batch(&task_ids);
1219 assert_eq!(cancelled, 10, "All 10 tasks should be cancelled");
1220
1221 // Wait to ensure callback is not executed
1222 // 等待确保回调未执行
1223 tokio::time::sleep(Duration::from_millis(100)).await;
1224 assert_eq!(counter.load(Ordering::SeqCst), 0, "No callbacks should have been executed");
1225 }
1226
1227 #[tokio::test]
1228 async fn test_cancel_batch_partial() {
1229 let timer = TimerWheel::with_defaults();
1230 let service = timer.create_service(ServiceConfig::default());
1231 let counter = Arc::new(AtomicU32::new(0));
1232
1233 // Batch schedule timers
1234 // 批量调度定时器
1235 let callbacks: Vec<_> = (0..10)
1236 .map(|_| {
1237 let counter = Arc::clone(&counter);
1238 (Duration::from_secs(10), Some(CallbackWrapper::new(move || {
1239 let counter = Arc::clone(&counter);
1240 async move {
1241 counter.fetch_add(1, Ordering::SeqCst);
1242 }
1243 })))
1244 })
1245 .collect();
1246
1247 let tasks = TimerService::create_batch_with_callbacks(callbacks);
1248 let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
1249 service.register_batch(tasks).unwrap();
1250
1251 // Only cancel first 5 tasks
1252 // 只取消前 5 个任务
1253 let to_cancel: Vec<_> = task_ids.iter().take(5).copied().collect();
1254 let cancelled = service.cancel_batch(&to_cancel);
1255 assert_eq!(cancelled, 5, "5 tasks should be cancelled");
1256
1257 // Wait to ensure first 5 callbacks are not executed
1258 // 等待确保前 5 个回调未执行
1259 tokio::time::sleep(Duration::from_millis(100)).await;
1260 assert_eq!(counter.load(Ordering::SeqCst), 0, "Cancelled tasks should not execute");
1261 }
1262
1263 #[tokio::test]
1264 async fn test_cancel_batch_empty() {
1265 let timer = TimerWheel::with_defaults();
1266 let service = timer.create_service(ServiceConfig::default());
1267
1268 // Cancel empty list
1269 // 取消空列表
1270 let empty: Vec<TaskId> = vec![];
1271 let cancelled = service.cancel_batch(&empty);
1272 assert_eq!(cancelled, 0, "No tasks should be cancelled");
1273 }
1274
1275 #[tokio::test]
1276 async fn test_postpone() {
1277 let timer = TimerWheel::with_defaults();
1278 let mut service = timer.create_service(ServiceConfig::default());
1279 let counter = Arc::new(AtomicU32::new(0));
1280
1281 // Register a task, original callback increases 1
1282 // 注册一个任务,原始回调增加 1
1283 let counter_clone1 = Arc::clone(&counter);
1284 let task = TimerService::create_task(
1285 Duration::from_millis(50),
1286 Some(CallbackWrapper::new(move || {
1287 let counter = Arc::clone(&counter_clone1);
1288 async move {
1289 counter.fetch_add(1, Ordering::SeqCst);
1290 }
1291 })),
1292 );
1293 let task_id = task.get_id();
1294 service.register(task).unwrap();
1295
1296 // Postpone task and replace callback, new callback increases 10
1297 // 延期任务并替换回调,新回调增加 10
1298 let counter_clone2 = Arc::clone(&counter);
1299 let postponed = service.postpone(
1300 task_id,
1301 Duration::from_millis(100),
1302 Some(CallbackWrapper::new(move || {
1303 let counter = Arc::clone(&counter_clone2);
1304 async move {
1305 counter.fetch_add(10, Ordering::SeqCst);
1306 }
1307 }))
1308 );
1309 assert!(postponed, "Task should be postponed successfully");
1310
1311 // Receive timeout notification (after postponing, need to wait 100ms, plus margin)
1312 // 接收超时通知(延期后,需要等待 100ms,加上余量)
1313 let mut rx = service.take_receiver().unwrap();
1314 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1315 .await
1316 .expect("Should receive timeout notification")
1317 .expect("Should receive Some value");
1318
1319 assert_eq!(received_task_id, task_id);
1320
1321 // Wait for callback to execute
1322 // 等待回调执行
1323 tokio::time::sleep(Duration::from_millis(20)).await;
1324
1325 // Verify new callback is executed (increased 10 instead of 1)
1326 // 验证新回调已执行(增加 10 而不是 1)
1327 assert_eq!(counter.load(Ordering::SeqCst), 10);
1328 }
1329
1330 #[tokio::test]
1331 async fn test_postpone_nonexistent_task() {
1332 let timer = TimerWheel::with_defaults();
1333 let service = timer.create_service(ServiceConfig::default());
1334
1335 // Try to postpone a nonexistent task
1336 // 尝试延期一个不存在的任务
1337 let fake_task = TimerService::create_task(Duration::from_millis(50), None);
1338 let fake_task_id = fake_task.get_id();
1339 // Do not register this task
1340 // 不注册这个任务
1341 let postponed = service.postpone(fake_task_id, Duration::from_millis(100), None);
1342 assert!(!postponed, "Nonexistent task should not be postponed");
1343 }
1344
1345 #[tokio::test]
1346 async fn test_postpone_batch() {
1347 let timer = TimerWheel::with_defaults();
1348 let mut service = timer.create_service(ServiceConfig::default());
1349 let counter = Arc::new(AtomicU32::new(0));
1350
1351 // Register 3 tasks
1352 // 注册 3 个任务
1353 let mut task_ids = Vec::new();
1354 for _ in 0..3 {
1355 let counter_clone = Arc::clone(&counter);
1356 let task = TimerService::create_task(
1357 Duration::from_millis(50),
1358 Some(CallbackWrapper::new(move || {
1359 let counter = Arc::clone(&counter_clone);
1360 async move {
1361 counter.fetch_add(1, Ordering::SeqCst);
1362 }
1363 })),
1364 );
1365 task_ids.push((task.get_id(), Duration::from_millis(150), None));
1366 service.register(task).unwrap();
1367 }
1368
1369 // Batch postpone
1370 // 批量延期
1371 let postponed = service.postpone_batch_with_callbacks(task_ids);
1372 assert_eq!(postponed, 3, "All 3 tasks should be postponed");
1373
1374 // Wait for original time 50ms, task should not trigger
1375 // 等待原始时间 50ms,任务应该不触发
1376 tokio::time::sleep(Duration::from_millis(70)).await;
1377 assert_eq!(counter.load(Ordering::SeqCst), 0);
1378
1379 // Receive all timeout notifications
1380 // 接收所有超时通知
1381 let mut received_count = 0;
1382 let mut rx = service.take_receiver().unwrap();
1383
1384 while received_count < 3 {
1385 match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
1386 Ok(Some(_task_id)) => {
1387 received_count += 1;
1388 }
1389 Ok(None) => break,
1390 Err(_) => break,
1391 }
1392 }
1393
1394 assert_eq!(received_count, 3);
1395
1396 // Wait for callback to execute
1397 // 等待回调执行
1398 tokio::time::sleep(Duration::from_millis(20)).await;
1399 assert_eq!(counter.load(Ordering::SeqCst), 3);
1400 }
1401
1402 #[tokio::test]
1403 async fn test_postpone_batch_with_callbacks() {
1404 let timer = TimerWheel::with_defaults();
1405 let mut service = timer.create_service(ServiceConfig::default());
1406 let counter = Arc::new(AtomicU32::new(0));
1407
1408 // Register 3 tasks
1409 // 注册 3 个任务
1410 let mut task_ids = Vec::new();
1411 for _ in 0..3 {
1412 let task = TimerService::create_task(
1413 Duration::from_millis(50),
1414 None,
1415 );
1416 task_ids.push(task.get_id());
1417 service.register(task).unwrap();
1418 }
1419
1420 // Batch postpone and replace callback
1421 // 批量延期并替换回调
1422 let updates: Vec<_> = task_ids
1423 .into_iter()
1424 .map(|id| {
1425 let counter_clone = Arc::clone(&counter);
1426 (id, Duration::from_millis(150), Some(CallbackWrapper::new(move || {
1427 let counter = Arc::clone(&counter_clone);
1428 async move {
1429 counter.fetch_add(1, Ordering::SeqCst);
1430 }
1431 })))
1432 })
1433 .collect();
1434
1435 let postponed = service.postpone_batch_with_callbacks(updates);
1436 assert_eq!(postponed, 3, "All 3 tasks should be postponed");
1437
1438 // Wait for original time 50ms, task should not trigger
1439 // 等待原始时间 50ms,任务应该不触发
1440 tokio::time::sleep(Duration::from_millis(70)).await;
1441 assert_eq!(counter.load(Ordering::SeqCst), 0);
1442
1443 // Receive all timeout notifications
1444 // 接收所有超时通知
1445 let mut received_count = 0;
1446 let mut rx = service.take_receiver().unwrap();
1447
1448 while received_count < 3 {
1449 match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
1450 Ok(Some(_task_id)) => {
1451 received_count += 1;
1452 }
1453 Ok(None) => break,
1454 Err(_) => break,
1455 }
1456 }
1457
1458 assert_eq!(received_count, 3);
1459
1460 // Wait for callback to execute
1461 // 等待回调执行
1462 tokio::time::sleep(Duration::from_millis(20)).await;
1463 assert_eq!(counter.load(Ordering::SeqCst), 3);
1464 }
1465
1466 #[tokio::test]
1467 async fn test_postpone_batch_empty() {
1468 let timer = TimerWheel::with_defaults();
1469 let service = timer.create_service(ServiceConfig::default());
1470
1471 // Postpone empty list
1472 let empty: Vec<(TaskId, Duration, Option<CallbackWrapper>)> = vec![];
1473 let postponed = service.postpone_batch_with_callbacks(empty);
1474 assert_eq!(postponed, 0, "No tasks should be postponed");
1475 }
1476
1477 #[tokio::test]
1478 async fn test_postpone_keeps_timeout_notification_valid() {
1479 let timer = TimerWheel::with_defaults();
1480 let mut service = timer.create_service(ServiceConfig::default());
1481 let counter = Arc::new(AtomicU32::new(0));
1482
1483 // Register a task
1484 // 注册一个任务
1485 let counter_clone = Arc::clone(&counter);
1486 let task = TimerService::create_task(
1487 Duration::from_millis(50),
1488 Some(CallbackWrapper::new(move || {
1489 let counter = Arc::clone(&counter_clone);
1490 async move {
1491 counter.fetch_add(1, Ordering::SeqCst);
1492 }
1493 })),
1494 );
1495 let task_id = task.get_id();
1496 service.register(task).unwrap();
1497
1498 // Postpone task
1499 // 延期任务
1500 service.postpone(task_id, Duration::from_millis(100), None);
1501
1502 // Verify timeout notification is still valid (after postponing, need to wait 100ms, plus margin)
1503 // 验证超时通知是否仍然有效(延期后,需要等待 100ms,加上余量)
1504 let mut rx = service.take_receiver().unwrap();
1505 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1506 .await
1507 .expect("Should receive timeout notification")
1508 .expect("Should receive Some value");
1509
1510 assert_eq!(received_task_id, task_id, "Timeout notification should still work after postpone");
1511
1512 // Wait for callback to execute
1513 // 等待回调执行
1514 tokio::time::sleep(Duration::from_millis(20)).await;
1515 assert_eq!(counter.load(Ordering::SeqCst), 1);
1516 }
1517
1518 #[tokio::test]
1519 async fn test_cancelled_task_not_forwarded_to_timeout_rx() {
1520 let timer = TimerWheel::with_defaults();
1521 let mut service = timer.create_service(ServiceConfig::default());
1522
1523 // Register two tasks: one will be cancelled, one will expire normally
1524 // 注册两个任务:一个将被取消,一个将正常过期
1525 let task1 = TimerService::create_task(Duration::from_secs(10), None);
1526 let task1_id = task1.get_id();
1527 service.register(task1).unwrap();
1528
1529 let task2 = TimerService::create_task(Duration::from_millis(50), None);
1530 let task2_id = task2.get_id();
1531 service.register(task2).unwrap();
1532
1533 // Cancel first task
1534 // 取消第一个任务
1535 let cancelled = service.cancel_task(task1_id);
1536 assert!(cancelled, "Task should be cancelled");
1537
1538 // Wait for second task to expire
1539 // 等待第二个任务过期
1540 let mut rx = service.take_receiver().unwrap();
1541 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1542 .await
1543 .expect("Should receive timeout notification")
1544 .expect("Should receive Some value");
1545
1546 // Should only receive notification for second task (expired), not for first task (cancelled)
1547 assert_eq!(received_task_id, task2_id, "Should only receive expired task notification");
1548
1549 // Verify no other notifications (especially cancelled tasks should not have notifications)
1550 let no_more = tokio::time::timeout(Duration::from_millis(50), rx.recv()).await;
1551 assert!(no_more.is_err(), "Should not receive any more notifications");
1552 }
1553
1554 #[tokio::test]
1555 async fn test_take_receiver_twice() {
1556 let timer = TimerWheel::with_defaults();
1557 let mut service = timer.create_service(ServiceConfig::default());
1558
1559 // First call should return Some
1560 // 第一次调用应该返回 Some
1561 let rx1 = service.take_receiver();
1562 assert!(rx1.is_some(), "First take_receiver should return Some");
1563
1564 // Second call should return None
1565 // 第二次调用应该返回 None
1566 let rx2 = service.take_receiver();
1567 assert!(rx2.is_none(), "Second take_receiver should return None");
1568 }
1569
1570 #[tokio::test]
1571 async fn test_postpone_batch_without_callbacks() {
1572 let timer = TimerWheel::with_defaults();
1573 let mut service = timer.create_service(ServiceConfig::default());
1574 let counter = Arc::new(AtomicU32::new(0));
1575
1576 // Register 3 tasks, with original callback
1577 // 注册 3 个任务,带有原始回调
1578 let mut task_ids = Vec::new();
1579 for _ in 0..3 {
1580 let counter_clone = Arc::clone(&counter);
1581 let task = TimerService::create_task(
1582 Duration::from_millis(50),
1583 Some(CallbackWrapper::new(move || {
1584 let counter = Arc::clone(&counter_clone);
1585 async move {
1586 counter.fetch_add(1, Ordering::SeqCst);
1587 }
1588 })),
1589 );
1590 task_ids.push(task.get_id());
1591 service.register(task).unwrap();
1592 }
1593
1594 // Batch postpone, without replacing callback
1595 // 批量延期,不替换回调
1596 let updates: Vec<_> = task_ids
1597 .iter()
1598 .map(|&id| (id, Duration::from_millis(150)))
1599 .collect();
1600 let postponed = service.postpone_batch(updates);
1601 assert_eq!(postponed, 3, "All 3 tasks should be postponed");
1602
1603 // Wait for original time 50ms, task should not trigger
1604 // 等待原始时间 50ms,任务应该不触发
1605 tokio::time::sleep(Duration::from_millis(70)).await;
1606 assert_eq!(counter.load(Ordering::SeqCst), 0, "Callbacks should not fire yet");
1607
1608 // Receive all timeout notifications
1609 // 接收所有超时通知
1610 let mut received_count = 0;
1611 let mut rx = service.take_receiver().unwrap();
1612
1613 while received_count < 3 {
1614 match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
1615 Ok(Some(_task_id)) => {
1616 received_count += 1;
1617 }
1618 Ok(None) => break,
1619 Err(_) => break,
1620 }
1621 }
1622
1623 assert_eq!(received_count, 3, "Should receive 3 timeout notifications");
1624
1625 // Wait for callback to execute
1626 // 等待回调执行
1627 tokio::time::sleep(Duration::from_millis(20)).await;
1628 assert_eq!(counter.load(Ordering::SeqCst), 3, "All callbacks should execute");
1629 }
1630}
1631