kestrel_protocol_timer/timer.rs
1use crate::config::{BatchConfig, ServiceConfig, WheelConfig};
2use crate::task::{CallbackWrapper, TaskId, TaskCompletionReason};
3use crate::wheel::Wheel;
4use parking_lot::Mutex;
5use std::sync::Arc;
6use std::time::Duration;
7use tokio::sync::oneshot;
8use tokio::task::JoinHandle;
9
10/// 完成通知接收器,用于接收定时器完成通知
11pub struct CompletionReceiver(pub oneshot::Receiver<TaskCompletionReason>);
12
13/// 定时器句柄,用于管理定时器的生命周期
14///
15/// 注意:此类型不实现 Clone,以防止重复取消同一个定时器。
16/// 每个定时器只应有一个所有者。
17pub struct TimerHandle {
18 pub(crate) task_id: TaskId,
19 pub(crate) wheel: Arc<Mutex<Wheel>>,
20 pub(crate) completion_rx: CompletionReceiver,
21}
22
23impl TimerHandle {
24 pub(crate) fn new(task_id: TaskId, wheel: Arc<Mutex<Wheel>>, completion_rx: oneshot::Receiver<TaskCompletionReason>) -> Self {
25 Self { task_id, wheel, completion_rx: CompletionReceiver(completion_rx) }
26 }
27
28 /// 取消定时器
29 ///
30 /// # 返回
31 /// 如果任务存在且成功取消返回 true,否则返回 false
32 ///
33 /// # 示例
34 /// ```no_run
35 /// # use kestrel_protocol_timer::{TimerWheel, CallbackWrapper};
36 /// # use std::time::Duration;
37 /// #
38 /// # #[tokio::main]
39 /// # async fn main() {
40 /// let timer = TimerWheel::with_defaults();
41 /// let callback = Some(CallbackWrapper::new(|| async {}));
42 /// let task = TimerWheel::create_task(Duration::from_secs(1), callback);
43 /// let handle = timer.register(task);
44 ///
45 /// // 取消定时器
46 /// let success = handle.cancel();
47 /// println!("取消成功: {}", success);
48 /// # }
49 /// ```
50 pub fn cancel(&self) -> bool {
51 let mut wheel = self.wheel.lock();
52 wheel.cancel(self.task_id)
53 }
54
55 /// 获取完成通知接收器的可变引用
56 ///
57 /// # 示例
58 /// ```no_run
59 /// # use kestrel_protocol_timer::{TimerWheel, CallbackWrapper};
60 /// # use std::time::Duration;
61 /// #
62 /// # #[tokio::main]
63 /// # async fn main() {
64 /// let timer = TimerWheel::with_defaults();
65 /// let callback = Some(CallbackWrapper::new(|| async {
66 /// println!("Timer fired!");
67 /// }));
68 /// let task = TimerWheel::create_task(Duration::from_secs(1), callback);
69 /// let handle = timer.register(task);
70 ///
71 /// // 等待定时器完成(使用 into_completion_receiver 消耗句柄)
72 /// handle.into_completion_receiver().0.await.ok();
73 /// println!("Timer completed!");
74 /// # }
75 /// ```
76 pub fn completion_receiver(&mut self) -> &mut CompletionReceiver {
77 &mut self.completion_rx
78 }
79
80 /// 消耗句柄,返回完成通知接收器
81 ///
82 /// # 示例
83 /// ```no_run
84 /// # use kestrel_protocol_timer::{TimerWheel, CallbackWrapper};
85 /// # use std::time::Duration;
86 /// #
87 /// # #[tokio::main]
88 /// # async fn main() {
89 /// let timer = TimerWheel::with_defaults();
90 /// let callback = Some(CallbackWrapper::new(|| async {
91 /// println!("Timer fired!");
92 /// }));
93 /// let task = TimerWheel::create_task(Duration::from_secs(1), callback);
94 /// let handle = timer.register(task);
95 ///
96 /// // 等待定时器完成
97 /// handle.into_completion_receiver().0.await.ok();
98 /// println!("Timer completed!");
99 /// # }
100 /// ```
101 pub fn into_completion_receiver(self) -> CompletionReceiver {
102 self.completion_rx
103 }
104}
105
106/// 批量定时器句柄,用于管理批量调度的定时器
107///
108/// 通过共享 Wheel 引用减少内存开销,同时提供批量操作和迭代器访问能力。
109///
110/// 注意:此类型不实现 Clone,以防止重复取消同一批定时器。
111/// 如需访问单个定时器句柄,请使用 `into_iter()` 或 `into_handles()` 进行转换。
112pub struct BatchHandle {
113 pub(crate) task_ids: Vec<TaskId>,
114 pub(crate) wheel: Arc<Mutex<Wheel>>,
115 pub(crate) completion_rxs: Vec<oneshot::Receiver<TaskCompletionReason>>,
116}
117
118impl BatchHandle {
119 pub(crate) fn new(task_ids: Vec<TaskId>, wheel: Arc<Mutex<Wheel>>, completion_rxs: Vec<oneshot::Receiver<TaskCompletionReason>>) -> Self {
120 Self { task_ids, wheel, completion_rxs }
121 }
122
123 /// 批量取消所有定时器
124 ///
125 /// # 返回
126 /// 成功取消的任务数量
127 ///
128 /// # 示例
129 /// ```no_run
130 /// # use kestrel_protocol_timer::{TimerWheel, CallbackWrapper};
131 /// # use std::time::Duration;
132 /// #
133 /// # #[tokio::main]
134 /// # async fn main() {
135 /// let timer = TimerWheel::with_defaults();
136 /// let delays: Vec<Duration> = (0..10)
137 /// .map(|_| Duration::from_secs(1))
138 /// .collect();
139 /// let tasks = TimerWheel::create_batch(delays);
140 /// let batch = timer.register_batch(tasks);
141 ///
142 /// let cancelled = batch.cancel_all();
143 /// println!("取消了 {} 个定时器", cancelled);
144 /// # }
145 /// ```
146 pub fn cancel_all(self) -> usize {
147 let mut wheel = self.wheel.lock();
148 wheel.cancel_batch(&self.task_ids)
149 }
150
151 /// 将批量句柄转换为单个定时器句柄的 Vec
152 ///
153 /// 消耗 BatchHandle,为每个任务创建独立的 TimerHandle。
154 ///
155 /// # 示例
156 /// ```no_run
157 /// # use kestrel_protocol_timer::{TimerWheel, CallbackWrapper};
158 /// # use std::time::Duration;
159 /// #
160 /// # #[tokio::main]
161 /// # async fn main() {
162 /// let timer = TimerWheel::with_defaults();
163 /// let delays: Vec<Duration> = (0..3)
164 /// .map(|_| Duration::from_secs(1))
165 /// .collect();
166 /// let tasks = TimerWheel::create_batch(delays);
167 /// let batch = timer.register_batch(tasks);
168 ///
169 /// // 转换为独立的句柄
170 /// let handles = batch.into_handles();
171 /// for handle in handles {
172 /// // 可以单独操作每个句柄
173 /// }
174 /// # }
175 /// ```
176 pub fn into_handles(self) -> Vec<TimerHandle> {
177 self.task_ids
178 .into_iter()
179 .zip(self.completion_rxs.into_iter())
180 .map(|(task_id, rx)| {
181 TimerHandle::new(task_id, self.wheel.clone(), rx)
182 })
183 .collect()
184 }
185
186 /// 获取批量任务的数量
187 pub fn len(&self) -> usize {
188 self.task_ids.len()
189 }
190
191 /// 检查批量任务是否为空
192 pub fn is_empty(&self) -> bool {
193 self.task_ids.is_empty()
194 }
195
196 /// 获取所有任务 ID 的引用
197 pub fn task_ids(&self) -> &[TaskId] {
198 &self.task_ids
199 }
200
201 /// 获取所有完成通知接收器的引用
202 ///
203 /// # 返回
204 /// 所有任务的完成通知接收器列表引用
205 pub fn completion_receivers(&mut self) -> &mut Vec<oneshot::Receiver<TaskCompletionReason>> {
206 &mut self.completion_rxs
207 }
208
209 /// 消耗句柄,返回所有完成通知接收器
210 ///
211 /// # 返回
212 /// 所有任务的完成通知接收器列表
213 ///
214 /// # 示例
215 /// ```no_run
216 /// # use kestrel_protocol_timer::{TimerWheel, CallbackWrapper};
217 /// # use std::time::Duration;
218 /// #
219 /// # #[tokio::main]
220 /// # async fn main() {
221 /// let timer = TimerWheel::with_defaults();
222 /// let delays: Vec<Duration> = (0..3)
223 /// .map(|_| Duration::from_secs(1))
224 /// .collect();
225 /// let tasks = TimerWheel::create_batch(delays);
226 /// let batch = timer.register_batch(tasks);
227 ///
228 /// // 获取所有完成通知接收器
229 /// let receivers = batch.into_completion_receivers();
230 /// for rx in receivers {
231 /// tokio::spawn(async move {
232 /// if rx.await.is_ok() {
233 /// println!("A timer completed!");
234 /// }
235 /// });
236 /// }
237 /// # }
238 /// ```
239 pub fn into_completion_receivers(self) -> Vec<oneshot::Receiver<TaskCompletionReason>> {
240 self.completion_rxs
241 }
242}
243
244/// 实现 IntoIterator,允许直接迭代 BatchHandle
245///
246/// # 示例
247/// ```no_run
248/// # use kestrel_protocol_timer::{TimerWheel, CallbackWrapper};
249/// # use std::time::Duration;
250/// #
251/// # #[tokio::main]
252/// # async fn main() {
253/// let timer = TimerWheel::with_defaults();
254/// let delays: Vec<Duration> = (0..3)
255/// .map(|_| Duration::from_secs(1))
256/// .collect();
257/// let tasks = TimerWheel::create_batch(delays);
258/// let batch = timer.register_batch(tasks);
259///
260/// // 直接迭代,每个元素都是独立的 TimerHandle
261/// for handle in batch {
262/// // 可以单独操作每个句柄
263/// }
264/// # }
265/// ```
266impl IntoIterator for BatchHandle {
267 type Item = TimerHandle;
268 type IntoIter = BatchHandleIter;
269
270 fn into_iter(self) -> Self::IntoIter {
271 BatchHandleIter {
272 task_ids: self.task_ids.into_iter(),
273 completion_rxs: self.completion_rxs.into_iter(),
274 wheel: self.wheel,
275 }
276 }
277}
278
279/// BatchHandle 的迭代器
280pub struct BatchHandleIter {
281 task_ids: std::vec::IntoIter<TaskId>,
282 completion_rxs: std::vec::IntoIter<oneshot::Receiver<TaskCompletionReason>>,
283 wheel: Arc<Mutex<Wheel>>,
284}
285
286impl Iterator for BatchHandleIter {
287 type Item = TimerHandle;
288
289 fn next(&mut self) -> Option<Self::Item> {
290 match (self.task_ids.next(), self.completion_rxs.next()) {
291 (Some(task_id), Some(rx)) => {
292 Some(TimerHandle::new(task_id, self.wheel.clone(), rx))
293 }
294 _ => None,
295 }
296 }
297
298 fn size_hint(&self) -> (usize, Option<usize>) {
299 self.task_ids.size_hint()
300 }
301}
302
303impl ExactSizeIterator for BatchHandleIter {
304 fn len(&self) -> usize {
305 self.task_ids.len()
306 }
307}
308
309/// 时间轮定时器管理器
310pub struct TimerWheel {
311 /// 时间轮唯一标识符
312
313 /// 时间轮实例(使用 Arc<Mutex> 包装以支持多线程访问)
314 wheel: Arc<Mutex<Wheel>>,
315
316 /// 后台 tick 循环任务句柄
317 tick_handle: Option<JoinHandle<()>>,
318}
319
320impl TimerWheel {
321 /// 创建新的定时器管理器
322 ///
323 /// # 参数
324 /// - `config`: 时间轮配置(已经过验证)
325 ///
326 /// # 示例
327 /// ```no_run
328 /// use kestrel_protocol_timer::{TimerWheel, WheelConfig, TimerTask, BatchConfig};
329 /// use std::time::Duration;
330 ///
331 /// #[tokio::main]
332 /// async fn main() {
333 /// let config = WheelConfig::builder()
334 /// .tick_duration(Duration::from_millis(10))
335 /// .slot_count(512)
336 /// .build()
337 /// .unwrap();
338 /// let timer = TimerWheel::new(config, BatchConfig::default());
339 ///
340 /// // 使用两步式 API
341 /// let task = TimerWheel::create_task(Duration::from_secs(1), None);
342 /// let handle = timer.register(task);
343 /// }
344 /// ```
345 pub fn new(config: WheelConfig, batch_config: BatchConfig) -> Self {
346 let tick_duration = config.tick_duration;
347 let wheel = Wheel::new(config, batch_config);
348 let wheel = Arc::new(Mutex::new(wheel));
349 let wheel_clone = wheel.clone();
350
351 // 启动后台 tick 循环
352 let tick_handle = tokio::spawn(async move {
353 Self::tick_loop(wheel_clone, tick_duration).await;
354 });
355
356 Self {
357 wheel,
358 tick_handle: Some(tick_handle),
359 }
360 }
361
362 /// 创建带默认配置的定时器管理器
363 /// - tick 时长: 10ms
364 /// - 槽位数量: 512
365 ///
366 /// # 示例
367 /// ```no_run
368 /// use kestrel_protocol_timer::TimerWheel;
369 ///
370 /// #[tokio::main]
371 /// async fn main() {
372 /// let timer = TimerWheel::with_defaults();
373 /// }
374 /// ```
375 pub fn with_defaults() -> Self {
376 Self::new(WheelConfig::default(), BatchConfig::default())
377 }
378
379 /// 创建与此时间轮绑定的 TimerService(使用默认配置)
380 ///
381 /// # 返回
382 /// 绑定到此时间轮的 TimerService 实例
383 ///
384 /// # 参数
385 /// - `service_config`: 服务配置
386 ///
387 /// # 示例
388 /// ```no_run
389 /// use kestrel_protocol_timer::{TimerWheel, TimerService, CallbackWrapper, ServiceConfig};
390 /// use std::time::Duration;
391 ///
392 ///
393 /// #[tokio::main]
394 /// async fn main() {
395 /// let timer = TimerWheel::with_defaults();
396 /// let mut service = timer.create_service(ServiceConfig::default());
397 ///
398 /// // 使用两步式 API 通过 service 批量调度定时器
399 /// let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = (0..5)
400 /// .map(|_| (Duration::from_millis(100), Some(CallbackWrapper::new(|| async {}))))
401 /// .collect();
402 /// let tasks = TimerService::create_batch_with_callbacks(callbacks);
403 /// service.register_batch(tasks).unwrap();
404 ///
405 /// // 接收超时通知
406 /// let mut rx = service.take_receiver().unwrap();
407 /// while let Some(task_id) = rx.recv().await {
408 /// println!("Task {:?} completed", task_id);
409 /// }
410 /// }
411 /// ```
412 pub fn create_service(&self, service_config: ServiceConfig) -> crate::service::TimerService {
413 crate::service::TimerService::new(self.wheel.clone(), service_config)
414 }
415
416 /// 创建与此时间轮绑定的 TimerService(使用自定义配置)
417 ///
418 /// # 参数
419 /// - `config`: 服务配置
420 ///
421 /// # 返回
422 /// 绑定到此时间轮的 TimerService 实例
423 ///
424 /// # 示例
425 /// ```no_run
426 /// use kestrel_protocol_timer::{TimerWheel, ServiceConfig};
427 ///
428 /// #[tokio::main]
429 /// async fn main() {
430 /// let timer = TimerWheel::with_defaults();
431 /// let config = ServiceConfig::builder()
432 /// .command_channel_capacity(1024)
433 /// .timeout_channel_capacity(2000)
434 /// .build()
435 /// .unwrap();
436 /// let service = timer.create_service_with_config(config);
437 /// }
438 /// ```
439 pub fn create_service_with_config(&self, config: ServiceConfig) -> crate::service::TimerService {
440 crate::service::TimerService::new(self.wheel.clone(), config)
441 }
442
443 /// 创建定时器任务(申请阶段)
444 ///
445 /// # 参数
446 /// - `delay`: 延迟时间
447 /// - `callback`: 实现了 TimerCallback trait 的回调对象
448 ///
449 /// # 返回
450 /// 返回 TimerTask,需要通过 `register()` 注册到时间轮
451 ///
452 /// # 示例
453 /// ```no_run
454 /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
455 /// use std::time::Duration;
456 ///
457 ///
458 /// #[tokio::main]
459 /// async fn main() {
460 /// let timer = TimerWheel::with_defaults();
461 ///
462 /// // 步骤 1: 创建任务
463 /// let task = TimerWheel::create_task(Duration::from_secs(1), Some(CallbackWrapper::new(|| async {
464 /// println!("Timer fired!");
465 /// })));
466 ///
467 /// // 获取任务 ID
468 /// let task_id = task.get_id();
469 /// println!("Created task: {:?}", task_id);
470 ///
471 /// // 步骤 2: 注册任务
472 /// let handle = timer.register(task);
473 /// }
474 /// ```
475 #[inline]
476 pub fn create_task(delay: Duration, callback: Option<CallbackWrapper>) -> crate::task::TimerTask {
477 crate::task::TimerTask::new(delay, callback)
478 }
479
480 /// 批量创建定时器任务(申请阶段)
481 ///
482 /// # 参数
483 /// - `callbacks`: (延迟时间, 回调) 的元组列表
484 ///
485 /// # 返回
486 /// 返回 TimerTask 列表,需要通过 `register_batch()` 注册到时间轮
487 ///
488 /// # 示例
489 /// ```no_run
490 /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
491 /// use std::time::Duration;
492 /// use std::sync::Arc;
493 /// use std::sync::atomic::{AtomicU32, Ordering};
494 ///
495 /// #[tokio::main]
496 /// async fn main() {
497 /// let timer = TimerWheel::with_defaults();
498 /// let counter = Arc::new(AtomicU32::new(0));
499 ///
500 /// // 步骤 1: 批量创建任务
501 /// let delays: Vec<Duration> = (0..3)
502 /// .map(|_| Duration::from_millis(100))
503 /// .collect();
504 ///
505 /// let tasks = TimerWheel::create_batch(delays);
506 /// println!("Created {} tasks", tasks.len());
507 ///
508 /// // 步骤 2: 批量注册任务
509 /// let batch = timer.register_batch(tasks);
510 /// }
511 /// ```
512 #[inline]
513 pub fn create_batch(delays: Vec<Duration>) -> Vec<crate::task::TimerTask>
514 {
515 delays
516 .into_iter()
517 .map(|delay| crate::task::TimerTask::new(delay, None))
518 .collect()
519 }
520
521 /// 批量创建定时器任务(申请阶段)
522 ///
523 /// # 参数
524 /// - `callbacks`: (延迟时间, 回调) 的元组列表
525 ///
526 /// # 返回
527 /// 返回 TimerTask 列表,需要通过 `register_batch()` 注册到时间轮
528 ///
529 /// # 示例
530 /// ```no_run
531 /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
532 /// use std::time::Duration;
533 /// use std::sync::Arc;
534 /// use std::sync::atomic::{AtomicU32, Ordering};
535 ///
536 /// #[tokio::main]
537 /// async fn main() {
538 /// let timer = TimerWheel::with_defaults();
539 /// let counter = Arc::new(AtomicU32::new(0));
540 ///
541 /// // 步骤 1: 批量创建任务
542 /// let delays: Vec<Duration> = (0..3)
543 /// .map(|_| Duration::from_millis(100))
544 /// .collect();
545 /// let callbacks: Vec<(Duration, Option<CallbackWrapper>)> = delays
546 /// .into_iter()
547 /// .map(|delay| {
548 /// let counter = Arc::clone(&counter);
549 /// let callback = Some(CallbackWrapper::new(move || {
550 /// let counter = Arc::clone(&counter);
551 /// async move {
552 /// counter.fetch_add(1, Ordering::SeqCst);
553 /// }
554 /// }));
555 /// (delay, callback)
556 /// })
557 /// .collect();
558 ///
559 /// let tasks = TimerWheel::create_batch_with_callbacks(callbacks);
560 /// println!("Created {} tasks", tasks.len());
561 ///
562 /// // 步骤 2: 批量注册任务
563 /// let batch = timer.register_batch(tasks);
564 /// }
565 /// ```
566 #[inline]
567 pub fn create_batch_with_callbacks(callbacks: Vec<(Duration, Option<CallbackWrapper>)>) -> Vec<crate::task::TimerTask>
568 {
569 callbacks
570 .into_iter()
571 .map(|(delay, callback)| crate::task::TimerTask::new(delay, callback))
572 .collect()
573 }
574
575 /// 注册定时器任务到时间轮(注册阶段)
576 ///
577 /// # 参数
578 /// - `task`: 通过 `create_task()` 创建的任务
579 ///
580 /// # 返回
581 /// 返回定时器句柄,可用于取消定时器和接收完成通知
582 ///
583 /// # 示例
584 /// ```no_run
585 /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
586 ///
587 /// use std::time::Duration;
588 ///
589 /// #[tokio::main]
590 /// async fn main() {
591 /// let timer = TimerWheel::with_defaults();
592 ///
593 /// let task = TimerWheel::create_task(Duration::from_secs(1), Some(CallbackWrapper::new(|| async {
594 /// println!("Timer fired!");
595 /// })));
596 /// let task_id = task.get_id();
597 ///
598 /// let handle = timer.register(task);
599 ///
600 /// // 等待定时器完成
601 /// handle.into_completion_receiver().0.await.ok();
602 /// }
603 /// ```
604 #[inline]
605 pub fn register(&self, task: crate::task::TimerTask) -> TimerHandle {
606 let (completion_tx, completion_rx) = oneshot::channel();
607 let notifier = crate::task::CompletionNotifier(completion_tx);
608
609 let task_id = task.id;
610
611 // 单次加锁完成所有操作
612 {
613 let mut wheel_guard = self.wheel.lock();
614 wheel_guard.insert(task, notifier);
615 }
616
617 TimerHandle::new(task_id, self.wheel.clone(), completion_rx)
618 }
619
620 /// 批量注册定时器任务到时间轮(注册阶段)
621 ///
622 /// # 参数
623 /// - `tasks`: 通过 `create_batch()` 创建的任务列表
624 ///
625 /// # 返回
626 /// 返回批量定时器句柄
627 ///
628 /// # 示例
629 /// ```no_run
630 /// use kestrel_protocol_timer::{TimerWheel, TimerTask};
631 /// use std::time::Duration;
632 ///
633 /// #[tokio::main]
634 /// async fn main() {
635 /// let timer = TimerWheel::with_defaults();
636 ///
637 /// let delays: Vec<Duration> = (0..3)
638 /// .map(|_| Duration::from_secs(1))
639 /// .collect();
640 /// let tasks = TimerWheel::create_batch(delays);
641 ///
642 /// let batch = timer.register_batch(tasks);
643 /// println!("Registered {} timers", batch.len());
644 /// }
645 /// ```
646 #[inline]
647 pub fn register_batch(&self, tasks: Vec<crate::task::TimerTask>) -> BatchHandle {
648 let task_count = tasks.len();
649 let mut completion_rxs = Vec::with_capacity(task_count);
650 let mut task_ids = Vec::with_capacity(task_count);
651 let mut prepared_tasks = Vec::with_capacity(task_count);
652
653 // 步骤1: 准备所有 channels 和 notifiers(无锁)
654 // 优化:使用 for 循环代替 map + collect,避免闭包捕获开销
655 for task in tasks {
656 let (completion_tx, completion_rx) = oneshot::channel();
657 let notifier = crate::task::CompletionNotifier(completion_tx);
658
659 task_ids.push(task.id);
660 completion_rxs.push(completion_rx);
661 prepared_tasks.push((task, notifier));
662 }
663
664 // 步骤2: 单次加锁,批量插入
665 {
666 let mut wheel_guard = self.wheel.lock();
667 wheel_guard.insert_batch(prepared_tasks);
668 }
669
670 BatchHandle::new(task_ids, self.wheel.clone(), completion_rxs)
671 }
672
673 /// 取消定时器
674 ///
675 /// # 参数
676 /// - `task_id`: 任务 ID
677 ///
678 /// # 返回
679 /// 如果任务存在且成功取消返回 true,否则返回 false
680 ///
681 /// # 示例
682 /// ```no_run
683 /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
684 ///
685 /// use std::time::Duration;
686 ///
687 /// #[tokio::main]
688 /// async fn main() {
689 /// let timer = TimerWheel::with_defaults();
690 ///
691 /// let task = TimerWheel::create_task(Duration::from_secs(10), Some(CallbackWrapper::new(|| async {
692 /// println!("Timer fired!");
693 /// })));
694 /// let task_id = task.get_id();
695 /// let _handle = timer.register(task);
696 ///
697 /// // 使用任务 ID 取消
698 /// let cancelled = timer.cancel(task_id);
699 /// println!("取消成功: {}", cancelled);
700 /// }
701 /// ```
702 #[inline]
703 pub fn cancel(&self, task_id: TaskId) -> bool {
704 let mut wheel = self.wheel.lock();
705 wheel.cancel(task_id)
706 }
707
708 /// 批量取消定时器
709 ///
710 /// # 参数
711 /// - `task_ids`: 要取消的任务 ID 列表
712 ///
713 /// # 返回
714 /// 成功取消的任务数量
715 ///
716 /// # 性能优势
717 /// - 批量处理减少锁竞争
718 /// - 内部优化批量取消操作
719 ///
720 /// # 示例
721 /// ```no_run
722 /// use kestrel_protocol_timer::{TimerWheel, TimerTask};
723 /// use std::time::Duration;
724 ///
725 /// #[tokio::main]
726 /// async fn main() {
727 /// let timer = TimerWheel::with_defaults();
728 ///
729 /// // 创建多个定时器
730 /// let task1 = TimerWheel::create_task(Duration::from_secs(10), None);
731 /// let task2 = TimerWheel::create_task(Duration::from_secs(10), None);
732 /// let task3 = TimerWheel::create_task(Duration::from_secs(10), None);
733 ///
734 /// let task_ids = vec![task1.get_id(), task2.get_id(), task3.get_id()];
735 ///
736 /// let _h1 = timer.register(task1);
737 /// let _h2 = timer.register(task2);
738 /// let _h3 = timer.register(task3);
739 ///
740 /// // 批量取消
741 /// let cancelled = timer.cancel_batch(&task_ids);
742 /// println!("已取消 {} 个定时器", cancelled);
743 /// }
744 /// ```
745 #[inline]
746 pub fn cancel_batch(&self, task_ids: &[TaskId]) -> usize {
747 let mut wheel = self.wheel.lock();
748 wheel.cancel_batch(task_ids)
749 }
750
751 /// 推迟定时器
752 ///
753 /// # 参数
754 /// - `task_id`: 要推迟的任务 ID
755 /// - `new_delay`: 新的延迟时间(从当前时间点重新计算)
756 /// - `callback`: 新的回调函数,传入 `None` 保持原回调不变,传入 `Some` 替换为新回调
757 ///
758 /// # 返回
759 /// 如果任务存在且成功推迟返回 true,否则返回 false
760 ///
761 /// # 注意
762 /// - 推迟后任务 ID 保持不变
763 /// - 原有的 completion_receiver 仍然有效
764 ///
765 /// # 示例
766 ///
767 /// ## 保持原回调
768 /// ```no_run
769 /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
770 /// use std::time::Duration;
771 ///
772 ///
773 /// #[tokio::main]
774 /// async fn main() {
775 /// let timer = TimerWheel::with_defaults();
776 ///
777 /// let task = TimerWheel::create_task(Duration::from_secs(5), Some(CallbackWrapper::new(|| async {
778 /// println!("Timer fired!");
779 /// })));
780 /// let task_id = task.get_id();
781 /// let _handle = timer.register(task);
782 ///
783 /// // 推迟到 10 秒后触发(保持原回调)
784 /// let success = timer.postpone(task_id, Duration::from_secs(10), None);
785 /// println!("推迟成功: {}", success);
786 /// }
787 /// ```
788 ///
789 /// ## 替换为新回调
790 /// ```no_run
791 /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
792 /// use std::time::Duration;
793 ///
794 /// #[tokio::main]
795 /// async fn main() {
796 /// let timer = TimerWheel::with_defaults();
797 ///
798 /// let task = TimerWheel::create_task(Duration::from_secs(5), Some(CallbackWrapper::new(|| async {
799 /// println!("Original callback!");
800 /// })));
801 /// let task_id = task.get_id();
802 /// let _handle = timer.register(task);
803 ///
804 /// // 推迟到 10 秒后触发(并替换为新回调)
805 /// let success = timer.postpone(task_id, Duration::from_secs(10), Some(CallbackWrapper::new(|| async {
806 /// println!("New callback!");
807 /// })));
808 /// println!("推迟成功: {}", success);
809 /// }
810 /// ```
811 #[inline]
812 pub fn postpone(
813 &self,
814 task_id: TaskId,
815 new_delay: Duration,
816 callback: Option<CallbackWrapper>,
817 ) -> bool {
818 let mut wheel = self.wheel.lock();
819 wheel.postpone(task_id, new_delay, callback)
820 }
821
822 /// 批量推迟定时器(保持原回调)
823 ///
824 /// # 参数
825 /// - `updates`: (任务ID, 新延迟) 的元组列表
826 ///
827 /// # 返回
828 /// 成功推迟的任务数量
829 ///
830 /// # 注意
831 /// - 此方法会保持所有任务的原回调不变
832 /// - 如需替换回调,请使用 `postpone_batch_with_callbacks`
833 ///
834 /// # 性能优势
835 /// - 批量处理减少锁竞争
836 /// - 内部优化批量推迟操作
837 ///
838 /// # 示例
839 /// ```no_run
840 /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
841 /// use std::time::Duration;
842 ///
843 /// #[tokio::main]
844 /// async fn main() {
845 /// let timer = TimerWheel::with_defaults();
846 ///
847 /// // 创建多个带回调的定时器
848 /// let task1 = TimerWheel::create_task(Duration::from_secs(5), Some(CallbackWrapper::new(|| async {
849 /// println!("Task 1 fired!");
850 /// })));
851 /// let task2 = TimerWheel::create_task(Duration::from_secs(5), Some(CallbackWrapper::new(|| async {
852 /// println!("Task 2 fired!");
853 /// })));
854 /// let task3 = TimerWheel::create_task(Duration::from_secs(5), Some(CallbackWrapper::new(|| async {
855 /// println!("Task 3 fired!");
856 /// })));
857 ///
858 /// let task_ids = vec![
859 /// (task1.get_id(), Duration::from_secs(10)),
860 /// (task2.get_id(), Duration::from_secs(15)),
861 /// (task3.get_id(), Duration::from_secs(20)),
862 /// ];
863 ///
864 /// timer.register(task1);
865 /// timer.register(task2);
866 /// timer.register(task3);
867 ///
868 /// // 批量推迟(保持原回调)
869 /// let postponed = timer.postpone_batch(task_ids);
870 /// println!("已推迟 {} 个定时器", postponed);
871 /// }
872 /// ```
873 #[inline]
874 pub fn postpone_batch(&self, updates: Vec<(TaskId, Duration)>) -> usize {
875 let mut wheel = self.wheel.lock();
876 wheel.postpone_batch(updates)
877 }
878
879 /// 批量推迟定时器(替换回调)
880 ///
881 /// # 参数
882 /// - `updates`: (任务ID, 新延迟, 新回调) 的元组列表
883 ///
884 /// # 返回
885 /// 成功推迟的任务数量
886 ///
887 /// # 性能优势
888 /// - 批量处理减少锁竞争
889 /// - 内部优化批量推迟操作
890 ///
891 /// # 示例
892 /// ```no_run
893 /// use kestrel_protocol_timer::{TimerWheel, TimerTask, CallbackWrapper};
894 /// use std::time::Duration;
895 /// use std::sync::Arc;
896 /// use std::sync::atomic::{AtomicU32, Ordering};
897 ///
898 /// #[tokio::main]
899 /// async fn main() {
900 /// let timer = TimerWheel::with_defaults();
901 /// let counter = Arc::new(AtomicU32::new(0));
902 ///
903 /// // 创建多个定时器
904 /// let task1 = TimerWheel::create_task(Duration::from_secs(5), None);
905 /// let task2 = TimerWheel::create_task(Duration::from_secs(5), None);
906 ///
907 /// let id1 = task1.get_id();
908 /// let id2 = task2.get_id();
909 ///
910 /// timer.register(task1);
911 /// timer.register(task2);
912 ///
913 /// // 批量推迟并替换回调
914 /// let updates: Vec<_> = vec![id1, id2]
915 /// .into_iter()
916 /// .map(|id| {
917 /// let counter = Arc::clone(&counter);
918 /// (id, Duration::from_secs(10), Some(CallbackWrapper::new(move || {
919 /// let counter = Arc::clone(&counter);
920 /// async move { counter.fetch_add(1, Ordering::SeqCst); }
921 /// })))
922 /// })
923 /// .collect();
924 /// let postponed = timer.postpone_batch_with_callbacks(updates);
925 /// println!("已推迟 {} 个定时器", postponed);
926 /// }
927 /// ```
928 #[inline]
929 pub fn postpone_batch_with_callbacks(
930 &self,
931 updates: Vec<(TaskId, Duration, Option<CallbackWrapper>)>,
932 ) -> usize {
933 let mut wheel = self.wheel.lock();
934 wheel.postpone_batch_with_callbacks(updates.to_vec())
935 }
936
937 /// 核心 tick 循环
938 async fn tick_loop(wheel: Arc<Mutex<Wheel>>, tick_duration: Duration) {
939 let mut interval = tokio::time::interval(tick_duration);
940 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
941
942 loop {
943 interval.tick().await;
944
945 // 推进时间轮并获取到期任务
946 let expired_tasks = {
947 let mut wheel_guard = wheel.lock();
948 wheel_guard.advance()
949 };
950
951 // 执行到期任务
952 for task in expired_tasks {
953 let callback = task.get_callback();
954
955 // 移动task的所有权来获取completion_notifier
956 let notifier = task.completion_notifier;
957
958 // 只有注册过的任务才有 notifier
959 if let Some(notifier) = notifier {
960 // 在独立的 tokio 任务中执行回调,并在回调完成后发送通知
961 if let Some(callback) = callback {
962 tokio::spawn(async move {
963 // 执行回调
964 let future = callback.call();
965 future.await;
966
967 // 回调执行完成后发送通知
968 let _ = notifier.0.send(TaskCompletionReason::Expired);
969 });
970 } else {
971 // 如果没有回调,立即发送完成通知
972 let _ = notifier.0.send(TaskCompletionReason::Expired);
973 }
974 }
975 }
976 }
977 }
978
979 /// 停止定时器管理器
980 pub async fn shutdown(mut self) {
981 if let Some(handle) = self.tick_handle.take() {
982 handle.abort();
983 let _ = handle.await;
984 }
985 }
986}
987
988impl Drop for TimerWheel {
989 fn drop(&mut self) {
990 if let Some(handle) = self.tick_handle.take() {
991 handle.abort();
992 }
993 }
994}
995
996#[cfg(test)]
997mod tests {
998 use super::*;
999 use std::sync::atomic::{AtomicU32, Ordering};
1000
1001 #[tokio::test]
1002 async fn test_timer_creation() {
1003 let _timer = TimerWheel::with_defaults();
1004 }
1005
1006 #[tokio::test]
1007 async fn test_schedule_once() {
1008 use std::sync::Arc;
1009 let timer = TimerWheel::with_defaults();
1010 let counter = Arc::new(AtomicU32::new(0));
1011 let counter_clone = Arc::clone(&counter);
1012
1013 let task = TimerWheel::create_task(
1014 Duration::from_millis(50),
1015 Some(CallbackWrapper::new(move || {
1016 let counter = Arc::clone(&counter_clone);
1017 async move {
1018 counter.fetch_add(1, Ordering::SeqCst);
1019 }
1020 })),
1021 );
1022 let _handle = timer.register(task);
1023
1024 // 等待定时器触发
1025 tokio::time::sleep(Duration::from_millis(100)).await;
1026 assert_eq!(counter.load(Ordering::SeqCst), 1);
1027 }
1028
1029 #[tokio::test]
1030 async fn test_cancel_timer() {
1031 use std::sync::Arc;
1032 let timer = TimerWheel::with_defaults();
1033 let counter = Arc::new(AtomicU32::new(0));
1034 let counter_clone = Arc::clone(&counter);
1035
1036 let task = TimerWheel::create_task(
1037 Duration::from_millis(100),
1038 Some(CallbackWrapper::new(move || {
1039 let counter = Arc::clone(&counter_clone);
1040 async move {
1041 counter.fetch_add(1, Ordering::SeqCst);
1042 }
1043 })),
1044 );
1045 let handle = timer.register(task);
1046
1047 // 立即取消
1048 let cancel_result = handle.cancel();
1049 assert!(cancel_result);
1050
1051 // 等待足够长时间确保定时器不会触发
1052 tokio::time::sleep(Duration::from_millis(200)).await;
1053 assert_eq!(counter.load(Ordering::SeqCst), 0);
1054 }
1055
1056 #[tokio::test]
1057 async fn test_cancel_immediate() {
1058 use std::sync::Arc;
1059 let timer = TimerWheel::with_defaults();
1060 let counter = Arc::new(AtomicU32::new(0));
1061 let counter_clone = Arc::clone(&counter);
1062
1063 let task = TimerWheel::create_task(
1064 Duration::from_millis(100),
1065 Some(CallbackWrapper::new(move || {
1066 let counter = Arc::clone(&counter_clone);
1067 async move {
1068 counter.fetch_add(1, Ordering::SeqCst);
1069 }
1070 })),
1071 );
1072 let handle = timer.register(task);
1073
1074 // 立即取消
1075 let cancel_result = handle.cancel();
1076 assert!(cancel_result);
1077
1078 // 等待足够长时间确保定时器不会触发
1079 tokio::time::sleep(Duration::from_millis(200)).await;
1080 assert_eq!(counter.load(Ordering::SeqCst), 0);
1081 }
1082
1083 #[tokio::test]
1084 async fn test_postpone_timer() {
1085 use std::sync::Arc;
1086 let timer = TimerWheel::with_defaults();
1087 let counter = Arc::new(AtomicU32::new(0));
1088 let counter_clone = Arc::clone(&counter);
1089
1090 let task = TimerWheel::create_task(
1091 Duration::from_millis(50),
1092 Some(CallbackWrapper::new(move || {
1093 let counter = Arc::clone(&counter_clone);
1094 async move {
1095 counter.fetch_add(1, Ordering::SeqCst);
1096 }
1097 })),
1098 );
1099 let task_id = task.get_id();
1100 let handle = timer.register(task);
1101
1102 // 推迟任务到 150ms
1103 let postponed = timer.postpone(task_id, Duration::from_millis(150), None);
1104 assert!(postponed);
1105
1106 // 等待原定时间 50ms,任务不应该触发
1107 tokio::time::sleep(Duration::from_millis(70)).await;
1108 assert_eq!(counter.load(Ordering::SeqCst), 0);
1109
1110 // 等待新的触发时间(从推迟开始算,还需要等待约 150ms)
1111 let result = tokio::time::timeout(
1112 Duration::from_millis(200),
1113 handle.into_completion_receiver().0
1114 ).await;
1115 assert!(result.is_ok());
1116
1117 // 等待回调执行
1118 tokio::time::sleep(Duration::from_millis(20)).await;
1119 assert_eq!(counter.load(Ordering::SeqCst), 1);
1120 }
1121
1122 #[tokio::test]
1123 async fn test_postpone_with_callback() {
1124 use std::sync::Arc;
1125 let timer = TimerWheel::with_defaults();
1126 let counter = Arc::new(AtomicU32::new(0));
1127 let counter_clone1 = Arc::clone(&counter);
1128 let counter_clone2 = Arc::clone(&counter);
1129
1130 // 创建任务,原始回调增加 1
1131 let task = TimerWheel::create_task(
1132 Duration::from_millis(50),
1133 Some(CallbackWrapper::new(move || {
1134 let counter = Arc::clone(&counter_clone1);
1135 async move {
1136 counter.fetch_add(1, Ordering::SeqCst);
1137 }
1138 })),
1139 );
1140 let task_id = task.get_id();
1141 let handle = timer.register(task);
1142
1143 // 推迟任务并替换回调,新回调增加 10
1144 let postponed = timer.postpone(
1145 task_id,
1146 Duration::from_millis(100),
1147 Some(CallbackWrapper::new(move || {
1148 let counter = Arc::clone(&counter_clone2);
1149 async move {
1150 counter.fetch_add(10, Ordering::SeqCst);
1151 }
1152 })),
1153 );
1154 assert!(postponed);
1155
1156 // 等待任务触发(推迟后需要等待100ms,加上余量)
1157 let result = tokio::time::timeout(
1158 Duration::from_millis(200),
1159 handle.into_completion_receiver().0
1160 ).await;
1161 assert!(result.is_ok());
1162
1163 // 等待回调执行
1164 tokio::time::sleep(Duration::from_millis(20)).await;
1165
1166 // 验证新回调被执行(增加了 10 而不是 1)
1167 assert_eq!(counter.load(Ordering::SeqCst), 10);
1168 }
1169
1170 #[tokio::test]
1171 async fn test_postpone_nonexistent_timer() {
1172 let timer = TimerWheel::with_defaults();
1173
1174 // 尝试推迟不存在的任务
1175 let fake_task = TimerWheel::create_task(Duration::from_millis(50), None);
1176 let fake_task_id = fake_task.get_id();
1177 // 不注册这个任务
1178
1179 let postponed = timer.postpone(fake_task_id, Duration::from_millis(100), None);
1180 assert!(!postponed);
1181 }
1182
1183 #[tokio::test]
1184 async fn test_postpone_batch() {
1185 use std::sync::Arc;
1186 let timer = TimerWheel::with_defaults();
1187 let counter = Arc::new(AtomicU32::new(0));
1188
1189 // 创建 3 个任务
1190 let mut task_ids = Vec::new();
1191 for _ in 0..3 {
1192 let counter_clone = Arc::clone(&counter);
1193 let task = TimerWheel::create_task(
1194 Duration::from_millis(50),
1195 Some(CallbackWrapper::new(move || {
1196 let counter = Arc::clone(&counter_clone);
1197 async move {
1198 counter.fetch_add(1, Ordering::SeqCst);
1199 }
1200 })),
1201 );
1202 task_ids.push((task.get_id(), Duration::from_millis(150)));
1203 timer.register(task);
1204 }
1205
1206 // 批量推迟
1207 let postponed = timer.postpone_batch(task_ids);
1208 assert_eq!(postponed, 3);
1209
1210 // 等待原定时间 50ms,任务不应该触发
1211 tokio::time::sleep(Duration::from_millis(70)).await;
1212 assert_eq!(counter.load(Ordering::SeqCst), 0);
1213
1214 // 等待新的触发时间(从推迟开始算,还需要等待约 150ms)
1215 tokio::time::sleep(Duration::from_millis(200)).await;
1216
1217 // 等待回调执行
1218 tokio::time::sleep(Duration::from_millis(20)).await;
1219 assert_eq!(counter.load(Ordering::SeqCst), 3);
1220 }
1221
1222 #[tokio::test]
1223 async fn test_postpone_batch_with_callbacks() {
1224 use std::sync::Arc;
1225 let timer = TimerWheel::with_defaults();
1226 let counter = Arc::new(AtomicU32::new(0));
1227
1228 // 创建 3 个任务
1229 let mut task_ids = Vec::new();
1230 for _ in 0..3 {
1231 let task = TimerWheel::create_task(
1232 Duration::from_millis(50),
1233 None
1234 );
1235 task_ids.push(task.get_id());
1236 timer.register(task);
1237 }
1238
1239 // 批量推迟并替换回调
1240 let updates: Vec<_> = task_ids
1241 .into_iter()
1242 .map(|id| {
1243 let counter_clone = Arc::clone(&counter);
1244 (id, Duration::from_millis(150), Some(CallbackWrapper::new(move || {
1245 let counter = Arc::clone(&counter_clone);
1246 async move {
1247 counter.fetch_add(1, Ordering::SeqCst);
1248 }
1249 })))
1250 })
1251 .collect();
1252
1253 let postponed = timer.postpone_batch_with_callbacks(updates);
1254 assert_eq!(postponed, 3);
1255
1256 // 等待原定时间 50ms,任务不应该触发
1257 tokio::time::sleep(Duration::from_millis(70)).await;
1258 assert_eq!(counter.load(Ordering::SeqCst), 0);
1259
1260 // 等待新的触发时间(从推迟开始算,还需要等待约 150ms)
1261 tokio::time::sleep(Duration::from_millis(200)).await;
1262
1263 // 等待回调执行
1264 tokio::time::sleep(Duration::from_millis(20)).await;
1265 assert_eq!(counter.load(Ordering::SeqCst), 3);
1266 }
1267
1268 #[tokio::test]
1269 async fn test_postpone_keeps_completion_receiver_valid() {
1270 use std::sync::Arc;
1271 let timer = TimerWheel::with_defaults();
1272 let counter = Arc::new(AtomicU32::new(0));
1273 let counter_clone = Arc::clone(&counter);
1274
1275 let task = TimerWheel::create_task(
1276 Duration::from_millis(50),
1277 Some(CallbackWrapper::new(move || {
1278 let counter = Arc::clone(&counter_clone);
1279 async move {
1280 counter.fetch_add(1, Ordering::SeqCst);
1281 }
1282 })),
1283 );
1284 let task_id = task.get_id();
1285 let handle = timer.register(task);
1286
1287 // 推迟任务
1288 timer.postpone(task_id, Duration::from_millis(100), None);
1289
1290 // 验证原 completion_receiver 仍然有效(推迟后需要等待100ms,加上余量)
1291 let result = tokio::time::timeout(
1292 Duration::from_millis(200),
1293 handle.into_completion_receiver().0
1294 ).await;
1295 assert!(result.is_ok(), "Completion receiver should still work after postpone");
1296
1297 // 等待回调执行
1298 tokio::time::sleep(Duration::from_millis(20)).await;
1299 assert_eq!(counter.load(Ordering::SeqCst), 1);
1300 }
1301}
1302