1use crate::core::types::TrackingResult;
10use serde::{Deserialize, Serialize};
11use std::cell::RefCell;
12use std::collections::HashMap;
13use std::fs::OpenOptions;
14use std::io::Write;
15use std::thread;
16use std::time::{SystemTime, UNIX_EPOCH};
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct Event {
21 pub timestamp: u64,
23 pub ptr: usize,
25 pub size: usize,
27 pub call_stack_hash: u64,
29 pub event_type: EventType,
31 pub var_name: Option<String>,
33 pub type_name: Option<String>,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub enum EventType {
40 Allocate,
42 Access,
44 Modify,
46 Drop,
48 Clone { target_ptr: usize },
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct FrequencyData {
55 pub call_stack_hash: u64,
57 pub frequency: u64,
59 pub total_size: usize,
61 pub sample_var_name: String,
63 pub sample_type_name: String,
65}
66
67#[derive(Debug)]
69struct ThreadLocalData {
70 event_buffer: Vec<Event>,
72 call_stack_frequencies: HashMap<u64, (u64, usize, String, String)>, file_handle: Option<std::fs::File>,
76 thread_id: String,
78 sample_counter: u64,
80 total_operations: u64,
82}
83
84impl ThreadLocalData {
85 fn new() -> Self {
86 Self {
87 event_buffer: Vec::with_capacity(1000), call_stack_frequencies: HashMap::new(),
89 file_handle: None,
90 thread_id: format!("{:?}", thread::current().id()),
91 sample_counter: 0,
92 total_operations: 0,
93 }
94 }
95
96 fn ensure_file_handle(&mut self) -> std::io::Result<()> {
98 if self.file_handle.is_none() {
99 let filename = format!("memscope_thread_{}.bin", self.thread_id);
100 let file = OpenOptions::new()
101 .create(true)
102 .append(true)
103 .open(filename)?;
104 self.file_handle = Some(file);
105 }
106 Ok(())
107 }
108
109 fn flush_events(&mut self) -> std::io::Result<()> {
111 if self.event_buffer.is_empty() {
112 return Ok(());
113 }
114
115 self.ensure_file_handle()?;
116
117 if let Some(ref mut file) = self.file_handle {
118 let serialized = serde_json::to_vec(&self.event_buffer)
120 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
121
122 let len = serialized.len() as u32;
124 file.write_all(&len.to_le_bytes())?;
125 file.write_all(&serialized)?;
126 file.flush()?;
127 }
128
129 self.event_buffer.clear();
130 Ok(())
131 }
132
133 fn flush_frequencies(&mut self) -> std::io::Result<()> {
135 if self.call_stack_frequencies.is_empty() {
136 return Ok(());
137 }
138
139 self.ensure_file_handle()?;
140
141 let freq_data: Vec<FrequencyData> = self
142 .call_stack_frequencies
143 .iter()
144 .map(
145 |(&hash, &(freq, total_size, ref var_name, ref type_name))| FrequencyData {
146 call_stack_hash: hash,
147 frequency: freq,
148 total_size,
149 sample_var_name: var_name.clone(),
150 sample_type_name: type_name.clone(),
151 },
152 )
153 .collect();
154
155 if let Some(ref mut file) = self.file_handle {
156 let serialized = serde_json::to_vec(&freq_data)
157 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
158
159 let marker = 0xFEEDFACEu32;
161 file.write_all(&marker.to_le_bytes())?;
162 file.write_all(&(serialized.len() as u32).to_le_bytes())?;
163 file.write_all(&serialized)?;
164 file.flush()?;
165 }
166
167 self.call_stack_frequencies.clear();
168 Ok(())
169 }
170}
171
172thread_local! {
174 static THREAD_DATA: RefCell<ThreadLocalData> = RefCell::new(ThreadLocalData::new());
175}
176
177pub struct SamplingTracker {
179 config: SamplingConfig,
181}
182
183#[derive(Debug, Clone)]
185pub struct SamplingConfig {
186 pub large_size_threshold: usize,
188 pub medium_size_threshold: usize,
190 pub medium_sample_rate: f64,
192 pub small_sample_rate: f64,
194 pub buffer_size: usize,
196}
197
198impl Default for SamplingConfig {
199 fn default() -> Self {
200 Self {
201 large_size_threshold: 10 * 1024, medium_size_threshold: 1024, medium_sample_rate: 0.1, small_sample_rate: 0.01, buffer_size: 1000, }
207 }
208}
209
210impl SamplingTracker {
211 pub fn new() -> Self {
213 Self {
214 config: SamplingConfig::default(),
215 }
216 }
217
218 pub fn with_config(config: SamplingConfig) -> Self {
220 Self { config }
221 }
222
223 pub fn track_variable(
225 &self,
226 ptr: usize,
227 size: usize,
228 var_name: String,
229 type_name: String,
230 ) -> TrackingResult<()> {
231 let call_stack_hash = self.calculate_call_stack_hash(&var_name, &type_name);
232
233 THREAD_DATA.with(|data| {
234 let mut data = data.borrow_mut();
235 data.total_operations += 1;
236
237 let entry = data
239 .call_stack_frequencies
240 .entry(call_stack_hash)
241 .or_insert((0, 0, var_name.clone(), type_name.clone()));
242 entry.0 += 1; entry.1 += size; if self.should_sample(size, &mut data) {
247 let event = Event {
248 timestamp: get_timestamp(),
249 ptr,
250 size,
251 call_stack_hash,
252 event_type: EventType::Allocate,
253 var_name: Some(var_name),
254 type_name: Some(type_name),
255 };
256
257 data.event_buffer.push(event);
258
259 if data.event_buffer.len() >= self.config.buffer_size {
261 data.flush_events()
262 .map_err(|e| crate::core::types::TrackingError::IoError(e.to_string()))?;
263 }
264 }
265
266 if data.total_operations % 10000 == 0 {
268 data.flush_frequencies()
269 .map_err(|e| crate::core::types::TrackingError::IoError(e.to_string()))?;
270 }
271
272 Ok(())
273 })
274 }
275
276 pub fn track_access(&self, ptr: usize) -> TrackingResult<()> {
278 self.track_operation(ptr, EventType::Access)
279 }
280
281 pub fn track_modify(&self, ptr: usize) -> TrackingResult<()> {
283 self.track_operation(ptr, EventType::Modify)
284 }
285
286 pub fn track_drop(&self, ptr: usize) -> TrackingResult<()> {
288 self.track_operation(ptr, EventType::Drop)
289 }
290
291 fn track_operation(&self, ptr: usize, event_type: EventType) -> TrackingResult<()> {
293 THREAD_DATA.with(|data| {
294 let mut data = data.borrow_mut();
295 data.sample_counter += 1;
296
297 if data.sample_counter % 10 == 0 || matches!(event_type, EventType::Drop) {
299 let event = Event {
300 timestamp: get_timestamp(),
301 ptr,
302 size: 0, call_stack_hash: ptr as u64, event_type,
305 var_name: None,
306 type_name: None,
307 };
308
309 data.event_buffer.push(event);
310
311 if data.event_buffer.len() >= self.config.buffer_size {
312 data.flush_events()
313 .map_err(|e| crate::core::types::TrackingError::IoError(e.to_string()))?;
314 }
315 }
316
317 Ok(())
318 })
319 }
320
321 fn should_sample(&self, size: usize, data: &mut ThreadLocalData) -> bool {
323 if size >= self.config.large_size_threshold {
325 return true;
326 }
327
328 if size >= self.config.medium_size_threshold {
330 return rand::random::<f64>() < self.config.medium_sample_rate;
331 }
332
333 data.sample_counter += 1;
335 if data.sample_counter % 100 == 0 {
336 return true;
338 }
339
340 rand::random::<f64>() < self.config.small_sample_rate
342 }
343
344 fn calculate_call_stack_hash(&self, var_name: &str, type_name: &str) -> u64 {
346 use std::collections::hash_map::DefaultHasher;
347 use std::hash::{Hash, Hasher};
348
349 let mut hasher = DefaultHasher::new();
350 var_name.hash(&mut hasher);
351 type_name.hash(&mut hasher);
352 hasher.finish()
354 }
355
356 pub fn flush_current_thread(&self) -> TrackingResult<()> {
358 THREAD_DATA.with(|data| {
359 let mut data = data.borrow_mut();
360 data.flush_events()
361 .map_err(|e| crate::core::types::TrackingError::IoError(e.to_string()))?;
362 data.flush_frequencies()
363 .map_err(|e| crate::core::types::TrackingError::IoError(e.to_string()))?;
364 Ok(())
365 })
366 }
367
368 pub fn get_current_thread_stats(&self) -> ThreadStats {
370 THREAD_DATA.with(|data| {
371 let data = data.borrow();
372 ThreadStats {
373 thread_id: data.thread_id.clone(),
374 total_operations: data.total_operations,
375 events_buffered: data.event_buffer.len(),
376 unique_call_stacks: data.call_stack_frequencies.len(),
377 }
378 })
379 }
380}
381
382impl Default for SamplingTracker {
383 fn default() -> Self {
384 Self::new()
385 }
386}
387
388#[derive(Debug, Clone)]
390pub struct ThreadStats {
391 pub thread_id: String,
392 pub total_operations: u64,
393 pub events_buffered: usize,
394 pub unique_call_stacks: usize,
395}
396
397fn get_timestamp() -> u64 {
399 SystemTime::now()
400 .duration_since(UNIX_EPOCH)
401 .unwrap_or_default()
402 .as_nanos() as u64
403}
404
405static GLOBAL_SAMPLING_TRACKER: std::sync::OnceLock<SamplingTracker> = std::sync::OnceLock::new();
407
408pub fn get_sampling_tracker() -> &'static SamplingTracker {
410 GLOBAL_SAMPLING_TRACKER.get_or_init(SamplingTracker::new)
411}
412
413pub fn init_sampling_tracker(config: SamplingConfig) {
415 GLOBAL_SAMPLING_TRACKER
416 .set(SamplingTracker::with_config(config))
417 .ok();
418}
419
420#[cfg(test)]
421mod tests {
422 use super::*;
423 use std::sync::Arc;
424 use std::thread;
425
426 #[test]
427 fn test_basic_sampling_tracker() {
428 let tracker = SamplingTracker::new();
429
430 tracker
432 .track_variable(0x1000, 1024, "test_var".to_string(), "Vec<i32>".to_string())
433 .unwrap();
434
435 tracker.track_access(0x1000).unwrap();
437 tracker.track_modify(0x1000).unwrap();
438 tracker.track_drop(0x1000).unwrap();
439
440 let stats = tracker.get_current_thread_stats();
441 assert!(stats.total_operations > 0);
442
443 tracker.flush_current_thread().unwrap();
445 }
446
447 #[test]
448 fn test_intelligent_sampling() {
449 let config = SamplingConfig {
450 large_size_threshold: 100,
451 medium_size_threshold: 50,
452 medium_sample_rate: 1.0, small_sample_rate: 0.0, buffer_size: 10,
455 };
456
457 let tracker = SamplingTracker::with_config(config);
458
459 tracker
461 .track_variable(0x1000, 200, "large_var".to_string(), "Vec<u8>".to_string())
462 .unwrap();
463
464 tracker
466 .track_variable(0x2000, 75, "medium_var".to_string(), "String".to_string())
467 .unwrap();
468
469 let stats = tracker.get_current_thread_stats();
470 assert_eq!(stats.total_operations, 2);
471 }
472
473 #[test]
474 fn test_multithread_sampling() {
475 let tracker = Arc::new(SamplingTracker::new());
476 let mut handles = vec![];
477
478 for i in 0..5 {
480 let tracker_clone = tracker.clone();
481 let handle = thread::spawn(move || {
482 for j in 0..10 {
483 let ptr = (i * 1000 + j) as usize;
484 tracker_clone
485 .track_variable(
486 ptr,
487 64,
488 format!("thread_{}_var_{}", i, j),
489 "TestType".to_string(),
490 )
491 .unwrap();
492 }
493
494 tracker_clone.flush_current_thread().unwrap();
495 });
496 handles.push(handle);
497 }
498
499 for handle in handles {
501 handle.join().unwrap();
502 }
503
504 let files = std::fs::read_dir(".")
506 .unwrap()
507 .filter_map(|entry| entry.ok())
508 .filter(|entry| {
509 entry
510 .file_name()
511 .to_str()
512 .map(|name| name.starts_with("memscope_thread_"))
513 .unwrap_or(false)
514 })
515 .count();
516
517 assert!(files >= 5); }
519}