1use crate::unified::tracking_dispatcher::{
6 MemoryTracker, TrackerConfig, TrackerError, TrackerStatistics, TrackerType,
7};
8use std::cell::RefCell;
9use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
10use std::sync::Arc;
11use std::thread;
12use std::time::Instant;
13use tracing::{debug, info, warn};
14
15pub struct ThreadLocalStrategy {
19 config: Option<TrackerConfig>,
21 global_state: Arc<GlobalTrackingState>,
23 thread_registry: Arc<ThreadRegistry>,
25 global_metrics: Arc<GlobalMetrics>,
27}
28
29#[derive(Debug)]
32struct GlobalTrackingState {
33 is_active: AtomicU64, active_threads: AtomicUsize,
37 session_start_ns: AtomicU64,
39 next_allocation_id: AtomicU64,
41}
42
43#[derive(Debug)]
46struct ThreadRegistry {
47 total_registered_threads: AtomicUsize,
49 active_tracking_threads: AtomicUsize,
51}
52
53#[derive(Debug)]
56struct GlobalMetrics {
57 total_allocations: AtomicU64,
59 total_bytes_allocated: AtomicU64,
61 peak_concurrent_allocations: AtomicU64,
63 total_overhead_bytes: AtomicUsize,
65}
66
67#[derive(Debug, Clone)]
70struct ThreadLocalRecord {
71 global_id: u64,
73 _local_sequence: u64,
75 ptr: usize,
77 size: usize,
79 var_name: Option<String>,
81 type_name: String,
83 timestamp_alloc: u64,
85 timestamp_dealloc: Option<u64>,
87 thread_id: u64,
89 thread_name: Option<String>,
91}
92
93#[derive(Debug)]
96struct ThreadLocalData {
97 allocations: Vec<ThreadLocalRecord>,
99 local_sequence: u64,
101 thread_metrics: ThreadMetrics,
103 is_registered: bool,
105}
106
107#[derive(Debug, Clone)]
109struct ThreadMetrics {
110 allocations_count: u64,
112 bytes_allocated: u64,
114 _start_time_ns: u64,
116 avg_allocation_time_ns: f64,
118 thread_overhead_bytes: usize,
120}
121
122thread_local! {
124 static THREAD_LOCAL_DATA: RefCell<ThreadLocalData> = const{ RefCell::new(ThreadLocalData {
127 allocations: Vec::new(),
128 local_sequence: 0,
129 thread_metrics: ThreadMetrics {
130 allocations_count: 0,
131 bytes_allocated: 0,
132 _start_time_ns: 0,
133 avg_allocation_time_ns: 0.0,
134 thread_overhead_bytes: 0,
135 },
136 is_registered: false,
137 })};
138}
139
140impl Default for GlobalTrackingState {
141 fn default() -> Self {
143 Self {
144 is_active: AtomicU64::new(0),
145 active_threads: AtomicUsize::new(0),
146 session_start_ns: AtomicU64::new(0),
147 next_allocation_id: AtomicU64::new(1),
148 }
149 }
150}
151
152impl Default for ThreadRegistry {
153 fn default() -> Self {
155 Self {
156 total_registered_threads: AtomicUsize::new(0),
157 active_tracking_threads: AtomicUsize::new(0),
158 }
159 }
160}
161
162impl Default for GlobalMetrics {
163 fn default() -> Self {
165 Self {
166 total_allocations: AtomicU64::new(0),
167 total_bytes_allocated: AtomicU64::new(0),
168 peak_concurrent_allocations: AtomicU64::new(0),
169 total_overhead_bytes: AtomicUsize::new(0),
170 }
171 }
172}
173
174impl ThreadLocalStrategy {
175 pub fn new() -> Self {
178 debug!("Creating new thread-local strategy");
179
180 Self {
181 config: None,
182 global_state: Arc::new(GlobalTrackingState::default()),
183 thread_registry: Arc::new(ThreadRegistry::default()),
184 global_metrics: Arc::new(GlobalMetrics::default()),
185 }
186 }
187
188 pub fn register_current_thread(&self) -> Result<(), TrackerError> {
191 let thread_id = Self::get_current_thread_id();
192 let thread_name = thread::current().name().map(|s| s.to_string());
193
194 debug!(
195 "Registering thread for tracking: id={}, name={:?}",
196 thread_id, thread_name
197 );
198
199 THREAD_LOCAL_DATA.with(|data| {
200 let mut thread_data = data.borrow_mut();
201
202 if thread_data.is_registered {
203 debug!("Thread already registered: {}", thread_id);
204 return Ok(());
205 }
206
207 thread_data.is_registered = true;
209 thread_data.local_sequence = 0;
210 thread_data.allocations.clear();
211 thread_data.thread_metrics = ThreadMetrics {
212 allocations_count: 0,
213 bytes_allocated: 0,
214 _start_time_ns: Self::get_timestamp_ns(),
215 avg_allocation_time_ns: 0.0,
216 thread_overhead_bytes: std::mem::size_of::<ThreadLocalData>(),
217 };
218
219 self.thread_registry
221 .total_registered_threads
222 .fetch_add(1, Ordering::Relaxed);
223
224 info!("Thread registered successfully: id={}", thread_id);
225 Ok(())
226 })
227 }
228
229 pub fn track_allocation(
232 &self,
233 ptr: usize,
234 size: usize,
235 var_name: Option<String>,
236 type_name: String,
237 ) -> Result<(), TrackerError> {
238 if self.global_state.is_active.load(Ordering::Relaxed) == 0 {
239 return Err(TrackerError::StartFailed {
240 reason: "Tracking not active".to_string(),
241 });
242 }
243
244 let start_time = Instant::now();
245 let thread_id = Self::get_current_thread_id();
246
247 THREAD_LOCAL_DATA.with(|data| {
248 let mut thread_data = data.borrow_mut();
249
250 if !thread_data.is_registered {
251 drop(thread_data);
253 self.register_current_thread()?;
254 thread_data = data.borrow_mut();
255 }
256
257 let global_id = self
259 .global_state
260 .next_allocation_id
261 .fetch_add(1, Ordering::Relaxed);
262 thread_data.local_sequence += 1;
263
264 let record = ThreadLocalRecord {
266 global_id,
267 _local_sequence: thread_data.local_sequence,
268 ptr,
269 size,
270 var_name,
271 type_name,
272 timestamp_alloc: Self::get_timestamp_ns(),
273 timestamp_dealloc: None,
274 thread_id,
275 thread_name: thread::current().name().map(|s| s.to_string()),
276 };
277
278 thread_data.allocations.push(record);
280
281 thread_data.thread_metrics.allocations_count += 1;
283 thread_data.thread_metrics.bytes_allocated += size as u64;
284
285 let allocation_time_ns = start_time.elapsed().as_nanos() as f64;
287 let weight = 0.1; thread_data.thread_metrics.avg_allocation_time_ns = (1.0 - weight)
289 * thread_data.thread_metrics.avg_allocation_time_ns
290 + weight * allocation_time_ns;
291
292 self.global_metrics
294 .total_allocations
295 .fetch_add(1, Ordering::Relaxed);
296 self.global_metrics
297 .total_bytes_allocated
298 .fetch_add(size as u64, Ordering::Relaxed);
299
300 debug!(
301 "Tracked allocation in thread {}: ptr={:x}, size={}, global_id={}",
302 thread_id, ptr, size, global_id
303 );
304
305 Ok(())
306 })
307 }
308
309 pub fn track_deallocation(&self, ptr: usize) -> Result<(), TrackerError> {
312 if self.global_state.is_active.load(Ordering::Relaxed) == 0 {
313 return Err(TrackerError::StartFailed {
314 reason: "Tracking not active".to_string(),
315 });
316 }
317
318 let timestamp = Self::get_timestamp_ns();
319 let thread_id = Self::get_current_thread_id();
320
321 THREAD_LOCAL_DATA.with(|data| {
322 let mut thread_data = data.borrow_mut();
323
324 if let Some(record) = thread_data
326 .allocations
327 .iter_mut()
328 .find(|r| r.ptr == ptr && r.timestamp_dealloc.is_none())
329 {
330 record.timestamp_dealloc = Some(timestamp);
331 debug!(
332 "Tracked deallocation in thread {}: ptr={:x}, global_id={}",
333 thread_id, ptr, record.global_id
334 );
335 Ok(())
336 } else {
337 warn!(
339 "Deallocation tracked for unknown pointer in thread {}: {:x}",
340 thread_id, ptr
341 );
342 Ok(()) }
344 })
345 }
346
347 fn collect_all_thread_data(&self) -> Result<Vec<ThreadLocalRecord>, TrackerError> {
350 debug!("Collecting data from all registered threads");
351
352 let mut all_records = Vec::new();
353 let mut total_overhead = 0;
354
355 THREAD_LOCAL_DATA.with(|data| {
363 let thread_data = data.borrow();
364 all_records.extend(thread_data.allocations.clone());
365 total_overhead += thread_data.thread_metrics.thread_overhead_bytes;
366 });
367
368 self.global_metrics
370 .total_overhead_bytes
371 .store(total_overhead, Ordering::Relaxed);
372
373 all_records.sort_by_key(|r| r.global_id);
375
376 info!(
377 "Collected {} records from threads, total overhead: {} bytes",
378 all_records.len(),
379 total_overhead
380 );
381
382 Ok(all_records)
383 }
384
385 fn export_as_json(&self) -> Result<String, TrackerError> {
387 let records = self.collect_all_thread_data()?;
388
389 let mut allocations = Vec::new();
391
392 for record in records.iter() {
393 let mut allocation = serde_json::Map::new();
394
395 allocation.insert(
396 "ptr".to_string(),
397 serde_json::Value::String(format!("{:x}", record.ptr)),
398 );
399 allocation.insert(
400 "size".to_string(),
401 serde_json::Value::Number(serde_json::Number::from(record.size)),
402 );
403 allocation.insert(
404 "timestamp_alloc".to_string(),
405 serde_json::Value::Number(serde_json::Number::from(record.timestamp_alloc)),
406 );
407
408 if let Some(var_name) = &record.var_name {
409 allocation.insert(
410 "var_name".to_string(),
411 serde_json::Value::String(var_name.clone()),
412 );
413 }
414
415 allocation.insert(
416 "type_name".to_string(),
417 serde_json::Value::String(record.type_name.clone()),
418 );
419
420 if let Some(timestamp_dealloc) = record.timestamp_dealloc {
421 allocation.insert(
422 "timestamp_dealloc".to_string(),
423 serde_json::Value::Number(serde_json::Number::from(timestamp_dealloc)),
424 );
425 }
426
427 allocation.insert(
429 "thread_id".to_string(),
430 serde_json::Value::Number(serde_json::Number::from(record.thread_id)),
431 );
432
433 if let Some(thread_name) = &record.thread_name {
434 allocation.insert(
435 "thread_name".to_string(),
436 serde_json::Value::String(thread_name.clone()),
437 );
438 }
439
440 allocation.insert(
441 "tracking_strategy".to_string(),
442 serde_json::Value::String("thread_local".to_string()),
443 );
444
445 allocations.push(serde_json::Value::Object(allocation));
446 }
447
448 let mut output = serde_json::Map::new();
449 output.insert(
450 "allocations".to_string(),
451 serde_json::Value::Array(allocations),
452 );
453 output.insert("strategy_metadata".to_string(), serde_json::json!({
454 "strategy_type": "thread_local",
455 "total_allocations": self.global_metrics.total_allocations.load(Ordering::Relaxed),
456 "total_bytes": self.global_metrics.total_bytes_allocated.load(Ordering::Relaxed),
457 "total_threads": self.thread_registry.total_registered_threads.load(Ordering::Relaxed),
458 "active_threads": self.thread_registry.active_tracking_threads.load(Ordering::Relaxed),
459 "overhead_bytes": self.global_metrics.total_overhead_bytes.load(Ordering::Relaxed)
460 }));
461
462 serde_json::to_string_pretty(&output).map_err(|e| TrackerError::DataCollectionFailed {
463 reason: format!("JSON serialization failed: {}", e),
464 })
465 }
466
467 fn get_current_thread_id() -> u64 {
469 use std::collections::hash_map::DefaultHasher;
471 use std::hash::{Hash, Hasher};
472
473 let mut hasher = DefaultHasher::new();
474 thread::current().id().hash(&mut hasher);
475 hasher.finish()
476 }
477
478 fn get_timestamp_ns() -> u64 {
480 std::time::SystemTime::now()
481 .duration_since(std::time::UNIX_EPOCH)
482 .map(|d| d.as_nanos() as u64)
483 .unwrap_or(0)
484 }
485}
486
487impl MemoryTracker for ThreadLocalStrategy {
488 fn initialize(&mut self, config: TrackerConfig) -> Result<(), TrackerError> {
491 debug!(
492 "Initializing thread-local strategy with config: {:?}",
493 config
494 );
495
496 if config.sample_rate < 0.0 || config.sample_rate > 1.0 {
498 return Err(TrackerError::InvalidConfiguration {
499 reason: "Sample rate must be between 0.0 and 1.0".to_string(),
500 });
501 }
502
503 if config.max_overhead_mb == 0 {
504 return Err(TrackerError::InvalidConfiguration {
505 reason: "Maximum overhead must be greater than 0".to_string(),
506 });
507 }
508
509 self.config = Some(config);
511
512 self.global_state.is_active.store(0, Ordering::Relaxed);
514 self.global_state.active_threads.store(0, Ordering::Relaxed);
515 self.global_state
516 .next_allocation_id
517 .store(1, Ordering::Relaxed);
518
519 self.global_metrics
521 .total_allocations
522 .store(0, Ordering::Relaxed);
523 self.global_metrics
524 .total_bytes_allocated
525 .store(0, Ordering::Relaxed);
526 self.global_metrics
527 .peak_concurrent_allocations
528 .store(0, Ordering::Relaxed);
529 self.global_metrics
530 .total_overhead_bytes
531 .store(0, Ordering::Relaxed);
532
533 self.thread_registry
535 .total_registered_threads
536 .store(0, Ordering::Relaxed);
537 self.thread_registry
538 .active_tracking_threads
539 .store(0, Ordering::Relaxed);
540
541 info!("Thread-local strategy initialized successfully");
542 Ok(())
543 }
544
545 fn start_tracking(&mut self) -> Result<(), TrackerError> {
547 debug!("Starting thread-local tracking");
548
549 let was_active = self.global_state.is_active.swap(1, Ordering::Relaxed);
551 if was_active == 1 {
552 warn!("Thread-local tracking was already active");
553 return Ok(()); }
555
556 self.global_state
558 .session_start_ns
559 .store(Self::get_timestamp_ns(), Ordering::Relaxed);
560
561 self.register_current_thread()?;
563
564 info!("Thread-local tracking started successfully");
565 Ok(())
566 }
567
568 fn stop_tracking(&mut self) -> Result<Vec<u8>, TrackerError> {
570 debug!("Stopping thread-local tracking");
571
572 let was_active = self.global_state.is_active.swap(0, Ordering::Relaxed);
574 if was_active == 0 {
575 warn!("Thread-local tracking was not active");
576 }
577
578 let json_data = self.export_as_json()?;
580
581 let total_allocations = self
582 .global_metrics
583 .total_allocations
584 .load(Ordering::Relaxed);
585 let total_threads = self
586 .thread_registry
587 .total_registered_threads
588 .load(Ordering::Relaxed);
589
590 info!(
591 "Thread-local tracking stopped: {} allocations from {} threads",
592 total_allocations, total_threads
593 );
594
595 Ok(json_data.into_bytes())
596 }
597
598 fn get_statistics(&self) -> TrackerStatistics {
600 TrackerStatistics {
601 allocations_tracked: self
602 .global_metrics
603 .total_allocations
604 .load(Ordering::Relaxed),
605 memory_tracked_bytes: self
606 .global_metrics
607 .total_bytes_allocated
608 .load(Ordering::Relaxed),
609 overhead_bytes: self
610 .global_metrics
611 .total_overhead_bytes
612 .load(Ordering::Relaxed) as u64,
613 tracking_duration_ms: {
614 let start_ns = self.global_state.session_start_ns.load(Ordering::Relaxed);
615 if start_ns > 0 {
616 (Self::get_timestamp_ns() - start_ns) / 1_000_000
617 } else {
618 0
619 }
620 },
621 }
622 }
623
624 fn is_active(&self) -> bool {
626 self.global_state.is_active.load(Ordering::Relaxed) == 1
627 }
628
629 fn tracker_type(&self) -> TrackerType {
631 TrackerType::MultiThread
632 }
633}
634
635impl Default for ThreadLocalStrategy {
636 fn default() -> Self {
638 Self::new()
639 }
640}
641
642#[cfg(test)]
643mod tests {
644 use super::*;
645 use std::sync::Barrier;
646 use std::thread;
647 use std::time::Duration;
648
649 #[test]
650 fn test_strategy_creation() {
651 let strategy = ThreadLocalStrategy::new();
652 assert!(!strategy.is_active());
653 assert_eq!(strategy.tracker_type(), TrackerType::MultiThread);
654 }
655
656 #[test]
657 fn test_strategy_initialization() {
658 let mut strategy = ThreadLocalStrategy::new();
659 let config = TrackerConfig::default();
660
661 let result = strategy.initialize(config);
662 assert!(result.is_ok());
663 assert!(!strategy.is_active()); }
665
666 #[test]
667 fn test_thread_registration() {
668 let strategy = ThreadLocalStrategy::new();
669
670 let result = strategy.register_current_thread();
671 assert!(result.is_ok());
672
673 let total_threads = strategy
675 .thread_registry
676 .total_registered_threads
677 .load(Ordering::Relaxed);
678 assert_eq!(total_threads, 1);
679 }
680
681 #[test]
682 fn test_single_thread_tracking() {
683 let mut strategy = ThreadLocalStrategy::new();
684 strategy
685 .initialize(TrackerConfig::default())
686 .expect("Strategy initialization should succeed");
687 strategy
688 .start_tracking()
689 .expect("Strategy should start tracking successfully");
690
691 let result = strategy.track_allocation(
693 0x1000,
694 128,
695 Some("test_var".to_string()),
696 "TestType".to_string(),
697 );
698 assert!(result.is_ok());
699
700 let stats = strategy.get_statistics();
702 assert_eq!(stats.allocations_tracked, 1);
703 assert_eq!(stats.memory_tracked_bytes, 128);
704
705 assert!(strategy.track_deallocation(0x1000).is_ok());
707
708 let data = strategy.stop_tracking();
710 assert!(data.is_ok());
711 }
712
713 #[test]
714 fn test_multi_thread_tracking() {
715 let mut strategy = ThreadLocalStrategy::new();
716 strategy
717 .initialize(TrackerConfig::default())
718 .expect("Multi-thread strategy initialization should succeed");
719 strategy
720 .start_tracking()
721 .expect("Multi-thread strategy should start tracking successfully");
722
723 let strategy = Arc::new(strategy);
724 let barrier = Arc::new(Barrier::new(3)); let mut handles = vec![];
726
727 for thread_id in 0..2 {
729 let strategy_clone = Arc::clone(&strategy);
730 let barrier_clone = Arc::clone(&barrier);
731
732 let handle = thread::spawn(move || {
733 strategy_clone
735 .register_current_thread()
736 .expect("Thread registration should succeed");
737
738 barrier_clone.wait();
740
741 for i in 0..10 {
743 let ptr = 0x1000 + (thread_id * 1000) + (i * 0x10);
744 let size = 64 + i * 8;
745
746 let result = strategy_clone.track_allocation(
747 ptr,
748 size,
749 Some(format!("var_{}_{}", thread_id, i)),
750 "TestType".to_string(),
751 );
752 assert!(result.is_ok());
753
754 thread::sleep(Duration::from_micros(100));
756 }
757 });
758
759 handles.push(handle);
760 }
761
762 barrier.wait();
764
765 for handle in handles {
767 handle
768 .join()
769 .expect("Worker thread should complete successfully");
770 }
771
772 let stats = strategy.get_statistics();
774 assert_eq!(stats.allocations_tracked, 20); let total_threads = strategy
778 .thread_registry
779 .total_registered_threads
780 .load(Ordering::Relaxed);
781 assert_eq!(total_threads, 3); }
783
784 #[test]
785 fn test_json_export_format() {
786 let mut strategy = ThreadLocalStrategy::new();
787 strategy
788 .initialize(TrackerConfig::default())
789 .expect("Strategy initialization should succeed for cleanup test");
790 strategy
791 .start_tracking()
792 .expect("Strategy should start tracking successfully for cleanup test");
793
794 strategy
796 .track_allocation(
797 0x1000,
798 256,
799 Some("test_variable".to_string()),
800 "TestStruct".to_string(),
801 )
802 .expect("Allocation tracking should succeed");
803
804 let data = strategy
805 .stop_tracking()
806 .expect("Strategy should stop tracking and return data successfully");
807 let json_str = String::from_utf8(data).expect("Strategy data should be valid UTF-8");
808
809 let parsed: serde_json::Value =
811 serde_json::from_str(&json_str).expect("Generated JSON should be valid");
812 assert!(parsed["allocations"].is_array());
813 assert!(parsed["strategy_metadata"].is_object());
814
815 let allocations = parsed["allocations"].as_array().unwrap();
816 assert_eq!(allocations.len(), 1);
817
818 let first_alloc = &allocations[0];
819 assert_eq!(first_alloc["ptr"].as_str().unwrap(), "1000");
820 assert_eq!(first_alloc["size"].as_u64().unwrap(), 256);
821 assert_eq!(first_alloc["var_name"].as_str().unwrap(), "test_variable");
822 assert_eq!(first_alloc["type_name"].as_str().unwrap(), "TestStruct");
823 assert_eq!(
824 first_alloc["tracking_strategy"].as_str().unwrap(),
825 "thread_local"
826 );
827 assert!(first_alloc["thread_id"].is_number());
828 }
829
830 #[test]
831 fn test_global_metrics_atomicity() {
832 let mut strategy = ThreadLocalStrategy::new();
833 strategy.initialize(TrackerConfig::default()).unwrap();
834 strategy.start_tracking().unwrap();
835
836 for i in 0..100 {
838 strategy
839 .track_allocation(0x1000 + i * 0x10, 64, None, "TestType".to_string())
840 .unwrap();
841 }
842
843 let stats = strategy.get_statistics();
844 assert_eq!(stats.allocations_tracked, 100);
845 assert_eq!(stats.memory_tracked_bytes, 6400);
846
847 let global_allocs = strategy
849 .global_metrics
850 .total_allocations
851 .load(Ordering::Relaxed);
852 let global_bytes = strategy
853 .global_metrics
854 .total_bytes_allocated
855 .load(Ordering::Relaxed);
856
857 assert_eq!(global_allocs, 100);
858 assert_eq!(global_bytes, 6400);
859 }
860}