Skip to main content

memscope_rs/capture/backends/
async_tracker.rs

1//! Async memory tracker implementation.
2//!
3//! This module contains async-specific memory tracking functionality
4//! including task tracking, efficiency scoring, and bottleneck analysis.
5
6use std::collections::HashMap;
7use std::future::Future;
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::sync::{Arc, Mutex};
10use std::thread::ThreadId;
11
12use super::async_types::{
13    AsyncAllocation, AsyncError, AsyncMemorySnapshot, AsyncResult, AsyncSnapshot, AsyncStats,
14    TrackedFuture,
15};
16use super::task_profile::{TaskMemoryProfile, TaskType};
17
18/// Global task ID counter for unique task identification.
19/// Tokio task IDs are recycled after task completion, so we need
20/// our own counter to ensure unique identification across all tasks.
21static TASK_COUNTER: AtomicU64 = AtomicU64::new(1);
22
23/// Generate a new unique task ID.
24/// This ID is never recycled, ensuring unique identification.
25pub fn generate_unique_task_id() -> u64 {
26    TASK_COUNTER.fetch_add(1, Ordering::Relaxed)
27}
28
29/// Global thread ID counter for unique thread identification.
30static THREAD_COUNTER: AtomicU64 = AtomicU64::new(1);
31
32thread_local! {
33    static THREAD_ID: u64 = THREAD_COUNTER.fetch_add(1, Ordering::Relaxed);
34}
35
36/// Get the current thread's unique ID.
37/// This ID is assigned once per thread and never changes.
38pub fn current_thread_id() -> u64 {
39    THREAD_ID.with(|id| *id)
40}
41
42/// Context for tracking memory allocations.
43/// Captures both thread and task information for accurate attribution.
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
45pub struct TrackerContext {
46    pub thread_id: u64,
47    pub task_id: Option<u64>,
48    pub tokio_task_id: Option<u64>,
49}
50
51impl TrackerContext {
52    /// Capture the current tracking context.
53    /// Returns thread ID and task ID (if in a task context).
54    pub fn capture() -> Self {
55        let task_id_from_context = TASK_CONTEXT.try_with(|ctx| *ctx).ok().flatten();
56        let tokio_task_id = tokio::task::try_id().and_then(|id| id.to_string().parse().ok());
57
58        Self {
59            thread_id: current_thread_id(),
60            task_id: task_id_from_context.or(CURRENT_TASK_ID.with(|cell| cell.get())),
61            tokio_task_id,
62        }
63    }
64}
65
66/// Task efficiency report
67#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
68pub struct TaskReport {
69    pub task_name: String,
70    pub task_type: TaskType,
71    pub efficiency_score: f64,
72    pub cpu_efficiency: f64,
73    pub memory_efficiency: f64,
74    pub io_efficiency: f64,
75    pub bottleneck: String,
76    pub recommendations: Vec<String>,
77}
78
79/// Resource ranking entry
80#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
81pub struct ResourceRanking {
82    pub task_name: String,
83    pub task_type: TaskType,
84    pub cpu_usage: f64,
85    pub memory_usage_mb: f64,
86    pub io_usage_mb: f64,
87    pub network_usage_mb: f64,
88    pub gpu_usage: f64,
89    pub overall_score: f64,
90}
91
92/// Global async tracker instance
93static GLOBAL_TRACKER: Mutex<Option<Arc<AsyncTracker>>> = Mutex::new(None);
94
95thread_local! {
96    static CURRENT_TASK_ID: std::cell::Cell<Option<u64>> = const { std::cell::Cell::new(None) };
97}
98
99tokio::task_local! {
100    static TASK_CONTEXT: Option<u64>;
101}
102
103/// RAII guard for automatic task cleanup.
104/// When dropped, it automatically clears the current task ID from thread-local storage.
105pub struct TaskGuard {
106    task_id: u64,
107    cleaned_up: bool,
108}
109
110// SAFETY: TaskGuard can be safely sent between threads because:
111// 1. It only contains primitive types (u64, bool) that are both Send and Sync.
112// 2. It does NOT hold any references, raw pointers, or non-thread-safe data.
113// 3. The thread-local state it manages (CURRENT_TASK_ID) is accessed via thread_local!(),
114//    which provides per-thread isolation - each thread has its own copy of the task ID.
115// 4. When TaskGuard is moved to another thread, it clears the task ID in the CURRENT
116//    thread's thread-local storage (not the original thread's), which is the expected
117//    behavior for task context management in async runtimes.
118// 5. The Drop implementation clears the task ID of whichever thread is currently
119//    executing, which is correct for RAII cleanup in async contexts where tasks
120//    may migrate between threads.
121unsafe impl Send for TaskGuard {}
122
123// SAFETY: TaskGuard can be safely shared between threads because:
124// 1. All fields (task_id, cleaned_up) are primitive types that are safe for concurrent access.
125// 2. TaskGuard does not provide mutable access to its fields through &self.
126// 3. The thread-local manipulation via clear_current_task_internal() operates on the
127//    CURRENT thread's storage, not a shared resource, so concurrent calls from different
128//    threads are inherently isolated.
129// 4. No interior mutability or shared mutable state is exposed.
130unsafe impl Sync for TaskGuard {}
131
132impl TaskGuard {
133    fn new(task_id: u64) -> Self {
134        Self {
135            task_id,
136            cleaned_up: false,
137        }
138    }
139
140    /// Get the task ID this guard is associated with.
141    pub fn task_id(&self) -> u64 {
142        self.task_id
143    }
144
145    /// Manually release the guard (prevents double cleanup).
146    pub fn release(mut self) {
147        self.cleaned_up = true;
148        TaskGuard::clear_current_task_internal();
149    }
150}
151
152impl Drop for TaskGuard {
153    fn drop(&mut self) {
154        if !self.cleaned_up {
155            TaskGuard::clear_current_task_internal();
156        }
157    }
158}
159
160impl TaskGuard {
161    fn clear_current_task_internal() {
162        CURRENT_TASK_ID.with(|cell| cell.set(None));
163    }
164}
165
166/// Async memory tracker for task-aware memory tracking.
167pub struct AsyncTracker {
168    allocations: Arc<Mutex<HashMap<usize, AsyncAllocation>>>,
169    stats: Arc<Mutex<AsyncStats>>,
170    profiles: Arc<Mutex<HashMap<u64, TaskMemoryProfile>>>,
171    initialized: Arc<Mutex<bool>>,
172}
173
174impl AsyncTracker {
175    pub fn new() -> Self {
176        Self {
177            allocations: Arc::new(Mutex::new(HashMap::new())),
178            stats: Arc::new(Mutex::new(AsyncStats::default())),
179            profiles: Arc::new(Mutex::new(HashMap::new())),
180            initialized: Arc::new(Mutex::new(false)),
181        }
182    }
183
184    pub fn set_current_task(task_id: u64) {
185        CURRENT_TASK_ID.with(|cell| cell.set(Some(task_id)));
186    }
187
188    pub fn clear_current_task() {
189        CURRENT_TASK_ID.with(|cell| cell.set(None));
190    }
191
192    /// Get the current task ID from the thread-local storage.
193    ///
194    /// Note: This only returns the manually set task ID.
195    /// For automatic tokio task detection, use `track_in_tokio_task()`.
196    pub fn get_current_task() -> Option<u64> {
197        CURRENT_TASK_ID.with(|cell| cell.get())
198    }
199
200    /// Enter a task context with automatic cleanup.
201    /// Returns a TaskGuard that will clear the task ID when dropped.
202    pub fn enter_task(task_id: u64) -> TaskGuard {
203        Self::set_current_task(task_id);
204        TaskGuard::new(task_id)
205    }
206
207    /// Execute a closure within a task context with automatic cleanup.
208    /// The task ID is automatically cleared after the closure completes,
209    /// even if the closure panics.
210    pub fn with_task<F, T>(task_id: u64, f: F) -> T
211    where
212        F: FnOnce() -> T,
213    {
214        let _guard = Self::enter_task(task_id);
215        f()
216    }
217
218    pub fn track_task_start(
219        &self,
220        task_id: u64,
221        name: String,
222        _thread_id: ThreadId,
223    ) -> Result<(), AsyncError> {
224        self.track_task_start_internal(task_id, None, name)
225    }
226
227    pub fn track_task_start_with_tokio(
228        &self,
229        task_id: u64,
230        tokio_task_id: u64,
231        name: String,
232        _thread_id: ThreadId,
233    ) -> Result<(), AsyncError> {
234        self.track_task_start_internal(task_id, Some(tokio_task_id), name)
235    }
236
237    fn track_task_start_internal(
238        &self,
239        task_id: u64,
240        tokio_task_id: Option<u64>,
241        name: String,
242    ) -> Result<(), AsyncError> {
243        {
244            let mut profiles = self
245                .profiles
246                .lock()
247                .map_err(|e| AsyncError::mutex_lock_failed("profiles", &e.to_string()))?;
248
249            if profiles.contains_key(&task_id) {
250                return Err(AsyncError::duplicate_task(task_id));
251            }
252
253            let profile = match tokio_task_id {
254                Some(id) => {
255                    TaskMemoryProfile::with_tokio_id(task_id, id, name, TaskType::default())
256                }
257                None => TaskMemoryProfile::new(task_id, name, TaskType::default()),
258            };
259            profiles.insert(task_id, profile);
260        }
261
262        let mut stats = self
263            .stats
264            .lock()
265            .map_err(|e| AsyncError::mutex_lock_failed("stats", &e.to_string()))?;
266        stats.total_tasks += 1;
267        stats.active_tasks += 1;
268
269        Self::set_current_task(task_id);
270
271        Ok(())
272    }
273
274    /// Track a task end.
275    pub fn track_task_end(&self, task_id: u64) -> Result<(), AsyncError> {
276        {
277            let mut profiles = self
278                .profiles
279                .lock()
280                .map_err(|e| AsyncError::mutex_lock_failed("profiles", &e.to_string()))?;
281
282            let profile = profiles
283                .get_mut(&task_id)
284                .ok_or_else(|| AsyncError::task_not_found(task_id))?;
285
286            if profile.is_completed() {
287                return Ok(());
288            }
289
290            profile.mark_completed();
291        }
292
293        let mut stats = self
294            .stats
295            .lock()
296            .map_err(|e| AsyncError::mutex_lock_failed("stats", &e.to_string()))?;
297        stats.active_tasks = stats.active_tasks.saturating_sub(1);
298
299        Self::clear_current_task();
300
301        Ok(())
302    }
303
304    /// Execute an async block within a tokio task context.
305    ///
306    /// This method automatically detects the tokio task ID and sets up tracking.
307    /// When the future completes, the task is automatically marked as ended.
308    ///
309    /// # Arguments
310    ///
311    /// * `name` - Task name for identification
312    /// * `future` - The async block to execute
313    ///
314    /// # Returns
315    ///
316    /// A tuple of (unique_task_id, output).
317    /// The unique_task_id is our internal ID (not the tokio task ID).
318    pub async fn track_in_tokio_task<F, T>(&self, name: String, future: F) -> (u64, T)
319    where
320        F: Future<Output = T>,
321    {
322        let unique_task_id = generate_unique_task_id();
323        let tokio_task_id = tokio::task::try_id().and_then(|id| id.to_string().parse().ok());
324        let thread_id = std::thread::current().id();
325
326        if let Some(tokio_id) = tokio_task_id {
327            if let Err(e) =
328                self.track_task_start_with_tokio(unique_task_id, tokio_id, name.clone(), thread_id)
329            {
330                tracing::warn!("Failed to track task start: {e}");
331            }
332        } else if let Err(e) = self.track_task_start(unique_task_id, name.clone(), thread_id) {
333            tracing::warn!("Failed to track task start: {e}");
334        }
335
336        let output = future.await;
337
338        if let Err(e) = self.track_task_end(unique_task_id) {
339            tracing::warn!("Failed to track task end: {e}");
340        }
341
342        (unique_task_id, output)
343    }
344
345    /// Detect zombie tasks.
346    ///
347    /// A zombie task is a task that was started but never completed.
348    /// These tasks may indicate memory leaks or improper task cleanup.
349    ///
350    /// # Returns
351    ///
352    /// A vector of task IDs for zombie tasks.
353    pub fn detect_zombie_tasks(&self) -> Vec<u64> {
354        let profiles = self.profiles.lock().unwrap();
355        profiles
356            .iter()
357            .filter(|(_, p)| !p.is_completed())
358            .map(|(&id, _)| id)
359            .collect()
360    }
361
362    /// Get statistics about zombie tasks.
363    ///
364    /// Optimized to acquire the lock only once.
365    pub fn zombie_task_stats(&self) -> (usize, usize) {
366        let profiles = self.profiles.lock().unwrap();
367        let zombies = profiles.iter().filter(|(_, p)| !p.is_completed()).count();
368        let total = profiles.len();
369        (zombies, total)
370    }
371
372    pub fn track_allocation_auto(
373        &self,
374        ptr: usize,
375        size: usize,
376        var_name: Option<String>,
377        type_name: Option<String>,
378    ) {
379        if let Some(task_id) = Self::get_current_task() {
380            self.track_allocation_with_location(ptr, size, task_id, var_name, type_name, None);
381        }
382    }
383
384    /// Track an allocation associated with a task.
385    pub fn track_allocation(&self, ptr: usize, size: usize, task_id: u64) {
386        self.track_allocation_with_location(ptr, size, task_id, None, None, None);
387    }
388
389    /// Track an allocation with source location.
390    pub fn track_allocation_with_location(
391        &self,
392        ptr: usize,
393        size: usize,
394        task_id: u64,
395        var_name: Option<String>,
396        type_name: Option<String>,
397        source_location: Option<super::async_types::SourceLocation>,
398    ) {
399        let allocation = AsyncAllocation {
400            ptr,
401            size,
402            timestamp: Self::now(),
403            task_id,
404            var_name,
405            type_name,
406            source_location,
407        };
408
409        {
410            if let Ok(mut allocations) = self.allocations.lock() {
411                allocations.insert(ptr, allocation);
412            } else {
413                tracing::error!("Failed to acquire allocations lock during track_allocation");
414            }
415        }
416
417        {
418            if let Ok(mut profiles) = self.profiles.lock() {
419                if let Some(profile) = profiles.get_mut(&task_id) {
420                    profile.record_allocation(size as u64);
421                }
422            } else {
423                tracing::error!("Failed to acquire profiles lock during track_allocation");
424            }
425        }
426
427        {
428            if let Ok(mut stats) = self.stats.lock() {
429                stats.total_allocations += 1;
430                stats.total_memory += size;
431                stats.active_memory += size;
432                if stats.active_memory > stats.peak_memory {
433                    stats.peak_memory = stats.active_memory;
434                }
435            } else {
436                tracing::error!("Failed to acquire stats lock during track_allocation");
437            }
438        }
439    }
440
441    /// Track a deallocation associated with a task.
442    pub fn track_deallocation(&self, ptr: usize) {
443        let (task_id, size) = {
444            if let Ok(mut allocations) = self.allocations.lock() {
445                allocations
446                    .remove(&ptr)
447                    .map(|alloc| (alloc.task_id, alloc.size))
448                    .unwrap_or((0, 0))
449            } else {
450                tracing::error!("Failed to acquire allocations lock during track_deallocation");
451                (0, 0)
452            }
453        };
454
455        if task_id != 0 {
456            if let Ok(mut profiles) = self.profiles.lock() {
457                if let Some(profile) = profiles.get_mut(&task_id) {
458                    profile.record_deallocation(size as u64);
459                }
460            } else {
461                tracing::error!("Failed to acquire profiles lock during track_deallocation");
462            }
463        }
464
465        if size > 0 {
466            if let Ok(mut stats) = self.stats.lock() {
467                stats.active_memory = stats.active_memory.saturating_sub(size);
468                stats.total_deallocations += 1;
469                stats.total_deallocated += size as u64;
470            } else {
471                tracing::error!("Failed to acquire stats lock during track_deallocation");
472            }
473        }
474    }
475
476    /// Get current statistics.
477    pub fn get_stats(&self) -> AsyncStats {
478        if let Ok(stats) = self.stats.lock() {
479            stats.clone()
480        } else {
481            tracing::error!("Failed to acquire stats lock in get_stats");
482            AsyncStats::default()
483        }
484    }
485
486    /// Take a snapshot of current state.
487    pub fn snapshot(&self) -> AsyncSnapshot {
488        let profiles = if let Ok(p) = self.profiles.lock() {
489            p
490        } else {
491            tracing::error!("Failed to acquire profiles lock in snapshot");
492            return AsyncSnapshot::default();
493        };
494
495        let tasks: Vec<super::async_types::TaskInfo> = profiles
496            .values()
497            .filter(|p| p.completed_at_ms.is_none())
498            .map(|p| super::async_types::TaskInfo {
499                task_id: p.task_id,
500                name: p.task_name.clone(),
501                thread_id: std::thread::current().id(),
502                created_at: p.created_at_ms * 1_000_000,
503                active_allocations: p.total_allocations as usize,
504                total_memory: p.current_memory as usize,
505            })
506            .collect();
507        drop(profiles);
508
509        let allocations = {
510            if let Ok(allocs) = self.allocations.lock() {
511                allocs.values().cloned().collect()
512            } else {
513                tracing::error!("Failed to acquire allocations lock in snapshot");
514                Vec::new()
515            }
516        };
517
518        let stats = self.get_stats();
519
520        AsyncSnapshot {
521            timestamp: Self::now(),
522            tasks,
523            allocations,
524            stats,
525        }
526    }
527
528    /// Get task memory profile
529    pub fn get_task_profile(&self, task_id: u64) -> Option<TaskMemoryProfile> {
530        if let Ok(profiles) = self.profiles.lock() {
531            profiles.get(&task_id).cloned()
532        } else {
533            tracing::error!("Failed to acquire profiles lock in get_task_profile");
534            None
535        }
536    }
537
538    /// Get all task profiles
539    pub fn get_all_profiles(&self) -> Vec<TaskMemoryProfile> {
540        if let Ok(profiles) = self.profiles.lock() {
541            profiles.values().cloned().collect()
542        } else {
543            tracing::error!("Failed to acquire profiles lock in get_all_profiles");
544            Vec::new()
545        }
546    }
547
548    /// Check if tracker is initialized
549    pub fn is_initialized(&self) -> bool {
550        if let Ok(initialized) = self.initialized.lock() {
551            *initialized
552        } else {
553            tracing::error!("Failed to acquire initialized lock in is_initialized");
554            false
555        }
556    }
557
558    /// Mark tracker as initialized
559    pub fn set_initialized(&self) {
560        if let Ok(mut initialized) = self.initialized.lock() {
561            *initialized = true;
562        } else {
563            tracing::error!("Failed to acquire initialized lock in set_initialized");
564        }
565    }
566
567    /// Generate task efficiency report
568    pub fn analyze_task(&self, task_id: u64, task_type: TaskType) -> Option<TaskReport> {
569        let profile = self.get_task_profile(task_id)?;
570
571        let total_bytes = profile.total_bytes as f64;
572        let total_allocations = profile.total_allocations as f64;
573        let peak_memory = profile.peak_memory as f64;
574        let duration_ms = profile.duration_ns as f64 / 1_000_000.0;
575
576        let compute_efficiency = if duration_ms > 0.0 {
577            (total_allocations / duration_ms * 1000.0).min(1.0)
578        } else {
579            0.0
580        };
581
582        let cpu_efficiency = match task_type {
583            TaskType::CpuIntensive | TaskType::IoIntensive | TaskType::GpuCompute => {
584                compute_efficiency
585            }
586            TaskType::MemoryIntensive => {
587                if total_bytes > 0.0 {
588                    (peak_memory / total_bytes).min(1.0)
589                } else {
590                    0.0
591                }
592            }
593            TaskType::NetworkIntensive => {
594                if total_bytes > 0.0 {
595                    (total_allocations / total_bytes * 1000.0).min(1.0)
596                } else {
597                    0.0
598                }
599            }
600            _ => compute_efficiency,
601        };
602
603        let memory_efficiency = if total_bytes > 0.0 {
604            (total_allocations / total_bytes * 1000.0).min(1.0)
605        } else {
606            0.0
607        };
608
609        let io_efficiency = if duration_ms > 0.0 {
610            (total_bytes / duration_ms / 1_048_576.0).min(1.0)
611        } else {
612            0.0
613        };
614
615        let efficiency_score = (cpu_efficiency + memory_efficiency + io_efficiency) / 3.0;
616
617        let bottleneck = if duration_ms > 5000.0 {
618            "Execution Time".to_string()
619        } else if peak_memory > 100.0 * 1024.0 * 1024.0 {
620            "Memory".to_string()
621        } else if total_allocations > 10000.0 {
622            "Allocations".to_string()
623        } else {
624            "None".to_string()
625        };
626
627        let mut recommendations = Vec::new();
628        if duration_ms > 5000.0 {
629            recommendations.push("Consider optimizing task execution time".to_string());
630        }
631        if peak_memory > 100.0 * 1024.0 * 1024.0 {
632            recommendations.push("Reduce peak memory usage".to_string());
633        }
634        if total_allocations > 10000.0 {
635            recommendations.push("Reduce number of allocations".to_string());
636        }
637        if recommendations.is_empty() {
638            recommendations.push("Performance is good".to_string());
639        }
640
641        Some(TaskReport {
642            task_name: profile.task_name.clone(),
643            task_type,
644            efficiency_score,
645            cpu_efficiency,
646            memory_efficiency,
647            io_efficiency,
648            bottleneck,
649            recommendations,
650        })
651    }
652
653    /// Get resource rankings for all tasks
654    pub fn get_resource_rankings(&self) -> Vec<ResourceRanking> {
655        let profiles = self.get_all_profiles();
656
657        let mut rankings: Vec<ResourceRanking> = profiles
658            .into_iter()
659            .map(|profile| {
660                let memory_mb = profile.total_bytes as f64 / 1_048_576.0;
661                let peak_memory_mb = profile.peak_memory as f64 / 1_048_576.0;
662                let duration_ms = profile.duration_ns as f64 / 1_000_000.0;
663                let allocation_rate = profile.allocation_rate;
664
665                let overall_score = memory_mb * 0.3
666                    + peak_memory_mb * 0.2
667                    + allocation_rate * 0.0001
668                    + duration_ms * 0.0001;
669
670                ResourceRanking {
671                    task_name: profile.task_name.clone(),
672                    task_type: profile.task_type,
673                    cpu_usage: allocation_rate,
674                    memory_usage_mb: memory_mb,
675                    io_usage_mb: 0.0,
676                    network_usage_mb: 0.0,
677                    gpu_usage: 0.0,
678                    overall_score,
679                }
680            })
681            .collect();
682
683        rankings.sort_by(|a, b| {
684            b.overall_score
685                .partial_cmp(&a.overall_score)
686                .unwrap_or(std::cmp::Ordering::Equal)
687        });
688
689        rankings
690    }
691
692    /// Get current timestamp.
693    fn now() -> u64 {
694        std::time::SystemTime::now()
695            .duration_since(std::time::UNIX_EPOCH)
696            .unwrap_or_default()
697            .as_nanos() as u64
698    }
699}
700
701impl Default for AsyncTracker {
702    fn default() -> Self {
703        Self::new()
704    }
705}
706
707impl Drop for AsyncTracker {
708    fn drop(&mut self) {
709        Self::clear_current_task();
710    }
711}
712
713/// Initialize async memory tracking system
714pub fn initialize() -> AsyncResult<()> {
715    let mut global = GLOBAL_TRACKER.lock().map_err(|_| AsyncError::System {
716        operation: Arc::from("initialize"),
717        message: Arc::from("Failed to acquire global tracker lock"),
718    })?;
719
720    if global.is_none() {
721        let tracker = AsyncTracker::new();
722        tracker.set_initialized();
723        *global = Some(Arc::new(tracker));
724        tracing::info!("Async memory tracking system initialized");
725        Ok(())
726    } else {
727        Err(AsyncError::initialization(
728            "tracker",
729            "Already initialized",
730            true,
731        ))
732    }
733}
734
735/// Shutdown async memory tracking system
736pub fn shutdown() -> AsyncResult<()> {
737    let mut global = GLOBAL_TRACKER.lock().map_err(|_| AsyncError::System {
738        operation: Arc::from("shutdown"),
739        message: Arc::from("Failed to acquire global tracker lock"),
740    })?;
741
742    if global.is_some() {
743        *global = None;
744        tracing::info!("Async memory tracking system shutdown");
745        Ok(())
746    } else {
747        Err(AsyncError::initialization(
748            "tracker",
749            "Not initialized",
750            true,
751        ))
752    }
753}
754
755/// Reset global tracker state (for testing only)
756#[cfg(test)]
757pub fn reset_global_tracker() {
758    if let Ok(mut global) = GLOBAL_TRACKER.lock() {
759        *global = None;
760    } else {
761        tracing::error!("Failed to acquire global tracker lock in reset_global_tracker");
762    }
763}
764
765/// Register an existing AsyncTracker instance as the global singleton.
766/// This allows GlobalTracker to share its async_tracker instance with
767/// the async_tracker module, so `spawn_tracked()` can register task lifecycle
768/// events on the same instance that the renderer reads from.
769pub fn register_global(tracker: Arc<AsyncTracker>) -> AsyncResult<()> {
770    let mut global = GLOBAL_TRACKER.lock().map_err(|_| AsyncError::System {
771        operation: Arc::from("register_global"),
772        message: Arc::from("Failed to acquire global tracker lock"),
773    })?;
774    *global = Some(tracker);
775    Ok(())
776}
777
778/// Get the global async tracker
779fn get_global_tracker() -> AsyncResult<Arc<AsyncTracker>> {
780    GLOBAL_TRACKER
781        .lock()
782        .map_err(|_| AsyncError::System {
783            operation: Arc::from("get_global_tracker"),
784            message: Arc::from("Failed to acquire global tracker lock"),
785        })?
786        .clone()
787        .ok_or_else(|| {
788            AsyncError::initialization("tracker", "Tracking system not initialized", true)
789        })
790}
791
792/// Create a tracked future wrapper
793pub fn create_tracked<F>(future: F) -> TrackedFuture<F>
794where
795    F: Future,
796{
797    TrackedFuture::new(future)
798}
799
800/// Spawn a tracked async task with automatic context management.
801///
802/// This function wraps a future with memscope tracking context.
803/// The task ID is automatically generated and managed, and the
804/// context is automatically cleaned up when the task completes.
805///
806/// # Arguments
807///
808/// * `future` - The async block to execute
809///
810/// # Returns
811///
812/// A `tokio::task::JoinHandle` that resolves to the future's output
813///
814/// # Example
815///
816/// ```ignore
817/// let handle = spawn_tracked(async {
818///     let data = vec![1u8; 1024];
819///     tracker.track_as(&data, "buffer", file!(), line!());
820///     // ... async work
821/// });
822/// ```
823pub fn spawn_tracked<F>(future: F) -> tokio::task::JoinHandle<F::Output>
824where
825    F: Future + Send + 'static,
826    F::Output: Send + 'static,
827{
828    let task_id = generate_unique_task_id();
829
830    tokio::spawn(async move {
831        // Register task start with the global async tracker so profiles
832        // are populated for the dashboard renderer.
833        let tracker = get_global_tracker().ok();
834        let task_name = format!("spawned_task_{}", task_id);
835        if let Some(ref tracker) = tracker {
836            let thread_id = std::thread::current().id();
837            let _ = tracker.track_task_start(task_id, task_name.clone(), thread_id);
838        }
839
840        // Register with TaskIdRegistry so the Task Relationship Graph
841        // in the dashboard shows this task with parent-child hierarchy.
842        crate::task_registry::global_registry().register_explicit_task(task_id, &task_name);
843
844        // Set thread-local CURRENT_TASK_ID so that GlobalTracker::track_as()
845        // can associate allocations with this task via AsyncTracker::get_current_task().
846        AsyncTracker::set_current_task(task_id);
847
848        // Run the user's future inside the tokio task-local context.
849        let result = TASK_CONTEXT.scope(Some(task_id), future).await;
850
851        // Cleanup: clear thread-local task ID and signal task completion.
852        AsyncTracker::clear_current_task();
853        crate::task_registry::global_registry().unregister_task(task_id);
854        if let Some(ref tracker) = tracker {
855            let _ = tracker.track_task_end(task_id);
856        }
857
858        result
859    })
860}
861
862/// Get current memory usage snapshot
863pub fn get_memory_snapshot() -> AsyncMemorySnapshot {
864    if let Ok(tracker) = get_global_tracker() {
865        let stats = tracker.get_stats();
866
867        AsyncMemorySnapshot {
868            active_task_count: stats.active_tasks,
869            total_allocated_bytes: stats.total_memory as u64,
870            allocation_events: stats.total_allocations as u64,
871            events_dropped: 0,
872            buffer_utilization: 0.0,
873        }
874    } else {
875        AsyncMemorySnapshot {
876            active_task_count: 0,
877            total_allocated_bytes: 0,
878            allocation_events: 0,
879            events_dropped: 0,
880            buffer_utilization: 0.0,
881        }
882    }
883}
884
885/// Check if async memory tracking is currently active
886pub fn is_tracking_active() -> bool {
887    GLOBAL_TRACKER.lock().is_ok_and(|global| global.is_some())
888}
889
890/// Track allocation for current task
891pub fn track_current_allocation(ptr: usize, size: usize) -> AsyncResult<()> {
892    let tracker = get_global_tracker()?;
893    let task_info = super::async_types::get_current_task();
894
895    if task_info.has_tracking_id() {
896        tracker.track_allocation(ptr, size, (task_info.primary_id() & 0xFFFFFFFF) as u64);
897    }
898
899    Ok(())
900}
901
902/// Track deallocation for current task
903pub fn track_current_deallocation(ptr: usize) -> AsyncResult<()> {
904    let tracker = get_global_tracker()?;
905    tracker.track_deallocation(ptr);
906    Ok(())
907}
908
909#[cfg(test)]
910mod tests {
911    use super::*;
912    use crate::capture::backends::async_types::TaskOperation;
913
914    #[test]
915    fn test_async_tracker_creation() {
916        let tracker = AsyncTracker::new();
917        let stats = tracker.get_stats();
918        assert_eq!(stats.total_tasks, 0);
919    }
920
921    #[test]
922    fn test_task_tracking() {
923        let tracker = AsyncTracker::new();
924        let thread_id = std::thread::current().id();
925        tracker
926            .track_task_start(1, "test_task".to_string(), thread_id)
927            .unwrap();
928
929        let stats = tracker.get_stats();
930        assert_eq!(stats.total_tasks, 1);
931        assert_eq!(stats.active_tasks, 1);
932
933        tracker.track_task_end(1).unwrap();
934        let stats = tracker.get_stats();
935        assert_eq!(stats.active_tasks, 0);
936    }
937
938    #[test]
939    fn test_allocation_tracking() {
940        let tracker = AsyncTracker::new();
941        let thread_id = std::thread::current().id();
942        tracker
943            .track_task_start(1, "test_task".to_string(), thread_id)
944            .unwrap();
945        tracker.track_allocation(0x1000, 1024, 1);
946
947        let profile = tracker.get_task_profile(1);
948        assert!(profile.is_some());
949        let profile = profile.unwrap();
950        assert_eq!(profile.total_allocations, 1);
951        assert_eq!(profile.total_bytes, 1024);
952    }
953
954    #[test]
955    fn test_initialization() {
956        reset_global_tracker();
957
958        let result = initialize();
959        assert!(result.is_ok());
960
961        let result2 = initialize();
962        if let Err(e) = result2 {
963            assert!(e.message().contains("Already initialized"));
964        }
965
966        let _ = shutdown();
967    }
968
969    #[test]
970    fn test_shutdown() {
971        reset_global_tracker();
972
973        initialize().unwrap();
974        let result = shutdown();
975        assert!(result.is_ok());
976
977        let result2 = shutdown();
978        if let Err(e) = result2 {
979            assert!(e.message().contains("Not initialized"));
980        }
981    }
982
983    #[test]
984    fn test_memory_snapshot() {
985        reset_global_tracker();
986
987        initialize().unwrap();
988        let snapshot = get_memory_snapshot();
989        assert_eq!(snapshot.active_task_count, 0);
990        let _ = shutdown();
991    }
992
993    #[test]
994    fn test_is_tracking_active() {
995        reset_global_tracker();
996
997        assert!(!is_tracking_active());
998        initialize().unwrap();
999        assert!(is_tracking_active());
1000        let _ = shutdown();
1001        assert!(!is_tracking_active());
1002    }
1003
1004    #[test]
1005    fn test_task_memory_profile() {
1006        let tracker = AsyncTracker::new();
1007        let thread_id = std::thread::current().id();
1008        tracker
1009            .track_task_start(1, "test_task".to_string(), thread_id)
1010            .unwrap();
1011        tracker.track_allocation(0x1000, 1024, 1);
1012        tracker.track_allocation(0x2000, 2048, 1);
1013        tracker.track_task_end(1).unwrap();
1014
1015        let profile = tracker.get_task_profile(1);
1016        assert!(profile.is_some());
1017        let profile = profile.unwrap();
1018        assert_eq!(profile.task_id, 1);
1019        assert_eq!(profile.total_allocations, 2);
1020        assert_eq!(profile.total_bytes, 3072);
1021    }
1022
1023    #[test]
1024    fn test_duplicate_task_tracking() {
1025        let tracker = AsyncTracker::new();
1026        let thread_id = std::thread::current().id();
1027
1028        // First registration should succeed
1029        let result = tracker.track_task_start(1, "test_task".to_string(), thread_id);
1030        assert!(result.is_ok());
1031
1032        // Second registration should fail
1033        let result = tracker.track_task_start(1, "duplicate_task".to_string(), thread_id);
1034        assert!(result.is_err());
1035        let error = result.unwrap_err();
1036        assert!(
1037            matches!(error, AsyncError::TaskTracking { operation, .. } if matches!(operation, TaskOperation::Duplicate))
1038        );
1039    }
1040
1041    #[test]
1042    fn test_task_not_found() {
1043        let tracker = AsyncTracker::new();
1044
1045        // Calling track_task_end with non-existent task should fail
1046        let result = tracker.track_task_end(999);
1047        assert!(result.is_err());
1048        let error = result.unwrap_err();
1049        assert!(
1050            matches!(error, AsyncError::TaskTracking { operation, .. } if matches!(operation, TaskOperation::TaskNotFound))
1051        );
1052    }
1053
1054    #[test]
1055    fn test_task_guard_cleanup() {
1056        assert!(AsyncTracker::get_current_task().is_none());
1057
1058        {
1059            let _guard = AsyncTracker::enter_task(42);
1060            assert_eq!(AsyncTracker::get_current_task(), Some(42));
1061        }
1062
1063        assert!(AsyncTracker::get_current_task().is_none());
1064    }
1065
1066    #[test]
1067    fn test_with_task_closure() {
1068        assert!(AsyncTracker::get_current_task().is_none());
1069
1070        let result = AsyncTracker::with_task(123, || {
1071            assert_eq!(AsyncTracker::get_current_task(), Some(123));
1072            "test_result"
1073        });
1074
1075        assert_eq!(result, "test_result");
1076        assert!(AsyncTracker::get_current_task().is_none());
1077    }
1078
1079    #[test]
1080    fn test_with_task_panic_cleanup() {
1081        assert!(AsyncTracker::get_current_task().is_none());
1082
1083        let result = std::panic::catch_unwind(|| {
1084            AsyncTracker::with_task(999, || {
1085                assert_eq!(AsyncTracker::get_current_task(), Some(999));
1086                panic!("intentional panic");
1087            });
1088        });
1089
1090        assert!(result.is_err());
1091        assert!(AsyncTracker::get_current_task().is_none());
1092    }
1093
1094    #[test]
1095    fn test_generate_unique_task_id() {
1096        let id1 = generate_unique_task_id();
1097        let id2 = generate_unique_task_id();
1098        let id3 = generate_unique_task_id();
1099
1100        assert!(id1 > 0);
1101        assert!(id2 > id1);
1102        assert!(id3 > id2);
1103    }
1104
1105    #[test]
1106    fn test_track_start_with_tokio() {
1107        let tracker = AsyncTracker::new();
1108        let thread_id = std::thread::current().id();
1109
1110        let result =
1111            tracker.track_task_start_with_tokio(1, 100, "tokio_task".to_string(), thread_id);
1112        assert!(result.is_ok());
1113
1114        let profile = tracker.get_task_profile(1);
1115        assert!(profile.is_some());
1116        let profile = profile.unwrap();
1117        assert_eq!(profile.task_id, 1);
1118        assert_eq!(profile.tokio_task_id, Some(100));
1119        assert_eq!(profile.task_name, "tokio_task");
1120    }
1121
1122    #[test]
1123    fn test_track_task_internal_without_tokio() {
1124        let tracker = AsyncTracker::new();
1125        let thread_id = std::thread::current().id();
1126
1127        let result = tracker.track_task_start(2, "normal_task".to_string(), thread_id);
1128        assert!(result.is_ok());
1129
1130        let profile = tracker.get_task_profile(2);
1131        assert!(profile.is_some());
1132        let profile = profile.unwrap();
1133        assert_eq!(profile.task_id, 2);
1134        assert_eq!(profile.tokio_task_id, None);
1135    }
1136
1137    #[test]
1138    fn test_detect_zombie_tasks() {
1139        let tracker = AsyncTracker::new();
1140        let thread_id = std::thread::current().id();
1141
1142        tracker
1143            .track_task_start(1, "task1".to_string(), thread_id)
1144            .unwrap();
1145        tracker
1146            .track_task_start(2, "task2".to_string(), thread_id)
1147            .unwrap();
1148        tracker
1149            .track_task_start(3, "task3".to_string(), thread_id)
1150            .unwrap();
1151
1152        tracker.track_task_end(1).unwrap();
1153
1154        let zombies = tracker.detect_zombie_tasks();
1155        assert_eq!(zombies.len(), 2);
1156        assert!(zombies.contains(&2));
1157        assert!(zombies.contains(&3));
1158    }
1159
1160    #[test]
1161    fn test_zombie_task_stats() {
1162        let tracker = AsyncTracker::new();
1163        let thread_id = std::thread::current().id();
1164
1165        tracker
1166            .track_task_start(1, "task1".to_string(), thread_id)
1167            .unwrap();
1168        tracker
1169            .track_task_start(2, "task2".to_string(), thread_id)
1170            .unwrap();
1171
1172        tracker.track_task_end(1).unwrap();
1173
1174        let (zombie_count, total) = tracker.zombie_task_stats();
1175        assert_eq!(zombie_count, 1);
1176        assert_eq!(total, 2);
1177    }
1178
1179    #[test]
1180    fn test_no_zombie_tasks_when_all_complete() {
1181        let tracker = AsyncTracker::new();
1182        let thread_id = std::thread::current().id();
1183
1184        tracker
1185            .track_task_start(1, "task1".to_string(), thread_id)
1186            .unwrap();
1187        tracker
1188            .track_task_start(2, "task2".to_string(), thread_id)
1189            .unwrap();
1190
1191        tracker.track_task_end(1).unwrap();
1192        tracker.track_task_end(2).unwrap();
1193
1194        let zombies = tracker.detect_zombie_tasks();
1195        assert!(zombies.is_empty());
1196    }
1197
1198    #[test]
1199    fn test_task_memory_profile_with_tokio_id() {
1200        let profile = TaskMemoryProfile::with_tokio_id(1, 999, "test".to_string(), TaskType::Mixed);
1201
1202        assert_eq!(profile.task_id, 1);
1203        assert_eq!(profile.tokio_task_id, Some(999));
1204        assert_eq!(profile.task_name, "test");
1205        assert_eq!(profile.task_type, TaskType::Mixed);
1206        assert!(!profile.is_completed());
1207    }
1208
1209    #[tokio::test]
1210    async fn test_track_in_tokio_task_basic() {
1211        let tracker = AsyncTracker::new();
1212
1213        let (task_id, result) = tracker
1214            .track_in_tokio_task("async_task".to_string(), async {
1215                tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
1216                42
1217            })
1218            .await;
1219
1220        assert!(task_id > 0);
1221        assert_eq!(result, 42);
1222
1223        let profile = tracker.get_task_profile(task_id);
1224        assert!(profile.is_some());
1225        let profile = profile.unwrap();
1226        assert_eq!(profile.task_name, "async_task");
1227        assert!(profile.is_completed());
1228    }
1229
1230    #[tokio::test]
1231    async fn test_track_in_tokio_task_basic_functionality() {
1232        let tracker = AsyncTracker::new();
1233
1234        let (task_id, result) = tracker
1235            .track_in_tokio_task("test_task".to_string(), async {
1236                tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
1237                "completed"
1238            })
1239            .await;
1240
1241        assert!(task_id > 0);
1242        assert_eq!(result, "completed");
1243
1244        let profile = tracker.get_task_profile(task_id);
1245        assert!(profile.is_some());
1246        let profile = profile.unwrap();
1247        assert_eq!(profile.task_name, "test_task");
1248        assert!(profile.is_completed());
1249    }
1250
1251    #[test]
1252    fn test_global_tracker_integration() {
1253        reset_global_tracker();
1254
1255        let result = initialize();
1256        assert!(result.is_ok());
1257
1258        let tracker = get_global_tracker();
1259        assert!(tracker.is_ok());
1260
1261        let tracker = tracker.unwrap();
1262        let stats = tracker.get_stats();
1263        assert_eq!(stats.total_tasks, 0);
1264        assert_eq!(stats.active_tasks, 0);
1265
1266        let _ = shutdown();
1267    }
1268}