kestrel_timer/service.rs
1use crate::config::ServiceConfig;
2use crate::error::TimerError;
3use crate::task::{CallbackWrapper, TaskCompletionReason, TaskId};
4use crate::wheel::Wheel;
5use futures::stream::{FuturesUnordered, StreamExt};
6use futures::future::BoxFuture;
7use parking_lot::Mutex;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::sync::mpsc;
11use tokio::task::JoinHandle;
12
13/// TimerService 命令类型
14///
15/// (Service command type)
16enum ServiceCommand {
17 /// 添加批量定时器句柄(仅包含必要数据:task_ids 和 completion_rxs)
18 /// (Add batch timer handle, only contains necessary data: task_ids and completion_rxs)
19 AddBatchHandle {
20 task_ids: Vec<TaskId>,
21 completion_rxs: Vec<tokio::sync::oneshot::Receiver<TaskCompletionReason>>,
22 },
23 /// 添加单个定时器句柄(仅包含必要数据:task_id 和 completion_rx)
24 /// (Add single timer handle, only contains necessary data: task_id and completion_rx)
25 AddTimerHandle {
26 task_id: TaskId,
27 completion_rx: tokio::sync::oneshot::Receiver<TaskCompletionReason>,
28 },
29}
30
31/// TimerService - 基于 Actor 模式的定时器服务
32/// (TimerService - timer service based on Actor pattern)
33///
34/// 管理多个定时器句柄,监听所有超时事件,并将 TaskId 聚合转发给用户。
35/// (Manages multiple timer handles, listens to all timeout events, and aggregates TaskId to be forwarded to the user.)
36/// # 特性 (Features)
37/// - 自动监听所有添加的定时器句柄的超时事件
38/// (Automatically listens to all added timer handles' timeout events)
39/// - 超时后自动从内部管理中移除该任务
40/// (Automatically removes the task from internal management after timeout)
41/// - 将超时的 TaskId 转发到统一的通道供用户接收
42/// (Aggregates TaskId to be forwarded to the user's unified channel)
43/// - 支持动态添加 BatchHandle 和 TimerHandle
44/// (Supports dynamic addition of BatchHandle and TimerHandle)
45///
46/// # 示例 (Examples)
47/// ```no_run
48/// use kestrel_timer::{TimerWheel, TimerService, CallbackWrapper, ServiceConfig};
49/// use std::time::Duration;
50///
51/// #[tokio::main]
52/// async fn main() {
53/// let timer = TimerWheel::with_defaults();
54/// let mut service = timer.create_service(ServiceConfig::default());
55///
56/// // 使用两步式 API 通过 service 批量调度定时器
57/// // (Use two-step API to batch schedule timers through service)
58/// let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = (0..5)
59/// .map(|i| {
60/// let callback = Some(CallbackWrapper::new(move || async move {
61/// println!("Timer {} fired!", i);
62/// }));
63/// (Duration::from_millis(100), callback)
64/// })
65/// .collect();
66/// // (Create batch of tasks with callbacks)
67/// let tasks = TimerService::create_batch_with_callbacks(callbacks);
68/// service.register_batch(tasks).unwrap();
69/// // (Register batch of tasks)
70///
71/// // 接收超时通知
72/// // (Receive timeout notifications)
73/// let mut rx = service.take_receiver().unwrap();
74/// // (Take receiver and receive timeout notifications)
75/// while let Some(task_id) = rx.recv().await {
76/// // (Receive timeout notification)
77/// println!("Task {:?} completed", task_id);
78/// }
79/// }
80/// ```
81pub struct TimerService {
82 /// 命令发送端
83 command_tx: mpsc::Sender<ServiceCommand>,
84 /// 超时接收端
85 timeout_rx: Option<mpsc::Receiver<TaskId>>,
86 /// Actor 任务句柄
87 actor_handle: Option<JoinHandle<()>>,
88 /// 时间轮引用(用于直接调度定时器)
89 wheel: Arc<Mutex<Wheel>>,
90 /// Actor 关闭信号发送端
91 shutdown_tx: Option<tokio::sync::oneshot::Sender<()>>,
92}
93
94impl TimerService {
95 /// 创建新的 TimerService
96 /// Create new TimerService
97 ///
98 /// # 参数 (Parameters)
99 /// - `wheel`: 时间轮引用
100 /// (Timing wheel reference)
101 /// - `config`: 服务配置
102 /// (Service configuration)
103 ///
104 /// # 注意 (Notes)
105 /// 通常不直接调用此方法,而是使用 `TimerWheel::create_service()` 来创建。
106 /// (Typically not called directly, but used to create through `TimerWheel::create_service()`)
107 ///
108 pub(crate) fn new(wheel: Arc<Mutex<Wheel>>, config: ServiceConfig) -> Self {
109 let (command_tx, command_rx) = mpsc::channel(config.command_channel_capacity);
110 let (timeout_tx, timeout_rx) = mpsc::channel(config.timeout_channel_capacity);
111
112 let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
113 let actor = ServiceActor::new(command_rx, timeout_tx, shutdown_rx);
114 let actor_handle = tokio::spawn(async move {
115 actor.run().await;
116 });
117
118 Self {
119 command_tx,
120 timeout_rx: Some(timeout_rx),
121 actor_handle: Some(actor_handle),
122 wheel,
123 shutdown_tx: Some(shutdown_tx),
124 }
125 }
126
127 /// 获取超时接收器(转移所有权)
128 /// (Get timeout receiver, transfer ownership)
129 ///
130 /// # 返回 (Returns)
131 /// 超时通知接收器,如果已经被取走则返回 None
132 /// (Timeout notification receiver, if already taken, returns None)
133 ///
134 /// # 注意 (Notes)
135 /// 此方法只能调用一次,因为它会转移接收器的所有权
136 /// (This method can only be called once, because it transfers ownership of the receiver)
137 ///
138 /// # 示例 (Examples)s
139 /// ```no_run
140 /// # use kestrel_timer::{TimerWheel, ServiceConfig};
141 /// # #[tokio::main]
142 /// # async fn main() {
143 /// let timer = TimerWheel::with_defaults();
144 /// let mut service = timer.create_service(ServiceConfig::default());
145 ///
146 /// let mut rx = service.take_receiver().unwrap();
147 /// while let Some(task_id) = rx.recv().await {
148 /// println!("Task {:?} timed out", task_id);
149 /// }
150 /// # }
151 /// ```
152 pub fn take_receiver(&mut self) -> Option<mpsc::Receiver<TaskId>> {
153 self.timeout_rx.take()
154 }
155
156 /// 取消指定的任务
157 /// (Cancel specified task)
158 ///
159 /// # 参数 (Parameters)
160 /// - `task_id`: 要取消的任务 ID
161 /// (Task ID to cancel)
162 ///
163 /// # 返回 (Returns)
164 /// - `true`: 任务存在且成功取消
165 /// (Task exists and cancellation is successful)
166 /// - `false`: 任务不存在或取消失败
167 /// (Task does not exist or cancellation fails)
168 ///
169 /// # 性能说明 (Performance Notes)
170 /// 此方法使用直接取消优化,不需要等待 Actor 处理,大幅降低延迟
171 /// (This method uses direct cancellation optimization, does not need to wait for Actor processing, greatly reducing latency)
172 ///
173 /// # 示例 (Examples)
174 /// ```no_run
175 /// # use kestrel_timer::{TimerWheel, TimerService, CallbackWrapper, ServiceConfig};
176 /// # use std::time::Duration;
177 /// #
178 /// # #[tokio::main]
179 /// # async fn main() {
180 /// let timer = TimerWheel::with_defaults();
181 /// let service = timer.create_service(ServiceConfig::default());
182 ///
183 /// // 使用两步式 API 调度定时器
184 /// // (Use two-step API to schedule timers)
185 /// let callback = Some(CallbackWrapper::new(|| async move {
186 /// println!("Timer fired!");
187 /// }));
188 /// let task = TimerService::create_task(Duration::from_secs(10), callback);
189 /// let task_id = task.get_id();
190 /// service.register(task).unwrap();
191 ///
192 /// // 取消任务
193 /// // (Cancel task)
194 /// let cancelled = service.cancel_task(task_id);
195 /// println!("Task cancelled: {}", cancelled);
196 /// # }
197 /// ```
198 #[inline]
199 pub fn cancel_task(&self, task_id: TaskId) -> bool {
200 // 直接取消任务,无需通知 Actor
201 // FuturesUnordered 会在任务被取消时自动清理
202 //
203 // Direct cancellation, no need to notify Actor
204 // FuturesUnordered will automatically clean up when tasks are cancelled
205 let mut wheel = self.wheel.lock();
206 wheel.cancel(task_id)
207 }
208
209 /// 批量取消任务
210 /// (Batch cancel tasks)
211 ///
212 /// 使用底层的批量取消操作一次性取消多个任务,性能优于循环调用 cancel_task。
213 /// (Use underlying batch cancellation operation to cancel multiple tasks at once, performance is better than calling cancel_task repeatedly.)
214 ///
215 /// # 参数 (Parameters)
216 /// - `task_ids`: 要取消的任务 ID 列表
217 /// (List of task IDs to cancel)
218 ///
219 /// # 返回 (Returns)
220 /// 成功取消的任务数量
221 /// (Number of successfully cancelled tasks)
222 ///
223 /// # 示例 (Examples)
224 /// ```no_run
225 /// # use kestrel_timer::{TimerWheel, TimerService, CallbackWrapper, ServiceConfig};
226 /// # use std::time::Duration;
227 /// #
228 /// # #[tokio::main]
229 /// # async fn main() {
230 /// let timer = TimerWheel::with_defaults();
231 /// let service = timer.create_service(ServiceConfig::default());
232 ///
233 /// let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = (0..10)
234 /// .map(|i| {
235 /// let callback = Some(CallbackWrapper::new(move || async move {
236 /// println!("Timer {} fired!", i);
237 /// }));
238 /// (Duration::from_secs(10), callback)
239 /// })
240 /// .collect();
241 /// let tasks = TimerService::create_batch_with_callbacks(callbacks);
242 /// let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
243 /// service.register_batch(tasks).unwrap();
244 ///
245 /// // 批量取消
246 /// let cancelled = service.cancel_batch(&task_ids);
247 /// println!("Cancelled {} tasks", cancelled);
248 /// # }
249 /// ```
250 #[inline]
251 pub fn cancel_batch(&self, task_ids: &[TaskId]) -> usize {
252 if task_ids.is_empty() {
253 return 0;
254 }
255
256 // 优化:直接使用底层的批量取消,无需通知 Actor
257 // FuturesUnordered 会在任务被取消时自动清理
258 let mut wheel = self.wheel.lock();
259 wheel.cancel_batch(task_ids)
260 }
261
262 /// 推迟任务(替换回调)
263 /// (Postpone task (replace callback))
264 ///
265 /// # 参数 (Parameters)
266 /// - `task_id`: 要推迟的任务 ID
267 /// (Task ID to postpone)
268 /// - `new_delay`: 新的延迟时间(从当前时间点重新计算)
269 /// (New delay time (recalculated from current time point))
270 /// - `callback`: 新的回调函数
271 /// (New callback function)
272 ///
273 /// # 返回 (Returns)
274 /// - `true`: 任务存在且成功推迟
275 /// (Task exists and is successfully postponed)
276 /// - `false`: 任务不存在或推迟失败
277 /// (Task does not exist or postponement fails)
278 ///
279 /// # 注意 (Notes)
280 /// - 推迟后任务 ID 保持不变
281 /// (Task ID remains unchanged after postponement)
282 /// - 原有的超时通知仍然有效
283 /// (Original timeout notification remains valid)
284 /// - 回调函数会被替换为新的回调
285 /// (Callback function will be replaced with new callback)
286 ///
287 /// # 示例 (Examples)
288 /// ```no_run
289 /// # use kestrel_timer::{TimerWheel, TimerService, CallbackWrapper, ServiceConfig};
290 /// # use std::time::Duration;
291 /// #
292 /// # #[tokio::main]
293 /// # async fn main() {
294 /// let timer = TimerWheel::with_defaults();
295 /// let service = timer.create_service(ServiceConfig::default());
296 ///
297 /// let callback = Some(CallbackWrapper::new(|| async {
298 /// println!("Original callback");
299 /// }));
300 /// let task = TimerService::create_task(Duration::from_secs(5), callback);
301 /// let task_id = task.get_id();
302 /// service.register(task).unwrap();
303 ///
304 /// // 推迟并替换回调
305 /// // (Postpone and replace callback)
306 /// let new_callback = Some(CallbackWrapper::new(|| async { println!("New callback!"); }));
307 /// let success = service.postpone(
308 /// task_id,
309 /// Duration::from_secs(10),
310 /// new_callback
311 /// );
312 /// println!("Postponed successfully: {}", success);
313 /// # }
314 /// ```
315 #[inline]
316 pub fn postpone(&self, task_id: TaskId, new_delay: Duration, callback: Option<CallbackWrapper>) -> bool {
317 let mut wheel = self.wheel.lock();
318 wheel.postpone(task_id, new_delay, callback)
319 }
320
321 /// 批量推迟任务(保持原回调)
322 /// (Batch postpone tasks, keep original callbacks)
323 ///
324 /// # 参数 (Parameters)
325 /// - `updates`: (任务ID, 新延迟) 的元组列表
326 /// (List of tuples of (task ID, new delay))
327 ///
328 /// # 返回 (Returns)
329 /// 成功推迟的任务数量
330 /// (Number of successfully postponed tasks)
331 ///
332 /// # 示例 (Examples)
333 /// ```no_run
334 /// # use kestrel_timer::{TimerWheel, TimerService, CallbackWrapper, ServiceConfig};
335 /// # use std::time::Duration;
336 /// #
337 /// # #[tokio::main]
338 /// # async fn main() {
339 /// let timer = TimerWheel::with_defaults();
340 /// let service = timer.create_service(ServiceConfig::default());
341 ///
342 /// let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = (0..3)
343 /// .map(|i| {
344 /// let callback = Some(CallbackWrapper::new(move || async move {
345 /// println!("Timer {} fired!", i);
346 /// }));
347 /// (Duration::from_secs(5), callback)
348 /// })
349 /// .collect();
350 /// let tasks = TimerService::create_batch_with_callbacks(callbacks);
351 /// let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
352 /// service.register_batch(tasks).unwrap();
353 ///
354 /// // 批量推迟(保持原回调)
355 /// // (Batch postpone, keep original callbacks)
356 /// let updates: Vec<_> = task_ids
357 /// .into_iter()
358 /// .map(|id| (id, Duration::from_secs(10)))
359 /// .collect();
360 /// let postponed = service.postpone_batch(updates);
361 /// println!("Postponed {} tasks", postponed);
362 /// # }
363 /// ```
364 #[inline]
365 pub fn postpone_batch(&self, updates: Vec<(TaskId, Duration)>) -> usize {
366 if updates.is_empty() {
367 return 0;
368 }
369
370 let mut wheel = self.wheel.lock();
371 wheel.postpone_batch(updates)
372 }
373
374 /// 批量推迟任务(替换回调)
375 /// (Batch postpone tasks, replace callbacks)
376 ///
377 /// # 参数 (Parameters)
378 /// - `updates`: (任务ID, 新延迟, 新回调) 的元组列表
379 /// (List of tuples of (task ID, new delay, new callback))
380 ///
381 /// # 返回 (Returns)
382 /// 成功推迟的任务数量
383 /// (Number of successfully postponed tasks)
384 ///
385 /// # 示例 (Examples)
386 /// ```no_run
387 /// # use kestrel_timer::{TimerWheel, TimerService, CallbackWrapper, ServiceConfig};
388 /// # use std::time::Duration;
389 /// #
390 /// # #[tokio::main]
391 /// # async fn main() {
392 /// let timer = TimerWheel::with_defaults();
393 /// let service = timer.create_service(ServiceConfig::default());
394 ///
395 /// // 创建 3 个任务,初始没有回调
396 /// // (Create 3 tasks, initially no callbacks)
397 /// let delays: Vec<Duration> = (0..3)
398 /// .map(|_| Duration::from_secs(5))
399 /// .collect();
400 /// let tasks = TimerService::create_batch(delays);
401 /// let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
402 /// service.register_batch(tasks).unwrap();
403 ///
404 /// // 批量推迟并添加新回调
405 /// // (Batch postpone and add new callbacks)
406 /// let updates: Vec<_> = task_ids
407 /// .into_iter()
408 /// .enumerate()
409 /// .map(|(i, id)| {
410 /// let callback = Some(CallbackWrapper::new(move || async move {
411 /// println!("New callback {}", i);
412 /// }));
413 /// (id, Duration::from_secs(10), callback)
414 /// })
415 /// .collect();
416 /// let postponed = service.postpone_batch_with_callbacks(updates);
417 /// println!("Postponed {} tasks", postponed);
418 /// # }
419 /// ```
420 #[inline]
421 pub fn postpone_batch_with_callbacks(
422 &self,
423 updates: Vec<(TaskId, Duration, Option<CallbackWrapper>)>,
424 ) -> usize {
425 if updates.is_empty() {
426 return 0;
427 }
428
429 let mut wheel = self.wheel.lock();
430 wheel.postpone_batch_with_callbacks(updates)
431 }
432
433 /// 创建定时器任务(静态方法,申请阶段)
434 /// (Create timer task (static method, apply stage))
435 ///
436 /// # 参数 (Parameters)
437 /// - `delay`: 延迟时间
438 /// (Delay time)
439 /// - `callback`: 实现了 TimerCallback trait 的回调对象
440 /// (Callback object implementing TimerCallback trait)
441 ///
442 /// # 返回 (Returns)
443 /// 返回 TimerTask,需要通过 `register()` 注册
444 /// (Return TimerTask, needs to be registered through `register()`)
445 ///
446 /// # 示例 (Examples)
447 /// ```no_run
448 /// # use kestrel_timer::{TimerWheel, TimerService, CallbackWrapper, ServiceConfig};
449 /// # use std::time::Duration;
450 /// #
451 /// # #[tokio::main]
452 /// # async fn main() {
453 /// let timer = TimerWheel::with_defaults();
454 /// let service = timer.create_service(ServiceConfig::default());
455 ///
456 /// // 步骤 1: 创建任务
457 /// // (Step 1: create task)
458 /// let callback = Some(CallbackWrapper::new(|| async {
459 /// println!("Timer fired!");
460 /// }));
461 /// let task = TimerService::create_task(Duration::from_millis(100), callback);
462 ///
463 /// let task_id = task.get_id();
464 /// println!("Created task: {:?}", task_id);
465 ///
466 /// // 步骤 2: 注册任务
467 /// // (Step 2: register task)
468 /// service.register(task).unwrap();
469 /// # }
470 /// ```
471 #[inline]
472 pub fn create_task(delay: Duration, callback: Option<CallbackWrapper>) -> crate::task::TimerTask {
473 crate::timer::TimerWheel::create_task(delay, callback)
474 }
475
476 /// 批量创建定时器任务(静态方法,申请阶段,不带回调)
477 /// (Create batch of timer tasks (static method, apply stage, no callbacks))
478 ///
479 /// # 参数 (Parameters)
480 /// - `delays`: 延迟时间列表
481 /// (List of delay times)
482 ///
483 /// # 返回 (Returns)
484 /// 返回 TimerTask 列表,需要通过 `register_batch()` 注册
485 /// (Return TimerTask list, needs to be registered through `register_batch()`)
486 ///
487 /// # 示例 (Examples)
488 /// ```no_run
489 /// # use kestrel_timer::{TimerWheel, TimerService, CallbackWrapper, ServiceConfig};
490 /// # use std::time::Duration;
491 /// #
492 /// # #[tokio::main]
493 /// # async fn main() {
494 /// let timer = TimerWheel::with_defaults();
495 /// let service = timer.create_service(ServiceConfig::default());
496 ///
497 /// // 步骤 1: 批量创建任务
498 /// // (Step 1: create batch of tasks)
499 /// let delays: Vec<Duration> = (0..3)
500 /// .map(|i| Duration::from_millis(100 * (i + 1)))
501 /// .collect();
502 ///
503 /// let tasks = TimerService::create_batch(delays);
504 /// println!("Created {} tasks", tasks.len());
505 ///
506 /// // 步骤 2: 批量注册任务
507 /// // (Step 2: register batch of tasks)
508 /// service.register_batch(tasks).unwrap();
509 /// # }
510 /// ```
511 #[inline]
512 pub fn create_batch(delays: Vec<Duration>) -> Vec<crate::task::TimerTask> {
513 crate::timer::TimerWheel::create_batch(delays)
514 }
515
516 /// 批量创建定时器任务(静态方法,申请阶段,带回调)
517 ///
518 /// # 参数 (Parameters)
519 /// - `callbacks`: (延迟时间, 回调) 的元组列表
520 /// (List of tuples of (delay time, callback))
521 ///
522 /// # 返回 (Returns)
523 /// 返回 TimerTask 列表,需要通过 `register_batch()` 注册
524 /// (Return TimerTask list, needs to be registered through `register_batch()`)
525 ///
526 /// # 示例 (Examples)
527 /// ```no_run
528 /// # use kestrel_timer::{TimerWheel, TimerService, CallbackWrapper, ServiceConfig};
529 /// # use std::time::Duration;
530 /// #
531 /// # #[tokio::main]
532 /// # async fn main() {
533 /// let timer = TimerWheel::with_defaults();
534 /// let service = timer.create_service(ServiceConfig::default());
535 ///
536 /// // 步骤 1: 批量创建任务
537 /// // (Step 1: create batch of tasks with callbacks)
538 /// let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = (0..3)
539 /// .map(|i| {
540 /// let callback = Some(CallbackWrapper::new(move || async move {
541 /// println!("Timer {} fired!", i);
542 /// }));
543 /// (Duration::from_millis(100 * (i + 1)), callback)
544 /// })
545 /// .collect();
546 ///
547 /// let tasks = TimerService::create_batch_with_callbacks(callbacks);
548 /// println!("Created {} tasks", tasks.len());
549 ///
550 /// // 步骤 2: 批量注册任务
551 /// // (Step 2: register batch of tasks with callbacks)
552 /// service.register_batch(tasks).unwrap();
553 /// # }
554 /// ```
555 #[inline]
556 pub fn create_batch_with_callbacks(callbacks: Vec<(Duration, Option<CallbackWrapper>)>) -> Vec<crate::task::TimerTask> {
557 crate::timer::TimerWheel::create_batch_with_callbacks(callbacks)
558 }
559
560 /// 注册定时器任务到服务(注册阶段)
561 /// (Register timer task to service (registration phase))
562 ///
563 /// # 参数 (Parameters)
564 /// - `task`: 通过 `create_task()` 创建的任务
565 /// (Task created via `create_task()`)
566 ///
567 /// # 返回 (Returns)
568 /// (Task created via `create_task()`)
569 /// - `Ok(())`: 注册成功
570 /// (Register successfully)
571 /// - `Err(TimerError::RegisterFailed)`: 注册失败(内部通道已满或已关闭)
572 /// (Register failed (internal channel is full or closed))
573 ///
574 /// # 示例 (Examples)
575 /// ```no_run
576 /// # use kestrel_timer::{TimerWheel, TimerService, CallbackWrapper, ServiceConfig};
577 /// # use std::time::Duration;
578 /// #
579 /// # #[tokio::main]
580 /// # async fn main() {
581 /// let timer = TimerWheel::with_defaults();
582 /// let service = timer.create_service(ServiceConfig::default());
583 ///
584 /// let callback = Some(CallbackWrapper::new(|| async move {
585 /// println!("Timer fired!");
586 /// }));
587 /// let task = TimerService::create_task(Duration::from_millis(100), callback);
588 /// let task_id = task.get_id();
589 ///
590 /// service.register(task).unwrap();
591 /// # }
592 /// ```
593 #[inline]
594 pub fn register(&self, task: crate::task::TimerTask) -> Result<(), TimerError> {
595 let (completion_tx, completion_rx) = tokio::sync::oneshot::channel();
596 let notifier = crate::task::CompletionNotifier(completion_tx);
597
598 let task_id = task.id;
599
600 // 单次加锁完成所有操作
601 // (Single lock, complete all operations)
602 {
603 let mut wheel_guard = self.wheel.lock();
604 wheel_guard.insert(task, notifier);
605 }
606
607 // 添加到服务管理(只发送必要数据)
608 // (Add to service management, only send necessary data)
609 self.command_tx
610 .try_send(ServiceCommand::AddTimerHandle {
611 task_id,
612 completion_rx,
613 })
614 .map_err(|_| TimerError::RegisterFailed)?;
615
616 Ok(())
617 }
618
619 /// 批量注册定时器任务到服务(注册阶段)
620 /// (Batch register timer tasks to service (registration phase))
621 ///
622 /// # 参数 (Parameters)
623 /// - `tasks`: 通过 `create_batch()` 创建的任务列表
624 /// (List of tasks created via `create_batch()`)
625 ///
626 /// # 返回 (Returns)
627 /// (List of tasks created via `create_batch()`)
628 /// - `Ok(())`: 注册成功
629 /// (Register successfully)
630 /// - `Err(TimerError::RegisterFailed)`: 注册失败(内部通道已满或已关闭)
631 /// (Register failed (internal channel is full or closed))
632 ///
633 /// # 示例 (Examples)
634 /// ```no_run
635 /// # use kestrel_timer::{TimerWheel, TimerService, CallbackWrapper, ServiceConfig};
636 /// # use std::time::Duration;
637 /// #
638 /// # #[tokio::main]
639 /// # async fn main() {
640 /// let timer = TimerWheel::with_defaults();
641 /// let service = timer.create_service(ServiceConfig::default());
642 ///
643 /// let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = (0..3)
644 /// .map(|i| {
645 /// let callback = Some(CallbackWrapper::new(move || async move {
646 /// println!("Timer {} fired!", i);
647 /// }));
648 /// (Duration::from_secs(1), callback)
649 /// })
650 /// .collect();
651 /// let tasks = TimerService::create_batch_with_callbacks(callbacks);
652 ///
653 /// service.register_batch(tasks).unwrap();
654 /// # }
655 /// ```
656 #[inline]
657 pub fn register_batch(&self, tasks: Vec<crate::task::TimerTask>) -> Result<(), TimerError> {
658 let task_count = tasks.len();
659 let mut completion_rxs = Vec::with_capacity(task_count);
660 let mut task_ids = Vec::with_capacity(task_count);
661 let mut prepared_tasks = Vec::with_capacity(task_count);
662
663 // 步骤1: 准备所有 channels 和 notifiers(无锁)
664 // 优化:使用 for 循环代替 map + collect,避免闭包捕获开销
665 // (Step 1: Prepare all channels and notifiers)
666 // (Optimize: use for loop instead of map + collect, avoid closure capture overhead)
667 for task in tasks {
668 let (completion_tx, completion_rx) = tokio::sync::oneshot::channel();
669 let notifier = crate::task::CompletionNotifier(completion_tx);
670
671 task_ids.push(task.id);
672 completion_rxs.push(completion_rx);
673 prepared_tasks.push((task, notifier));
674 }
675
676 // 步骤2: 单次加锁,批量插入
677 // (Step 2: Single lock, batch insert)
678 {
679 let mut wheel_guard = self.wheel.lock();
680 wheel_guard.insert_batch(prepared_tasks);
681 }
682
683 // 添加到服务管理(只发送必要数据)
684 // (Add to service management, only send necessary data)
685 self.command_tx
686 .try_send(ServiceCommand::AddBatchHandle {
687 task_ids,
688 completion_rxs,
689 })
690 .map_err(|_| TimerError::RegisterFailed)?;
691
692 Ok(())
693 }
694
695 /// 优雅关闭 TimerService
696 /// (Graceful shutdown of TimerService)
697 ///
698 /// # 示例 (Examples)
699 /// ```no_run
700 /// # use kestrel_timer::{TimerWheel, ServiceConfig};
701 /// # #[tokio::main]
702 /// # async fn main() {
703 /// let timer = TimerWheel::with_defaults();
704 /// let mut service = timer.create_service(ServiceConfig::default());
705 ///
706 /// // 使用 service...
707 /// // (Use service...)
708 ///
709 /// service.shutdown().await;
710 /// # }
711 /// ```
712 pub async fn shutdown(mut self) {
713 // (Take shutdown_tx and send shutdown signal)
714 if let Some(shutdown_tx) = self.shutdown_tx.take() {
715 let _ = shutdown_tx.send(());
716 }
717 // (Take actor_handle and await completion)
718 if let Some(handle) = self.actor_handle.take() {
719 let _ = handle.await;
720 }
721 }
722}
723
724
725impl Drop for TimerService {
726 fn drop(&mut self) {
727 if let Some(handle) = self.actor_handle.take() {
728 handle.abort();
729 }
730 }
731}
732
733/// ServiceActor - 内部 Actor 实现
734/// (ServiceActor - internal Actor implementation)
735struct ServiceActor {
736 /// 命令接收端
737 command_rx: mpsc::Receiver<ServiceCommand>,
738 /// 超时发送端
739 timeout_tx: mpsc::Sender<TaskId>,
740 /// Actor 关闭信号接收端
741 shutdown_rx: tokio::sync::oneshot::Receiver<()>,
742}
743
744impl ServiceActor {
745 fn new(command_rx: mpsc::Receiver<ServiceCommand>, timeout_tx: mpsc::Sender<TaskId>, shutdown_rx: tokio::sync::oneshot::Receiver<()>) -> Self {
746 Self {
747 command_rx,
748 timeout_tx,
749 shutdown_rx,
750 }
751 }
752
753 async fn run(mut self) {
754 // 使用 FuturesUnordered 来监听所有的 completion_rxs
755 // 每个 future 返回 (TaskId, Result)
756 //
757 // (Use FuturesUnordered to listen to all completion_rxs)
758 // (Each future returns (TaskId, Result))
759 let mut futures: FuturesUnordered<BoxFuture<'static, (TaskId, Result<TaskCompletionReason, tokio::sync::oneshot::error::RecvError>)>> = FuturesUnordered::new();
760
761 // 将 shutdown_rx 移出 self,以便在 select! 中使用 &mut
762 // (Move shutdown_rx out of self, so it can be used in select! with &mut)
763 let mut shutdown_rx = self.shutdown_rx;
764
765 loop {
766 tokio::select! {
767 // 监听高优先级的关闭信号
768 // (Listen to high-priority shutdown signal)
769 _ = &mut shutdown_rx => {
770 // 收到关闭信号,立即退出循环
771 // (Receive shutdown signal, exit loop immediately)
772 break;
773 }
774
775 // 监听超时事件
776 // (Listen to timeout events)
777 Some((task_id, result)) = futures.next() => {
778 // 检查完成原因,只转发超时(Expired)事件,不转发取消(Cancelled)事件
779 // (Check completion reason, only forward expired (Expired) events, do not forward cancelled (Cancelled) events)
780 if let Ok(TaskCompletionReason::Expired) = result {
781 let _ = self.timeout_tx.send(task_id).await;
782 }
783 // 任务会自动从 FuturesUnordered 中移除
784 // (Task will be automatically removed from FuturesUnordered)
785 }
786
787 // 监听命令
788 // (Listen to commands)
789 Some(cmd) = self.command_rx.recv() => {
790 match cmd {
791 ServiceCommand::AddBatchHandle { task_ids, completion_rxs } => {
792 // 将所有任务添加到 futures 中
793 // (Add all tasks to futures)
794 for (task_id, rx) in task_ids.into_iter().zip(completion_rxs.into_iter()) {
795 let future: BoxFuture<'static, (TaskId, Result<TaskCompletionReason, tokio::sync::oneshot::error::RecvError>)> = Box::pin(async move {
796 (task_id, rx.await)
797 });
798 futures.push(future);
799 }
800 }
801 ServiceCommand::AddTimerHandle { task_id, completion_rx } => {
802 // 添加到 futures 中
803 // (Add to futures)
804 let future: BoxFuture<'static, (TaskId, Result<TaskCompletionReason, tokio::sync::oneshot::error::RecvError>)> = Box::pin(async move {
805 (task_id, completion_rx.await)
806 });
807 futures.push(future);
808 }
809 }
810 }
811
812 // 如果没有任何 future 且命令通道已关闭,退出循环
813 // (If no futures and command channel is closed, exit loop)
814 else => {
815 break;
816 }
817 }
818 }
819 }
820}
821
822#[cfg(test)]
823mod tests {
824 use super::*;
825 use crate::TimerWheel;
826 use std::sync::atomic::{AtomicU32, Ordering};
827 use std::sync::Arc;
828 use std::time::Duration;
829
830 #[tokio::test]
831 async fn test_service_creation() {
832 let timer = TimerWheel::with_defaults();
833 let _service = timer.create_service(ServiceConfig::default());
834 }
835
836
837 #[tokio::test]
838 async fn test_add_timer_handle_and_receive_timeout() {
839 let timer = TimerWheel::with_defaults();
840 let mut service = timer.create_service(ServiceConfig::default());
841
842 // 创建单个定时器
843 let task = TimerService::create_task(Duration::from_millis(50), Some(CallbackWrapper::new(|| async {})));
844 let task_id = task.get_id();
845
846 // 注册到 service
847 service.register(task).unwrap();
848
849 // 接收超时通知
850 let mut rx = service.take_receiver().unwrap();
851 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
852 .await
853 .expect("Should receive timeout notification")
854 .expect("Should receive Some value");
855
856 assert_eq!(received_task_id, task_id);
857 }
858
859
860 #[tokio::test]
861 async fn test_shutdown() {
862 let timer = TimerWheel::with_defaults();
863 let service = timer.create_service(ServiceConfig::default());
864
865 // 添加一些定时器
866 let task1 = TimerService::create_task(Duration::from_secs(10), None);
867 let task2 = TimerService::create_task(Duration::from_secs(10), None);
868 service.register(task1).unwrap();
869 service.register(task2).unwrap();
870
871 // 立即关闭(不等待定时器触发)
872 service.shutdown().await;
873 }
874
875
876
877 #[tokio::test]
878 async fn test_cancel_task() {
879 let timer = TimerWheel::with_defaults();
880 let service = timer.create_service(ServiceConfig::default());
881
882 // 添加一个长时间的定时器
883 let task = TimerService::create_task(Duration::from_secs(10), None);
884 let task_id = task.get_id();
885
886 service.register(task).unwrap();
887
888 // 取消任务
889 let cancelled = service.cancel_task(task_id);
890 assert!(cancelled, "Task should be cancelled successfully");
891
892 // 尝试再次取消同一个任务,应该返回 false
893 let cancelled_again = service.cancel_task(task_id);
894 assert!(!cancelled_again, "Task should not exist anymore");
895 }
896
897 #[tokio::test]
898 async fn test_cancel_nonexistent_task() {
899 let timer = TimerWheel::with_defaults();
900 let service = timer.create_service(ServiceConfig::default());
901
902 // 添加一个定时器以初始化 service
903 let task = TimerService::create_task(Duration::from_millis(50), None);
904 service.register(task).unwrap();
905
906 // 尝试取消一个不存在的任务(创建一个不会实际注册的任务ID)
907 let fake_task = TimerService::create_task(Duration::from_millis(50), None);
908 let fake_task_id = fake_task.get_id();
909 // 不注册 fake_task
910 let cancelled = service.cancel_task(fake_task_id);
911 assert!(!cancelled, "Nonexistent task should not be cancelled");
912 }
913
914
915 #[tokio::test]
916 async fn test_task_timeout_cleans_up_task_sender() {
917 let timer = TimerWheel::with_defaults();
918 let mut service = timer.create_service(ServiceConfig::default());
919
920 // 添加一个短时间的定时器
921 let task = TimerService::create_task(Duration::from_millis(50), None);
922 let task_id = task.get_id();
923
924 service.register(task).unwrap();
925
926 // 等待任务超时
927 let mut rx = service.take_receiver().unwrap();
928 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
929 .await
930 .expect("Should receive timeout notification")
931 .expect("Should receive Some value");
932
933 assert_eq!(received_task_id, task_id);
934
935 // 等待一下确保内部清理完成
936 tokio::time::sleep(Duration::from_millis(10)).await;
937
938 // 尝试取消已经超时的任务,应该返回 false
939 let cancelled = service.cancel_task(task_id);
940 assert!(!cancelled, "Timed out task should not exist anymore");
941 }
942
943 #[tokio::test]
944 async fn test_cancel_task_spawns_background_task() {
945 let timer = TimerWheel::with_defaults();
946 let service = timer.create_service(ServiceConfig::default());
947 let counter = Arc::new(AtomicU32::new(0));
948
949 // 创建一个定时器
950 let counter_clone = Arc::clone(&counter);
951 let task = TimerService::create_task(
952 Duration::from_secs(10),
953 Some(CallbackWrapper::new(move || {
954 let counter = Arc::clone(&counter_clone);
955 async move {
956 counter.fetch_add(1, Ordering::SeqCst);
957 }
958 })),
959 );
960 let task_id = task.get_id();
961
962 service.register(task).unwrap();
963
964 // 使用 cancel_task(会等待结果,但在后台协程中处理)
965 let cancelled = service.cancel_task(task_id);
966 assert!(cancelled, "Task should be cancelled successfully");
967
968 // 等待足够长时间确保回调不会被执行
969 tokio::time::sleep(Duration::from_millis(100)).await;
970 assert_eq!(counter.load(Ordering::SeqCst), 0, "Callback should not have been executed");
971
972 // 验证任务已从 active_tasks 中移除
973 let cancelled_again = service.cancel_task(task_id);
974 assert!(!cancelled_again, "Task should have been removed from active_tasks");
975 }
976
977 #[tokio::test]
978 async fn test_schedule_once_direct() {
979 let timer = TimerWheel::with_defaults();
980 let mut service = timer.create_service(ServiceConfig::default());
981 let counter = Arc::new(AtomicU32::new(0));
982
983 // 直接通过 service 调度定时器
984 let counter_clone = Arc::clone(&counter);
985 let task = TimerService::create_task(
986 Duration::from_millis(50),
987 Some(CallbackWrapper::new(move || {
988 let counter = Arc::clone(&counter_clone);
989 async move {
990 counter.fetch_add(1, Ordering::SeqCst);
991 }
992 })),
993 );
994 let task_id = task.get_id();
995 service.register(task).unwrap();
996
997 // 等待定时器触发
998 let mut rx = service.take_receiver().unwrap();
999 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1000 .await
1001 .expect("Should receive timeout notification")
1002 .expect("Should receive Some value");
1003
1004 assert_eq!(received_task_id, task_id);
1005
1006 // 等待回调执行
1007 tokio::time::sleep(Duration::from_millis(50)).await;
1008 assert_eq!(counter.load(Ordering::SeqCst), 1);
1009 }
1010
1011 #[tokio::test]
1012 async fn test_schedule_once_batch_direct() {
1013 let timer = TimerWheel::with_defaults();
1014 let mut service = timer.create_service(ServiceConfig::default());
1015 let counter = Arc::new(AtomicU32::new(0));
1016
1017 // 直接通过 service 批量调度定时器
1018 let callbacks: Vec<_> = (0..3)
1019 .map(|_| {
1020 let counter = Arc::clone(&counter);
1021 (Duration::from_millis(50), Some(CallbackWrapper::new(move || {
1022 let counter = Arc::clone(&counter);
1023 async move {
1024 counter.fetch_add(1, Ordering::SeqCst);
1025 }
1026 })))
1027 })
1028 .collect();
1029
1030 let tasks = TimerService::create_batch_with_callbacks(callbacks);
1031 assert_eq!(tasks.len(), 3);
1032 service.register_batch(tasks).unwrap();
1033
1034 // 接收所有超时通知
1035 let mut received_count = 0;
1036 let mut rx = service.take_receiver().unwrap();
1037
1038 while received_count < 3 {
1039 match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
1040 Ok(Some(_task_id)) => {
1041 received_count += 1;
1042 }
1043 Ok(None) => break,
1044 Err(_) => break,
1045 }
1046 }
1047
1048 assert_eq!(received_count, 3);
1049
1050 // 等待回调执行
1051 tokio::time::sleep(Duration::from_millis(50)).await;
1052 assert_eq!(counter.load(Ordering::SeqCst), 3);
1053 }
1054
1055 #[tokio::test]
1056 async fn test_schedule_once_notify_direct() {
1057 let timer = TimerWheel::with_defaults();
1058 let mut service = timer.create_service(ServiceConfig::default());
1059
1060 // 直接通过 service 调度仅通知的定时器(无回调)
1061 let task = TimerService::create_task(Duration::from_millis(50), None);
1062 let task_id = task.get_id();
1063 service.register(task).unwrap();
1064
1065 // 接收超时通知
1066 let mut rx = service.take_receiver().unwrap();
1067 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1068 .await
1069 .expect("Should receive timeout notification")
1070 .expect("Should receive Some value");
1071
1072 assert_eq!(received_task_id, task_id);
1073 }
1074
1075 #[tokio::test]
1076 async fn test_schedule_and_cancel_direct() {
1077 let timer = TimerWheel::with_defaults();
1078 let service = timer.create_service(ServiceConfig::default());
1079 let counter = Arc::new(AtomicU32::new(0));
1080
1081 // 直接调度定时器
1082 let counter_clone = Arc::clone(&counter);
1083 let task = TimerService::create_task(
1084 Duration::from_secs(10),
1085 Some(CallbackWrapper::new(move || {
1086 let counter = Arc::clone(&counter_clone);
1087 async move {
1088 counter.fetch_add(1, Ordering::SeqCst);
1089 }
1090 })),
1091 );
1092 let task_id = task.get_id();
1093 service.register(task).unwrap();
1094
1095 // 立即取消
1096 let cancelled = service.cancel_task(task_id);
1097 assert!(cancelled, "Task should be cancelled successfully");
1098
1099 // 等待确保回调不会执行
1100 tokio::time::sleep(Duration::from_millis(100)).await;
1101 assert_eq!(counter.load(Ordering::SeqCst), 0, "Callback should not have been executed");
1102 }
1103
1104 #[tokio::test]
1105 async fn test_cancel_batch_direct() {
1106 let timer = TimerWheel::with_defaults();
1107 let service = timer.create_service(ServiceConfig::default());
1108 let counter = Arc::new(AtomicU32::new(0));
1109
1110 // 批量调度定时器
1111 let callbacks: Vec<_> = (0..10)
1112 .map(|_| {
1113 let counter = Arc::clone(&counter);
1114 (Duration::from_secs(10), Some(CallbackWrapper::new(move || {
1115 let counter = Arc::clone(&counter);
1116 async move {
1117 counter.fetch_add(1, Ordering::SeqCst);
1118 }
1119 })))
1120 })
1121 .collect();
1122
1123 let tasks = TimerService::create_batch_with_callbacks(callbacks);
1124 let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
1125 assert_eq!(task_ids.len(), 10);
1126 service.register_batch(tasks).unwrap();
1127
1128 // 批量取消所有任务
1129 let cancelled = service.cancel_batch(&task_ids);
1130 assert_eq!(cancelled, 10, "All 10 tasks should be cancelled");
1131
1132 // 等待确保回调不会执行
1133 tokio::time::sleep(Duration::from_millis(100)).await;
1134 assert_eq!(counter.load(Ordering::SeqCst), 0, "No callbacks should have been executed");
1135 }
1136
1137 #[tokio::test]
1138 async fn test_cancel_batch_partial() {
1139 let timer = TimerWheel::with_defaults();
1140 let service = timer.create_service(ServiceConfig::default());
1141 let counter = Arc::new(AtomicU32::new(0));
1142
1143 // 批量调度定时器
1144 let callbacks: Vec<_> = (0..10)
1145 .map(|_| {
1146 let counter = Arc::clone(&counter);
1147 (Duration::from_secs(10), Some(CallbackWrapper::new(move || {
1148 let counter = Arc::clone(&counter);
1149 async move {
1150 counter.fetch_add(1, Ordering::SeqCst);
1151 }
1152 })))
1153 })
1154 .collect();
1155
1156 let tasks = TimerService::create_batch_with_callbacks(callbacks);
1157 let task_ids: Vec<_> = tasks.iter().map(|t| t.get_id()).collect();
1158 service.register_batch(tasks).unwrap();
1159
1160 // 只取消前5个任务
1161 let to_cancel: Vec<_> = task_ids.iter().take(5).copied().collect();
1162 let cancelled = service.cancel_batch(&to_cancel);
1163 assert_eq!(cancelled, 5, "5 tasks should be cancelled");
1164
1165 // 等待确保前5个回调不会执行
1166 tokio::time::sleep(Duration::from_millis(100)).await;
1167 assert_eq!(counter.load(Ordering::SeqCst), 0, "Cancelled tasks should not execute");
1168 }
1169
1170 #[tokio::test]
1171 async fn test_cancel_batch_empty() {
1172 let timer = TimerWheel::with_defaults();
1173 let service = timer.create_service(ServiceConfig::default());
1174
1175 // 取消空列表
1176 let empty: Vec<TaskId> = vec![];
1177 let cancelled = service.cancel_batch(&empty);
1178 assert_eq!(cancelled, 0, "No tasks should be cancelled");
1179 }
1180
1181 #[tokio::test]
1182 async fn test_postpone() {
1183 let timer = TimerWheel::with_defaults();
1184 let mut service = timer.create_service(ServiceConfig::default());
1185 let counter = Arc::new(AtomicU32::new(0));
1186
1187 // 注册一个任务,原始回调增加 1
1188 let counter_clone1 = Arc::clone(&counter);
1189 let task = TimerService::create_task(
1190 Duration::from_millis(50),
1191 Some(CallbackWrapper::new(move || {
1192 let counter = Arc::clone(&counter_clone1);
1193 async move {
1194 counter.fetch_add(1, Ordering::SeqCst);
1195 }
1196 })),
1197 );
1198 let task_id = task.get_id();
1199 service.register(task).unwrap();
1200
1201 // 推迟任务并替换回调,新回调增加 10
1202 let counter_clone2 = Arc::clone(&counter);
1203 let postponed = service.postpone(
1204 task_id,
1205 Duration::from_millis(100),
1206 Some(CallbackWrapper::new(move || {
1207 let counter = Arc::clone(&counter_clone2);
1208 async move {
1209 counter.fetch_add(10, Ordering::SeqCst);
1210 }
1211 }))
1212 );
1213 assert!(postponed, "Task should be postponed successfully");
1214
1215 // 接收超时通知(推迟后需要等待100ms,加上余量)
1216 let mut rx = service.take_receiver().unwrap();
1217 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1218 .await
1219 .expect("Should receive timeout notification")
1220 .expect("Should receive Some value");
1221
1222 assert_eq!(received_task_id, task_id);
1223
1224 // 等待回调执行
1225 tokio::time::sleep(Duration::from_millis(20)).await;
1226
1227 // 验证新回调被执行(增加了 10 而不是 1)
1228 assert_eq!(counter.load(Ordering::SeqCst), 10);
1229 }
1230
1231 #[tokio::test]
1232 async fn test_postpone_nonexistent_task() {
1233 let timer = TimerWheel::with_defaults();
1234 let service = timer.create_service(ServiceConfig::default());
1235
1236 // 尝试推迟一个不存在的任务
1237 let fake_task = TimerService::create_task(Duration::from_millis(50), None);
1238 let fake_task_id = fake_task.get_id();
1239 // 不注册这个任务
1240
1241 let postponed = service.postpone(fake_task_id, Duration::from_millis(100), None);
1242 assert!(!postponed, "Nonexistent task should not be postponed");
1243 }
1244
1245 #[tokio::test]
1246 async fn test_postpone_batch() {
1247 let timer = TimerWheel::with_defaults();
1248 let mut service = timer.create_service(ServiceConfig::default());
1249 let counter = Arc::new(AtomicU32::new(0));
1250
1251 // 注册 3 个任务
1252 let mut task_ids = Vec::new();
1253 for _ in 0..3 {
1254 let counter_clone = Arc::clone(&counter);
1255 let task = TimerService::create_task(
1256 Duration::from_millis(50),
1257 Some(CallbackWrapper::new(move || {
1258 let counter = Arc::clone(&counter_clone);
1259 async move {
1260 counter.fetch_add(1, Ordering::SeqCst);
1261 }
1262 })),
1263 );
1264 task_ids.push((task.get_id(), Duration::from_millis(150), None));
1265 service.register(task).unwrap();
1266 }
1267
1268 // 批量推迟
1269 let postponed = service.postpone_batch_with_callbacks(task_ids);
1270 assert_eq!(postponed, 3, "All 3 tasks should be postponed");
1271
1272 // 等待原定时间 50ms,任务不应该触发
1273 tokio::time::sleep(Duration::from_millis(70)).await;
1274 assert_eq!(counter.load(Ordering::SeqCst), 0);
1275
1276 // 接收所有超时通知
1277 let mut received_count = 0;
1278 let mut rx = service.take_receiver().unwrap();
1279
1280 while received_count < 3 {
1281 match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
1282 Ok(Some(_task_id)) => {
1283 received_count += 1;
1284 }
1285 Ok(None) => break,
1286 Err(_) => break,
1287 }
1288 }
1289
1290 assert_eq!(received_count, 3);
1291
1292 // 等待回调执行
1293 tokio::time::sleep(Duration::from_millis(20)).await;
1294 assert_eq!(counter.load(Ordering::SeqCst), 3);
1295 }
1296
1297 #[tokio::test]
1298 async fn test_postpone_batch_with_callbacks() {
1299 let timer = TimerWheel::with_defaults();
1300 let mut service = timer.create_service(ServiceConfig::default());
1301 let counter = Arc::new(AtomicU32::new(0));
1302
1303 // 注册 3 个任务
1304 let mut task_ids = Vec::new();
1305 for _ in 0..3 {
1306 let task = TimerService::create_task(
1307 Duration::from_millis(50),
1308 None,
1309 );
1310 task_ids.push(task.get_id());
1311 service.register(task).unwrap();
1312 }
1313
1314 // 批量推迟并替换回调
1315 let updates: Vec<_> = task_ids
1316 .into_iter()
1317 .map(|id| {
1318 let counter_clone = Arc::clone(&counter);
1319 (id, Duration::from_millis(150), Some(CallbackWrapper::new(move || {
1320 let counter = Arc::clone(&counter_clone);
1321 async move {
1322 counter.fetch_add(1, Ordering::SeqCst);
1323 }
1324 })))
1325 })
1326 .collect();
1327
1328 let postponed = service.postpone_batch_with_callbacks(updates);
1329 assert_eq!(postponed, 3, "All 3 tasks should be postponed");
1330
1331 // 等待原定时间 50ms,任务不应该触发
1332 tokio::time::sleep(Duration::from_millis(70)).await;
1333 assert_eq!(counter.load(Ordering::SeqCst), 0);
1334
1335 // 接收所有超时通知
1336 let mut received_count = 0;
1337 let mut rx = service.take_receiver().unwrap();
1338
1339 while received_count < 3 {
1340 match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
1341 Ok(Some(_task_id)) => {
1342 received_count += 1;
1343 }
1344 Ok(None) => break,
1345 Err(_) => break,
1346 }
1347 }
1348
1349 assert_eq!(received_count, 3);
1350
1351 // 等待回调执行
1352 tokio::time::sleep(Duration::from_millis(20)).await;
1353 assert_eq!(counter.load(Ordering::SeqCst), 3);
1354 }
1355
1356 #[tokio::test]
1357 async fn test_postpone_batch_empty() {
1358 let timer = TimerWheel::with_defaults();
1359 let service = timer.create_service(ServiceConfig::default());
1360
1361 // 推迟空列表
1362 let empty: Vec<(TaskId, Duration, Option<CallbackWrapper>)> = vec![];
1363 let postponed = service.postpone_batch_with_callbacks(empty);
1364 assert_eq!(postponed, 0, "No tasks should be postponed");
1365 }
1366
1367 #[tokio::test]
1368 async fn test_postpone_keeps_timeout_notification_valid() {
1369 let timer = TimerWheel::with_defaults();
1370 let mut service = timer.create_service(ServiceConfig::default());
1371 let counter = Arc::new(AtomicU32::new(0));
1372
1373 // 注册一个任务
1374 let counter_clone = Arc::clone(&counter);
1375 let task = TimerService::create_task(
1376 Duration::from_millis(50),
1377 Some(CallbackWrapper::new(move || {
1378 let counter = Arc::clone(&counter_clone);
1379 async move {
1380 counter.fetch_add(1, Ordering::SeqCst);
1381 }
1382 })),
1383 );
1384 let task_id = task.get_id();
1385 service.register(task).unwrap();
1386
1387 // 推迟任务
1388 service.postpone(task_id, Duration::from_millis(100), None);
1389
1390 // 验证超时通知仍然有效(推迟后需要等待100ms,加上余量)
1391 let mut rx = service.take_receiver().unwrap();
1392 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1393 .await
1394 .expect("Should receive timeout notification")
1395 .expect("Should receive Some value");
1396
1397 assert_eq!(received_task_id, task_id, "Timeout notification should still work after postpone");
1398
1399 // 等待回调执行
1400 tokio::time::sleep(Duration::from_millis(20)).await;
1401 assert_eq!(counter.load(Ordering::SeqCst), 1);
1402 }
1403
1404 #[tokio::test]
1405 async fn test_cancelled_task_not_forwarded_to_timeout_rx() {
1406 let timer = TimerWheel::with_defaults();
1407 let mut service = timer.create_service(ServiceConfig::default());
1408
1409 // 注册两个任务:一个会被取消,一个会正常到期
1410 let task1 = TimerService::create_task(Duration::from_secs(10), None);
1411 let task1_id = task1.get_id();
1412 service.register(task1).unwrap();
1413
1414 let task2 = TimerService::create_task(Duration::from_millis(50), None);
1415 let task2_id = task2.get_id();
1416 service.register(task2).unwrap();
1417
1418 // 取消第一个任务
1419 let cancelled = service.cancel_task(task1_id);
1420 assert!(cancelled, "Task should be cancelled");
1421
1422 // 等待第二个任务到期
1423 let mut rx = service.take_receiver().unwrap();
1424 let received_task_id = tokio::time::timeout(Duration::from_millis(200), rx.recv())
1425 .await
1426 .expect("Should receive timeout notification")
1427 .expect("Should receive Some value");
1428
1429 // 应该只收到第二个任务(到期的)的通知,不应该收到第一个任务(取消的)的通知
1430 assert_eq!(received_task_id, task2_id, "Should only receive expired task notification");
1431
1432 // 验证没有其他通知(特别是被取消的任务不应该有通知)
1433 let no_more = tokio::time::timeout(Duration::from_millis(50), rx.recv()).await;
1434 assert!(no_more.is_err(), "Should not receive any more notifications");
1435 }
1436
1437 #[tokio::test]
1438 async fn test_take_receiver_twice() {
1439 let timer = TimerWheel::with_defaults();
1440 let mut service = timer.create_service(ServiceConfig::default());
1441
1442 // 第一次调用应该返回 Some
1443 let rx1 = service.take_receiver();
1444 assert!(rx1.is_some(), "First take_receiver should return Some");
1445
1446 // 第二次调用应该返回 None
1447 let rx2 = service.take_receiver();
1448 assert!(rx2.is_none(), "Second take_receiver should return None");
1449 }
1450
1451 #[tokio::test]
1452 async fn test_postpone_batch_without_callbacks() {
1453 let timer = TimerWheel::with_defaults();
1454 let mut service = timer.create_service(ServiceConfig::default());
1455 let counter = Arc::new(AtomicU32::new(0));
1456
1457 // 注册 3 个任务,有原始回调
1458 let mut task_ids = Vec::new();
1459 for _ in 0..3 {
1460 let counter_clone = Arc::clone(&counter);
1461 let task = TimerService::create_task(
1462 Duration::from_millis(50),
1463 Some(CallbackWrapper::new(move || {
1464 let counter = Arc::clone(&counter_clone);
1465 async move {
1466 counter.fetch_add(1, Ordering::SeqCst);
1467 }
1468 })),
1469 );
1470 task_ids.push(task.get_id());
1471 service.register(task).unwrap();
1472 }
1473
1474 // 批量推迟,不替换回调
1475 let updates: Vec<_> = task_ids
1476 .iter()
1477 .map(|&id| (id, Duration::from_millis(150)))
1478 .collect();
1479 let postponed = service.postpone_batch(updates);
1480 assert_eq!(postponed, 3, "All 3 tasks should be postponed");
1481
1482 // 等待原定时间 50ms,任务不应该触发
1483 tokio::time::sleep(Duration::from_millis(70)).await;
1484 assert_eq!(counter.load(Ordering::SeqCst), 0, "Callbacks should not fire yet");
1485
1486 // 接收所有超时通知
1487 let mut received_count = 0;
1488 let mut rx = service.take_receiver().unwrap();
1489
1490 while received_count < 3 {
1491 match tokio::time::timeout(Duration::from_millis(200), rx.recv()).await {
1492 Ok(Some(_task_id)) => {
1493 received_count += 1;
1494 }
1495 Ok(None) => break,
1496 Err(_) => break,
1497 }
1498 }
1499
1500 assert_eq!(received_count, 3, "Should receive 3 timeout notifications");
1501
1502 // 等待回调执行
1503 tokio::time::sleep(Duration::from_millis(20)).await;
1504 assert_eq!(counter.load(Ordering::SeqCst), 3, "All callbacks should execute");
1505 }
1506}
1507