1use std::{
7 fs::File,
8 io::Write,
9 path::PathBuf,
10 sync::{
11 atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering},
12 Arc, OnceLock,
13 },
14 thread::ThreadId,
15};
16
17use crossbeam::queue::SegQueue;
18use dashmap::DashMap;
19
20use super::lockfree_types::{Event, EventType, MemoryStats};
21
22static TRACKING_ENABLED: AtomicBool = AtomicBool::new(false);
23static OUTPUT_DIRECTORY: OnceLock<std::path::PathBuf> = OnceLock::new();
24
25pub struct ThreadLocalTracker {
30 thread_id: ThreadId,
31 events: Arc<SegQueue<Event>>,
32 active_allocations: Arc<DashMap<usize, usize>>,
33 total_allocations: AtomicU64,
34 total_allocated: AtomicU64,
35 total_deallocations: AtomicU64,
36 total_deallocated: AtomicU64,
37 active_memory: AtomicU64,
38 peak_memory: AtomicU64,
39 output_file: PathBuf,
40 sample_rate: f64,
41 total_seen: AtomicUsize,
42 total_tracked: AtomicUsize,
43}
44
45impl ThreadLocalTracker {
46 pub fn new(thread_id: ThreadId, output_file: PathBuf, sample_rate: f64) -> Self {
47 Self {
48 thread_id,
49 events: Arc::new(SegQueue::new()),
50 active_allocations: Arc::new(DashMap::new()),
51 total_allocations: AtomicU64::new(0),
52 total_allocated: AtomicU64::new(0),
53 total_deallocations: AtomicU64::new(0),
54 total_deallocated: AtomicU64::new(0),
55 active_memory: AtomicU64::new(0),
56 peak_memory: AtomicU64::new(0),
57 output_file,
58 sample_rate: sample_rate.clamp(0.0, 1.0),
59 total_seen: AtomicUsize::new(0),
60 total_tracked: AtomicUsize::new(0),
61 }
62 }
63
64 pub fn track_allocation(&self, ptr: usize, size: usize, call_stack_hash: u64) {
65 self.total_seen.fetch_add(1, Ordering::Relaxed);
66
67 if self.sample_rate < 1.0 {
68 let sample_decision = rand::random::<f64>();
69 if sample_decision >= self.sample_rate {
70 return;
71 }
72 }
73
74 self.total_tracked.fetch_add(1, Ordering::Relaxed);
75
76 let event = Event::allocation(ptr, size, call_stack_hash, self.thread_id);
77 self.events.push(event);
78
79 self.active_allocations.insert(ptr, size);
80
81 self.total_allocations.fetch_add(1, Ordering::Relaxed);
82 self.total_allocated
83 .fetch_add(size as u64, Ordering::Relaxed);
84
85 let new_active = self.active_memory.fetch_add(size as u64, Ordering::Relaxed) + size as u64;
86
87 let mut current_peak = self.peak_memory.load(Ordering::Relaxed);
89 let mut backoff_count = 0u32;
90 const MAX_BACKOFF_ATTEMPTS: u32 = 10;
91
92 while new_active > current_peak {
93 match self.peak_memory.compare_exchange_weak(
94 current_peak,
95 new_active,
96 Ordering::Relaxed,
97 Ordering::Relaxed,
98 ) {
99 Ok(_) => break,
100 Err(actual) => {
101 current_peak = actual;
102 backoff_count += 1;
103
104 if backoff_count < MAX_BACKOFF_ATTEMPTS {
106 std::hint::spin_loop();
108 } else if backoff_count < MAX_BACKOFF_ATTEMPTS * 2 {
109 std::thread::yield_now();
111 } else {
112 std::thread::sleep(std::time::Duration::from_micros(1));
114 }
115 }
116 }
117 }
118 }
119
120 pub fn track_deallocation(&self, ptr: usize, call_stack_hash: u64) {
121 let size = self
122 .active_allocations
123 .remove(&ptr)
124 .map(|(_, v)| v)
125 .unwrap_or(0);
126
127 let event = Event::deallocation(ptr, size, call_stack_hash, self.thread_id);
128 self.events.push(event);
129
130 self.total_deallocations.fetch_add(1, Ordering::Relaxed);
131 self.total_deallocated
132 .fetch_add(size as u64, Ordering::Relaxed);
133
134 let _ = self
136 .active_memory
137 .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| {
138 Some(current.saturating_sub(size as u64))
139 });
140 }
141
142 pub fn get_stats(&self) -> MemoryStats {
143 MemoryStats {
144 total_allocations: self.total_allocations.load(Ordering::Relaxed) as usize,
145 total_allocated: self.total_allocated.load(Ordering::Relaxed) as usize,
146 total_deallocations: self.total_deallocations.load(Ordering::Relaxed) as usize,
147 total_deallocated: self.total_deallocated.load(Ordering::Relaxed) as usize,
148 active_memory: self.active_memory.load(Ordering::Relaxed) as usize,
149 peak_memory: self.peak_memory.load(Ordering::Relaxed) as usize,
150 }
151 }
152
153 pub fn get_sampling_stats(&self) -> (usize, usize) {
154 (
155 self.total_seen.load(Ordering::Relaxed),
156 self.total_tracked.load(Ordering::Relaxed),
157 )
158 }
159
160 pub fn finalize(&self) -> std::io::Result<()> {
161 let mut events = Vec::new();
162 while let Some(event) = self.events.pop() {
163 events.push(event);
164 }
165
166 if events.is_empty() {
167 return Ok(());
168 }
169
170 if let Some(parent) = self.output_file.parent() {
171 std::fs::create_dir_all(parent)?;
172 }
173
174 let mut file = File::create(&self.output_file)?;
175
176 let header = "MEMSCOPE_LOCKFREE";
177 file.write_all(header.as_bytes())?;
178
179 for event in events {
180 self.write_event(&mut file, &event)?;
181 }
182
183 file.flush()?;
184 Ok(())
185 }
186
187 fn write_event(&self, file: &mut File, event: &Event) -> std::io::Result<()> {
188 let event_type_byte = match event.event_type {
189 EventType::Allocation => 1u8,
190 EventType::Deallocation => 2u8,
191 EventType::Clone => 3u8,
192 EventType::Move => 4u8,
193 EventType::Borrow => 5u8,
194 EventType::MutBorrow => 6u8,
195 };
196 file.write_all(&event_type_byte.to_le_bytes())?;
197 file.write_all(&event.timestamp.to_le_bytes())?;
198 file.write_all(&event.ptr.to_le_bytes())?;
199 file.write_all(&event.size.to_le_bytes())?;
200 file.write_all(&event.call_stack_hash.to_le_bytes())?;
201 Ok(())
202 }
203
204 pub fn thread_id(&self) -> ThreadId {
205 self.thread_id
206 }
207
208 pub fn output_file(&self) -> &PathBuf {
209 &self.output_file
210 }
211
212 pub fn event_count(&self) -> usize {
213 self.events.len()
214 }
215
216 pub fn clear_events(&self) {
217 while self.events.pop().is_some() {}
218 }
219}
220
221impl Drop for ThreadLocalTracker {
222 fn drop(&mut self) {
223 if let Err(e) = self.finalize() {
224 tracing::warn!("Failed to finalize thread-local tracker: {}", e);
225 }
226 }
227}
228
229pub fn calculate_call_stack_hash(call_stack: &[usize]) -> u64 {
230 use std::collections::hash_map::DefaultHasher;
231 use std::hash::{Hash, Hasher};
232
233 let mut hasher = DefaultHasher::new();
234 for addr in call_stack {
235 addr.hash(&mut hasher);
236 }
237 hasher.finish()
238}
239
240#[cfg(test)]
241mod tests {
242 use super::*;
243
244 #[test]
245 fn test_thread_local_tracker_creation() {
246 let thread_id = std::thread::current().id();
247 let output_file = PathBuf::from("/tmp/test_tracker.bin");
248 let tracker = ThreadLocalTracker::new(thread_id, output_file, 1.0);
249
250 assert_eq!(tracker.thread_id(), thread_id);
251 assert_eq!(tracker.event_count(), 0);
252 }
253
254 #[test]
255 fn test_allocation_tracking() {
256 let thread_id = std::thread::current().id();
257 let output_file = PathBuf::from("/tmp/test_tracker2.bin");
258 let tracker = ThreadLocalTracker::new(thread_id, output_file, 1.0);
259
260 tracker.track_allocation(0x1000, 1024, 12345);
261
262 let stats = tracker.get_stats();
263 assert_eq!(stats.total_allocations, 1);
264 assert_eq!(stats.total_allocated, 1024);
265 assert_eq!(stats.active_memory, 1024);
266 assert_eq!(tracker.event_count(), 1);
267 }
268
269 #[test]
270 fn test_deallocation_tracking() {
271 let thread_id = std::thread::current().id();
272 let output_file = PathBuf::from("/tmp/test_tracker3.bin");
273 let tracker = ThreadLocalTracker::new(thread_id, output_file, 1.0);
274
275 tracker.track_allocation(0x1000, 1024, 12345);
276 tracker.track_deallocation(0x1000, 12345);
277
278 let stats = tracker.get_stats();
279 assert_eq!(stats.total_allocations, 1);
280 assert_eq!(stats.total_deallocations, 1);
281 assert_eq!(stats.active_memory, 0);
282 }
283
284 #[test]
285 fn test_call_stack_hash() {
286 let call_stack = vec![0x1000, 0x2000, 0x3000];
287 let hash1 = calculate_call_stack_hash(&call_stack);
288 let hash2 = calculate_call_stack_hash(&call_stack);
289
290 assert_eq!(hash1, hash2);
291
292 let different_stack = vec![0x1000, 0x2000, 0x4000];
293 let hash3 = calculate_call_stack_hash(&different_stack);
294 assert_ne!(hash1, hash3);
295 }
296}
297
298thread_local! {
299 static THREAD_TRACKER: std::cell::RefCell<Option<ThreadLocalTracker>> = const { std::cell::RefCell::new(None) };
300}
301
302fn get_thread_id() -> u64 {
303 crate::utils::current_thread_id_u64()
304}
305
306pub fn init_thread_tracker(
307 output_dir: &std::path::Path,
308 sample_rate: Option<f64>,
309) -> Result<(), Box<dyn std::error::Error>> {
310 let sample_rate = sample_rate.unwrap_or(1.0);
311 let thread_id = std::thread::current().id();
312 let output_file = output_dir.join(format!("memscope_thread_{}.bin", get_thread_id()));
313
314 let tracker = ThreadLocalTracker::new(thread_id, output_file, sample_rate);
315
316 THREAD_TRACKER.with(|thread_tracker| {
317 *thread_tracker.borrow_mut() = Some(tracker);
318 });
319
320 Ok(())
321}
322
323pub fn track_allocation_lockfree(
324 ptr: usize,
325 size: usize,
326 call_stack_hash: u64,
327) -> Result<(), Box<dyn std::error::Error>> {
328 THREAD_TRACKER.with(|thread_tracker| {
329 if let Some(ref tracker) = *thread_tracker.borrow() {
330 tracker.track_allocation(ptr, size, call_stack_hash);
331 Ok(())
332 } else {
333 Err("Thread tracker not initialized. Call init_thread_tracker() first.".into())
334 }
335 })
336}
337
338pub fn track_deallocation_lockfree(
339 ptr: usize,
340 call_stack_hash: u64,
341) -> Result<(), Box<dyn std::error::Error>> {
342 THREAD_TRACKER.with(|thread_tracker| {
343 if let Some(ref tracker) = *thread_tracker.borrow() {
344 tracker.track_deallocation(ptr, call_stack_hash);
345 Ok(())
346 } else {
347 Err("Thread tracker not initialized. Call init_thread_tracker() first.".into())
348 }
349 })
350}
351
352pub fn finalize_thread_tracker() -> Result<(), Box<dyn std::error::Error>> {
353 THREAD_TRACKER.with(|thread_tracker| {
354 let mut tracker_ref = thread_tracker.borrow_mut();
355 if let Some(ref mut tracker) = *tracker_ref {
356 tracker
357 .finalize()
358 .map_err(|e| Box::new(e) as Box<dyn std::error::Error>)
359 } else {
360 Ok(())
361 }
362 })
363}
364
365pub fn get_current_tracker() -> Option<ThreadLocalTracker> {
376 THREAD_TRACKER.with(|thread_tracker| {
377 thread_tracker
378 .borrow()
379 .as_ref()
380 .map(|tracker| ThreadLocalTracker {
381 thread_id: tracker.thread_id,
382 events: Arc::clone(&tracker.events),
383 active_allocations: Arc::clone(&tracker.active_allocations),
384 total_allocations: AtomicU64::new(
385 tracker.total_allocations.load(Ordering::Relaxed),
386 ),
387 total_allocated: AtomicU64::new(tracker.total_allocated.load(Ordering::Relaxed)),
388 total_deallocations: AtomicU64::new(
389 tracker.total_deallocations.load(Ordering::Relaxed),
390 ),
391 total_deallocated: AtomicU64::new(
392 tracker.total_deallocated.load(Ordering::Relaxed),
393 ),
394 active_memory: AtomicU64::new(tracker.active_memory.load(Ordering::Relaxed)),
395 peak_memory: AtomicU64::new(tracker.peak_memory.load(Ordering::Relaxed)),
396 output_file: tracker.output_file.clone(),
397 sample_rate: tracker.sample_rate,
398 total_seen: AtomicUsize::new(tracker.total_seen.load(Ordering::Relaxed)),
399 total_tracked: AtomicUsize::new(tracker.total_tracked.load(Ordering::Relaxed)),
400 })
401 })
402}
403
404pub fn trace_all<P: AsRef<std::path::Path>>(
405 output_dir: &P,
406) -> Result<(), Box<dyn std::error::Error>> {
407 let output_path = output_dir.as_ref().to_path_buf();
408
409 let _ = OUTPUT_DIRECTORY.set(output_path.clone());
410
411 if output_path.exists() {
412 let timestamp = std::time::SystemTime::now()
413 .duration_since(std::time::UNIX_EPOCH)
414 .map(|d| d.as_secs())
415 .unwrap_or(0);
416 let backup_name = format!(
417 "{}.backup.{}",
418 output_path
419 .file_name()
420 .unwrap_or_default()
421 .to_string_lossy(),
422 timestamp
423 );
424 let backup_path = output_path.with_file_name(backup_name);
425 std::fs::rename(&output_path, &backup_path)?;
426 tracing::info!("Existing directory backed up to: {}", backup_path.display());
427 }
428 std::fs::create_dir_all(&output_path)?;
429
430 TRACKING_ENABLED.store(true, Ordering::SeqCst);
431
432 tracing::info!("Lockfree tracking started: {}", output_path.display());
433
434 Ok(())
435}
436
437pub fn trace_thread<P: AsRef<std::path::Path>>(
438 output_dir: &P,
439) -> Result<(), Box<dyn std::error::Error>> {
440 let output_path = output_dir.as_ref().to_path_buf();
441
442 if !output_path.exists() {
443 std::fs::create_dir_all(&output_path)?;
444 }
445
446 init_thread_tracker(&output_path, Some(1.0))?;
447
448 Ok(())
449}
450
451pub fn stop_tracing() -> Result<(), Box<dyn std::error::Error>> {
452 if !TRACKING_ENABLED.load(Ordering::SeqCst) {
453 return Ok(());
454 }
455
456 let _ = finalize_thread_tracker();
457
458 TRACKING_ENABLED.store(false, Ordering::SeqCst);
459
460 Ok(())
461}
462
463pub fn is_tracking() -> bool {
464 TRACKING_ENABLED.load(Ordering::SeqCst)
465}
466
467pub fn memory_snapshot() -> super::lockfree_types::MemorySnapshot {
468 use super::lockfree_types::MemorySnapshot;
469
470 let (current_mb, peak_mb, allocations, deallocations) = THREAD_TRACKER.with(|thread_tracker| {
471 if let Some(tracker) = thread_tracker.borrow().as_ref() {
472 let stats = tracker.get_stats();
473 (
474 stats.active_memory as f64 / (1024.0 * 1024.0),
475 stats.peak_memory as f64 / (1024.0 * 1024.0),
476 stats.total_allocations,
477 stats.total_deallocations,
478 )
479 } else {
480 (0.0, 0.0, 0, 0)
481 }
482 });
483
484 MemorySnapshot {
485 current_mb,
486 peak_mb,
487 allocations: allocations as u64,
488 deallocations: deallocations as u64,
489 active_threads: if TRACKING_ENABLED.load(Ordering::SeqCst) {
490 1
491 } else {
492 0
493 },
494 }
495}
496
497pub fn quick_trace<F, R>(f: F) -> R
498where
499 F: FnOnce() -> R,
500{
501 let temp_dir = std::env::temp_dir().join("memscope_lockfree_quick");
502
503 if trace_all(&temp_dir).is_err() {
504 return f();
505 }
506
507 let result = f();
508
509 let _ = stop_tracing();
510
511 tracing::info!("Quick trace completed - check {}", temp_dir.display());
512
513 result
514}
515
516#[cfg(test)]
517mod global_api_tests {
518 use super::*;
519
520 #[test]
521 fn test_init_thread_tracker() {
522 let temp_dir = std::env::temp_dir().join("memscope_global_test");
523 std::fs::create_dir_all(&temp_dir).unwrap();
524
525 let result = init_thread_tracker(&temp_dir, Some(1.0));
526 assert!(result.is_ok(), "Should successfully initialize tracker");
527
528 let result2 = init_thread_tracker(&temp_dir, Some(0.5));
529 assert!(result2.is_ok(), "Should handle duplicate initialization");
530 }
531
532 #[test]
533 fn test_track_without_init() {
534 THREAD_TRACKER.with(|t| {
535 *t.borrow_mut() = None;
536 });
537
538 let result = track_allocation_lockfree(0x1000, 1024, 12345);
539 assert!(result.is_err(), "Should fail without initialization");
540 }
541
542 #[test]
543 fn test_finalize_without_init() {
544 THREAD_TRACKER.with(|t| {
545 *t.borrow_mut() = None;
546 });
547
548 let result = finalize_thread_tracker();
549 assert!(
550 result.is_ok(),
551 "Should handle finalization without initialization"
552 );
553 }
554
555 #[test]
556 fn test_global_api_workflow() {
557 let temp_dir = std::env::temp_dir().join("memscope_workflow_test");
558 std::fs::create_dir_all(&temp_dir).unwrap();
559
560 init_thread_tracker(&temp_dir, Some(1.0)).unwrap();
561
562 track_allocation_lockfree(0x1000, 1024, 12345).unwrap();
563 track_deallocation_lockfree(0x1000, 12345).unwrap();
564
565 let tracker = get_current_tracker();
566 assert!(tracker.is_some(), "Should have active tracker");
567
568 if let Some(t) = tracker {
569 let stats = t.get_stats();
570 assert_eq!(stats.total_allocations, 1);
571 assert_eq!(stats.total_deallocations, 1);
572 }
573
574 finalize_thread_tracker().unwrap();
575 }
576}