memscope_rs/core/
thread_registry.rs

1//! Thread registry for managing thread-local memory trackers and data aggregation.
2//!
3//! This module provides a global registry that tracks all thread-local memory trackers
4//! for data aggregation purposes. It enables the unified tracking system to collect
5//! data from all tracking modes: track_var!, lockfree, and async_memory.
6
7use crate::core::tracker::memory_tracker::MemoryTracker;
8use crate::core::types::MemoryStats;
9use std::collections::HashMap;
10use std::sync::{Arc, Mutex, Weak};
11use std::thread::ThreadId;
12
13/// Global thread registry for tracking all thread-local memory trackers
14static THREAD_REGISTRY: std::sync::OnceLock<Arc<Mutex<ThreadRegistry>>> =
15    std::sync::OnceLock::new();
16
17/// Thread registry that maintains weak references to all thread-local trackers
18struct ThreadRegistry {
19    /// Map of thread ID to weak reference of memory tracker
20    trackers: HashMap<ThreadId, Weak<MemoryTracker>>,
21    /// Cached data from completed threads (persisted after thread exit)
22    cached_thread_data: HashMap<ThreadId, CachedThreadData>,
23    /// Total number of threads ever registered
24    total_threads_registered: usize,
25    /// Number of currently active threads
26    active_threads: usize,
27}
28
29/// Cached tracking data from a completed thread
30#[derive(Debug, Clone)]
31pub struct CachedThreadData {
32    /// Thread ID
33    pub thread_id: ThreadId,
34    /// Cached memory stats
35    pub stats: MemoryStats,
36    /// Timestamp when data was cached
37    #[allow(dead_code)]
38    pub cached_at: std::time::SystemTime,
39}
40
41impl ThreadRegistry {
42    /// Create a new thread registry
43    fn new() -> Self {
44        Self {
45            trackers: HashMap::new(),
46            cached_thread_data: HashMap::new(),
47            total_threads_registered: 0,
48            active_threads: 0,
49        }
50    }
51
52    /// Register a thread-local tracker
53    fn register_tracker(&mut self, thread_id: ThreadId, tracker: &Arc<MemoryTracker>) {
54        // Clean up any dead weak references first
55        self.cleanup_dead_references();
56
57        // Register the new tracker
58        self.trackers.insert(thread_id, Arc::downgrade(tracker));
59        self.total_threads_registered += 1;
60        self.active_threads = self.trackers.len();
61
62        tracing::debug!(
63            "Registered thread {:?}, total threads: {}, active threads: {}",
64            thread_id,
65            self.total_threads_registered,
66            self.active_threads
67        );
68    }
69
70    /// Collect all currently active trackers for data aggregation
71    fn collect_active_trackers(&mut self) -> Vec<Arc<MemoryTracker>> {
72        // First, try to cache data from all trackers that are still upgradeable
73        self.cache_all_available_data();
74
75        // Collect all strong references that are still alive
76        let mut active_trackers = Vec::new();
77        let mut dead_thread_ids = Vec::new();
78
79        for (thread_id, weak_tracker) in &self.trackers {
80            if let Some(strong_tracker) = weak_tracker.upgrade() {
81                active_trackers.push(strong_tracker);
82                tracing::debug!("Successfully collected tracker for thread {:?}", thread_id);
83            } else {
84                // Tracker is dead but we might have cached data
85                if self.cached_thread_data.contains_key(thread_id) {
86                    tracing::debug!("Found cached data for dead thread {:?}", thread_id);
87                } else {
88                    tracing::debug!(
89                        "Found dead tracker reference with no cached data for thread {:?}",
90                        thread_id
91                    );
92                }
93                dead_thread_ids.push(*thread_id);
94            }
95        }
96
97        tracing::debug!(
98            "Collected {} active trackers, {} dead with cached data, {} total cached entries",
99            active_trackers.len(),
100            dead_thread_ids.len(),
101            self.cached_thread_data.len()
102        );
103
104        active_trackers
105    }
106
107    /// Cache data from all currently available trackers
108    fn cache_all_available_data(&mut self) {
109        for (thread_id, weak_tracker) in &self.trackers {
110            if let Some(strong_tracker) = weak_tracker.upgrade() {
111                if let Ok(stats) = strong_tracker.get_stats() {
112                    // Only cache if we have meaningful data
113                    if stats.total_allocations > 0 {
114                        let allocations = stats.total_allocations;
115                        let allocated = stats.total_allocated;
116
117                        self.cached_thread_data.insert(
118                            *thread_id,
119                            CachedThreadData {
120                                thread_id: *thread_id,
121                                stats,
122                                cached_at: std::time::SystemTime::now(),
123                            },
124                        );
125                        tracing::debug!(
126                            "Cached data for thread {:?}: {} allocations, {} bytes",
127                            thread_id,
128                            allocations,
129                            allocated
130                        );
131                    }
132                }
133            }
134        }
135    }
136
137    /// Remove dead weak references from the registry
138    fn cleanup_dead_references(&mut self) {
139        let initial_count = self.trackers.len();
140        self.trackers
141            .retain(|_thread_id, weak_tracker| weak_tracker.strong_count() > 0);
142
143        let removed_count = initial_count - self.trackers.len();
144        if removed_count > 0 {
145            tracing::debug!("Cleaned up {} dead tracker references", removed_count);
146        }
147
148        self.active_threads = self.trackers.len();
149    }
150
151    /// Get registry statistics for monitoring
152    fn get_stats(&self) -> ThreadRegistryStats {
153        ThreadRegistryStats {
154            total_threads_registered: self.total_threads_registered,
155            active_threads: self.active_threads,
156            dead_references: self
157                .trackers
158                .iter()
159                .filter(|(_, weak)| weak.strong_count() == 0)
160                .count(),
161        }
162    }
163}
164
165/// Statistics about the thread registry for monitoring and debugging
166#[derive(Debug, Clone)]
167pub struct ThreadRegistryStats {
168    /// Total number of threads that have ever been registered
169    pub total_threads_registered: usize,
170    /// Number of currently active threads
171    pub active_threads: usize,
172    /// Number of dead weak references (cleanup candidates)
173    pub dead_references: usize,
174}
175
176/// Data aggregation result from all tracking modes
177#[derive(Debug, Clone)]
178pub struct AggregatedTrackingData {
179    /// Number of trackers included in this aggregation
180    pub tracker_count: usize,
181    /// Total allocations across all trackers
182    pub total_allocations: u64,
183    /// Total bytes allocated across all trackers
184    pub total_bytes_allocated: u64,
185    /// Peak memory usage across all trackers
186    pub peak_memory_usage: u64,
187    /// Number of active threads that contributed data
188    pub active_threads: usize,
189    /// Combined statistics from all tracking modes
190    pub combined_stats: Vec<CombinedTrackerStats>,
191}
192
193/// Combined statistics from a single tracker (can be track_var!, lockfree, or async)
194#[derive(Debug, Clone)]
195pub struct CombinedTrackerStats {
196    /// Thread ID where this tracker operates
197    pub thread_id: ThreadId,
198    /// Type of tracking mode
199    pub tracking_mode: String,
200    /// Number of allocations in this tracker
201    pub allocations: u64,
202    /// Bytes allocated in this tracker
203    pub bytes_allocated: u64,
204    /// Peak memory for this tracker
205    pub peak_memory: u64,
206}
207
208/// Get the global thread registry instance
209fn get_registry() -> Arc<Mutex<ThreadRegistry>> {
210    THREAD_REGISTRY
211        .get_or_init(|| Arc::new(Mutex::new(ThreadRegistry::new())))
212        .clone()
213}
214
215/// Register the current thread's tracker with the global registry.
216///
217/// This function should be called automatically when a thread-local tracker
218/// is first accessed. It stores a weak reference to avoid preventing
219/// tracker cleanup when threads exit.
220pub fn register_current_thread_tracker(tracker: &Arc<MemoryTracker>) {
221    let thread_id = std::thread::current().id();
222
223    if let Ok(mut registry) = get_registry().lock() {
224        registry.register_tracker(thread_id, tracker);
225    } else {
226        tracing::error!("Failed to acquire registry lock for thread registration");
227    }
228}
229
230/// Collect and aggregate data from all tracking modes.
231///
232/// This is the main function for unified data collection that combines:
233/// - track_var! data from all threads
234/// - lockfree module data
235/// - async_memory module data
236pub fn collect_unified_tracking_data() -> Result<AggregatedTrackingData, String> {
237    let mut combined_stats = Vec::new();
238    let mut total_allocations = 0u64;
239    let mut total_bytes_allocated = 0u64;
240    let mut peak_memory_usage = 0u64;
241
242    // Collect track_var! data from all thread-local trackers (active + cached)
243    let active_trackers = collect_all_trackers();
244    let cached_data = get_cached_thread_data();
245
246    // Process active trackers
247    for tracker in &active_trackers {
248        if let Ok(stats) = tracker.get_stats() {
249            let thread_stats = CombinedTrackerStats {
250                thread_id: std::thread::current().id(), // Will be improved with actual thread IDs
251                tracking_mode: "track_var!".to_string(),
252                allocations: stats.total_allocations as u64,
253                bytes_allocated: stats.total_allocated as u64,
254                peak_memory: stats.peak_memory as u64,
255            };
256
257            total_allocations += stats.total_allocations as u64;
258            total_bytes_allocated += stats.total_allocated as u64;
259            peak_memory_usage = peak_memory_usage.max(stats.peak_memory as u64);
260
261            combined_stats.push(thread_stats);
262        }
263    }
264
265    // Process cached data from completed threads
266    for cached in cached_data {
267        let thread_stats = CombinedTrackerStats {
268            thread_id: cached.thread_id,
269            tracking_mode: "track_var!".to_string(),
270            allocations: cached.stats.total_allocations as u64,
271            bytes_allocated: cached.stats.total_allocated as u64,
272            peak_memory: cached.stats.peak_memory as u64,
273        };
274
275        total_allocations += cached.stats.total_allocations as u64;
276        total_bytes_allocated += cached.stats.total_allocated as u64;
277        peak_memory_usage = peak_memory_usage.max(cached.stats.peak_memory as u64);
278
279        combined_stats.push(thread_stats);
280    }
281
282    // TODO: Integrate with lockfree module data
283    // This will be implemented to collect data from lockfree aggregators
284
285    // TODO: Integrate with async_memory module data
286    // This will be implemented to collect data from async trackers
287
288    let aggregated_data = AggregatedTrackingData {
289        tracker_count: active_trackers.len(),
290        total_allocations,
291        total_bytes_allocated,
292        peak_memory_usage,
293        active_threads: active_trackers.len(),
294        combined_stats,
295    };
296
297    tracing::info!(
298        "Collected unified tracking data: {} trackers, {} allocations, {} bytes",
299        aggregated_data.tracker_count,
300        aggregated_data.total_allocations,
301        aggregated_data.total_bytes_allocated
302    );
303
304    Ok(aggregated_data)
305}
306
307/// Collect all currently active thread-local memory trackers.
308///
309/// This function is used by the aggregation system to gather data from
310/// all active threads when running in thread-local mode.
311pub fn collect_all_trackers() -> Vec<Arc<MemoryTracker>> {
312    match get_registry().lock() {
313        Ok(mut registry) => registry.collect_active_trackers(),
314        Err(e) => {
315            tracing::error!(
316                "Failed to acquire registry lock for tracker collection: {}",
317                e
318            );
319            Vec::new()
320        }
321    }
322}
323
324/// Get cached thread data from completed threads.
325///
326/// This function returns data that was cached from threads that have already
327/// completed but whose tracking data is still valuable for aggregation.
328pub fn get_cached_thread_data() -> Vec<CachedThreadData> {
329    match get_registry().lock() {
330        Ok(registry) => registry.cached_thread_data.values().cloned().collect(),
331        Err(e) => {
332            tracing::error!("Failed to acquire registry lock for cached data: {}", e);
333            Vec::new()
334        }
335    }
336}
337
338/// Get statistics about the thread registry.
339///
340/// This function provides information about how many threads have been
341/// registered and how many are currently active.
342pub fn get_registry_stats() -> ThreadRegistryStats {
343    match get_registry().lock() {
344        Ok(registry) => registry.get_stats(),
345        Err(e) => {
346            tracing::error!("Failed to acquire registry lock for stats: {}", e);
347            ThreadRegistryStats {
348                total_threads_registered: 0,
349                active_threads: 0,
350                dead_references: 0,
351            }
352        }
353    }
354}
355
356/// Clean up dead references from the registry.
357///
358/// This function can be called periodically to remove weak references
359/// to trackers whose threads have exited.
360pub fn cleanup_registry() {
361    if let Ok(mut registry) = get_registry().lock() {
362        registry.cleanup_dead_references();
363    } else {
364        tracing::error!("Failed to acquire registry lock for cleanup");
365    }
366}
367
368/// Check if the registry has any active trackers.
369pub fn has_active_trackers() -> bool {
370    match get_registry().lock() {
371        Ok(registry) => !registry.trackers.is_empty(),
372        Err(_) => false,
373    }
374}
375
376/// Enable precise tracking mode for maximum accuracy.
377///
378/// This configures all trackers to use thread-local mode and enables
379/// detailed tracking for precise allocation tracking.
380pub fn enable_precise_tracking() {
381    crate::core::tracker::configure_tracking_strategy(true);
382    tracing::info!("Enabled precise tracking mode with thread-local trackers");
383}
384
385/// Enable performance tracking mode for production use.
386///
387/// This configures trackers for minimal overhead while still providing
388/// useful tracking data.
389pub fn enable_performance_tracking() {
390    crate::core::tracker::configure_tracking_strategy(false);
391    tracing::info!("Enabled performance tracking mode with global singleton");
392}
393
394#[cfg(test)]
395mod tests {
396    use super::*;
397    use crate::core::tracker::memory_tracker::MemoryTracker;
398    use std::thread;
399    use std::time::Duration;
400
401    #[test]
402    fn test_thread_registry_registration() {
403        let tracker = Arc::new(MemoryTracker::new());
404        register_current_thread_tracker(&tracker);
405
406        let stats = get_registry_stats();
407        assert!(stats.active_threads > 0);
408        assert!(stats.total_threads_registered > 0);
409    }
410
411    #[test]
412    fn test_collect_trackers() {
413        let tracker = Arc::new(MemoryTracker::new());
414        register_current_thread_tracker(&tracker);
415
416        let collected = collect_all_trackers();
417        assert!(!collected.is_empty());
418    }
419
420    #[test]
421    fn test_unified_data_collection() {
422        let tracker = Arc::new(MemoryTracker::new());
423        register_current_thread_tracker(&tracker);
424
425        let result = collect_unified_tracking_data();
426        assert!(result.is_ok());
427
428        let data = result.unwrap();
429        assert!(data.tracker_count > 0);
430    }
431
432    #[test]
433    fn test_precise_tracking_mode() {
434        enable_precise_tracking();
435
436        // Test that multiple threads can register independently
437        let handles: Vec<_> = (0..3)
438            .map(|i| {
439                thread::spawn(move || {
440                    let tracker = Arc::new(MemoryTracker::new());
441                    register_current_thread_tracker(&tracker);
442                    thread::sleep(Duration::from_millis(10));
443                    i
444                })
445            })
446            .collect();
447
448        let results: Vec<_> = handles.into_iter().map(|h| h.join().unwrap()).collect();
449
450        assert_eq!(results.len(), 3);
451
452        // Verify that we can collect data from multiple threads
453        let stats = get_registry_stats();
454        assert!(stats.active_threads >= 1); // At least the main thread
455    }
456}