kestrel_protocol_timer/timer.rs
1use crate::config::{BatchConfig, ServiceConfig, WheelConfig};
2use crate::task::{CallbackWrapper, TaskId, TaskCompletionReason};
3use crate::wheel::Wheel;
4use parking_lot::Mutex;
5use std::sync::Arc;
6use std::time::Duration;
7use tokio::sync::oneshot;
8use tokio::task::JoinHandle;
9
10/// 完成通知接收器,用于接收定时器完成通知
11pub struct CompletionReceiver(pub oneshot::Receiver<TaskCompletionReason>);
12
13/// 定时器句柄,用于管理定时器的生命周期
14///
15/// 注意:此类型不实现 Clone,以防止重复取消同一个定时器。
16/// 每个定时器只应有一个所有者。
17pub struct TimerHandle {
18 pub(crate) task_id: TaskId,
19 pub(crate) wheel: Arc<Mutex<Wheel>>,
20 pub(crate) completion_rx: CompletionReceiver,
21}
22
23impl TimerHandle {
24 pub(crate) fn new(task_id: TaskId, wheel: Arc<Mutex<Wheel>>, completion_rx: oneshot::Receiver<TaskCompletionReason>) -> Self {
25 Self { task_id, wheel, completion_rx: CompletionReceiver(completion_rx) }
26 }
27
28 /// 取消定时器
29 ///
30 /// # 返回
31 /// 如果任务存在且成功取消返回 true,否则返回 false
32 ///
33 /// # 示例
34 /// ```no_run
35 /// # use kestrel_protocol_timer::{TimerWheel, CallbackWrapper};
36 /// # use std::time::Duration;
37 /// #
38 /// # #[tokio::main]
39 /// # async fn main() {
40 /// let timer = TimerWheel::with_defaults();
41 /// let callback = Some(CallbackWrapper::new(|| async {}));
42 /// let task = TimerWheel::create_task(Duration::from_secs(1), callback);
43 /// let handle = timer.register(task);
44 ///
45 /// // 取消定时器
46 /// let success = handle.cancel();
47 /// println!("取消成功: {}", success);
48 /// # }
49 /// ```
50 pub fn cancel(&self) -> bool {
51 let mut wheel = self.wheel.lock();
52 wheel.cancel(self.task_id)
53 }
54
55 /// 获取完成通知接收器的可变引用
56 ///
57 /// # 示例
58 /// ```no_run
59 /// # use kestrel_protocol_timer::{TimerWheel, CallbackWrapper};
60 /// # use std::time::Duration;
61 /// #
62 /// # #[tokio::main]
63 /// # async fn main() {
64 /// let timer = TimerWheel::with_defaults();
65 /// let callback = Some(CallbackWrapper::new(|| async {
66 /// println!("Timer fired!");
67 /// }));
68 /// let task = TimerWheel::create_task(Duration::from_secs(1), callback);
69 /// let handle = timer.register(task);
70 ///
71 /// // 等待定时器完成(使用 into_completion_receiver 消耗句柄)
72 /// handle.into_completion_receiver().0.await.ok();
73 /// println!("Timer completed!");
74 /// # }
75 /// ```
76 pub fn completion_receiver(&mut self) -> &mut CompletionReceiver {
77 &mut self.completion_rx
78 }
79
80 /// 消耗句柄,返回完成通知接收器
81 ///
82 /// # 示例
83 /// ```no_run
84 /// # use kestrel_protocol_timer::{TimerWheel, CallbackWrapper};
85 /// # use std::time::Duration;
86 /// #
87 /// # #[tokio::main]
88 /// # async fn main() {
89 /// let timer = TimerWheel::with_defaults();
90 /// let callback = Some(CallbackWrapper::new(|| async {
91 /// println!("Timer fired!");
92 /// }));
93 /// let task = TimerWheel::create_task(Duration::from_secs(1), callback);
94 /// let handle = timer.register(task);
95 ///
96 /// // 等待定时器完成
97 /// handle.into_completion_receiver().0.await.ok();
98 /// println!("Timer completed!");
99 /// # }
100 /// ```
101 pub fn into_completion_receiver(self) -> CompletionReceiver {
102 self.completion_rx
103 }
104}
105
106/// 批量定时器句柄,用于管理批量调度的定时器
107///
108/// 通过共享 Wheel 引用减少内存开销,同时提供批量操作和迭代器访问能力。
109///
110/// 注意:此类型不实现 Clone,以防止重复取消同一批定时器。
111/// 如需访问单个定时器句柄,请使用 `into_iter()` 或 `into_handles()` 进行转换。
112pub struct BatchHandle {
113 pub(crate) task_ids: Vec<TaskId>,
114 pub(crate) wheel: Arc<Mutex<Wheel>>,
115 pub(crate) completion_rxs: Vec<oneshot::Receiver<TaskCompletionReason>>,
116}
117
118impl BatchHandle {
119 pub(crate) fn new(task_ids: Vec<TaskId>, wheel: Arc<Mutex<Wheel>>, completion_rxs: Vec<oneshot::Receiver<TaskCompletionReason>>) -> Self {
120 Self { task_ids, wheel, completion_rxs }
121 }
122
123 /// 批量取消所有定时器
124 ///
125 /// # 返回
126 /// 成功取消的任务数量
127 ///
128 /// # 示例
129 /// ```no_run
130 /// # use kestrel_protocol_timer::{TimerWheel, CallbackWrapper};
131 /// # use std::time::Duration;
132 /// #
133 /// # #[tokio::main]
134 /// # async fn main() {
135 /// let timer = TimerWheel::with_defaults();
136 /// let delays: Vec<Duration> = (0..10)
137 /// .map(|_| Duration::from_secs(1))
138 /// .collect();
139 /// let tasks = TimerWheel::create_batch(delays);
140 /// let batch = timer.register_batch(tasks);
141 ///
142 /// let cancelled = batch.cancel_all();
143 /// println!("取消了 {} 个定时器", cancelled);
144 /// # }
145 /// ```
146 pub fn cancel_all(self) -> usize {
147 let mut wheel = self.wheel.lock();
148 wheel.cancel_batch(&self.task_ids)
149 }
150
151 /// 将批量句柄转换为单个定时器句柄的 Vec
152 ///
153 /// 消耗 BatchHandle,为每个任务创建独立的 TimerHandle。
154 ///
155 /// # 示例
156 /// ```no_run
157 /// # use kestrel_protocol_timer::{TimerWheel, CallbackWrapper};
158 /// # use std::time::Duration;
159 /// #
160 /// # #[tokio::main]
161 /// # async fn main() {
162 /// let timer = TimerWheel::with_defaults();
163 /// let delays: Vec<Duration> = (0..3)
164 /// .map(|_| Duration::from_secs(1))
165 /// .collect();
166 /// let tasks = TimerWheel::create_batch(delays);
167 /// let batch = timer.register_batch(tasks);
168 ///
169 /// // 转换为独立的句柄
170 /// let handles = batch.into_handles();
171 /// for handle in handles {
172 /// // 可以单独操作每个句柄
173 /// }
174 /// # }
175 /// ```
176 pub fn into_handles(self) -> Vec<TimerHandle> {
177 self.task_ids
178 .into_iter()
179 .zip(self.completion_rxs.into_iter())
180 .map(|(task_id, rx)| {
181 TimerHandle::new(task_id, self.wheel.clone(), rx)
182 })
183 .collect()
184 }
185
186 /// 获取批量任务的数量
187 pub fn len(&self) -> usize {
188 self.task_ids.len()
189 }
190
191 /// 检查批量任务是否为空
192 pub fn is_empty(&self) -> bool {
193 self.task_ids.is_empty()
194 }
195
196 /// 获取所有任务 ID 的引用
197 pub fn task_ids(&self) -> &[TaskId] {
198 &self.task_ids
199 }
200
201 /// 获取所有完成通知接收器的引用
202 ///
203 /// # 返回
204 /// 所有任务的完成通知接收器列表引用
205 pub fn completion_receivers(&mut self) -> &mut Vec<oneshot::Receiver<TaskCompletionReason>> {
206 &mut self.completion_rxs
207 }
208
209 /// 消耗句柄,返回所有完成通知接收器
210 ///
211 /// # 返回
212 /// 所有任务的完成通知接收器列表
213 ///
214 /// # 示例
215 /// ```no_run
216 /// # use kestrel_protocol_timer::{TimerWheel, CallbackWrapper};
217 /// # use std::time::Duration;
218 /// #
219 /// # #[tokio::main]
220 /// # async fn main() {
221 /// let timer = TimerWheel::with_defaults();
222 /// let delays: Vec<Duration> = (0..3)
223 /// .map(|_| Duration::from_secs(1))
224 /// .collect();
225 /// let tasks = TimerWheel::create_batch(delays);
226 /// let batch = timer.register_batch(tasks);
227 ///
228 /// // 获取所有完成通知接收器
229 /// let receivers = batch.into_completion_receivers();
230 /// for rx in receivers {
231 /// tokio::spawn(async move {
232 /// if rx.await.is_ok() {
233 /// println!("A timer completed!");
234 /// }
235 /// });
236 /// }
237 /// # }
238 /// ```
239 pub fn into_completion_receivers(self) -> Vec<oneshot::Receiver<TaskCompletionReason>> {
240 self.completion_rxs
241 }
242}
243
244/// 实现 IntoIterator,允许直接迭代 BatchHandle
245///
246/// # 示例
247/// ```no_run
248/// # use kestrel_protocol_timer::{TimerWheel, CallbackWrapper};
249/// # use std::time::Duration;
250/// #
251/// # #[tokio::main]
252/// # async fn main() {
253/// let timer = TimerWheel::with_defaults();
254/// let delays: Vec<Duration> = (0..3)
255/// .map(|_| Duration::from_secs(1))
256/// .collect();
257/// let tasks = TimerWheel::create_batch(delays);
258/// let batch = timer.register_batch(tasks);
259///
260/// // 直接迭代,每个元素都是独立的 TimerHandle
261/// for handle in batch {
262/// // 可以单独操作每个句柄
263/// }
264/// # }
265/// ```
266impl IntoIterator for BatchHandle {
267 type Item = TimerHandle;
268 type IntoIter = BatchHandleIter;
269
270 fn into_iter(self) -> Self::IntoIter {
271 BatchHandleIter {
272 task_ids: self.task_ids.into_iter(),
273 completion_rxs: self.completion_rxs.into_iter(),
274 wheel: self.wheel,
275 }
276 }
277}
278
279/// BatchHandle 的迭代器
280pub struct BatchHandleIter {
281 task_ids: std::vec::IntoIter<TaskId>,
282 completion_rxs: std::vec::IntoIter<oneshot::Receiver<TaskCompletionReason>>,
283 wheel: Arc<Mutex<Wheel>>,
284}
285
286impl Iterator for BatchHandleIter {
287 type Item = TimerHandle;
288
289 fn next(&mut self) -> Option<Self::Item> {
290 match (self.task_ids.next(), self.completion_rxs.next()) {
291 (Some(task_id), Some(rx)) => {
292 Some(TimerHandle::new(task_id, self.wheel.clone(), rx))
293 }
294 _ => None,
295 }
296 }
297
298 fn size_hint(&self) -> (usize, Option<usize>) {
299 self.task_ids.size_hint()
300 }
301}
302
303impl ExactSizeIterator for BatchHandleIter {
304 fn len(&self) -> usize {
305 self.task_ids.len()
306 }
307}
308
309/// 时间轮定时器管理器
310pub struct TimerWheel {
311 /// 时间轮唯一标识符
312
313 /// 时间轮实例(使用 Arc<Mutex> 包装以支持多线程访问)
314 wheel: Arc<Mutex<Wheel>>,
315
316 /// 后台 tick 循环任务句柄
317 tick_handle: Option<JoinHandle<()>>,
318}
319
320impl TimerWheel {
321 /// 创建新的定时器管理器
322 ///
323 /// # 参数
324 /// - `config`: 时间轮配置(已经过验证)
325 ///
326 /// # 示例
327 /// ```no_run
328 /// use kestrel_protocol_timer::{TimerWheel, WheelConfig, TimerTask, BatchConfig};
329 /// use std::time::Duration;
330 ///
331 /// #[tokio::main]
332 /// async fn main() {
333 /// let config = WheelConfig::builder()
334 /// .l0_tick_duration(Duration::from_millis(10))
335 /// .l0_slot_count(512)
336 /// .l1_tick_duration(Duration::from_secs(1))
337 /// .l1_slot_count(64)
338 /// .build()
339 /// .unwrap();
340 /// let timer = TimerWheel::new(config, BatchConfig::default());
341 ///
342 /// // 使用两步式 API
343 /// let task = TimerWheel::create_task(Duration::from_secs(1), None);
344 /// let handle = timer.register(task);
345 /// }
346 /// ```
347 pub fn new(config: WheelConfig, batch_config: BatchConfig) -> Self {
348 let tick_duration = config.hierarchical.l0_tick_duration;
349 let wheel = Wheel::new(config, batch_config);
350 let wheel = Arc::new(Mutex::new(wheel));
351 let wheel_clone = wheel.clone();
352
353 // 启动后台 tick 循环
354 let tick_handle = tokio::spawn(async move {
355 Self::tick_loop(wheel_clone, tick_duration).await;
356 });
357
358 Self {
359 wheel,
360 tick_handle: Some(tick_handle),
361 }
362 }
363
364 /// 创建带默认配置的定时器管理器(分层模式)
365 /// - L0 层 tick 时长: 10ms, 槽位数量: 512
366 /// - L1 层 tick 时长: 1s, 槽位数量: 64
367 ///
368 /// # 示例
369 /// ```no_run
370 /// use kestrel_protocol_timer::TimerWheel;
371 ///
372 /// #[tokio::main]
373 /// async fn main() {
374 /// let timer = TimerWheel::with_defaults();
375 /// }
376 /// ```
377 pub fn with_defaults() -> Self {
378 Self::new(WheelConfig::default(), BatchConfig::default())
379 }
380
381 /// 创建与此时间轮绑定的 TimerService(使用默认配置)
382 ///
383 /// # 返回
384 /// 绑定到此时间轮的 TimerService 实例
385 ///
386 /// # 参数
387 /// - `service_config`: 服务配置
388 ///
389 /// # 示例
390 /// ```no_run
391 /// use kestrel_protocol_timer::{TimerWheel, TimerService, CallbackWrapper, ServiceConfig};
392 /// use std::time::Duration;
393 ///
394 ///
395 /// #[tokio::main]
396 /// async fn main() {
397 /// let timer = TimerWheel::with_defaults();
398 /// let mut service = timer.create_service(ServiceConfig::default());
399 ///
400 /// // 使用两步式 API 通过 service 批量调度定时器
401 /// let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = (0..5)
402 /// .map(|_| (Duration::from_millis(100), Some(CallbackWrapper::new(|| async {}))))
403 /// .collect();
404 /// let tasks = TimerService::create_batch_with_callbacks(callbacks);
405 /// service.register_batch(tasks).unwrap();
406 ///
407 /// // 接收超时通知
408 /// let mut rx = service.take_receiver().unwrap();
409 /// while let Some(task_id) = rx.recv().await {
410 /// println!("Task {:?} completed", task_id);
411 /// }
412 /// }
413 /// ```
414 pub fn create_service(&self, service_config: ServiceConfig) -> crate::service::TimerService {
415 crate::service::TimerService::new(self.wheel.clone(), service_config)
416 }
417
418 /// 创建与此时间轮绑定的 TimerService(使用自定义配置)
419 ///
420 /// # 参数
421 /// - `config`: 服务配置
422 ///
423 /// # 返回
424 /// 绑定到此时间轮的 TimerService 实例
425 ///
426 /// # 示例
427 /// ```no_run
428 /// use kestrel_protocol_timer::{TimerWheel, ServiceConfig};
429 ///
430 /// #[tokio::main]
431 /// async fn main() {
432 /// let timer = TimerWheel::with_defaults();
433 /// let config = ServiceConfig::builder()
434 /// .command_channel_capacity(1024)
435 /// .timeout_channel_capacity(2000)
436 /// .build()
437 /// .unwrap();
438 /// let service = timer.create_service_with_config(config);
439 /// }
440 /// ```
441 pub fn create_service_with_config(&self, config: ServiceConfig) -> crate::service::TimerService {
442 crate::service::TimerService::new(self.wheel.clone(), config)
443 }
444
445 /// 创建定时器任务(申请阶段)
446 ///
447 /// # 参数
448 /// - `delay`: 延迟时间
449 /// - `callback`: 实现了 TimerCallback trait 的回调对象
450 ///
451 /// # 返回
452 /// 返回 TimerTask,需要通过 `register()` 注册到时间轮
453 ///
454 /// # 示例
455 /// ```no_run
456 /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
457 /// use std::time::Duration;
458 ///
459 ///
460 /// #[tokio::main]
461 /// async fn main() {
462 /// let timer = TimerWheel::with_defaults();
463 ///
464 /// // 步骤 1: 创建任务
465 /// let task = TimerWheel::create_task(Duration::from_secs(1), Some(CallbackWrapper::new(|| async {
466 /// println!("Timer fired!");
467 /// })));
468 ///
469 /// // 获取任务 ID
470 /// let task_id = task.get_id();
471 /// println!("Created task: {:?}", task_id);
472 ///
473 /// // 步骤 2: 注册任务
474 /// let handle = timer.register(task);
475 /// }
476 /// ```
477 #[inline]
478 pub fn create_task(delay: Duration, callback: Option<CallbackWrapper>) -> crate::task::TimerTask {
479 crate::task::TimerTask::new(delay, callback)
480 }
481
482 /// 批量创建定时器任务(申请阶段)
483 ///
484 /// # 参数
485 /// - `callbacks`: (延迟时间, 回调) 的元组列表
486 ///
487 /// # 返回
488 /// 返回 TimerTask 列表,需要通过 `register_batch()` 注册到时间轮
489 ///
490 /// # 示例
491 /// ```no_run
492 /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
493 /// use std::time::Duration;
494 /// use std::sync::Arc;
495 /// use std::sync::atomic::{AtomicU32, Ordering};
496 ///
497 /// #[tokio::main]
498 /// async fn main() {
499 /// let timer = TimerWheel::with_defaults();
500 /// let counter = Arc::new(AtomicU32::new(0));
501 ///
502 /// // 步骤 1: 批量创建任务
503 /// let delays: Vec<Duration> = (0..3)
504 /// .map(|_| Duration::from_millis(100))
505 /// .collect();
506 ///
507 /// let tasks = TimerWheel::create_batch(delays);
508 /// println!("Created {} tasks", tasks.len());
509 ///
510 /// // 步骤 2: 批量注册任务
511 /// let batch = timer.register_batch(tasks);
512 /// }
513 /// ```
514 #[inline]
515 pub fn create_batch(delays: Vec<Duration>) -> Vec<crate::task::TimerTask>
516 {
517 delays
518 .into_iter()
519 .map(|delay| crate::task::TimerTask::new(delay, None))
520 .collect()
521 }
522
523 /// 批量创建定时器任务(申请阶段)
524 ///
525 /// # 参数
526 /// - `callbacks`: (延迟时间, 回调) 的元组列表
527 ///
528 /// # 返回
529 /// 返回 TimerTask 列表,需要通过 `register_batch()` 注册到时间轮
530 ///
531 /// # 示例
532 /// ```no_run
533 /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
534 /// use std::time::Duration;
535 /// use std::sync::Arc;
536 /// use std::sync::atomic::{AtomicU32, Ordering};
537 ///
538 /// #[tokio::main]
539 /// async fn main() {
540 /// let timer = TimerWheel::with_defaults();
541 /// let counter = Arc::new(AtomicU32::new(0));
542 ///
543 /// // 步骤 1: 批量创建任务
544 /// let delays: Vec<Duration> = (0..3)
545 /// .map(|_| Duration::from_millis(100))
546 /// .collect();
547 /// let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = delays
548 /// .into_iter()
549 /// .map(|delay| {
550 /// let counter = Arc::clone(&counter);
551 /// let callback = Some(CallbackWrapper::new(move || {
552 /// let counter = Arc::clone(&counter);
553 /// async move {
554 /// counter.fetch_add(1, Ordering::SeqCst);
555 /// }
556 /// }));
557 /// (delay, callback)
558 /// })
559 /// .collect();
560 ///
561 /// let tasks = TimerWheel::create_batch_with_callbacks(callbacks);
562 /// println!("Created {} tasks", tasks.len());
563 ///
564 /// // 步骤 2: 批量注册任务
565 /// let batch = timer.register_batch(tasks);
566 /// }
567 /// ```
568 #[inline]
569 pub fn create_batch_with_callbacks(callbacks: Vec<(Duration, Option<CallbackWrapper>)>) -> Vec<crate::task::TimerTask>
570 {
571 callbacks
572 .into_iter()
573 .map(|(delay, callback)| crate::task::TimerTask::new(delay, callback))
574 .collect()
575 }
576
577 /// 注册定时器任务到时间轮(注册阶段)
578 ///
579 /// # 参数
580 /// - `task`: 通过 `create_task()` 创建的任务
581 ///
582 /// # 返回
583 /// 返回定时器句柄,可用于取消定时器和接收完成通知
584 ///
585 /// # 示例
586 /// ```no_run
587 /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
588 ///
589 /// use std::time::Duration;
590 ///
591 /// #[tokio::main]
592 /// async fn main() {
593 /// let timer = TimerWheel::with_defaults();
594 ///
595 /// let task = TimerWheel::create_task(Duration::from_secs(1), Some(CallbackWrapper::new(|| async {
596 /// println!("Timer fired!");
597 /// })));
598 /// let task_id = task.get_id();
599 ///
600 /// let handle = timer.register(task);
601 ///
602 /// // 等待定时器完成
603 /// handle.into_completion_receiver().0.await.ok();
604 /// }
605 /// ```
606 #[inline]
607 pub fn register(&self, task: crate::task::TimerTask) -> TimerHandle {
608 let (completion_tx, completion_rx) = oneshot::channel();
609 let notifier = crate::task::CompletionNotifier(completion_tx);
610
611 let task_id = task.id;
612
613 // 单次加锁完成所有操作
614 {
615 let mut wheel_guard = self.wheel.lock();
616 wheel_guard.insert(task, notifier);
617 }
618
619 TimerHandle::new(task_id, self.wheel.clone(), completion_rx)
620 }
621
622 /// 批量注册定时器任务到时间轮(注册阶段)
623 ///
624 /// # 参数
625 /// - `tasks`: 通过 `create_batch()` 创建的任务列表
626 ///
627 /// # 返回
628 /// 返回批量定时器句柄
629 ///
630 /// # 示例
631 /// ```no_run
632 /// use kestrel_protocol_timer::{TimerWheel, TimerTask};
633 /// use std::time::Duration;
634 ///
635 /// #[tokio::main]
636 /// async fn main() {
637 /// let timer = TimerWheel::with_defaults();
638 ///
639 /// let delays: Vec<Duration> = (0..3)
640 /// .map(|_| Duration::from_secs(1))
641 /// .collect();
642 /// let tasks = TimerWheel::create_batch(delays);
643 ///
644 /// let batch = timer.register_batch(tasks);
645 /// println!("Registered {} timers", batch.len());
646 /// }
647 /// ```
648 #[inline]
649 pub fn register_batch(&self, tasks: Vec<crate::task::TimerTask>) -> BatchHandle {
650 let task_count = tasks.len();
651 let mut completion_rxs = Vec::with_capacity(task_count);
652 let mut task_ids = Vec::with_capacity(task_count);
653 let mut prepared_tasks = Vec::with_capacity(task_count);
654
655 // 步骤1: 准备所有 channels 和 notifiers(无锁)
656 // 优化:使用 for 循环代替 map + collect,避免闭包捕获开销
657 for task in tasks {
658 let (completion_tx, completion_rx) = oneshot::channel();
659 let notifier = crate::task::CompletionNotifier(completion_tx);
660
661 task_ids.push(task.id);
662 completion_rxs.push(completion_rx);
663 prepared_tasks.push((task, notifier));
664 }
665
666 // 步骤2: 单次加锁,批量插入
667 {
668 let mut wheel_guard = self.wheel.lock();
669 wheel_guard.insert_batch(prepared_tasks);
670 }
671
672 BatchHandle::new(task_ids, self.wheel.clone(), completion_rxs)
673 }
674
675 /// 取消定时器
676 ///
677 /// # 参数
678 /// - `task_id`: 任务 ID
679 ///
680 /// # 返回
681 /// 如果任务存在且成功取消返回 true,否则返回 false
682 ///
683 /// # 示例
684 /// ```no_run
685 /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
686 ///
687 /// use std::time::Duration;
688 ///
689 /// #[tokio::main]
690 /// async fn main() {
691 /// let timer = TimerWheel::with_defaults();
692 ///
693 /// let task = TimerWheel::create_task(Duration::from_secs(10), Some(CallbackWrapper::new(|| async {
694 /// println!("Timer fired!");
695 /// })));
696 /// let task_id = task.get_id();
697 /// let _handle = timer.register(task);
698 ///
699 /// // 使用任务 ID 取消
700 /// let cancelled = timer.cancel(task_id);
701 /// println!("取消成功: {}", cancelled);
702 /// }
703 /// ```
704 #[inline]
705 pub fn cancel(&self, task_id: TaskId) -> bool {
706 let mut wheel = self.wheel.lock();
707 wheel.cancel(task_id)
708 }
709
710 /// 批量取消定时器
711 ///
712 /// # 参数
713 /// - `task_ids`: 要取消的任务 ID 列表
714 ///
715 /// # 返回
716 /// 成功取消的任务数量
717 ///
718 /// # 性能优势
719 /// - 批量处理减少锁竞争
720 /// - 内部优化批量取消操作
721 ///
722 /// # 示例
723 /// ```no_run
724 /// use kestrel_protocol_timer::{TimerWheel, TimerTask};
725 /// use std::time::Duration;
726 ///
727 /// #[tokio::main]
728 /// async fn main() {
729 /// let timer = TimerWheel::with_defaults();
730 ///
731 /// // 创建多个定时器
732 /// let task1 = TimerWheel::create_task(Duration::from_secs(10), None);
733 /// let task2 = TimerWheel::create_task(Duration::from_secs(10), None);
734 /// let task3 = TimerWheel::create_task(Duration::from_secs(10), None);
735 ///
736 /// let task_ids = vec![task1.get_id(), task2.get_id(), task3.get_id()];
737 ///
738 /// let _h1 = timer.register(task1);
739 /// let _h2 = timer.register(task2);
740 /// let _h3 = timer.register(task3);
741 ///
742 /// // 批量取消
743 /// let cancelled = timer.cancel_batch(&task_ids);
744 /// println!("已取消 {} 个定时器", cancelled);
745 /// }
746 /// ```
747 #[inline]
748 pub fn cancel_batch(&self, task_ids: &[TaskId]) -> usize {
749 let mut wheel = self.wheel.lock();
750 wheel.cancel_batch(task_ids)
751 }
752
753 /// 推迟定时器
754 ///
755 /// # 参数
756 /// - `task_id`: 要推迟的任务 ID
757 /// - `new_delay`: 新的延迟时间(从当前时间点重新计算)
758 /// - `callback`: 新的回调函数,传入 `None` 保持原回调不变,传入 `Some` 替换为新回调
759 ///
760 /// # 返回
761 /// 如果任务存在且成功推迟返回 true,否则返回 false
762 ///
763 /// # 注意
764 /// - 推迟后任务 ID 保持不变
765 /// - 原有的 completion_receiver 仍然有效
766 ///
767 /// # 示例
768 ///
769 /// ## 保持原回调
770 /// ```no_run
771 /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
772 /// use std::time::Duration;
773 ///
774 ///
775 /// #[tokio::main]
776 /// async fn main() {
777 /// let timer = TimerWheel::with_defaults();
778 ///
779 /// let task = TimerWheel::create_task(Duration::from_secs(5), Some(CallbackWrapper::new(|| async {
780 /// println!("Timer fired!");
781 /// })));
782 /// let task_id = task.get_id();
783 /// let _handle = timer.register(task);
784 ///
785 /// // 推迟到 10 秒后触发(保持原回调)
786 /// let success = timer.postpone(task_id, Duration::from_secs(10), None);
787 /// println!("推迟成功: {}", success);
788 /// }
789 /// ```
790 ///
791 /// ## 替换为新回调
792 /// ```no_run
793 /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
794 /// use std::time::Duration;
795 ///
796 /// #[tokio::main]
797 /// async fn main() {
798 /// let timer = TimerWheel::with_defaults();
799 ///
800 /// let task = TimerWheel::create_task(Duration::from_secs(5), Some(CallbackWrapper::new(|| async {
801 /// println!("Original callback!");
802 /// })));
803 /// let task_id = task.get_id();
804 /// let _handle = timer.register(task);
805 ///
806 /// // 推迟到 10 秒后触发(并替换为新回调)
807 /// let success = timer.postpone(task_id, Duration::from_secs(10), Some(CallbackWrapper::new(|| async {
808 /// println!("New callback!");
809 /// })));
810 /// println!("推迟成功: {}", success);
811 /// }
812 /// ```
813 #[inline]
814 pub fn postpone(
815 &self,
816 task_id: TaskId,
817 new_delay: Duration,
818 callback: Option<CallbackWrapper>,
819 ) -> bool {
820 let mut wheel = self.wheel.lock();
821 wheel.postpone(task_id, new_delay, callback)
822 }
823
824 /// 批量推迟定时器(保持原回调)
825 ///
826 /// # 参数
827 /// - `updates`: (任务ID, 新延迟) 的元组列表
828 ///
829 /// # 返回
830 /// 成功推迟的任务数量
831 ///
832 /// # 注意
833 /// - 此方法会保持所有任务的原回调不变
834 /// - 如需替换回调,请使用 `postpone_batch_with_callbacks`
835 ///
836 /// # 性能优势
837 /// - 批量处理减少锁竞争
838 /// - 内部优化批量推迟操作
839 ///
840 /// # 示例
841 /// ```no_run
842 /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
843 /// use std::time::Duration;
844 ///
845 /// #[tokio::main]
846 /// async fn main() {
847 /// let timer = TimerWheel::with_defaults();
848 ///
849 /// // 创建多个带回调的定时器
850 /// let task1 = TimerWheel::create_task(Duration::from_secs(5), Some(CallbackWrapper::new(|| async {
851 /// println!("Task 1 fired!");
852 /// })));
853 /// let task2 = TimerWheel::create_task(Duration::from_secs(5), Some(CallbackWrapper::new(|| async {
854 /// println!("Task 2 fired!");
855 /// })));
856 /// let task3 = TimerWheel::create_task(Duration::from_secs(5), Some(CallbackWrapper::new(|| async {
857 /// println!("Task 3 fired!");
858 /// })));
859 ///
860 /// let task_ids = vec![
861 /// (task1.get_id(), Duration::from_secs(10)),
862 /// (task2.get_id(), Duration::from_secs(15)),
863 /// (task3.get_id(), Duration::from_secs(20)),
864 /// ];
865 ///
866 /// timer.register(task1);
867 /// timer.register(task2);
868 /// timer.register(task3);
869 ///
870 /// // 批量推迟(保持原回调)
871 /// let postponed = timer.postpone_batch(task_ids);
872 /// println!("已推迟 {} 个定时器", postponed);
873 /// }
874 /// ```
875 #[inline]
876 pub fn postpone_batch(&self, updates: Vec<(TaskId, Duration)>) -> usize {
877 let mut wheel = self.wheel.lock();
878 wheel.postpone_batch(updates)
879 }
880
881 /// 批量推迟定时器(替换回调)
882 ///
883 /// # 参数
884 /// - `updates`: (任务ID, 新延迟, 新回调) 的元组列表
885 ///
886 /// # 返回
887 /// 成功推迟的任务数量
888 ///
889 /// # 性能优势
890 /// - 批量处理减少锁竞争
891 /// - 内部优化批量推迟操作
892 ///
893 /// # 示例
894 /// ```no_run
895 /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
896 /// use std::time::Duration;
897 /// use std::sync::Arc;
898 /// use std::sync::atomic::{AtomicU32, Ordering};
899 ///
900 /// #[tokio::main]
901 /// async fn main() {
902 /// let timer = TimerWheel::with_defaults();
903 /// let counter = Arc::new(AtomicU32::new(0));
904 ///
905 /// // 创建多个定时器
906 /// let task1 = TimerWheel::create_task(Duration::from_secs(5), None);
907 /// let task2 = TimerWheel::create_task(Duration::from_secs(5), None);
908 ///
909 /// let id1 = task1.get_id();
910 /// let id2 = task2.get_id();
911 ///
912 /// timer.register(task1);
913 /// timer.register(task2);
914 ///
915 /// // 批量推迟并替换回调
916 /// let updates: Vec<_> = vec![id1, id2]
917 /// .into_iter()
918 /// .map(|id| {
919 /// let counter = Arc::clone(&counter);
920 /// (id, Duration::from_secs(10), Some(CallbackWrapper::new(move || {
921 /// let counter = Arc::clone(&counter);
922 /// async move { counter.fetch_add(1, Ordering::SeqCst); }
923 /// })))
924 /// })
925 /// .collect();
926 /// let postponed = timer.postpone_batch_with_callbacks(updates);
927 /// println!("已推迟 {} 个定时器", postponed);
928 /// }
929 /// ```
930 #[inline]
931 pub fn postpone_batch_with_callbacks(
932 &self,
933 updates: Vec<(TaskId, Duration, Option<CallbackWrapper>)>,
934 ) -> usize {
935 let mut wheel = self.wheel.lock();
936 wheel.postpone_batch_with_callbacks(updates.to_vec())
937 }
938
939 /// 核心 tick 循环
940 async fn tick_loop(wheel: Arc<Mutex<Wheel>>, tick_duration: Duration) {
941 let mut interval = tokio::time::interval(tick_duration);
942 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
943
944 loop {
945 interval.tick().await;
946
947 // 推进时间轮并获取到期任务
948 let expired_tasks = {
949 let mut wheel_guard = wheel.lock();
950 wheel_guard.advance()
951 };
952
953 // 执行到期任务
954 for task in expired_tasks {
955 let callback = task.get_callback();
956
957 // 移动task的所有权来获取completion_notifier
958 let notifier = task.completion_notifier;
959
960 // 只有注册过的任务才有 notifier
961 if let Some(notifier) = notifier {
962 // 在独立的 tokio 任务中执行回调,并在回调完成后发送通知
963 if let Some(callback) = callback {
964 tokio::spawn(async move {
965 // 执行回调
966 let future = callback.call();
967 future.await;
968
969 // 回调执行完成后发送通知
970 let _ = notifier.0.send(TaskCompletionReason::Expired);
971 });
972 } else {
973 // 如果没有回调,立即发送完成通知
974 let _ = notifier.0.send(TaskCompletionReason::Expired);
975 }
976 }
977 }
978 }
979 }
980
981 /// 停止定时器管理器
982 pub async fn shutdown(mut self) {
983 if let Some(handle) = self.tick_handle.take() {
984 handle.abort();
985 let _ = handle.await;
986 }
987 }
988}
989
990impl Drop for TimerWheel {
991 fn drop(&mut self) {
992 if let Some(handle) = self.tick_handle.take() {
993 handle.abort();
994 }
995 }
996}
997
998#[cfg(test)]
999mod tests {
1000 use super::*;
1001 use std::sync::atomic::{AtomicU32, Ordering};
1002
1003 #[tokio::test]
1004 async fn test_timer_creation() {
1005 let _timer = TimerWheel::with_defaults();
1006 }
1007
1008 #[tokio::test]
1009 async fn test_schedule_once() {
1010 use std::sync::Arc;
1011 let timer = TimerWheel::with_defaults();
1012 let counter = Arc::new(AtomicU32::new(0));
1013 let counter_clone = Arc::clone(&counter);
1014
1015 let task = TimerWheel::create_task(
1016 Duration::from_millis(50),
1017 Some(CallbackWrapper::new(move || {
1018 let counter = Arc::clone(&counter_clone);
1019 async move {
1020 counter.fetch_add(1, Ordering::SeqCst);
1021 }
1022 })),
1023 );
1024 let _handle = timer.register(task);
1025
1026 // 等待定时器触发
1027 tokio::time::sleep(Duration::from_millis(100)).await;
1028 assert_eq!(counter.load(Ordering::SeqCst), 1);
1029 }
1030
1031 #[tokio::test]
1032 async fn test_cancel_timer() {
1033 use std::sync::Arc;
1034 let timer = TimerWheel::with_defaults();
1035 let counter = Arc::new(AtomicU32::new(0));
1036 let counter_clone = Arc::clone(&counter);
1037
1038 let task = TimerWheel::create_task(
1039 Duration::from_millis(100),
1040 Some(CallbackWrapper::new(move || {
1041 let counter = Arc::clone(&counter_clone);
1042 async move {
1043 counter.fetch_add(1, Ordering::SeqCst);
1044 }
1045 })),
1046 );
1047 let handle = timer.register(task);
1048
1049 // 立即取消
1050 let cancel_result = handle.cancel();
1051 assert!(cancel_result);
1052
1053 // 等待足够长时间确保定时器不会触发
1054 tokio::time::sleep(Duration::from_millis(200)).await;
1055 assert_eq!(counter.load(Ordering::SeqCst), 0);
1056 }
1057
1058 #[tokio::test]
1059 async fn test_cancel_immediate() {
1060 use std::sync::Arc;
1061 let timer = TimerWheel::with_defaults();
1062 let counter = Arc::new(AtomicU32::new(0));
1063 let counter_clone = Arc::clone(&counter);
1064
1065 let task = TimerWheel::create_task(
1066 Duration::from_millis(100),
1067 Some(CallbackWrapper::new(move || {
1068 let counter = Arc::clone(&counter_clone);
1069 async move {
1070 counter.fetch_add(1, Ordering::SeqCst);
1071 }
1072 })),
1073 );
1074 let handle = timer.register(task);
1075
1076 // 立即取消
1077 let cancel_result = handle.cancel();
1078 assert!(cancel_result);
1079
1080 // 等待足够长时间确保定时器不会触发
1081 tokio::time::sleep(Duration::from_millis(200)).await;
1082 assert_eq!(counter.load(Ordering::SeqCst), 0);
1083 }
1084
1085 #[tokio::test]
1086 async fn test_postpone_timer() {
1087 use std::sync::Arc;
1088 let timer = TimerWheel::with_defaults();
1089 let counter = Arc::new(AtomicU32::new(0));
1090 let counter_clone = Arc::clone(&counter);
1091
1092 let task = TimerWheel::create_task(
1093 Duration::from_millis(50),
1094 Some(CallbackWrapper::new(move || {
1095 let counter = Arc::clone(&counter_clone);
1096 async move {
1097 counter.fetch_add(1, Ordering::SeqCst);
1098 }
1099 })),
1100 );
1101 let task_id = task.get_id();
1102 let handle = timer.register(task);
1103
1104 // 推迟任务到 150ms
1105 let postponed = timer.postpone(task_id, Duration::from_millis(150), None);
1106 assert!(postponed);
1107
1108 // 等待原定时间 50ms,任务不应该触发
1109 tokio::time::sleep(Duration::from_millis(70)).await;
1110 assert_eq!(counter.load(Ordering::SeqCst), 0);
1111
1112 // 等待新的触发时间(从推迟开始算,还需要等待约 150ms)
1113 let result = tokio::time::timeout(
1114 Duration::from_millis(200),
1115 handle.into_completion_receiver().0
1116 ).await;
1117 assert!(result.is_ok());
1118
1119 // 等待回调执行
1120 tokio::time::sleep(Duration::from_millis(20)).await;
1121 assert_eq!(counter.load(Ordering::SeqCst), 1);
1122 }
1123
1124 #[tokio::test]
1125 async fn test_postpone_with_callback() {
1126 use std::sync::Arc;
1127 let timer = TimerWheel::with_defaults();
1128 let counter = Arc::new(AtomicU32::new(0));
1129 let counter_clone1 = Arc::clone(&counter);
1130 let counter_clone2 = Arc::clone(&counter);
1131
1132 // 创建任务,原始回调增加 1
1133 let task = TimerWheel::create_task(
1134 Duration::from_millis(50),
1135 Some(CallbackWrapper::new(move || {
1136 let counter = Arc::clone(&counter_clone1);
1137 async move {
1138 counter.fetch_add(1, Ordering::SeqCst);
1139 }
1140 })),
1141 );
1142 let task_id = task.get_id();
1143 let handle = timer.register(task);
1144
1145 // 推迟任务并替换回调,新回调增加 10
1146 let postponed = timer.postpone(
1147 task_id,
1148 Duration::from_millis(100),
1149 Some(CallbackWrapper::new(move || {
1150 let counter = Arc::clone(&counter_clone2);
1151 async move {
1152 counter.fetch_add(10, Ordering::SeqCst);
1153 }
1154 })),
1155 );
1156 assert!(postponed);
1157
1158 // 等待任务触发(推迟后需要等待100ms,加上余量)
1159 let result = tokio::time::timeout(
1160 Duration::from_millis(200),
1161 handle.into_completion_receiver().0
1162 ).await;
1163 assert!(result.is_ok());
1164
1165 // 等待回调执行
1166 tokio::time::sleep(Duration::from_millis(20)).await;
1167
1168 // 验证新回调被执行(增加了 10 而不是 1)
1169 assert_eq!(counter.load(Ordering::SeqCst), 10);
1170 }
1171
1172 #[tokio::test]
1173 async fn test_postpone_nonexistent_timer() {
1174 let timer = TimerWheel::with_defaults();
1175
1176 // 尝试推迟不存在的任务
1177 let fake_task = TimerWheel::create_task(Duration::from_millis(50), None);
1178 let fake_task_id = fake_task.get_id();
1179 // 不注册这个任务
1180
1181 let postponed = timer.postpone(fake_task_id, Duration::from_millis(100), None);
1182 assert!(!postponed);
1183 }
1184
1185 #[tokio::test]
1186 async fn test_postpone_batch() {
1187 use std::sync::Arc;
1188 let timer = TimerWheel::with_defaults();
1189 let counter = Arc::new(AtomicU32::new(0));
1190
1191 // 创建 3 个任务
1192 let mut task_ids = Vec::new();
1193 for _ in 0..3 {
1194 let counter_clone = Arc::clone(&counter);
1195 let task = TimerWheel::create_task(
1196 Duration::from_millis(50),
1197 Some(CallbackWrapper::new(move || {
1198 let counter = Arc::clone(&counter_clone);
1199 async move {
1200 counter.fetch_add(1, Ordering::SeqCst);
1201 }
1202 })),
1203 );
1204 task_ids.push((task.get_id(), Duration::from_millis(150)));
1205 timer.register(task);
1206 }
1207
1208 // 批量推迟
1209 let postponed = timer.postpone_batch(task_ids);
1210 assert_eq!(postponed, 3);
1211
1212 // 等待原定时间 50ms,任务不应该触发
1213 tokio::time::sleep(Duration::from_millis(70)).await;
1214 assert_eq!(counter.load(Ordering::SeqCst), 0);
1215
1216 // 等待新的触发时间(从推迟开始算,还需要等待约 150ms)
1217 tokio::time::sleep(Duration::from_millis(200)).await;
1218
1219 // 等待回调执行
1220 tokio::time::sleep(Duration::from_millis(20)).await;
1221 assert_eq!(counter.load(Ordering::SeqCst), 3);
1222 }
1223
1224 #[tokio::test]
1225 async fn test_postpone_batch_with_callbacks() {
1226 use std::sync::Arc;
1227 let timer = TimerWheel::with_defaults();
1228 let counter = Arc::new(AtomicU32::new(0));
1229
1230 // 创建 3 个任务
1231 let mut task_ids = Vec::new();
1232 for _ in 0..3 {
1233 let task = TimerWheel::create_task(
1234 Duration::from_millis(50),
1235 None
1236 );
1237 task_ids.push(task.get_id());
1238 timer.register(task);
1239 }
1240
1241 // 批量推迟并替换回调
1242 let updates: Vec<_> = task_ids
1243 .into_iter()
1244 .map(|id| {
1245 let counter_clone = Arc::clone(&counter);
1246 (id, Duration::from_millis(150), Some(CallbackWrapper::new(move || {
1247 let counter = Arc::clone(&counter_clone);
1248 async move {
1249 counter.fetch_add(1, Ordering::SeqCst);
1250 }
1251 })))
1252 })
1253 .collect();
1254
1255 let postponed = timer.postpone_batch_with_callbacks(updates);
1256 assert_eq!(postponed, 3);
1257
1258 // 等待原定时间 50ms,任务不应该触发
1259 tokio::time::sleep(Duration::from_millis(70)).await;
1260 assert_eq!(counter.load(Ordering::SeqCst), 0);
1261
1262 // 等待新的触发时间(从推迟开始算,还需要等待约 150ms)
1263 tokio::time::sleep(Duration::from_millis(200)).await;
1264
1265 // 等待回调执行
1266 tokio::time::sleep(Duration::from_millis(20)).await;
1267 assert_eq!(counter.load(Ordering::SeqCst), 3);
1268 }
1269
1270 #[tokio::test]
1271 async fn test_postpone_keeps_completion_receiver_valid() {
1272 use std::sync::Arc;
1273 let timer = TimerWheel::with_defaults();
1274 let counter = Arc::new(AtomicU32::new(0));
1275 let counter_clone = Arc::clone(&counter);
1276
1277 let task = TimerWheel::create_task(
1278 Duration::from_millis(50),
1279 Some(CallbackWrapper::new(move || {
1280 let counter = Arc::clone(&counter_clone);
1281 async move {
1282 counter.fetch_add(1, Ordering::SeqCst);
1283 }
1284 })),
1285 );
1286 let task_id = task.get_id();
1287 let handle = timer.register(task);
1288
1289 // 推迟任务
1290 timer.postpone(task_id, Duration::from_millis(100), None);
1291
1292 // 验证原 completion_receiver 仍然有效(推迟后需要等待100ms,加上余量)
1293 let result = tokio::time::timeout(
1294 Duration::from_millis(200),
1295 handle.into_completion_receiver().0
1296 ).await;
1297 assert!(result.is_ok(), "Completion receiver should still work after postpone");
1298
1299 // 等待回调执行
1300 tokio::time::sleep(Duration::from_millis(20)).await;
1301 assert_eq!(counter.load(Ordering::SeqCst), 1);
1302 }
1303}
1304