Skip to main content

memscope_rs/capture/backends/
async_tracker.rs

1//! Async memory tracker implementation.
2//!
3//! This module contains async-specific memory tracking functionality
4//! including task tracking, efficiency scoring, and bottleneck analysis.
5
6use std::collections::HashMap;
7use std::future::Future;
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::sync::{Arc, Mutex};
10use std::thread::ThreadId;
11
12use super::async_types::{
13    AsyncAllocation, AsyncError, AsyncMemorySnapshot, AsyncResult, AsyncSnapshot, AsyncStats,
14    TrackedFuture,
15};
16use super::task_profile::{TaskMemoryProfile, TaskType};
17
18/// Global task ID counter for unique task identification.
19/// Tokio task IDs are recycled after task completion, so we need
20/// our own counter to ensure unique identification across all tasks.
21static TASK_COUNTER: AtomicU64 = AtomicU64::new(1);
22
23/// Generate a new unique task ID.
24/// This ID is never recycled, ensuring unique identification.
25pub fn generate_unique_task_id() -> u64 {
26    TASK_COUNTER.fetch_add(1, Ordering::Relaxed)
27}
28
29/// Global thread ID counter for unique thread identification.
30static THREAD_COUNTER: AtomicU64 = AtomicU64::new(1);
31
32thread_local! {
33    static THREAD_ID: u64 = THREAD_COUNTER.fetch_add(1, Ordering::Relaxed);
34}
35
36/// Get the current thread's unique ID.
37/// This ID is assigned once per thread and never changes.
38pub fn current_thread_id() -> u64 {
39    THREAD_ID.with(|id| *id)
40}
41
42/// Context for tracking memory allocations.
43/// Captures both thread and task information for accurate attribution.
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
45pub struct TrackerContext {
46    pub thread_id: u64,
47    pub task_id: Option<u64>,
48    pub tokio_task_id: Option<u64>,
49}
50
51impl TrackerContext {
52    /// Capture the current tracking context.
53    /// Returns thread ID and task ID (if in a task context).
54    pub fn capture() -> Self {
55        let task_id_from_context = TASK_CONTEXT.try_with(|ctx| *ctx).ok().flatten();
56        let tokio_task_id = tokio::task::try_id().and_then(|id| id.to_string().parse().ok());
57
58        Self {
59            thread_id: current_thread_id(),
60            task_id: task_id_from_context.or(CURRENT_TASK_ID.with(|cell| cell.get())),
61            tokio_task_id,
62        }
63    }
64}
65
66/// Task efficiency report
67#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
68pub struct TaskReport {
69    pub task_name: String,
70    pub task_type: TaskType,
71    pub efficiency_score: f64,
72    pub cpu_efficiency: f64,
73    pub memory_efficiency: f64,
74    pub io_efficiency: f64,
75    pub bottleneck: String,
76    pub recommendations: Vec<String>,
77}
78
79/// Resource ranking entry
80#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
81pub struct ResourceRanking {
82    pub task_name: String,
83    pub task_type: TaskType,
84    pub cpu_usage: f64,
85    pub memory_usage_mb: f64,
86    pub io_usage_mb: f64,
87    pub network_usage_mb: f64,
88    pub gpu_usage: f64,
89    pub overall_score: f64,
90}
91
92/// Global async tracker instance
93static GLOBAL_TRACKER: Mutex<Option<Arc<AsyncTracker>>> = Mutex::new(None);
94
95thread_local! {
96    static CURRENT_TASK_ID: std::cell::Cell<Option<u64>> = const { std::cell::Cell::new(None) };
97}
98
99tokio::task_local! {
100    static TASK_CONTEXT: Option<u64>;
101}
102
103/// RAII guard for automatic task cleanup.
104/// When dropped, it automatically clears the current task ID from thread-local storage.
105pub struct TaskGuard {
106    task_id: u64,
107    cleaned_up: bool,
108}
109
110// SAFETY: TaskGuard can be safely sent between threads because:
111// 1. It only contains primitive types (u64, bool) that are both Send and Sync.
112// 2. It does NOT hold any references, raw pointers, or non-thread-safe data.
113// 3. The thread-local state it manages (CURRENT_TASK_ID) is accessed via thread_local!(),
114//    which provides per-thread isolation - each thread has its own copy of the task ID.
115// 4. When TaskGuard is moved to another thread, it clears the task ID in the CURRENT
116//    thread's thread-local storage (not the original thread's), which is the expected
117//    behavior for task context management in async runtimes.
118// 5. The Drop implementation clears the task ID of whichever thread is currently
119//    executing, which is correct for RAII cleanup in async contexts where tasks
120//    may migrate between threads.
121unsafe impl Send for TaskGuard {}
122
123// SAFETY: TaskGuard can be safely shared between threads because:
124// 1. All fields (task_id, cleaned_up) are primitive types that are safe for concurrent access.
125// 2. TaskGuard does not provide mutable access to its fields through &self.
126// 3. The thread-local manipulation via clear_current_task_internal() operates on the
127//    CURRENT thread's storage, not a shared resource, so concurrent calls from different
128//    threads are inherently isolated.
129// 4. No interior mutability or shared mutable state is exposed.
130unsafe impl Sync for TaskGuard {}
131
132impl TaskGuard {
133    fn new(task_id: u64) -> Self {
134        Self {
135            task_id,
136            cleaned_up: false,
137        }
138    }
139
140    /// Get the task ID this guard is associated with.
141    pub fn task_id(&self) -> u64 {
142        self.task_id
143    }
144
145    /// Manually release the guard (prevents double cleanup).
146    pub fn release(mut self) {
147        self.cleaned_up = true;
148        TaskGuard::clear_current_task_internal();
149    }
150}
151
152impl Drop for TaskGuard {
153    fn drop(&mut self) {
154        if !self.cleaned_up {
155            TaskGuard::clear_current_task_internal();
156        }
157    }
158}
159
160impl TaskGuard {
161    fn clear_current_task_internal() {
162        CURRENT_TASK_ID.with(|cell| cell.set(None));
163    }
164}
165
166/// Async memory tracker for task-aware memory tracking.
167pub struct AsyncTracker {
168    allocations: Arc<Mutex<HashMap<usize, AsyncAllocation>>>,
169    stats: Arc<Mutex<AsyncStats>>,
170    profiles: Arc<Mutex<HashMap<u64, TaskMemoryProfile>>>,
171    initialized: Arc<Mutex<bool>>,
172}
173
174impl AsyncTracker {
175    pub fn new() -> Self {
176        Self {
177            allocations: Arc::new(Mutex::new(HashMap::new())),
178            stats: Arc::new(Mutex::new(AsyncStats::default())),
179            profiles: Arc::new(Mutex::new(HashMap::new())),
180            initialized: Arc::new(Mutex::new(false)),
181        }
182    }
183
184    pub fn set_current_task(task_id: u64) {
185        CURRENT_TASK_ID.with(|cell| cell.set(Some(task_id)));
186    }
187
188    pub fn clear_current_task() {
189        CURRENT_TASK_ID.with(|cell| cell.set(None));
190    }
191
192    /// Get the current task ID from the thread-local storage.
193    ///
194    /// Note: This only returns the manually set task ID.
195    /// For automatic tokio task detection, use `track_in_tokio_task()`.
196    pub fn get_current_task() -> Option<u64> {
197        CURRENT_TASK_ID.with(|cell| cell.get())
198    }
199
200    /// Enter a task context with automatic cleanup.
201    /// Returns a TaskGuard that will clear the task ID when dropped.
202    pub fn enter_task(task_id: u64) -> TaskGuard {
203        Self::set_current_task(task_id);
204        TaskGuard::new(task_id)
205    }
206
207    /// Execute a closure within a task context with automatic cleanup.
208    /// The task ID is automatically cleared after the closure completes,
209    /// even if the closure panics.
210    pub fn with_task<F, T>(task_id: u64, f: F) -> T
211    where
212        F: FnOnce() -> T,
213    {
214        let _guard = Self::enter_task(task_id);
215        f()
216    }
217
218    pub fn track_task_start(
219        &self,
220        task_id: u64,
221        name: String,
222        _thread_id: ThreadId,
223    ) -> Result<(), AsyncError> {
224        self.track_task_start_internal(task_id, None, name)
225    }
226
227    pub fn track_task_start_with_tokio(
228        &self,
229        task_id: u64,
230        tokio_task_id: u64,
231        name: String,
232        _thread_id: ThreadId,
233    ) -> Result<(), AsyncError> {
234        self.track_task_start_internal(task_id, Some(tokio_task_id), name)
235    }
236
237    fn track_task_start_internal(
238        &self,
239        task_id: u64,
240        tokio_task_id: Option<u64>,
241        name: String,
242    ) -> Result<(), AsyncError> {
243        {
244            let mut profiles = self
245                .profiles
246                .lock()
247                .map_err(|e| AsyncError::mutex_lock_failed("profiles", &e.to_string()))?;
248
249            if profiles.contains_key(&task_id) {
250                return Err(AsyncError::duplicate_task(task_id));
251            }
252
253            let profile = match tokio_task_id {
254                Some(id) => {
255                    TaskMemoryProfile::with_tokio_id(task_id, id, name, TaskType::default())
256                }
257                None => TaskMemoryProfile::new(task_id, name, TaskType::default()),
258            };
259            profiles.insert(task_id, profile);
260        }
261
262        let mut stats = self
263            .stats
264            .lock()
265            .map_err(|e| AsyncError::mutex_lock_failed("stats", &e.to_string()))?;
266        stats.total_tasks += 1;
267        stats.active_tasks += 1;
268
269        Self::set_current_task(task_id);
270
271        Ok(())
272    }
273
274    /// Track a task end.
275    pub fn track_task_end(&self, task_id: u64) -> Result<(), AsyncError> {
276        {
277            let mut profiles = self
278                .profiles
279                .lock()
280                .map_err(|e| AsyncError::mutex_lock_failed("profiles", &e.to_string()))?;
281
282            let profile = profiles
283                .get_mut(&task_id)
284                .ok_or_else(|| AsyncError::task_not_found(task_id))?;
285
286            if profile.is_completed() {
287                return Ok(());
288            }
289
290            profile.mark_completed();
291        }
292
293        let mut stats = self
294            .stats
295            .lock()
296            .map_err(|e| AsyncError::mutex_lock_failed("stats", &e.to_string()))?;
297        stats.active_tasks = stats.active_tasks.saturating_sub(1);
298
299        Self::clear_current_task();
300
301        Ok(())
302    }
303
304    /// Execute an async block within a tokio task context.
305    ///
306    /// This method automatically detects the tokio task ID and sets up tracking.
307    /// When the future completes, the task is automatically marked as ended.
308    ///
309    /// # Arguments
310    ///
311    /// * `name` - Task name for identification
312    /// * `future` - The async block to execute
313    ///
314    /// # Returns
315    ///
316    /// A tuple of (unique_task_id, output).
317    /// The unique_task_id is our internal ID (not the tokio task ID).
318    pub async fn track_in_tokio_task<F, T>(&self, name: String, future: F) -> (u64, T)
319    where
320        F: Future<Output = T>,
321    {
322        let unique_task_id = generate_unique_task_id();
323        let tokio_task_id = tokio::task::try_id().and_then(|id| id.to_string().parse().ok());
324        let thread_id = std::thread::current().id();
325
326        if let Some(tokio_id) = tokio_task_id {
327            if let Err(e) =
328                self.track_task_start_with_tokio(unique_task_id, tokio_id, name.clone(), thread_id)
329            {
330                tracing::warn!("Failed to track task start: {e}");
331            }
332        } else {
333            if let Err(e) = self.track_task_start(unique_task_id, name.clone(), thread_id) {
334                tracing::warn!("Failed to track task start: {e}");
335            }
336        }
337
338        let output = future.await;
339
340        if let Err(e) = self.track_task_end(unique_task_id) {
341            tracing::warn!("Failed to track task end: {e}");
342        }
343
344        (unique_task_id, output)
345    }
346
347    /// Detect zombie tasks.
348    ///
349    /// A zombie task is a task that was started but never completed.
350    /// These tasks may indicate memory leaks or improper task cleanup.
351    ///
352    /// # Returns
353    ///
354    /// A vector of task IDs for zombie tasks.
355    pub fn detect_zombie_tasks(&self) -> Vec<u64> {
356        let profiles = self.profiles.lock().unwrap();
357        profiles
358            .iter()
359            .filter(|(_, p)| !p.is_completed())
360            .map(|(&id, _)| id)
361            .collect()
362    }
363
364    /// Get statistics about zombie tasks.
365    ///
366    /// Optimized to acquire the lock only once.
367    pub fn zombie_task_stats(&self) -> (usize, usize) {
368        let profiles = self.profiles.lock().unwrap();
369        let zombies = profiles.iter().filter(|(_, p)| !p.is_completed()).count();
370        let total = profiles.len();
371        (zombies, total)
372    }
373
374    pub fn track_allocation_auto(
375        &self,
376        ptr: usize,
377        size: usize,
378        var_name: Option<String>,
379        type_name: Option<String>,
380    ) {
381        if let Some(task_id) = Self::get_current_task() {
382            self.track_allocation_with_location(ptr, size, task_id, var_name, type_name, None);
383        }
384    }
385
386    /// Track an allocation associated with a task.
387    pub fn track_allocation(&self, ptr: usize, size: usize, task_id: u64) {
388        self.track_allocation_with_location(ptr, size, task_id, None, None, None);
389    }
390
391    /// Track an allocation with source location.
392    pub fn track_allocation_with_location(
393        &self,
394        ptr: usize,
395        size: usize,
396        task_id: u64,
397        var_name: Option<String>,
398        type_name: Option<String>,
399        source_location: Option<super::async_types::SourceLocation>,
400    ) {
401        let allocation = AsyncAllocation {
402            ptr,
403            size,
404            timestamp: Self::now(),
405            task_id,
406            var_name,
407            type_name,
408            source_location,
409        };
410
411        {
412            if let Ok(mut allocations) = self.allocations.lock() {
413                allocations.insert(ptr, allocation);
414            } else {
415                tracing::error!("Failed to acquire allocations lock during track_allocation");
416            }
417        }
418
419        {
420            if let Ok(mut profiles) = self.profiles.lock() {
421                if let Some(profile) = profiles.get_mut(&task_id) {
422                    profile.record_allocation(size as u64);
423                }
424            } else {
425                tracing::error!("Failed to acquire profiles lock during track_allocation");
426            }
427        }
428
429        {
430            if let Ok(mut stats) = self.stats.lock() {
431                stats.total_allocations += 1;
432                stats.total_memory += size;
433                stats.active_memory += size;
434                if stats.active_memory > stats.peak_memory {
435                    stats.peak_memory = stats.active_memory;
436                }
437            } else {
438                tracing::error!("Failed to acquire stats lock during track_allocation");
439            }
440        }
441    }
442
443    /// Track a deallocation associated with a task.
444    pub fn track_deallocation(&self, ptr: usize) {
445        let (task_id, size) = {
446            if let Ok(mut allocations) = self.allocations.lock() {
447                allocations
448                    .remove(&ptr)
449                    .map(|alloc| (alloc.task_id, alloc.size))
450                    .unwrap_or((0, 0))
451            } else {
452                tracing::error!("Failed to acquire allocations lock during track_deallocation");
453                (0, 0)
454            }
455        };
456
457        if task_id != 0 {
458            if let Ok(mut profiles) = self.profiles.lock() {
459                if let Some(profile) = profiles.get_mut(&task_id) {
460                    profile.record_deallocation(size as u64);
461                }
462            } else {
463                tracing::error!("Failed to acquire profiles lock during track_deallocation");
464            }
465        }
466
467        if size > 0 {
468            if let Ok(mut stats) = self.stats.lock() {
469                stats.active_memory = stats.active_memory.saturating_sub(size);
470                stats.total_deallocations += 1;
471                stats.total_deallocated += size as u64;
472            } else {
473                tracing::error!("Failed to acquire stats lock during track_deallocation");
474            }
475        }
476    }
477
478    /// Get current statistics.
479    pub fn get_stats(&self) -> AsyncStats {
480        if let Ok(stats) = self.stats.lock() {
481            stats.clone()
482        } else {
483            tracing::error!("Failed to acquire stats lock in get_stats");
484            AsyncStats::default()
485        }
486    }
487
488    /// Take a snapshot of current state.
489    pub fn snapshot(&self) -> AsyncSnapshot {
490        let profiles = if let Ok(p) = self.profiles.lock() {
491            p
492        } else {
493            tracing::error!("Failed to acquire profiles lock in snapshot");
494            return AsyncSnapshot::default();
495        };
496
497        let tasks: Vec<super::async_types::TaskInfo> = profiles
498            .values()
499            .filter(|p| p.completed_at_ms.is_none())
500            .map(|p| super::async_types::TaskInfo {
501                task_id: p.task_id,
502                name: p.task_name.clone(),
503                thread_id: std::thread::current().id(),
504                created_at: p.created_at_ms * 1_000_000,
505                active_allocations: p.total_allocations as usize,
506                total_memory: p.current_memory as usize,
507            })
508            .collect();
509        drop(profiles);
510
511        let allocations = {
512            if let Ok(allocs) = self.allocations.lock() {
513                allocs.values().cloned().collect()
514            } else {
515                tracing::error!("Failed to acquire allocations lock in snapshot");
516                Vec::new()
517            }
518        };
519
520        let stats = self.get_stats();
521
522        AsyncSnapshot {
523            timestamp: Self::now(),
524            tasks,
525            allocations,
526            stats,
527        }
528    }
529
530    /// Get task memory profile
531    pub fn get_task_profile(&self, task_id: u64) -> Option<TaskMemoryProfile> {
532        if let Ok(profiles) = self.profiles.lock() {
533            profiles.get(&task_id).cloned()
534        } else {
535            tracing::error!("Failed to acquire profiles lock in get_task_profile");
536            None
537        }
538    }
539
540    /// Get all task profiles
541    pub fn get_all_profiles(&self) -> Vec<TaskMemoryProfile> {
542        if let Ok(profiles) = self.profiles.lock() {
543            profiles.values().cloned().collect()
544        } else {
545            tracing::error!("Failed to acquire profiles lock in get_all_profiles");
546            Vec::new()
547        }
548    }
549
550    /// Check if tracker is initialized
551    pub fn is_initialized(&self) -> bool {
552        if let Ok(initialized) = self.initialized.lock() {
553            *initialized
554        } else {
555            tracing::error!("Failed to acquire initialized lock in is_initialized");
556            false
557        }
558    }
559
560    /// Mark tracker as initialized
561    pub fn set_initialized(&self) {
562        if let Ok(mut initialized) = self.initialized.lock() {
563            *initialized = true;
564        } else {
565            tracing::error!("Failed to acquire initialized lock in set_initialized");
566        }
567    }
568
569    /// Generate task efficiency report
570    pub fn analyze_task(&self, task_id: u64, task_type: TaskType) -> Option<TaskReport> {
571        let profile = self.get_task_profile(task_id)?;
572
573        let total_bytes = profile.total_bytes as f64;
574        let total_allocations = profile.total_allocations as f64;
575        let peak_memory = profile.peak_memory as f64;
576        let duration_ms = profile.duration_ns as f64 / 1_000_000.0;
577
578        let compute_efficiency = if duration_ms > 0.0 {
579            (total_allocations / duration_ms * 1000.0).min(1.0)
580        } else {
581            0.0
582        };
583
584        let cpu_efficiency = match task_type {
585            TaskType::CpuIntensive | TaskType::IoIntensive | TaskType::GpuCompute => {
586                compute_efficiency
587            }
588            TaskType::MemoryIntensive => {
589                if total_bytes > 0.0 {
590                    (peak_memory / total_bytes).min(1.0)
591                } else {
592                    0.0
593                }
594            }
595            TaskType::NetworkIntensive => {
596                if total_bytes > 0.0 {
597                    (total_allocations / total_bytes * 1000.0).min(1.0)
598                } else {
599                    0.0
600                }
601            }
602            _ => compute_efficiency,
603        };
604
605        let memory_efficiency = if total_bytes > 0.0 {
606            (total_allocations / total_bytes * 1000.0).min(1.0)
607        } else {
608            0.0
609        };
610
611        let io_efficiency = if duration_ms > 0.0 {
612            (total_bytes / duration_ms / 1_048_576.0).min(1.0)
613        } else {
614            0.0
615        };
616
617        let efficiency_score = (cpu_efficiency + memory_efficiency + io_efficiency) / 3.0;
618
619        let bottleneck = if duration_ms > 5000.0 {
620            "Execution Time".to_string()
621        } else if peak_memory > 100.0 * 1024.0 * 1024.0 {
622            "Memory".to_string()
623        } else if total_allocations > 10000.0 {
624            "Allocations".to_string()
625        } else {
626            "None".to_string()
627        };
628
629        let mut recommendations = Vec::new();
630        if duration_ms > 5000.0 {
631            recommendations.push("Consider optimizing task execution time".to_string());
632        }
633        if peak_memory > 100.0 * 1024.0 * 1024.0 {
634            recommendations.push("Reduce peak memory usage".to_string());
635        }
636        if total_allocations > 10000.0 {
637            recommendations.push("Reduce number of allocations".to_string());
638        }
639        if recommendations.is_empty() {
640            recommendations.push("Performance is good".to_string());
641        }
642
643        Some(TaskReport {
644            task_name: profile.task_name.clone(),
645            task_type,
646            efficiency_score,
647            cpu_efficiency,
648            memory_efficiency,
649            io_efficiency,
650            bottleneck,
651            recommendations,
652        })
653    }
654
655    /// Get resource rankings for all tasks
656    pub fn get_resource_rankings(&self) -> Vec<ResourceRanking> {
657        let profiles = self.get_all_profiles();
658
659        let mut rankings: Vec<ResourceRanking> = profiles
660            .into_iter()
661            .map(|profile| {
662                let memory_mb = profile.total_bytes as f64 / 1_048_576.0;
663                let peak_memory_mb = profile.peak_memory as f64 / 1_048_576.0;
664                let duration_ms = profile.duration_ns as f64 / 1_000_000.0;
665                let allocation_rate = profile.allocation_rate;
666
667                let overall_score = memory_mb * 0.3
668                    + peak_memory_mb * 0.2
669                    + allocation_rate * 0.0001
670                    + duration_ms * 0.0001;
671
672                ResourceRanking {
673                    task_name: profile.task_name.clone(),
674                    task_type: profile.task_type,
675                    cpu_usage: allocation_rate,
676                    memory_usage_mb: memory_mb,
677                    io_usage_mb: 0.0,
678                    network_usage_mb: 0.0,
679                    gpu_usage: 0.0,
680                    overall_score,
681                }
682            })
683            .collect();
684
685        rankings.sort_by(|a, b| {
686            b.overall_score
687                .partial_cmp(&a.overall_score)
688                .unwrap_or(std::cmp::Ordering::Equal)
689        });
690
691        rankings
692    }
693
694    /// Get current timestamp.
695    fn now() -> u64 {
696        std::time::SystemTime::now()
697            .duration_since(std::time::UNIX_EPOCH)
698            .unwrap_or_default()
699            .as_nanos() as u64
700    }
701}
702
703impl Default for AsyncTracker {
704    fn default() -> Self {
705        Self::new()
706    }
707}
708
709impl Drop for AsyncTracker {
710    fn drop(&mut self) {
711        Self::clear_current_task();
712    }
713}
714
715/// Initialize async memory tracking system
716pub fn initialize() -> AsyncResult<()> {
717    let mut global = GLOBAL_TRACKER.lock().map_err(|_| AsyncError::System {
718        operation: Arc::from("initialize"),
719        message: Arc::from("Failed to acquire global tracker lock"),
720    })?;
721
722    if global.is_none() {
723        let tracker = AsyncTracker::new();
724        tracker.set_initialized();
725        *global = Some(Arc::new(tracker));
726        tracing::info!("Async memory tracking system initialized");
727        Ok(())
728    } else {
729        Err(AsyncError::initialization(
730            "tracker",
731            "Already initialized",
732            true,
733        ))
734    }
735}
736
737/// Shutdown async memory tracking system
738pub fn shutdown() -> AsyncResult<()> {
739    let mut global = GLOBAL_TRACKER.lock().map_err(|_| AsyncError::System {
740        operation: Arc::from("shutdown"),
741        message: Arc::from("Failed to acquire global tracker lock"),
742    })?;
743
744    if global.is_some() {
745        *global = None;
746        tracing::info!("Async memory tracking system shutdown");
747        Ok(())
748    } else {
749        Err(AsyncError::initialization(
750            "tracker",
751            "Not initialized",
752            true,
753        ))
754    }
755}
756
757/// Reset global tracker state (for testing only)
758#[cfg(test)]
759pub fn reset_global_tracker() {
760    if let Ok(mut global) = GLOBAL_TRACKER.lock() {
761        *global = None;
762    } else {
763        tracing::error!("Failed to acquire global tracker lock in reset_global_tracker");
764    }
765}
766
767/// Get the global async tracker
768fn get_global_tracker() -> AsyncResult<Arc<AsyncTracker>> {
769    GLOBAL_TRACKER
770        .lock()
771        .map_err(|_| AsyncError::System {
772            operation: Arc::from("get_global_tracker"),
773            message: Arc::from("Failed to acquire global tracker lock"),
774        })?
775        .clone()
776        .ok_or_else(|| {
777            AsyncError::initialization("tracker", "Tracking system not initialized", true)
778        })
779}
780
781/// Create a tracked future wrapper
782pub fn create_tracked<F>(future: F) -> TrackedFuture<F>
783where
784    F: Future,
785{
786    TrackedFuture::new(future)
787}
788
789/// Spawn a tracked async task with automatic context management.
790///
791/// This function wraps a future with memscope tracking context.
792/// The task ID is automatically generated and managed, and the
793/// context is automatically cleaned up when the task completes.
794///
795/// # Arguments
796///
797/// * `future` - The async block to execute
798///
799/// # Returns
800///
801/// A `tokio::task::JoinHandle` that resolves to the future's output
802///
803/// # Example
804///
805/// ```ignore
806/// let handle = spawn_tracked(async {
807///     let data = vec![1u8; 1024];
808///     tracker.track_as(&data, "buffer", file!(), line!());
809///     // ... async work
810/// });
811/// ```
812pub fn spawn_tracked<F>(future: F) -> tokio::task::JoinHandle<F::Output>
813where
814    F: Future + Send + 'static,
815    F::Output: Send + 'static,
816{
817    let task_id = generate_unique_task_id();
818
819    tokio::spawn(async move { TASK_CONTEXT.scope(Some(task_id), future).await })
820}
821
822/// Get current memory usage snapshot
823pub fn get_memory_snapshot() -> AsyncMemorySnapshot {
824    if let Ok(tracker) = get_global_tracker() {
825        let stats = tracker.get_stats();
826
827        AsyncMemorySnapshot {
828            active_task_count: stats.active_tasks,
829            total_allocated_bytes: stats.total_memory as u64,
830            allocation_events: stats.total_allocations as u64,
831            events_dropped: 0,
832            buffer_utilization: 0.0,
833        }
834    } else {
835        AsyncMemorySnapshot {
836            active_task_count: 0,
837            total_allocated_bytes: 0,
838            allocation_events: 0,
839            events_dropped: 0,
840            buffer_utilization: 0.0,
841        }
842    }
843}
844
845/// Check if async memory tracking is currently active
846pub fn is_tracking_active() -> bool {
847    GLOBAL_TRACKER.lock().is_ok_and(|global| global.is_some())
848}
849
850/// Track allocation for current task
851pub fn track_current_allocation(ptr: usize, size: usize) -> AsyncResult<()> {
852    let tracker = get_global_tracker()?;
853    let task_info = super::async_types::get_current_task();
854
855    if task_info.has_tracking_id() {
856        tracker.track_allocation(ptr, size, (task_info.primary_id() & 0xFFFFFFFF) as u64);
857    }
858
859    Ok(())
860}
861
862/// Track deallocation for current task
863pub fn track_current_deallocation(ptr: usize) -> AsyncResult<()> {
864    let tracker = get_global_tracker()?;
865    tracker.track_deallocation(ptr);
866    Ok(())
867}
868
869#[cfg(test)]
870mod tests {
871    use super::*;
872    use crate::capture::backends::async_types::TaskOperation;
873
874    #[test]
875    fn test_async_tracker_creation() {
876        let tracker = AsyncTracker::new();
877        let stats = tracker.get_stats();
878        assert_eq!(stats.total_tasks, 0);
879    }
880
881    #[test]
882    fn test_task_tracking() {
883        let tracker = AsyncTracker::new();
884        let thread_id = std::thread::current().id();
885        tracker
886            .track_task_start(1, "test_task".to_string(), thread_id)
887            .unwrap();
888
889        let stats = tracker.get_stats();
890        assert_eq!(stats.total_tasks, 1);
891        assert_eq!(stats.active_tasks, 1);
892
893        tracker.track_task_end(1).unwrap();
894        let stats = tracker.get_stats();
895        assert_eq!(stats.active_tasks, 0);
896    }
897
898    #[test]
899    fn test_allocation_tracking() {
900        let tracker = AsyncTracker::new();
901        let thread_id = std::thread::current().id();
902        tracker
903            .track_task_start(1, "test_task".to_string(), thread_id)
904            .unwrap();
905        tracker.track_allocation(0x1000, 1024, 1);
906
907        let profile = tracker.get_task_profile(1);
908        assert!(profile.is_some());
909        let profile = profile.unwrap();
910        assert_eq!(profile.total_allocations, 1);
911        assert_eq!(profile.total_bytes, 1024);
912    }
913
914    #[test]
915    fn test_initialization() {
916        reset_global_tracker();
917
918        let result = initialize();
919        assert!(result.is_ok());
920
921        let result2 = initialize();
922        if let Err(e) = result2 {
923            assert!(e.message().contains("Already initialized"));
924        }
925
926        let _ = shutdown();
927    }
928
929    #[test]
930    fn test_shutdown() {
931        reset_global_tracker();
932
933        initialize().unwrap();
934        let result = shutdown();
935        assert!(result.is_ok());
936
937        let result2 = shutdown();
938        if let Err(e) = result2 {
939            assert!(e.message().contains("Not initialized"));
940        }
941    }
942
943    #[test]
944    fn test_memory_snapshot() {
945        reset_global_tracker();
946
947        initialize().unwrap();
948        let snapshot = get_memory_snapshot();
949        assert_eq!(snapshot.active_task_count, 0);
950        let _ = shutdown();
951    }
952
953    #[test]
954    fn test_is_tracking_active() {
955        reset_global_tracker();
956
957        assert!(!is_tracking_active());
958        initialize().unwrap();
959        assert!(is_tracking_active());
960        let _ = shutdown();
961        assert!(!is_tracking_active());
962    }
963
964    #[test]
965    fn test_task_memory_profile() {
966        let tracker = AsyncTracker::new();
967        let thread_id = std::thread::current().id();
968        tracker
969            .track_task_start(1, "test_task".to_string(), thread_id)
970            .unwrap();
971        tracker.track_allocation(0x1000, 1024, 1);
972        tracker.track_allocation(0x2000, 2048, 1);
973        tracker.track_task_end(1).unwrap();
974
975        let profile = tracker.get_task_profile(1);
976        assert!(profile.is_some());
977        let profile = profile.unwrap();
978        assert_eq!(profile.task_id, 1);
979        assert_eq!(profile.total_allocations, 2);
980        assert_eq!(profile.total_bytes, 3072);
981    }
982
983    #[test]
984    fn test_duplicate_task_tracking() {
985        let tracker = AsyncTracker::new();
986        let thread_id = std::thread::current().id();
987
988        // First registration should succeed
989        let result = tracker.track_task_start(1, "test_task".to_string(), thread_id);
990        assert!(result.is_ok());
991
992        // Second registration should fail
993        let result = tracker.track_task_start(1, "duplicate_task".to_string(), thread_id);
994        assert!(result.is_err());
995        let error = result.unwrap_err();
996        assert!(
997            matches!(error, AsyncError::TaskTracking { operation, .. } if matches!(operation, TaskOperation::Duplicate))
998        );
999    }
1000
1001    #[test]
1002    fn test_task_not_found() {
1003        let tracker = AsyncTracker::new();
1004
1005        // Calling track_task_end with non-existent task should fail
1006        let result = tracker.track_task_end(999);
1007        assert!(result.is_err());
1008        let error = result.unwrap_err();
1009        assert!(
1010            matches!(error, AsyncError::TaskTracking { operation, .. } if matches!(operation, TaskOperation::TaskNotFound))
1011        );
1012    }
1013
1014    #[test]
1015    fn test_task_guard_cleanup() {
1016        assert!(AsyncTracker::get_current_task().is_none());
1017
1018        {
1019            let _guard = AsyncTracker::enter_task(42);
1020            assert_eq!(AsyncTracker::get_current_task(), Some(42));
1021        }
1022
1023        assert!(AsyncTracker::get_current_task().is_none());
1024    }
1025
1026    #[test]
1027    fn test_with_task_closure() {
1028        assert!(AsyncTracker::get_current_task().is_none());
1029
1030        let result = AsyncTracker::with_task(123, || {
1031            assert_eq!(AsyncTracker::get_current_task(), Some(123));
1032            "test_result"
1033        });
1034
1035        assert_eq!(result, "test_result");
1036        assert!(AsyncTracker::get_current_task().is_none());
1037    }
1038
1039    #[test]
1040    fn test_with_task_panic_cleanup() {
1041        assert!(AsyncTracker::get_current_task().is_none());
1042
1043        let result = std::panic::catch_unwind(|| {
1044            AsyncTracker::with_task(999, || {
1045                assert_eq!(AsyncTracker::get_current_task(), Some(999));
1046                panic!("intentional panic");
1047            });
1048        });
1049
1050        assert!(result.is_err());
1051        assert!(AsyncTracker::get_current_task().is_none());
1052    }
1053
1054    #[test]
1055    fn test_generate_unique_task_id() {
1056        let id1 = generate_unique_task_id();
1057        let id2 = generate_unique_task_id();
1058        let id3 = generate_unique_task_id();
1059
1060        assert!(id1 > 0);
1061        assert!(id2 > id1);
1062        assert!(id3 > id2);
1063    }
1064
1065    #[test]
1066    fn test_track_start_with_tokio() {
1067        let tracker = AsyncTracker::new();
1068        let thread_id = std::thread::current().id();
1069
1070        let result =
1071            tracker.track_task_start_with_tokio(1, 100, "tokio_task".to_string(), thread_id);
1072        assert!(result.is_ok());
1073
1074        let profile = tracker.get_task_profile(1);
1075        assert!(profile.is_some());
1076        let profile = profile.unwrap();
1077        assert_eq!(profile.task_id, 1);
1078        assert_eq!(profile.tokio_task_id, Some(100));
1079        assert_eq!(profile.task_name, "tokio_task");
1080    }
1081
1082    #[test]
1083    fn test_track_task_internal_without_tokio() {
1084        let tracker = AsyncTracker::new();
1085        let thread_id = std::thread::current().id();
1086
1087        let result = tracker.track_task_start(2, "normal_task".to_string(), thread_id);
1088        assert!(result.is_ok());
1089
1090        let profile = tracker.get_task_profile(2);
1091        assert!(profile.is_some());
1092        let profile = profile.unwrap();
1093        assert_eq!(profile.task_id, 2);
1094        assert_eq!(profile.tokio_task_id, None);
1095    }
1096
1097    #[test]
1098    fn test_detect_zombie_tasks() {
1099        let tracker = AsyncTracker::new();
1100        let thread_id = std::thread::current().id();
1101
1102        tracker
1103            .track_task_start(1, "task1".to_string(), thread_id)
1104            .unwrap();
1105        tracker
1106            .track_task_start(2, "task2".to_string(), thread_id)
1107            .unwrap();
1108        tracker
1109            .track_task_start(3, "task3".to_string(), thread_id)
1110            .unwrap();
1111
1112        tracker.track_task_end(1).unwrap();
1113
1114        let zombies = tracker.detect_zombie_tasks();
1115        assert_eq!(zombies.len(), 2);
1116        assert!(zombies.contains(&2));
1117        assert!(zombies.contains(&3));
1118    }
1119
1120    #[test]
1121    fn test_zombie_task_stats() {
1122        let tracker = AsyncTracker::new();
1123        let thread_id = std::thread::current().id();
1124
1125        tracker
1126            .track_task_start(1, "task1".to_string(), thread_id)
1127            .unwrap();
1128        tracker
1129            .track_task_start(2, "task2".to_string(), thread_id)
1130            .unwrap();
1131
1132        tracker.track_task_end(1).unwrap();
1133
1134        let (zombie_count, total) = tracker.zombie_task_stats();
1135        assert_eq!(zombie_count, 1);
1136        assert_eq!(total, 2);
1137    }
1138
1139    #[test]
1140    fn test_no_zombie_tasks_when_all_complete() {
1141        let tracker = AsyncTracker::new();
1142        let thread_id = std::thread::current().id();
1143
1144        tracker
1145            .track_task_start(1, "task1".to_string(), thread_id)
1146            .unwrap();
1147        tracker
1148            .track_task_start(2, "task2".to_string(), thread_id)
1149            .unwrap();
1150
1151        tracker.track_task_end(1).unwrap();
1152        tracker.track_task_end(2).unwrap();
1153
1154        let zombies = tracker.detect_zombie_tasks();
1155        assert!(zombies.is_empty());
1156    }
1157
1158    #[test]
1159    fn test_task_memory_profile_with_tokio_id() {
1160        let profile = TaskMemoryProfile::with_tokio_id(1, 999, "test".to_string(), TaskType::Mixed);
1161
1162        assert_eq!(profile.task_id, 1);
1163        assert_eq!(profile.tokio_task_id, Some(999));
1164        assert_eq!(profile.task_name, "test");
1165        assert_eq!(profile.task_type, TaskType::Mixed);
1166        assert!(!profile.is_completed());
1167    }
1168
1169    #[tokio::test]
1170    async fn test_track_in_tokio_task_basic() {
1171        let tracker = AsyncTracker::new();
1172
1173        let (task_id, result) = tracker
1174            .track_in_tokio_task("async_task".to_string(), async {
1175                tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
1176                42
1177            })
1178            .await;
1179
1180        assert!(task_id > 0);
1181        assert_eq!(result, 42);
1182
1183        let profile = tracker.get_task_profile(task_id);
1184        assert!(profile.is_some());
1185        let profile = profile.unwrap();
1186        assert_eq!(profile.task_name, "async_task");
1187        assert!(profile.is_completed());
1188    }
1189
1190    #[tokio::test]
1191    async fn test_track_in_tokio_task_basic_functionality() {
1192        let tracker = AsyncTracker::new();
1193
1194        let (task_id, result) = tracker
1195            .track_in_tokio_task("test_task".to_string(), async {
1196                tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
1197                "completed"
1198            })
1199            .await;
1200
1201        assert!(task_id > 0);
1202        assert_eq!(result, "completed");
1203
1204        let profile = tracker.get_task_profile(task_id);
1205        assert!(profile.is_some());
1206        let profile = profile.unwrap();
1207        assert_eq!(profile.task_name, "test_task");
1208        assert!(profile.is_completed());
1209    }
1210
1211    #[test]
1212    fn test_global_tracker_integration() {
1213        reset_global_tracker();
1214
1215        let result = initialize();
1216        assert!(result.is_ok());
1217
1218        let tracker = get_global_tracker();
1219        assert!(tracker.is_ok());
1220
1221        let tracker = tracker.unwrap();
1222        let stats = tracker.get_stats();
1223        assert_eq!(stats.total_tasks, 0);
1224        assert_eq!(stats.active_tasks, 0);
1225
1226        let _ = shutdown();
1227    }
1228}