memscope_rs/async_memory/
task_id.rs1use std::cell::Cell;
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::task::Context;
10
11use crate::async_memory::error::{AsyncError, AsyncResult, TaskOperation};
12
13static TASK_EPOCH: AtomicU64 = AtomicU64::new(1);
18
19pub type TaskId = u128;
24
25#[derive(Clone, Copy, Debug, Default)]
30pub struct TaskInfo {
31 pub waker_id: TaskId,
33 pub span_id: Option<u64>,
35 pub created_at: u64,
37}
38
39impl TaskInfo {
40 pub fn new(waker_id: TaskId, span_id: Option<u64>) -> Self {
42 Self {
43 waker_id,
44 span_id,
45 created_at: current_timestamp(),
46 }
47 }
48
49 pub fn has_tracking_id(&self) -> bool {
51 self.waker_id != 0 || self.span_id.is_some()
52 }
53
54 pub fn primary_id(&self) -> TaskId {
56 if self.waker_id != 0 {
57 self.waker_id
58 } else {
59 self.span_id.map(|id| id as TaskId).unwrap_or(0)
61 }
62 }
63}
64
65thread_local! {
70 static CURRENT_TASK: Cell<TaskInfo> = const { Cell::new(TaskInfo {
71 waker_id: 0,
72 span_id: None,
73 created_at: 0,
74 }) };
75}
76
77#[inline(always)]
82pub fn generate_task_id(cx: &Context<'_>) -> AsyncResult<TaskId> {
83 let waker_addr = cx.waker() as *const _ as u64;
86
87 let epoch = TASK_EPOCH.fetch_add(1, Ordering::Relaxed);
89
90 let task_id = ((epoch as u128) << 64) | (waker_addr as u128);
92
93 if task_id == 0 {
95 return Err(AsyncError::task_tracking(
96 TaskOperation::IdGeneration,
97 "Generated zero task ID - invalid waker or epoch overflow",
98 None,
99 ));
100 }
101
102 Ok(task_id)
103}
104
105#[inline(always)]
110pub fn set_current_task(task_info: TaskInfo) {
111 CURRENT_TASK.with(|current| current.set(task_info));
112}
113
114#[inline(always)]
119pub fn get_current_task() -> TaskInfo {
120 CURRENT_TASK.with(|current| current.get())
121}
122
123#[inline(always)]
128pub fn update_span_id(span_id: Option<u64>) -> AsyncResult<()> {
129 CURRENT_TASK.with(|current| {
130 let mut info = current.get();
131 info.span_id = span_id;
132 current.set(info);
133 });
134 Ok(())
135}
136
137#[inline(always)]
142pub fn clear_current_task() {
143 CURRENT_TASK.with(|current| current.set(TaskInfo::default()));
144}
145
146#[inline(always)]
151fn current_timestamp() -> u64 {
152 #[cfg(target_arch = "x86_64")]
153 {
154 unsafe { std::arch::x86_64::_rdtsc() }
156 }
157 #[cfg(not(target_arch = "x86_64"))]
158 {
159 use std::time::{SystemTime, UNIX_EPOCH};
161 SystemTime::now()
162 .duration_since(UNIX_EPOCH)
163 .map(|d| d.as_nanos() as u64)
164 .unwrap_or(0)
165 }
166}
167
168pub fn current_epoch() -> u64 {
170 TASK_EPOCH.load(Ordering::Relaxed)
171}
172
173#[cfg(test)]
175pub fn reset_epoch() {
176 TASK_EPOCH.store(1, Ordering::Relaxed);
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182 use std::task::{RawWaker, RawWakerVTable, Waker};
183
184 fn create_test_waker() -> Waker {
186 fn noop(_: *const ()) {}
187 fn clone_waker(data: *const ()) -> RawWaker {
188 RawWaker::new(data, &VTABLE)
189 }
190
191 const VTABLE: RawWakerVTable = RawWakerVTable::new(clone_waker, noop, noop, noop);
192
193 unsafe { Waker::from_raw(RawWaker::new(std::ptr::null(), &VTABLE)) }
194 }
195
196 #[test]
197 fn test_task_id_generation() {
198 reset_epoch();
199
200 let waker = create_test_waker();
201 let cx = Context::from_waker(&waker);
202
203 let id1 = generate_task_id(&cx).expect("Failed to generate task ID");
204 let id2 = generate_task_id(&cx).expect("Failed to generate task ID");
205
206 assert_ne!(id1, id2);
208
209 assert_ne!(id1, 0);
211 assert_ne!(id2, 0);
212
213 let epoch1 = (id1 >> 64) as u64;
215 let epoch2 = (id2 >> 64) as u64;
216 assert_eq!(epoch2, epoch1 + 1);
217 }
218
219 #[test]
220 fn test_task_info_operations() {
221 let info = TaskInfo::new(12345, Some(67890));
222
223 assert!(info.has_tracking_id());
224 assert_eq!(info.primary_id(), 12345);
225 assert_ne!(info.created_at, 0);
226
227 let info_no_waker = TaskInfo::new(0, Some(67890));
229 assert!(info_no_waker.has_tracking_id());
230 assert_eq!(info_no_waker.primary_id(), 67890);
231
232 let info_empty = TaskInfo::default();
234 assert!(!info_empty.has_tracking_id());
235 assert_eq!(info_empty.primary_id(), 0);
236 }
237
238 #[test]
239 fn test_thread_local_storage() {
240 let info = TaskInfo::new(12345, Some(67890));
241
242 assert!(!get_current_task().has_tracking_id());
244
245 set_current_task(info);
247 let retrieved = get_current_task();
248 assert_eq!(retrieved.waker_id, 12345);
249 assert_eq!(retrieved.span_id, Some(67890));
250
251 update_span_id(Some(99999)).expect("Failed to update span ID");
253 let updated = get_current_task();
254 assert_eq!(updated.waker_id, 12345); assert_eq!(updated.span_id, Some(99999)); clear_current_task();
259 assert!(!get_current_task().has_tracking_id());
260 }
261
262 #[test]
263 fn test_epoch_progression() {
264 reset_epoch();
265 let initial_epoch = current_epoch();
266
267 let waker = create_test_waker();
268 let cx = Context::from_waker(&waker);
269
270 let mut previous_epoch = initial_epoch;
272 for _i in 0..5 {
273 let _id = generate_task_id(&cx).expect("Failed to generate task ID");
274 let current = current_epoch();
275 assert!(
277 current > previous_epoch,
278 "Epoch should progress: {} -> {}",
279 previous_epoch,
280 current
281 );
282 previous_epoch = current;
283 }
284
285 assert_eq!(current_epoch(), initial_epoch + 5);
287 }
288
289 #[test]
290 fn test_timestamp_generation() {
291 let ts1 = current_timestamp();
292 let ts2 = current_timestamp();
293
294 assert_ne!(ts1, 0);
296 assert_ne!(ts2, 0);
297 assert!(ts2 >= ts1);
298 }
299
300 #[test]
301 fn test_concurrent_task_id_generation() {
302 use std::sync::{Arc, Mutex};
303 use std::thread;
304
305 reset_epoch();
306 let ids = Arc::new(Mutex::new(Vec::new()));
307 let handles: Vec<_> = (0..10)
308 .map(|_| {
309 let ids_clone = Arc::clone(&ids);
310 thread::spawn(move || {
311 let waker = create_test_waker();
312 let cx = Context::from_waker(&waker);
313 let id = generate_task_id(&cx).expect("Failed to generate task ID");
314 ids_clone.lock().expect("Lock poisoned").push(id);
315 })
316 })
317 .collect();
318
319 for handle in handles {
320 handle.join().expect("Thread panicked");
321 }
322
323 let ids = ids.lock().expect("Lock poisoned");
324
325 let mut sorted_ids = ids.clone();
327 sorted_ids.sort();
328 sorted_ids.dedup();
329 assert_eq!(sorted_ids.len(), ids.len());
330
331 assert!(ids.iter().all(|&id| id != 0));
333 }
334}