kestrel_timer/service.rs
1use crate::{BatchHandle, TimerHandle};
2use crate::config::ServiceConfig;
3use crate::error::TimerError;
4use crate::task::{CallbackWrapper, TaskCompletionReason, TaskId};
5use crate::wheel::Wheel;
6use futures::stream::{FuturesUnordered, StreamExt};
7use futures::future::BoxFuture;
8use parking_lot::Mutex;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::sync::mpsc;
12use tokio::task::JoinHandle;
13
14/// Service command type
15///
16/// 服务命令类型
17enum ServiceCommand {
18 /// Add batch timer handle, only contains necessary data: task_ids and completion_rxs
19 ///
20 /// 添加批量定时器句柄,仅包含必要数据: task_ids 和 completion_rxs
21 AddBatchHandle {
22 task_ids: Vec<TaskId>,
23 completion_rxs: Vec<tokio::sync::oneshot::Receiver<TaskCompletionReason>>,
24 },
25 /// Add single timer handle, only contains necessary data: task_id and completion_rx
26 ///
27 /// 添加单个定时器句柄,仅包含必要数据: task_id 和 completion_rx
28 AddTimerHandle {
29 task_id: TaskId,
30 completion_rx: tokio::sync::oneshot::Receiver<TaskCompletionReason>,
31 },
32}
33
34/// TimerService - timer service based on Actor pattern
35/// Manages multiple timer handles, listens to all timeout events, and aggregates TaskId to be forwarded to the user.
36/// # Features
37/// - Automatically listens to all added timer handles' timeout events
38/// - Automatically removes the task from internal management after timeout
39/// - Aggregates TaskId to be forwarded to the user's unified channel
40/// - Supports dynamic addition of BatchHandle and TimerHandle
41///
42///
43/// # 定时器服务,基于 Actor 模式管理多个定时器句柄,监听所有超时事件,并将 TaskId 聚合转发给用户。
44/// - 自动监听所有添加的定时器句柄的超时事件
45/// - 自动在超时后从内部管理中移除任务
46/// - 将 TaskId 聚合转发给用户
47/// - 支持动态添加 BatchHandle 和 TimerHandle
48///
49/// # Examples (示例)
50/// ```no_run
51/// use kestrel_timer::{TimerWheel, TimerService, CallbackWrapper, config::ServiceConfig};
52/// use std::time::Duration;
53///
54/// #[tokio::main]
55/// async fn main() {
56/// let timer = TimerWheel::with_defaults();
57/// let mut service = timer.create_service(ServiceConfig::default());
58///
59/// // Use two-step API to batch schedule timers through service
60/// // (使用两步 API 通过服务批量调度定时器)
61/// let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = (0..5)
62/// .map(|i| {
63/// let callback = Some(CallbackWrapper::new(move || async move {
64/// println!("Timer {} fired!", i);
65/// }));
66/// (Duration::from_millis(100), callback)
67/// })
68/// .collect();
69/// let tasks = TimerService::create_batch_with_callbacks(callbacks);
70/// service.register_batch(tasks).unwrap();
71///
72/// // Receive timeout notifications (接收超时通知)
73/// let mut rx = service.take_receiver().unwrap();
74/// while let Some(task_id) = rx.recv().await {
75/// // Receive timeout notification (接收超时通知)
76/// println!("Task {:?} completed", task_id);
77/// }
78/// }
79/// ```
80pub struct TimerService {
81 /// Command sender
82 ///
83 /// 命令发送器
84 command_tx: mpsc::Sender<ServiceCommand>,
85 /// Timeout receiver
86 ///
87 /// 超时接收器
88 timeout_rx: Option<mpsc::Receiver<TaskId>>,
89 /// Actor task handle
90 ///
91 /// Actor 任务句柄
92 actor_handle: Option<JoinHandle<()>>,
93 /// Timing wheel reference (for direct scheduling of timers)
94 ///
95 /// 时间轮引用 (用于直接调度定时器)
96 wheel: Arc<Mutex<Wheel>>,
97 /// Actor shutdown signal sender
98 ///
99 /// Actor 关闭信号发送器
100 shutdown_tx: Option<tokio::sync::oneshot::Sender<()>>,
101}
102
103impl TimerService {
104 /// Create new TimerService
105 ///
106 /// # Parameters
107 /// - `wheel`: Timing wheel reference
108 /// - `config`: Service configuration
109 ///
110 /// # Notes
111 /// Typically not called directly, but used to create through `TimerWheel::create_service()`
112 ///
113 /// 创建新的 TimerService
114 ///
115 /// # 参数
116 /// - `wheel`: 时间轮引用
117 /// - `config`: 服务配置
118 ///
119 /// # 注意
120 /// 通常不直接调用,而是通过 `TimerWheel::create_service()` 创建
121 ///
122 pub(crate) fn new(wheel: Arc<Mutex<Wheel>>, config: ServiceConfig) -> Self {
123 let (command_tx, command_rx) = mpsc::channel(config.command_channel_capacity);
124 let (timeout_tx, timeout_rx) = mpsc::channel(config.timeout_channel_capacity);
125
126 let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
127 let actor = ServiceActor::new(command_rx, timeout_tx, shutdown_rx);
128 let actor_handle = tokio::spawn(async move {
129 actor.run().await;
130 });
131
132 Self {
133 command_tx,
134 timeout_rx: Some(timeout_rx),
135 actor_handle: Some(actor_handle),
136 wheel,
137 shutdown_tx: Some(shutdown_tx),
138 }
139 }
140
141 /// Get timeout receiver (transfer ownership)
142 ///
143 /// # Returns
144 /// Timeout notification receiver, if already taken, returns None
145 ///
146 /// # Notes
147 /// This method can only be called once, because it transfers ownership of the receiver
148 ///
149 /// 获取超时通知接收器 (转移所有权)
150 ///
151 /// # 返回值
152 /// 超时通知接收器,如果已经取走,返回 None
153 ///
154 /// # 注意
155 /// 此方法只能调用一次,因为它转移了接收器的所有权
156 ///
157 /// # Examples (示例)
158 /// ```no_run
159 /// # use kestrel_timer::{TimerWheel, config::ServiceConfig};
160 /// # #[tokio::main]
161 /// # async fn main() {
162 /// let timer = TimerWheel::with_defaults();
163 /// let mut service = timer.create_service(ServiceConfig::default());
164 ///
165 /// let mut rx = service.take_receiver().unwrap();
166 /// while let Some(task_id) = rx.recv().await {
167 /// println!("Task {:?} timed out", task_id);
168 /// }
169 /// # }
170 /// ```
171 pub fn take_receiver(&mut self) -> Option<mpsc::Receiver<TaskId>> {
172 self.timeout_rx.take()
173 }
174
175 /// Cancel specified task
176 ///
177 /// # Parameters
178 /// - `task_id`: Task ID to cancel
179 ///
180 /// # Returns
181 /// - `true`: Task exists and cancellation is successful
182 /// - `false`: Task does not exist or cancellation fails
183 ///
184 /// 取消指定任务
185 ///
186 /// # 参数
187 /// - `task_id`: 任务 ID
188 ///
189 /// # 返回值
190 /// - `true`: 任务存在且取消成功
191 /// - `false`: 任务不存在或取消失败
192 ///
193 /// # Examples (示例)
194 /// ```no_run
195 /// # use kestrel_timer::{TimerWheel, TimerService, CallbackWrapper, config::ServiceConfig};
196 /// # use std::time::Duration;
197 /// #
198 /// # #[tokio::main]
199 /// # async fn main() {
200 /// let timer = TimerWheel::with_defaults();
201 /// let service = timer.create_service(ServiceConfig::default());
202 ///
203 /// // Use two-step API to schedule timers
204 /// let callback = Some(CallbackWrapper::new(|| async move {
205 /// println!("Timer fired!"); // 定时器触发
206 /// }));
207 /// let task = TimerService::create_task(Duration::from_secs(10), callback);
208 /// let task_id = task.get_id();
209 /// service.register(task).unwrap(); // 注册定时器
210 ///
211 /// // Cancel task
212 /// let cancelled = service.cancel_task(task_id);
213 /// println!("Task cancelled: {}", cancelled); // 任务取消
214 /// # }
215 /// ```
216 #[inline]
217 pub fn cancel_task(&self, task_id: TaskId) -> bool {
218 // Direct cancellation, no need to notify Actor
219 // FuturesUnordered will automatically clean up when tasks are cancelled
220 // 直接取消,无需通知 Actor
221 // FuturesUnordered 将在任务取消时自动清理
222 let mut wheel = self.wheel.lock();
223 wheel.cancel(task_id)
224 }
225
226 /// Batch cancel tasks
227 ///
228 /// Use underlying batch cancellation operation to cancel multiple tasks at once, performance is better than calling cancel_task repeatedly.
229 ///
230 /// # Parameters
231 /// - `task_ids`: List of task IDs to cancel
232 ///
233 /// # Returns
234 /// Number of successfully cancelled tasks
235 ///
236 /// 批量取消任务
237 ///
238 /// # 参数
239 /// - `task_ids`: 任务 ID 列表
240 ///
241 /// # 返回值
242 /// 成功取消的任务数量
243 ///
244 /// # Examples (示例)
245 /// ```no_run
246 /// # use kestrel_timer::{TimerWheel, TimerService, CallbackWrapper, config::ServiceConfig};
247 /// # use std::time::Duration;
248 /// #
249 /// # #[tokio::main]
250 /// # async fn main() {
251 /// let timer = TimerWheel::with_defaults();
252 /// let service = timer.create_service(ServiceConfig::default());
253 ///
254 /// let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = (0..10)
255 /// .map(|i| {
256 /// let callback = Some(CallbackWrapper::new(move || async move {
257 /// println!("Timer {} fired!", i); // 定时器触发
258 /// }));
259 /// (Duration::from_secs(10), callback)
260 /// })
261 /// .collect();
262 /// let tasks = TimerService::create_batch_with_callbacks(callbacks);
263 /// let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
264 /// service.register_batch(tasks).unwrap(); // 注册定时器
265 ///
266 /// // Batch cancel
267 /// let cancelled = service.cancel_batch(&task_ids);
268 /// println!("Cancelled {} tasks", cancelled); // 任务取消
269 /// # }
270 /// ```
271 #[inline]
272 pub fn cancel_batch(&self, task_ids: &[TaskId]) -> usize {
273 if task_ids.is_empty() {
274 return 0;
275 }
276
277 // Direct batch cancellation, no need to notify Actor
278 // FuturesUnordered will automatically clean up when tasks are cancelled
279 // 直接批量取消,无需通知 Actor
280 // FuturesUnordered 将在任务取消时自动清理
281 let mut wheel = self.wheel.lock();
282 wheel.cancel_batch(task_ids)
283 }
284
285 /// Postpone task (replace callback)
286 ///
287 /// # Parameters
288 /// - `task_id`: Task ID to postpone
289 /// - `new_delay`: New delay time (recalculated from current time point)
290 /// - `callback`: New callback function
291 ///
292 /// # Returns
293 /// - `true`: Task exists and is successfully postponed
294 /// - `false`: Task does not exist or postponement fails
295 ///
296 /// # Notes
297 /// - Task ID remains unchanged after postponement
298 /// - Original timeout notification remains valid
299 /// - Callback function will be replaced with new callback
300 ///
301 /// 推迟任务 (替换回调)
302 ///
303 /// # 参数
304 /// - `task_id`: 任务 ID
305 /// - `new_delay`: 新的延迟时间 (从当前时间点重新计算)
306 /// - `callback`: 新的回调函数
307 ///
308 /// # 返回值
309 /// - `true`: 任务存在且延期成功
310 /// - `false`: 任务不存在或延期失败
311 ///
312 /// # 注意
313 /// - 任务 ID 在延期后保持不变
314 /// - 原始超时通知保持有效
315 /// - 回调函数将被新的回调函数替换
316 ///
317 /// # Examples (示例)
318 /// ```no_run
319 /// # use kestrel_timer::{TimerWheel, TimerService, CallbackWrapper, config::ServiceConfig};
320 /// # use std::time::Duration;
321 /// #
322 /// # #[tokio::main]
323 /// # async fn main() {
324 /// let timer = TimerWheel::with_defaults();
325 /// let service = timer.create_service(ServiceConfig::default());
326 ///
327 /// let callback = Some(CallbackWrapper::new(|| async {
328 /// println!("Original callback"); // 原始回调
329 /// }));
330 /// let task = TimerService::create_task(Duration::from_secs(5), callback);
331 /// let task_id = task.get_id();
332 /// service.register(task).unwrap(); // 注册定时器
333 ///
334 /// // Postpone and replace callback (延期并替换回调)
335 /// let new_callback = Some(CallbackWrapper::new(|| async { println!("New callback!"); }));
336 /// let success = service.postpone(
337 /// task_id,
338 /// Duration::from_secs(10),
339 /// new_callback
340 /// );
341 /// println!("Postponed successfully: {}", success);
342 /// # }
343 /// ```
344 #[inline]
345 pub fn postpone(&self, task_id: TaskId, new_delay: Duration, callback: Option<CallbackWrapper>) -> bool {
346 let mut wheel = self.wheel.lock();
347 wheel.postpone(task_id, new_delay, callback)
348 }
349
350 /// Batch postpone tasks (keep original callbacks)
351 ///
352 /// # Parameters
353 /// - `updates`: List of tuples of (task ID, new delay)
354 ///
355 /// # Returns
356 /// Number of successfully postponed tasks
357 ///
358 /// 批量延期任务 (保持原始回调)
359 ///
360 /// # 参数
361 /// - `updates`: (任务 ID, 新延迟) 元组列表
362 ///
363 /// # 返回值
364 /// 成功延期的任务数量
365 ///
366 /// # Examples (示例)
367 /// ```no_run
368 /// # use kestrel_timer::{TimerWheel, TimerService, CallbackWrapper, config::ServiceConfig};
369 /// # use std::time::Duration;
370 /// #
371 /// # #[tokio::main]
372 /// # async fn main() {
373 /// let timer = TimerWheel::with_defaults();
374 /// let service = timer.create_service(ServiceConfig::default());
375 ///
376 /// let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = (0..3)
377 /// .map(|i| {
378 /// let callback = Some(CallbackWrapper::new(move || async move {
379 /// println!("Timer {} fired!", i);
380 /// }));
381 /// (Duration::from_secs(5), callback)
382 /// })
383 /// .collect();
384 /// let tasks = TimerService::create_batch_with_callbacks(callbacks);
385 /// let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
386 /// service.register_batch(tasks).unwrap();
387 ///
388 /// // Batch postpone (keep original callbacks)
389 /// // 批量延期任务 (保持原始回调)
390 /// let updates: Vec<_> = task_ids
391 /// .into_iter()
392 /// .map(|id| (id, Duration::from_secs(10)))
393 /// .collect();
394 /// let postponed = service.postpone_batch(updates);
395 /// println!("Postponed {} tasks", postponed);
396 /// # }
397 /// ```
398 #[inline]
399 pub fn postpone_batch(&self, updates: Vec<(TaskId, Duration)>) -> usize {
400 if updates.is_empty() {
401 return 0;
402 }
403
404 let mut wheel = self.wheel.lock();
405 wheel.postpone_batch(updates)
406 }
407
408 /// Batch postpone tasks (replace callbacks)
409 ///
410 /// # Parameters
411 /// - `updates`: List of tuples of (task ID, new delay, new callback)
412 ///
413 /// # Returns
414 /// Number of successfully postponed tasks
415 ///
416 /// 批量延期任务 (替换回调)
417 ///
418 /// # 参数
419 /// - `updates`: (任务 ID, 新延迟, 新回调) 元组列表
420 ///
421 /// # 返回值
422 /// 成功延期的任务数量
423 ///
424 /// # Examples (示例)
425 /// ```no_run
426 /// # use kestrel_timer::{TimerWheel, TimerService, CallbackWrapper, config::ServiceConfig};
427 /// # use std::time::Duration;
428 /// #
429 /// # #[tokio::main]
430 /// # async fn main() {
431 /// let timer = TimerWheel::with_defaults();
432 /// let service = timer.create_service(ServiceConfig::default());
433 ///
434 /// // Create 3 tasks, initially no callbacks
435 /// // 创建 3 个任务,最初没有回调
436 /// let delays: Vec<Duration> = (0..3)
437 /// .map(|_| Duration::from_secs(5))
438 /// .collect();
439 /// let tasks = TimerService::create_batch(delays);
440 /// let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
441 /// service.register_batch(tasks).unwrap();
442 ///
443 /// // Batch postpone and add new callbacks
444 /// // 批量延期并添加新的回调
445 /// let updates: Vec<_> = task_ids
446 /// .into_iter()
447 /// .enumerate()
448 /// .map(|(i, id)| {
449 /// let callback = Some(CallbackWrapper::new(move || async move {
450 /// println!("New callback {}", i);
451 /// }));
452 /// (id, Duration::from_secs(10), callback)
453 /// })
454 /// .collect();
455 /// let postponed = service.postpone_batch_with_callbacks(updates);
456 /// println!("Postponed {} tasks", postponed);
457 /// # }
458 /// ```
459 #[inline]
460 pub fn postpone_batch_with_callbacks(
461 &self,
462 updates: Vec<(TaskId, Duration, Option<CallbackWrapper>)>,
463 ) -> usize {
464 if updates.is_empty() {
465 return 0;
466 }
467
468 let mut wheel = self.wheel.lock();
469 wheel.postpone_batch_with_callbacks(updates)
470 }
471
472 /// Create timer task (static method, apply stage)
473 ///
474 /// # Parameters
475 /// - `delay`: Delay time
476 /// - `callback`: Callback object implementing TimerCallback trait
477 ///
478 /// # Returns
479 /// Return TimerTask, needs to be registered through `register()`
480 ///
481 /// 创建定时器任务 (静态方法,应用阶段)
482 ///
483 /// # 参数
484 /// - `delay`: 延迟时间
485 /// - `callback`: 回调对象,实现 TimerCallback 特质
486 ///
487 /// # 返回值
488 /// 返回 TimerTask,需要通过 `register()` 注册
489 ///
490 /// # Examples (示例)
491 /// ```no_run
492 /// # use kestrel_timer::{TimerWheel, TimerService, CallbackWrapper, config::ServiceConfig};
493 /// # use std::time::Duration;
494 /// #
495 /// # #[tokio::main]
496 /// # async fn main() {
497 /// let timer = TimerWheel::with_defaults();
498 /// let service = timer.create_service(ServiceConfig::default());
499 ///
500 /// // Step 1: create task
501 /// // 创建任务
502 /// let callback = Some(CallbackWrapper::new(|| async {
503 /// println!("Timer fired!");
504 /// }));
505 /// let task = TimerService::create_task(Duration::from_millis(100), callback);
506 ///
507 /// let task_id = task.get_id();
508 /// println!("Created task: {:?}", task_id);
509 ///
510 /// // Step 2: register task
511 /// // 注册任务
512 /// service.register(task).unwrap();
513 /// # }
514 /// ```
515 #[inline]
516 pub fn create_task(delay: Duration, callback: Option<CallbackWrapper>) -> crate::task::TimerTask {
517 crate::timer::TimerWheel::create_task(delay, callback)
518 }
519
520 /// Create batch of timer tasks (static method, apply stage, no callbacks)
521 ///
522 /// # Parameters
523 /// - `delays`: List of delay times
524 ///
525 /// # Returns
526 /// Return TimerTask list, needs to be registered through `register_batch()`
527 ///
528 /// 创建定时器任务 (静态方法,应用阶段,没有回调)
529 /// # 参数
530 /// - `delays`: 延迟时间列表
531 ///
532 /// # 返回值
533 /// 返回 TimerTask 列表,需要通过 `register_batch()` 注册
534 ///
535 /// # Examples (示例)
536 /// ```no_run
537 /// # use kestrel_timer::{TimerWheel, TimerService, CallbackWrapper, config::ServiceConfig};
538 /// # use std::time::Duration;
539 /// #
540 /// # #[tokio::main]
541 /// # async fn main() {
542 /// let timer = TimerWheel::with_defaults();
543 /// let service = timer.create_service(ServiceConfig::default());
544 ///
545 /// // Step 1: create batch of tasks
546 /// // 创建批量任务
547 /// let delays: Vec<Duration> = (0..3)
548 /// .map(|i| Duration::from_millis(100 * (i + 1)))
549 /// .collect();
550 ///
551 /// let tasks = TimerService::create_batch(delays);
552 /// println!("Created {} tasks", tasks.len());
553 ///
554 /// // Step 2: register batch of tasks
555 /// // 注册批量任务
556 /// service.register_batch(tasks).unwrap();
557 /// # }
558 /// ```
559 #[inline]
560 pub fn create_batch(delays: Vec<Duration>) -> Vec<crate::task::TimerTask> {
561 crate::timer::TimerWheel::create_batch(delays)
562 }
563
564 /// Create batch of timer tasks (static method, apply stage, with callbacks)
565 ///
566 /// # Parameters
567 /// - `callbacks`: List of tuples of (delay time, callback)
568 ///
569 /// # Returns
570 /// Return TimerTask list, needs to be registered through `register_batch()`
571 ///
572 /// 创建定时器任务 (静态方法,应用阶段,有回调)
573 /// # 参数
574 /// - `callbacks`: (延迟时间, 回调) 元组列表
575 ///
576 /// # 返回值
577 /// 返回 TimerTask 列表,需要通过 `register_batch()` 注册
578 ///
579 /// # Examples (示例)
580 /// ```no_run
581 /// # use kestrel_timer::{TimerWheel, TimerService, CallbackWrapper, config::ServiceConfig};
582 /// # use std::time::Duration;
583 /// #
584 /// # #[tokio::main]
585 /// # async fn main() {
586 /// let timer = TimerWheel::with_defaults();
587 /// let service = timer.create_service(ServiceConfig::default());
588 ///
589 /// // Step 1: create batch of tasks with callbacks
590 /// // 创建批量任务,带有回调
591 /// let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = (0..3)
592 /// .map(|i| {
593 /// let callback = Some(CallbackWrapper::new(move || async move {
594 /// println!("Timer {} fired!", i);
595 /// }));
596 /// (Duration::from_millis(100 * (i + 1)), callback)
597 /// })
598 /// .collect();
599 ///
600 /// let tasks = TimerService::create_batch_with_callbacks(callbacks);
601 /// println!("Created {} tasks", tasks.len());
602 ///
603 /// // Step 2: register batch of tasks with callbacks
604 /// // 注册批量任务,带有回调
605 /// service.register_batch(tasks).unwrap();
606 /// # }
607 /// ```
608 #[inline]
609 pub fn create_batch_with_callbacks(callbacks: Vec<(Duration, Option<CallbackWrapper>)>) -> Vec<crate::task::TimerTask> {
610 crate::timer::TimerWheel::create_batch_with_callbacks(callbacks)
611 }
612
613 /// Register timer task to service (registration phase)
614 ///
615 /// # Parameters
616 /// - `task`: Task created via `create_task()`
617 ///
618 /// # Returns
619 /// - `Ok(TimerHandle)`: Register successfully
620 /// - `Err(TimerError::RegisterFailed)`: Register failed (internal channel is full or closed)
621 ///
622 /// 注册定时器任务到服务 (注册阶段)
623 /// # 参数
624 /// - `task`: 通过 `create_task()` 创建的任务
625 ///
626 /// # 返回值
627 /// - `Ok(TimerHandle)`: 注册成功
628 /// - `Err(TimerError::RegisterFailed)`: 注册失败 (内部通道已满或关闭)
629 ///
630 /// # Examples (示例)
631 /// ```no_run
632 /// # use kestrel_timer::{TimerWheel, TimerService, CallbackWrapper, config::ServiceConfig};
633 /// # use std::time::Duration;
634 /// #
635 /// # #[tokio::main]
636 /// # async fn main() {
637 /// let timer = TimerWheel::with_defaults();
638 /// let service = timer.create_service(ServiceConfig::default());
639 ///
640 /// // Step 1: create task
641 /// // 创建任务
642 /// let callback = Some(CallbackWrapper::new(|| async move {
643 /// println!("Timer fired!");
644 /// }));
645 /// let task = TimerService::create_task(Duration::from_millis(100), callback);
646 /// let task_id = task.get_id();
647 ///
648 /// // Step 2: register task
649 /// // 注册任务
650 /// service.register(task).unwrap();
651 /// # }
652 /// ```
653 #[inline]
654 pub fn register(&self, task: crate::task::TimerTask) -> Result<TimerHandle, TimerError> {
655 let (completion_tx, completion_rx) = tokio::sync::oneshot::channel();
656 let notifier = crate::task::CompletionNotifier(completion_tx);
657
658 let task_id = task.id;
659
660 // Single lock, complete all operations
661 // 单次锁定,完成所有操作
662 {
663 let mut wheel_guard = self.wheel.lock();
664 wheel_guard.insert(task, notifier);
665 }
666
667 // Add to service management (only send necessary data)
668 // 添加到服务管理(只发送必要数据)
669 self.command_tx
670 .try_send(ServiceCommand::AddTimerHandle {
671 task_id,
672 completion_rx,
673 })
674 .map_err(|_| TimerError::RegisterFailed)?;
675
676 Ok(TimerHandle::new(task_id, self.wheel.clone()))
677 }
678
679 /// Batch register timer tasks to service (registration phase)
680 ///
681 /// # Parameters
682 /// - `tasks`: List of tasks created via `create_batch()`
683 ///
684 /// # Returns
685 /// - `Ok(BatchHandle)`: Register successfully
686 /// - `Err(TimerError::RegisterFailed)`: Register failed (internal channel is full or closed)
687 ///
688 /// 批量注册定时器任务到服务 (注册阶段)
689 /// # 参数
690 /// - `tasks`: 通过 `create_batch()` 创建的任务列表
691 ///
692 /// # 返回值
693 /// - `Ok(BatchHandle)`: 注册成功
694 /// - `Err(TimerError::RegisterFailed)`: 注册失败 (内部通道已满或关闭)
695 ///
696 /// # Examples (示例)
697 /// ```no_run
698 /// # use kestrel_timer::{TimerWheel, TimerService, CallbackWrapper, config::ServiceConfig};
699 /// # use std::time::Duration;
700 /// #
701 /// # #[tokio::main]
702 /// # async fn main() {
703 /// let timer = TimerWheel::with_defaults();
704 /// let service = timer.create_service(ServiceConfig::default());
705 ///
706 /// // Step 1: create batch of tasks with callbacks
707 /// // 创建批量任务,带有回调
708 /// let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = (0..3)
709 /// .map(|i| {
710 /// let callback = Some(CallbackWrapper::new(move || async move {
711 /// println!("Timer {} fired!", i);
712 /// }));
713 /// (Duration::from_secs(1), callback)
714 /// })
715 /// .collect();
716 /// let tasks = TimerService::create_batch_with_callbacks(callbacks);
717 ///
718 /// // Step 2: register batch of tasks with callbacks
719 /// // 注册批量任务,带有回调
720 /// service.register_batch(tasks).unwrap();
721 /// # }
722 /// ```
723 #[inline]
724 pub fn register_batch(&self, tasks: Vec<crate::task::TimerTask>) -> Result<BatchHandle, TimerError> {
725 let task_count = tasks.len();
726 let mut completion_rxs = Vec::with_capacity(task_count);
727 let mut task_ids = Vec::with_capacity(task_count);
728 let mut prepared_tasks = Vec::with_capacity(task_count);
729
730 // Step 1: prepare all channels and notifiers (no lock)
731 // 步骤 1: 准备所有通道和通知器(无锁)
732 for task in tasks {
733 let (completion_tx, completion_rx) = tokio::sync::oneshot::channel();
734 let notifier = crate::task::CompletionNotifier(completion_tx);
735
736 task_ids.push(task.id);
737 completion_rxs.push(completion_rx);
738 prepared_tasks.push((task, notifier));
739 }
740
741 // Step 2: single lock, batch insert
742 // 步骤 2: 单次锁定,批量插入
743 {
744 let mut wheel_guard = self.wheel.lock();
745 wheel_guard.insert_batch(prepared_tasks);
746 }
747
748 // Add to service management (only send necessary data)
749 // 添加到服务管理(只发送必要数据)
750 self.command_tx
751 .try_send(ServiceCommand::AddBatchHandle {
752 task_ids: task_ids.clone(),
753 completion_rxs,
754 })
755 .map_err(|_| TimerError::RegisterFailed)?;
756
757 Ok(BatchHandle::new(task_ids, self.wheel.clone()))
758 }
759
760 /// Graceful shutdown of TimerService
761 ///
762 /// 优雅关闭 TimerService
763 ///
764 /// # Examples (示例)
765 /// ```no_run
766 /// # use kestrel_timer::{TimerWheel, config::ServiceConfig};
767 /// # #[tokio::main]
768 /// # async fn main() {
769 /// let timer = TimerWheel::with_defaults();
770 /// let mut service = timer.create_service(ServiceConfig::default());
771 ///
772 /// // Use service... (使用服务...)
773 ///
774 /// service.shutdown().await;
775 /// # }
776 /// ```
777 pub async fn shutdown(mut self) {
778 if let Some(shutdown_tx) = self.shutdown_tx.take() {
779 let _ = shutdown_tx.send(());
780 }
781 if let Some(handle) = self.actor_handle.take() {
782 let _ = handle.await;
783 }
784 }
785}
786
787
788impl Drop for TimerService {
789 fn drop(&mut self) {
790 if let Some(handle) = self.actor_handle.take() {
791 handle.abort();
792 }
793 }
794}
795
796/// ServiceActor - internal Actor implementation
797///
798/// ServiceActor - 内部 Actor 实现
799struct ServiceActor {
800 /// Command receiver
801 ///
802 /// 命令接收器
803 command_rx: mpsc::Receiver<ServiceCommand>,
804 /// Timeout sender
805 ///
806 /// 超时发送器
807 timeout_tx: mpsc::Sender<TaskId>,
808 /// Actor shutdown signal receiver
809 ///
810 /// Actor 关闭信号接收器
811 shutdown_rx: tokio::sync::oneshot::Receiver<()>,
812}
813
814impl ServiceActor {
815 /// Create new ServiceActor
816 ///
817 /// 创建新的 ServiceActor
818 fn new(command_rx: mpsc::Receiver<ServiceCommand>, timeout_tx: mpsc::Sender<TaskId>, shutdown_rx: tokio::sync::oneshot::Receiver<()>) -> Self {
819 Self {
820 command_rx,
821 timeout_tx,
822 shutdown_rx,
823 }
824 }
825
826 /// Run Actor event loop
827 ///
828 /// 运行 Actor 事件循环
829 async fn run(mut self) {
830 // Use FuturesUnordered to listen to all completion_rxs
831 // Each future returns (TaskId, Result)
832 // 使用 FuturesUnordered 监听所有 completion_rxs
833 // 每个 future 返回 (TaskId, Result)
834 let mut futures: FuturesUnordered<BoxFuture<'static, (TaskId, Result<TaskCompletionReason, tokio::sync::oneshot::error::RecvError>)>> = FuturesUnordered::new();
835
836 // Move shutdown_rx out of self, so it can be used in select! with &mut
837 // 将 shutdown_rx 从 self 中移出,以便在 select! 中使用 &mut
838 let mut shutdown_rx = self.shutdown_rx;
839
840 loop {
841 tokio::select! {
842 // Listen to high-priority shutdown signal
843 // 监听高优先级关闭信号
844 _ = &mut shutdown_rx => {
845 // Receive shutdown signal, exit loop immediately
846 // 接收到关闭信号,立即退出循环
847 break;
848 }
849
850 // Listen to timeout events
851 // 监听超时事件
852 Some((task_id, result)) = futures.next() => {
853 // Check completion reason, only forward expired (Expired) events, do not forward cancelled (Cancelled) events
854 // 检查完成原因,只转发过期 (Expired) 事件,不转发取消 (Cancelled) 事件
855 if let Ok(TaskCompletionReason::Expired) = result {
856 let _ = self.timeout_tx.send(task_id).await;
857 }
858 // Task will be automatically removed from FuturesUnordered
859 // 任务将自动从 FuturesUnordered 中移除
860 }
861
862 // Listen to commands
863 // 监听命令
864 Some(cmd) = self.command_rx.recv() => {
865 match cmd {
866 ServiceCommand::AddBatchHandle { task_ids, completion_rxs } => {
867 // Add all tasks to futures
868 // 将所有任务添加到 futures
869 for (task_id, rx) in task_ids.into_iter().zip(completion_rxs.into_iter()) {
870 let future: BoxFuture<'static, (TaskId, Result<TaskCompletionReason, tokio::sync::oneshot::error::RecvError>)> = Box::pin(async move {
871 (task_id, rx.await)
872 });
873 futures.push(future);
874 }
875 }
876 ServiceCommand::AddTimerHandle { task_id, completion_rx } => {
877 // Add to futures
878 // 添加到 futures
879 let future: BoxFuture<'static, (TaskId, Result<TaskCompletionReason, tokio::sync::oneshot::error::RecvError>)> = Box::pin(async move {
880 (task_id, completion_rx.await)
881 });
882 futures.push(future);
883 }
884 }
885 }
886
887 // If no futures and command channel is closed, exit loop
888 // 如果没有 futures 且命令通道关闭,退出循环
889 else => {
890 break;
891 }
892 }
893 }
894 }
895}
896
897#[cfg(test)]
898mod tests {
899 use super::*;
900 use crate::TimerWheel;
901 use std::sync::atomic::{AtomicU32, Ordering};
902 use std::sync::Arc;
903 use std::time::Duration;
904
905 #[tokio::test]
906 async fn test_service_creation() {
907 let timer = TimerWheel::with_defaults();
908 let _service = timer.create_service(ServiceConfig::default());
909 }
910
911
912 #[tokio::test]
913 async fn test_add_timer_handle_and_receive_timeout() {
914 let timer = TimerWheel::with_defaults();
915 let mut service = timer.create_service(ServiceConfig::default());
916
917 // Create single timer (创建单个定时器)
918 let task = TimerService::create_task(Duration::from_millis(50), Some(CallbackWrapper::new(|| async {})));
919 let task_id = task.get_id();
920
921 // Register to service (注册到服务)
922 service.register(task).unwrap();
923
924 // Receive timeout notification (接收超时通知)
925 let mut rx = service.take_receiver().unwrap();
926 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
927 .await
928 .expect("Should receive timeout notification")
929 .expect("Should receive Some value");
930
931 assert_eq!(received_task_id, task_id);
932 }
933
934
935 #[tokio::test]
936 async fn test_shutdown() {
937 let timer = TimerWheel::with_defaults();
938 let service = timer.create_service(ServiceConfig::default());
939
940 // Add some timers (添加一些定时器)
941 let task1 = TimerService::create_task(Duration::from_secs(10), None);
942 let task2 = TimerService::create_task(Duration::from_secs(10), None);
943 service.register(task1).unwrap();
944 service.register(task2).unwrap();
945
946 // Immediately shutdown (without waiting for timers to trigger) (立即关闭(不等待定时器触发))
947 service.shutdown().await;
948 }
949
950
951
952 #[tokio::test]
953 async fn test_cancel_task() {
954 let timer = TimerWheel::with_defaults();
955 let service = timer.create_service(ServiceConfig::default());
956
957 // Add a long-term timer (添加一个长期定时器)
958 let task = TimerService::create_task(Duration::from_secs(10), None);
959 let task_id = task.get_id();
960
961 service.register(task).unwrap();
962
963 // Cancel task (取消任务)
964 let cancelled = service.cancel_task(task_id);
965 assert!(cancelled, "Task should be cancelled successfully");
966
967 // Try to cancel the same task again, should return false (再次尝试取消同一任务,应返回 false)
968 let cancelled_again = service.cancel_task(task_id);
969 assert!(!cancelled_again, "Task should not exist anymore");
970 }
971
972 #[tokio::test]
973 async fn test_cancel_nonexistent_task() {
974 let timer = TimerWheel::with_defaults();
975 let service = timer.create_service(ServiceConfig::default());
976
977 // Add a timer to initialize service (添加定时器以初始化服务)
978 let task = TimerService::create_task(Duration::from_millis(50), None);
979 service.register(task).unwrap();
980
981 // Try to cancel a nonexistent task (create a task ID that will not actually be registered)
982 // 尝试取消不存在的任务(创建一个实际不会注册的任务 ID)
983 let fake_task = TimerService::create_task(Duration::from_millis(50), None);
984 let fake_task_id = fake_task.get_id();
985 // Do not register fake_task (不注册 fake_task)
986 let cancelled = service.cancel_task(fake_task_id);
987 assert!(!cancelled, "Nonexistent task should not be cancelled");
988 }
989
990
991 #[tokio::test]
992 async fn test_task_timeout_cleans_up_task_sender() {
993 let timer = TimerWheel::with_defaults();
994 let mut service = timer.create_service(ServiceConfig::default());
995
996 // Add a short-term timer (添加短期定时器)
997 let task = TimerService::create_task(Duration::from_millis(50), None);
998 let task_id = task.get_id();
999
1000 service.register(task).unwrap();
1001
1002 // Wait for task timeout (等待任务超时)
1003 let mut rx = service.take_receiver().unwrap();
1004 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1005 .await
1006 .expect("Should receive timeout notification")
1007 .expect("Should receive Some value");
1008
1009 assert_eq!(received_task_id, task_id);
1010
1011 // Wait a moment to ensure internal cleanup is complete (等待片刻以确保内部清理完成)
1012 tokio::time::sleep(Duration::from_millis(10)).await;
1013
1014 // Try to cancel the timed-out task, should return false (尝试取消超时任务,应返回 false)
1015 let cancelled = service.cancel_task(task_id);
1016 assert!(!cancelled, "Timed out task should not exist anymore");
1017 }
1018
1019 #[tokio::test]
1020 async fn test_cancel_task_spawns_background_task() {
1021 let timer = TimerWheel::with_defaults();
1022 let service = timer.create_service(ServiceConfig::default());
1023 let counter = Arc::new(AtomicU32::new(0));
1024
1025 // Create a timer (创建定时器)
1026 let counter_clone = Arc::clone(&counter);
1027 let task = TimerService::create_task(
1028 Duration::from_secs(10),
1029 Some(CallbackWrapper::new(move || {
1030 let counter = Arc::clone(&counter_clone);
1031 async move {
1032 counter.fetch_add(1, Ordering::SeqCst);
1033 }
1034 })),
1035 );
1036 let task_id = task.get_id();
1037
1038 service.register(task).unwrap();
1039
1040 // Use cancel_task (will wait for result, but processed in background coroutine)
1041 // 使用 cancel_task(将等待结果,但在后台协程中处理)
1042 let cancelled = service.cancel_task(task_id);
1043 assert!(cancelled, "Task should be cancelled successfully");
1044
1045 // Wait long enough to ensure callback is not executed (等待足够长时间以确保回调未执行)
1046 tokio::time::sleep(Duration::from_millis(100)).await;
1047 assert_eq!(counter.load(Ordering::SeqCst), 0, "Callback should not have been executed");
1048
1049 // Verify task has been removed from active_tasks (验证任务已从 active_tasks 中移除)
1050 let cancelled_again = service.cancel_task(task_id);
1051 assert!(!cancelled_again, "Task should have been removed from active_tasks");
1052 }
1053
1054 #[tokio::test]
1055 async fn test_schedule_once_direct() {
1056 let timer = TimerWheel::with_defaults();
1057 let mut service = timer.create_service(ServiceConfig::default());
1058 let counter = Arc::new(AtomicU32::new(0));
1059
1060 // Schedule timer directly through service
1061 // 直接通过服务调度定时器
1062 let counter_clone = Arc::clone(&counter);
1063 let task = TimerService::create_task(
1064 Duration::from_millis(50),
1065 Some(CallbackWrapper::new(move || {
1066 let counter = Arc::clone(&counter_clone);
1067 async move {
1068 counter.fetch_add(1, Ordering::SeqCst);
1069 }
1070 })),
1071 );
1072 let task_id = task.get_id();
1073 service.register(task).unwrap();
1074
1075 // Wait for timer to trigger
1076 // 等待定时器触发
1077 let mut rx = service.take_receiver().unwrap();
1078 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1079 .await
1080 .expect("Should receive timeout notification")
1081 .expect("Should receive Some value");
1082
1083 assert_eq!(received_task_id, task_id);
1084
1085 // Wait for callback to execute
1086 // 等待回调执行
1087 tokio::time::sleep(Duration::from_millis(50)).await;
1088 assert_eq!(counter.load(Ordering::SeqCst), 1);
1089 }
1090
1091 #[tokio::test]
1092 async fn test_schedule_once_batch_direct() {
1093 let timer = TimerWheel::with_defaults();
1094 let mut service = timer.create_service(ServiceConfig::default());
1095 let counter = Arc::new(AtomicU32::new(0));
1096
1097 // Schedule timers directly through service
1098 // 直接通过服务调度定时器
1099 let callbacks: Vec<_> = (0..3)
1100 .map(|_| {
1101 let counter = Arc::clone(&counter);
1102 (Duration::from_millis(50), Some(CallbackWrapper::new(move || {
1103 let counter = Arc::clone(&counter);
1104 async move {
1105 counter.fetch_add(1, Ordering::SeqCst);
1106 }
1107 })))
1108 })
1109 .collect();
1110
1111 let tasks = TimerService::create_batch_with_callbacks(callbacks);
1112 assert_eq!(tasks.len(), 3);
1113 service.register_batch(tasks).unwrap();
1114
1115 // Receive all timeout notifications
1116 // 接收所有超时通知
1117 let mut received_count = 0;
1118 let mut rx = service.take_receiver().unwrap();
1119
1120 while received_count < 3 {
1121 match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
1122 Ok(Some(_task_id)) => {
1123 received_count += 1;
1124 }
1125 Ok(None) => break,
1126 Err(_) => break,
1127 }
1128 }
1129
1130 assert_eq!(received_count, 3);
1131
1132 // Wait for callback to execute
1133 // 等待回调执行
1134 tokio::time::sleep(Duration::from_millis(50)).await;
1135 assert_eq!(counter.load(Ordering::SeqCst), 3);
1136 }
1137
1138 #[tokio::test]
1139 async fn test_schedule_once_notify_direct() {
1140 let timer = TimerWheel::with_defaults();
1141 let mut service = timer.create_service(ServiceConfig::default());
1142
1143 // Schedule only notification timer directly through service (no callback)
1144 // 直接通过服务调度通知定时器(没有回调)
1145 let task = TimerService::create_task(Duration::from_millis(50), None);
1146 let task_id = task.get_id();
1147 service.register(task).unwrap();
1148
1149 // Receive timeout notification
1150 // 接收超时通知
1151 let mut rx = service.take_receiver().unwrap();
1152 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1153 .await
1154 .expect("Should receive timeout notification")
1155 .expect("Should receive Some value");
1156
1157 assert_eq!(received_task_id, task_id);
1158 }
1159
1160 #[tokio::test]
1161 async fn test_schedule_and_cancel_direct() {
1162 let timer = TimerWheel::with_defaults();
1163 let service = timer.create_service(ServiceConfig::default());
1164 let counter = Arc::new(AtomicU32::new(0));
1165
1166 // Schedule timer directly
1167 // 直接调度定时器
1168 let counter_clone = Arc::clone(&counter);
1169 let task = TimerService::create_task(
1170 Duration::from_secs(10),
1171 Some(CallbackWrapper::new(move || {
1172 let counter = Arc::clone(&counter_clone);
1173 async move {
1174 counter.fetch_add(1, Ordering::SeqCst);
1175 }
1176 })),
1177 );
1178 let task_id = task.get_id();
1179 service.register(task).unwrap();
1180
1181 // Immediately cancel
1182 // 立即取消
1183 let cancelled = service.cancel_task(task_id);
1184 assert!(cancelled, "Task should be cancelled successfully");
1185
1186 // Wait to ensure callback is not executed
1187 // 等待确保回调未执行
1188 tokio::time::sleep(Duration::from_millis(100)).await;
1189 assert_eq!(counter.load(Ordering::SeqCst), 0, "Callback should not have been executed");
1190 }
1191
1192 #[tokio::test]
1193 async fn test_cancel_batch_direct() {
1194 let timer = TimerWheel::with_defaults();
1195 let service = timer.create_service(ServiceConfig::default());
1196 let counter = Arc::new(AtomicU32::new(0));
1197
1198 // Batch schedule timers
1199 // 批量调度定时器
1200 let callbacks: Vec<_> = (0..10)
1201 .map(|_| {
1202 let counter = Arc::clone(&counter);
1203 (Duration::from_secs(10), Some(CallbackWrapper::new(move || {
1204 let counter = Arc::clone(&counter);
1205 async move {
1206 counter.fetch_add(1, Ordering::SeqCst);
1207 }
1208 })))
1209 })
1210 .collect();
1211
1212 let tasks = TimerService::create_batch_with_callbacks(callbacks);
1213 let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
1214 assert_eq!(task_ids.len(), 10);
1215 service.register_batch(tasks).unwrap();
1216
1217 // Batch cancel all tasks
1218 // 批量取消所有任务
1219 let cancelled = service.cancel_batch(&task_ids);
1220 assert_eq!(cancelled, 10, "All 10 tasks should be cancelled");
1221
1222 // Wait to ensure callback is not executed
1223 // 等待确保回调未执行
1224 tokio::time::sleep(Duration::from_millis(100)).await;
1225 assert_eq!(counter.load(Ordering::SeqCst), 0, "No callbacks should have been executed");
1226 }
1227
1228 #[tokio::test]
1229 async fn test_cancel_batch_partial() {
1230 let timer = TimerWheel::with_defaults();
1231 let service = timer.create_service(ServiceConfig::default());
1232 let counter = Arc::new(AtomicU32::new(0));
1233
1234 // Batch schedule timers
1235 // 批量调度定时器
1236 let callbacks: Vec<_> = (0..10)
1237 .map(|_| {
1238 let counter = Arc::clone(&counter);
1239 (Duration::from_secs(10), Some(CallbackWrapper::new(move || {
1240 let counter = Arc::clone(&counter);
1241 async move {
1242 counter.fetch_add(1, Ordering::SeqCst);
1243 }
1244 })))
1245 })
1246 .collect();
1247
1248 let tasks = TimerService::create_batch_with_callbacks(callbacks);
1249 let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
1250 service.register_batch(tasks).unwrap();
1251
1252 // Only cancel first 5 tasks
1253 // 只取消前 5 个任务
1254 let to_cancel: Vec<_> = task_ids.iter().take(5).copied().collect();
1255 let cancelled = service.cancel_batch(&to_cancel);
1256 assert_eq!(cancelled, 5, "5 tasks should be cancelled");
1257
1258 // Wait to ensure first 5 callbacks are not executed
1259 // 等待确保前 5 个回调未执行
1260 tokio::time::sleep(Duration::from_millis(100)).await;
1261 assert_eq!(counter.load(Ordering::SeqCst), 0, "Cancelled tasks should not execute");
1262 }
1263
1264 #[tokio::test]
1265 async fn test_cancel_batch_empty() {
1266 let timer = TimerWheel::with_defaults();
1267 let service = timer.create_service(ServiceConfig::default());
1268
1269 // Cancel empty list
1270 // 取消空列表
1271 let empty: Vec<TaskId> = vec![];
1272 let cancelled = service.cancel_batch(&empty);
1273 assert_eq!(cancelled, 0, "No tasks should be cancelled");
1274 }
1275
1276 #[tokio::test]
1277 async fn test_postpone() {
1278 let timer = TimerWheel::with_defaults();
1279 let mut service = timer.create_service(ServiceConfig::default());
1280 let counter = Arc::new(AtomicU32::new(0));
1281
1282 // Register a task, original callback increases 1
1283 // 注册一个任务,原始回调增加 1
1284 let counter_clone1 = Arc::clone(&counter);
1285 let task = TimerService::create_task(
1286 Duration::from_millis(50),
1287 Some(CallbackWrapper::new(move || {
1288 let counter = Arc::clone(&counter_clone1);
1289 async move {
1290 counter.fetch_add(1, Ordering::SeqCst);
1291 }
1292 })),
1293 );
1294 let task_id = task.get_id();
1295 service.register(task).unwrap();
1296
1297 // Postpone task and replace callback, new callback increases 10
1298 // 延期任务并替换回调,新回调增加 10
1299 let counter_clone2 = Arc::clone(&counter);
1300 let postponed = service.postpone(
1301 task_id,
1302 Duration::from_millis(100),
1303 Some(CallbackWrapper::new(move || {
1304 let counter = Arc::clone(&counter_clone2);
1305 async move {
1306 counter.fetch_add(10, Ordering::SeqCst);
1307 }
1308 }))
1309 );
1310 assert!(postponed, "Task should be postponed successfully");
1311
1312 // Receive timeout notification (after postponing, need to wait 100ms, plus margin)
1313 // 接收超时通知(延期后,需要等待 100ms,加上余量)
1314 let mut rx = service.take_receiver().unwrap();
1315 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1316 .await
1317 .expect("Should receive timeout notification")
1318 .expect("Should receive Some value");
1319
1320 assert_eq!(received_task_id, task_id);
1321
1322 // Wait for callback to execute
1323 // 等待回调执行
1324 tokio::time::sleep(Duration::from_millis(20)).await;
1325
1326 // Verify new callback is executed (increased 10 instead of 1)
1327 // 验证新回调已执行(增加 10 而不是 1)
1328 assert_eq!(counter.load(Ordering::SeqCst), 10);
1329 }
1330
1331 #[tokio::test]
1332 async fn test_postpone_nonexistent_task() {
1333 let timer = TimerWheel::with_defaults();
1334 let service = timer.create_service(ServiceConfig::default());
1335
1336 // Try to postpone a nonexistent task
1337 // 尝试延期一个不存在的任务
1338 let fake_task = TimerService::create_task(Duration::from_millis(50), None);
1339 let fake_task_id = fake_task.get_id();
1340 // Do not register this task
1341 // 不注册这个任务
1342 let postponed = service.postpone(fake_task_id, Duration::from_millis(100), None);
1343 assert!(!postponed, "Nonexistent task should not be postponed");
1344 }
1345
1346 #[tokio::test]
1347 async fn test_postpone_batch() {
1348 let timer = TimerWheel::with_defaults();
1349 let mut service = timer.create_service(ServiceConfig::default());
1350 let counter = Arc::new(AtomicU32::new(0));
1351
1352 // Register 3 tasks
1353 // 注册 3 个任务
1354 let mut task_ids = Vec::new();
1355 for _ in 0..3 {
1356 let counter_clone = Arc::clone(&counter);
1357 let task = TimerService::create_task(
1358 Duration::from_millis(50),
1359 Some(CallbackWrapper::new(move || {
1360 let counter = Arc::clone(&counter_clone);
1361 async move {
1362 counter.fetch_add(1, Ordering::SeqCst);
1363 }
1364 })),
1365 );
1366 task_ids.push((task.get_id(), Duration::from_millis(150), None));
1367 service.register(task).unwrap();
1368 }
1369
1370 // Batch postpone
1371 // 批量延期
1372 let postponed = service.postpone_batch_with_callbacks(task_ids);
1373 assert_eq!(postponed, 3, "All 3 tasks should be postponed");
1374
1375 // Wait for original time 50ms, task should not trigger
1376 // 等待原始时间 50ms,任务应该不触发
1377 tokio::time::sleep(Duration::from_millis(70)).await;
1378 assert_eq!(counter.load(Ordering::SeqCst), 0);
1379
1380 // Receive all timeout notifications
1381 // 接收所有超时通知
1382 let mut received_count = 0;
1383 let mut rx = service.take_receiver().unwrap();
1384
1385 while received_count < 3 {
1386 match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
1387 Ok(Some(_task_id)) => {
1388 received_count += 1;
1389 }
1390 Ok(None) => break,
1391 Err(_) => break,
1392 }
1393 }
1394
1395 assert_eq!(received_count, 3);
1396
1397 // Wait for callback to execute
1398 // 等待回调执行
1399 tokio::time::sleep(Duration::from_millis(20)).await;
1400 assert_eq!(counter.load(Ordering::SeqCst), 3);
1401 }
1402
1403 #[tokio::test]
1404 async fn test_postpone_batch_with_callbacks() {
1405 let timer = TimerWheel::with_defaults();
1406 let mut service = timer.create_service(ServiceConfig::default());
1407 let counter = Arc::new(AtomicU32::new(0));
1408
1409 // Register 3 tasks
1410 // 注册 3 个任务
1411 let mut task_ids = Vec::new();
1412 for _ in 0..3 {
1413 let task = TimerService::create_task(
1414 Duration::from_millis(50),
1415 None,
1416 );
1417 task_ids.push(task.get_id());
1418 service.register(task).unwrap();
1419 }
1420
1421 // Batch postpone and replace callback
1422 // 批量延期并替换回调
1423 let updates: Vec<_> = task_ids
1424 .into_iter()
1425 .map(|id| {
1426 let counter_clone = Arc::clone(&counter);
1427 (id, Duration::from_millis(150), Some(CallbackWrapper::new(move || {
1428 let counter = Arc::clone(&counter_clone);
1429 async move {
1430 counter.fetch_add(1, Ordering::SeqCst);
1431 }
1432 })))
1433 })
1434 .collect();
1435
1436 let postponed = service.postpone_batch_with_callbacks(updates);
1437 assert_eq!(postponed, 3, "All 3 tasks should be postponed");
1438
1439 // Wait for original time 50ms, task should not trigger
1440 // 等待原始时间 50ms,任务应该不触发
1441 tokio::time::sleep(Duration::from_millis(70)).await;
1442 assert_eq!(counter.load(Ordering::SeqCst), 0);
1443
1444 // Receive all timeout notifications
1445 // 接收所有超时通知
1446 let mut received_count = 0;
1447 let mut rx = service.take_receiver().unwrap();
1448
1449 while received_count < 3 {
1450 match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
1451 Ok(Some(_task_id)) => {
1452 received_count += 1;
1453 }
1454 Ok(None) => break,
1455 Err(_) => break,
1456 }
1457 }
1458
1459 assert_eq!(received_count, 3);
1460
1461 // Wait for callback to execute
1462 // 等待回调执行
1463 tokio::time::sleep(Duration::from_millis(20)).await;
1464 assert_eq!(counter.load(Ordering::SeqCst), 3);
1465 }
1466
1467 #[tokio::test]
1468 async fn test_postpone_batch_empty() {
1469 let timer = TimerWheel::with_defaults();
1470 let service = timer.create_service(ServiceConfig::default());
1471
1472 // Postpone empty list
1473 let empty: Vec<(TaskId, Duration, Option<CallbackWrapper>)> = vec![];
1474 let postponed = service.postpone_batch_with_callbacks(empty);
1475 assert_eq!(postponed, 0, "No tasks should be postponed");
1476 }
1477
1478 #[tokio::test]
1479 async fn test_postpone_keeps_timeout_notification_valid() {
1480 let timer = TimerWheel::with_defaults();
1481 let mut service = timer.create_service(ServiceConfig::default());
1482 let counter = Arc::new(AtomicU32::new(0));
1483
1484 // Register a task
1485 // 注册一个任务
1486 let counter_clone = Arc::clone(&counter);
1487 let task = TimerService::create_task(
1488 Duration::from_millis(50),
1489 Some(CallbackWrapper::new(move || {
1490 let counter = Arc::clone(&counter_clone);
1491 async move {
1492 counter.fetch_add(1, Ordering::SeqCst);
1493 }
1494 })),
1495 );
1496 let task_id = task.get_id();
1497 service.register(task).unwrap();
1498
1499 // Postpone task
1500 // 延期任务
1501 service.postpone(task_id, Duration::from_millis(100), None);
1502
1503 // Verify timeout notification is still valid (after postponing, need to wait 100ms, plus margin)
1504 // 验证超时通知是否仍然有效(延期后,需要等待 100ms,加上余量)
1505 let mut rx = service.take_receiver().unwrap();
1506 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1507 .await
1508 .expect("Should receive timeout notification")
1509 .expect("Should receive Some value");
1510
1511 assert_eq!(received_task_id, task_id, "Timeout notification should still work after postpone");
1512
1513 // Wait for callback to execute
1514 // 等待回调执行
1515 tokio::time::sleep(Duration::from_millis(20)).await;
1516 assert_eq!(counter.load(Ordering::SeqCst), 1);
1517 }
1518
1519 #[tokio::test]
1520 async fn test_cancelled_task_not_forwarded_to_timeout_rx() {
1521 let timer = TimerWheel::with_defaults();
1522 let mut service = timer.create_service(ServiceConfig::default());
1523
1524 // Register two tasks: one will be cancelled, one will expire normally
1525 // 注册两个任务:一个将被取消,一个将正常过期
1526 let task1 = TimerService::create_task(Duration::from_secs(10), None);
1527 let task1_id = task1.get_id();
1528 service.register(task1).unwrap();
1529
1530 let task2 = TimerService::create_task(Duration::from_millis(50), None);
1531 let task2_id = task2.get_id();
1532 service.register(task2).unwrap();
1533
1534 // Cancel first task
1535 // 取消第一个任务
1536 let cancelled = service.cancel_task(task1_id);
1537 assert!(cancelled, "Task should be cancelled");
1538
1539 // Wait for second task to expire
1540 // 等待第二个任务过期
1541 let mut rx = service.take_receiver().unwrap();
1542 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1543 .await
1544 .expect("Should receive timeout notification")
1545 .expect("Should receive Some value");
1546
1547 // Should only receive notification for second task (expired), not for first task (cancelled)
1548 assert_eq!(received_task_id, task2_id, "Should only receive expired task notification");
1549
1550 // Verify no other notifications (especially cancelled tasks should not have notifications)
1551 let no_more = tokio::time::timeout(Duration::from_millis(50), rx.recv()).await;
1552 assert!(no_more.is_err(), "Should not receive any more notifications");
1553 }
1554
1555 #[tokio::test]
1556 async fn test_take_receiver_twice() {
1557 let timer = TimerWheel::with_defaults();
1558 let mut service = timer.create_service(ServiceConfig::default());
1559
1560 // First call should return Some
1561 // 第一次调用应该返回 Some
1562 let rx1 = service.take_receiver();
1563 assert!(rx1.is_some(), "First take_receiver should return Some");
1564
1565 // Second call should return None
1566 // 第二次调用应该返回 None
1567 let rx2 = service.take_receiver();
1568 assert!(rx2.is_none(), "Second take_receiver should return None");
1569 }
1570
1571 #[tokio::test]
1572 async fn test_postpone_batch_without_callbacks() {
1573 let timer = TimerWheel::with_defaults();
1574 let mut service = timer.create_service(ServiceConfig::default());
1575 let counter = Arc::new(AtomicU32::new(0));
1576
1577 // Register 3 tasks, with original callback
1578 // 注册 3 个任务,带有原始回调
1579 let mut task_ids = Vec::new();
1580 for _ in 0..3 {
1581 let counter_clone = Arc::clone(&counter);
1582 let task = TimerService::create_task(
1583 Duration::from_millis(50),
1584 Some(CallbackWrapper::new(move || {
1585 let counter = Arc::clone(&counter_clone);
1586 async move {
1587 counter.fetch_add(1, Ordering::SeqCst);
1588 }
1589 })),
1590 );
1591 task_ids.push(task.get_id());
1592 service.register(task).unwrap();
1593 }
1594
1595 // Batch postpone, without replacing callback
1596 // 批量延期,不替换回调
1597 let updates: Vec<_> = task_ids
1598 .iter()
1599 .map(|&id| (id, Duration::from_millis(150)))
1600 .collect();
1601 let postponed = service.postpone_batch(updates);
1602 assert_eq!(postponed, 3, "All 3 tasks should be postponed");
1603
1604 // Wait for original time 50ms, task should not trigger
1605 // 等待原始时间 50ms,任务应该不触发
1606 tokio::time::sleep(Duration::from_millis(70)).await;
1607 assert_eq!(counter.load(Ordering::SeqCst), 0, "Callbacks should not fire yet");
1608
1609 // Receive all timeout notifications
1610 // 接收所有超时通知
1611 let mut received_count = 0;
1612 let mut rx = service.take_receiver().unwrap();
1613
1614 while received_count < 3 {
1615 match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
1616 Ok(Some(_task_id)) => {
1617 received_count += 1;
1618 }
1619 Ok(None) => break,
1620 Err(_) => break,
1621 }
1622 }
1623
1624 assert_eq!(received_count, 3, "Should receive 3 timeout notifications");
1625
1626 // Wait for callback to execute
1627 // 等待回调执行
1628 tokio::time::sleep(Duration::from_millis(20)).await;
1629 assert_eq!(counter.load(Ordering::SeqCst), 3, "All callbacks should execute");
1630 }
1631}
1632