1use crate::core::tracker::memory_tracker::MemoryTracker;
8use crate::core::types::MemoryStats;
9use std::collections::HashMap;
10use std::sync::{Arc, Mutex, Weak};
11use std::thread::ThreadId;
12
13static THREAD_REGISTRY: std::sync::OnceLock<Arc<Mutex<ThreadRegistry>>> =
15 std::sync::OnceLock::new();
16
17struct ThreadRegistry {
19 trackers: HashMap<ThreadId, Weak<MemoryTracker>>,
21 cached_thread_data: HashMap<ThreadId, CachedThreadData>,
23 total_threads_registered: usize,
25 active_threads: usize,
27}
28
29#[derive(Debug, Clone)]
31pub struct CachedThreadData {
32 pub thread_id: ThreadId,
34 pub stats: MemoryStats,
36 #[allow(dead_code)]
38 pub cached_at: std::time::SystemTime,
39}
40
41impl ThreadRegistry {
42 fn new() -> Self {
44 Self {
45 trackers: HashMap::new(),
46 cached_thread_data: HashMap::new(),
47 total_threads_registered: 0,
48 active_threads: 0,
49 }
50 }
51
52 fn register_tracker(&mut self, thread_id: ThreadId, tracker: &Arc<MemoryTracker>) {
54 self.cleanup_dead_references();
56
57 self.trackers.insert(thread_id, Arc::downgrade(tracker));
59 self.total_threads_registered += 1;
60 self.active_threads = self.trackers.len();
61
62 tracing::debug!(
63 "Registered thread {:?}, total threads: {}, active threads: {}",
64 thread_id,
65 self.total_threads_registered,
66 self.active_threads
67 );
68 }
69
70 fn collect_active_trackers(&mut self) -> Vec<Arc<MemoryTracker>> {
72 self.cache_all_available_data();
74
75 let mut active_trackers = Vec::new();
77 let mut dead_thread_ids = Vec::new();
78
79 for (thread_id, weak_tracker) in &self.trackers {
80 if let Some(strong_tracker) = weak_tracker.upgrade() {
81 active_trackers.push(strong_tracker);
82 tracing::debug!("Successfully collected tracker for thread {:?}", thread_id);
83 } else {
84 if self.cached_thread_data.contains_key(thread_id) {
86 tracing::debug!("Found cached data for dead thread {:?}", thread_id);
87 } else {
88 tracing::debug!(
89 "Found dead tracker reference with no cached data for thread {:?}",
90 thread_id
91 );
92 }
93 dead_thread_ids.push(*thread_id);
94 }
95 }
96
97 tracing::debug!(
98 "Collected {} active trackers, {} dead with cached data, {} total cached entries",
99 active_trackers.len(),
100 dead_thread_ids.len(),
101 self.cached_thread_data.len()
102 );
103
104 active_trackers
105 }
106
107 fn cache_all_available_data(&mut self) {
109 for (thread_id, weak_tracker) in &self.trackers {
110 if let Some(strong_tracker) = weak_tracker.upgrade() {
111 if let Ok(stats) = strong_tracker.get_stats() {
112 if stats.total_allocations > 0 {
114 let allocations = stats.total_allocations;
115 let allocated = stats.total_allocated;
116
117 self.cached_thread_data.insert(
118 *thread_id,
119 CachedThreadData {
120 thread_id: *thread_id,
121 stats,
122 cached_at: std::time::SystemTime::now(),
123 },
124 );
125 tracing::debug!(
126 "Cached data for thread {:?}: {} allocations, {} bytes",
127 thread_id,
128 allocations,
129 allocated
130 );
131 }
132 }
133 }
134 }
135 }
136
137 fn cleanup_dead_references(&mut self) {
139 let initial_count = self.trackers.len();
140 self.trackers
141 .retain(|_thread_id, weak_tracker| weak_tracker.strong_count() > 0);
142
143 let removed_count = initial_count - self.trackers.len();
144 if removed_count > 0 {
145 tracing::debug!("Cleaned up {} dead tracker references", removed_count);
146 }
147
148 self.active_threads = self.trackers.len();
149 }
150
151 fn get_stats(&self) -> ThreadRegistryStats {
153 ThreadRegistryStats {
154 total_threads_registered: self.total_threads_registered,
155 active_threads: self.active_threads,
156 dead_references: self
157 .trackers
158 .iter()
159 .filter(|(_, weak)| weak.strong_count() == 0)
160 .count(),
161 }
162 }
163}
164
165#[derive(Debug, Clone)]
167pub struct ThreadRegistryStats {
168 pub total_threads_registered: usize,
170 pub active_threads: usize,
172 pub dead_references: usize,
174}
175
176#[derive(Debug, Clone)]
178pub struct AggregatedTrackingData {
179 pub tracker_count: usize,
181 pub total_allocations: u64,
183 pub total_bytes_allocated: u64,
185 pub peak_memory_usage: u64,
187 pub active_threads: usize,
189 pub combined_stats: Vec<CombinedTrackerStats>,
191}
192
193#[derive(Debug, Clone)]
195pub struct CombinedTrackerStats {
196 pub thread_id: ThreadId,
198 pub tracking_mode: String,
200 pub allocations: u64,
202 pub bytes_allocated: u64,
204 pub peak_memory: u64,
206}
207
208fn get_registry() -> Arc<Mutex<ThreadRegistry>> {
210 THREAD_REGISTRY
211 .get_or_init(|| Arc::new(Mutex::new(ThreadRegistry::new())))
212 .clone()
213}
214
215pub fn register_current_thread_tracker(tracker: &Arc<MemoryTracker>) {
221 let thread_id = std::thread::current().id();
222
223 if let Ok(mut registry) = get_registry().lock() {
224 registry.register_tracker(thread_id, tracker);
225 } else {
226 tracing::error!("Failed to acquire registry lock for thread registration");
227 }
228}
229
230pub fn collect_unified_tracking_data() -> Result<AggregatedTrackingData, String> {
237 let mut combined_stats = Vec::new();
238 let mut total_allocations = 0u64;
239 let mut total_bytes_allocated = 0u64;
240 let mut peak_memory_usage = 0u64;
241
242 let active_trackers = collect_all_trackers();
244 let cached_data = get_cached_thread_data();
245
246 for tracker in &active_trackers {
248 if let Ok(stats) = tracker.get_stats() {
249 let thread_stats = CombinedTrackerStats {
250 thread_id: std::thread::current().id(), tracking_mode: "track_var!".to_string(),
252 allocations: stats.total_allocations as u64,
253 bytes_allocated: stats.total_allocated as u64,
254 peak_memory: stats.peak_memory as u64,
255 };
256
257 total_allocations += stats.total_allocations as u64;
258 total_bytes_allocated += stats.total_allocated as u64;
259 peak_memory_usage = peak_memory_usage.max(stats.peak_memory as u64);
260
261 combined_stats.push(thread_stats);
262 }
263 }
264
265 for cached in cached_data {
267 let thread_stats = CombinedTrackerStats {
268 thread_id: cached.thread_id,
269 tracking_mode: "track_var!".to_string(),
270 allocations: cached.stats.total_allocations as u64,
271 bytes_allocated: cached.stats.total_allocated as u64,
272 peak_memory: cached.stats.peak_memory as u64,
273 };
274
275 total_allocations += cached.stats.total_allocations as u64;
276 total_bytes_allocated += cached.stats.total_allocated as u64;
277 peak_memory_usage = peak_memory_usage.max(cached.stats.peak_memory as u64);
278
279 combined_stats.push(thread_stats);
280 }
281
282 let aggregated_data = AggregatedTrackingData {
289 tracker_count: active_trackers.len(),
290 total_allocations,
291 total_bytes_allocated,
292 peak_memory_usage,
293 active_threads: active_trackers.len(),
294 combined_stats,
295 };
296
297 tracing::info!(
298 "Collected unified tracking data: {} trackers, {} allocations, {} bytes",
299 aggregated_data.tracker_count,
300 aggregated_data.total_allocations,
301 aggregated_data.total_bytes_allocated
302 );
303
304 Ok(aggregated_data)
305}
306
307pub fn collect_all_trackers() -> Vec<Arc<MemoryTracker>> {
312 match get_registry().lock() {
313 Ok(mut registry) => registry.collect_active_trackers(),
314 Err(e) => {
315 tracing::error!(
316 "Failed to acquire registry lock for tracker collection: {}",
317 e
318 );
319 Vec::new()
320 }
321 }
322}
323
324pub fn get_cached_thread_data() -> Vec<CachedThreadData> {
329 match get_registry().lock() {
330 Ok(registry) => registry.cached_thread_data.values().cloned().collect(),
331 Err(e) => {
332 tracing::error!("Failed to acquire registry lock for cached data: {}", e);
333 Vec::new()
334 }
335 }
336}
337
338pub fn get_registry_stats() -> ThreadRegistryStats {
343 match get_registry().lock() {
344 Ok(registry) => registry.get_stats(),
345 Err(e) => {
346 tracing::error!("Failed to acquire registry lock for stats: {}", e);
347 ThreadRegistryStats {
348 total_threads_registered: 0,
349 active_threads: 0,
350 dead_references: 0,
351 }
352 }
353 }
354}
355
356pub fn cleanup_registry() {
361 if let Ok(mut registry) = get_registry().lock() {
362 registry.cleanup_dead_references();
363 } else {
364 tracing::error!("Failed to acquire registry lock for cleanup");
365 }
366}
367
368pub fn has_active_trackers() -> bool {
370 match get_registry().lock() {
371 Ok(registry) => !registry.trackers.is_empty(),
372 Err(_) => false,
373 }
374}
375
376pub fn enable_precise_tracking() {
381 crate::core::tracker::configure_tracking_strategy(true);
382 tracing::info!("Enabled precise tracking mode with thread-local trackers");
383}
384
385pub fn enable_performance_tracking() {
390 crate::core::tracker::configure_tracking_strategy(false);
391 tracing::info!("Enabled performance tracking mode with global singleton");
392}
393
394#[cfg(test)]
395mod tests {
396 use super::*;
397 use crate::core::tracker::memory_tracker::MemoryTracker;
398 use std::thread;
399 use std::time::Duration;
400
401 #[test]
402 fn test_thread_registry_registration() {
403 let tracker = Arc::new(MemoryTracker::new());
404 register_current_thread_tracker(&tracker);
405
406 let stats = get_registry_stats();
407 assert!(stats.active_threads > 0);
408 assert!(stats.total_threads_registered > 0);
409 }
410
411 #[test]
412 fn test_collect_trackers() {
413 let tracker = Arc::new(MemoryTracker::new());
414 register_current_thread_tracker(&tracker);
415
416 let collected = collect_all_trackers();
417 assert!(!collected.is_empty());
418 }
419
420 #[test]
421 fn test_unified_data_collection() {
422 let tracker = Arc::new(MemoryTracker::new());
423 register_current_thread_tracker(&tracker);
424
425 let result = collect_unified_tracking_data();
426 assert!(result.is_ok());
427
428 let data = result.unwrap();
429 assert!(data.tracker_count > 0);
430 }
431
432 #[test]
433 fn test_precise_tracking_mode() {
434 enable_precise_tracking();
435
436 let handles: Vec<_> = (0..3)
438 .map(|i| {
439 thread::spawn(move || {
440 let tracker = Arc::new(MemoryTracker::new());
441 register_current_thread_tracker(&tracker);
442 thread::sleep(Duration::from_millis(10));
443 i
444 })
445 })
446 .collect();
447
448 let results: Vec<_> = handles.into_iter().map(|h| h.join().unwrap()).collect();
449
450 assert_eq!(results.len(), 3);
451
452 let stats = get_registry_stats();
454 assert!(stats.active_threads >= 1); }
456}