Skip to main content

memscope_rs/capture/backends/
lockfree_tracker.rs

1//! Lockfree memory tracker implementation.
2//!
3//! This module contains the ThreadLocalTracker for thread-local memory tracking
4//! using lock-free data structures for optimal concurrent performance.
5
6use std::{
7    fs::File,
8    io::Write,
9    path::PathBuf,
10    sync::{
11        atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering},
12        Arc, OnceLock,
13    },
14    thread::ThreadId,
15};
16
17use crossbeam::queue::SegQueue;
18use dashmap::DashMap;
19
20use super::lockfree_types::{Event, EventType, MemoryStats};
21
22static TRACKING_ENABLED: AtomicBool = AtomicBool::new(false);
23static OUTPUT_DIRECTORY: OnceLock<std::path::PathBuf> = OnceLock::new();
24
25/// Thread-local memory tracker using lock-free data structures.
26///
27/// This tracker uses SegQueue for events and DashMap for active allocations,
28/// ensuring no event loss under high contention.
29pub struct ThreadLocalTracker {
30    thread_id: ThreadId,
31    events: Arc<SegQueue<Event>>,
32    active_allocations: Arc<DashMap<usize, usize>>,
33    total_allocations: AtomicU64,
34    total_allocated: AtomicU64,
35    total_deallocations: AtomicU64,
36    total_deallocated: AtomicU64,
37    active_memory: AtomicU64,
38    peak_memory: AtomicU64,
39    output_file: PathBuf,
40    sample_rate: f64,
41    total_seen: AtomicUsize,
42    total_tracked: AtomicUsize,
43}
44
45impl ThreadLocalTracker {
46    pub fn new(thread_id: ThreadId, output_file: PathBuf, sample_rate: f64) -> Self {
47        Self {
48            thread_id,
49            events: Arc::new(SegQueue::new()),
50            active_allocations: Arc::new(DashMap::new()),
51            total_allocations: AtomicU64::new(0),
52            total_allocated: AtomicU64::new(0),
53            total_deallocations: AtomicU64::new(0),
54            total_deallocated: AtomicU64::new(0),
55            active_memory: AtomicU64::new(0),
56            peak_memory: AtomicU64::new(0),
57            output_file,
58            sample_rate: sample_rate.clamp(0.0, 1.0),
59            total_seen: AtomicUsize::new(0),
60            total_tracked: AtomicUsize::new(0),
61        }
62    }
63
64    pub fn track_allocation(&self, ptr: usize, size: usize, call_stack_hash: u64) {
65        self.total_seen.fetch_add(1, Ordering::Relaxed);
66
67        if self.sample_rate < 1.0 {
68            let sample_decision = rand::random::<f64>();
69            if sample_decision >= self.sample_rate {
70                return;
71            }
72        }
73
74        self.total_tracked.fetch_add(1, Ordering::Relaxed);
75
76        let event = Event::allocation(ptr, size, call_stack_hash, self.thread_id);
77        self.events.push(event);
78
79        self.active_allocations.insert(ptr, size);
80
81        self.total_allocations.fetch_add(1, Ordering::Relaxed);
82        self.total_allocated
83            .fetch_add(size as u64, Ordering::Relaxed);
84
85        let new_active = self.active_memory.fetch_add(size as u64, Ordering::Relaxed) + size as u64;
86
87        // CAS loop with exponential backoff to prevent CPU waste under high contention
88        let mut current_peak = self.peak_memory.load(Ordering::Relaxed);
89        let mut backoff_count = 0u32;
90        const MAX_BACKOFF_ATTEMPTS: u32 = 10;
91
92        while new_active > current_peak {
93            match self.peak_memory.compare_exchange_weak(
94                current_peak,
95                new_active,
96                Ordering::Relaxed,
97                Ordering::Relaxed,
98            ) {
99                Ok(_) => break,
100                Err(actual) => {
101                    current_peak = actual;
102                    backoff_count += 1;
103
104                    // Exponential backoff strategy
105                    if backoff_count < MAX_BACKOFF_ATTEMPTS {
106                        // Short-term contention: use spin_loop for efficiency
107                        std::hint::spin_loop();
108                    } else if backoff_count < MAX_BACKOFF_ATTEMPTS * 2 {
109                        // Medium-term contention: yield to other threads
110                        std::thread::yield_now();
111                    } else {
112                        // Long-term contention: use small sleep to reduce CPU usage
113                        std::thread::sleep(std::time::Duration::from_micros(1));
114                    }
115                }
116            }
117        }
118    }
119
120    pub fn track_deallocation(&self, ptr: usize, call_stack_hash: u64) {
121        let size = self
122            .active_allocations
123            .remove(&ptr)
124            .map(|(_, v)| v)
125            .unwrap_or(0);
126
127        let event = Event::deallocation(ptr, size, call_stack_hash, self.thread_id);
128        self.events.push(event);
129
130        self.total_deallocations.fetch_add(1, Ordering::Relaxed);
131        self.total_deallocated
132            .fetch_add(size as u64, Ordering::Relaxed);
133
134        // Use fetch_update for atomic saturating_sub to prevent underflow
135        let _ = self
136            .active_memory
137            .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| {
138                Some(current.saturating_sub(size as u64))
139            });
140    }
141
142    pub fn get_stats(&self) -> MemoryStats {
143        MemoryStats {
144            total_allocations: self.total_allocations.load(Ordering::Relaxed) as usize,
145            total_allocated: self.total_allocated.load(Ordering::Relaxed) as usize,
146            total_deallocations: self.total_deallocations.load(Ordering::Relaxed) as usize,
147            total_deallocated: self.total_deallocated.load(Ordering::Relaxed) as usize,
148            active_memory: self.active_memory.load(Ordering::Relaxed) as usize,
149            peak_memory: self.peak_memory.load(Ordering::Relaxed) as usize,
150        }
151    }
152
153    pub fn get_sampling_stats(&self) -> (usize, usize) {
154        (
155            self.total_seen.load(Ordering::Relaxed),
156            self.total_tracked.load(Ordering::Relaxed),
157        )
158    }
159
160    pub fn finalize(&self) -> std::io::Result<()> {
161        let mut events = Vec::new();
162        while let Some(event) = self.events.pop() {
163            events.push(event);
164        }
165
166        if events.is_empty() {
167            return Ok(());
168        }
169
170        if let Some(parent) = self.output_file.parent() {
171            std::fs::create_dir_all(parent)?;
172        }
173
174        let mut file = File::create(&self.output_file)?;
175
176        let header = "MEMSCOPE_LOCKFREE";
177        file.write_all(header.as_bytes())?;
178
179        for event in events {
180            self.write_event(&mut file, &event)?;
181        }
182
183        file.flush()?;
184        Ok(())
185    }
186
187    fn write_event(&self, file: &mut File, event: &Event) -> std::io::Result<()> {
188        let event_type_byte = match event.event_type {
189            EventType::Allocation => 1u8,
190            EventType::Deallocation => 2u8,
191        };
192        file.write_all(&event_type_byte.to_le_bytes())?;
193        file.write_all(&event.timestamp.to_le_bytes())?;
194        file.write_all(&event.ptr.to_le_bytes())?;
195        file.write_all(&event.size.to_le_bytes())?;
196        file.write_all(&event.call_stack_hash.to_le_bytes())?;
197        Ok(())
198    }
199
200    pub fn thread_id(&self) -> ThreadId {
201        self.thread_id
202    }
203
204    pub fn output_file(&self) -> &PathBuf {
205        &self.output_file
206    }
207
208    pub fn event_count(&self) -> usize {
209        self.events.len()
210    }
211
212    pub fn clear_events(&self) {
213        while self.events.pop().is_some() {}
214    }
215}
216
217impl Drop for ThreadLocalTracker {
218    fn drop(&mut self) {
219        if let Err(e) = self.finalize() {
220            tracing::warn!("Failed to finalize thread-local tracker: {}", e);
221        }
222    }
223}
224
225pub fn calculate_call_stack_hash(call_stack: &[usize]) -> u64 {
226    use std::collections::hash_map::DefaultHasher;
227    use std::hash::{Hash, Hasher};
228
229    let mut hasher = DefaultHasher::new();
230    for addr in call_stack {
231        addr.hash(&mut hasher);
232    }
233    hasher.finish()
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239
240    #[test]
241    fn test_thread_local_tracker_creation() {
242        let thread_id = std::thread::current().id();
243        let output_file = PathBuf::from("/tmp/test_tracker.bin");
244        let tracker = ThreadLocalTracker::new(thread_id, output_file, 1.0);
245
246        assert_eq!(tracker.thread_id(), thread_id);
247        assert_eq!(tracker.event_count(), 0);
248    }
249
250    #[test]
251    fn test_allocation_tracking() {
252        let thread_id = std::thread::current().id();
253        let output_file = PathBuf::from("/tmp/test_tracker2.bin");
254        let tracker = ThreadLocalTracker::new(thread_id, output_file, 1.0);
255
256        tracker.track_allocation(0x1000, 1024, 12345);
257
258        let stats = tracker.get_stats();
259        assert_eq!(stats.total_allocations, 1);
260        assert_eq!(stats.total_allocated, 1024);
261        assert_eq!(stats.active_memory, 1024);
262        assert_eq!(tracker.event_count(), 1);
263    }
264
265    #[test]
266    fn test_deallocation_tracking() {
267        let thread_id = std::thread::current().id();
268        let output_file = PathBuf::from("/tmp/test_tracker3.bin");
269        let tracker = ThreadLocalTracker::new(thread_id, output_file, 1.0);
270
271        tracker.track_allocation(0x1000, 1024, 12345);
272        tracker.track_deallocation(0x1000, 12345);
273
274        let stats = tracker.get_stats();
275        assert_eq!(stats.total_allocations, 1);
276        assert_eq!(stats.total_deallocations, 1);
277        assert_eq!(stats.active_memory, 0);
278    }
279
280    #[test]
281    fn test_call_stack_hash() {
282        let call_stack = vec![0x1000, 0x2000, 0x3000];
283        let hash1 = calculate_call_stack_hash(&call_stack);
284        let hash2 = calculate_call_stack_hash(&call_stack);
285
286        assert_eq!(hash1, hash2);
287
288        let different_stack = vec![0x1000, 0x2000, 0x4000];
289        let hash3 = calculate_call_stack_hash(&different_stack);
290        assert_ne!(hash1, hash3);
291    }
292}
293
294thread_local! {
295    static THREAD_TRACKER: std::cell::RefCell<Option<ThreadLocalTracker>> = const { std::cell::RefCell::new(None) };
296}
297
298fn get_thread_id() -> u64 {
299    crate::utils::current_thread_id_u64()
300}
301
302pub fn init_thread_tracker(
303    output_dir: &std::path::Path,
304    sample_rate: Option<f64>,
305) -> Result<(), Box<dyn std::error::Error>> {
306    let sample_rate = sample_rate.unwrap_or(1.0);
307    let thread_id = std::thread::current().id();
308    let output_file = output_dir.join(format!("memscope_thread_{}.bin", get_thread_id()));
309
310    let tracker = ThreadLocalTracker::new(thread_id, output_file, sample_rate);
311
312    THREAD_TRACKER.with(|thread_tracker| {
313        *thread_tracker.borrow_mut() = Some(tracker);
314    });
315
316    Ok(())
317}
318
319pub fn track_allocation_lockfree(
320    ptr: usize,
321    size: usize,
322    call_stack_hash: u64,
323) -> Result<(), Box<dyn std::error::Error>> {
324    THREAD_TRACKER.with(|thread_tracker| {
325        if let Some(ref tracker) = *thread_tracker.borrow() {
326            tracker.track_allocation(ptr, size, call_stack_hash);
327            Ok(())
328        } else {
329            Err("Thread tracker not initialized. Call init_thread_tracker() first.".into())
330        }
331    })
332}
333
334pub fn track_deallocation_lockfree(
335    ptr: usize,
336    call_stack_hash: u64,
337) -> Result<(), Box<dyn std::error::Error>> {
338    THREAD_TRACKER.with(|thread_tracker| {
339        if let Some(ref tracker) = *thread_tracker.borrow() {
340            tracker.track_deallocation(ptr, call_stack_hash);
341            Ok(())
342        } else {
343            Err("Thread tracker not initialized. Call init_thread_tracker() first.".into())
344        }
345    })
346}
347
348pub fn finalize_thread_tracker() -> Result<(), Box<dyn std::error::Error>> {
349    THREAD_TRACKER.with(|thread_tracker| {
350        let mut tracker_ref = thread_tracker.borrow_mut();
351        if let Some(ref mut tracker) = *tracker_ref {
352            tracker
353                .finalize()
354                .map_err(|e| Box::new(e) as Box<dyn std::error::Error>)
355        } else {
356            Ok(())
357        }
358    })
359}
360
361/// Get a snapshot of the current thread's tracker.
362///
363/// # Warning
364/// The returned tracker shares `events` and `active_allocations` with the original.
365/// **Do NOT call `finalize()` on the returned tracker** - it will drain shared data
366/// and corrupt the original tracker. Use this only for read operations.
367///
368/// # Returns
369/// - `Some(ThreadLocalTracker)` if a tracker is initialized
370/// - `None` if `init_thread_tracker()` was not called
371pub fn get_current_tracker() -> Option<ThreadLocalTracker> {
372    THREAD_TRACKER.with(|thread_tracker| {
373        thread_tracker
374            .borrow()
375            .as_ref()
376            .map(|tracker| ThreadLocalTracker {
377                thread_id: tracker.thread_id,
378                events: Arc::clone(&tracker.events),
379                active_allocations: Arc::clone(&tracker.active_allocations),
380                total_allocations: AtomicU64::new(
381                    tracker.total_allocations.load(Ordering::Relaxed),
382                ),
383                total_allocated: AtomicU64::new(tracker.total_allocated.load(Ordering::Relaxed)),
384                total_deallocations: AtomicU64::new(
385                    tracker.total_deallocations.load(Ordering::Relaxed),
386                ),
387                total_deallocated: AtomicU64::new(
388                    tracker.total_deallocated.load(Ordering::Relaxed),
389                ),
390                active_memory: AtomicU64::new(tracker.active_memory.load(Ordering::Relaxed)),
391                peak_memory: AtomicU64::new(tracker.peak_memory.load(Ordering::Relaxed)),
392                output_file: tracker.output_file.clone(),
393                sample_rate: tracker.sample_rate,
394                total_seen: AtomicUsize::new(tracker.total_seen.load(Ordering::Relaxed)),
395                total_tracked: AtomicUsize::new(tracker.total_tracked.load(Ordering::Relaxed)),
396            })
397    })
398}
399
400pub fn trace_all<P: AsRef<std::path::Path>>(
401    output_dir: &P,
402) -> Result<(), Box<dyn std::error::Error>> {
403    let output_path = output_dir.as_ref().to_path_buf();
404
405    let _ = OUTPUT_DIRECTORY.set(output_path.clone());
406
407    if output_path.exists() {
408        let timestamp = std::time::SystemTime::now()
409            .duration_since(std::time::UNIX_EPOCH)
410            .map(|d| d.as_secs())
411            .unwrap_or(0);
412        let backup_name = format!(
413            "{}.backup.{}",
414            output_path
415                .file_name()
416                .unwrap_or_default()
417                .to_string_lossy(),
418            timestamp
419        );
420        let backup_path = output_path.with_file_name(backup_name);
421        std::fs::rename(&output_path, &backup_path)?;
422        tracing::info!("Existing directory backed up to: {}", backup_path.display());
423    }
424    std::fs::create_dir_all(&output_path)?;
425
426    TRACKING_ENABLED.store(true, Ordering::SeqCst);
427
428    tracing::info!("Lockfree tracking started: {}", output_path.display());
429
430    Ok(())
431}
432
433pub fn trace_thread<P: AsRef<std::path::Path>>(
434    output_dir: &P,
435) -> Result<(), Box<dyn std::error::Error>> {
436    let output_path = output_dir.as_ref().to_path_buf();
437
438    if !output_path.exists() {
439        std::fs::create_dir_all(&output_path)?;
440    }
441
442    init_thread_tracker(&output_path, Some(1.0))?;
443
444    Ok(())
445}
446
447pub fn stop_tracing() -> Result<(), Box<dyn std::error::Error>> {
448    if !TRACKING_ENABLED.load(Ordering::SeqCst) {
449        return Ok(());
450    }
451
452    let _ = finalize_thread_tracker();
453
454    TRACKING_ENABLED.store(false, Ordering::SeqCst);
455
456    Ok(())
457}
458
459pub fn is_tracking() -> bool {
460    TRACKING_ENABLED.load(Ordering::SeqCst)
461}
462
463pub fn memory_snapshot() -> super::lockfree_types::MemorySnapshot {
464    use super::lockfree_types::MemorySnapshot;
465
466    let (current_mb, peak_mb, allocations, deallocations) = THREAD_TRACKER.with(|thread_tracker| {
467        if let Some(tracker) = thread_tracker.borrow().as_ref() {
468            let stats = tracker.get_stats();
469            (
470                stats.active_memory as f64 / (1024.0 * 1024.0),
471                stats.peak_memory as f64 / (1024.0 * 1024.0),
472                stats.total_allocations,
473                stats.total_deallocations,
474            )
475        } else {
476            (0.0, 0.0, 0, 0)
477        }
478    });
479
480    MemorySnapshot {
481        current_mb,
482        peak_mb,
483        allocations: allocations as u64,
484        deallocations: deallocations as u64,
485        active_threads: if TRACKING_ENABLED.load(Ordering::SeqCst) {
486            1
487        } else {
488            0
489        },
490    }
491}
492
493pub fn quick_trace<F, R>(f: F) -> R
494where
495    F: FnOnce() -> R,
496{
497    let temp_dir = std::env::temp_dir().join("memscope_lockfree_quick");
498
499    if trace_all(&temp_dir).is_err() {
500        return f();
501    }
502
503    let result = f();
504
505    let _ = stop_tracing();
506
507    tracing::info!("Quick trace completed - check {}", temp_dir.display());
508
509    result
510}
511
512#[cfg(test)]
513mod global_api_tests {
514    use super::*;
515
516    #[test]
517    fn test_init_thread_tracker() {
518        let temp_dir = std::env::temp_dir().join("memscope_global_test");
519        std::fs::create_dir_all(&temp_dir).unwrap();
520
521        let result = init_thread_tracker(&temp_dir, Some(1.0));
522        assert!(result.is_ok(), "Should successfully initialize tracker");
523
524        let result2 = init_thread_tracker(&temp_dir, Some(0.5));
525        assert!(result2.is_ok(), "Should handle duplicate initialization");
526    }
527
528    #[test]
529    fn test_track_without_init() {
530        THREAD_TRACKER.with(|t| {
531            *t.borrow_mut() = None;
532        });
533
534        let result = track_allocation_lockfree(0x1000, 1024, 12345);
535        assert!(result.is_err(), "Should fail without initialization");
536    }
537
538    #[test]
539    fn test_finalize_without_init() {
540        THREAD_TRACKER.with(|t| {
541            *t.borrow_mut() = None;
542        });
543
544        let result = finalize_thread_tracker();
545        assert!(
546            result.is_ok(),
547            "Should handle finalization without initialization"
548        );
549    }
550
551    #[test]
552    fn test_global_api_workflow() {
553        let temp_dir = std::env::temp_dir().join("memscope_workflow_test");
554        std::fs::create_dir_all(&temp_dir).unwrap();
555
556        init_thread_tracker(&temp_dir, Some(1.0)).unwrap();
557
558        track_allocation_lockfree(0x1000, 1024, 12345).unwrap();
559        track_deallocation_lockfree(0x1000, 12345).unwrap();
560
561        let tracker = get_current_tracker();
562        assert!(tracker.is_some(), "Should have active tracker");
563
564        if let Some(t) = tracker {
565            let stats = t.get_stats();
566            assert_eq!(stats.total_allocations, 1);
567            assert_eq!(stats.total_deallocations, 1);
568        }
569
570        finalize_thread_tracker().unwrap();
571    }
572}