kestrel_timer/service.rs
1use lite_sync::{oneshot, spsc};
2use crate::{BatchHandle, TimerHandle};
3use crate::config::ServiceConfig;
4use crate::error::TimerError;
5use crate::task::{CallbackWrapper, CompletionReceiver, TaskCompletion, TaskId};
6use crate::wheel::Wheel;
7use futures::stream::{FuturesUnordered, StreamExt};
8use futures::future::BoxFuture;
9use parking_lot::Mutex;
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::task::JoinHandle;
13
14/// Task notification type for distinguishing between one-shot and periodic tasks
15///
16/// 任务通知类型,用于区分一次性任务和周期性任务
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum TaskNotification {
19 /// One-shot task expired notification
20 ///
21 /// 一次性任务过期通知
22 OneShot(TaskId),
23 /// Periodic task called notification
24 ///
25 /// 周期性任务被调用通知
26 Periodic(TaskId),
27}
28
29impl TaskNotification {
30 /// Get the task ID from the notification
31 ///
32 /// 从通知中获取任务 ID
33 pub fn task_id(&self) -> TaskId {
34 match self {
35 TaskNotification::OneShot(id) => *id,
36 TaskNotification::Periodic(id) => *id,
37 }
38 }
39
40 /// Check if this is a one-shot task notification
41 ///
42 /// 检查是否为一次性任务通知
43 pub fn is_oneshot(&self) -> bool {
44 matches!(self, TaskNotification::OneShot(_))
45 }
46
47 /// Check if this is a periodic task notification
48 ///
49 /// 检查是否为周期性任务通知
50 pub fn is_periodic(&self) -> bool {
51 matches!(self, TaskNotification::Periodic(_))
52 }
53}
54
55/// Service command type
56///
57/// 服务命令类型
58enum ServiceCommand {
59 /// Add batch timer handle, only contains necessary data: task_ids and completion_rxs
60 ///
61 /// 添加批量定时器句柄,仅包含必要数据: task_ids 和 completion_rxs
62 AddBatchHandle {
63 task_ids: Vec<TaskId>,
64 completion_rxs: Vec<CompletionReceiver>,
65 },
66 /// Add single timer handle, only contains necessary data: task_id and completion_rx
67 ///
68 /// 添加单个定时器句柄,仅包含必要数据: task_id 和 completion_rx
69 AddTimerHandle {
70 task_id: TaskId,
71 completion_rx: CompletionReceiver,
72 },
73}
74
75/// TimerService - timer service based on Actor pattern
76/// Manages multiple timer handles, listens to all timeout events, and aggregates notifications to be forwarded to the user.
77/// # Features
78/// - Automatically listens to all added timer handles' timeout events
79/// - Automatically removes one-shot tasks from internal management after timeout
80/// - Continuously monitors periodic tasks and forwards each invocation
81/// - Aggregates notifications (both one-shot and periodic) to be forwarded to the user's unified channel
82/// - Supports dynamic addition of BatchHandle and TimerHandle
83///
84///
85/// # 定时器服务,基于 Actor 模式管理多个定时器句柄,监听所有超时事件,并将通知聚合转发给用户。
86/// - 自动监听所有添加的定时器句柄的超时事件
87/// - 自动在一次性任务超时后从内部管理中移除任务
88/// - 持续监听周期性任务并转发每次调用通知
89/// - 将通知(一次性和周期性)聚合转发给用户
90/// - 支持动态添加 BatchHandle 和 TimerHandle
91///
92/// # Examples (示例)
93/// ```no_run
94/// use kestrel_timer::{TimerWheel, TimerService, CallbackWrapper, TaskNotification, config::ServiceConfig};
95/// use std::time::Duration;
96///
97/// #[tokio::main]
98/// async fn main() {
99/// let timer = TimerWheel::with_defaults();
100/// let mut service = timer.create_service(ServiceConfig::default());
101///
102/// // Register one-shot tasks (注册一次性任务)
103/// use kestrel_timer::TimerTask;
104/// let oneshot_tasks: Vec<_> = (0..3)
105/// .map(|i| {
106/// let callback = Some(CallbackWrapper::new(move || async move {
107/// println!("One-shot timer {} fired!", i);
108/// }));
109/// TimerTask::new_oneshot(Duration::from_millis(100), callback)
110/// })
111/// .collect();
112/// service.register_batch(oneshot_tasks).unwrap();
113///
114/// // Register periodic tasks (注册周期性任务)
115/// let periodic_task = TimerTask::new_periodic(
116/// Duration::from_millis(100),
117/// Duration::from_millis(50),
118/// Some(CallbackWrapper::new(|| async { println!("Periodic timer fired!"); })),
119/// None
120/// );
121/// service.register(periodic_task).unwrap();
122///
123/// // Receive notifications (接收通知)
124/// let rx = service.take_receiver().unwrap();
125/// while let Some(notification) = rx.recv().await {
126/// match notification {
127/// TaskNotification::OneShot(task_id) => {
128/// println!("One-shot task {:?} expired", task_id);
129/// }
130/// TaskNotification::Periodic(task_id) => {
131/// println!("Periodic task {:?} called", task_id);
132/// }
133/// }
134/// }
135/// }
136/// ```
137pub struct TimerService {
138 /// Command sender
139 ///
140 /// 命令发送器
141 command_tx: spsc::Sender<ServiceCommand, 32>,
142 /// Timeout receiver (supports both one-shot and periodic task notifications)
143 ///
144 /// 超时接收器(支持一次性和周期性任务通知)
145 timeout_rx: Option<spsc::Receiver<TaskNotification, 32>>,
146 /// Actor task handle
147 ///
148 /// Actor 任务句柄
149 actor_handle: Option<JoinHandle<()>>,
150 /// Timing wheel reference (for direct scheduling of timers)
151 ///
152 /// 时间轮引用 (用于直接调度定时器)
153 wheel: Arc<Mutex<Wheel>>,
154 /// Actor shutdown signal sender
155 ///
156 /// Actor 关闭信号发送器
157 shutdown_tx: Option<oneshot::Sender<()>>,
158}
159
160impl TimerService {
161 /// Create new TimerService
162 ///
163 /// # Parameters
164 /// - `wheel`: Timing wheel reference
165 /// - `config`: Service configuration
166 ///
167 /// # Notes
168 /// Typically not called directly, but used to create through `TimerWheel::create_service()`
169 ///
170 /// 创建新的 TimerService
171 ///
172 /// # 参数
173 /// - `wheel`: 时间轮引用
174 /// - `config`: 服务配置
175 ///
176 /// # 注意
177 /// 通常不直接调用,而是通过 `TimerWheel::create_service()` 创建
178 ///
179 pub(crate) fn new(wheel: Arc<Mutex<Wheel>>, config: ServiceConfig) -> Self {
180 let (command_tx, command_rx) = spsc::channel(config.command_channel_capacity);
181 let (timeout_tx, timeout_rx) = spsc::channel(config.timeout_channel_capacity);
182
183 let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
184 let actor = ServiceActor::new(command_rx, timeout_tx, shutdown_rx);
185 let actor_handle = tokio::spawn(async move {
186 actor.run().await;
187 });
188
189 Self {
190 command_tx,
191 timeout_rx: Some(timeout_rx),
192 actor_handle: Some(actor_handle),
193 wheel,
194 shutdown_tx: Some(shutdown_tx),
195 }
196 }
197
198 /// Get timeout receiver (transfer ownership)
199 ///
200 /// # Returns
201 /// Timeout notification receiver, if already taken, returns None
202 ///
203 /// # Notes
204 /// This method can only be called once, because it transfers ownership of the receiver
205 /// The receiver will receive both one-shot task expired notifications and periodic task called notifications
206 ///
207 /// 获取超时通知接收器 (转移所有权)
208 ///
209 /// # 返回值
210 /// 超时通知接收器,如果已经取走,返回 None
211 ///
212 /// # 注意
213 /// 此方法只能调用一次,因为它转移了接收器的所有权
214 /// 接收器将接收一次性任务过期通知和周期性任务被调用通知
215 ///
216 /// # Examples (示例)
217 /// ```no_run
218 /// # use kestrel_timer::{TimerWheel, config::ServiceConfig, TaskNotification};
219 /// # #[tokio::main]
220 /// # async fn main() {
221 /// let timer = TimerWheel::with_defaults();
222 /// let mut service = timer.create_service(ServiceConfig::default());
223 ///
224 /// let rx = service.take_receiver().unwrap();
225 /// while let Some(notification) = rx.recv().await {
226 /// match notification {
227 /// TaskNotification::OneShot(task_id) => {
228 /// println!("One-shot task {:?} expired", task_id);
229 /// }
230 /// TaskNotification::Periodic(task_id) => {
231 /// println!("Periodic task {:?} called", task_id);
232 /// }
233 /// }
234 /// }
235 /// # }
236 /// ```
237 pub fn take_receiver(&mut self) -> Option<spsc::Receiver<TaskNotification, 32>> {
238 self.timeout_rx.take()
239 }
240
241 /// Cancel specified task
242 ///
243 /// # Parameters
244 /// - `task_id`: Task ID to cancel
245 ///
246 /// # Returns
247 /// - `true`: Task exists and cancellation is successful
248 /// - `false`: Task does not exist or cancellation fails
249 ///
250 /// 取消指定任务
251 ///
252 /// # 参数
253 /// - `task_id`: 任务 ID
254 ///
255 /// # 返回值
256 /// - `true`: 任务存在且取消成功
257 /// - `false`: 任务不存在或取消失败
258 ///
259 /// # Examples (示例)
260 /// ```no_run
261 /// # use kestrel_timer::{TimerWheel, TimerService, CallbackWrapper, TimerTask, config::ServiceConfig};
262 /// # use std::time::Duration;
263 /// #
264 /// # #[tokio::main]
265 /// # async fn main() {
266 /// let timer = TimerWheel::with_defaults();
267 /// let service = timer.create_service(ServiceConfig::default());
268 ///
269 /// // Use two-step API to schedule timers
270 /// let callback = Some(CallbackWrapper::new(|| async move {
271 /// println!("Timer fired!"); // 定时器触发
272 /// }));
273 /// let task = TimerTask::new_oneshot(Duration::from_secs(10), callback);
274 /// let task_id = task.get_id();
275 /// service.register(task).unwrap(); // 注册定时器
276 ///
277 /// // Cancel task
278 /// let cancelled = service.cancel_task(task_id);
279 /// println!("Task cancelled: {}", cancelled); // 任务取消
280 /// # }
281 /// ```
282 #[inline]
283 pub fn cancel_task(&self, task_id: TaskId) -> bool {
284 // Direct cancellation, no need to notify Actor
285 // FuturesUnordered will automatically clean up when tasks are cancelled
286 // 直接取消,无需通知 Actor
287 // FuturesUnordered 将在任务取消时自动清理
288 let mut wheel = self.wheel.lock();
289 wheel.cancel(task_id)
290 }
291
292 /// Batch cancel tasks
293 ///
294 /// Use underlying batch cancellation operation to cancel multiple tasks at once, performance is better than calling cancel_task repeatedly.
295 ///
296 /// # Parameters
297 /// - `task_ids`: List of task IDs to cancel
298 ///
299 /// # Returns
300 /// Number of successfully cancelled tasks
301 ///
302 /// 批量取消任务
303 ///
304 /// # 参数
305 /// - `task_ids`: 任务 ID 列表
306 ///
307 /// # 返回值
308 /// 成功取消的任务数量
309 ///
310 /// # Examples (示例)
311 /// ```no_run
312 /// # use kestrel_timer::{TimerWheel, TimerService, CallbackWrapper, TimerTask, config::ServiceConfig};
313 /// # use std::time::Duration;
314 /// #
315 /// # #[tokio::main]
316 /// # async fn main() {
317 /// let timer = TimerWheel::with_defaults();
318 /// let service = timer.create_service(ServiceConfig::default());
319 ///
320 /// let tasks: Vec<_> = (0..10)
321 /// .map(|i| {
322 /// let callback = Some(CallbackWrapper::new(move || async move {
323 /// println!("Timer {} fired!", i); // 定时器触发
324 /// }));
325 /// TimerTask::new_oneshot(Duration::from_secs(10), callback)
326 /// })
327 /// .collect();
328 /// let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
329 /// service.register_batch(tasks).unwrap(); // 注册定时器
330 ///
331 /// // Batch cancel
332 /// let cancelled = service.cancel_batch(&task_ids);
333 /// println!("Cancelled {} tasks", cancelled); // 任务取消
334 /// # }
335 /// ```
336 #[inline]
337 pub fn cancel_batch(&self, task_ids: &[TaskId]) -> usize {
338 if task_ids.is_empty() {
339 return 0;
340 }
341
342 // Direct batch cancellation, no need to notify Actor
343 // FuturesUnordered will automatically clean up when tasks are cancelled
344 // 直接批量取消,无需通知 Actor
345 // FuturesUnordered 将在任务取消时自动清理
346 let mut wheel = self.wheel.lock();
347 wheel.cancel_batch(task_ids)
348 }
349
350 /// Postpone task (replace callback)
351 ///
352 /// # Parameters
353 /// - `task_id`: Task ID to postpone
354 /// - `new_delay`: New delay time (recalculated from current time point)
355 /// - `callback`: New callback function
356 ///
357 /// # Returns
358 /// - `true`: Task exists and is successfully postponed
359 /// - `false`: Task does not exist or postponement fails
360 ///
361 /// # Notes
362 /// - Task ID remains unchanged after postponement
363 /// - Original timeout notification remains valid
364 /// - Callback function will be replaced with new callback
365 ///
366 /// 推迟任务 (替换回调)
367 ///
368 /// # 参数
369 /// - `task_id`: 任务 ID
370 /// - `new_delay`: 新的延迟时间 (从当前时间点重新计算)
371 /// - `callback`: 新的回调函数
372 ///
373 /// # 返回值
374 /// - `true`: 任务存在且延期成功
375 /// - `false`: 任务不存在或延期失败
376 ///
377 /// # 注意
378 /// - 任务 ID 在延期后保持不变
379 /// - 原始超时通知保持有效
380 /// - 回调函数将被新的回调函数替换
381 ///
382 /// # Examples (示例)
383 /// ```no_run
384 /// # use kestrel_timer::{TimerWheel, TimerService, CallbackWrapper, TimerTask, config::ServiceConfig};
385 /// # use std::time::Duration;
386 /// #
387 /// # #[tokio::main]
388 /// # async fn main() {
389 /// let timer = TimerWheel::with_defaults();
390 /// let service = timer.create_service(ServiceConfig::default());
391 ///
392 /// let callback = Some(CallbackWrapper::new(|| async {
393 /// println!("Original callback"); // 原始回调
394 /// }));
395 /// let task = TimerTask::new_oneshot(Duration::from_secs(5), callback);
396 /// let task_id = task.get_id();
397 /// service.register(task).unwrap(); // 注册定时器
398 ///
399 /// // Postpone and replace callback (延期并替换回调)
400 /// let new_callback = Some(CallbackWrapper::new(|| async { println!("New callback!"); }));
401 /// let success = service.postpone(
402 /// task_id,
403 /// Duration::from_secs(10),
404 /// new_callback
405 /// );
406 /// println!("Postponed successfully: {}", success);
407 /// # }
408 /// ```
409 #[inline]
410 pub fn postpone(&self, task_id: TaskId, new_delay: Duration, callback: Option<CallbackWrapper>) -> bool {
411 let mut wheel = self.wheel.lock();
412 wheel.postpone(task_id, new_delay, callback)
413 }
414
415 /// Batch postpone tasks (keep original callbacks)
416 ///
417 /// # Parameters
418 /// - `updates`: List of tuples of (task ID, new delay)
419 ///
420 /// # Returns
421 /// Number of successfully postponed tasks
422 ///
423 /// 批量延期任务 (保持原始回调)
424 ///
425 /// # 参数
426 /// - `updates`: (任务 ID, 新延迟) 元组列表
427 ///
428 /// # 返回值
429 /// 成功延期的任务数量
430 ///
431 /// # Examples (示例)
432 /// ```no_run
433 /// # use kestrel_timer::{TimerWheel, TimerService, CallbackWrapper, TimerTask, config::ServiceConfig};
434 /// # use std::time::Duration;
435 /// #
436 /// # #[tokio::main]
437 /// # async fn main() {
438 /// let timer = TimerWheel::with_defaults();
439 /// let service = timer.create_service(ServiceConfig::default());
440 ///
441 /// let tasks: Vec<_> = (0..3)
442 /// .map(|i| {
443 /// let callback = Some(CallbackWrapper::new(move || async move {
444 /// println!("Timer {} fired!", i);
445 /// }));
446 /// TimerTask::new_oneshot(Duration::from_secs(5), callback)
447 /// })
448 /// .collect();
449 /// let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
450 /// service.register_batch(tasks).unwrap();
451 ///
452 /// // Batch postpone (keep original callbacks)
453 /// // 批量延期任务 (保持原始回调)
454 /// let updates: Vec<_> = task_ids
455 /// .into_iter()
456 /// .map(|id| (id, Duration::from_secs(10)))
457 /// .collect();
458 /// let postponed = service.postpone_batch(updates);
459 /// println!("Postponed {} tasks", postponed);
460 /// # }
461 /// ```
462 #[inline]
463 pub fn postpone_batch(&self, updates: Vec<(TaskId, Duration)>) -> usize {
464 if updates.is_empty() {
465 return 0;
466 }
467
468 let mut wheel = self.wheel.lock();
469 wheel.postpone_batch(updates)
470 }
471
472 /// Batch postpone tasks (replace callbacks)
473 ///
474 /// # Parameters
475 /// - `updates`: List of tuples of (task ID, new delay, new callback)
476 ///
477 /// # Returns
478 /// Number of successfully postponed tasks
479 ///
480 /// 批量延期任务 (替换回调)
481 ///
482 /// # 参数
483 /// - `updates`: (任务 ID, 新延迟, 新回调) 元组列表
484 ///
485 /// # 返回值
486 /// 成功延期的任务数量
487 ///
488 /// # Examples (示例)
489 /// ```no_run
490 /// # use kestrel_timer::{TimerWheel, TimerService, CallbackWrapper, config::ServiceConfig};
491 /// # use std::time::Duration;
492 /// #
493 /// # #[tokio::main]
494 /// # async fn main() {
495 /// # use kestrel_timer::TimerTask;
496 /// let timer = TimerWheel::with_defaults();
497 /// let service = timer.create_service(ServiceConfig::default());
498 ///
499 /// // Create 3 tasks, initially no callbacks
500 /// // 创建 3 个任务,最初没有回调
501 /// let tasks: Vec<_> = (0..3)
502 /// .map(|_| TimerTask::new_oneshot(Duration::from_secs(5), None))
503 /// .collect();
504 /// let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
505 /// service.register_batch(tasks).unwrap();
506 ///
507 /// // Batch postpone and add new callbacks
508 /// // 批量延期并添加新的回调
509 /// let updates: Vec<_> = task_ids
510 /// .into_iter()
511 /// .enumerate()
512 /// .map(|(i, id)| {
513 /// let callback = Some(CallbackWrapper::new(move || async move {
514 /// println!("New callback {}", i);
515 /// }));
516 /// (id, Duration::from_secs(10), callback)
517 /// })
518 /// .collect();
519 /// let postponed = service.postpone_batch_with_callbacks(updates);
520 /// println!("Postponed {} tasks", postponed);
521 /// # }
522 /// ```
523 #[inline]
524 pub fn postpone_batch_with_callbacks(
525 &self,
526 updates: Vec<(TaskId, Duration, Option<CallbackWrapper>)>,
527 ) -> usize {
528 if updates.is_empty() {
529 return 0;
530 }
531
532 let mut wheel = self.wheel.lock();
533 wheel.postpone_batch_with_callbacks(updates)
534 }
535
536 /// Register timer task to service (registration phase)
537 ///
538 /// # Parameters
539 /// - `task`: Task created via `TimerTask::new_oneshot()`
540 ///
541 /// # Returns
542 /// - `Ok(TimerHandle)`: Register successfully
543 /// - `Err(TimerError::RegisterFailed)`: Register failed (internal channel is full or closed)
544 ///
545 /// 注册定时器任务到服务 (注册阶段)
546 /// # 参数
547 /// - `task`: 通过 `TimerTask::new_oneshot()` 创建的任务
548 ///
549 /// # 返回值
550 /// - `Ok(TimerHandle)`: 注册成功
551 /// - `Err(TimerError::RegisterFailed)`: 注册失败 (内部通道已满或关闭)
552 ///
553 /// # Examples (示例)
554 /// ```no_run
555 /// # use kestrel_timer::{TimerWheel, TimerService, CallbackWrapper, config::ServiceConfig, TimerTask};
556 /// # use std::time::Duration;
557 /// #
558 /// # #[tokio::main]
559 /// # async fn main() {
560 /// let timer = TimerWheel::with_defaults();
561 /// let service = timer.create_service(ServiceConfig::default());
562 ///
563 /// // Step 1: create task
564 /// // 创建任务
565 /// let callback = Some(CallbackWrapper::new(|| async move {
566 /// println!("Timer fired!");
567 /// }));
568 /// let task = TimerTask::new_oneshot(Duration::from_millis(100), callback);
569 /// let task_id = task.get_id();
570 ///
571 /// // Step 2: register task
572 /// // 注册任务
573 /// service.register(task).unwrap();
574 /// # }
575 /// ```
576 #[inline]
577 pub fn register(&self, task: crate::task::TimerTask) -> Result<TimerHandle, TimerError> {
578
579 let (task, completion_rx) = crate::task::TimerTaskWithCompletionNotifier::from_timer_task(task);
580
581 let task_id = task.id;
582
583 // Single lock, complete all operations
584 // 单次锁定,完成所有操作
585 {
586 let mut wheel_guard = self.wheel.lock();
587 wheel_guard.insert(task);
588 }
589
590 // Add to service management (only send necessary data)
591 // 添加到服务管理(只发送必要数据)
592 self.command_tx
593 .try_send(ServiceCommand::AddTimerHandle {
594 task_id,
595 completion_rx,
596 })
597 .map_err(|_| TimerError::RegisterFailed)?;
598
599 Ok(TimerHandle::new(task_id, self.wheel.clone()))
600 }
601
602 /// Batch register timer tasks to service (registration phase)
603 ///
604 /// # Parameters
605 /// - `tasks`: List of tasks created via `TimerTask::new_oneshot()`
606 ///
607 /// # Returns
608 /// - `Ok(BatchHandle)`: Register successfully
609 /// - `Err(TimerError::RegisterFailed)`: Register failed (internal channel is full or closed)
610 ///
611 /// 批量注册定时器任务到服务 (注册阶段)
612 /// # 参数
613 /// - `tasks`: 通过 `TimerTask::new_oneshot()` 创建的任务列表
614 ///
615 /// # 返回值
616 /// - `Ok(BatchHandle)`: 注册成功
617 /// - `Err(TimerError::RegisterFailed)`: 注册失败 (内部通道已满或关闭)
618 ///
619 /// # Examples (示例)
620 /// ```no_run
621 /// # use kestrel_timer::{TimerWheel, TimerService, CallbackWrapper, config::ServiceConfig, TimerTask};
622 /// # use std::time::Duration;
623 /// #
624 /// # #[tokio::main]
625 /// # async fn main() {
626 /// # use kestrel_timer::TimerTask;
627 /// let timer = TimerWheel::with_defaults();
628 /// let service = timer.create_service(ServiceConfig::default());
629 ///
630 /// // Step 1: create batch of tasks with callbacks
631 /// // 创建批量任务,带有回调
632 /// let tasks: Vec<TimerTask> = (0..3)
633 /// .map(|i| {
634 /// let callback = Some(CallbackWrapper::new(move || async move {
635 /// println!("Timer {} fired!", i);
636 /// }));
637 /// TimerTask::new_oneshot(Duration::from_secs(1), callback)
638 /// })
639 /// .collect();
640 ///
641 /// // Step 2: register batch of tasks with callbacks
642 /// // 注册批量任务,带有回调
643 /// service.register_batch(tasks).unwrap();
644 /// # }
645 /// ```
646 #[inline]
647 pub fn register_batch(&self, tasks: Vec<crate::task::TimerTask>) -> Result<BatchHandle, TimerError> {
648 let task_count = tasks.len();
649 let mut completion_rxs = Vec::with_capacity(task_count);
650 let mut task_ids = Vec::with_capacity(task_count);
651 let mut prepared_tasks = Vec::with_capacity(task_count);
652
653 // Step 1: prepare all channels and notifiers (no lock)
654 // 步骤 1: 准备所有通道和通知器(无锁)
655 for task in tasks {
656 let (task, completion_rx) = crate::task::TimerTaskWithCompletionNotifier::from_timer_task(task);
657 task_ids.push(task.id);
658 completion_rxs.push(completion_rx);
659 prepared_tasks.push(task);
660 }
661
662 // Step 2: single lock, batch insert
663 // 步骤 2: 单次锁定,批量插入
664 {
665 let mut wheel_guard = self.wheel.lock();
666 wheel_guard.insert_batch(prepared_tasks);
667 }
668
669 // Add to service management (only send necessary data)
670 // 添加到服务管理(只发送必要数据)
671 self.command_tx
672 .try_send(ServiceCommand::AddBatchHandle {
673 task_ids: task_ids.clone(),
674 completion_rxs,
675 })
676 .map_err(|_| TimerError::RegisterFailed)?;
677
678 Ok(BatchHandle::new(task_ids, self.wheel.clone()))
679 }
680
681 /// Graceful shutdown of TimerService
682 ///
683 /// 优雅关闭 TimerService
684 ///
685 /// # Examples (示例)
686 /// ```no_run
687 /// # use kestrel_timer::{TimerWheel, config::ServiceConfig};
688 /// # #[tokio::main]
689 /// # async fn main() {
690 /// let timer = TimerWheel::with_defaults();
691 /// let mut service = timer.create_service(ServiceConfig::default());
692 ///
693 /// // Use service... (使用服务...)
694 ///
695 /// service.shutdown().await;
696 /// # }
697 /// ```
698 pub async fn shutdown(mut self) {
699 if let Some(shutdown_tx) = self.shutdown_tx.take() {
700 shutdown_tx.notify(());
701 }
702 if let Some(handle) = self.actor_handle.take() {
703 let _ = handle.await;
704 }
705 }
706}
707
708
709impl Drop for TimerService {
710 fn drop(&mut self) {
711 if let Some(handle) = self.actor_handle.take() {
712 handle.abort();
713 }
714 }
715}
716
717/// ServiceActor - internal Actor implementation
718///
719/// ServiceActor - 内部 Actor 实现
720struct ServiceActor {
721 /// Command receiver
722 ///
723 /// 命令接收器
724 command_rx: spsc::Receiver<ServiceCommand, 32>,
725 /// Timeout sender (supports both one-shot and periodic task notifications)
726 ///
727 /// 超时发送器(支持一次性和周期性任务通知)
728 timeout_tx: spsc::Sender<TaskNotification, 32>,
729 /// Actor shutdown signal receiver
730 ///
731 /// Actor 关闭信号接收器
732 shutdown_rx: oneshot::Receiver<()>,
733}
734
735impl ServiceActor {
736 /// Create new ServiceActor
737 ///
738 /// 创建新的 ServiceActor
739 fn new(command_rx: spsc::Receiver<ServiceCommand, 32>, timeout_tx: spsc::Sender<TaskNotification, 32>, shutdown_rx: oneshot::Receiver<()>) -> Self {
740 Self {
741 command_rx,
742 timeout_tx,
743 shutdown_rx,
744 }
745 }
746
747 /// Run Actor event loop
748 ///
749 /// 运行 Actor 事件循环
750 async fn run(self) {
751 // Use separate FuturesUnordered for one-shot and periodic tasks
752 // 为一次性任务和周期性任务使用独立的 FuturesUnordered
753
754 // One-shot futures: each future returns (TaskId, TaskCompletion)
755 // 一次性任务 futures:每个 future 返回 (TaskId, TaskCompletion)
756 let mut oneshot_futures: FuturesUnordered<BoxFuture<'static, (TaskId, TaskCompletion)>> = FuturesUnordered::new();
757
758 // Periodic futures: each future returns (TaskId, Option<PeriodicTaskCompletion>, mpsc::Receiver)
759 // The receiver is returned so we can continue listening for next event
760 // 周期性任务 futures:每个 future 返回 (TaskId, Option<PeriodicTaskCompletion>, mpsc::Receiver)
761 // 返回接收器以便我们可以继续监听下一个事件
762 type PeriodicFutureResult = (TaskId, Option<TaskCompletion>, crate::task::PeriodicCompletionReceiver);
763 let mut periodic_futures: FuturesUnordered<BoxFuture<'static, PeriodicFutureResult>> = FuturesUnordered::new();
764
765 // Move shutdown_rx out of self, so it can be used in select! with &mut
766 // 将 shutdown_rx 从 self 中移出,以便在 select! 中使用 &mut
767 let mut shutdown_rx = self.shutdown_rx;
768
769 loop {
770 tokio::select! {
771 // Listen to high-priority shutdown signal
772 // 监听高优先级关闭信号
773 _ = &mut shutdown_rx => {
774 // Receive shutdown signal, exit loop immediately
775 // 接收到关闭信号,立即退出循环
776 break;
777 }
778
779 // Listen to one-shot task timeout events
780 // 监听一次性任务超时事件
781 Some((task_id, completion)) = oneshot_futures.next() => {
782 // Check completion reason, only forward Called events, do not forward Cancelled events
783 // 检查完成原因,只转发 Called 事件,不转发 Cancelled 事件
784 if completion == TaskCompletion::Called {
785 let _ = self.timeout_tx.send(TaskNotification::OneShot(task_id)).await;
786 }
787 // Task will be automatically removed from FuturesUnordered
788 // 任务将自动从 FuturesUnordered 中移除
789 }
790
791 // Listen to periodic task events
792 // 监听周期性任务事件
793 Some((task_id, reason, mut receiver)) = periodic_futures.next() => {
794 // Check completion reason, only forward Called events, do not forward Cancelled events
795 // 检查完成原因,只转发 Called 事件,不转发 Cancelled 事件
796 if let Some(TaskCompletion::Called) = reason {
797 let _ = self.timeout_tx.send(TaskNotification::Periodic(task_id)).await;
798
799 // Re-add the receiver to continue listening for next periodic event
800 // 重新添加接收器以继续监听下一个周期性事件
801 let future: BoxFuture<'static, PeriodicFutureResult> = Box::pin(async move {
802 let reason = receiver.recv().await;
803 (task_id, reason, receiver)
804 });
805 periodic_futures.push(future);
806 }
807 // If Cancelled or None, do not re-add the future (task is done)
808 // 如果是 Cancelled 或 None,不重新添加 future(任务结束)
809 }
810
811 // Listen to commands
812 // 监听命令
813 Some(cmd) = self.command_rx.recv() => {
814 match cmd {
815 ServiceCommand::AddBatchHandle { task_ids, completion_rxs } => {
816 // Add all tasks to appropriate futures
817 // 将所有任务添加到相应的 futures
818 for (task_id, rx) in task_ids.into_iter().zip(completion_rxs.into_iter()) {
819 match rx {
820 crate::task::CompletionReceiver::OneShot(receiver) => {
821 let future: BoxFuture<'static, (TaskId, TaskCompletion)> = Box::pin(async move {
822 (task_id, receiver.wait().await)
823 });
824 oneshot_futures.push(future);
825 },
826 crate::task::CompletionReceiver::Periodic(mut receiver) => {
827 let future: BoxFuture<'static, PeriodicFutureResult> = Box::pin(async move {
828 let reason = receiver.recv().await;
829 (task_id, reason, receiver)
830 });
831 periodic_futures.push(future);
832 }
833 }
834 }
835 }
836 ServiceCommand::AddTimerHandle { task_id, completion_rx } => {
837 // Add to appropriate futures
838 // 添加到相应的 futures
839 match completion_rx {
840 crate::task::CompletionReceiver::OneShot(receiver) => {
841 let future: BoxFuture<'static, (TaskId, TaskCompletion)> = Box::pin(async move {
842 (task_id, receiver.wait().await)
843 });
844 oneshot_futures.push(future);
845 },
846 crate::task::CompletionReceiver::Periodic(mut receiver) => {
847 let future: BoxFuture<'static, PeriodicFutureResult> = Box::pin(async move {
848 let reason = receiver.recv().await;
849 (task_id, reason, receiver)
850 });
851 periodic_futures.push(future);
852 }
853 }
854 }
855 }
856 }
857
858 // If no futures and command channel is closed, exit loop
859 // 如果没有 futures 且命令通道关闭,退出循环
860 else => {
861 break;
862 }
863 }
864 }
865 }
866}
867
868#[cfg(test)]
869mod tests {
870 use super::*;
871 use crate::{TimerWheel, TimerTask};
872 use std::sync::atomic::{AtomicU32, Ordering};
873 use std::sync::Arc;
874 use std::time::Duration;
875
876 #[tokio::test]
877 async fn test_service_creation() {
878 let timer = TimerWheel::with_defaults();
879 let _service = timer.create_service(ServiceConfig::default());
880 }
881
882
883 #[tokio::test]
884 async fn test_add_timer_handle_and_receive_timeout() {
885 let timer = TimerWheel::with_defaults();
886 let mut service = timer.create_service(ServiceConfig::default());
887
888 // Create single timer (创建单个定时器)
889 let task = TimerTask::new_oneshot(Duration::from_millis(50), Some(CallbackWrapper::new(|| async {})));
890 let task_id = task.get_id();
891
892 // Register to service (注册到服务)
893 service.register(task).unwrap();
894
895 // Receive timeout notification (接收超时通知)
896 let rx = service.take_receiver().unwrap();
897 let received_notification = tokio::time::timeout(Duration::from_millis(200), rx.recv())
898 .await
899 .expect("Should receive timeout notification")
900 .expect("Should receive Some value");
901
902 assert_eq!(received_notification, TaskNotification::OneShot(task_id));
903 }
904
905
906 #[tokio::test]
907 async fn test_shutdown() {
908 let timer = TimerWheel::with_defaults();
909 let service = timer.create_service(ServiceConfig::default());
910
911 // Add some timers (添加一些定时器)
912 let task1 = TimerTask::new_oneshot(Duration::from_secs(10), None);
913 let task2 = TimerTask::new_oneshot(Duration::from_secs(10), None);
914 service.register(task1).unwrap();
915 service.register(task2).unwrap();
916
917 // Immediately shutdown (without waiting for timers to trigger) (立即关闭(不等待定时器触发))
918 service.shutdown().await;
919 }
920
921
922
923 #[tokio::test]
924 async fn test_cancel_task() {
925 let timer = TimerWheel::with_defaults();
926 let service = timer.create_service(ServiceConfig::default());
927
928 // Add a long-term timer (添加一个长期定时器)
929 let task = TimerTask::new_oneshot(Duration::from_secs(10), None);
930 let task_id = task.get_id();
931
932 service.register(task).unwrap();
933
934 // Cancel task (取消任务)
935 let cancelled = service.cancel_task(task_id);
936 assert!(cancelled, "Task should be cancelled successfully");
937
938 // Try to cancel the same task again, should return false (再次尝试取消同一任务,应返回 false)
939 let cancelled_again = service.cancel_task(task_id);
940 assert!(!cancelled_again, "Task should not exist anymore");
941 }
942
943 #[tokio::test]
944 async fn test_cancel_nonexistent_task() {
945 let timer = TimerWheel::with_defaults();
946 let service = timer.create_service(ServiceConfig::default());
947
948 // Add a timer to initialize service (添加定时器以初始化服务)
949 let task = TimerTask::new_oneshot(Duration::from_millis(50), None);
950 service.register(task).unwrap();
951
952 // Try to cancel a nonexistent task (create a task ID that will not actually be registered)
953 // 尝试取消不存在的任务(创建一个实际不会注册的任务 ID)
954 let fake_task = TimerTask::new_oneshot(Duration::from_millis(50), None);
955 let fake_task_id = fake_task.get_id();
956 // Do not register fake_task (不注册 fake_task)
957 let cancelled = service.cancel_task(fake_task_id);
958 assert!(!cancelled, "Nonexistent task should not be cancelled");
959 }
960
961
962 #[tokio::test]
963 async fn test_task_timeout_cleans_up_task_sender() {
964 let timer = TimerWheel::with_defaults();
965 let mut service = timer.create_service(ServiceConfig::default());
966
967 // Add a short-term timer (添加短期定时器)
968 let task = TimerTask::new_oneshot(Duration::from_millis(50), None);
969 let task_id = task.get_id();
970
971 service.register(task).unwrap();
972
973 // Wait for task timeout (等待任务超时)
974 let rx = service.take_receiver().unwrap();
975 let received_notification = tokio::time::timeout(Duration::from_millis(200), rx.recv())
976 .await
977 .expect("Should receive timeout notification")
978 .expect("Should receive Some value");
979
980 assert_eq!(received_notification, TaskNotification::OneShot(task_id));
981
982 // Wait a moment to ensure internal cleanup is complete (等待片刻以确保内部清理完成)
983 tokio::time::sleep(Duration::from_millis(10)).await;
984
985 // Try to cancel the timed-out task, should return false (尝试取消超时任务,应返回 false)
986 let cancelled = service.cancel_task(task_id);
987 assert!(!cancelled, "Timed out task should not exist anymore");
988 }
989
990 #[tokio::test]
991 async fn test_cancel_task_spawns_background_task() {
992 let timer = TimerWheel::with_defaults();
993 let service = timer.create_service(ServiceConfig::default());
994 let counter = Arc::new(AtomicU32::new(0));
995
996 // Create a timer (创建定时器)
997 let counter_clone = Arc::clone(&counter);
998 let task = TimerTask::new_oneshot(
999 Duration::from_secs(10),
1000 Some(CallbackWrapper::new(move || {
1001 let counter = Arc::clone(&counter_clone);
1002 async move {
1003 counter.fetch_add(1, Ordering::SeqCst);
1004 }
1005 })),
1006 );
1007 let task_id = task.get_id();
1008
1009 service.register(task).unwrap();
1010
1011 // Use cancel_task (will wait for result, but processed in background coroutine)
1012 // 使用 cancel_task(将等待结果,但在后台协程中处理)
1013 let cancelled = service.cancel_task(task_id);
1014 assert!(cancelled, "Task should be cancelled successfully");
1015
1016 // Wait long enough to ensure callback is not executed (等待足够长时间以确保回调未执行)
1017 tokio::time::sleep(Duration::from_millis(100)).await;
1018 assert_eq!(counter.load(Ordering::SeqCst), 0, "Callback should not have been executed");
1019
1020 // Verify task has been removed from active_tasks (验证任务已从 active_tasks 中移除)
1021 let cancelled_again = service.cancel_task(task_id);
1022 assert!(!cancelled_again, "Task should have been removed from active_tasks");
1023 }
1024
1025 #[tokio::test]
1026 async fn test_schedule_once_direct() {
1027 let timer = TimerWheel::with_defaults();
1028 let mut service = timer.create_service(ServiceConfig::default());
1029 let counter = Arc::new(AtomicU32::new(0));
1030
1031 // Schedule timer directly through service
1032 // 直接通过服务调度定时器
1033 let counter_clone = Arc::clone(&counter);
1034 let task = TimerTask::new_oneshot(
1035 Duration::from_millis(50),
1036 Some(CallbackWrapper::new(move || {
1037 let counter = Arc::clone(&counter_clone);
1038 async move {
1039 counter.fetch_add(1, Ordering::SeqCst);
1040 }
1041 })),
1042 );
1043 let task_id = task.get_id();
1044 service.register(task).unwrap();
1045
1046 // Wait for timer to trigger
1047 // 等待定时器触发
1048 let rx = service.take_receiver().unwrap();
1049 let received_notification = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1050 .await
1051 .expect("Should receive timeout notification")
1052 .expect("Should receive Some value");
1053
1054 assert_eq!(received_notification, TaskNotification::OneShot(task_id));
1055
1056 // Wait for callback to execute
1057 // 等待回调执行
1058 tokio::time::sleep(Duration::from_millis(50)).await;
1059 assert_eq!(counter.load(Ordering::SeqCst), 1);
1060 }
1061
1062 #[tokio::test]
1063 async fn test_schedule_once_batch_direct() {
1064 let timer = TimerWheel::with_defaults();
1065 let mut service = timer.create_service(ServiceConfig::default());
1066 let counter = Arc::new(AtomicU32::new(0));
1067
1068 // Schedule timers directly through service
1069 // 直接通过服务调度定时器
1070 let tasks: Vec<_> = (0..3)
1071 .map(|_| {
1072 let counter = Arc::clone(&counter);
1073 TimerTask::new_oneshot(Duration::from_millis(50), Some(CallbackWrapper::new(move || {
1074 let counter = Arc::clone(&counter);
1075 async move {
1076 counter.fetch_add(1, Ordering::SeqCst);
1077 }
1078 })))
1079 })
1080 .collect();
1081
1082 assert_eq!(tasks.len(), 3);
1083 service.register_batch(tasks).unwrap();
1084
1085 // Receive all timeout notifications
1086 // 接收所有超时通知
1087 let mut received_count = 0;
1088 let rx = service.take_receiver().unwrap();
1089
1090 while received_count < 3 {
1091 match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
1092 Ok(Some(_task_id)) => {
1093 received_count += 1;
1094 }
1095 Ok(None) => break,
1096 Err(_) => break,
1097 }
1098 }
1099
1100 assert_eq!(received_count, 3);
1101
1102 // Wait for callback to execute
1103 // 等待回调执行
1104 tokio::time::sleep(Duration::from_millis(50)).await;
1105 assert_eq!(counter.load(Ordering::SeqCst), 3);
1106 }
1107
1108 #[tokio::test]
1109 async fn test_schedule_once_notify_direct() {
1110 let timer = TimerWheel::with_defaults();
1111 let mut service = timer.create_service(ServiceConfig::default());
1112
1113 // Schedule only notification timer directly through service (no callback)
1114 // 直接通过服务调度通知定时器(没有回调)
1115 let task = TimerTask::new_oneshot(Duration::from_millis(50), None);
1116 let task_id = task.get_id();
1117 service.register(task).unwrap();
1118
1119 // Receive timeout notification
1120 // 接收超时通知
1121 let rx = service.take_receiver().unwrap();
1122 let received_notification = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1123 .await
1124 .expect("Should receive timeout notification")
1125 .expect("Should receive Some value");
1126
1127 assert_eq!(received_notification, TaskNotification::OneShot(task_id));
1128 }
1129
1130 #[tokio::test]
1131 async fn test_schedule_and_cancel_direct() {
1132 let timer = TimerWheel::with_defaults();
1133 let service = timer.create_service(ServiceConfig::default());
1134 let counter = Arc::new(AtomicU32::new(0));
1135
1136 // Schedule timer directly
1137 // 直接调度定时器
1138 let counter_clone = Arc::clone(&counter);
1139 let task = TimerTask::new_oneshot(
1140 Duration::from_secs(10),
1141 Some(CallbackWrapper::new(move || {
1142 let counter = Arc::clone(&counter_clone);
1143 async move {
1144 counter.fetch_add(1, Ordering::SeqCst);
1145 }
1146 })),
1147 );
1148 let task_id = task.get_id();
1149 service.register(task).unwrap();
1150
1151 // Immediately cancel
1152 // 立即取消
1153 let cancelled = service.cancel_task(task_id);
1154 assert!(cancelled, "Task should be cancelled successfully");
1155
1156 // Wait to ensure callback is not executed
1157 // 等待确保回调未执行
1158 tokio::time::sleep(Duration::from_millis(100)).await;
1159 assert_eq!(counter.load(Ordering::SeqCst), 0, "Callback should not have been executed");
1160 }
1161
1162 #[tokio::test]
1163 async fn test_cancel_batch_direct() {
1164 let timer = TimerWheel::with_defaults();
1165 let service = timer.create_service(ServiceConfig::default());
1166 let counter = Arc::new(AtomicU32::new(0));
1167
1168 // Batch schedule timers
1169 // 批量调度定时器
1170 let tasks: Vec<_> = (0..10)
1171 .map(|_| {
1172 let counter = Arc::clone(&counter);
1173 TimerTask::new_oneshot(Duration::from_secs(10), Some(CallbackWrapper::new(move || {
1174 let counter = Arc::clone(&counter);
1175 async move {
1176 counter.fetch_add(1, Ordering::SeqCst);
1177 }
1178 })))
1179 })
1180 .collect();
1181
1182 let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
1183 assert_eq!(task_ids.len(), 10);
1184 service.register_batch(tasks).unwrap();
1185
1186 // Batch cancel all tasks
1187 // 批量取消所有任务
1188 let cancelled = service.cancel_batch(&task_ids);
1189 assert_eq!(cancelled, 10, "All 10 tasks should be cancelled");
1190
1191 // Wait to ensure callback is not executed
1192 // 等待确保回调未执行
1193 tokio::time::sleep(Duration::from_millis(100)).await;
1194 assert_eq!(counter.load(Ordering::SeqCst), 0, "No callbacks should have been executed");
1195 }
1196
1197 #[tokio::test]
1198 async fn test_cancel_batch_partial() {
1199 let timer = TimerWheel::with_defaults();
1200 let service = timer.create_service(ServiceConfig::default());
1201 let counter = Arc::new(AtomicU32::new(0));
1202
1203 // Batch schedule timers
1204 // 批量调度定时器
1205 let tasks: Vec<_> = (0..10)
1206 .map(|_| {
1207 let counter = Arc::clone(&counter);
1208 TimerTask::new_oneshot(Duration::from_secs(10), Some(CallbackWrapper::new(move || {
1209 let counter = Arc::clone(&counter);
1210 async move {
1211 counter.fetch_add(1, Ordering::SeqCst);
1212 }
1213 })))
1214 })
1215 .collect();
1216
1217 let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
1218 service.register_batch(tasks).unwrap();
1219
1220 // Only cancel first 5 tasks
1221 // 只取消前 5 个任务
1222 let to_cancel: Vec<_> = task_ids.iter().take(5).copied().collect();
1223 let cancelled = service.cancel_batch(&to_cancel);
1224 assert_eq!(cancelled, 5, "5 tasks should be cancelled");
1225
1226 // Wait to ensure first 5 callbacks are not executed
1227 // 等待确保前 5 个回调未执行
1228 tokio::time::sleep(Duration::from_millis(100)).await;
1229 assert_eq!(counter.load(Ordering::SeqCst), 0, "Cancelled tasks should not execute");
1230 }
1231
1232 #[tokio::test]
1233 async fn test_cancel_batch_empty() {
1234 let timer = TimerWheel::with_defaults();
1235 let service = timer.create_service(ServiceConfig::default());
1236
1237 // Cancel empty list
1238 // 取消空列表
1239 let empty: Vec<TaskId> = vec![];
1240 let cancelled = service.cancel_batch(&empty);
1241 assert_eq!(cancelled, 0, "No tasks should be cancelled");
1242 }
1243
1244 #[tokio::test]
1245 async fn test_postpone() {
1246 let timer = TimerWheel::with_defaults();
1247 let mut service = timer.create_service(ServiceConfig::default());
1248 let counter = Arc::new(AtomicU32::new(0));
1249
1250 // Register a task, original callback increases 1
1251 // 注册一个任务,原始回调增加 1
1252 let counter_clone1 = Arc::clone(&counter);
1253 let task = TimerTask::new_oneshot(
1254 Duration::from_millis(50),
1255 Some(CallbackWrapper::new(move || {
1256 let counter = Arc::clone(&counter_clone1);
1257 async move {
1258 counter.fetch_add(1, Ordering::SeqCst);
1259 }
1260 })),
1261 );
1262 let task_id = task.get_id();
1263 service.register(task).unwrap();
1264
1265 // Postpone task and replace callback, new callback increases 10
1266 // 延期任务并替换回调,新回调增加 10
1267 let counter_clone2 = Arc::clone(&counter);
1268 let postponed = service.postpone(
1269 task_id,
1270 Duration::from_millis(100),
1271 Some(CallbackWrapper::new(move || {
1272 let counter = Arc::clone(&counter_clone2);
1273 async move {
1274 counter.fetch_add(10, Ordering::SeqCst);
1275 }
1276 }))
1277 );
1278 assert!(postponed, "Task should be postponed successfully");
1279
1280 // Receive timeout notification (after postponing, need to wait 100ms, plus margin)
1281 // 接收超时通知(延期后,需要等待 100ms,加上余量)
1282 let rx = service.take_receiver().unwrap();
1283 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1284 .await
1285 .expect("Should receive timeout notification")
1286 .expect("Should receive Some value");
1287
1288 assert_eq!(received_task_id, TaskNotification::OneShot(task_id));
1289
1290 // Wait for callback to execute
1291 // 等待回调执行
1292 tokio::time::sleep(Duration::from_millis(20)).await;
1293
1294 // Verify new callback is executed (increased 10 instead of 1)
1295 // 验证新回调已执行(增加 10 而不是 1)
1296 assert_eq!(counter.load(Ordering::SeqCst), 10);
1297 }
1298
1299 #[tokio::test]
1300 async fn test_postpone_nonexistent_task() {
1301 let timer = TimerWheel::with_defaults();
1302 let service = timer.create_service(ServiceConfig::default());
1303
1304 // Try to postpone a nonexistent task
1305 // 尝试延期一个不存在的任务
1306 let fake_task = TimerTask::new_oneshot(Duration::from_millis(50), None);
1307 let fake_task_id = fake_task.get_id();
1308 // Do not register this task
1309 // 不注册这个任务
1310 let postponed = service.postpone(fake_task_id, Duration::from_millis(100), None);
1311 assert!(!postponed, "Nonexistent task should not be postponed");
1312 }
1313
1314 #[tokio::test]
1315 async fn test_postpone_batch() {
1316 let timer = TimerWheel::with_defaults();
1317 let mut service = timer.create_service(ServiceConfig::default());
1318 let counter = Arc::new(AtomicU32::new(0));
1319
1320 // Register 3 tasks
1321 // 注册 3 个任务
1322 let mut task_ids = Vec::new();
1323 for _ in 0..3 {
1324 let counter_clone = Arc::clone(&counter);
1325 let task = TimerTask::new_oneshot(
1326 Duration::from_millis(50),
1327 Some(CallbackWrapper::new(move || {
1328 let counter = Arc::clone(&counter_clone);
1329 async move {
1330 counter.fetch_add(1, Ordering::SeqCst);
1331 }
1332 })),
1333 );
1334 task_ids.push((task.get_id(), Duration::from_millis(150), None));
1335 service.register(task).unwrap();
1336 }
1337
1338 // Batch postpone
1339 // 批量延期
1340 let postponed = service.postpone_batch_with_callbacks(task_ids);
1341 assert_eq!(postponed, 3, "All 3 tasks should be postponed");
1342
1343 // Wait for original time 50ms, task should not trigger
1344 // 等待原始时间 50ms,任务应该不触发
1345 tokio::time::sleep(Duration::from_millis(70)).await;
1346 assert_eq!(counter.load(Ordering::SeqCst), 0);
1347
1348 // Receive all timeout notifications
1349 // 接收所有超时通知
1350 let mut received_count = 0;
1351 let rx = service.take_receiver().unwrap();
1352
1353 while received_count < 3 {
1354 match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
1355 Ok(Some(_task_id)) => {
1356 received_count += 1;
1357 }
1358 Ok(None) => break,
1359 Err(_) => break,
1360 }
1361 }
1362
1363 assert_eq!(received_count, 3);
1364
1365 // Wait for callback to execute
1366 // 等待回调执行
1367 tokio::time::sleep(Duration::from_millis(20)).await;
1368 assert_eq!(counter.load(Ordering::SeqCst), 3);
1369 }
1370
1371 #[tokio::test]
1372 async fn test_postpone_batch_with_callbacks() {
1373 let timer = TimerWheel::with_defaults();
1374 let mut service = timer.create_service(ServiceConfig::default());
1375 let counter = Arc::new(AtomicU32::new(0));
1376
1377 // Register 3 tasks
1378 // 注册 3 个任务
1379 let mut task_ids = Vec::new();
1380 for _ in 0..3 {
1381 let task = TimerTask::new_oneshot(
1382 Duration::from_millis(50),
1383 None,
1384 );
1385 task_ids.push(task.get_id());
1386 service.register(task).unwrap();
1387 }
1388
1389 // Batch postpone and replace callback
1390 // 批量延期并替换回调
1391 let updates: Vec<_> = task_ids
1392 .into_iter()
1393 .map(|id| {
1394 let counter_clone = Arc::clone(&counter);
1395 (id, Duration::from_millis(150), Some(CallbackWrapper::new(move || {
1396 let counter = Arc::clone(&counter_clone);
1397 async move {
1398 counter.fetch_add(1, Ordering::SeqCst);
1399 }
1400 })))
1401 })
1402 .collect();
1403
1404 let postponed = service.postpone_batch_with_callbacks(updates);
1405 assert_eq!(postponed, 3, "All 3 tasks should be postponed");
1406
1407 // Wait for original time 50ms, task should not trigger
1408 // 等待原始时间 50ms,任务应该不触发
1409 tokio::time::sleep(Duration::from_millis(70)).await;
1410 assert_eq!(counter.load(Ordering::SeqCst), 0);
1411
1412 // Receive all timeout notifications
1413 // 接收所有超时通知
1414 let mut received_count = 0;
1415 let rx = service.take_receiver().unwrap();
1416
1417 while received_count < 3 {
1418 match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
1419 Ok(Some(_task_id)) => {
1420 received_count += 1;
1421 }
1422 Ok(None) => break,
1423 Err(_) => break,
1424 }
1425 }
1426
1427 assert_eq!(received_count, 3);
1428
1429 // Wait for callback to execute
1430 // 等待回调执行
1431 tokio::time::sleep(Duration::from_millis(20)).await;
1432 assert_eq!(counter.load(Ordering::SeqCst), 3);
1433 }
1434
1435 #[tokio::test]
1436 async fn test_postpone_batch_empty() {
1437 let timer = TimerWheel::with_defaults();
1438 let service = timer.create_service(ServiceConfig::default());
1439
1440 // Postpone empty list
1441 let empty: Vec<(TaskId, Duration, Option<CallbackWrapper>)> = vec![];
1442 let postponed = service.postpone_batch_with_callbacks(empty);
1443 assert_eq!(postponed, 0, "No tasks should be postponed");
1444 }
1445
1446 #[tokio::test]
1447 async fn test_postpone_keeps_timeout_notification_valid() {
1448 let timer = TimerWheel::with_defaults();
1449 let mut service = timer.create_service(ServiceConfig::default());
1450 let counter = Arc::new(AtomicU32::new(0));
1451
1452 // Register a task
1453 // 注册一个任务
1454 let counter_clone = Arc::clone(&counter);
1455 let task = TimerTask::new_oneshot(
1456 Duration::from_millis(50),
1457 Some(CallbackWrapper::new(move || {
1458 let counter = Arc::clone(&counter_clone);
1459 async move {
1460 counter.fetch_add(1, Ordering::SeqCst);
1461 }
1462 })),
1463 );
1464 let task_id = task.get_id();
1465 service.register(task).unwrap();
1466
1467 // Postpone task
1468 // 延期任务
1469 service.postpone(task_id, Duration::from_millis(100), None);
1470
1471 // Verify timeout notification is still valid (after postponing, need to wait 100ms, plus margin)
1472 // 验证超时通知是否仍然有效(延期后,需要等待 100ms,加上余量)
1473 let rx = service.take_receiver().unwrap();
1474 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1475 .await
1476 .expect("Should receive timeout notification")
1477 .expect("Should receive Some value");
1478
1479 assert_eq!(received_task_id, TaskNotification::OneShot(task_id), "Timeout notification should still work after postpone");
1480
1481 // Wait for callback to execute
1482 // 等待回调执行
1483 tokio::time::sleep(Duration::from_millis(20)).await;
1484 assert_eq!(counter.load(Ordering::SeqCst), 1);
1485 }
1486
1487 #[tokio::test]
1488 async fn test_cancelled_task_not_forwarded_to_timeout_rx() {
1489 let timer = TimerWheel::with_defaults();
1490 let mut service = timer.create_service(ServiceConfig::default());
1491
1492 // Register two tasks: one will be cancelled, one will expire normally
1493 // 注册两个任务:一个将被取消,一个将正常过期
1494 let task1 = TimerTask::new_oneshot(Duration::from_secs(10), None);
1495 let task1_id = task1.get_id();
1496 service.register(task1).unwrap();
1497
1498 let task2 = TimerTask::new_oneshot(Duration::from_millis(50), None);
1499 let task2_id = task2.get_id();
1500 service.register(task2).unwrap();
1501
1502 // Cancel first task
1503 // 取消第一个任务
1504 let cancelled = service.cancel_task(task1_id);
1505 assert!(cancelled, "Task should be cancelled");
1506
1507 // Wait for second task to expire
1508 // 等待第二个任务过期
1509 let rx = service.take_receiver().unwrap();
1510 let received_notification = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1511 .await
1512 .expect("Should receive timeout notification")
1513 .expect("Should receive Some value");
1514
1515 // Should only receive notification for second task (expired), not for first task (cancelled)
1516 assert_eq!(received_notification, TaskNotification::OneShot(task2_id), "Should only receive expired task notification");
1517
1518 // Verify no other notifications (especially cancelled tasks should not have notifications)
1519 let no_more = tokio::time::timeout(Duration::from_millis(50), rx.recv()).await;
1520 assert!(no_more.is_err(), "Should not receive any more notifications");
1521 }
1522
1523 #[tokio::test]
1524 async fn test_take_receiver_twice() {
1525 let timer = TimerWheel::with_defaults();
1526 let mut service = timer.create_service(ServiceConfig::default());
1527
1528 // First call should return Some
1529 // 第一次调用应该返回 Some
1530 let rx1 = service.take_receiver();
1531 assert!(rx1.is_some(), "First take_receiver should return Some");
1532
1533 // Second call should return None
1534 // 第二次调用应该返回 None
1535 let rx2 = service.take_receiver();
1536 assert!(rx2.is_none(), "Second take_receiver should return None");
1537 }
1538
1539 #[tokio::test]
1540 async fn test_postpone_batch_without_callbacks() {
1541 let timer = TimerWheel::with_defaults();
1542 let mut service = timer.create_service(ServiceConfig::default());
1543 let counter = Arc::new(AtomicU32::new(0));
1544
1545 // Register 3 tasks, with original callback
1546 // 注册 3 个任务,带有原始回调
1547 let mut task_ids = Vec::new();
1548 for _ in 0..3 {
1549 let counter_clone = Arc::clone(&counter);
1550 let task = TimerTask::new_oneshot(
1551 Duration::from_millis(50),
1552 Some(CallbackWrapper::new(move || {
1553 let counter = Arc::clone(&counter_clone);
1554 async move {
1555 counter.fetch_add(1, Ordering::SeqCst);
1556 }
1557 })),
1558 );
1559 task_ids.push(task.get_id());
1560 service.register(task).unwrap();
1561 }
1562
1563 // Batch postpone, without replacing callback
1564 // 批量延期,不替换回调
1565 let updates: Vec<_> = task_ids
1566 .iter()
1567 .map(|&id| (id, Duration::from_millis(150)))
1568 .collect();
1569 let postponed = service.postpone_batch(updates);
1570 assert_eq!(postponed, 3, "All 3 tasks should be postponed");
1571
1572 // Wait for original time 50ms, task should not trigger
1573 // 等待原始时间 50ms,任务应该不触发
1574 tokio::time::sleep(Duration::from_millis(70)).await;
1575 assert_eq!(counter.load(Ordering::SeqCst), 0, "Callbacks should not fire yet");
1576
1577 // Receive all timeout notifications
1578 // 接收所有超时通知
1579 let mut received_count = 0;
1580 let rx = service.take_receiver().unwrap();
1581
1582 while received_count < 3 {
1583 match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
1584 Ok(Some(_task_id)) => {
1585 received_count += 1;
1586 }
1587 Ok(None) => break,
1588 Err(_) => break,
1589 }
1590 }
1591
1592 assert_eq!(received_count, 3, "Should receive 3 timeout notifications");
1593
1594 // Wait for callback to execute
1595 // 等待回调执行
1596 tokio::time::sleep(Duration::from_millis(20)).await;
1597 assert_eq!(counter.load(Ordering::SeqCst), 3, "All callbacks should execute");
1598 }
1599
1600 #[tokio::test]
1601 async fn test_periodic_task_basic() {
1602 let timer = TimerWheel::with_defaults();
1603 let mut service = timer.create_service(ServiceConfig::default());
1604 let counter = Arc::new(AtomicU32::new(0));
1605
1606 // Register a periodic task with 50ms interval (注册一个 50ms 间隔的周期性任务)
1607 let counter_clone = Arc::clone(&counter);
1608 let task = TimerTask::new_periodic(
1609 Duration::from_millis(30), // initial delay (初始延迟)
1610 Duration::from_millis(50), // interval (间隔)
1611 Some(CallbackWrapper::new(move || {
1612 let counter = Arc::clone(&counter_clone);
1613 async move {
1614 counter.fetch_add(1, Ordering::SeqCst);
1615 }
1616 })),
1617 None,
1618 );
1619 let task_id = task.get_id();
1620 service.register(task).unwrap();
1621
1622 // Receive periodic notifications (接收周期性通知)
1623 let rx = service.take_receiver().unwrap();
1624 let mut notification_count = 0;
1625
1626 // Receive first 3 notifications (接收前 3 个通知)
1627 while notification_count < 3 {
1628 match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
1629 Ok(Some(notification)) => {
1630 match notification {
1631 TaskNotification::Periodic(id) => {
1632 assert_eq!(id, task_id, "Should receive notification for correct task");
1633 notification_count += 1;
1634 }
1635 _ => panic!("Expected periodic notification"),
1636 }
1637 }
1638 Ok(None) => break,
1639 Err(_) => panic!("Timeout waiting for periodic notification"),
1640 }
1641 }
1642
1643 assert_eq!(notification_count, 3, "Should receive 3 periodic notifications");
1644
1645 // Wait for callback to execute
1646 // 等待回调执行
1647 tokio::time::sleep(Duration::from_millis(20)).await;
1648 assert_eq!(counter.load(Ordering::SeqCst), 3, "Callback should execute 3 times");
1649
1650 // Cancel the periodic task (取消周期性任务)
1651 let cancelled = service.cancel_task(task_id);
1652 assert!(cancelled, "Should be able to cancel periodic task");
1653 }
1654
1655 #[tokio::test]
1656 async fn test_periodic_task_cancel_no_notification() {
1657 let timer = TimerWheel::with_defaults();
1658 let mut service = timer.create_service(ServiceConfig::default());
1659
1660 // Register a periodic task (注册周期性任务)
1661 let task = TimerTask::new_periodic(
1662 Duration::from_millis(30),
1663 Duration::from_millis(50),
1664 None,
1665 None,
1666 );
1667 let task_id = task.get_id();
1668 service.register(task).unwrap();
1669
1670 // Wait for first notification (等待第一个通知)
1671 let rx = service.take_receiver().unwrap();
1672 let notification = tokio::time::timeout(Duration::from_millis(100), rx.recv())
1673 .await
1674 .expect("Should receive first notification")
1675 .expect("Should receive Some value");
1676
1677 assert_eq!(notification, TaskNotification::Periodic(task_id));
1678
1679 // Cancel the task (取消任务)
1680 let cancelled = service.cancel_task(task_id);
1681 assert!(cancelled, "Should be able to cancel task");
1682
1683 // Should not receive cancelled notification (不应该接收到取消通知)
1684 match tokio::time::timeout(Duration::from_millis(100), rx.recv()).await {
1685 Ok(Some(_)) => panic!("Should not receive cancelled notification"),
1686 Ok(None) | Err(_) => {} // Expected: timeout or channel closed
1687 }
1688 }
1689
1690 #[tokio::test]
1691 async fn test_mixed_oneshot_and_periodic_tasks() {
1692 let timer = TimerWheel::with_defaults();
1693 let mut service = timer.create_service(ServiceConfig::default());
1694
1695 // Register one-shot tasks (注册一次性任务)
1696 let oneshot_task = TimerTask::new_oneshot(Duration::from_millis(50), None);
1697 let oneshot_id = oneshot_task.get_id();
1698 service.register(oneshot_task).unwrap();
1699
1700 // Register periodic task (注册周期性任务)
1701 let periodic_task = TimerTask::new_periodic(
1702 Duration::from_millis(30),
1703 Duration::from_millis(40),
1704 None,
1705 None,
1706 );
1707 let periodic_id = periodic_task.get_id();
1708 service.register(periodic_task).unwrap();
1709
1710 // Receive notifications (接收通知)
1711 let rx = service.take_receiver().unwrap();
1712 let mut oneshot_received = false;
1713 let mut periodic_count = 0;
1714
1715 // Receive notifications for a while (接收一段时间的通知)
1716 let start = tokio::time::Instant::now();
1717 while start.elapsed() < Duration::from_millis(200) {
1718 match tokio::time::timeout(Duration::from_millis(100), rx.recv()).await {
1719 Ok(Some(notification)) => {
1720 match notification {
1721 TaskNotification::OneShot(id) => {
1722 assert_eq!(id, oneshot_id, "Should be one-shot task");
1723 oneshot_received = true;
1724 }
1725 TaskNotification::Periodic(id) => {
1726 assert_eq!(id, periodic_id, "Should be periodic task");
1727 periodic_count += 1;
1728 }
1729 }
1730 }
1731 Ok(None) => break,
1732 Err(_) => break,
1733 }
1734 }
1735
1736 assert!(oneshot_received, "Should receive one-shot notification");
1737 assert!(periodic_count >= 2, "Should receive at least 2 periodic notifications");
1738
1739 // Cancel periodic task (取消周期性任务)
1740 service.cancel_task(periodic_id);
1741 }
1742
1743 #[tokio::test]
1744 async fn test_periodic_task_batch_register() {
1745 let timer = TimerWheel::with_defaults();
1746 let mut service = timer.create_service(ServiceConfig::default());
1747 let counter = Arc::new(AtomicU32::new(0));
1748
1749 // Register multiple periodic tasks in batch (批量注册多个周期性任务)
1750 let tasks: Vec<_> = (0..3)
1751 .map(|_| {
1752 let counter = Arc::clone(&counter);
1753 TimerTask::new_periodic(
1754 Duration::from_millis(30),
1755 Duration::from_millis(50),
1756 Some(CallbackWrapper::new(move || {
1757 let counter = Arc::clone(&counter);
1758 async move {
1759 counter.fetch_add(1, Ordering::SeqCst);
1760 }
1761 })),
1762 None,
1763 )
1764 })
1765 .collect();
1766
1767 let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
1768 service.register_batch(tasks).unwrap();
1769
1770 // Receive notifications (接收通知)
1771 let rx = service.take_receiver().unwrap();
1772 let mut notification_counts = std::collections::HashMap::new();
1773
1774 // Receive notifications for a while (接收一段时间的通知)
1775 let start = tokio::time::Instant::now();
1776 while start.elapsed() < Duration::from_millis(180) {
1777 match tokio::time::timeout(Duration::from_millis(100), rx.recv()).await {
1778 Ok(Some(TaskNotification::Periodic(id))) => {
1779 *notification_counts.entry(id).or_insert(0) += 1;
1780 }
1781 Ok(Some(_)) => panic!("Expected periodic notification"),
1782 Ok(None) => break,
1783 Err(_) => break,
1784 }
1785 }
1786
1787 // Each task should receive at least 2 notifications (每个任务应该至少收到 2 个通知)
1788 for task_id in &task_ids {
1789 let count = notification_counts.get(task_id).copied().unwrap_or(0);
1790 assert!(count >= 2, "Task {:?} should receive at least 2 notifications, got {}", task_id, count);
1791 }
1792
1793 // Wait for callbacks to execute
1794 // 等待回调执行
1795 tokio::time::sleep(Duration::from_millis(20)).await;
1796 let total_callbacks = counter.load(Ordering::SeqCst);
1797 assert!(total_callbacks >= 6, "Should have at least 6 callback executions (3 tasks * 2), got {}", total_callbacks);
1798
1799 // Cancel all periodic tasks (取消所有周期性任务)
1800 let cancelled = service.cancel_batch(&task_ids);
1801 assert_eq!(cancelled, 3, "Should cancel all 3 tasks");
1802 }
1803}
1804