memscope_rs/async_memory/
task_id.rs

1//! Task identification and propagation for async memory tracking
2//!
3//! This module provides zero-overhead task identification using Context waker addresses
4//! combined with global epoch counters to ensure absolute uniqueness across the
5//! application lifetime.
6
7use std::cell::Cell;
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::task::Context;
10
11use crate::async_memory::error::{AsyncError, AsyncResult, TaskOperation};
12
13/// Global monotonic counter to ensure task ID uniqueness
14///
15/// Combined with waker addresses, this guarantees no ID collisions even if
16/// waker memory is reused after task completion.
17static TASK_EPOCH: AtomicU64 = AtomicU64::new(1);
18
19/// Unique identifier for async tasks
20///
21/// Combines a global epoch counter (high 64 bits) with the waker vtable address
22/// (low 64 bits) to ensure absolute uniqueness across application lifetime.
23pub type TaskId = u128;
24
25/// Extended task information stored in thread-local storage
26///
27/// Supports dual-track approach: precise TrackedFuture identification
28/// plus tracing::Subscriber integration for broader ecosystem coverage.
29#[derive(Clone, Copy, Debug, Default)]
30pub struct TaskInfo {
31    /// Primary task ID from TrackedFuture (high precision)
32    pub waker_id: TaskId,
33    /// Secondary span ID from tracing ecosystem (broader coverage)
34    pub span_id: Option<u64>,
35    /// Task creation timestamp for lifecycle analysis
36    pub created_at: u64,
37}
38
39impl TaskInfo {
40    /// Create new task info with current timestamp
41    pub fn new(waker_id: TaskId, span_id: Option<u64>) -> Self {
42        Self {
43            waker_id,
44            span_id,
45            created_at: current_timestamp(),
46        }
47    }
48
49    /// Check if any tracking ID is available
50    pub fn has_tracking_id(&self) -> bool {
51        self.waker_id != 0 || self.span_id.is_some()
52    }
53
54    /// Get the primary tracking ID (waker_id preferred, fallback to span_id)
55    pub fn primary_id(&self) -> TaskId {
56        if self.waker_id != 0 {
57            self.waker_id
58        } else {
59            // Convert span_id to TaskId format if waker_id unavailable
60            self.span_id.map(|id| id as TaskId).unwrap_or(0)
61        }
62    }
63}
64
65// Thread-local storage for current task information
66//
67// Uses Cell for zero-overhead access from the allocator hook path.
68// More efficient than tokio::task_local for our specific use case.
69thread_local! {
70    static CURRENT_TASK: Cell<TaskInfo> = const { Cell::new(TaskInfo {
71        waker_id: 0,
72        span_id: None,
73        created_at: 0,
74    }) };
75}
76
77/// Generate unique task ID from Context waker
78///
79/// Uses the waker's vtable address combined with a global epoch counter
80/// to ensure uniqueness even if waker memory is reused.
81#[inline(always)]
82pub fn generate_task_id(cx: &Context<'_>) -> AsyncResult<TaskId> {
83    // Extract waker vtable address as unique identifier
84    // Use waker pointer address as identifier (stable within task lifetime)
85    let waker_addr = cx.waker() as *const _ as u64;
86
87    // Get monotonic epoch counter
88    let epoch = TASK_EPOCH.fetch_add(1, Ordering::Relaxed);
89
90    // Combine epoch (high 64 bits) with waker address (low 64 bits)
91    let task_id = ((epoch as u128) << 64) | (waker_addr as u128);
92
93    // Validate non-zero result
94    if task_id == 0 {
95        return Err(AsyncError::task_tracking(
96            TaskOperation::IdGeneration,
97            "Generated zero task ID - invalid waker or epoch overflow",
98            None,
99        ));
100    }
101
102    Ok(task_id)
103}
104
105/// Set current task information in thread-local storage
106///
107/// Called by TrackedFuture during poll operations to establish task context
108/// for allocation tracking.
109#[inline(always)]
110pub fn set_current_task(task_info: TaskInfo) {
111    CURRENT_TASK.with(|current| current.set(task_info));
112}
113
114/// Get current task information from thread-local storage
115///
116/// Returns the task context for the currently executing async task.
117/// Used by the global allocator hook to attribute memory allocations.
118#[inline(always)]
119pub fn get_current_task() -> TaskInfo {
120    CURRENT_TASK.with(|current| current.get())
121}
122
123/// Update span ID for tracing integration
124///
125/// Called by MemScopeSubscriber when entering/exiting tracing spans
126/// to provide fallback task identification.
127#[inline(always)]
128pub fn update_span_id(span_id: Option<u64>) -> AsyncResult<()> {
129    CURRENT_TASK.with(|current| {
130        let mut info = current.get();
131        info.span_id = span_id;
132        current.set(info);
133    });
134    Ok(())
135}
136
137/// Clear current task context
138///
139/// Called when leaving task scope to prevent attribution of allocations
140/// to completed tasks.
141#[inline(always)]
142pub fn clear_current_task() {
143    CURRENT_TASK.with(|current| current.set(TaskInfo::default()));
144}
145
146/// Get current timestamp using efficient method
147///
148/// Uses TSC (Time Stamp Counter) on x86_64 for minimal overhead,
149/// falls back to system time on other architectures.
150#[inline(always)]
151fn current_timestamp() -> u64 {
152    #[cfg(target_arch = "x86_64")]
153    {
154        // Use hardware timestamp counter for minimal overhead
155        unsafe { std::arch::x86_64::_rdtsc() }
156    }
157    #[cfg(not(target_arch = "x86_64"))]
158    {
159        // Fallback to system time for other architectures
160        use std::time::{SystemTime, UNIX_EPOCH};
161        SystemTime::now()
162            .duration_since(UNIX_EPOCH)
163            .map(|d| d.as_nanos() as u64)
164            .unwrap_or(0)
165    }
166}
167
168/// Get current epoch counter value for diagnostics
169pub fn current_epoch() -> u64 {
170    TASK_EPOCH.load(Ordering::Relaxed)
171}
172
173/// Reset epoch counter (for testing only)
174#[cfg(test)]
175pub fn reset_epoch() {
176    TASK_EPOCH.store(1, Ordering::Relaxed);
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182    use std::task::{RawWaker, RawWakerVTable, Waker};
183
184    // Helper to create a dummy waker for testing
185    fn create_test_waker() -> Waker {
186        fn noop(_: *const ()) {}
187        fn clone_waker(data: *const ()) -> RawWaker {
188            RawWaker::new(data, &VTABLE)
189        }
190
191        const VTABLE: RawWakerVTable = RawWakerVTable::new(clone_waker, noop, noop, noop);
192
193        unsafe { Waker::from_raw(RawWaker::new(std::ptr::null(), &VTABLE)) }
194    }
195
196    #[test]
197    fn test_task_id_generation() {
198        reset_epoch();
199
200        let waker = create_test_waker();
201        let cx = Context::from_waker(&waker);
202
203        let id1 = generate_task_id(&cx).expect("Failed to generate task ID");
204        let id2 = generate_task_id(&cx).expect("Failed to generate task ID");
205
206        // IDs should be different due to epoch increment
207        assert_ne!(id1, id2);
208
209        // IDs should be non-zero
210        assert_ne!(id1, 0);
211        assert_ne!(id2, 0);
212
213        // High 64 bits should contain epoch values
214        let epoch1 = (id1 >> 64) as u64;
215        let epoch2 = (id2 >> 64) as u64;
216        assert_eq!(epoch2, epoch1 + 1);
217    }
218
219    #[test]
220    fn test_task_info_operations() {
221        let info = TaskInfo::new(12345, Some(67890));
222
223        assert!(info.has_tracking_id());
224        assert_eq!(info.primary_id(), 12345);
225        assert_ne!(info.created_at, 0);
226
227        // Test fallback to span_id
228        let info_no_waker = TaskInfo::new(0, Some(67890));
229        assert!(info_no_waker.has_tracking_id());
230        assert_eq!(info_no_waker.primary_id(), 67890);
231
232        // Test no tracking
233        let info_empty = TaskInfo::default();
234        assert!(!info_empty.has_tracking_id());
235        assert_eq!(info_empty.primary_id(), 0);
236    }
237
238    #[test]
239    fn test_thread_local_storage() {
240        let info = TaskInfo::new(12345, Some(67890));
241
242        // Initially empty
243        assert!(!get_current_task().has_tracking_id());
244
245        // Set and verify
246        set_current_task(info);
247        let retrieved = get_current_task();
248        assert_eq!(retrieved.waker_id, 12345);
249        assert_eq!(retrieved.span_id, Some(67890));
250
251        // Update span ID
252        update_span_id(Some(99999)).expect("Failed to update span ID");
253        let updated = get_current_task();
254        assert_eq!(updated.waker_id, 12345); // Unchanged
255        assert_eq!(updated.span_id, Some(99999)); // Updated
256
257        // Clear
258        clear_current_task();
259        assert!(!get_current_task().has_tracking_id());
260    }
261
262    #[test]
263    fn test_epoch_progression() {
264        reset_epoch();
265        let initial_epoch = current_epoch();
266
267        let waker = create_test_waker();
268        let cx = Context::from_waker(&waker);
269
270        // Generate some task IDs and verify epoch progression
271        let mut previous_epoch = initial_epoch;
272        for _i in 0..5 {
273            let _id = generate_task_id(&cx).expect("Failed to generate task ID");
274            let current = current_epoch();
275            // Each generate_task_id should increment the epoch
276            assert!(
277                current > previous_epoch,
278                "Epoch should progress: {} -> {}",
279                previous_epoch,
280                current
281            );
282            previous_epoch = current;
283        }
284
285        // Total progression should be 5 increments
286        assert_eq!(current_epoch(), initial_epoch + 5);
287    }
288
289    #[test]
290    fn test_timestamp_generation() {
291        let ts1 = current_timestamp();
292        let ts2 = current_timestamp();
293
294        // Timestamps should be non-zero and monotonic (or at least not decreasing)
295        assert_ne!(ts1, 0);
296        assert_ne!(ts2, 0);
297        assert!(ts2 >= ts1);
298    }
299
300    #[test]
301    fn test_concurrent_task_id_generation() {
302        use std::sync::{Arc, Mutex};
303        use std::thread;
304
305        reset_epoch();
306        let ids = Arc::new(Mutex::new(Vec::new()));
307        let handles: Vec<_> = (0..10)
308            .map(|_| {
309                let ids_clone = Arc::clone(&ids);
310                thread::spawn(move || {
311                    let waker = create_test_waker();
312                    let cx = Context::from_waker(&waker);
313                    let id = generate_task_id(&cx).expect("Failed to generate task ID");
314                    ids_clone.lock().expect("Lock poisoned").push(id);
315                })
316            })
317            .collect();
318
319        for handle in handles {
320            handle.join().expect("Thread panicked");
321        }
322
323        let ids = ids.lock().expect("Lock poisoned");
324
325        // All IDs should be unique
326        let mut sorted_ids = ids.clone();
327        sorted_ids.sort();
328        sorted_ids.dedup();
329        assert_eq!(sorted_ids.len(), ids.len());
330
331        // All IDs should be non-zero
332        assert!(ids.iter().all(|&id| id != 0));
333    }
334}