memscope_rs/core/
allocator.rs

1//! Custom global allocator for tracking memory allocations.
2
3use crate::core::enhanced_type_inference::TypeInferenceEngine;
4use std::alloc::{GlobalAlloc, Layout, System};
5use std::sync::Mutex;
6
7/// A custom allocator that tracks memory allocations and deallocations.
8///
9/// This allocator wraps the system allocator and records all allocation
10/// and deallocation events through the global memory tracker.
11pub struct TrackingAllocator;
12
13// Global type inference engine for the allocator
14static _TYPE_INFERENCE_ENGINE: std::sync::LazyLock<Mutex<TypeInferenceEngine>> =
15    std::sync::LazyLock::new(|| Mutex::new(TypeInferenceEngine::new()));
16
17impl TrackingAllocator {
18    /// Create a new tracking allocator instance.
19    pub const fn new() -> Self {
20        Self
21    }
22
23    /// Simple type inference using static strings to avoid recursive allocations
24    fn _infer_type_from_allocation_context(size: usize) -> &'static str {
25        // CRITICAL FIX: Use static strings to prevent recursive allocations
26        match size {
27            // Common Rust type sizes
28            1 => "u8",
29            2 => "u16",
30            4 => "u32",
31            8 => "u64",
32            16 => "u128",
33
34            // String and Vec common sizes
35            24 => "String",
36            32 => "Vec<T>",
37            48 => "HashMap<K,V>",
38
39            // Smart pointer sizes
40            size if size == std::mem::size_of::<std::sync::Arc<String>>() => "Arc<T>",
41            size if size == std::mem::size_of::<std::rc::Rc<String>>() => "Rc<T>",
42            size if size == std::mem::size_of::<Box<String>>() => "Box<T>",
43
44            // Default for other sizes - use static strings
45            _ => "unknown",
46        }
47    }
48
49    // REMOVED: fallback_type_inference - no longer needed with static strings
50
51    /// Get a simplified call stack for context
52    fn _get_simplified_call_stack() -> Vec<String> {
53        // For now, return a simple placeholder
54        // In a real implementation, this could use backtrace crate
55        vec!["global_allocator".to_string(), "system_alloc".to_string()]
56    }
57
58    /// Simple variable name inference using static strings to avoid recursive allocations
59    fn _infer_variable_from_allocation_context(size: usize) -> &'static str {
60        // CRITICAL FIX: Use static strings to prevent recursive allocations
61        match size {
62            // Small allocations - likely primitives
63            1..=8 => "primitive_data",
64
65            // Medium allocations - likely structs or small collections
66            9..=64 => "struct_data",
67
68            // Large allocations - likely collections or buffers
69            65..=1024 => "collection_data",
70
71            // Very large allocations - likely buffers or large data structures
72            _ => "buffer_data",
73        }
74    }
75}
76
77// Thread-local flag to prevent recursive tracking
78thread_local! {
79    static TRACKING_DISABLED: std::cell::Cell<bool> = const { std::cell::Cell::new(false) };
80}
81
82unsafe impl GlobalAlloc for TrackingAllocator {
83    unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
84        // Allocate memory first
85        let ptr = System.alloc(layout);
86
87        // Track the allocation if it succeeded and tracking is not disabled
88        if !ptr.is_null() {
89            // Check if tracking is disabled for this thread to prevent recursion
90            let should_track = TRACKING_DISABLED.with(|disabled| !disabled.get());
91
92            if should_track {
93                // Temporarily disable tracking to prevent recursion during tracking operations
94                TRACKING_DISABLED.with(|disabled| disabled.set(true));
95
96                // CRITICAL FIX: Use simple tracking like master branch to avoid recursion
97                if let Ok(tracker) =
98                    std::panic::catch_unwind(crate::core::tracker::get_global_tracker)
99                {
100                    // Simple tracking without context to prevent recursive allocations
101                    let _ = tracker.track_allocation(ptr as usize, layout.size());
102                }
103
104                // Re-enable tracking
105                TRACKING_DISABLED.with(|disabled| disabled.set(false));
106            }
107        }
108
109        ptr
110    }
111
112    unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
113        // Track the deallocation first
114        let should_track = TRACKING_DISABLED.with(|disabled| !disabled.get());
115
116        if should_track {
117            // Temporarily disable tracking to prevent recursion
118            TRACKING_DISABLED.with(|disabled| disabled.set(true));
119
120            // Track the deallocation - use try_lock approach to avoid deadlocks
121            if let Ok(tracker) = std::panic::catch_unwind(crate::core::tracker::get_global_tracker)
122            {
123                // Ignore errors to prevent deallocation failures from breaking the program
124                let _ = tracker.track_deallocation(ptr as usize);
125            }
126
127            // Re-enable tracking
128            TRACKING_DISABLED.with(|disabled| disabled.set(false));
129        }
130
131        // Deallocate the memory
132        System.dealloc(ptr, layout);
133    }
134}
135
136impl Default for TrackingAllocator {
137    fn default() -> Self {
138        Self::new()
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145    use std::alloc::{GlobalAlloc, Layout};
146    use std::sync::atomic::{AtomicBool, Ordering};
147    use std::sync::Once;
148
149    // Helper to reset thread-local state between tests
150    fn reset_thread_local_state() {
151        TRACKING_DISABLED.with(|disabled| disabled.set(false));
152    }
153
154    #[test]
155    fn test_allocation_tracking() {
156        let allocator = TrackingAllocator::new();
157        let layout = Layout::from_size_align(1024, 8).unwrap();
158
159        unsafe {
160            let ptr = allocator.alloc(layout);
161            assert!(!ptr.is_null());
162
163            // Test deallocation
164            allocator.dealloc(ptr, layout);
165        }
166    }
167
168    #[test]
169    fn test_zero_sized_allocation() {
170        let allocator = TrackingAllocator::new();
171        let layout = Layout::from_size_align(0, 1).unwrap();
172
173        unsafe {
174            let ptr = allocator.alloc(layout);
175            // Zero-sized allocations may return null or a valid pointer
176            // Both are acceptable according to the GlobalAlloc trait
177            allocator.dealloc(ptr, layout);
178        }
179    }
180
181    #[test]
182    fn test_large_allocation() {
183        let allocator = TrackingAllocator::new();
184        let layout = Layout::from_size_align(1024 * 1024, 8).unwrap(); // 1MB
185
186        unsafe {
187            let ptr = allocator.alloc(layout);
188            if !ptr.is_null() {
189                // Only test deallocation if allocation succeeded
190                allocator.dealloc(ptr, layout);
191            }
192        }
193    }
194
195    #[test]
196    fn test_multiple_allocations() {
197        let allocator = TrackingAllocator::new();
198        let mut ptrs = Vec::new();
199
200        // Allocate multiple blocks
201        for i in 1..=10 {
202            let layout = Layout::from_size_align(i * 64, 8).unwrap();
203            unsafe {
204                let ptr = allocator.alloc(layout);
205                if !ptr.is_null() {
206                    ptrs.push((ptr, layout));
207                }
208            }
209        }
210
211        // Deallocate all blocks
212        for (ptr, layout) in ptrs {
213            unsafe {
214                allocator.dealloc(ptr, layout);
215            }
216        }
217    }
218
219    #[test]
220    fn test_type_inference_from_size() {
221        // Test the static type inference
222        assert_eq!(
223            TrackingAllocator::_infer_type_from_allocation_context(1),
224            "u8"
225        );
226        assert_eq!(
227            TrackingAllocator::_infer_type_from_allocation_context(4),
228            "u32"
229        );
230        assert_eq!(
231            TrackingAllocator::_infer_type_from_allocation_context(8),
232            "u64"
233        );
234        assert_eq!(
235            TrackingAllocator::_infer_type_from_allocation_context(24),
236            "String"
237        );
238        assert_eq!(
239            TrackingAllocator::_infer_type_from_allocation_context(32),
240            "Vec<T>"
241        );
242        assert_eq!(
243            TrackingAllocator::_infer_type_from_allocation_context(999),
244            "unknown"
245        );
246    }
247
248    #[test]
249    fn test_variable_inference_from_size() {
250        // Test the static variable inference
251        assert_eq!(
252            TrackingAllocator::_infer_variable_from_allocation_context(4),
253            "primitive_data"
254        );
255        assert_eq!(
256            TrackingAllocator::_infer_variable_from_allocation_context(32),
257            "struct_data"
258        );
259        assert_eq!(
260            TrackingAllocator::_infer_variable_from_allocation_context(512),
261            "collection_data"
262        );
263        assert_eq!(
264            TrackingAllocator::_infer_variable_from_allocation_context(2048),
265            "buffer_data"
266        );
267    }
268
269    #[test]
270    fn test_default_implementation() {
271        let allocator = TrackingAllocator::new();
272        assert_eq!(
273            std::mem::size_of_val(&allocator),
274            std::mem::size_of::<TrackingAllocator>()
275        );
276    }
277
278    #[test]
279    fn test_type_inference() {
280        // Test type inference for various sizes
281        assert_eq!(
282            TrackingAllocator::_infer_type_from_allocation_context(1),
283            "u8"
284        );
285        assert_eq!(
286            TrackingAllocator::_infer_type_from_allocation_context(2),
287            "u16"
288        );
289        assert_eq!(
290            TrackingAllocator::_infer_type_from_allocation_context(4),
291            "u32"
292        );
293        assert_eq!(
294            TrackingAllocator::_infer_type_from_allocation_context(8),
295            "u64"
296        );
297        assert_eq!(
298            TrackingAllocator::_infer_type_from_allocation_context(16),
299            "u128"
300        );
301        assert_eq!(
302            TrackingAllocator::_infer_type_from_allocation_context(24),
303            "String"
304        );
305        assert_eq!(
306            TrackingAllocator::_infer_type_from_allocation_context(32),
307            "Vec<T>"
308        );
309        assert_eq!(
310            TrackingAllocator::_infer_type_from_allocation_context(48),
311            "HashMap<K,V>"
312        );
313
314        // Test unknown size
315        assert_eq!(
316            TrackingAllocator::_infer_type_from_allocation_context(12345),
317            "unknown"
318        );
319    }
320
321    #[test]
322    fn test_variable_inference() {
323        // Test variable inference for different size ranges
324        assert_eq!(
325            TrackingAllocator::_infer_variable_from_allocation_context(0),
326            "buffer_data"
327        );
328        assert_eq!(
329            TrackingAllocator::_infer_variable_from_allocation_context(4),
330            "primitive_data"
331        );
332        assert_eq!(
333            TrackingAllocator::_infer_variable_from_allocation_context(8),
334            "primitive_data"
335        );
336        assert_eq!(
337            TrackingAllocator::_infer_variable_from_allocation_context(16),
338            "struct_data"
339        );
340        assert_eq!(
341            TrackingAllocator::_infer_variable_from_allocation_context(32),
342            "struct_data"
343        );
344        assert_eq!(
345            TrackingAllocator::_infer_variable_from_allocation_context(64),
346            "struct_data"
347        );
348        assert_eq!(
349            TrackingAllocator::_infer_variable_from_allocation_context(65),
350            "collection_data"
351        );
352        assert_eq!(
353            TrackingAllocator::_infer_variable_from_allocation_context(128),
354            "collection_data"
355        );
356        assert_eq!(
357            TrackingAllocator::_infer_variable_from_allocation_context(1024),
358            "collection_data"
359        );
360        assert_eq!(
361            TrackingAllocator::_infer_variable_from_allocation_context(1025),
362            "buffer_data"
363        );
364        assert_eq!(
365            TrackingAllocator::_infer_variable_from_allocation_context(usize::MAX),
366            "buffer_data"
367        );
368    }
369
370    #[test]
371    fn test_thread_local_tracking() {
372        reset_thread_local_state();
373
374        // Test that tracking is enabled by default
375        TRACKING_DISABLED.with(|disabled| {
376            assert!(!disabled.get());
377        });
378
379        // Test disabling tracking
380        TRACKING_DISABLED.with(|disabled| {
381            disabled.set(true);
382            assert!(disabled.get());
383            disabled.set(false);
384        });
385    }
386
387    #[test]
388    fn test_simplified_call_stack() {
389        let stack = TrackingAllocator::_get_simplified_call_stack();
390        assert_eq!(stack.len(), 2);
391        assert_eq!(stack[0], "global_allocator");
392        assert_eq!(stack[1], "system_alloc");
393    }
394
395    #[test]
396    fn test_allocation_edge_cases() {
397        let allocator = TrackingAllocator::new();
398
399        // Test with maximum alignment
400        let max_align = std::mem::size_of::<usize>() * 2;
401        let layout = Layout::from_size_align(16, max_align).unwrap();
402
403        unsafe {
404            let ptr = allocator.alloc(layout);
405            if !ptr.is_null() {
406                // Test that the pointer is properly aligned
407                assert_eq!((ptr as usize) % max_align, 0);
408                allocator.dealloc(ptr, layout);
409            }
410        }
411
412        // Test with minimal size but non-zero
413        let layout = Layout::from_size_align(1, 1).unwrap();
414        unsafe {
415            let ptr = allocator.alloc(layout);
416            if !ptr.is_null() {
417                allocator.dealloc(ptr, layout);
418            }
419        }
420    }
421
422    #[test]
423    fn test_recursive_allocation_handling() {
424        // This test verifies that recursive allocations don't cause infinite loops
425        let allocator = TrackingAllocator::new();
426        let layout = Layout::from_size_align(64, 8).unwrap();
427
428        // Set up a flag to detect if we're in a recursive call
429        static RECURSION_DETECTED: AtomicBool = AtomicBool::new(false);
430        static INIT: Once = Once::new();
431
432        INIT.call_once(|| {
433            // Install a panic hook to detect if we hit a stack overflow
434            let original_hook = std::panic::take_hook();
435            std::panic::set_hook(Box::new(move |panic_info| {
436                if let Some(s) = panic_info.payload().downcast_ref::<&str>() {
437                    if s.contains("stack overflow") {
438                        RECURSION_DETECTED.store(true, Ordering::SeqCst);
439                    }
440                }
441                original_hook(panic_info);
442            }));
443        });
444
445        // This allocation will trigger tracking, but the thread-local flag should prevent recursion
446        unsafe {
447            let ptr = allocator.alloc(layout);
448            if !ptr.is_null() {
449                allocator.dealloc(ptr, layout);
450            }
451        }
452
453        // Verify we didn't hit a stack overflow
454        assert!(
455            !RECURSION_DETECTED.load(Ordering::SeqCst),
456            "Recursive allocation detected - thread-local tracking failed"
457        );
458    }
459}