kestrel_protocol_timer/task.rs
1use std::future::Future;
2use std::pin::Pin;
3use std::sync::atomic::{AtomicU64, Ordering};
4use std::sync::Arc;
5use tokio::sync::oneshot;
6
7/// 全局唯一的任务 ID 生成器
8static NEXT_TASK_ID: AtomicU64 = AtomicU64::new(1);
9
10/// 任务完成原因
11///
12/// 表示定时器任务完成的原因,可以是正常到期或被取消。
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum TaskCompletionReason {
15 /// 任务正常到期
16 Expired,
17 /// 任务被取消
18 Cancelled,
19}
20
21/// 定时器任务的唯一标识符
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
23pub struct TaskId(u64);
24
25impl TaskId {
26 /// 生成一个新的唯一任务 ID(内部使用)
27 pub(crate) fn new() -> Self {
28 TaskId(NEXT_TASK_ID.fetch_add(1, Ordering::Relaxed))
29 }
30
31 /// 获取任务 ID 的数值
32 pub fn as_u64(&self) -> u64 {
33 self.0
34 }
35}
36
37impl Default for TaskId {
38 fn default() -> Self {
39 Self::new()
40 }
41}
42
43/// 定时器回调 trait
44///
45/// 实现此 trait 的类型可以作为定时器的回调函数使用。
46///
47/// # 示例
48///
49/// ```
50/// use kestrel_protocol_timer::TimerCallback;
51/// use std::future::Future;
52/// use std::pin::Pin;
53///
54/// struct MyCallback;
55///
56/// impl TimerCallback for MyCallback {
57/// fn call(&self) -> Pin<Box<dyn Future<Output = ()> + Send>> {
58/// Box::pin(async {
59/// println!("Timer callback executed!");
60/// })
61/// }
62/// }
63/// ```
64pub trait TimerCallback: Send + Sync + 'static {
65 /// 执行回调,返回一个 Future
66 fn call(&self) -> Pin<Box<dyn Future<Output = ()> + Send>>;
67}
68
69/// 为闭包实现 TimerCallback trait
70/// 支持 Fn() -> Future 类型的闭包(可以多次调用,适合周期性任务)
71impl<F, Fut> TimerCallback for F
72where
73 F: Fn() -> Fut + Send + Sync + 'static,
74 Fut: Future<Output = ()> + Send + 'static,
75{
76 fn call(&self) -> Pin<Box<dyn Future<Output = ()> + Send>> {
77 Box::pin(self())
78 }
79}
80
81/// 回调包装器类型
82pub type CallbackWrapper = Arc<dyn TimerCallback>;
83
84/// 完成通知器,用于在任务完成时发送通知
85pub struct CompletionNotifier(pub oneshot::Sender<TaskCompletionReason>);
86
87/// 定时器任务
88///
89/// 用户通过两步式 API 使用:
90/// 1. 使用 `TimerTask::new()` 创建任务
91/// 2. 使用 `TimerWheel::register()` 或 `TimerService::register()` 注册任务
92pub struct TimerTask {
93 /// 任务唯一标识符
94 pub(crate) id: TaskId,
95
96 /// 用户指定的延迟时间
97 pub(crate) delay: std::time::Duration,
98
99 /// 到期时间(相对于时间轮的 tick 数)
100 pub(crate) deadline_tick: u64,
101
102 /// 轮次计数(用于超出时间轮范围的任务)
103 pub(crate) rounds: u32,
104
105 /// 异步回调函数(可选)
106 pub(crate) callback: Option<CallbackWrapper>,
107
108 /// 完成通知器(用于在任务完成时发送通知,注册时创建)
109 pub(crate) completion_notifier: Option<CompletionNotifier>,
110}
111
112impl TimerTask {
113 /// 创建新的定时器任务
114 ///
115 /// # 参数
116 /// - `delay`: 延迟时间
117 /// - `callback`: 回调函数(可选)
118 ///
119 /// # 示例
120 /// ```no_run
121 /// use kestrel_protocol_timer::TimerTask;
122 /// use std::time::Duration;
123 /// use std::sync::Arc;
124 ///
125 /// // 创建带回调的任务
126 /// let callback = Arc::new(|| async {
127 /// println!("Timer fired!");
128 /// });
129 /// let task = TimerTask::new(Duration::from_secs(1), Some(callback));
130 ///
131 /// // 创建仅通知的任务
132 /// let task = TimerTask::new(Duration::from_secs(1), None);
133 /// ```
134 pub fn new(delay: std::time::Duration, callback: Option<CallbackWrapper>) -> Self {
135 Self {
136 id: TaskId::new(),
137 delay,
138 deadline_tick: 0,
139 rounds: 0,
140 callback,
141 completion_notifier: None,
142 }
143 }
144
145 /// 获取任务 ID
146 ///
147 /// # 示例
148 /// ```no_run
149 /// use kestrel_protocol_timer::TimerTask;
150 /// use std::time::Duration;
151 ///
152 /// let task = TimerTask::new(Duration::from_secs(1), None);
153 /// let task_id = task.get_id();
154 /// println!("Task ID: {:?}", task_id);
155 /// ```
156 pub fn get_id(&self) -> TaskId {
157 self.id
158 }
159
160 /// 内部方法:准备注册(在注册时由时间轮调用)
161 pub(crate) fn prepare_for_registration(
162 &mut self,
163 completion_notifier: CompletionNotifier,
164 deadline_tick: u64,
165 rounds: u32,
166 ) {
167 self.completion_notifier = Some(completion_notifier);
168 self.deadline_tick = deadline_tick;
169 self.rounds = rounds;
170 }
171
172 /// 获取回调函数的克隆(如果存在)
173 pub(crate) fn get_callback(&self) -> Option<CallbackWrapper> {
174 self.callback.as_ref().map(Arc::clone)
175 }
176}
177
178/// 任务位置信息,用于取消操作
179#[derive(Debug, Clone)]
180pub struct TaskLocation {
181 pub slot_index: usize,
182 /// 任务在槽位 Vec 中的索引位置(用于 O(1) 取消)
183 pub vec_index: usize,
184 #[allow(dead_code)]
185 pub task_id: TaskId,
186}
187
188impl TaskLocation {
189 pub fn new(slot_index: usize, vec_index: usize, task_id: TaskId) -> Self {
190 Self {
191 slot_index,
192 vec_index,
193 task_id,
194 }
195 }
196}
197