kestrel_protocol_timer/timer.rs
1use crate::config::{ServiceConfig, WheelConfig};
2use crate::task::{CallbackWrapper, TaskId, TaskCompletionReason};
3use crate::wheel::Wheel;
4use parking_lot::Mutex;
5use std::sync::Arc;
6use std::time::Duration;
7use tokio::sync::oneshot;
8use tokio::task::JoinHandle;
9
10/// 完成通知接收器,用于接收定时器完成通知
11pub struct CompletionReceiver(pub oneshot::Receiver<TaskCompletionReason>);
12
13/// 定时器句柄,用于管理定时器的生命周期
14///
15/// 注意:此类型不实现 Clone,以防止重复取消同一个定时器。
16/// 每个定时器只应有一个所有者。
17pub struct TimerHandle {
18 pub(crate) task_id: TaskId,
19 pub(crate) wheel: Arc<Mutex<Wheel>>,
20 pub(crate) completion_rx: CompletionReceiver,
21}
22
23impl TimerHandle {
24 pub(crate) fn new(task_id: TaskId, wheel: Arc<Mutex<Wheel>>, completion_rx: oneshot::Receiver<TaskCompletionReason>) -> Self {
25 Self { task_id, wheel, completion_rx: CompletionReceiver(completion_rx) }
26 }
27
28 /// 取消定时器
29 ///
30 /// # 返回
31 /// 如果任务存在且成功取消返回 true,否则返回 false
32 ///
33 /// # 示例
34 /// ```no_run
35 /// # use kestrel_protocol_timer::{TimerWheel, CallbackWrapper};
36 /// # use std::time::Duration;
37 /// #
38 /// # #[tokio::main]
39 /// # async fn main() {
40 /// let timer = TimerWheel::with_defaults();
41 /// let callback = Some(CallbackWrapper::new(|| async {}));
42 /// let task = TimerWheel::create_task(Duration::from_secs(1), callback);
43 /// let handle = timer.register(task);
44 ///
45 /// // 取消定时器
46 /// let success = handle.cancel();
47 /// println!("取消成功: {}", success);
48 /// # }
49 /// ```
50 pub fn cancel(&self) -> bool {
51 let mut wheel = self.wheel.lock();
52 wheel.cancel(self.task_id)
53 }
54
55 /// 获取完成通知接收器的可变引用
56 ///
57 /// # 示例
58 /// ```no_run
59 /// # use kestrel_protocol_timer::{TimerWheel, CallbackWrapper};
60 /// # use std::time::Duration;
61 /// #
62 /// # #[tokio::main]
63 /// # async fn main() {
64 /// let timer = TimerWheel::with_defaults();
65 /// let callback = Some(CallbackWrapper::new(|| async {
66 /// println!("Timer fired!");
67 /// }));
68 /// let task = TimerWheel::create_task(Duration::from_secs(1), callback);
69 /// let handle = timer.register(task);
70 ///
71 /// // 等待定时器完成(使用 into_completion_receiver 消耗句柄)
72 /// handle.into_completion_receiver().0.await.ok();
73 /// println!("Timer completed!");
74 /// # }
75 /// ```
76 pub fn completion_receiver(&mut self) -> &mut CompletionReceiver {
77 &mut self.completion_rx
78 }
79
80 /// 消耗句柄,返回完成通知接收器
81 ///
82 /// # 示例
83 /// ```no_run
84 /// # use kestrel_protocol_timer::{TimerWheel, CallbackWrapper};
85 /// # use std::time::Duration;
86 /// #
87 /// # #[tokio::main]
88 /// # async fn main() {
89 /// let timer = TimerWheel::with_defaults();
90 /// let callback = Some(CallbackWrapper::new(|| async {
91 /// println!("Timer fired!");
92 /// }));
93 /// let task = TimerWheel::create_task(Duration::from_secs(1), callback);
94 /// let handle = timer.register(task);
95 ///
96 /// // 等待定时器完成
97 /// handle.into_completion_receiver().0.await.ok();
98 /// println!("Timer completed!");
99 /// # }
100 /// ```
101 pub fn into_completion_receiver(self) -> CompletionReceiver {
102 self.completion_rx
103 }
104}
105
106/// 批量定时器句柄,用于管理批量调度的定时器
107///
108/// 通过共享 Wheel 引用减少内存开销,同时提供批量操作和迭代器访问能力。
109///
110/// 注意:此类型不实现 Clone,以防止重复取消同一批定时器。
111/// 如需访问单个定时器句柄,请使用 `into_iter()` 或 `into_handles()` 进行转换。
112pub struct BatchHandle {
113 pub(crate) task_ids: Vec<TaskId>,
114 pub(crate) wheel: Arc<Mutex<Wheel>>,
115 pub(crate) completion_rxs: Vec<oneshot::Receiver<TaskCompletionReason>>,
116}
117
118impl BatchHandle {
119 pub(crate) fn new(task_ids: Vec<TaskId>, wheel: Arc<Mutex<Wheel>>, completion_rxs: Vec<oneshot::Receiver<TaskCompletionReason>>) -> Self {
120 Self { task_ids, wheel, completion_rxs }
121 }
122
123 /// 批量取消所有定时器
124 ///
125 /// # 返回
126 /// 成功取消的任务数量
127 ///
128 /// # 示例
129 /// ```no_run
130 /// # use kestrel_protocol_timer::{TimerWheel, CallbackWrapper};
131 /// # use std::time::Duration;
132 /// #
133 /// # #[tokio::main]
134 /// # async fn main() {
135 /// let timer = TimerWheel::with_defaults();
136 /// let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = (0..10)
137 /// .map(|_| (Duration::from_secs(1), Some(CallbackWrapper::new(|| async {}))))
138 /// .collect();
139 /// let tasks = TimerWheel::create_batch(callbacks);
140 /// let batch = timer.register_batch(tasks);
141 ///
142 /// let cancelled = batch.cancel_all();
143 /// println!("取消了 {} 个定时器", cancelled);
144 /// # }
145 /// ```
146 pub fn cancel_all(self) -> usize {
147 let mut wheel = self.wheel.lock();
148 wheel.cancel_batch(&self.task_ids)
149 }
150
151 /// 将批量句柄转换为单个定时器句柄的 Vec
152 ///
153 /// 消耗 BatchHandle,为每个任务创建独立的 TimerHandle。
154 ///
155 /// # 示例
156 /// ```no_run
157 /// # use kestrel_protocol_timer::{TimerWheel, CallbackWrapper};
158 /// # use std::time::Duration;
159 /// #
160 /// # #[tokio::main]
161 /// # async fn main() {
162 /// let timer = TimerWheel::with_defaults();
163 /// let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = (0..3)
164 /// .map(|_| (Duration::from_secs(1), Some(CallbackWrapper::new(|| async {}))))
165 /// .collect();
166 /// let tasks = TimerWheel::create_batch(callbacks);
167 /// let batch = timer.register_batch(tasks);
168 ///
169 /// // 转换为独立的句柄
170 /// let handles = batch.into_handles();
171 /// for handle in handles {
172 /// // 可以单独操作每个句柄
173 /// }
174 /// # }
175 /// ```
176 pub fn into_handles(self) -> Vec<TimerHandle> {
177 self.task_ids
178 .into_iter()
179 .zip(self.completion_rxs.into_iter())
180 .map(|(task_id, rx)| {
181 TimerHandle::new(task_id, self.wheel.clone(), rx)
182 })
183 .collect()
184 }
185
186 /// 获取批量任务的数量
187 pub fn len(&self) -> usize {
188 self.task_ids.len()
189 }
190
191 /// 检查批量任务是否为空
192 pub fn is_empty(&self) -> bool {
193 self.task_ids.is_empty()
194 }
195
196 /// 获取所有任务 ID 的引用
197 pub fn task_ids(&self) -> &[TaskId] {
198 &self.task_ids
199 }
200
201 /// 获取所有完成通知接收器的引用
202 ///
203 /// # 返回
204 /// 所有任务的完成通知接收器列表引用
205 pub fn completion_receivers(&mut self) -> &mut Vec<oneshot::Receiver<TaskCompletionReason>> {
206 &mut self.completion_rxs
207 }
208
209 /// 消耗句柄,返回所有完成通知接收器
210 ///
211 /// # 返回
212 /// 所有任务的完成通知接收器列表
213 ///
214 /// # 示例
215 /// ```no_run
216 /// # use kestrel_protocol_timer::{TimerWheel, CallbackWrapper};
217 /// # use std::time::Duration;
218 /// #
219 /// # #[tokio::main]
220 /// # async fn main() {
221 /// let timer = TimerWheel::with_defaults();
222 /// let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = (0..3)
223 /// .map(|_| (Duration::from_secs(1), Some(CallbackWrapper::new(|| async {}))))
224 /// .collect();
225 /// let tasks = TimerWheel::create_batch(callbacks);
226 /// let batch = timer.register_batch(tasks);
227 ///
228 /// // 获取所有完成通知接收器
229 /// let receivers = batch.into_completion_receivers();
230 /// for rx in receivers {
231 /// tokio::spawn(async move {
232 /// if rx.await.is_ok() {
233 /// println!("A timer completed!");
234 /// }
235 /// });
236 /// }
237 /// # }
238 /// ```
239 pub fn into_completion_receivers(self) -> Vec<oneshot::Receiver<TaskCompletionReason>> {
240 self.completion_rxs
241 }
242}
243
244/// 实现 IntoIterator,允许直接迭代 BatchHandle
245///
246/// # 示例
247/// ```no_run
248/// # use kestrel_protocol_timer::{TimerWheel, CallbackWrapper};
249/// # use std::time::Duration;
250/// #
251/// # #[tokio::main]
252/// # async fn main() {
253/// let timer = TimerWheel::with_defaults();
254/// let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = (0..3)
255/// .map(|_| (Duration::from_secs(1), Some(CallbackWrapper::new(|| async {}))))
256/// .collect();
257/// let tasks = TimerWheel::create_batch(callbacks);
258/// let batch = timer.register_batch(tasks);
259///
260/// // 直接迭代,每个元素都是独立的 TimerHandle
261/// for handle in batch {
262/// // 可以单独操作每个句柄
263/// }
264/// # }
265/// ```
266impl IntoIterator for BatchHandle {
267 type Item = TimerHandle;
268 type IntoIter = BatchHandleIter;
269
270 fn into_iter(self) -> Self::IntoIter {
271 BatchHandleIter {
272 task_ids: self.task_ids.into_iter(),
273 completion_rxs: self.completion_rxs.into_iter(),
274 wheel: self.wheel,
275 }
276 }
277}
278
279/// BatchHandle 的迭代器
280pub struct BatchHandleIter {
281 task_ids: std::vec::IntoIter<TaskId>,
282 completion_rxs: std::vec::IntoIter<oneshot::Receiver<TaskCompletionReason>>,
283 wheel: Arc<Mutex<Wheel>>,
284}
285
286impl Iterator for BatchHandleIter {
287 type Item = TimerHandle;
288
289 fn next(&mut self) -> Option<Self::Item> {
290 match (self.task_ids.next(), self.completion_rxs.next()) {
291 (Some(task_id), Some(rx)) => {
292 Some(TimerHandle::new(task_id, self.wheel.clone(), rx))
293 }
294 _ => None,
295 }
296 }
297
298 fn size_hint(&self) -> (usize, Option<usize>) {
299 self.task_ids.size_hint()
300 }
301}
302
303impl ExactSizeIterator for BatchHandleIter {
304 fn len(&self) -> usize {
305 self.task_ids.len()
306 }
307}
308
309/// 时间轮定时器管理器
310pub struct TimerWheel {
311 /// 时间轮唯一标识符
312
313 /// 时间轮实例(使用 Arc<Mutex> 包装以支持多线程访问)
314 wheel: Arc<Mutex<Wheel>>,
315
316 /// 后台 tick 循环任务句柄
317 tick_handle: Option<JoinHandle<()>>,
318}
319
320impl TimerWheel {
321 /// 创建新的定时器管理器
322 ///
323 /// # 参数
324 /// - `config`: 时间轮配置(已经过验证)
325 ///
326 /// # 示例
327 /// ```no_run
328 /// use kestrel_protocol_timer::{TimerWheel, WheelConfig, TimerTask};
329 /// use std::time::Duration;
330 ///
331 /// #[tokio::main]
332 /// async fn main() {
333 /// let config = WheelConfig::builder()
334 /// .tick_duration(Duration::from_millis(10))
335 /// .slot_count(512)
336 /// .build()
337 /// .unwrap();
338 /// let timer = TimerWheel::new(config);
339 ///
340 /// // 使用两步式 API
341 /// let task = TimerWheel::create_task(Duration::from_secs(1), None);
342 /// let handle = timer.register(task);
343 /// }
344 /// ```
345 pub fn new(config: WheelConfig) -> Self {
346 let tick_duration = config.tick_duration;
347 let wheel = Wheel::new(config);
348 let wheel = Arc::new(Mutex::new(wheel));
349 let wheel_clone = wheel.clone();
350
351 // 启动后台 tick 循环
352 let tick_handle = tokio::spawn(async move {
353 Self::tick_loop(wheel_clone, tick_duration).await;
354 });
355
356 Self {
357 wheel,
358 tick_handle: Some(tick_handle),
359 }
360 }
361
362 /// 创建带默认配置的定时器管理器
363 /// - tick 时长: 10ms
364 /// - 槽位数量: 512
365 ///
366 /// # 示例
367 /// ```no_run
368 /// use kestrel_protocol_timer::TimerWheel;
369 ///
370 /// #[tokio::main]
371 /// async fn main() {
372 /// let timer = TimerWheel::with_defaults();
373 /// }
374 /// ```
375 pub fn with_defaults() -> Self {
376 Self::new(WheelConfig::default())
377 }
378
379 /// 创建与此时间轮绑定的 TimerService(使用默认配置)
380 ///
381 /// # 返回
382 /// 绑定到此时间轮的 TimerService 实例
383 ///
384 /// # 示例
385 /// ```no_run
386 /// use kestrel_protocol_timer::{TimerWheel, TimerService, CallbackWrapper};
387 /// use std::time::Duration;
388 ///
389 ///
390 /// #[tokio::main]
391 /// async fn main() {
392 /// let timer = TimerWheel::with_defaults();
393 /// let mut service = timer.create_service();
394 ///
395 /// // 使用两步式 API 通过 service 批量调度定时器
396 /// let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = (0..5)
397 /// .map(|_| (Duration::from_millis(100), Some(CallbackWrapper::new(|| async {}))))
398 /// .collect();
399 /// let tasks = TimerService::create_batch(callbacks);
400 /// service.register_batch(tasks).unwrap();
401 ///
402 /// // 接收超时通知
403 /// let mut rx = service.take_receiver().unwrap();
404 /// while let Some(task_id) = rx.recv().await {
405 /// println!("Task {:?} completed", task_id);
406 /// }
407 /// }
408 /// ```
409 pub fn create_service(&self) -> crate::service::TimerService {
410 crate::service::TimerService::new(self.wheel.clone(), ServiceConfig::default())
411 }
412
413 /// 创建与此时间轮绑定的 TimerService(使用自定义配置)
414 ///
415 /// # 参数
416 /// - `config`: 服务配置
417 ///
418 /// # 返回
419 /// 绑定到此时间轮的 TimerService 实例
420 ///
421 /// # 示例
422 /// ```no_run
423 /// use kestrel_protocol_timer::{TimerWheel, ServiceConfig};
424 ///
425 /// #[tokio::main]
426 /// async fn main() {
427 /// let timer = TimerWheel::with_defaults();
428 /// let config = ServiceConfig::builder()
429 /// .command_channel_capacity(1024)
430 /// .timeout_channel_capacity(2000)
431 /// .build()
432 /// .unwrap();
433 /// let service = timer.create_service_with_config(config);
434 /// }
435 /// ```
436 pub fn create_service_with_config(&self, config: ServiceConfig) -> crate::service::TimerService {
437 crate::service::TimerService::new(self.wheel.clone(), config)
438 }
439
440 /// 创建定时器任务(申请阶段)
441 ///
442 /// # 参数
443 /// - `delay`: 延迟时间
444 /// - `callback`: 实现了 TimerCallback trait 的回调对象
445 ///
446 /// # 返回
447 /// 返回 TimerTask,需要通过 `register()` 注册到时间轮
448 ///
449 /// # 示例
450 /// ```no_run
451 /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
452 /// use std::time::Duration;
453 ///
454 ///
455 /// #[tokio::main]
456 /// async fn main() {
457 /// let timer = TimerWheel::with_defaults();
458 ///
459 /// // 步骤 1: 创建任务
460 /// let task = TimerWheel::create_task(Duration::from_secs(1), Some(CallbackWrapper::new(|| async {
461 /// println!("Timer fired!");
462 /// })));
463 ///
464 /// // 获取任务 ID
465 /// let task_id = task.get_id();
466 /// println!("Created task: {:?}", task_id);
467 ///
468 /// // 步骤 2: 注册任务
469 /// let handle = timer.register(task);
470 /// }
471 /// ```
472 #[inline]
473 pub fn create_task(delay: Duration, callback: Option<CallbackWrapper>) -> crate::task::TimerTask {
474 crate::task::TimerTask::new(delay, callback)
475 }
476
477 /// 批量创建定时器任务(申请阶段)
478 ///
479 /// # 参数
480 /// - `callbacks`: (延迟时间, 回调) 的元组列表
481 ///
482 /// # 返回
483 /// 返回 TimerTask 列表,需要通过 `register_batch()` 注册到时间轮
484 ///
485 /// # 示例
486 /// ```no_run
487 /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
488 /// use std::time::Duration;
489 /// use std::sync::Arc;
490 /// use std::sync::atomic::{AtomicU32, Ordering};
491 ///
492 /// #[tokio::main]
493 /// async fn main() {
494 /// let timer = TimerWheel::with_defaults();
495 /// let counter = Arc::new(AtomicU32::new(0));
496 ///
497 /// // 步骤 1: 批量创建任务
498 /// let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = (0..3)
499 /// .map(|i| {
500 /// let counter = Arc::clone(&counter);
501 /// let delay = Duration::from_millis(100 + i * 100);
502 /// let callback = Some(CallbackWrapper::new(move || {
503 /// let counter = Arc::clone(&counter);
504 /// async move {
505 /// counter.fetch_add(1, Ordering::SeqCst);
506 /// }
507 /// }));
508 /// (delay, callback)
509 /// })
510 /// .collect();
511 ///
512 /// let tasks = TimerWheel::create_batch(callbacks);
513 /// println!("Created {} tasks", tasks.len());
514 ///
515 /// // 步骤 2: 批量注册任务
516 /// let batch = timer.register_batch(tasks);
517 /// }
518 /// ```
519 #[inline]
520 pub fn create_batch(callbacks: Vec<(Duration, Option<CallbackWrapper>)>) -> Vec<crate::task::TimerTask>
521 {
522 callbacks
523 .into_iter()
524 .map(|(delay, callback)| crate::task::TimerTask::new(delay, callback))
525 .collect()
526 }
527
528 /// 注册定时器任务到时间轮(注册阶段)
529 ///
530 /// # 参数
531 /// - `task`: 通过 `create_task()` 创建的任务
532 ///
533 /// # 返回
534 /// 返回定时器句柄,可用于取消定时器和接收完成通知
535 ///
536 /// # 示例
537 /// ```no_run
538 /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
539 ///
540 /// use std::time::Duration;
541 ///
542 /// #[tokio::main]
543 /// async fn main() {
544 /// let timer = TimerWheel::with_defaults();
545 ///
546 /// let task = TimerWheel::create_task(Duration::from_secs(1), Some(CallbackWrapper::new(|| async {
547 /// println!("Timer fired!");
548 /// })));
549 /// let task_id = task.get_id();
550 ///
551 /// let handle = timer.register(task);
552 ///
553 /// // 等待定时器完成
554 /// handle.into_completion_receiver().0.await.ok();
555 /// }
556 /// ```
557 #[inline]
558 pub fn register(&self, task: crate::task::TimerTask) -> TimerHandle {
559 let (completion_tx, completion_rx) = oneshot::channel();
560 let notifier = crate::task::CompletionNotifier(completion_tx);
561
562 let delay = task.delay;
563 let task_id = task.id;
564
565 // 单次加锁完成所有操作
566 {
567 let mut wheel_guard = self.wheel.lock();
568 wheel_guard.insert(delay, task, notifier);
569 }
570
571 TimerHandle::new(task_id, self.wheel.clone(), completion_rx)
572 }
573
574 /// 批量注册定时器任务到时间轮(注册阶段)
575 ///
576 /// # 参数
577 /// - `tasks`: 通过 `create_batch()` 创建的任务列表
578 ///
579 /// # 返回
580 /// 返回批量定时器句柄
581 ///
582 /// # 示例
583 /// ```no_run
584 /// use kestrel_protocol_timer::{TimerWheel, TimerTask};
585 /// use std::time::Duration;
586 ///
587 /// #[tokio::main]
588 /// async fn main() {
589 /// let timer = TimerWheel::with_defaults();
590 ///
591 /// let callbacks: Vec<_> = (0..3)
592 /// .map(|_| (Duration::from_secs(1), None))
593 /// .collect();
594 /// let tasks = TimerWheel::create_batch(callbacks);
595 ///
596 /// let batch = timer.register_batch(tasks);
597 /// println!("Registered {} timers", batch.len());
598 /// }
599 /// ```
600 #[inline]
601 pub fn register_batch(&self, tasks: Vec<crate::task::TimerTask>) -> BatchHandle {
602 let task_count = tasks.len();
603 let mut completion_rxs = Vec::with_capacity(task_count);
604 let mut task_ids = Vec::with_capacity(task_count);
605 let mut prepared_tasks = Vec::with_capacity(task_count);
606
607 // 步骤1: 准备所有 channels 和 notifiers(无锁)
608 // 优化:使用 for 循环代替 map + collect,避免闭包捕获开销
609 for task in tasks {
610 let (completion_tx, completion_rx) = oneshot::channel();
611 let notifier = crate::task::CompletionNotifier(completion_tx);
612
613 task_ids.push(task.id);
614 completion_rxs.push(completion_rx);
615 prepared_tasks.push((task.delay, task, notifier));
616 }
617
618 // 步骤2: 单次加锁,批量插入
619 {
620 let mut wheel_guard = self.wheel.lock();
621 wheel_guard.insert_batch(prepared_tasks);
622 }
623
624 BatchHandle::new(task_ids, self.wheel.clone(), completion_rxs)
625 }
626
627 /// 取消定时器
628 ///
629 /// # 参数
630 /// - `task_id`: 任务 ID
631 ///
632 /// # 返回
633 /// 如果任务存在且成功取消返回 true,否则返回 false
634 ///
635 /// # 示例
636 /// ```no_run
637 /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
638 ///
639 /// use std::time::Duration;
640 ///
641 /// #[tokio::main]
642 /// async fn main() {
643 /// let timer = TimerWheel::with_defaults();
644 ///
645 /// let task = TimerWheel::create_task(Duration::from_secs(10), Some(CallbackWrapper::new(|| async {
646 /// println!("Timer fired!");
647 /// })));
648 /// let task_id = task.get_id();
649 /// let _handle = timer.register(task);
650 ///
651 /// // 使用任务 ID 取消
652 /// let cancelled = timer.cancel(task_id);
653 /// println!("取消成功: {}", cancelled);
654 /// }
655 /// ```
656 #[inline]
657 pub fn cancel(&self, task_id: TaskId) -> bool {
658 let mut wheel = self.wheel.lock();
659 wheel.cancel(task_id)
660 }
661
662 /// 批量取消定时器
663 ///
664 /// # 参数
665 /// - `task_ids`: 要取消的任务 ID 列表
666 ///
667 /// # 返回
668 /// 成功取消的任务数量
669 ///
670 /// # 性能优势
671 /// - 批量处理减少锁竞争
672 /// - 内部优化批量取消操作
673 ///
674 /// # 示例
675 /// ```no_run
676 /// use kestrel_protocol_timer::{TimerWheel, TimerTask};
677 /// use std::time::Duration;
678 ///
679 /// #[tokio::main]
680 /// async fn main() {
681 /// let timer = TimerWheel::with_defaults();
682 ///
683 /// // 创建多个定时器
684 /// let task1 = TimerWheel::create_task(Duration::from_secs(10), None);
685 /// let task2 = TimerWheel::create_task(Duration::from_secs(10), None);
686 /// let task3 = TimerWheel::create_task(Duration::from_secs(10), None);
687 ///
688 /// let task_ids = vec![task1.get_id(), task2.get_id(), task3.get_id()];
689 ///
690 /// let _h1 = timer.register(task1);
691 /// let _h2 = timer.register(task2);
692 /// let _h3 = timer.register(task3);
693 ///
694 /// // 批量取消
695 /// let cancelled = timer.cancel_batch(&task_ids);
696 /// println!("已取消 {} 个定时器", cancelled);
697 /// }
698 /// ```
699 #[inline]
700 pub fn cancel_batch(&self, task_ids: &[TaskId]) -> usize {
701 let mut wheel = self.wheel.lock();
702 wheel.cancel_batch(task_ids)
703 }
704
705 /// 推迟定时器(保持原回调)
706 ///
707 /// # 参数
708 /// - `task_id`: 要推迟的任务 ID
709 /// - `new_delay`: 新的延迟时间(从当前时间点重新计算)
710 ///
711 /// # 返回
712 /// 如果任务存在且成功推迟返回 true,否则返回 false
713 ///
714 /// # 注意
715 /// - 推迟后任务 ID 保持不变
716 /// - 原有的 completion_receiver 仍然有效
717 /// - 保持原回调函数不变
718 ///
719 /// # 示例
720 /// ```no_run
721 /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
722 /// use std::time::Duration;
723 ///
724 ///
725 /// #[tokio::main]
726 /// async fn main() {
727 /// let timer = TimerWheel::with_defaults();
728 ///
729 /// let task = TimerWheel::create_task(Duration::from_secs(5), Some(CallbackWrapper::new(|| async {
730 /// println!("Timer fired!");
731 /// })));
732 /// let task_id = task.get_id();
733 /// let _handle = timer.register(task);
734 ///
735 /// // 推迟到 10 秒后触发
736 /// let success = timer.postpone(task_id, Duration::from_secs(10), Some(CallbackWrapper::new(|| async {
737 /// println!("Timer fired!");
738 /// })));
739 /// println!("推迟成功: {}", success);
740 /// }
741 /// ```
742 #[inline]
743 pub fn postpone(
744 &self,
745 task_id: TaskId,
746 new_delay: Duration,
747 callback: Option<CallbackWrapper>,
748 ) -> bool {
749 let mut wheel = self.wheel.lock();
750 wheel.postpone(task_id, new_delay, callback)
751 }
752
753 /// 批量推迟定时器(保持原回调)
754 ///
755 /// # 参数
756 /// - `updates`: (任务ID, 新延迟) 的元组列表
757 ///
758 /// # 返回
759 /// 成功推迟的任务数量
760 ///
761 /// # 性能优势
762 /// - 批量处理减少锁竞争
763 /// - 内部优化批量推迟操作
764 ///
765 /// # 示例
766 /// ```no_run
767 /// use kestrel_protocol_timer::{TimerWheel, TimerTask};
768 /// use std::time::Duration;
769 ///
770 /// #[tokio::main]
771 /// async fn main() {
772 /// let timer = TimerWheel::with_defaults();
773 ///
774 /// // 创建多个定时器
775 /// let task1 = TimerWheel::create_task(Duration::from_secs(5), None);
776 /// let task2 = TimerWheel::create_task(Duration::from_secs(5), None);
777 /// let task3 = TimerWheel::create_task(Duration::from_secs(5), None);
778 ///
779 /// let task_ids = vec![
780 /// (task1.get_id(), Duration::from_secs(10)),
781 /// (task2.get_id(), Duration::from_secs(15)),
782 /// (task3.get_id(), Duration::from_secs(20)),
783 /// ];
784 ///
785 /// timer.register(task1);
786 /// timer.register(task2);
787 /// timer.register(task3);
788 ///
789 /// // 批量推迟
790 /// let postponed = timer.postpone_batch(&task_ids);
791 /// println!("已推迟 {} 个定时器", postponed);
792 /// }
793 /// ```
794 #[inline]
795 pub fn postpone_batch(&self, updates: &[(TaskId, Duration)]) -> usize {
796 let mut wheel = self.wheel.lock();
797 wheel.postpone_batch(updates.to_vec())
798 }
799
800 /// 批量推迟定时器(替换回调)
801 ///
802 /// # 参数
803 /// - `updates`: (任务ID, 新延迟, 新回调) 的元组列表
804 ///
805 /// # 返回
806 /// 成功推迟的任务数量
807 ///
808 /// # 性能优势
809 /// - 批量处理减少锁竞争
810 /// - 内部优化批量推迟操作
811 ///
812 /// # 示例
813 /// ```no_run
814 /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
815 /// use std::time::Duration;
816 /// use std::sync::Arc;
817 /// use std::sync::atomic::{AtomicU32, Ordering};
818 ///
819 /// #[tokio::main]
820 /// async fn main() {
821 /// let timer = TimerWheel::with_defaults();
822 /// let counter = Arc::new(AtomicU32::new(0));
823 ///
824 /// // 创建多个定时器
825 /// let task1 = TimerWheel::create_task(Duration::from_secs(5), None);
826 /// let task2 = TimerWheel::create_task(Duration::from_secs(5), None);
827 ///
828 /// let id1 = task1.get_id();
829 /// let id2 = task2.get_id();
830 ///
831 /// timer.register(task1);
832 /// timer.register(task2);
833 ///
834 /// // 批量推迟并替换回调
835 /// let updates: Vec<_> = vec![id1, id2]
836 /// .into_iter()
837 /// .map(|id| {
838 /// let counter = Arc::clone(&counter);
839 /// (id, Duration::from_secs(10), Some(CallbackWrapper::new(move || {
840 /// let counter = Arc::clone(&counter);
841 /// async move { counter.fetch_add(1, Ordering::SeqCst); }
842 /// })))
843 /// })
844 /// .collect();
845 /// let postponed = timer.postpone_batch_with_callbacks(updates);
846 /// println!("已推迟 {} 个定时器", postponed);
847 /// }
848 /// ```
849 #[inline]
850 pub fn postpone_batch_with_callbacks(
851 &self,
852 updates: Vec<(TaskId, Duration, Option<CallbackWrapper>)>,
853 ) -> usize {
854 let mut wheel = self.wheel.lock();
855 wheel.postpone_batch_with_callbacks(updates.to_vec())
856 }
857
858 /// 核心 tick 循环
859 async fn tick_loop(wheel: Arc<Mutex<Wheel>>, tick_duration: Duration) {
860 let mut interval = tokio::time::interval(tick_duration);
861 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
862
863 loop {
864 interval.tick().await;
865
866 // 推进时间轮并获取到期任务
867 let expired_tasks = {
868 let mut wheel_guard = wheel.lock();
869 wheel_guard.advance()
870 };
871
872 // 执行到期任务
873 for task in expired_tasks {
874 let callback = task.get_callback();
875
876 // 移动task的所有权来获取completion_notifier
877 let notifier = task.completion_notifier;
878
879 // 只有注册过的任务才有 notifier
880 if let Some(notifier) = notifier {
881 // 在独立的 tokio 任务中执行回调,并在回调完成后发送通知
882 if let Some(callback) = callback {
883 tokio::spawn(async move {
884 // 执行回调
885 let future = callback.call();
886 future.await;
887
888 // 回调执行完成后发送通知
889 let _ = notifier.0.send(TaskCompletionReason::Expired);
890 });
891 } else {
892 // 如果没有回调,立即发送完成通知
893 let _ = notifier.0.send(TaskCompletionReason::Expired);
894 }
895 }
896 }
897 }
898 }
899
900 /// 停止定时器管理器
901 pub async fn shutdown(mut self) {
902 if let Some(handle) = self.tick_handle.take() {
903 handle.abort();
904 let _ = handle.await;
905 }
906 }
907}
908
909impl Drop for TimerWheel {
910 fn drop(&mut self) {
911 if let Some(handle) = self.tick_handle.take() {
912 handle.abort();
913 }
914 }
915}
916
917#[cfg(test)]
918mod tests {
919 use super::*;
920 use std::sync::atomic::{AtomicU32, Ordering};
921
922 #[tokio::test]
923 async fn test_timer_creation() {
924 let _timer = TimerWheel::with_defaults();
925 }
926
927 #[tokio::test]
928 async fn test_schedule_once() {
929 use std::sync::Arc;
930 let timer = TimerWheel::with_defaults();
931 let counter = Arc::new(AtomicU32::new(0));
932 let counter_clone = Arc::clone(&counter);
933
934 let task = TimerWheel::create_task(
935 Duration::from_millis(50),
936 Some(CallbackWrapper::new(move || {
937 let counter = Arc::clone(&counter_clone);
938 async move {
939 counter.fetch_add(1, Ordering::SeqCst);
940 }
941 })),
942 );
943 let _handle = timer.register(task);
944
945 // 等待定时器触发
946 tokio::time::sleep(Duration::from_millis(100)).await;
947 assert_eq!(counter.load(Ordering::SeqCst), 1);
948 }
949
950 #[tokio::test]
951 async fn test_cancel_timer() {
952 use std::sync::Arc;
953 let timer = TimerWheel::with_defaults();
954 let counter = Arc::new(AtomicU32::new(0));
955 let counter_clone = Arc::clone(&counter);
956
957 let task = TimerWheel::create_task(
958 Duration::from_millis(100),
959 Some(CallbackWrapper::new(move || {
960 let counter = Arc::clone(&counter_clone);
961 async move {
962 counter.fetch_add(1, Ordering::SeqCst);
963 }
964 })),
965 );
966 let handle = timer.register(task);
967
968 // 立即取消
969 let cancel_result = handle.cancel();
970 assert!(cancel_result);
971
972 // 等待足够长时间确保定时器不会触发
973 tokio::time::sleep(Duration::from_millis(200)).await;
974 assert_eq!(counter.load(Ordering::SeqCst), 0);
975 }
976
977 #[tokio::test]
978 async fn test_cancel_immediate() {
979 use std::sync::Arc;
980 let timer = TimerWheel::with_defaults();
981 let counter = Arc::new(AtomicU32::new(0));
982 let counter_clone = Arc::clone(&counter);
983
984 let task = TimerWheel::create_task(
985 Duration::from_millis(100),
986 Some(CallbackWrapper::new(move || {
987 let counter = Arc::clone(&counter_clone);
988 async move {
989 counter.fetch_add(1, Ordering::SeqCst);
990 }
991 })),
992 );
993 let handle = timer.register(task);
994
995 // 立即取消
996 let cancel_result = handle.cancel();
997 assert!(cancel_result);
998
999 // 等待足够长时间确保定时器不会触发
1000 tokio::time::sleep(Duration::from_millis(200)).await;
1001 assert_eq!(counter.load(Ordering::SeqCst), 0);
1002 }
1003
1004 #[tokio::test]
1005 async fn test_postpone_timer() {
1006 use std::sync::Arc;
1007 let timer = TimerWheel::with_defaults();
1008 let counter = Arc::new(AtomicU32::new(0));
1009 let counter_clone = Arc::clone(&counter);
1010
1011 let task = TimerWheel::create_task(
1012 Duration::from_millis(50),
1013 Some(CallbackWrapper::new(move || {
1014 let counter = Arc::clone(&counter_clone);
1015 async move {
1016 counter.fetch_add(1, Ordering::SeqCst);
1017 }
1018 })),
1019 );
1020 let task_id = task.get_id();
1021 let handle = timer.register(task);
1022
1023 // 推迟任务到 150ms
1024 let postponed = timer.postpone(task_id, Duration::from_millis(150), None);
1025 assert!(postponed);
1026
1027 // 等待原定时间 50ms,任务不应该触发
1028 tokio::time::sleep(Duration::from_millis(70)).await;
1029 assert_eq!(counter.load(Ordering::SeqCst), 0);
1030
1031 // 等待新的触发时间(从推迟开始算,还需要等待约 150ms)
1032 let result = tokio::time::timeout(
1033 Duration::from_millis(200),
1034 handle.into_completion_receiver().0
1035 ).await;
1036 assert!(result.is_ok());
1037
1038 // 等待回调执行
1039 tokio::time::sleep(Duration::from_millis(20)).await;
1040 assert_eq!(counter.load(Ordering::SeqCst), 1);
1041 }
1042
1043 #[tokio::test]
1044 async fn test_postpone_with_callback() {
1045 use std::sync::Arc;
1046 let timer = TimerWheel::with_defaults();
1047 let counter = Arc::new(AtomicU32::new(0));
1048 let counter_clone1 = Arc::clone(&counter);
1049 let counter_clone2 = Arc::clone(&counter);
1050
1051 // 创建任务,原始回调增加 1
1052 let task = TimerWheel::create_task(
1053 Duration::from_millis(50),
1054 Some(CallbackWrapper::new(move || {
1055 let counter = Arc::clone(&counter_clone1);
1056 async move {
1057 counter.fetch_add(1, Ordering::SeqCst);
1058 }
1059 })),
1060 );
1061 let task_id = task.get_id();
1062 let handle = timer.register(task);
1063
1064 // 推迟任务并替换回调,新回调增加 10
1065 let postponed = timer.postpone(
1066 task_id,
1067 Duration::from_millis(100),
1068 Some(CallbackWrapper::new(move || {
1069 let counter = Arc::clone(&counter_clone2);
1070 async move {
1071 counter.fetch_add(10, Ordering::SeqCst);
1072 }
1073 })),
1074 );
1075 assert!(postponed);
1076
1077 // 等待任务触发(推迟后需要等待100ms,加上余量)
1078 let result = tokio::time::timeout(
1079 Duration::from_millis(200),
1080 handle.into_completion_receiver().0
1081 ).await;
1082 assert!(result.is_ok());
1083
1084 // 等待回调执行
1085 tokio::time::sleep(Duration::from_millis(20)).await;
1086
1087 // 验证新回调被执行(增加了 10 而不是 1)
1088 assert_eq!(counter.load(Ordering::SeqCst), 10);
1089 }
1090
1091 #[tokio::test]
1092 async fn test_postpone_nonexistent_timer() {
1093 let timer = TimerWheel::with_defaults();
1094
1095 // 尝试推迟不存在的任务
1096 let fake_task = TimerWheel::create_task(Duration::from_millis(50), None);
1097 let fake_task_id = fake_task.get_id();
1098 // 不注册这个任务
1099
1100 let postponed = timer.postpone(fake_task_id, Duration::from_millis(100), None);
1101 assert!(!postponed);
1102 }
1103
1104 #[tokio::test]
1105 async fn test_postpone_batch() {
1106 use std::sync::Arc;
1107 let timer = TimerWheel::with_defaults();
1108 let counter = Arc::new(AtomicU32::new(0));
1109
1110 // 创建 3 个任务
1111 let mut task_ids = Vec::new();
1112 for _ in 0..3 {
1113 let counter_clone = Arc::clone(&counter);
1114 let task = TimerWheel::create_task(
1115 Duration::from_millis(50),
1116 Some(CallbackWrapper::new(move || {
1117 let counter = Arc::clone(&counter_clone);
1118 async move {
1119 counter.fetch_add(1, Ordering::SeqCst);
1120 }
1121 })),
1122 );
1123 task_ids.push((task.get_id(), Duration::from_millis(150)));
1124 timer.register(task);
1125 }
1126
1127 // 批量推迟
1128 let postponed = timer.postpone_batch(&task_ids);
1129 assert_eq!(postponed, 3);
1130
1131 // 等待原定时间 50ms,任务不应该触发
1132 tokio::time::sleep(Duration::from_millis(70)).await;
1133 assert_eq!(counter.load(Ordering::SeqCst), 0);
1134
1135 // 等待新的触发时间(从推迟开始算,还需要等待约 150ms)
1136 tokio::time::sleep(Duration::from_millis(200)).await;
1137
1138 // 等待回调执行
1139 tokio::time::sleep(Duration::from_millis(20)).await;
1140 assert_eq!(counter.load(Ordering::SeqCst), 3);
1141 }
1142
1143 #[tokio::test]
1144 async fn test_postpone_batch_with_callbacks() {
1145 use std::sync::Arc;
1146 let timer = TimerWheel::with_defaults();
1147 let counter = Arc::new(AtomicU32::new(0));
1148
1149 // 创建 3 个任务
1150 let mut task_ids = Vec::new();
1151 for _ in 0..3 {
1152 let task = TimerWheel::create_task(
1153 Duration::from_millis(50),
1154 None
1155 );
1156 task_ids.push(task.get_id());
1157 timer.register(task);
1158 }
1159
1160 // 批量推迟并替换回调
1161 let updates: Vec<_> = task_ids
1162 .into_iter()
1163 .map(|id| {
1164 let counter_clone = Arc::clone(&counter);
1165 (id, Duration::from_millis(150), Some(CallbackWrapper::new(move || {
1166 let counter = Arc::clone(&counter_clone);
1167 async move {
1168 counter.fetch_add(1, Ordering::SeqCst);
1169 }
1170 })))
1171 })
1172 .collect();
1173
1174 let postponed = timer.postpone_batch_with_callbacks(updates);
1175 assert_eq!(postponed, 3);
1176
1177 // 等待原定时间 50ms,任务不应该触发
1178 tokio::time::sleep(Duration::from_millis(70)).await;
1179 assert_eq!(counter.load(Ordering::SeqCst), 0);
1180
1181 // 等待新的触发时间(从推迟开始算,还需要等待约 150ms)
1182 tokio::time::sleep(Duration::from_millis(200)).await;
1183
1184 // 等待回调执行
1185 tokio::time::sleep(Duration::from_millis(20)).await;
1186 assert_eq!(counter.load(Ordering::SeqCst), 3);
1187 }
1188
1189 #[tokio::test]
1190 async fn test_postpone_keeps_completion_receiver_valid() {
1191 use std::sync::Arc;
1192 let timer = TimerWheel::with_defaults();
1193 let counter = Arc::new(AtomicU32::new(0));
1194 let counter_clone = Arc::clone(&counter);
1195
1196 let task = TimerWheel::create_task(
1197 Duration::from_millis(50),
1198 Some(CallbackWrapper::new(move || {
1199 let counter = Arc::clone(&counter_clone);
1200 async move {
1201 counter.fetch_add(1, Ordering::SeqCst);
1202 }
1203 })),
1204 );
1205 let task_id = task.get_id();
1206 let handle = timer.register(task);
1207
1208 // 推迟任务
1209 timer.postpone(task_id, Duration::from_millis(100), None);
1210
1211 // 验证原 completion_receiver 仍然有效(推迟后需要等待100ms,加上余量)
1212 let result = tokio::time::timeout(
1213 Duration::from_millis(200),
1214 handle.into_completion_receiver().0
1215 ).await;
1216 assert!(result.is_ok(), "Completion receiver should still work after postpone");
1217
1218 // 等待回调执行
1219 tokio::time::sleep(Duration::from_millis(20)).await;
1220 assert_eq!(counter.load(Ordering::SeqCst), 1);
1221 }
1222}
1223