Skip to main content

memscope_rs/core/
allocator.rs

1//! Custom global allocator for tracking memory allocations.
2
3use std::alloc::{GlobalAlloc, Layout, System};
4
5/// A custom allocator that tracks memory allocations and deallocations.
6///
7/// This allocator wraps the system allocator and records all allocation
8/// and deallocation events through the global memory tracker.
9pub struct TrackingAllocator;
10
11impl TrackingAllocator {
12    /// Create a new tracking allocator instance.
13    pub const fn new() -> Self {
14        Self
15    }
16
17    /// Simple type inference using static strings to avoid recursive allocations
18    fn _infer_type_from_allocation_context(size: usize) -> &'static str {
19        // CRITICAL FIX: Use static strings to prevent recursive allocations
20        match size {
21            // Common Rust type sizes
22            1 => "u8",
23            2 => "u16",
24            4 => "u32",
25            8 => "u64",
26            16 => "u128",
27
28            // String and Vec common sizes
29            24 => "String",
30            32 => "Vec<T>",
31            48 => "HashMap<K,V>",
32
33            // Smart pointer sizes
34            size if size == std::mem::size_of::<std::sync::Arc<String>>() => "Arc<T>",
35            size if size == std::mem::size_of::<std::rc::Rc<String>>() => "Rc<T>",
36            size if size == std::mem::size_of::<Box<String>>() => "Box<T>",
37
38            // Default for other sizes - use static strings
39            _ => "unknown",
40        }
41    }
42
43    // REMOVED: fallback_type_inference - no longer needed with static strings
44
45    /// Get a simplified call stack for context
46    fn _get_simplified_call_stack() -> Vec<String> {
47        // For now, return a simple placeholder
48        // In a real implementation, this could use backtrace crate
49        vec!["global_allocator".to_string(), "system_alloc".to_string()]
50    }
51
52    /// Simple variable name inference using static strings to avoid recursive allocations
53    fn _infer_variable_from_allocation_context(size: usize) -> &'static str {
54        // CRITICAL FIX: Use static strings to prevent recursive allocations
55        match size {
56            // Small allocations - likely primitives
57            1..=8 => "primitive_data",
58
59            // Medium allocations - likely structs or small collections
60            9..=64 => "struct_data",
61
62            // Large allocations - likely collections or buffers
63            65..=1024 => "collection_data",
64
65            // Very large allocations - likely buffers or large data structures
66            _ => "buffer_data",
67        }
68    }
69}
70
71// Thread-local flag to prevent recursive tracking
72thread_local! {
73    static TRACKING_DISABLED: std::cell::Cell<bool> = const { std::cell::Cell::new(false) };
74}
75
76unsafe impl GlobalAlloc for TrackingAllocator {
77    unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
78        // Allocate memory first
79        let ptr = System.alloc(layout);
80
81        // Track the allocation if it succeeded and tracking is not disabled
82        if !ptr.is_null() {
83            // Check if tracking is disabled for this thread to prevent recursion
84            let should_track = TRACKING_DISABLED.with(|disabled| !disabled.get());
85
86            if should_track {
87                // Temporarily disable tracking to prevent recursion during tracking operations
88                TRACKING_DISABLED.with(|disabled| disabled.set(true));
89
90                // CRITICAL FIX: Use simple tracking like master branch to avoid recursion
91                if let Ok(tracker) = std::panic::catch_unwind(crate::core::tracker::get_tracker) {
92                    // Simple tracking without context to prevent recursive allocations
93                    let _ = tracker.track_allocation(ptr as usize, layout.size());
94                }
95
96                // Re-enable tracking
97                TRACKING_DISABLED.with(|disabled| disabled.set(false));
98            }
99        }
100
101        ptr
102    }
103
104    unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
105        // Track the deallocation first
106        let should_track = TRACKING_DISABLED.with(|disabled| !disabled.get());
107
108        if should_track {
109            // Temporarily disable tracking to prevent recursion
110            TRACKING_DISABLED.with(|disabled| disabled.set(true));
111
112            // Track the deallocation - use try_lock approach to avoid deadlocks
113            if let Ok(tracker) = std::panic::catch_unwind(crate::core::tracker::get_tracker) {
114                // Ignore errors to prevent deallocation failures from breaking the program
115                let _ = tracker.track_deallocation(ptr as usize);
116            }
117
118            // Re-enable tracking
119            TRACKING_DISABLED.with(|disabled| disabled.set(false));
120        }
121
122        // Deallocate the memory
123        System.dealloc(ptr, layout);
124    }
125}
126
127impl Default for TrackingAllocator {
128    fn default() -> Self {
129        Self::new()
130    }
131}
132
133#[cfg(test)]
134mod tests {
135    use super::*;
136    use std::alloc::{GlobalAlloc, Layout};
137    use std::sync::atomic::{AtomicBool, Ordering};
138    use std::sync::Once;
139
140    // Helper to reset thread-local state between tests
141    fn reset_thread_local_state() {
142        TRACKING_DISABLED.with(|disabled| disabled.set(false));
143    }
144
145    #[test]
146    fn test_allocation_tracking() {
147        let allocator = TrackingAllocator::new();
148        let layout = Layout::from_size_align(1024, 8).unwrap();
149
150        unsafe {
151            let ptr = allocator.alloc(layout);
152            assert!(!ptr.is_null());
153
154            // Test deallocation
155            allocator.dealloc(ptr, layout);
156        }
157    }
158
159    #[test]
160    fn test_zero_sized_allocation() {
161        let allocator = TrackingAllocator::new();
162        let layout = Layout::from_size_align(0, 1).unwrap();
163
164        unsafe {
165            let ptr = allocator.alloc(layout);
166            // Zero-sized allocations may return null or a valid pointer
167            // Both are acceptable according to the GlobalAlloc trait
168            allocator.dealloc(ptr, layout);
169        }
170    }
171
172    #[test]
173    fn test_large_allocation() {
174        let allocator = TrackingAllocator::new();
175        let layout = Layout::from_size_align(1024 * 1024, 8).unwrap(); // 1MB
176
177        unsafe {
178            let ptr = allocator.alloc(layout);
179            if !ptr.is_null() {
180                // Only test deallocation if allocation succeeded
181                allocator.dealloc(ptr, layout);
182            }
183        }
184    }
185
186    #[test]
187    fn test_multiple_allocations() {
188        let allocator = TrackingAllocator::new();
189        let mut ptrs = Vec::new();
190
191        // Allocate multiple blocks
192        for i in 1..=10 {
193            let layout = Layout::from_size_align(i * 64, 8).unwrap();
194            unsafe {
195                let ptr = allocator.alloc(layout);
196                if !ptr.is_null() {
197                    ptrs.push((ptr, layout));
198                }
199            }
200        }
201
202        // Deallocate all blocks
203        for (ptr, layout) in ptrs {
204            unsafe {
205                allocator.dealloc(ptr, layout);
206            }
207        }
208    }
209
210    #[test]
211    fn test_type_inference_from_size() {
212        // Test the static type inference
213        assert_eq!(
214            TrackingAllocator::_infer_type_from_allocation_context(1),
215            "u8"
216        );
217        assert_eq!(
218            TrackingAllocator::_infer_type_from_allocation_context(4),
219            "u32"
220        );
221        assert_eq!(
222            TrackingAllocator::_infer_type_from_allocation_context(8),
223            "u64"
224        );
225        assert_eq!(
226            TrackingAllocator::_infer_type_from_allocation_context(24),
227            "String"
228        );
229        assert_eq!(
230            TrackingAllocator::_infer_type_from_allocation_context(32),
231            "Vec<T>"
232        );
233        assert_eq!(
234            TrackingAllocator::_infer_type_from_allocation_context(999),
235            "unknown"
236        );
237    }
238
239    #[test]
240    fn test_variable_inference_from_size() {
241        // Test the static variable inference
242        assert_eq!(
243            TrackingAllocator::_infer_variable_from_allocation_context(4),
244            "primitive_data"
245        );
246        assert_eq!(
247            TrackingAllocator::_infer_variable_from_allocation_context(32),
248            "struct_data"
249        );
250        assert_eq!(
251            TrackingAllocator::_infer_variable_from_allocation_context(512),
252            "collection_data"
253        );
254        assert_eq!(
255            TrackingAllocator::_infer_variable_from_allocation_context(2048),
256            "buffer_data"
257        );
258    }
259
260    #[test]
261    fn test_default_implementation() {
262        let allocator = TrackingAllocator::new();
263        assert_eq!(
264            std::mem::size_of_val(&allocator),
265            std::mem::size_of::<TrackingAllocator>()
266        );
267    }
268
269    #[test]
270    fn test_type_inference() {
271        // Test type inference for various sizes
272        assert_eq!(
273            TrackingAllocator::_infer_type_from_allocation_context(1),
274            "u8"
275        );
276        assert_eq!(
277            TrackingAllocator::_infer_type_from_allocation_context(2),
278            "u16"
279        );
280        assert_eq!(
281            TrackingAllocator::_infer_type_from_allocation_context(4),
282            "u32"
283        );
284        assert_eq!(
285            TrackingAllocator::_infer_type_from_allocation_context(8),
286            "u64"
287        );
288        assert_eq!(
289            TrackingAllocator::_infer_type_from_allocation_context(16),
290            "u128"
291        );
292        assert_eq!(
293            TrackingAllocator::_infer_type_from_allocation_context(24),
294            "String"
295        );
296        assert_eq!(
297            TrackingAllocator::_infer_type_from_allocation_context(32),
298            "Vec<T>"
299        );
300        assert_eq!(
301            TrackingAllocator::_infer_type_from_allocation_context(48),
302            "HashMap<K,V>"
303        );
304
305        // Test unknown size
306        assert_eq!(
307            TrackingAllocator::_infer_type_from_allocation_context(12345),
308            "unknown"
309        );
310    }
311
312    #[test]
313    fn test_variable_inference() {
314        // Test variable inference for different size ranges
315        assert_eq!(
316            TrackingAllocator::_infer_variable_from_allocation_context(0),
317            "buffer_data"
318        );
319        assert_eq!(
320            TrackingAllocator::_infer_variable_from_allocation_context(4),
321            "primitive_data"
322        );
323        assert_eq!(
324            TrackingAllocator::_infer_variable_from_allocation_context(8),
325            "primitive_data"
326        );
327        assert_eq!(
328            TrackingAllocator::_infer_variable_from_allocation_context(16),
329            "struct_data"
330        );
331        assert_eq!(
332            TrackingAllocator::_infer_variable_from_allocation_context(32),
333            "struct_data"
334        );
335        assert_eq!(
336            TrackingAllocator::_infer_variable_from_allocation_context(64),
337            "struct_data"
338        );
339        assert_eq!(
340            TrackingAllocator::_infer_variable_from_allocation_context(65),
341            "collection_data"
342        );
343        assert_eq!(
344            TrackingAllocator::_infer_variable_from_allocation_context(128),
345            "collection_data"
346        );
347        assert_eq!(
348            TrackingAllocator::_infer_variable_from_allocation_context(1024),
349            "collection_data"
350        );
351        assert_eq!(
352            TrackingAllocator::_infer_variable_from_allocation_context(1025),
353            "buffer_data"
354        );
355        assert_eq!(
356            TrackingAllocator::_infer_variable_from_allocation_context(usize::MAX),
357            "buffer_data"
358        );
359    }
360
361    #[test]
362    fn test_thread_local_tracking() {
363        reset_thread_local_state();
364
365        // Test that tracking is enabled by default
366        TRACKING_DISABLED.with(|disabled| {
367            assert!(!disabled.get());
368        });
369
370        // Test disabling tracking
371        TRACKING_DISABLED.with(|disabled| {
372            disabled.set(true);
373            assert!(disabled.get());
374            disabled.set(false);
375        });
376    }
377
378    #[test]
379    fn test_simplified_call_stack() {
380        let stack = TrackingAllocator::_get_simplified_call_stack();
381        assert_eq!(stack.len(), 2);
382        assert_eq!(stack[0], "global_allocator");
383        assert_eq!(stack[1], "system_alloc");
384    }
385
386    #[test]
387    fn test_allocation_edge_cases() {
388        let allocator = TrackingAllocator::new();
389
390        // Test with maximum alignment
391        let max_align = std::mem::size_of::<usize>() * 2;
392        let layout = Layout::from_size_align(16, max_align).unwrap();
393
394        unsafe {
395            let ptr = allocator.alloc(layout);
396            if !ptr.is_null() {
397                // Test that the pointer is properly aligned
398                assert_eq!((ptr as usize) % max_align, 0);
399                allocator.dealloc(ptr, layout);
400            }
401        }
402
403        // Test with minimal size but non-zero
404        let layout = Layout::from_size_align(1, 1).unwrap();
405        unsafe {
406            let ptr = allocator.alloc(layout);
407            if !ptr.is_null() {
408                allocator.dealloc(ptr, layout);
409            }
410        }
411    }
412
413    #[test]
414    fn test_recursive_allocation_handling() {
415        // This test verifies that recursive allocations don't cause infinite loops
416        let allocator = TrackingAllocator::new();
417        let layout = Layout::from_size_align(64, 8).unwrap();
418
419        // Set up a flag to detect if we're in a recursive call
420        static RECURSION_DETECTED: AtomicBool = AtomicBool::new(false);
421        static INIT: Once = Once::new();
422
423        INIT.call_once(|| {
424            // Install a panic hook to detect if we hit a stack overflow
425            let original_hook = std::panic::take_hook();
426            std::panic::set_hook(Box::new(move |panic_info| {
427                if let Some(s) = panic_info.payload().downcast_ref::<&str>() {
428                    if s.contains("stack overflow") {
429                        RECURSION_DETECTED.store(true, Ordering::SeqCst);
430                    }
431                }
432                original_hook(panic_info);
433            }));
434        });
435
436        // This allocation will trigger tracking, but the thread-local flag should prevent recursion
437        unsafe {
438            let ptr = allocator.alloc(layout);
439            if !ptr.is_null() {
440                allocator.dealloc(ptr, layout);
441            }
442        }
443
444        // Verify we didn't hit a stack overflow
445        assert!(
446            !RECURSION_DETECTED.load(Ordering::SeqCst),
447            "Recursive allocation detected - thread-local tracking failed"
448        );
449    }
450}