memscope_rs/core/
sampling_tracker.rs

1//! High-performance sampling-based memory tracker using binary serialization.
2//!
3//! This module implements the core multi-thread tracking system using:
4//! - bincode binary serialization for zero-overhead data storage
5//! - Intelligent sampling (frequency + size dimensions)
6//! - Thread-local storage with file-based communication
7//! - Batch writing for optimal performance
8
9use crate::core::types::TrackingResult;
10use serde::{Deserialize, Serialize};
11use std::cell::RefCell;
12use std::collections::HashMap;
13use std::fs::OpenOptions;
14use std::io::Write;
15use std::thread;
16use std::time::{SystemTime, UNIX_EPOCH};
17
18/// Core event data structure optimized for binary serialization
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct Event {
21    /// Timestamp when the event occurred (nanoseconds since epoch)
22    pub timestamp: u64,
23    /// Memory pointer address
24    pub ptr: usize,
25    /// Size of memory allocation/operation
26    pub size: usize,
27    /// Hash of the call stack for frequency analysis
28    pub call_stack_hash: u64,
29    /// Type of memory operation
30    pub event_type: EventType,
31    /// Variable name (if available)
32    pub var_name: Option<String>,
33    /// Type name (if available)
34    pub type_name: Option<String>,
35}
36
37/// Types of memory operations we track
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub enum EventType {
40    /// Memory allocation
41    Allocate,
42    /// Memory access/read
43    Access,
44    /// Memory modification/write
45    Modify,
46    /// Memory deallocation
47    Drop,
48    /// Variable clone operation
49    Clone { target_ptr: usize },
50}
51
52/// Frequency data for call stack analysis
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct FrequencyData {
55    /// Hash of the call stack
56    pub call_stack_hash: u64,
57    /// How many times this call stack occurred
58    pub frequency: u64,
59    /// Total memory allocated by this call stack
60    pub total_size: usize,
61    /// Representative variable name from this call stack
62    pub sample_var_name: String,
63    /// Representative type name from this call stack
64    pub sample_type_name: String,
65}
66
67/// Thread-local data structure for high-performance tracking
68#[derive(Debug)]
69struct ThreadLocalData {
70    /// Buffer for events before batch writing
71    event_buffer: Vec<Event>,
72    /// Call stack frequency tracking
73    call_stack_frequencies: HashMap<u64, (u64, usize, String, String)>, // (count, total_size, sample_var, sample_type)
74    /// File handle for this thread's data
75    file_handle: Option<std::fs::File>,
76    /// Thread ID for identification
77    thread_id: String,
78    /// Sample counter for frequency-based sampling
79    sample_counter: u64,
80    /// Total operations performed by this thread
81    total_operations: u64,
82}
83
84impl ThreadLocalData {
85    fn new() -> Self {
86        Self {
87            event_buffer: Vec::with_capacity(1000), // Pre-allocate buffer
88            call_stack_frequencies: HashMap::new(),
89            file_handle: None,
90            thread_id: format!("{:?}", thread::current().id()),
91            sample_counter: 0,
92            total_operations: 0,
93        }
94    }
95
96    /// Initialize the file handle for this thread
97    fn ensure_file_handle(&mut self) -> std::io::Result<()> {
98        if self.file_handle.is_none() {
99            let filename = format!("memscope_thread_{}.bin", self.thread_id);
100            let file = OpenOptions::new()
101                .create(true)
102                .append(true)
103                .open(filename)?;
104            self.file_handle = Some(file);
105        }
106        Ok(())
107    }
108
109    /// Flush the event buffer to disk using binary serialization
110    fn flush_events(&mut self) -> std::io::Result<()> {
111        if self.event_buffer.is_empty() {
112            return Ok(());
113        }
114
115        self.ensure_file_handle()?;
116
117        if let Some(ref mut file) = self.file_handle {
118            // Serialize events to binary format using bincode
119            let serialized = serde_json::to_vec(&self.event_buffer)
120                .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
121
122            // Write length prefix for easier parsing
123            let len = serialized.len() as u32;
124            file.write_all(&len.to_le_bytes())?;
125            file.write_all(&serialized)?;
126            file.flush()?;
127        }
128
129        self.event_buffer.clear();
130        Ok(())
131    }
132
133    /// Flush frequency data to disk
134    fn flush_frequencies(&mut self) -> std::io::Result<()> {
135        if self.call_stack_frequencies.is_empty() {
136            return Ok(());
137        }
138
139        self.ensure_file_handle()?;
140
141        let freq_data: Vec<FrequencyData> = self
142            .call_stack_frequencies
143            .iter()
144            .map(
145                |(&hash, &(freq, total_size, ref var_name, ref type_name))| FrequencyData {
146                    call_stack_hash: hash,
147                    frequency: freq,
148                    total_size,
149                    sample_var_name: var_name.clone(),
150                    sample_type_name: type_name.clone(),
151                },
152            )
153            .collect();
154
155        if let Some(ref mut file) = self.file_handle {
156            let serialized = serde_json::to_vec(&freq_data)
157                .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
158
159            // Write marker for frequency data
160            let marker = 0xFEEDFACEu32;
161            file.write_all(&marker.to_le_bytes())?;
162            file.write_all(&(serialized.len() as u32).to_le_bytes())?;
163            file.write_all(&serialized)?;
164            file.flush()?;
165        }
166
167        self.call_stack_frequencies.clear();
168        Ok(())
169    }
170}
171
172// Thread-local storage for tracking data
173thread_local! {
174    static THREAD_DATA: RefCell<ThreadLocalData> = RefCell::new(ThreadLocalData::new());
175}
176
177/// High-performance sampling tracker with intelligent sampling strategies
178pub struct SamplingTracker {
179    /// Configuration for sampling behavior
180    config: SamplingConfig,
181}
182
183/// Configuration for sampling behavior
184#[derive(Debug, Clone)]
185pub struct SamplingConfig {
186    /// Size threshold for guaranteed sampling (bytes)
187    pub large_size_threshold: usize,
188    /// Size threshold for medium probability sampling (bytes)
189    pub medium_size_threshold: usize,
190    /// Sampling rate for medium-sized allocations (0.0-1.0)
191    pub medium_sample_rate: f64,
192    /// Sampling rate for small allocations (0.0-1.0)
193    pub small_sample_rate: f64,
194    /// Buffer size before flushing to disk
195    pub buffer_size: usize,
196}
197
198impl Default for SamplingConfig {
199    fn default() -> Self {
200        Self {
201            large_size_threshold: 10 * 1024, // 10KB - always sample
202            medium_size_threshold: 1024,     // 1KB - 10% sample rate
203            medium_sample_rate: 0.1,         // 10%
204            small_sample_rate: 0.01,         // 1%
205            buffer_size: 1000,               // Flush after 1000 events
206        }
207    }
208}
209
210impl SamplingTracker {
211    /// Create a new sampling tracker with default configuration
212    pub fn new() -> Self {
213        Self {
214            config: SamplingConfig::default(),
215        }
216    }
217
218    /// Create a new sampling tracker with custom configuration
219    pub fn with_config(config: SamplingConfig) -> Self {
220        Self { config }
221    }
222
223    /// Track a variable allocation with intelligent sampling
224    pub fn track_variable(
225        &self,
226        ptr: usize,
227        size: usize,
228        var_name: String,
229        type_name: String,
230    ) -> TrackingResult<()> {
231        let call_stack_hash = self.calculate_call_stack_hash(&var_name, &type_name);
232
233        THREAD_DATA.with(|data| {
234            let mut data = data.borrow_mut();
235            data.total_operations += 1;
236
237            // Update frequency tracking (always track frequency, even if we don't sample the event)
238            let entry = data
239                .call_stack_frequencies
240                .entry(call_stack_hash)
241                .or_insert((0, 0, var_name.clone(), type_name.clone()));
242            entry.0 += 1; // Increment frequency
243            entry.1 += size; // Add to total size
244
245            // Intelligent sampling decision
246            if self.should_sample(size, &mut data) {
247                let event = Event {
248                    timestamp: get_timestamp(),
249                    ptr,
250                    size,
251                    call_stack_hash,
252                    event_type: EventType::Allocate,
253                    var_name: Some(var_name),
254                    type_name: Some(type_name),
255                };
256
257                data.event_buffer.push(event);
258
259                // Flush if buffer is full
260                if data.event_buffer.len() >= self.config.buffer_size {
261                    data.flush_events()
262                        .map_err(|e| crate::core::types::TrackingError::IoError(e.to_string()))?;
263                }
264            }
265
266            // Periodically flush frequency data
267            if data.total_operations % 10000 == 0 {
268                data.flush_frequencies()
269                    .map_err(|e| crate::core::types::TrackingError::IoError(e.to_string()))?;
270            }
271
272            Ok(())
273        })
274    }
275
276    /// Track variable access
277    pub fn track_access(&self, ptr: usize) -> TrackingResult<()> {
278        self.track_operation(ptr, EventType::Access)
279    }
280
281    /// Track variable modification
282    pub fn track_modify(&self, ptr: usize) -> TrackingResult<()> {
283        self.track_operation(ptr, EventType::Modify)
284    }
285
286    /// Track variable drop
287    pub fn track_drop(&self, ptr: usize) -> TrackingResult<()> {
288        self.track_operation(ptr, EventType::Drop)
289    }
290
291    /// Generic operation tracking with lighter sampling
292    fn track_operation(&self, ptr: usize, event_type: EventType) -> TrackingResult<()> {
293        THREAD_DATA.with(|data| {
294            let mut data = data.borrow_mut();
295            data.sample_counter += 1;
296
297            // Sample operations less frequently than allocations
298            if data.sample_counter % 10 == 0 || matches!(event_type, EventType::Drop) {
299                let event = Event {
300                    timestamp: get_timestamp(),
301                    ptr,
302                    size: 0,                     // Operations don't have size
303                    call_stack_hash: ptr as u64, // Use ptr as a simple hash for operations
304                    event_type,
305                    var_name: None,
306                    type_name: None,
307                };
308
309                data.event_buffer.push(event);
310
311                if data.event_buffer.len() >= self.config.buffer_size {
312                    data.flush_events()
313                        .map_err(|e| crate::core::types::TrackingError::IoError(e.to_string()))?;
314                }
315            }
316
317            Ok(())
318        })
319    }
320
321    /// Intelligent sampling decision based on size and frequency
322    fn should_sample(&self, size: usize, data: &mut ThreadLocalData) -> bool {
323        // Always sample large allocations
324        if size >= self.config.large_size_threshold {
325            return true;
326        }
327
328        // Medium-sized allocations: probabilistic sampling
329        if size >= self.config.medium_size_threshold {
330            return rand::random::<f64>() < self.config.medium_sample_rate;
331        }
332
333        // Small allocations: very low probability, but frequency-aware
334        data.sample_counter += 1;
335        if data.sample_counter % 100 == 0 {
336            // Every 100th small allocation gets sampled
337            return true;
338        }
339
340        // Otherwise, use configured small sample rate
341        rand::random::<f64>() < self.config.small_sample_rate
342    }
343
344    /// Calculate a simple call stack hash for frequency tracking
345    fn calculate_call_stack_hash(&self, var_name: &str, type_name: &str) -> u64 {
346        use std::collections::hash_map::DefaultHasher;
347        use std::hash::{Hash, Hasher};
348
349        let mut hasher = DefaultHasher::new();
350        var_name.hash(&mut hasher);
351        type_name.hash(&mut hasher);
352        // In a more sophisticated implementation, we'd include actual call stack
353        hasher.finish()
354    }
355
356    /// Flush all pending data for the current thread
357    pub fn flush_current_thread(&self) -> TrackingResult<()> {
358        THREAD_DATA.with(|data| {
359            let mut data = data.borrow_mut();
360            data.flush_events()
361                .map_err(|e| crate::core::types::TrackingError::IoError(e.to_string()))?;
362            data.flush_frequencies()
363                .map_err(|e| crate::core::types::TrackingError::IoError(e.to_string()))?;
364            Ok(())
365        })
366    }
367
368    /// Get current thread's basic statistics
369    pub fn get_current_thread_stats(&self) -> ThreadStats {
370        THREAD_DATA.with(|data| {
371            let data = data.borrow();
372            ThreadStats {
373                thread_id: data.thread_id.clone(),
374                total_operations: data.total_operations,
375                events_buffered: data.event_buffer.len(),
376                unique_call_stacks: data.call_stack_frequencies.len(),
377            }
378        })
379    }
380}
381
382impl Default for SamplingTracker {
383    fn default() -> Self {
384        Self::new()
385    }
386}
387
388/// Basic thread statistics
389#[derive(Debug, Clone)]
390pub struct ThreadStats {
391    pub thread_id: String,
392    pub total_operations: u64,
393    pub events_buffered: usize,
394    pub unique_call_stacks: usize,
395}
396
397/// Get current timestamp in nanoseconds
398fn get_timestamp() -> u64 {
399    SystemTime::now()
400        .duration_since(UNIX_EPOCH)
401        .unwrap_or_default()
402        .as_nanos() as u64
403}
404
405/// Global sampling tracker instance
406static GLOBAL_SAMPLING_TRACKER: std::sync::OnceLock<SamplingTracker> = std::sync::OnceLock::new();
407
408/// Get the global sampling tracker
409pub fn get_sampling_tracker() -> &'static SamplingTracker {
410    GLOBAL_SAMPLING_TRACKER.get_or_init(SamplingTracker::new)
411}
412
413/// Initialize the sampling tracker with custom configuration
414pub fn init_sampling_tracker(config: SamplingConfig) {
415    GLOBAL_SAMPLING_TRACKER
416        .set(SamplingTracker::with_config(config))
417        .ok();
418}
419
420#[cfg(test)]
421mod tests {
422    use super::*;
423    use std::sync::Arc;
424    use std::thread;
425
426    #[test]
427    fn test_basic_sampling_tracker() {
428        let tracker = SamplingTracker::new();
429
430        // Test variable tracking
431        tracker
432            .track_variable(0x1000, 1024, "test_var".to_string(), "Vec<i32>".to_string())
433            .unwrap();
434
435        // Test operations
436        tracker.track_access(0x1000).unwrap();
437        tracker.track_modify(0x1000).unwrap();
438        tracker.track_drop(0x1000).unwrap();
439
440        let stats = tracker.get_current_thread_stats();
441        assert!(stats.total_operations > 0);
442
443        // Flush data
444        tracker.flush_current_thread().unwrap();
445    }
446
447    #[test]
448    fn test_intelligent_sampling() {
449        let config = SamplingConfig {
450            large_size_threshold: 100,
451            medium_size_threshold: 50,
452            medium_sample_rate: 1.0, // 100% for testing
453            small_sample_rate: 0.0,  // 0% for testing
454            buffer_size: 10,
455        };
456
457        let tracker = SamplingTracker::with_config(config);
458
459        // Large allocation should always be sampled
460        tracker
461            .track_variable(0x1000, 200, "large_var".to_string(), "Vec<u8>".to_string())
462            .unwrap();
463
464        // Medium allocation should be sampled (100% rate)
465        tracker
466            .track_variable(0x2000, 75, "medium_var".to_string(), "String".to_string())
467            .unwrap();
468
469        let stats = tracker.get_current_thread_stats();
470        assert_eq!(stats.total_operations, 2);
471    }
472
473    #[test]
474    fn test_multithread_sampling() {
475        let tracker = Arc::new(SamplingTracker::new());
476        let mut handles = vec![];
477
478        // Test with multiple threads
479        for i in 0..5 {
480            let tracker_clone = tracker.clone();
481            let handle = thread::spawn(move || {
482                for j in 0..10 {
483                    let ptr = (i * 1000 + j) as usize;
484                    tracker_clone
485                        .track_variable(
486                            ptr,
487                            64,
488                            format!("thread_{}_var_{}", i, j),
489                            "TestType".to_string(),
490                        )
491                        .unwrap();
492                }
493
494                tracker_clone.flush_current_thread().unwrap();
495            });
496            handles.push(handle);
497        }
498
499        // Wait for all threads
500        for handle in handles {
501            handle.join().unwrap();
502        }
503
504        // Verify files were created
505        let files = std::fs::read_dir(".")
506            .unwrap()
507            .filter_map(|entry| entry.ok())
508            .filter(|entry| {
509                entry
510                    .file_name()
511                    .to_str()
512                    .map(|name| name.starts_with("memscope_thread_"))
513                    .unwrap_or(false)
514            })
515            .count();
516
517        assert!(files >= 5); // At least 5 thread files should be created
518    }
519}