kestrel_protocol_timer/timer.rs
1use crate::config::{ServiceConfig, WheelConfig};
2use crate::task::{CallbackWrapper, TaskId, TimerCallback};
3use crate::wheel::Wheel;
4use parking_lot::Mutex;
5use std::sync::Arc;
6use std::time::Duration;
7use tokio::sync::oneshot;
8use tokio::task::JoinHandle;
9
10/// 完成通知接收器,用于接收定时器完成通知
11pub struct CompletionReceiver(pub oneshot::Receiver<()>);
12
13/// 定时器句柄,用于管理定时器的生命周期
14///
15/// 注意:此类型不实现 Clone,以防止重复取消同一个定时器。
16/// 每个定时器只应有一个所有者。
17pub struct TimerHandle {
18 pub(crate) task_id: TaskId,
19 pub(crate) wheel: Arc<Mutex<Wheel>>,
20 pub(crate) completion_rx: CompletionReceiver,
21}
22
23impl TimerHandle {
24 pub(crate) fn new(task_id: TaskId, wheel: Arc<Mutex<Wheel>>, completion_rx: oneshot::Receiver<()>) -> Self {
25 Self { task_id, wheel, completion_rx: CompletionReceiver(completion_rx) }
26 }
27
28 /// 取消定时器
29 ///
30 /// # 返回
31 /// 如果任务存在且成功取消返回 true,否则返回 false
32 ///
33 /// # 示例
34 /// ```no_run
35 /// # use kestrel_protocol_timer::{TimerWheel, TimerTask};
36 /// # use std::time::Duration;
37 /// # #[tokio::main]
38 /// # async fn main() {
39 /// let timer = TimerWheel::with_defaults();
40 /// let task = TimerWheel::create_task(Duration::from_secs(1), || async {});
41 /// let handle = timer.register(task);
42 ///
43 /// // 取消定时器
44 /// let success = handle.cancel();
45 /// println!("取消成功: {}", success);
46 /// # }
47 /// ```
48 pub fn cancel(&self) -> bool {
49 let mut wheel = self.wheel.lock();
50 wheel.cancel(self.task_id)
51 }
52
53 /// 获取完成通知接收器的可变引用
54 ///
55 /// # 示例
56 /// ```no_run
57 /// # use kestrel_protocol_timer::{TimerWheel, TimerTask};
58 /// # use std::time::Duration;
59 /// # #[tokio::main]
60 /// # async fn main() {
61 /// let timer = TimerWheel::with_defaults();
62 /// let task = TimerWheel::create_task(Duration::from_secs(1), || async {
63 /// println!("Timer fired!");
64 /// });
65 /// let handle = timer.register(task);
66 ///
67 /// // 等待定时器完成(使用 into_completion_receiver 消耗句柄)
68 /// handle.into_completion_receiver().0.await.ok();
69 /// println!("Timer completed!");
70 /// # }
71 /// ```
72 pub fn completion_receiver(&mut self) -> &mut CompletionReceiver {
73 &mut self.completion_rx
74 }
75
76 /// 消耗句柄,返回完成通知接收器
77 ///
78 /// # 示例
79 /// ```no_run
80 /// # use kestrel_protocol_timer::{TimerWheel, TimerTask};
81 /// # use std::time::Duration;
82 /// # #[tokio::main]
83 /// # async fn main() {
84 /// let timer = TimerWheel::with_defaults();
85 /// let task = TimerWheel::create_task(Duration::from_secs(1), || async {
86 /// println!("Timer fired!");
87 /// });
88 /// let handle = timer.register(task);
89 ///
90 /// // 等待定时器完成
91 /// handle.into_completion_receiver().0.await.ok();
92 /// println!("Timer completed!");
93 /// # }
94 /// ```
95 pub fn into_completion_receiver(self) -> CompletionReceiver {
96 self.completion_rx
97 }
98}
99
100/// 批量定时器句柄,用于管理批量调度的定时器
101///
102/// 通过共享 Wheel 引用减少内存开销,同时提供批量操作和迭代器访问能力。
103///
104/// 注意:此类型不实现 Clone,以防止重复取消同一批定时器。
105/// 如需访问单个定时器句柄,请使用 `into_iter()` 或 `into_handles()` 进行转换。
106pub struct BatchHandle {
107 pub(crate) task_ids: Vec<TaskId>,
108 pub(crate) wheel: Arc<Mutex<Wheel>>,
109 pub(crate) completion_rxs: Vec<oneshot::Receiver<()>>,
110}
111
112impl BatchHandle {
113 pub(crate) fn new(task_ids: Vec<TaskId>, wheel: Arc<Mutex<Wheel>>, completion_rxs: Vec<oneshot::Receiver<()>>) -> Self {
114 Self { task_ids, wheel, completion_rxs }
115 }
116
117 /// 批量取消所有定时器
118 ///
119 /// # 返回
120 /// 成功取消的任务数量
121 ///
122 /// # 示例
123 /// ```no_run
124 /// # use kestrel_protocol_timer::{TimerWheel, TimerTask};
125 /// # use std::time::Duration;
126 /// # #[tokio::main]
127 /// # async fn main() {
128 /// let timer = TimerWheel::with_defaults();
129 /// let tasks = TimerWheel::create_batch(
130 /// (0..10).map(|_| (Duration::from_secs(1), || async {})).collect()
131 /// );
132 /// let batch = timer.register_batch(tasks);
133 ///
134 /// let cancelled = batch.cancel_all();
135 /// println!("取消了 {} 个定时器", cancelled);
136 /// # }
137 /// ```
138 pub fn cancel_all(self) -> usize {
139 let mut wheel = self.wheel.lock();
140 wheel.cancel_batch(&self.task_ids)
141 }
142
143 /// 将批量句柄转换为单个定时器句柄的 Vec
144 ///
145 /// 消耗 BatchHandle,为每个任务创建独立的 TimerHandle。
146 ///
147 /// # 示例
148 /// ```no_run
149 /// # use kestrel_protocol_timer::{TimerWheel, TimerTask};
150 /// # use std::time::Duration;
151 /// # #[tokio::main]
152 /// # async fn main() {
153 /// let timer = TimerWheel::with_defaults();
154 /// let tasks = TimerWheel::create_batch(
155 /// (0..3).map(|_| (Duration::from_secs(1), || async {})).collect()
156 /// );
157 /// let batch = timer.register_batch(tasks);
158 ///
159 /// // 转换为独立的句柄
160 /// let handles = batch.into_handles();
161 /// for handle in handles {
162 /// // 可以单独操作每个句柄
163 /// }
164 /// # }
165 /// ```
166 pub fn into_handles(self) -> Vec<TimerHandle> {
167 self.task_ids
168 .into_iter()
169 .zip(self.completion_rxs.into_iter())
170 .map(|(task_id, rx)| {
171 TimerHandle::new(task_id, self.wheel.clone(), rx)
172 })
173 .collect()
174 }
175
176 /// 获取批量任务的数量
177 pub fn len(&self) -> usize {
178 self.task_ids.len()
179 }
180
181 /// 检查批量任务是否为空
182 pub fn is_empty(&self) -> bool {
183 self.task_ids.is_empty()
184 }
185
186 /// 获取所有任务 ID 的引用
187 pub fn task_ids(&self) -> &[TaskId] {
188 &self.task_ids
189 }
190
191 /// 获取所有完成通知接收器的引用
192 ///
193 /// # 返回
194 /// 所有任务的完成通知接收器列表引用
195 pub fn completion_receivers(&mut self) -> &mut Vec<oneshot::Receiver<()>> {
196 &mut self.completion_rxs
197 }
198
199 /// 消耗句柄,返回所有完成通知接收器
200 ///
201 /// # 返回
202 /// 所有任务的完成通知接收器列表
203 ///
204 /// # 示例
205 /// ```no_run
206 /// # use kestrel_protocol_timer::{TimerWheel, TimerTask};
207 /// # use std::time::Duration;
208 /// # #[tokio::main]
209 /// # async fn main() {
210 /// let timer = TimerWheel::with_defaults();
211 /// let tasks = TimerWheel::create_batch(
212 /// (0..3).map(|_| (Duration::from_secs(1), || async {})).collect()
213 /// );
214 /// let batch = timer.register_batch(tasks);
215 ///
216 /// // 获取所有完成通知接收器
217 /// let receivers = batch.into_completion_receivers();
218 /// for rx in receivers {
219 /// tokio::spawn(async move {
220 /// if rx.await.is_ok() {
221 /// println!("A timer completed!");
222 /// }
223 /// });
224 /// }
225 /// # }
226 /// ```
227 pub fn into_completion_receivers(self) -> Vec<oneshot::Receiver<()>> {
228 self.completion_rxs
229 }
230}
231
232/// 实现 IntoIterator,允许直接迭代 BatchHandle
233///
234/// # 示例
235/// ```no_run
236/// # use kestrel_protocol_timer::{TimerWheel, TimerTask};
237/// # use std::time::Duration;
238/// # #[tokio::main]
239/// # async fn main() {
240/// let timer = TimerWheel::with_defaults();
241/// let tasks = TimerWheel::create_batch(
242/// (0..3).map(|_| (Duration::from_secs(1), || async {})).collect()
243/// );
244/// let batch = timer.register_batch(tasks);
245///
246/// // 直接迭代,每个元素都是独立的 TimerHandle
247/// for handle in batch {
248/// // 可以单独操作每个句柄
249/// }
250/// # }
251/// ```
252impl IntoIterator for BatchHandle {
253 type Item = TimerHandle;
254 type IntoIter = BatchHandleIter;
255
256 fn into_iter(self) -> Self::IntoIter {
257 BatchHandleIter {
258 task_ids: self.task_ids.into_iter(),
259 completion_rxs: self.completion_rxs.into_iter(),
260 wheel: self.wheel,
261 }
262 }
263}
264
265/// BatchHandle 的迭代器
266pub struct BatchHandleIter {
267 task_ids: std::vec::IntoIter<TaskId>,
268 completion_rxs: std::vec::IntoIter<oneshot::Receiver<()>>,
269 wheel: Arc<Mutex<Wheel>>,
270}
271
272impl Iterator for BatchHandleIter {
273 type Item = TimerHandle;
274
275 fn next(&mut self) -> Option<Self::Item> {
276 match (self.task_ids.next(), self.completion_rxs.next()) {
277 (Some(task_id), Some(rx)) => {
278 Some(TimerHandle::new(task_id, self.wheel.clone(), rx))
279 }
280 _ => None,
281 }
282 }
283
284 fn size_hint(&self) -> (usize, Option<usize>) {
285 self.task_ids.size_hint()
286 }
287}
288
289impl ExactSizeIterator for BatchHandleIter {
290 fn len(&self) -> usize {
291 self.task_ids.len()
292 }
293}
294
295/// 时间轮定时器管理器
296pub struct TimerWheel {
297 /// 时间轮唯一标识符
298
299 /// 时间轮实例(使用 Arc<Mutex> 包装以支持多线程访问)
300 wheel: Arc<Mutex<Wheel>>,
301
302 /// 后台 tick 循环任务句柄
303 tick_handle: Option<JoinHandle<()>>,
304}
305
306impl TimerWheel {
307 /// 创建新的定时器管理器
308 ///
309 /// # 参数
310 /// - `config`: 时间轮配置(已经过验证)
311 ///
312 /// # 示例
313 /// ```no_run
314 /// use kestrel_protocol_timer::{TimerWheel, WheelConfig, TimerTask};
315 /// use std::time::Duration;
316 ///
317 /// #[tokio::main]
318 /// async fn main() {
319 /// let config = WheelConfig::builder()
320 /// .tick_duration(Duration::from_millis(10))
321 /// .slot_count(512)
322 /// .build()
323 /// .unwrap();
324 /// let timer = TimerWheel::new(config);
325 ///
326 /// // 使用两步式 API
327 /// let task = TimerWheel::create_task(Duration::from_secs(1), || async {});
328 /// let handle = timer.register(task);
329 /// }
330 /// ```
331 pub fn new(config: WheelConfig) -> Self {
332 let tick_duration = config.tick_duration;
333 let wheel = Wheel::new(config);
334 let wheel = Arc::new(Mutex::new(wheel));
335 let wheel_clone = wheel.clone();
336
337 // 启动后台 tick 循环
338 let tick_handle = tokio::spawn(async move {
339 Self::tick_loop(wheel_clone, tick_duration).await;
340 });
341
342 Self {
343 wheel,
344 tick_handle: Some(tick_handle),
345 }
346 }
347
348 /// 创建带默认配置的定时器管理器
349 /// - tick 时长: 10ms
350 /// - 槽位数量: 512
351 ///
352 /// # 示例
353 /// ```no_run
354 /// use kestrel_protocol_timer::TimerWheel;
355 ///
356 /// #[tokio::main]
357 /// async fn main() {
358 /// let timer = TimerWheel::with_defaults();
359 /// }
360 /// ```
361 pub fn with_defaults() -> Self {
362 Self::new(WheelConfig::default())
363 }
364
365 /// 创建与此时间轮绑定的 TimerService(使用默认配置)
366 ///
367 /// # 返回
368 /// 绑定到此时间轮的 TimerService 实例
369 ///
370 /// # 示例
371 /// ```no_run
372 /// use kestrel_protocol_timer::{TimerWheel, TimerService};
373 /// use std::time::Duration;
374 ///
375 /// #[tokio::main]
376 /// async fn main() {
377 /// let timer = TimerWheel::with_defaults();
378 /// let mut service = timer.create_service();
379 ///
380 /// // 使用两步式 API 通过 service 批量调度定时器
381 /// let tasks = TimerService::create_batch(
382 /// (0..5).map(|_| (Duration::from_millis(100), || async {})).collect()
383 /// );
384 /// service.register_batch(tasks).await;
385 ///
386 /// // 接收超时通知
387 /// let mut rx = service.take_receiver().unwrap();
388 /// while let Some(task_id) = rx.recv().await {
389 /// println!("Task {:?} completed", task_id);
390 /// }
391 /// }
392 /// ```
393 pub fn create_service(&self) -> crate::service::TimerService {
394 crate::service::TimerService::new(self.wheel.clone(), ServiceConfig::default())
395 }
396
397 /// 创建与此时间轮绑定的 TimerService(使用自定义配置)
398 ///
399 /// # 参数
400 /// - `config`: 服务配置
401 ///
402 /// # 返回
403 /// 绑定到此时间轮的 TimerService 实例
404 ///
405 /// # 示例
406 /// ```no_run
407 /// use kestrel_protocol_timer::{TimerWheel, ServiceConfig};
408 ///
409 /// #[tokio::main]
410 /// async fn main() {
411 /// let timer = TimerWheel::with_defaults();
412 /// let config = ServiceConfig::builder()
413 /// .command_channel_capacity(1024)
414 /// .timeout_channel_capacity(2000)
415 /// .build()
416 /// .unwrap();
417 /// let service = timer.create_service_with_config(config);
418 /// }
419 /// ```
420 pub fn create_service_with_config(&self, config: ServiceConfig) -> crate::service::TimerService {
421 crate::service::TimerService::new(self.wheel.clone(), config)
422 }
423
424 /// 创建定时器任务(申请阶段)
425 ///
426 /// # 参数
427 /// - `delay`: 延迟时间
428 /// - `callback`: 实现了 TimerCallback trait 的回调对象
429 ///
430 /// # 返回
431 /// 返回 TimerTask,需要通过 `register()` 注册到时间轮
432 ///
433 /// # 示例
434 /// ```no_run
435 /// use kestrel_protocol_timer::{TimerWheel, TimerTask};
436 /// use std::time::Duration;
437 ///
438 /// #[tokio::main]
439 /// async fn main() {
440 /// let timer = TimerWheel::with_defaults();
441 ///
442 /// // 步骤 1: 创建任务
443 /// let task = TimerWheel::create_task(Duration::from_secs(1), || async {
444 /// println!("Timer fired!");
445 /// });
446 ///
447 /// // 获取任务 ID
448 /// let task_id = task.get_id();
449 /// println!("Created task: {:?}", task_id);
450 ///
451 /// // 步骤 2: 注册任务
452 /// let handle = timer.register(task);
453 /// }
454 /// ```
455 pub fn create_task<C>(delay: Duration, callback: C) -> crate::task::TimerTask
456 where
457 C: TimerCallback,
458 {
459 use std::sync::Arc;
460 let callback_wrapper = Arc::new(callback) as CallbackWrapper;
461 crate::task::TimerTask::new(delay, Some(callback_wrapper))
462 }
463
464 /// 批量创建定时器任务(申请阶段)
465 ///
466 /// # 参数
467 /// - `callbacks`: (延迟时间, 回调) 的元组列表
468 ///
469 /// # 返回
470 /// 返回 TimerTask 列表,需要通过 `register_batch()` 注册到时间轮
471 ///
472 /// # 示例
473 /// ```no_run
474 /// use kestrel_protocol_timer::{TimerWheel, TimerTask};
475 /// use std::time::Duration;
476 /// use std::sync::Arc;
477 /// use std::sync::atomic::{AtomicU32, Ordering};
478 ///
479 /// #[tokio::main]
480 /// async fn main() {
481 /// let timer = TimerWheel::with_defaults();
482 /// let counter = Arc::new(AtomicU32::new(0));
483 ///
484 /// // 步骤 1: 批量创建任务
485 /// let callbacks: Vec<_> = (0..3)
486 /// .map(|i| {
487 /// let counter = Arc::clone(&counter);
488 /// let delay = Duration::from_millis(100 + i * 100);
489 /// let callback = move || {
490 /// let counter = Arc::clone(&counter);
491 /// async move {
492 /// counter.fetch_add(1, Ordering::SeqCst);
493 /// }
494 /// };
495 /// (delay, callback)
496 /// })
497 /// .collect();
498 ///
499 /// let tasks = TimerWheel::create_batch(callbacks);
500 /// println!("Created {} tasks", tasks.len());
501 ///
502 /// // 步骤 2: 批量注册任务
503 /// let batch = timer.register_batch(tasks);
504 /// }
505 /// ```
506 pub fn create_batch<C>(callbacks: Vec<(Duration, C)>) -> Vec<crate::task::TimerTask>
507 where
508 C: TimerCallback,
509 {
510 use std::sync::Arc;
511 callbacks
512 .into_iter()
513 .map(|(delay, callback)| {
514 let callback_wrapper = Arc::new(callback) as CallbackWrapper;
515 crate::task::TimerTask::new(delay, Some(callback_wrapper))
516 })
517 .collect()
518 }
519
520 /// 注册定时器任务到时间轮(注册阶段)
521 ///
522 /// # 参数
523 /// - `task`: 通过 `create_task()` 创建的任务
524 ///
525 /// # 返回
526 /// 返回定时器句柄,可用于取消定时器和接收完成通知
527 ///
528 /// # 示例
529 /// ```no_run
530 /// use kestrel_protocol_timer::{TimerWheel, TimerTask};
531 /// use std::time::Duration;
532 ///
533 /// #[tokio::main]
534 /// async fn main() {
535 /// let timer = TimerWheel::with_defaults();
536 ///
537 /// let task = TimerWheel::create_task(Duration::from_secs(1), || async {
538 /// println!("Timer fired!");
539 /// });
540 /// let task_id = task.get_id();
541 ///
542 /// let handle = timer.register(task);
543 ///
544 /// // 等待定时器完成
545 /// handle.into_completion_receiver().0.await.ok();
546 /// }
547 /// ```
548 #[inline]
549 pub fn register(&self, task: crate::task::TimerTask) -> TimerHandle {
550 let (completion_tx, completion_rx) = oneshot::channel();
551 let notifier = crate::task::CompletionNotifier(completion_tx);
552
553 let delay = task.delay;
554 let task_id = task.id;
555
556 // 单次加锁完成所有操作
557 {
558 let mut wheel_guard = self.wheel.lock();
559 wheel_guard.insert(delay, task, notifier);
560 }
561
562 TimerHandle::new(task_id, self.wheel.clone(), completion_rx)
563 }
564
565 /// 批量注册定时器任务到时间轮(注册阶段)
566 ///
567 /// # 参数
568 /// - `tasks`: 通过 `create_batch()` 创建的任务列表
569 ///
570 /// # 返回
571 /// 返回批量定时器句柄
572 ///
573 /// # 示例
574 /// ```no_run
575 /// use kestrel_protocol_timer::{TimerWheel, TimerTask};
576 /// use std::time::Duration;
577 ///
578 /// #[tokio::main]
579 /// async fn main() {
580 /// let timer = TimerWheel::with_defaults();
581 ///
582 /// let callbacks: Vec<_> = (0..3)
583 /// .map(|_| (Duration::from_secs(1), || async {}))
584 /// .collect();
585 /// let tasks = TimerWheel::create_batch(callbacks);
586 ///
587 /// let batch = timer.register_batch(tasks);
588 /// println!("Registered {} timers", batch.len());
589 /// }
590 /// ```
591 #[inline]
592 pub fn register_batch(&self, tasks: Vec<crate::task::TimerTask>) -> BatchHandle {
593 let task_count = tasks.len();
594 let mut completion_rxs = Vec::with_capacity(task_count);
595 let mut task_ids = Vec::with_capacity(task_count);
596 let mut prepared_tasks = Vec::with_capacity(task_count);
597
598 // 步骤1: 准备所有 channels 和 notifiers(无锁)
599 // 优化:使用 for 循环代替 map + collect,避免闭包捕获开销
600 for task in tasks {
601 let (completion_tx, completion_rx) = oneshot::channel();
602 let notifier = crate::task::CompletionNotifier(completion_tx);
603
604 task_ids.push(task.id);
605 completion_rxs.push(completion_rx);
606 prepared_tasks.push((task.delay, task, notifier));
607 }
608
609 // 步骤2: 单次加锁,批量插入
610 {
611 let mut wheel_guard = self.wheel.lock();
612 wheel_guard.insert_batch(prepared_tasks);
613 }
614
615 BatchHandle::new(task_ids, self.wheel.clone(), completion_rxs)
616 }
617
618 /// 取消定时器
619 ///
620 /// # 参数
621 /// - `task_id`: 任务 ID
622 ///
623 /// # 返回
624 /// 如果任务存在且成功取消返回 true,否则返回 false
625 ///
626 /// # 示例
627 /// ```no_run
628 /// use kestrel_protocol_timer::{TimerWheel, TimerTask};
629 /// use std::time::Duration;
630 ///
631 /// #[tokio::main]
632 /// async fn main() {
633 /// let timer = TimerWheel::with_defaults();
634 ///
635 /// let task = TimerWheel::create_task(Duration::from_secs(10), || async {});
636 /// let task_id = task.get_id();
637 /// let _handle = timer.register(task);
638 ///
639 /// // 使用任务 ID 取消
640 /// let cancelled = timer.cancel(task_id);
641 /// println!("取消成功: {}", cancelled);
642 /// }
643 /// ```
644 #[inline]
645 pub fn cancel(&self, task_id: TaskId) -> bool {
646 let mut wheel = self.wheel.lock();
647 wheel.cancel(task_id)
648 }
649
650 /// 批量取消定时器
651 ///
652 /// # 参数
653 /// - `task_ids`: 要取消的任务 ID 列表
654 ///
655 /// # 返回
656 /// 成功取消的任务数量
657 ///
658 /// # 性能优势
659 /// - 批量处理减少锁竞争
660 /// - 内部优化批量取消操作
661 ///
662 /// # 示例
663 /// ```no_run
664 /// use kestrel_protocol_timer::{TimerWheel, TimerTask};
665 /// use std::time::Duration;
666 ///
667 /// #[tokio::main]
668 /// async fn main() {
669 /// let timer = TimerWheel::with_defaults();
670 ///
671 /// // 创建多个定时器
672 /// let task1 = TimerWheel::create_task(Duration::from_secs(10), || async {});
673 /// let task2 = TimerWheel::create_task(Duration::from_secs(10), || async {});
674 /// let task3 = TimerWheel::create_task(Duration::from_secs(10), || async {});
675 ///
676 /// let task_ids = vec![task1.get_id(), task2.get_id(), task3.get_id()];
677 ///
678 /// let _h1 = timer.register(task1);
679 /// let _h2 = timer.register(task2);
680 /// let _h3 = timer.register(task3);
681 ///
682 /// // 批量取消
683 /// let cancelled = timer.cancel_batch(&task_ids);
684 /// println!("已取消 {} 个定时器", cancelled);
685 /// }
686 /// ```
687 #[inline]
688 pub fn cancel_batch(&self, task_ids: &[TaskId]) -> usize {
689 let mut wheel = self.wheel.lock();
690 wheel.cancel_batch(task_ids)
691 }
692
693 /// 推迟定时器(保持原回调)
694 ///
695 /// # 参数
696 /// - `task_id`: 要推迟的任务 ID
697 /// - `new_delay`: 新的延迟时间(从当前时间点重新计算)
698 ///
699 /// # 返回
700 /// 如果任务存在且成功推迟返回 true,否则返回 false
701 ///
702 /// # 注意
703 /// - 推迟后任务 ID 保持不变
704 /// - 原有的 completion_receiver 仍然有效
705 /// - 保持原回调函数不变
706 ///
707 /// # 示例
708 /// ```no_run
709 /// use kestrel_protocol_timer::{TimerWheel, TimerTask};
710 /// use std::time::Duration;
711 ///
712 /// #[tokio::main]
713 /// async fn main() {
714 /// let timer = TimerWheel::with_defaults();
715 ///
716 /// let task = TimerWheel::create_task(Duration::from_secs(5), || async {
717 /// println!("Timer fired!");
718 /// });
719 /// let task_id = task.get_id();
720 /// let _handle = timer.register(task);
721 ///
722 /// // 推迟到 10 秒后触发
723 /// let success = timer.postpone(task_id, Duration::from_secs(10));
724 /// println!("推迟成功: {}", success);
725 /// }
726 /// ```
727 #[inline]
728 pub fn postpone(&self, task_id: TaskId, new_delay: Duration) -> bool {
729 let mut wheel = self.wheel.lock();
730 wheel.postpone(task_id, new_delay, None)
731 }
732
733 /// 推迟定时器(替换回调)
734 ///
735 /// # 参数
736 /// - `task_id`: 要推迟的任务 ID
737 /// - `new_delay`: 新的延迟时间(从当前时间点重新计算)
738 /// - `callback`: 新的回调函数
739 ///
740 /// # 返回
741 /// 如果任务存在且成功推迟返回 true,否则返回 false
742 ///
743 /// # 注意
744 /// - 推迟后任务 ID 保持不变
745 /// - 原有的 completion_receiver 仍然有效
746 /// - 回调函数会被替换为新的回调
747 ///
748 /// # 示例
749 /// ```no_run
750 /// use kestrel_protocol_timer::{TimerWheel, TimerTask};
751 /// use std::time::Duration;
752 ///
753 /// #[tokio::main]
754 /// async fn main() {
755 /// let timer = TimerWheel::with_defaults();
756 ///
757 /// let task = TimerWheel::create_task(Duration::from_secs(5), || async {
758 /// println!("Original callback");
759 /// });
760 /// let task_id = task.get_id();
761 /// let _handle = timer.register(task);
762 ///
763 /// // 推迟并替换回调
764 /// let success = timer.postpone_with_callback(
765 /// task_id,
766 /// Duration::from_secs(10),
767 /// || async {
768 /// println!("New callback!");
769 /// }
770 /// );
771 /// println!("推迟成功: {}", success);
772 /// }
773 /// ```
774 #[inline]
775 pub fn postpone_with_callback<C>(
776 &self,
777 task_id: TaskId,
778 new_delay: Duration,
779 callback: C,
780 ) -> bool
781 where
782 C: TimerCallback,
783 {
784 use std::sync::Arc;
785 let callback_wrapper = Arc::new(callback) as CallbackWrapper;
786 let mut wheel = self.wheel.lock();
787 wheel.postpone(task_id, new_delay, Some(callback_wrapper))
788 }
789
790 /// 批量推迟定时器(保持原回调)
791 ///
792 /// # 参数
793 /// - `updates`: (任务ID, 新延迟) 的元组列表
794 ///
795 /// # 返回
796 /// 成功推迟的任务数量
797 ///
798 /// # 性能优势
799 /// - 批量处理减少锁竞争
800 /// - 内部优化批量推迟操作
801 ///
802 /// # 示例
803 /// ```no_run
804 /// use kestrel_protocol_timer::{TimerWheel, TimerTask};
805 /// use std::time::Duration;
806 ///
807 /// #[tokio::main]
808 /// async fn main() {
809 /// let timer = TimerWheel::with_defaults();
810 ///
811 /// // 创建多个定时器
812 /// let task1 = TimerWheel::create_task(Duration::from_secs(5), || async {});
813 /// let task2 = TimerWheel::create_task(Duration::from_secs(5), || async {});
814 /// let task3 = TimerWheel::create_task(Duration::from_secs(5), || async {});
815 ///
816 /// let task_ids = vec![
817 /// (task1.get_id(), Duration::from_secs(10)),
818 /// (task2.get_id(), Duration::from_secs(15)),
819 /// (task3.get_id(), Duration::from_secs(20)),
820 /// ];
821 ///
822 /// timer.register(task1);
823 /// timer.register(task2);
824 /// timer.register(task3);
825 ///
826 /// // 批量推迟
827 /// let postponed = timer.postpone_batch(&task_ids);
828 /// println!("已推迟 {} 个定时器", postponed);
829 /// }
830 /// ```
831 #[inline]
832 pub fn postpone_batch(&self, updates: &[(TaskId, Duration)]) -> usize {
833 let updates_vec: Vec<_> = updates
834 .iter()
835 .map(|(task_id, delay)| (*task_id, *delay, None))
836 .collect();
837 let mut wheel = self.wheel.lock();
838 wheel.postpone_batch(updates_vec)
839 }
840
841 /// 批量推迟定时器(替换回调)
842 ///
843 /// # 参数
844 /// - `updates`: (任务ID, 新延迟, 新回调) 的元组列表
845 ///
846 /// # 返回
847 /// 成功推迟的任务数量
848 ///
849 /// # 性能优势
850 /// - 批量处理减少锁竞争
851 /// - 内部优化批量推迟操作
852 ///
853 /// # 示例
854 /// ```no_run
855 /// use kestrel_protocol_timer::{TimerWheel, TimerTask};
856 /// use std::time::Duration;
857 /// use std::sync::Arc;
858 /// use std::sync::atomic::{AtomicU32, Ordering};
859 ///
860 /// #[tokio::main]
861 /// async fn main() {
862 /// let timer = TimerWheel::with_defaults();
863 /// let counter = Arc::new(AtomicU32::new(0));
864 ///
865 /// // 创建多个定时器
866 /// let task1 = TimerWheel::create_task(Duration::from_secs(5), || async {});
867 /// let task2 = TimerWheel::create_task(Duration::from_secs(5), || async {});
868 ///
869 /// let id1 = task1.get_id();
870 /// let id2 = task2.get_id();
871 ///
872 /// timer.register(task1);
873 /// timer.register(task2);
874 ///
875 /// // 批量推迟并替换回调
876 /// let updates: Vec<_> = vec![id1, id2]
877 /// .into_iter()
878 /// .map(|id| {
879 /// let counter = Arc::clone(&counter);
880 /// (id, Duration::from_secs(10), move || {
881 /// let counter = Arc::clone(&counter);
882 /// async move { counter.fetch_add(1, Ordering::SeqCst); }
883 /// })
884 /// })
885 /// .collect();
886 /// let postponed = timer.postpone_batch_with_callbacks(updates);
887 /// println!("已推迟 {} 个定时器", postponed);
888 /// }
889 /// ```
890 #[inline]
891 pub fn postpone_batch_with_callbacks<C>(
892 &self,
893 updates: Vec<(TaskId, Duration, C)>,
894 ) -> usize
895 where
896 C: TimerCallback,
897 {
898 use std::sync::Arc;
899 let updates_vec: Vec<_> = updates
900 .into_iter()
901 .map(|(task_id, delay, callback)| {
902 let callback_wrapper = Arc::new(callback) as CallbackWrapper;
903 (task_id, delay, Some(callback_wrapper))
904 })
905 .collect();
906 let mut wheel = self.wheel.lock();
907 wheel.postpone_batch(updates_vec)
908 }
909
910 /// 核心 tick 循环
911 async fn tick_loop(wheel: Arc<Mutex<Wheel>>, tick_duration: Duration) {
912 let mut interval = tokio::time::interval(tick_duration);
913 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
914
915 loop {
916 interval.tick().await;
917
918 // 推进时间轮并获取到期任务
919 let expired_tasks = {
920 let mut wheel_guard = wheel.lock();
921 wheel_guard.advance()
922 };
923
924 // 执行到期任务
925 for task in expired_tasks {
926 let callback = task.get_callback();
927
928 // 移动task的所有权来获取completion_notifier
929 let notifier = task.completion_notifier;
930
931 // 只有注册过的任务才有 notifier
932 if let Some(notifier) = notifier {
933 // 在独立的 tokio 任务中执行回调,并在回调完成后发送通知
934 if let Some(callback) = callback {
935 tokio::spawn(async move {
936 // 执行回调
937 let future = callback.call();
938 future.await;
939
940 // 回调执行完成后发送通知
941 let _ = notifier.0.send(());
942 });
943 } else {
944 // 如果没有回调,立即发送完成通知
945 let _ = notifier.0.send(());
946 }
947 }
948 }
949 }
950 }
951
952 /// 停止定时器管理器
953 pub async fn shutdown(mut self) {
954 if let Some(handle) = self.tick_handle.take() {
955 handle.abort();
956 let _ = handle.await;
957 }
958 }
959}
960
961impl Drop for TimerWheel {
962 fn drop(&mut self) {
963 if let Some(handle) = self.tick_handle.take() {
964 handle.abort();
965 }
966 }
967}
968
969#[cfg(test)]
970mod tests {
971 use super::*;
972 use std::sync::atomic::{AtomicU32, Ordering};
973
974 #[tokio::test]
975 async fn test_timer_creation() {
976 let _timer = TimerWheel::with_defaults();
977 }
978
979 #[tokio::test]
980 async fn test_schedule_once() {
981 use std::sync::Arc;
982 let timer = TimerWheel::with_defaults();
983 let counter = Arc::new(AtomicU32::new(0));
984 let counter_clone = Arc::clone(&counter);
985
986 let task = TimerWheel::create_task(
987 Duration::from_millis(50),
988 move || {
989 let counter = Arc::clone(&counter_clone);
990 async move {
991 counter.fetch_add(1, Ordering::SeqCst);
992 }
993 },
994 );
995 let _handle = timer.register(task);
996
997 // 等待定时器触发
998 tokio::time::sleep(Duration::from_millis(100)).await;
999 assert_eq!(counter.load(Ordering::SeqCst), 1);
1000 }
1001
1002 #[tokio::test]
1003 async fn test_cancel_timer() {
1004 use std::sync::Arc;
1005 let timer = TimerWheel::with_defaults();
1006 let counter = Arc::new(AtomicU32::new(0));
1007 let counter_clone = Arc::clone(&counter);
1008
1009 let task = TimerWheel::create_task(
1010 Duration::from_millis(100),
1011 move || {
1012 let counter = Arc::clone(&counter_clone);
1013 async move {
1014 counter.fetch_add(1, Ordering::SeqCst);
1015 }
1016 },
1017 );
1018 let handle = timer.register(task);
1019
1020 // 立即取消
1021 let cancel_result = handle.cancel();
1022 assert!(cancel_result);
1023
1024 // 等待足够长时间确保定时器不会触发
1025 tokio::time::sleep(Duration::from_millis(200)).await;
1026 assert_eq!(counter.load(Ordering::SeqCst), 0);
1027 }
1028
1029 #[tokio::test]
1030 async fn test_cancel_immediate() {
1031 use std::sync::Arc;
1032 let timer = TimerWheel::with_defaults();
1033 let counter = Arc::new(AtomicU32::new(0));
1034 let counter_clone = Arc::clone(&counter);
1035
1036 let task = TimerWheel::create_task(
1037 Duration::from_millis(100),
1038 move || {
1039 let counter = Arc::clone(&counter_clone);
1040 async move {
1041 counter.fetch_add(1, Ordering::SeqCst);
1042 }
1043 },
1044 );
1045 let handle = timer.register(task);
1046
1047 // 立即取消
1048 let cancel_result = handle.cancel();
1049 assert!(cancel_result);
1050
1051 // 等待足够长时间确保定时器不会触发
1052 tokio::time::sleep(Duration::from_millis(200)).await;
1053 assert_eq!(counter.load(Ordering::SeqCst), 0);
1054 }
1055
1056 #[tokio::test]
1057 async fn test_postpone_timer() {
1058 use std::sync::Arc;
1059 let timer = TimerWheel::with_defaults();
1060 let counter = Arc::new(AtomicU32::new(0));
1061 let counter_clone = Arc::clone(&counter);
1062
1063 let task = TimerWheel::create_task(
1064 Duration::from_millis(50),
1065 move || {
1066 let counter = Arc::clone(&counter_clone);
1067 async move {
1068 counter.fetch_add(1, Ordering::SeqCst);
1069 }
1070 },
1071 );
1072 let task_id = task.get_id();
1073 let handle = timer.register(task);
1074
1075 // 推迟任务到 150ms
1076 let postponed = timer.postpone(task_id, Duration::from_millis(150));
1077 assert!(postponed);
1078
1079 // 等待原定时间 50ms,任务不应该触发
1080 tokio::time::sleep(Duration::from_millis(70)).await;
1081 assert_eq!(counter.load(Ordering::SeqCst), 0);
1082
1083 // 等待新的触发时间(从推迟开始算,还需要等待约 150ms)
1084 let result = tokio::time::timeout(
1085 Duration::from_millis(200),
1086 handle.into_completion_receiver().0
1087 ).await;
1088 assert!(result.is_ok());
1089
1090 // 等待回调执行
1091 tokio::time::sleep(Duration::from_millis(20)).await;
1092 assert_eq!(counter.load(Ordering::SeqCst), 1);
1093 }
1094
1095 #[tokio::test]
1096 async fn test_postpone_with_callback() {
1097 use std::sync::Arc;
1098 let timer = TimerWheel::with_defaults();
1099 let counter = Arc::new(AtomicU32::new(0));
1100 let counter_clone1 = Arc::clone(&counter);
1101 let counter_clone2 = Arc::clone(&counter);
1102
1103 // 创建任务,原始回调增加 1
1104 let task = TimerWheel::create_task(
1105 Duration::from_millis(50),
1106 move || {
1107 let counter = Arc::clone(&counter_clone1);
1108 async move {
1109 counter.fetch_add(1, Ordering::SeqCst);
1110 }
1111 },
1112 );
1113 let task_id = task.get_id();
1114 let handle = timer.register(task);
1115
1116 // 推迟任务并替换回调,新回调增加 10
1117 let postponed = timer.postpone_with_callback(
1118 task_id,
1119 Duration::from_millis(100),
1120 move || {
1121 let counter = Arc::clone(&counter_clone2);
1122 async move {
1123 counter.fetch_add(10, Ordering::SeqCst);
1124 }
1125 }
1126 );
1127 assert!(postponed);
1128
1129 // 等待任务触发(推迟后需要等待100ms,加上余量)
1130 let result = tokio::time::timeout(
1131 Duration::from_millis(200),
1132 handle.into_completion_receiver().0
1133 ).await;
1134 assert!(result.is_ok());
1135
1136 // 等待回调执行
1137 tokio::time::sleep(Duration::from_millis(20)).await;
1138
1139 // 验证新回调被执行(增加了 10 而不是 1)
1140 assert_eq!(counter.load(Ordering::SeqCst), 10);
1141 }
1142
1143 #[tokio::test]
1144 async fn test_postpone_nonexistent_timer() {
1145 let timer = TimerWheel::with_defaults();
1146
1147 // 尝试推迟不存在的任务
1148 let fake_task = TimerWheel::create_task(Duration::from_millis(50), || async {});
1149 let fake_task_id = fake_task.get_id();
1150 // 不注册这个任务
1151
1152 let postponed = timer.postpone(fake_task_id, Duration::from_millis(100));
1153 assert!(!postponed);
1154 }
1155
1156 #[tokio::test]
1157 async fn test_postpone_batch() {
1158 use std::sync::Arc;
1159 let timer = TimerWheel::with_defaults();
1160 let counter = Arc::new(AtomicU32::new(0));
1161
1162 // 创建 3 个任务
1163 let mut task_ids = Vec::new();
1164 for _ in 0..3 {
1165 let counter_clone = Arc::clone(&counter);
1166 let task = TimerWheel::create_task(
1167 Duration::from_millis(50),
1168 move || {
1169 let counter = Arc::clone(&counter_clone);
1170 async move {
1171 counter.fetch_add(1, Ordering::SeqCst);
1172 }
1173 },
1174 );
1175 task_ids.push((task.get_id(), Duration::from_millis(150)));
1176 timer.register(task);
1177 }
1178
1179 // 批量推迟
1180 let postponed = timer.postpone_batch(&task_ids);
1181 assert_eq!(postponed, 3);
1182
1183 // 等待原定时间 50ms,任务不应该触发
1184 tokio::time::sleep(Duration::from_millis(70)).await;
1185 assert_eq!(counter.load(Ordering::SeqCst), 0);
1186
1187 // 等待新的触发时间(从推迟开始算,还需要等待约 150ms)
1188 tokio::time::sleep(Duration::from_millis(200)).await;
1189
1190 // 等待回调执行
1191 tokio::time::sleep(Duration::from_millis(20)).await;
1192 assert_eq!(counter.load(Ordering::SeqCst), 3);
1193 }
1194
1195 #[tokio::test]
1196 async fn test_postpone_batch_with_callbacks() {
1197 use std::sync::Arc;
1198 let timer = TimerWheel::with_defaults();
1199 let counter = Arc::new(AtomicU32::new(0));
1200
1201 // 创建 3 个任务
1202 let mut task_ids = Vec::new();
1203 for _ in 0..3 {
1204 let task = TimerWheel::create_task(
1205 Duration::from_millis(50),
1206 || async {},
1207 );
1208 task_ids.push(task.get_id());
1209 timer.register(task);
1210 }
1211
1212 // 批量推迟并替换回调
1213 let updates: Vec<_> = task_ids
1214 .into_iter()
1215 .map(|id| {
1216 let counter_clone = Arc::clone(&counter);
1217 (id, Duration::from_millis(150), move || {
1218 let counter = Arc::clone(&counter_clone);
1219 async move {
1220 counter.fetch_add(1, Ordering::SeqCst);
1221 }
1222 })
1223 })
1224 .collect();
1225
1226 let postponed = timer.postpone_batch_with_callbacks(updates);
1227 assert_eq!(postponed, 3);
1228
1229 // 等待原定时间 50ms,任务不应该触发
1230 tokio::time::sleep(Duration::from_millis(70)).await;
1231 assert_eq!(counter.load(Ordering::SeqCst), 0);
1232
1233 // 等待新的触发时间(从推迟开始算,还需要等待约 150ms)
1234 tokio::time::sleep(Duration::from_millis(200)).await;
1235
1236 // 等待回调执行
1237 tokio::time::sleep(Duration::from_millis(20)).await;
1238 assert_eq!(counter.load(Ordering::SeqCst), 3);
1239 }
1240
1241 #[tokio::test]
1242 async fn test_postpone_keeps_completion_receiver_valid() {
1243 use std::sync::Arc;
1244 let timer = TimerWheel::with_defaults();
1245 let counter = Arc::new(AtomicU32::new(0));
1246 let counter_clone = Arc::clone(&counter);
1247
1248 let task = TimerWheel::create_task(
1249 Duration::from_millis(50),
1250 move || {
1251 let counter = Arc::clone(&counter_clone);
1252 async move {
1253 counter.fetch_add(1, Ordering::SeqCst);
1254 }
1255 },
1256 );
1257 let task_id = task.get_id();
1258 let handle = timer.register(task);
1259
1260 // 推迟任务
1261 timer.postpone(task_id, Duration::from_millis(100));
1262
1263 // 验证原 completion_receiver 仍然有效(推迟后需要等待100ms,加上余量)
1264 let result = tokio::time::timeout(
1265 Duration::from_millis(200),
1266 handle.into_completion_receiver().0
1267 ).await;
1268 assert!(result.is_ok(), "Completion receiver should still work after postpone");
1269
1270 // 等待回调执行
1271 tokio::time::sleep(Duration::from_millis(20)).await;
1272 assert_eq!(counter.load(Ordering::SeqCst), 1);
1273 }
1274}
1275