kestrel_protocol_timer/timer.rs
1use crate::config::{ServiceConfig, WheelConfig};
2use crate::task::{CallbackWrapper, TaskId, TaskCompletionReason};
3use crate::wheel::Wheel;
4use parking_lot::Mutex;
5use std::sync::Arc;
6use std::time::Duration;
7use tokio::sync::oneshot;
8use tokio::task::JoinHandle;
9
10/// 完成通知接收器,用于接收定时器完成通知
11pub struct CompletionReceiver(pub oneshot::Receiver<TaskCompletionReason>);
12
13/// 定时器句柄,用于管理定时器的生命周期
14///
15/// 注意:此类型不实现 Clone,以防止重复取消同一个定时器。
16/// 每个定时器只应有一个所有者。
17pub struct TimerHandle {
18 pub(crate) task_id: TaskId,
19 pub(crate) wheel: Arc<Mutex<Wheel>>,
20 pub(crate) completion_rx: CompletionReceiver,
21}
22
23impl TimerHandle {
24 pub(crate) fn new(task_id: TaskId, wheel: Arc<Mutex<Wheel>>, completion_rx: oneshot::Receiver<TaskCompletionReason>) -> Self {
25 Self { task_id, wheel, completion_rx: CompletionReceiver(completion_rx) }
26 }
27
28 /// 取消定时器
29 ///
30 /// # 返回
31 /// 如果任务存在且成功取消返回 true,否则返回 false
32 ///
33 /// # 示例
34 /// ```no_run
35 /// # use kestrel_protocol_timer::{TimerWheel, CallbackWrapper};
36 /// # use std::time::Duration;
37 /// #
38 /// # #[tokio::main]
39 /// # async fn main() {
40 /// let timer = TimerWheel::with_defaults();
41 /// let callback = Some(CallbackWrapper::new(|| async {}));
42 /// let task = TimerWheel::create_task(Duration::from_secs(1), callback);
43 /// let handle = timer.register(task);
44 ///
45 /// // 取消定时器
46 /// let success = handle.cancel();
47 /// println!("取消成功: {}", success);
48 /// # }
49 /// ```
50 pub fn cancel(&self) -> bool {
51 let mut wheel = self.wheel.lock();
52 wheel.cancel(self.task_id)
53 }
54
55 /// 获取完成通知接收器的可变引用
56 ///
57 /// # 示例
58 /// ```no_run
59 /// # use kestrel_protocol_timer::{TimerWheel, CallbackWrapper};
60 /// # use std::time::Duration;
61 /// #
62 /// # #[tokio::main]
63 /// # async fn main() {
64 /// let timer = TimerWheel::with_defaults();
65 /// let callback = Some(CallbackWrapper::new(|| async {
66 /// println!("Timer fired!");
67 /// }));
68 /// let task = TimerWheel::create_task(Duration::from_secs(1), callback);
69 /// let handle = timer.register(task);
70 ///
71 /// // 等待定时器完成(使用 into_completion_receiver 消耗句柄)
72 /// handle.into_completion_receiver().0.await.ok();
73 /// println!("Timer completed!");
74 /// # }
75 /// ```
76 pub fn completion_receiver(&mut self) -> &mut CompletionReceiver {
77 &mut self.completion_rx
78 }
79
80 /// 消耗句柄,返回完成通知接收器
81 ///
82 /// # 示例
83 /// ```no_run
84 /// # use kestrel_protocol_timer::{TimerWheel, CallbackWrapper};
85 /// # use std::time::Duration;
86 /// #
87 /// # #[tokio::main]
88 /// # async fn main() {
89 /// let timer = TimerWheel::with_defaults();
90 /// let callback = Some(CallbackWrapper::new(|| async {
91 /// println!("Timer fired!");
92 /// }));
93 /// let task = TimerWheel::create_task(Duration::from_secs(1), callback);
94 /// let handle = timer.register(task);
95 ///
96 /// // 等待定时器完成
97 /// handle.into_completion_receiver().0.await.ok();
98 /// println!("Timer completed!");
99 /// # }
100 /// ```
101 pub fn into_completion_receiver(self) -> CompletionReceiver {
102 self.completion_rx
103 }
104}
105
106/// 批量定时器句柄,用于管理批量调度的定时器
107///
108/// 通过共享 Wheel 引用减少内存开销,同时提供批量操作和迭代器访问能力。
109///
110/// 注意:此类型不实现 Clone,以防止重复取消同一批定时器。
111/// 如需访问单个定时器句柄,请使用 `into_iter()` 或 `into_handles()` 进行转换。
112pub struct BatchHandle {
113 pub(crate) task_ids: Vec<TaskId>,
114 pub(crate) wheel: Arc<Mutex<Wheel>>,
115 pub(crate) completion_rxs: Vec<oneshot::Receiver<TaskCompletionReason>>,
116}
117
118impl BatchHandle {
119 pub(crate) fn new(task_ids: Vec<TaskId>, wheel: Arc<Mutex<Wheel>>, completion_rxs: Vec<oneshot::Receiver<TaskCompletionReason>>) -> Self {
120 Self { task_ids, wheel, completion_rxs }
121 }
122
123 /// 批量取消所有定时器
124 ///
125 /// # 返回
126 /// 成功取消的任务数量
127 ///
128 /// # 示例
129 /// ```no_run
130 /// # use kestrel_protocol_timer::{TimerWheel, CallbackWrapper};
131 /// # use std::time::Duration;
132 /// #
133 /// # #[tokio::main]
134 /// # async fn main() {
135 /// let timer = TimerWheel::with_defaults();
136 /// let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = (0..10)
137 /// .map(|_| (Duration::from_secs(1), Some(CallbackWrapper::new(|| async {}))))
138 /// .collect();
139 /// let tasks = TimerWheel::create_batch(callbacks);
140 /// let batch = timer.register_batch(tasks);
141 ///
142 /// let cancelled = batch.cancel_all();
143 /// println!("取消了 {} 个定时器", cancelled);
144 /// # }
145 /// ```
146 pub fn cancel_all(self) -> usize {
147 let mut wheel = self.wheel.lock();
148 wheel.cancel_batch(&self.task_ids)
149 }
150
151 /// 将批量句柄转换为单个定时器句柄的 Vec
152 ///
153 /// 消耗 BatchHandle,为每个任务创建独立的 TimerHandle。
154 ///
155 /// # 示例
156 /// ```no_run
157 /// # use kestrel_protocol_timer::{TimerWheel, CallbackWrapper};
158 /// # use std::time::Duration;
159 /// #
160 /// # #[tokio::main]
161 /// # async fn main() {
162 /// let timer = TimerWheel::with_defaults();
163 /// let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = (0..3)
164 /// .map(|_| (Duration::from_secs(1), Some(CallbackWrapper::new(|| async {}))))
165 /// .collect();
166 /// let tasks = TimerWheel::create_batch(callbacks);
167 /// let batch = timer.register_batch(tasks);
168 ///
169 /// // 转换为独立的句柄
170 /// let handles = batch.into_handles();
171 /// for handle in handles {
172 /// // 可以单独操作每个句柄
173 /// }
174 /// # }
175 /// ```
176 pub fn into_handles(self) -> Vec<TimerHandle> {
177 self.task_ids
178 .into_iter()
179 .zip(self.completion_rxs.into_iter())
180 .map(|(task_id, rx)| {
181 TimerHandle::new(task_id, self.wheel.clone(), rx)
182 })
183 .collect()
184 }
185
186 /// 获取批量任务的数量
187 pub fn len(&self) -> usize {
188 self.task_ids.len()
189 }
190
191 /// 检查批量任务是否为空
192 pub fn is_empty(&self) -> bool {
193 self.task_ids.is_empty()
194 }
195
196 /// 获取所有任务 ID 的引用
197 pub fn task_ids(&self) -> &[TaskId] {
198 &self.task_ids
199 }
200
201 /// 获取所有完成通知接收器的引用
202 ///
203 /// # 返回
204 /// 所有任务的完成通知接收器列表引用
205 pub fn completion_receivers(&mut self) -> &mut Vec<oneshot::Receiver<TaskCompletionReason>> {
206 &mut self.completion_rxs
207 }
208
209 /// 消耗句柄,返回所有完成通知接收器
210 ///
211 /// # 返回
212 /// 所有任务的完成通知接收器列表
213 ///
214 /// # 示例
215 /// ```no_run
216 /// # use kestrel_protocol_timer::{TimerWheel, CallbackWrapper};
217 /// # use std::time::Duration;
218 /// #
219 /// # #[tokio::main]
220 /// # async fn main() {
221 /// let timer = TimerWheel::with_defaults();
222 /// let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = (0..3)
223 /// .map(|_| (Duration::from_secs(1), Some(CallbackWrapper::new(|| async {}))))
224 /// .collect();
225 /// let tasks = TimerWheel::create_batch(callbacks);
226 /// let batch = timer.register_batch(tasks);
227 ///
228 /// // 获取所有完成通知接收器
229 /// let receivers = batch.into_completion_receivers();
230 /// for rx in receivers {
231 /// tokio::spawn(async move {
232 /// if rx.await.is_ok() {
233 /// println!("A timer completed!");
234 /// }
235 /// });
236 /// }
237 /// # }
238 /// ```
239 pub fn into_completion_receivers(self) -> Vec<oneshot::Receiver<TaskCompletionReason>> {
240 self.completion_rxs
241 }
242}
243
244/// 实现 IntoIterator,允许直接迭代 BatchHandle
245///
246/// # 示例
247/// ```no_run
248/// # use kestrel_protocol_timer::{TimerWheel, CallbackWrapper};
249/// # use std::time::Duration;
250/// #
251/// # #[tokio::main]
252/// # async fn main() {
253/// let timer = TimerWheel::with_defaults();
254/// let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = (0..3)
255/// .map(|_| (Duration::from_secs(1), Some(CallbackWrapper::new(|| async {}))))
256/// .collect();
257/// let tasks = TimerWheel::create_batch(callbacks);
258/// let batch = timer.register_batch(tasks);
259///
260/// // 直接迭代,每个元素都是独立的 TimerHandle
261/// for handle in batch {
262/// // 可以单独操作每个句柄
263/// }
264/// # }
265/// ```
266impl IntoIterator for BatchHandle {
267 type Item = TimerHandle;
268 type IntoIter = BatchHandleIter;
269
270 fn into_iter(self) -> Self::IntoIter {
271 BatchHandleIter {
272 task_ids: self.task_ids.into_iter(),
273 completion_rxs: self.completion_rxs.into_iter(),
274 wheel: self.wheel,
275 }
276 }
277}
278
279/// BatchHandle 的迭代器
280pub struct BatchHandleIter {
281 task_ids: std::vec::IntoIter<TaskId>,
282 completion_rxs: std::vec::IntoIter<oneshot::Receiver<TaskCompletionReason>>,
283 wheel: Arc<Mutex<Wheel>>,
284}
285
286impl Iterator for BatchHandleIter {
287 type Item = TimerHandle;
288
289 fn next(&mut self) -> Option<Self::Item> {
290 match (self.task_ids.next(), self.completion_rxs.next()) {
291 (Some(task_id), Some(rx)) => {
292 Some(TimerHandle::new(task_id, self.wheel.clone(), rx))
293 }
294 _ => None,
295 }
296 }
297
298 fn size_hint(&self) -> (usize, Option<usize>) {
299 self.task_ids.size_hint()
300 }
301}
302
303impl ExactSizeIterator for BatchHandleIter {
304 fn len(&self) -> usize {
305 self.task_ids.len()
306 }
307}
308
309/// 时间轮定时器管理器
310pub struct TimerWheel {
311 /// 时间轮唯一标识符
312
313 /// 时间轮实例(使用 Arc<Mutex> 包装以支持多线程访问)
314 wheel: Arc<Mutex<Wheel>>,
315
316 /// 后台 tick 循环任务句柄
317 tick_handle: Option<JoinHandle<()>>,
318}
319
320impl TimerWheel {
321 /// 创建新的定时器管理器
322 ///
323 /// # 参数
324 /// - `config`: 时间轮配置(已经过验证)
325 ///
326 /// # 示例
327 /// ```no_run
328 /// use kestrel_protocol_timer::{TimerWheel, WheelConfig, TimerTask};
329 /// use std::time::Duration;
330 ///
331 /// #[tokio::main]
332 /// async fn main() {
333 /// let config = WheelConfig::builder()
334 /// .tick_duration(Duration::from_millis(10))
335 /// .slot_count(512)
336 /// .build()
337 /// .unwrap();
338 /// let timer = TimerWheel::new(config);
339 ///
340 /// // 使用两步式 API
341 /// let task = TimerWheel::create_task(Duration::from_secs(1), None);
342 /// let handle = timer.register(task);
343 /// }
344 /// ```
345 pub fn new(config: WheelConfig) -> Self {
346 let tick_duration = config.tick_duration;
347 let wheel = Wheel::new(config);
348 let wheel = Arc::new(Mutex::new(wheel));
349 let wheel_clone = wheel.clone();
350
351 // 启动后台 tick 循环
352 let tick_handle = tokio::spawn(async move {
353 Self::tick_loop(wheel_clone, tick_duration).await;
354 });
355
356 Self {
357 wheel,
358 tick_handle: Some(tick_handle),
359 }
360 }
361
362 /// 创建带默认配置的定时器管理器
363 /// - tick 时长: 10ms
364 /// - 槽位数量: 512
365 ///
366 /// # 示例
367 /// ```no_run
368 /// use kestrel_protocol_timer::TimerWheel;
369 ///
370 /// #[tokio::main]
371 /// async fn main() {
372 /// let timer = TimerWheel::with_defaults();
373 /// }
374 /// ```
375 pub fn with_defaults() -> Self {
376 Self::new(WheelConfig::default())
377 }
378
379 /// 创建与此时间轮绑定的 TimerService(使用默认配置)
380 ///
381 /// # 返回
382 /// 绑定到此时间轮的 TimerService 实例
383 ///
384 /// # 示例
385 /// ```no_run
386 /// use kestrel_protocol_timer::{TimerWheel, TimerService, CallbackWrapper};
387 /// use std::time::Duration;
388 ///
389 ///
390 /// #[tokio::main]
391 /// async fn main() {
392 /// let timer = TimerWheel::with_defaults();
393 /// let mut service = timer.create_service();
394 ///
395 /// // 使用两步式 API 通过 service 批量调度定时器
396 /// let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = (0..5)
397 /// .map(|_| (Duration::from_millis(100), Some(CallbackWrapper::new(|| async {}))))
398 /// .collect();
399 /// let tasks = TimerService::create_batch(callbacks);
400 /// service.register_batch(tasks).unwrap();
401 ///
402 /// // 接收超时通知
403 /// let mut rx = service.take_receiver().unwrap();
404 /// while let Some(task_id) = rx.recv().await {
405 /// println!("Task {:?} completed", task_id);
406 /// }
407 /// }
408 /// ```
409 pub fn create_service(&self) -> crate::service::TimerService {
410 crate::service::TimerService::new(self.wheel.clone(), ServiceConfig::default())
411 }
412
413 /// 创建与此时间轮绑定的 TimerService(使用自定义配置)
414 ///
415 /// # 参数
416 /// - `config`: 服务配置
417 ///
418 /// # 返回
419 /// 绑定到此时间轮的 TimerService 实例
420 ///
421 /// # 示例
422 /// ```no_run
423 /// use kestrel_protocol_timer::{TimerWheel, ServiceConfig};
424 ///
425 /// #[tokio::main]
426 /// async fn main() {
427 /// let timer = TimerWheel::with_defaults();
428 /// let config = ServiceConfig::builder()
429 /// .command_channel_capacity(1024)
430 /// .timeout_channel_capacity(2000)
431 /// .build()
432 /// .unwrap();
433 /// let service = timer.create_service_with_config(config);
434 /// }
435 /// ```
436 pub fn create_service_with_config(&self, config: ServiceConfig) -> crate::service::TimerService {
437 crate::service::TimerService::new(self.wheel.clone(), config)
438 }
439
440 /// 创建定时器任务(申请阶段)
441 ///
442 /// # 参数
443 /// - `delay`: 延迟时间
444 /// - `callback`: 实现了 TimerCallback trait 的回调对象
445 ///
446 /// # 返回
447 /// 返回 TimerTask,需要通过 `register()` 注册到时间轮
448 ///
449 /// # 示例
450 /// ```no_run
451 /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
452 /// use std::time::Duration;
453 ///
454 ///
455 /// #[tokio::main]
456 /// async fn main() {
457 /// let timer = TimerWheel::with_defaults();
458 ///
459 /// // 步骤 1: 创建任务
460 /// let task = TimerWheel::create_task(Duration::from_secs(1), Some(CallbackWrapper::new(|| async {
461 /// println!("Timer fired!");
462 /// })));
463 ///
464 /// // 获取任务 ID
465 /// let task_id = task.get_id();
466 /// println!("Created task: {:?}", task_id);
467 ///
468 /// // 步骤 2: 注册任务
469 /// let handle = timer.register(task);
470 /// }
471 /// ```
472 #[inline]
473 pub fn create_task(delay: Duration, callback: Option<CallbackWrapper>) -> crate::task::TimerTask {
474 crate::task::TimerTask::new(delay, callback)
475 }
476
477 /// 批量创建定时器任务(申请阶段)
478 ///
479 /// # 参数
480 /// - `callbacks`: (延迟时间, 回调) 的元组列表
481 ///
482 /// # 返回
483 /// 返回 TimerTask 列表,需要通过 `register_batch()` 注册到时间轮
484 ///
485 /// # 示例
486 /// ```no_run
487 /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
488 /// use std::time::Duration;
489 /// use std::sync::Arc;
490 /// use std::sync::atomic::{AtomicU32, Ordering};
491 ///
492 /// #[tokio::main]
493 /// async fn main() {
494 /// let timer = TimerWheel::with_defaults();
495 /// let counter = Arc::new(AtomicU32::new(0));
496 ///
497 /// // 步骤 1: 批量创建任务
498 /// let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = (0..3)
499 /// .map(|i| {
500 /// let counter = Arc::clone(&counter);
501 /// let delay = Duration::from_millis(100 + i * 100);
502 /// let callback = Some(CallbackWrapper::new(move || {
503 /// let counter = Arc::clone(&counter);
504 /// async move {
505 /// counter.fetch_add(1, Ordering::SeqCst);
506 /// }
507 /// }));
508 /// (delay, callback)
509 /// })
510 /// .collect();
511 ///
512 /// let tasks = TimerWheel::create_batch(callbacks);
513 /// println!("Created {} tasks", tasks.len());
514 ///
515 /// // 步骤 2: 批量注册任务
516 /// let batch = timer.register_batch(tasks);
517 /// }
518 /// ```
519 #[inline]
520 pub fn create_batch(callbacks: Vec<(Duration, Option<CallbackWrapper>)>) -> Vec<crate::task::TimerTask>
521 {
522 callbacks
523 .into_iter()
524 .map(|(delay, callback)| crate::task::TimerTask::new(delay, callback))
525 .collect()
526 }
527
528 /// 注册定时器任务到时间轮(注册阶段)
529 ///
530 /// # 参数
531 /// - `task`: 通过 `create_task()` 创建的任务
532 ///
533 /// # 返回
534 /// 返回定时器句柄,可用于取消定时器和接收完成通知
535 ///
536 /// # 示例
537 /// ```no_run
538 /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
539 ///
540 /// use std::time::Duration;
541 ///
542 /// #[tokio::main]
543 /// async fn main() {
544 /// let timer = TimerWheel::with_defaults();
545 ///
546 /// let task = TimerWheel::create_task(Duration::from_secs(1), Some(CallbackWrapper::new(|| async {
547 /// println!("Timer fired!");
548 /// })));
549 /// let task_id = task.get_id();
550 ///
551 /// let handle = timer.register(task);
552 ///
553 /// // 等待定时器完成
554 /// handle.into_completion_receiver().0.await.ok();
555 /// }
556 /// ```
557 #[inline]
558 pub fn register(&self, task: crate::task::TimerTask) -> TimerHandle {
559 let (completion_tx, completion_rx) = oneshot::channel();
560 let notifier = crate::task::CompletionNotifier(completion_tx);
561
562 let delay = task.delay;
563 let task_id = task.id;
564
565 // 单次加锁完成所有操作
566 {
567 let mut wheel_guard = self.wheel.lock();
568 wheel_guard.insert(delay, task, notifier);
569 }
570
571 TimerHandle::new(task_id, self.wheel.clone(), completion_rx)
572 }
573
574 /// 批量注册定时器任务到时间轮(注册阶段)
575 ///
576 /// # 参数
577 /// - `tasks`: 通过 `create_batch()` 创建的任务列表
578 ///
579 /// # 返回
580 /// 返回批量定时器句柄
581 ///
582 /// # 示例
583 /// ```no_run
584 /// use kestrel_protocol_timer::{TimerWheel, TimerTask};
585 /// use std::time::Duration;
586 ///
587 /// #[tokio::main]
588 /// async fn main() {
589 /// let timer = TimerWheel::with_defaults();
590 ///
591 /// let callbacks: Vec<_> = (0..3)
592 /// .map(|_| (Duration::from_secs(1), None))
593 /// .collect();
594 /// let tasks = TimerWheel::create_batch(callbacks);
595 ///
596 /// let batch = timer.register_batch(tasks);
597 /// println!("Registered {} timers", batch.len());
598 /// }
599 /// ```
600 #[inline]
601 pub fn register_batch(&self, tasks: Vec<crate::task::TimerTask>) -> BatchHandle {
602 let task_count = tasks.len();
603 let mut completion_rxs = Vec::with_capacity(task_count);
604 let mut task_ids = Vec::with_capacity(task_count);
605 let mut prepared_tasks = Vec::with_capacity(task_count);
606
607 // 步骤1: 准备所有 channels 和 notifiers(无锁)
608 // 优化:使用 for 循环代替 map + collect,避免闭包捕获开销
609 for task in tasks {
610 let (completion_tx, completion_rx) = oneshot::channel();
611 let notifier = crate::task::CompletionNotifier(completion_tx);
612
613 task_ids.push(task.id);
614 completion_rxs.push(completion_rx);
615 prepared_tasks.push((task.delay, task, notifier));
616 }
617
618 // 步骤2: 单次加锁,批量插入
619 {
620 let mut wheel_guard = self.wheel.lock();
621 wheel_guard.insert_batch(prepared_tasks);
622 }
623
624 BatchHandle::new(task_ids, self.wheel.clone(), completion_rxs)
625 }
626
627 /// 取消定时器
628 ///
629 /// # 参数
630 /// - `task_id`: 任务 ID
631 ///
632 /// # 返回
633 /// 如果任务存在且成功取消返回 true,否则返回 false
634 ///
635 /// # 示例
636 /// ```no_run
637 /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
638 ///
639 /// use std::time::Duration;
640 ///
641 /// #[tokio::main]
642 /// async fn main() {
643 /// let timer = TimerWheel::with_defaults();
644 ///
645 /// let task = TimerWheel::create_task(Duration::from_secs(10), Some(CallbackWrapper::new(|| async {
646 /// println!("Timer fired!");
647 /// })));
648 /// let task_id = task.get_id();
649 /// let _handle = timer.register(task);
650 ///
651 /// // 使用任务 ID 取消
652 /// let cancelled = timer.cancel(task_id);
653 /// println!("取消成功: {}", cancelled);
654 /// }
655 /// ```
656 #[inline]
657 pub fn cancel(&self, task_id: TaskId) -> bool {
658 let mut wheel = self.wheel.lock();
659 wheel.cancel(task_id)
660 }
661
662 /// 批量取消定时器
663 ///
664 /// # 参数
665 /// - `task_ids`: 要取消的任务 ID 列表
666 ///
667 /// # 返回
668 /// 成功取消的任务数量
669 ///
670 /// # 性能优势
671 /// - 批量处理减少锁竞争
672 /// - 内部优化批量取消操作
673 ///
674 /// # 示例
675 /// ```no_run
676 /// use kestrel_protocol_timer::{TimerWheel, TimerTask};
677 /// use std::time::Duration;
678 ///
679 /// #[tokio::main]
680 /// async fn main() {
681 /// let timer = TimerWheel::with_defaults();
682 ///
683 /// // 创建多个定时器
684 /// let task1 = TimerWheel::create_task(Duration::from_secs(10), None);
685 /// let task2 = TimerWheel::create_task(Duration::from_secs(10), None);
686 /// let task3 = TimerWheel::create_task(Duration::from_secs(10), None);
687 ///
688 /// let task_ids = vec![task1.get_id(), task2.get_id(), task3.get_id()];
689 ///
690 /// let _h1 = timer.register(task1);
691 /// let _h2 = timer.register(task2);
692 /// let _h3 = timer.register(task3);
693 ///
694 /// // 批量取消
695 /// let cancelled = timer.cancel_batch(&task_ids);
696 /// println!("已取消 {} 个定时器", cancelled);
697 /// }
698 /// ```
699 #[inline]
700 pub fn cancel_batch(&self, task_ids: &[TaskId]) -> usize {
701 let mut wheel = self.wheel.lock();
702 wheel.cancel_batch(task_ids)
703 }
704
705 /// 推迟定时器
706 ///
707 /// # 参数
708 /// - `task_id`: 要推迟的任务 ID
709 /// - `new_delay`: 新的延迟时间(从当前时间点重新计算)
710 /// - `callback`: 新的回调函数,传入 `None` 保持原回调不变,传入 `Some` 替换为新回调
711 ///
712 /// # 返回
713 /// 如果任务存在且成功推迟返回 true,否则返回 false
714 ///
715 /// # 注意
716 /// - 推迟后任务 ID 保持不变
717 /// - 原有的 completion_receiver 仍然有效
718 ///
719 /// # 示例
720 ///
721 /// ## 保持原回调
722 /// ```no_run
723 /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
724 /// use std::time::Duration;
725 ///
726 ///
727 /// #[tokio::main]
728 /// async fn main() {
729 /// let timer = TimerWheel::with_defaults();
730 ///
731 /// let task = TimerWheel::create_task(Duration::from_secs(5), Some(CallbackWrapper::new(|| async {
732 /// println!("Timer fired!");
733 /// })));
734 /// let task_id = task.get_id();
735 /// let _handle = timer.register(task);
736 ///
737 /// // 推迟到 10 秒后触发(保持原回调)
738 /// let success = timer.postpone(task_id, Duration::from_secs(10), None);
739 /// println!("推迟成功: {}", success);
740 /// }
741 /// ```
742 ///
743 /// ## 替换为新回调
744 /// ```no_run
745 /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
746 /// use std::time::Duration;
747 ///
748 /// #[tokio::main]
749 /// async fn main() {
750 /// let timer = TimerWheel::with_defaults();
751 ///
752 /// let task = TimerWheel::create_task(Duration::from_secs(5), Some(CallbackWrapper::new(|| async {
753 /// println!("Original callback!");
754 /// })));
755 /// let task_id = task.get_id();
756 /// let _handle = timer.register(task);
757 ///
758 /// // 推迟到 10 秒后触发(并替换为新回调)
759 /// let success = timer.postpone(task_id, Duration::from_secs(10), Some(CallbackWrapper::new(|| async {
760 /// println!("New callback!");
761 /// })));
762 /// println!("推迟成功: {}", success);
763 /// }
764 /// ```
765 #[inline]
766 pub fn postpone(
767 &self,
768 task_id: TaskId,
769 new_delay: Duration,
770 callback: Option<CallbackWrapper>,
771 ) -> bool {
772 let mut wheel = self.wheel.lock();
773 wheel.postpone(task_id, new_delay, callback)
774 }
775
776 /// 批量推迟定时器(保持原回调)
777 ///
778 /// # 参数
779 /// - `updates`: (任务ID, 新延迟) 的元组列表
780 ///
781 /// # 返回
782 /// 成功推迟的任务数量
783 ///
784 /// # 注意
785 /// - 此方法会保持所有任务的原回调不变
786 /// - 如需替换回调,请使用 `postpone_batch_with_callbacks`
787 ///
788 /// # 性能优势
789 /// - 批量处理减少锁竞争
790 /// - 内部优化批量推迟操作
791 ///
792 /// # 示例
793 /// ```no_run
794 /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
795 /// use std::time::Duration;
796 ///
797 /// #[tokio::main]
798 /// async fn main() {
799 /// let timer = TimerWheel::with_defaults();
800 ///
801 /// // 创建多个带回调的定时器
802 /// let task1 = TimerWheel::create_task(Duration::from_secs(5), Some(CallbackWrapper::new(|| async {
803 /// println!("Task 1 fired!");
804 /// })));
805 /// let task2 = TimerWheel::create_task(Duration::from_secs(5), Some(CallbackWrapper::new(|| async {
806 /// println!("Task 2 fired!");
807 /// })));
808 /// let task3 = TimerWheel::create_task(Duration::from_secs(5), Some(CallbackWrapper::new(|| async {
809 /// println!("Task 3 fired!");
810 /// })));
811 ///
812 /// let task_ids = vec![
813 /// (task1.get_id(), Duration::from_secs(10)),
814 /// (task2.get_id(), Duration::from_secs(15)),
815 /// (task3.get_id(), Duration::from_secs(20)),
816 /// ];
817 ///
818 /// timer.register(task1);
819 /// timer.register(task2);
820 /// timer.register(task3);
821 ///
822 /// // 批量推迟(保持原回调)
823 /// let postponed = timer.postpone_batch(&task_ids);
824 /// println!("已推迟 {} 个定时器", postponed);
825 /// }
826 /// ```
827 #[inline]
828 pub fn postpone_batch(&self, updates: &[(TaskId, Duration)]) -> usize {
829 let mut wheel = self.wheel.lock();
830 wheel.postpone_batch(updates.to_vec())
831 }
832
833 /// 批量推迟定时器(替换回调)
834 ///
835 /// # 参数
836 /// - `updates`: (任务ID, 新延迟, 新回调) 的元组列表
837 ///
838 /// # 返回
839 /// 成功推迟的任务数量
840 ///
841 /// # 性能优势
842 /// - 批量处理减少锁竞争
843 /// - 内部优化批量推迟操作
844 ///
845 /// # 示例
846 /// ```no_run
847 /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
848 /// use std::time::Duration;
849 /// use std::sync::Arc;
850 /// use std::sync::atomic::{AtomicU32, Ordering};
851 ///
852 /// #[tokio::main]
853 /// async fn main() {
854 /// let timer = TimerWheel::with_defaults();
855 /// let counter = Arc::new(AtomicU32::new(0));
856 ///
857 /// // 创建多个定时器
858 /// let task1 = TimerWheel::create_task(Duration::from_secs(5), None);
859 /// let task2 = TimerWheel::create_task(Duration::from_secs(5), None);
860 ///
861 /// let id1 = task1.get_id();
862 /// let id2 = task2.get_id();
863 ///
864 /// timer.register(task1);
865 /// timer.register(task2);
866 ///
867 /// // 批量推迟并替换回调
868 /// let updates: Vec<_> = vec![id1, id2]
869 /// .into_iter()
870 /// .map(|id| {
871 /// let counter = Arc::clone(&counter);
872 /// (id, Duration::from_secs(10), Some(CallbackWrapper::new(move || {
873 /// let counter = Arc::clone(&counter);
874 /// async move { counter.fetch_add(1, Ordering::SeqCst); }
875 /// })))
876 /// })
877 /// .collect();
878 /// let postponed = timer.postpone_batch_with_callbacks(updates);
879 /// println!("已推迟 {} 个定时器", postponed);
880 /// }
881 /// ```
882 #[inline]
883 pub fn postpone_batch_with_callbacks(
884 &self,
885 updates: Vec<(TaskId, Duration, Option<CallbackWrapper>)>,
886 ) -> usize {
887 let mut wheel = self.wheel.lock();
888 wheel.postpone_batch_with_callbacks(updates.to_vec())
889 }
890
891 /// 核心 tick 循环
892 async fn tick_loop(wheel: Arc<Mutex<Wheel>>, tick_duration: Duration) {
893 let mut interval = tokio::time::interval(tick_duration);
894 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
895
896 loop {
897 interval.tick().await;
898
899 // 推进时间轮并获取到期任务
900 let expired_tasks = {
901 let mut wheel_guard = wheel.lock();
902 wheel_guard.advance()
903 };
904
905 // 执行到期任务
906 for task in expired_tasks {
907 let callback = task.get_callback();
908
909 // 移动task的所有权来获取completion_notifier
910 let notifier = task.completion_notifier;
911
912 // 只有注册过的任务才有 notifier
913 if let Some(notifier) = notifier {
914 // 在独立的 tokio 任务中执行回调,并在回调完成后发送通知
915 if let Some(callback) = callback {
916 tokio::spawn(async move {
917 // 执行回调
918 let future = callback.call();
919 future.await;
920
921 // 回调执行完成后发送通知
922 let _ = notifier.0.send(TaskCompletionReason::Expired);
923 });
924 } else {
925 // 如果没有回调,立即发送完成通知
926 let _ = notifier.0.send(TaskCompletionReason::Expired);
927 }
928 }
929 }
930 }
931 }
932
933 /// 停止定时器管理器
934 pub async fn shutdown(mut self) {
935 if let Some(handle) = self.tick_handle.take() {
936 handle.abort();
937 let _ = handle.await;
938 }
939 }
940}
941
942impl Drop for TimerWheel {
943 fn drop(&mut self) {
944 if let Some(handle) = self.tick_handle.take() {
945 handle.abort();
946 }
947 }
948}
949
950#[cfg(test)]
951mod tests {
952 use super::*;
953 use std::sync::atomic::{AtomicU32, Ordering};
954
955 #[tokio::test]
956 async fn test_timer_creation() {
957 let _timer = TimerWheel::with_defaults();
958 }
959
960 #[tokio::test]
961 async fn test_schedule_once() {
962 use std::sync::Arc;
963 let timer = TimerWheel::with_defaults();
964 let counter = Arc::new(AtomicU32::new(0));
965 let counter_clone = Arc::clone(&counter);
966
967 let task = TimerWheel::create_task(
968 Duration::from_millis(50),
969 Some(CallbackWrapper::new(move || {
970 let counter = Arc::clone(&counter_clone);
971 async move {
972 counter.fetch_add(1, Ordering::SeqCst);
973 }
974 })),
975 );
976 let _handle = timer.register(task);
977
978 // 等待定时器触发
979 tokio::time::sleep(Duration::from_millis(100)).await;
980 assert_eq!(counter.load(Ordering::SeqCst), 1);
981 }
982
983 #[tokio::test]
984 async fn test_cancel_timer() {
985 use std::sync::Arc;
986 let timer = TimerWheel::with_defaults();
987 let counter = Arc::new(AtomicU32::new(0));
988 let counter_clone = Arc::clone(&counter);
989
990 let task = TimerWheel::create_task(
991 Duration::from_millis(100),
992 Some(CallbackWrapper::new(move || {
993 let counter = Arc::clone(&counter_clone);
994 async move {
995 counter.fetch_add(1, Ordering::SeqCst);
996 }
997 })),
998 );
999 let handle = timer.register(task);
1000
1001 // 立即取消
1002 let cancel_result = handle.cancel();
1003 assert!(cancel_result);
1004
1005 // 等待足够长时间确保定时器不会触发
1006 tokio::time::sleep(Duration::from_millis(200)).await;
1007 assert_eq!(counter.load(Ordering::SeqCst), 0);
1008 }
1009
1010 #[tokio::test]
1011 async fn test_cancel_immediate() {
1012 use std::sync::Arc;
1013 let timer = TimerWheel::with_defaults();
1014 let counter = Arc::new(AtomicU32::new(0));
1015 let counter_clone = Arc::clone(&counter);
1016
1017 let task = TimerWheel::create_task(
1018 Duration::from_millis(100),
1019 Some(CallbackWrapper::new(move || {
1020 let counter = Arc::clone(&counter_clone);
1021 async move {
1022 counter.fetch_add(1, Ordering::SeqCst);
1023 }
1024 })),
1025 );
1026 let handle = timer.register(task);
1027
1028 // 立即取消
1029 let cancel_result = handle.cancel();
1030 assert!(cancel_result);
1031
1032 // 等待足够长时间确保定时器不会触发
1033 tokio::time::sleep(Duration::from_millis(200)).await;
1034 assert_eq!(counter.load(Ordering::SeqCst), 0);
1035 }
1036
1037 #[tokio::test]
1038 async fn test_postpone_timer() {
1039 use std::sync::Arc;
1040 let timer = TimerWheel::with_defaults();
1041 let counter = Arc::new(AtomicU32::new(0));
1042 let counter_clone = Arc::clone(&counter);
1043
1044 let task = TimerWheel::create_task(
1045 Duration::from_millis(50),
1046 Some(CallbackWrapper::new(move || {
1047 let counter = Arc::clone(&counter_clone);
1048 async move {
1049 counter.fetch_add(1, Ordering::SeqCst);
1050 }
1051 })),
1052 );
1053 let task_id = task.get_id();
1054 let handle = timer.register(task);
1055
1056 // 推迟任务到 150ms
1057 let postponed = timer.postpone(task_id, Duration::from_millis(150), None);
1058 assert!(postponed);
1059
1060 // 等待原定时间 50ms,任务不应该触发
1061 tokio::time::sleep(Duration::from_millis(70)).await;
1062 assert_eq!(counter.load(Ordering::SeqCst), 0);
1063
1064 // 等待新的触发时间(从推迟开始算,还需要等待约 150ms)
1065 let result = tokio::time::timeout(
1066 Duration::from_millis(200),
1067 handle.into_completion_receiver().0
1068 ).await;
1069 assert!(result.is_ok());
1070
1071 // 等待回调执行
1072 tokio::time::sleep(Duration::from_millis(20)).await;
1073 assert_eq!(counter.load(Ordering::SeqCst), 1);
1074 }
1075
1076 #[tokio::test]
1077 async fn test_postpone_with_callback() {
1078 use std::sync::Arc;
1079 let timer = TimerWheel::with_defaults();
1080 let counter = Arc::new(AtomicU32::new(0));
1081 let counter_clone1 = Arc::clone(&counter);
1082 let counter_clone2 = Arc::clone(&counter);
1083
1084 // 创建任务,原始回调增加 1
1085 let task = TimerWheel::create_task(
1086 Duration::from_millis(50),
1087 Some(CallbackWrapper::new(move || {
1088 let counter = Arc::clone(&counter_clone1);
1089 async move {
1090 counter.fetch_add(1, Ordering::SeqCst);
1091 }
1092 })),
1093 );
1094 let task_id = task.get_id();
1095 let handle = timer.register(task);
1096
1097 // 推迟任务并替换回调,新回调增加 10
1098 let postponed = timer.postpone(
1099 task_id,
1100 Duration::from_millis(100),
1101 Some(CallbackWrapper::new(move || {
1102 let counter = Arc::clone(&counter_clone2);
1103 async move {
1104 counter.fetch_add(10, Ordering::SeqCst);
1105 }
1106 })),
1107 );
1108 assert!(postponed);
1109
1110 // 等待任务触发(推迟后需要等待100ms,加上余量)
1111 let result = tokio::time::timeout(
1112 Duration::from_millis(200),
1113 handle.into_completion_receiver().0
1114 ).await;
1115 assert!(result.is_ok());
1116
1117 // 等待回调执行
1118 tokio::time::sleep(Duration::from_millis(20)).await;
1119
1120 // 验证新回调被执行(增加了 10 而不是 1)
1121 assert_eq!(counter.load(Ordering::SeqCst), 10);
1122 }
1123
1124 #[tokio::test]
1125 async fn test_postpone_nonexistent_timer() {
1126 let timer = TimerWheel::with_defaults();
1127
1128 // 尝试推迟不存在的任务
1129 let fake_task = TimerWheel::create_task(Duration::from_millis(50), None);
1130 let fake_task_id = fake_task.get_id();
1131 // 不注册这个任务
1132
1133 let postponed = timer.postpone(fake_task_id, Duration::from_millis(100), None);
1134 assert!(!postponed);
1135 }
1136
1137 #[tokio::test]
1138 async fn test_postpone_batch() {
1139 use std::sync::Arc;
1140 let timer = TimerWheel::with_defaults();
1141 let counter = Arc::new(AtomicU32::new(0));
1142
1143 // 创建 3 个任务
1144 let mut task_ids = Vec::new();
1145 for _ in 0..3 {
1146 let counter_clone = Arc::clone(&counter);
1147 let task = TimerWheel::create_task(
1148 Duration::from_millis(50),
1149 Some(CallbackWrapper::new(move || {
1150 let counter = Arc::clone(&counter_clone);
1151 async move {
1152 counter.fetch_add(1, Ordering::SeqCst);
1153 }
1154 })),
1155 );
1156 task_ids.push((task.get_id(), Duration::from_millis(150)));
1157 timer.register(task);
1158 }
1159
1160 // 批量推迟
1161 let postponed = timer.postpone_batch(&task_ids);
1162 assert_eq!(postponed, 3);
1163
1164 // 等待原定时间 50ms,任务不应该触发
1165 tokio::time::sleep(Duration::from_millis(70)).await;
1166 assert_eq!(counter.load(Ordering::SeqCst), 0);
1167
1168 // 等待新的触发时间(从推迟开始算,还需要等待约 150ms)
1169 tokio::time::sleep(Duration::from_millis(200)).await;
1170
1171 // 等待回调执行
1172 tokio::time::sleep(Duration::from_millis(20)).await;
1173 assert_eq!(counter.load(Ordering::SeqCst), 3);
1174 }
1175
1176 #[tokio::test]
1177 async fn test_postpone_batch_with_callbacks() {
1178 use std::sync::Arc;
1179 let timer = TimerWheel::with_defaults();
1180 let counter = Arc::new(AtomicU32::new(0));
1181
1182 // 创建 3 个任务
1183 let mut task_ids = Vec::new();
1184 for _ in 0..3 {
1185 let task = TimerWheel::create_task(
1186 Duration::from_millis(50),
1187 None
1188 );
1189 task_ids.push(task.get_id());
1190 timer.register(task);
1191 }
1192
1193 // 批量推迟并替换回调
1194 let updates: Vec<_> = task_ids
1195 .into_iter()
1196 .map(|id| {
1197 let counter_clone = Arc::clone(&counter);
1198 (id, Duration::from_millis(150), Some(CallbackWrapper::new(move || {
1199 let counter = Arc::clone(&counter_clone);
1200 async move {
1201 counter.fetch_add(1, Ordering::SeqCst);
1202 }
1203 })))
1204 })
1205 .collect();
1206
1207 let postponed = timer.postpone_batch_with_callbacks(updates);
1208 assert_eq!(postponed, 3);
1209
1210 // 等待原定时间 50ms,任务不应该触发
1211 tokio::time::sleep(Duration::from_millis(70)).await;
1212 assert_eq!(counter.load(Ordering::SeqCst), 0);
1213
1214 // 等待新的触发时间(从推迟开始算,还需要等待约 150ms)
1215 tokio::time::sleep(Duration::from_millis(200)).await;
1216
1217 // 等待回调执行
1218 tokio::time::sleep(Duration::from_millis(20)).await;
1219 assert_eq!(counter.load(Ordering::SeqCst), 3);
1220 }
1221
1222 #[tokio::test]
1223 async fn test_postpone_keeps_completion_receiver_valid() {
1224 use std::sync::Arc;
1225 let timer = TimerWheel::with_defaults();
1226 let counter = Arc::new(AtomicU32::new(0));
1227 let counter_clone = Arc::clone(&counter);
1228
1229 let task = TimerWheel::create_task(
1230 Duration::from_millis(50),
1231 Some(CallbackWrapper::new(move || {
1232 let counter = Arc::clone(&counter_clone);
1233 async move {
1234 counter.fetch_add(1, Ordering::SeqCst);
1235 }
1236 })),
1237 );
1238 let task_id = task.get_id();
1239 let handle = timer.register(task);
1240
1241 // 推迟任务
1242 timer.postpone(task_id, Duration::from_millis(100), None);
1243
1244 // 验证原 completion_receiver 仍然有效(推迟后需要等待100ms,加上余量)
1245 let result = tokio::time::timeout(
1246 Duration::from_millis(200),
1247 handle.into_completion_receiver().0
1248 ).await;
1249 assert!(result.is_ok(), "Completion receiver should still work after postpone");
1250
1251 // 等待回调执行
1252 tokio::time::sleep(Duration::from_millis(20)).await;
1253 assert_eq!(counter.load(Ordering::SeqCst), 1);
1254 }
1255}
1256