Skip to main content

memscope_rs/capture/backends/
lockfree_tracker.rs

1//! Lockfree memory tracker implementation.
2//!
3//! This module contains the ThreadLocalTracker for thread-local memory tracking
4//! using lock-free data structures for optimal concurrent performance.
5
6use std::{
7    fs::File,
8    io::Write,
9    path::PathBuf,
10    sync::{
11        atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering},
12        Arc, OnceLock,
13    },
14    thread::ThreadId,
15};
16
17use crossbeam::queue::SegQueue;
18use dashmap::DashMap;
19
20use super::lockfree_types::{Event, EventType, MemoryStats};
21
22static TRACKING_ENABLED: AtomicBool = AtomicBool::new(false);
23static OUTPUT_DIRECTORY: OnceLock<std::path::PathBuf> = OnceLock::new();
24
25/// Thread-local memory tracker using lock-free data structures.
26///
27/// This tracker uses SegQueue for events and DashMap for active allocations,
28/// ensuring no event loss under high contention.
29pub struct ThreadLocalTracker {
30    thread_id: ThreadId,
31    events: Arc<SegQueue<Event>>,
32    active_allocations: Arc<DashMap<usize, usize>>,
33    total_allocations: AtomicU64,
34    total_allocated: AtomicU64,
35    total_deallocations: AtomicU64,
36    total_deallocated: AtomicU64,
37    active_memory: AtomicU64,
38    peak_memory: AtomicU64,
39    output_file: PathBuf,
40    sample_rate: f64,
41    total_seen: AtomicUsize,
42    total_tracked: AtomicUsize,
43}
44
45impl ThreadLocalTracker {
46    pub fn new(thread_id: ThreadId, output_file: PathBuf, sample_rate: f64) -> Self {
47        Self {
48            thread_id,
49            events: Arc::new(SegQueue::new()),
50            active_allocations: Arc::new(DashMap::new()),
51            total_allocations: AtomicU64::new(0),
52            total_allocated: AtomicU64::new(0),
53            total_deallocations: AtomicU64::new(0),
54            total_deallocated: AtomicU64::new(0),
55            active_memory: AtomicU64::new(0),
56            peak_memory: AtomicU64::new(0),
57            output_file,
58            sample_rate: sample_rate.clamp(0.0, 1.0),
59            total_seen: AtomicUsize::new(0),
60            total_tracked: AtomicUsize::new(0),
61        }
62    }
63
64    pub fn track_allocation(&self, ptr: usize, size: usize, call_stack_hash: u64) {
65        self.total_seen.fetch_add(1, Ordering::Relaxed);
66
67        if self.sample_rate < 1.0 {
68            let sample_decision = rand::random::<f64>();
69            if sample_decision >= self.sample_rate {
70                return;
71            }
72        }
73
74        self.total_tracked.fetch_add(1, Ordering::Relaxed);
75
76        let event = Event::allocation(ptr, size, call_stack_hash, self.thread_id);
77        self.events.push(event);
78
79        self.active_allocations.insert(ptr, size);
80
81        self.total_allocations.fetch_add(1, Ordering::Relaxed);
82        self.total_allocated
83            .fetch_add(size as u64, Ordering::Relaxed);
84
85        let new_active = self.active_memory.fetch_add(size as u64, Ordering::Relaxed) + size as u64;
86
87        // CAS loop with exponential backoff to prevent CPU waste under high contention
88        let mut current_peak = self.peak_memory.load(Ordering::Relaxed);
89        let mut backoff_count = 0u32;
90        const MAX_BACKOFF_ATTEMPTS: u32 = 10;
91
92        while new_active > current_peak {
93            match self.peak_memory.compare_exchange_weak(
94                current_peak,
95                new_active,
96                Ordering::Relaxed,
97                Ordering::Relaxed,
98            ) {
99                Ok(_) => break,
100                Err(actual) => {
101                    current_peak = actual;
102                    backoff_count += 1;
103
104                    // Exponential backoff strategy
105                    if backoff_count < MAX_BACKOFF_ATTEMPTS {
106                        // Short-term contention: use spin_loop for efficiency
107                        std::hint::spin_loop();
108                    } else if backoff_count < MAX_BACKOFF_ATTEMPTS * 2 {
109                        // Medium-term contention: yield to other threads
110                        std::thread::yield_now();
111                    } else {
112                        // Long-term contention: use small sleep to reduce CPU usage
113                        std::thread::sleep(std::time::Duration::from_micros(1));
114                    }
115                }
116            }
117        }
118    }
119
120    pub fn track_deallocation(&self, ptr: usize, call_stack_hash: u64) {
121        let size = self
122            .active_allocations
123            .remove(&ptr)
124            .map(|(_, v)| v)
125            .unwrap_or(0);
126
127        let event = Event::deallocation(ptr, size, call_stack_hash, self.thread_id);
128        self.events.push(event);
129
130        self.total_deallocations.fetch_add(1, Ordering::Relaxed);
131        self.total_deallocated
132            .fetch_add(size as u64, Ordering::Relaxed);
133
134        // Use fetch_update for atomic saturating_sub to prevent underflow
135        let _ = self
136            .active_memory
137            .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| {
138                Some(current.saturating_sub(size as u64))
139            });
140    }
141
142    pub fn get_stats(&self) -> MemoryStats {
143        MemoryStats {
144            total_allocations: self.total_allocations.load(Ordering::Relaxed) as usize,
145            total_allocated: self.total_allocated.load(Ordering::Relaxed) as usize,
146            total_deallocations: self.total_deallocations.load(Ordering::Relaxed) as usize,
147            total_deallocated: self.total_deallocated.load(Ordering::Relaxed) as usize,
148            active_memory: self.active_memory.load(Ordering::Relaxed) as usize,
149            peak_memory: self.peak_memory.load(Ordering::Relaxed) as usize,
150        }
151    }
152
153    pub fn get_sampling_stats(&self) -> (usize, usize) {
154        (
155            self.total_seen.load(Ordering::Relaxed),
156            self.total_tracked.load(Ordering::Relaxed),
157        )
158    }
159
160    pub fn finalize(&self) -> std::io::Result<()> {
161        let mut events = Vec::new();
162        while let Some(event) = self.events.pop() {
163            events.push(event);
164        }
165
166        if events.is_empty() {
167            return Ok(());
168        }
169
170        if let Some(parent) = self.output_file.parent() {
171            std::fs::create_dir_all(parent)?;
172        }
173
174        let mut file = File::create(&self.output_file)?;
175
176        let header = "MEMSCOPE_LOCKFREE";
177        file.write_all(header.as_bytes())?;
178
179        for event in events {
180            self.write_event(&mut file, &event)?;
181        }
182
183        file.flush()?;
184        Ok(())
185    }
186
187    fn write_event(&self, file: &mut File, event: &Event) -> std::io::Result<()> {
188        let event_type_byte = match event.event_type {
189            EventType::Allocation => 1u8,
190            EventType::Deallocation => 2u8,
191            EventType::Clone => 3u8,
192            EventType::Move => 4u8,
193            EventType::Borrow => 5u8,
194            EventType::MutBorrow => 6u8,
195        };
196        file.write_all(&event_type_byte.to_le_bytes())?;
197        file.write_all(&event.timestamp.to_le_bytes())?;
198        file.write_all(&event.ptr.to_le_bytes())?;
199        file.write_all(&event.size.to_le_bytes())?;
200        file.write_all(&event.call_stack_hash.to_le_bytes())?;
201        Ok(())
202    }
203
204    pub fn thread_id(&self) -> ThreadId {
205        self.thread_id
206    }
207
208    pub fn output_file(&self) -> &PathBuf {
209        &self.output_file
210    }
211
212    pub fn event_count(&self) -> usize {
213        self.events.len()
214    }
215
216    pub fn clear_events(&self) {
217        while self.events.pop().is_some() {}
218    }
219}
220
221impl Drop for ThreadLocalTracker {
222    fn drop(&mut self) {
223        if let Err(e) = self.finalize() {
224            tracing::warn!("Failed to finalize thread-local tracker: {}", e);
225        }
226    }
227}
228
229pub fn calculate_call_stack_hash(call_stack: &[usize]) -> u64 {
230    use std::collections::hash_map::DefaultHasher;
231    use std::hash::{Hash, Hasher};
232
233    let mut hasher = DefaultHasher::new();
234    for addr in call_stack {
235        addr.hash(&mut hasher);
236    }
237    hasher.finish()
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243
244    #[test]
245    fn test_thread_local_tracker_creation() {
246        let thread_id = std::thread::current().id();
247        let output_file = PathBuf::from("/tmp/test_tracker.bin");
248        let tracker = ThreadLocalTracker::new(thread_id, output_file, 1.0);
249
250        assert_eq!(tracker.thread_id(), thread_id);
251        assert_eq!(tracker.event_count(), 0);
252    }
253
254    #[test]
255    fn test_allocation_tracking() {
256        let thread_id = std::thread::current().id();
257        let output_file = PathBuf::from("/tmp/test_tracker2.bin");
258        let tracker = ThreadLocalTracker::new(thread_id, output_file, 1.0);
259
260        tracker.track_allocation(0x1000, 1024, 12345);
261
262        let stats = tracker.get_stats();
263        assert_eq!(stats.total_allocations, 1);
264        assert_eq!(stats.total_allocated, 1024);
265        assert_eq!(stats.active_memory, 1024);
266        assert_eq!(tracker.event_count(), 1);
267    }
268
269    #[test]
270    fn test_deallocation_tracking() {
271        let thread_id = std::thread::current().id();
272        let output_file = PathBuf::from("/tmp/test_tracker3.bin");
273        let tracker = ThreadLocalTracker::new(thread_id, output_file, 1.0);
274
275        tracker.track_allocation(0x1000, 1024, 12345);
276        tracker.track_deallocation(0x1000, 12345);
277
278        let stats = tracker.get_stats();
279        assert_eq!(stats.total_allocations, 1);
280        assert_eq!(stats.total_deallocations, 1);
281        assert_eq!(stats.active_memory, 0);
282    }
283
284    #[test]
285    fn test_call_stack_hash() {
286        let call_stack = vec![0x1000, 0x2000, 0x3000];
287        let hash1 = calculate_call_stack_hash(&call_stack);
288        let hash2 = calculate_call_stack_hash(&call_stack);
289
290        assert_eq!(hash1, hash2);
291
292        let different_stack = vec![0x1000, 0x2000, 0x4000];
293        let hash3 = calculate_call_stack_hash(&different_stack);
294        assert_ne!(hash1, hash3);
295    }
296}
297
298thread_local! {
299    static THREAD_TRACKER: std::cell::RefCell<Option<ThreadLocalTracker>> = const { std::cell::RefCell::new(None) };
300}
301
302fn get_thread_id() -> u64 {
303    crate::utils::current_thread_id_u64()
304}
305
306pub fn init_thread_tracker(
307    output_dir: &std::path::Path,
308    sample_rate: Option<f64>,
309) -> Result<(), Box<dyn std::error::Error>> {
310    let sample_rate = sample_rate.unwrap_or(1.0);
311    let thread_id = std::thread::current().id();
312    let output_file = output_dir.join(format!("memscope_thread_{}.bin", get_thread_id()));
313
314    let tracker = ThreadLocalTracker::new(thread_id, output_file, sample_rate);
315
316    THREAD_TRACKER.with(|thread_tracker| {
317        *thread_tracker.borrow_mut() = Some(tracker);
318    });
319
320    Ok(())
321}
322
323pub fn track_allocation_lockfree(
324    ptr: usize,
325    size: usize,
326    call_stack_hash: u64,
327) -> Result<(), Box<dyn std::error::Error>> {
328    THREAD_TRACKER.with(|thread_tracker| {
329        if let Some(ref tracker) = *thread_tracker.borrow() {
330            tracker.track_allocation(ptr, size, call_stack_hash);
331            Ok(())
332        } else {
333            Err("Thread tracker not initialized. Call init_thread_tracker() first.".into())
334        }
335    })
336}
337
338pub fn track_deallocation_lockfree(
339    ptr: usize,
340    call_stack_hash: u64,
341) -> Result<(), Box<dyn std::error::Error>> {
342    THREAD_TRACKER.with(|thread_tracker| {
343        if let Some(ref tracker) = *thread_tracker.borrow() {
344            tracker.track_deallocation(ptr, call_stack_hash);
345            Ok(())
346        } else {
347            Err("Thread tracker not initialized. Call init_thread_tracker() first.".into())
348        }
349    })
350}
351
352pub fn finalize_thread_tracker() -> Result<(), Box<dyn std::error::Error>> {
353    THREAD_TRACKER.with(|thread_tracker| {
354        let mut tracker_ref = thread_tracker.borrow_mut();
355        if let Some(ref mut tracker) = *tracker_ref {
356            tracker
357                .finalize()
358                .map_err(|e| Box::new(e) as Box<dyn std::error::Error>)
359        } else {
360            Ok(())
361        }
362    })
363}
364
365/// Get a snapshot of the current thread's tracker.
366///
367/// # Warning
368/// The returned tracker shares `events` and `active_allocations` with the original.
369/// **Do NOT call `finalize()` on the returned tracker** - it will drain shared data
370/// and corrupt the original tracker. Use this only for read operations.
371///
372/// # Returns
373/// - `Some(ThreadLocalTracker)` if a tracker is initialized
374/// - `None` if `init_thread_tracker()` was not called
375pub fn get_current_tracker() -> Option<ThreadLocalTracker> {
376    THREAD_TRACKER.with(|thread_tracker| {
377        thread_tracker
378            .borrow()
379            .as_ref()
380            .map(|tracker| ThreadLocalTracker {
381                thread_id: tracker.thread_id,
382                events: Arc::clone(&tracker.events),
383                active_allocations: Arc::clone(&tracker.active_allocations),
384                total_allocations: AtomicU64::new(
385                    tracker.total_allocations.load(Ordering::Relaxed),
386                ),
387                total_allocated: AtomicU64::new(tracker.total_allocated.load(Ordering::Relaxed)),
388                total_deallocations: AtomicU64::new(
389                    tracker.total_deallocations.load(Ordering::Relaxed),
390                ),
391                total_deallocated: AtomicU64::new(
392                    tracker.total_deallocated.load(Ordering::Relaxed),
393                ),
394                active_memory: AtomicU64::new(tracker.active_memory.load(Ordering::Relaxed)),
395                peak_memory: AtomicU64::new(tracker.peak_memory.load(Ordering::Relaxed)),
396                output_file: tracker.output_file.clone(),
397                sample_rate: tracker.sample_rate,
398                total_seen: AtomicUsize::new(tracker.total_seen.load(Ordering::Relaxed)),
399                total_tracked: AtomicUsize::new(tracker.total_tracked.load(Ordering::Relaxed)),
400            })
401    })
402}
403
404pub fn trace_all<P: AsRef<std::path::Path>>(
405    output_dir: &P,
406) -> Result<(), Box<dyn std::error::Error>> {
407    let output_path = output_dir.as_ref().to_path_buf();
408
409    let _ = OUTPUT_DIRECTORY.set(output_path.clone());
410
411    if output_path.exists() {
412        let timestamp = std::time::SystemTime::now()
413            .duration_since(std::time::UNIX_EPOCH)
414            .map(|d| d.as_secs())
415            .unwrap_or(0);
416        let backup_name = format!(
417            "{}.backup.{}",
418            output_path
419                .file_name()
420                .unwrap_or_default()
421                .to_string_lossy(),
422            timestamp
423        );
424        let backup_path = output_path.with_file_name(backup_name);
425        std::fs::rename(&output_path, &backup_path)?;
426        tracing::info!("Existing directory backed up to: {}", backup_path.display());
427    }
428    std::fs::create_dir_all(&output_path)?;
429
430    TRACKING_ENABLED.store(true, Ordering::SeqCst);
431
432    tracing::info!("Lockfree tracking started: {}", output_path.display());
433
434    Ok(())
435}
436
437pub fn trace_thread<P: AsRef<std::path::Path>>(
438    output_dir: &P,
439) -> Result<(), Box<dyn std::error::Error>> {
440    let output_path = output_dir.as_ref().to_path_buf();
441
442    if !output_path.exists() {
443        std::fs::create_dir_all(&output_path)?;
444    }
445
446    init_thread_tracker(&output_path, Some(1.0))?;
447
448    Ok(())
449}
450
451pub fn stop_tracing() -> Result<(), Box<dyn std::error::Error>> {
452    if !TRACKING_ENABLED.load(Ordering::SeqCst) {
453        return Ok(());
454    }
455
456    let _ = finalize_thread_tracker();
457
458    TRACKING_ENABLED.store(false, Ordering::SeqCst);
459
460    Ok(())
461}
462
463pub fn is_tracking() -> bool {
464    TRACKING_ENABLED.load(Ordering::SeqCst)
465}
466
467pub fn memory_snapshot() -> super::lockfree_types::MemorySnapshot {
468    use super::lockfree_types::MemorySnapshot;
469
470    let (current_mb, peak_mb, allocations, deallocations) = THREAD_TRACKER.with(|thread_tracker| {
471        if let Some(tracker) = thread_tracker.borrow().as_ref() {
472            let stats = tracker.get_stats();
473            (
474                stats.active_memory as f64 / (1024.0 * 1024.0),
475                stats.peak_memory as f64 / (1024.0 * 1024.0),
476                stats.total_allocations,
477                stats.total_deallocations,
478            )
479        } else {
480            (0.0, 0.0, 0, 0)
481        }
482    });
483
484    MemorySnapshot {
485        current_mb,
486        peak_mb,
487        allocations: allocations as u64,
488        deallocations: deallocations as u64,
489        active_threads: if TRACKING_ENABLED.load(Ordering::SeqCst) {
490            1
491        } else {
492            0
493        },
494    }
495}
496
497pub fn quick_trace<F, R>(f: F) -> R
498where
499    F: FnOnce() -> R,
500{
501    let temp_dir = std::env::temp_dir().join("memscope_lockfree_quick");
502
503    if trace_all(&temp_dir).is_err() {
504        return f();
505    }
506
507    let result = f();
508
509    let _ = stop_tracing();
510
511    tracing::info!("Quick trace completed - check {}", temp_dir.display());
512
513    result
514}
515
516#[cfg(test)]
517mod global_api_tests {
518    use super::*;
519
520    #[test]
521    fn test_init_thread_tracker() {
522        let temp_dir = std::env::temp_dir().join("memscope_global_test");
523        std::fs::create_dir_all(&temp_dir).unwrap();
524
525        let result = init_thread_tracker(&temp_dir, Some(1.0));
526        assert!(result.is_ok(), "Should successfully initialize tracker");
527
528        let result2 = init_thread_tracker(&temp_dir, Some(0.5));
529        assert!(result2.is_ok(), "Should handle duplicate initialization");
530    }
531
532    #[test]
533    fn test_track_without_init() {
534        THREAD_TRACKER.with(|t| {
535            *t.borrow_mut() = None;
536        });
537
538        let result = track_allocation_lockfree(0x1000, 1024, 12345);
539        assert!(result.is_err(), "Should fail without initialization");
540    }
541
542    #[test]
543    fn test_finalize_without_init() {
544        THREAD_TRACKER.with(|t| {
545            *t.borrow_mut() = None;
546        });
547
548        let result = finalize_thread_tracker();
549        assert!(
550            result.is_ok(),
551            "Should handle finalization without initialization"
552        );
553    }
554
555    #[test]
556    fn test_global_api_workflow() {
557        let temp_dir = std::env::temp_dir().join("memscope_workflow_test");
558        std::fs::create_dir_all(&temp_dir).unwrap();
559
560        init_thread_tracker(&temp_dir, Some(1.0)).unwrap();
561
562        track_allocation_lockfree(0x1000, 1024, 12345).unwrap();
563        track_deallocation_lockfree(0x1000, 12345).unwrap();
564
565        let tracker = get_current_tracker();
566        assert!(tracker.is_some(), "Should have active tracker");
567
568        if let Some(t) = tracker {
569            let stats = t.get_stats();
570            assert_eq!(stats.total_allocations, 1);
571            assert_eq!(stats.total_deallocations, 1);
572        }
573
574        finalize_thread_tracker().unwrap();
575    }
576}