memscope_rs/core/
allocator.rs

1//! Custom global allocator for tracking memory allocations.
2
3use crate::core::enhanced_type_inference::TypeInferenceEngine;
4use std::alloc::{GlobalAlloc, Layout, System};
5use std::sync::Mutex;
6
7/// A custom allocator that tracks memory allocations and deallocations.
8///
9/// This allocator wraps the system allocator and records all allocation
10/// and deallocation events through the global memory tracker.
11pub struct TrackingAllocator;
12
13// Global type inference engine for the allocator
14static _TYPE_INFERENCE_ENGINE: std::sync::LazyLock<Mutex<TypeInferenceEngine>> =
15    std::sync::LazyLock::new(|| Mutex::new(TypeInferenceEngine::new()));
16
17impl TrackingAllocator {
18    /// Create a new tracking allocator instance.
19    pub const fn new() -> Self {
20        Self
21    }
22
23    /// Simple type inference using static strings to avoid recursive allocations
24    fn _infer_type_from_allocation_context(size: usize) -> &'static str {
25        // CRITICAL FIX: Use static strings to prevent recursive allocations
26        match size {
27            // Common Rust type sizes
28            1 => "u8",
29            2 => "u16",
30            4 => "u32",
31            8 => "u64",
32            16 => "u128",
33
34            // String and Vec common sizes
35            24 => "String",
36            32 => "Vec<T>",
37            48 => "HashMap<K,V>",
38
39            // Smart pointer sizes
40            size if size == std::mem::size_of::<std::sync::Arc<String>>() => "Arc<T>",
41            size if size == std::mem::size_of::<std::rc::Rc<String>>() => "Rc<T>",
42            size if size == std::mem::size_of::<Box<String>>() => "Box<T>",
43
44            // Default for other sizes - use static strings
45            _ => "unknown",
46        }
47    }
48
49    // REMOVED: fallback_type_inference - no longer needed with static strings
50
51    /// Get a simplified call stack for context
52    fn _get_simplified_call_stack() -> Vec<String> {
53        // For now, return a simple placeholder
54        // In a real implementation, this could use backtrace crate
55        vec!["global_allocator".to_string(), "system_alloc".to_string()]
56    }
57
58    /// Simple variable name inference using static strings to avoid recursive allocations
59    fn _infer_variable_from_allocation_context(size: usize) -> &'static str {
60        // CRITICAL FIX: Use static strings to prevent recursive allocations
61        match size {
62            // Small allocations - likely primitives
63            1..=8 => "primitive_data",
64
65            // Medium allocations - likely structs or small collections
66            9..=64 => "struct_data",
67
68            // Large allocations - likely collections or buffers
69            65..=1024 => "collection_data",
70
71            // Very large allocations - likely buffers or large data structures
72            _ => "buffer_data",
73        }
74    }
75}
76
77// Thread-local flag to prevent recursive tracking
78thread_local! {
79    static TRACKING_DISABLED: std::cell::Cell<bool> = const { std::cell::Cell::new(false) };
80}
81
82unsafe impl GlobalAlloc for TrackingAllocator {
83    unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
84        // Allocate memory first
85        let ptr = System.alloc(layout);
86
87        // Track the allocation if it succeeded and tracking is not disabled
88        if !ptr.is_null() {
89            // Check if tracking is disabled for this thread to prevent recursion
90            let should_track = TRACKING_DISABLED.with(|disabled| !disabled.get());
91
92            if should_track {
93                // Temporarily disable tracking to prevent recursion during tracking operations
94                TRACKING_DISABLED.with(|disabled| disabled.set(true));
95
96                // CRITICAL FIX: Use simple tracking like master branch to avoid recursion
97                if let Ok(tracker) = std::panic::catch_unwind(crate::core::tracker::get_tracker) {
98                    // Simple tracking without context to prevent recursive allocations
99                    let _ = tracker.track_allocation(ptr as usize, layout.size());
100                }
101
102                // Re-enable tracking
103                TRACKING_DISABLED.with(|disabled| disabled.set(false));
104            }
105        }
106
107        ptr
108    }
109
110    unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
111        // Track the deallocation first
112        let should_track = TRACKING_DISABLED.with(|disabled| !disabled.get());
113
114        if should_track {
115            // Temporarily disable tracking to prevent recursion
116            TRACKING_DISABLED.with(|disabled| disabled.set(true));
117
118            // Track the deallocation - use try_lock approach to avoid deadlocks
119            if let Ok(tracker) = std::panic::catch_unwind(crate::core::tracker::get_tracker) {
120                // Ignore errors to prevent deallocation failures from breaking the program
121                let _ = tracker.track_deallocation(ptr as usize);
122            }
123
124            // Re-enable tracking
125            TRACKING_DISABLED.with(|disabled| disabled.set(false));
126        }
127
128        // Deallocate the memory
129        System.dealloc(ptr, layout);
130    }
131}
132
133impl Default for TrackingAllocator {
134    fn default() -> Self {
135        Self::new()
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use super::*;
142    use std::alloc::{GlobalAlloc, Layout};
143    use std::sync::atomic::{AtomicBool, Ordering};
144    use std::sync::Once;
145
146    // Helper to reset thread-local state between tests
147    fn reset_thread_local_state() {
148        TRACKING_DISABLED.with(|disabled| disabled.set(false));
149    }
150
151    #[test]
152    fn test_allocation_tracking() {
153        let allocator = TrackingAllocator::new();
154        let layout = Layout::from_size_align(1024, 8).unwrap();
155
156        unsafe {
157            let ptr = allocator.alloc(layout);
158            assert!(!ptr.is_null());
159
160            // Test deallocation
161            allocator.dealloc(ptr, layout);
162        }
163    }
164
165    #[test]
166    fn test_zero_sized_allocation() {
167        let allocator = TrackingAllocator::new();
168        let layout = Layout::from_size_align(0, 1).unwrap();
169
170        unsafe {
171            let ptr = allocator.alloc(layout);
172            // Zero-sized allocations may return null or a valid pointer
173            // Both are acceptable according to the GlobalAlloc trait
174            allocator.dealloc(ptr, layout);
175        }
176    }
177
178    #[test]
179    fn test_large_allocation() {
180        let allocator = TrackingAllocator::new();
181        let layout = Layout::from_size_align(1024 * 1024, 8).unwrap(); // 1MB
182
183        unsafe {
184            let ptr = allocator.alloc(layout);
185            if !ptr.is_null() {
186                // Only test deallocation if allocation succeeded
187                allocator.dealloc(ptr, layout);
188            }
189        }
190    }
191
192    #[test]
193    fn test_multiple_allocations() {
194        let allocator = TrackingAllocator::new();
195        let mut ptrs = Vec::new();
196
197        // Allocate multiple blocks
198        for i in 1..=10 {
199            let layout = Layout::from_size_align(i * 64, 8).unwrap();
200            unsafe {
201                let ptr = allocator.alloc(layout);
202                if !ptr.is_null() {
203                    ptrs.push((ptr, layout));
204                }
205            }
206        }
207
208        // Deallocate all blocks
209        for (ptr, layout) in ptrs {
210            unsafe {
211                allocator.dealloc(ptr, layout);
212            }
213        }
214    }
215
216    #[test]
217    fn test_type_inference_from_size() {
218        // Test the static type inference
219        assert_eq!(
220            TrackingAllocator::_infer_type_from_allocation_context(1),
221            "u8"
222        );
223        assert_eq!(
224            TrackingAllocator::_infer_type_from_allocation_context(4),
225            "u32"
226        );
227        assert_eq!(
228            TrackingAllocator::_infer_type_from_allocation_context(8),
229            "u64"
230        );
231        assert_eq!(
232            TrackingAllocator::_infer_type_from_allocation_context(24),
233            "String"
234        );
235        assert_eq!(
236            TrackingAllocator::_infer_type_from_allocation_context(32),
237            "Vec<T>"
238        );
239        assert_eq!(
240            TrackingAllocator::_infer_type_from_allocation_context(999),
241            "unknown"
242        );
243    }
244
245    #[test]
246    fn test_variable_inference_from_size() {
247        // Test the static variable inference
248        assert_eq!(
249            TrackingAllocator::_infer_variable_from_allocation_context(4),
250            "primitive_data"
251        );
252        assert_eq!(
253            TrackingAllocator::_infer_variable_from_allocation_context(32),
254            "struct_data"
255        );
256        assert_eq!(
257            TrackingAllocator::_infer_variable_from_allocation_context(512),
258            "collection_data"
259        );
260        assert_eq!(
261            TrackingAllocator::_infer_variable_from_allocation_context(2048),
262            "buffer_data"
263        );
264    }
265
266    #[test]
267    fn test_default_implementation() {
268        let allocator = TrackingAllocator::new();
269        assert_eq!(
270            std::mem::size_of_val(&allocator),
271            std::mem::size_of::<TrackingAllocator>()
272        );
273    }
274
275    #[test]
276    fn test_type_inference() {
277        // Test type inference for various sizes
278        assert_eq!(
279            TrackingAllocator::_infer_type_from_allocation_context(1),
280            "u8"
281        );
282        assert_eq!(
283            TrackingAllocator::_infer_type_from_allocation_context(2),
284            "u16"
285        );
286        assert_eq!(
287            TrackingAllocator::_infer_type_from_allocation_context(4),
288            "u32"
289        );
290        assert_eq!(
291            TrackingAllocator::_infer_type_from_allocation_context(8),
292            "u64"
293        );
294        assert_eq!(
295            TrackingAllocator::_infer_type_from_allocation_context(16),
296            "u128"
297        );
298        assert_eq!(
299            TrackingAllocator::_infer_type_from_allocation_context(24),
300            "String"
301        );
302        assert_eq!(
303            TrackingAllocator::_infer_type_from_allocation_context(32),
304            "Vec<T>"
305        );
306        assert_eq!(
307            TrackingAllocator::_infer_type_from_allocation_context(48),
308            "HashMap<K,V>"
309        );
310
311        // Test unknown size
312        assert_eq!(
313            TrackingAllocator::_infer_type_from_allocation_context(12345),
314            "unknown"
315        );
316    }
317
318    #[test]
319    fn test_variable_inference() {
320        // Test variable inference for different size ranges
321        assert_eq!(
322            TrackingAllocator::_infer_variable_from_allocation_context(0),
323            "buffer_data"
324        );
325        assert_eq!(
326            TrackingAllocator::_infer_variable_from_allocation_context(4),
327            "primitive_data"
328        );
329        assert_eq!(
330            TrackingAllocator::_infer_variable_from_allocation_context(8),
331            "primitive_data"
332        );
333        assert_eq!(
334            TrackingAllocator::_infer_variable_from_allocation_context(16),
335            "struct_data"
336        );
337        assert_eq!(
338            TrackingAllocator::_infer_variable_from_allocation_context(32),
339            "struct_data"
340        );
341        assert_eq!(
342            TrackingAllocator::_infer_variable_from_allocation_context(64),
343            "struct_data"
344        );
345        assert_eq!(
346            TrackingAllocator::_infer_variable_from_allocation_context(65),
347            "collection_data"
348        );
349        assert_eq!(
350            TrackingAllocator::_infer_variable_from_allocation_context(128),
351            "collection_data"
352        );
353        assert_eq!(
354            TrackingAllocator::_infer_variable_from_allocation_context(1024),
355            "collection_data"
356        );
357        assert_eq!(
358            TrackingAllocator::_infer_variable_from_allocation_context(1025),
359            "buffer_data"
360        );
361        assert_eq!(
362            TrackingAllocator::_infer_variable_from_allocation_context(usize::MAX),
363            "buffer_data"
364        );
365    }
366
367    #[test]
368    fn test_thread_local_tracking() {
369        reset_thread_local_state();
370
371        // Test that tracking is enabled by default
372        TRACKING_DISABLED.with(|disabled| {
373            assert!(!disabled.get());
374        });
375
376        // Test disabling tracking
377        TRACKING_DISABLED.with(|disabled| {
378            disabled.set(true);
379            assert!(disabled.get());
380            disabled.set(false);
381        });
382    }
383
384    #[test]
385    fn test_simplified_call_stack() {
386        let stack = TrackingAllocator::_get_simplified_call_stack();
387        assert_eq!(stack.len(), 2);
388        assert_eq!(stack[0], "global_allocator");
389        assert_eq!(stack[1], "system_alloc");
390    }
391
392    #[test]
393    fn test_allocation_edge_cases() {
394        let allocator = TrackingAllocator::new();
395
396        // Test with maximum alignment
397        let max_align = std::mem::size_of::<usize>() * 2;
398        let layout = Layout::from_size_align(16, max_align).unwrap();
399
400        unsafe {
401            let ptr = allocator.alloc(layout);
402            if !ptr.is_null() {
403                // Test that the pointer is properly aligned
404                assert_eq!((ptr as usize) % max_align, 0);
405                allocator.dealloc(ptr, layout);
406            }
407        }
408
409        // Test with minimal size but non-zero
410        let layout = Layout::from_size_align(1, 1).unwrap();
411        unsafe {
412            let ptr = allocator.alloc(layout);
413            if !ptr.is_null() {
414                allocator.dealloc(ptr, layout);
415            }
416        }
417    }
418
419    #[test]
420    fn test_recursive_allocation_handling() {
421        // This test verifies that recursive allocations don't cause infinite loops
422        let allocator = TrackingAllocator::new();
423        let layout = Layout::from_size_align(64, 8).unwrap();
424
425        // Set up a flag to detect if we're in a recursive call
426        static RECURSION_DETECTED: AtomicBool = AtomicBool::new(false);
427        static INIT: Once = Once::new();
428
429        INIT.call_once(|| {
430            // Install a panic hook to detect if we hit a stack overflow
431            let original_hook = std::panic::take_hook();
432            std::panic::set_hook(Box::new(move |panic_info| {
433                if let Some(s) = panic_info.payload().downcast_ref::<&str>() {
434                    if s.contains("stack overflow") {
435                        RECURSION_DETECTED.store(true, Ordering::SeqCst);
436                    }
437                }
438                original_hook(panic_info);
439            }));
440        });
441
442        // This allocation will trigger tracking, but the thread-local flag should prevent recursion
443        unsafe {
444            let ptr = allocator.alloc(layout);
445            if !ptr.is_null() {
446                allocator.dealloc(ptr, layout);
447            }
448        }
449
450        // Verify we didn't hit a stack overflow
451        assert!(
452            !RECURSION_DETECTED.load(Ordering::SeqCst),
453            "Recursive allocation detected - thread-local tracking failed"
454        );
455    }
456}