memscope_rs/core/
allocator.rs1use crate::core::enhanced_type_inference::TypeInferenceEngine;
4use std::alloc::{GlobalAlloc, Layout, System};
5use std::sync::Mutex;
6
7pub struct TrackingAllocator;
12
13static _TYPE_INFERENCE_ENGINE: std::sync::LazyLock<Mutex<TypeInferenceEngine>> =
15 std::sync::LazyLock::new(|| Mutex::new(TypeInferenceEngine::new()));
16
17impl TrackingAllocator {
18 pub const fn new() -> Self {
20 Self
21 }
22
23 fn _infer_type_from_allocation_context(size: usize) -> &'static str {
25 match size {
27 1 => "u8",
29 2 => "u16",
30 4 => "u32",
31 8 => "u64",
32 16 => "u128",
33
34 24 => "String",
36 32 => "Vec<T>",
37 48 => "HashMap<K,V>",
38
39 size if size == std::mem::size_of::<std::sync::Arc<String>>() => "Arc<T>",
41 size if size == std::mem::size_of::<std::rc::Rc<String>>() => "Rc<T>",
42 size if size == std::mem::size_of::<Box<String>>() => "Box<T>",
43
44 _ => "unknown",
46 }
47 }
48
49 fn _get_simplified_call_stack() -> Vec<String> {
53 vec!["global_allocator".to_string(), "system_alloc".to_string()]
56 }
57
58 fn _infer_variable_from_allocation_context(size: usize) -> &'static str {
60 match size {
62 1..=8 => "primitive_data",
64
65 9..=64 => "struct_data",
67
68 65..=1024 => "collection_data",
70
71 _ => "buffer_data",
73 }
74 }
75}
76
77thread_local! {
79 static TRACKING_DISABLED: std::cell::Cell<bool> = const { std::cell::Cell::new(false) };
80}
81
82unsafe impl GlobalAlloc for TrackingAllocator {
83 unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
84 let ptr = System.alloc(layout);
86
87 if !ptr.is_null() {
89 let should_track = TRACKING_DISABLED.with(|disabled| !disabled.get());
91
92 if should_track {
93 TRACKING_DISABLED.with(|disabled| disabled.set(true));
95
96 if let Ok(tracker) = std::panic::catch_unwind(crate::core::tracker::get_tracker) {
98 let _ = tracker.track_allocation(ptr as usize, layout.size());
100 }
101
102 TRACKING_DISABLED.with(|disabled| disabled.set(false));
104 }
105 }
106
107 ptr
108 }
109
110 unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
111 let should_track = TRACKING_DISABLED.with(|disabled| !disabled.get());
113
114 if should_track {
115 TRACKING_DISABLED.with(|disabled| disabled.set(true));
117
118 if let Ok(tracker) = std::panic::catch_unwind(crate::core::tracker::get_tracker) {
120 let _ = tracker.track_deallocation(ptr as usize);
122 }
123
124 TRACKING_DISABLED.with(|disabled| disabled.set(false));
126 }
127
128 System.dealloc(ptr, layout);
130 }
131}
132
133impl Default for TrackingAllocator {
134 fn default() -> Self {
135 Self::new()
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142 use std::alloc::{GlobalAlloc, Layout};
143 use std::sync::atomic::{AtomicBool, Ordering};
144 use std::sync::Once;
145
146 fn reset_thread_local_state() {
148 TRACKING_DISABLED.with(|disabled| disabled.set(false));
149 }
150
151 #[test]
152 fn test_allocation_tracking() {
153 let allocator = TrackingAllocator::new();
154 let layout = Layout::from_size_align(1024, 8).unwrap();
155
156 unsafe {
157 let ptr = allocator.alloc(layout);
158 assert!(!ptr.is_null());
159
160 allocator.dealloc(ptr, layout);
162 }
163 }
164
165 #[test]
166 fn test_zero_sized_allocation() {
167 let allocator = TrackingAllocator::new();
168 let layout = Layout::from_size_align(0, 1).unwrap();
169
170 unsafe {
171 let ptr = allocator.alloc(layout);
172 allocator.dealloc(ptr, layout);
175 }
176 }
177
178 #[test]
179 fn test_large_allocation() {
180 let allocator = TrackingAllocator::new();
181 let layout = Layout::from_size_align(1024 * 1024, 8).unwrap(); unsafe {
184 let ptr = allocator.alloc(layout);
185 if !ptr.is_null() {
186 allocator.dealloc(ptr, layout);
188 }
189 }
190 }
191
192 #[test]
193 fn test_multiple_allocations() {
194 let allocator = TrackingAllocator::new();
195 let mut ptrs = Vec::new();
196
197 for i in 1..=10 {
199 let layout = Layout::from_size_align(i * 64, 8).unwrap();
200 unsafe {
201 let ptr = allocator.alloc(layout);
202 if !ptr.is_null() {
203 ptrs.push((ptr, layout));
204 }
205 }
206 }
207
208 for (ptr, layout) in ptrs {
210 unsafe {
211 allocator.dealloc(ptr, layout);
212 }
213 }
214 }
215
216 #[test]
217 fn test_type_inference_from_size() {
218 assert_eq!(
220 TrackingAllocator::_infer_type_from_allocation_context(1),
221 "u8"
222 );
223 assert_eq!(
224 TrackingAllocator::_infer_type_from_allocation_context(4),
225 "u32"
226 );
227 assert_eq!(
228 TrackingAllocator::_infer_type_from_allocation_context(8),
229 "u64"
230 );
231 assert_eq!(
232 TrackingAllocator::_infer_type_from_allocation_context(24),
233 "String"
234 );
235 assert_eq!(
236 TrackingAllocator::_infer_type_from_allocation_context(32),
237 "Vec<T>"
238 );
239 assert_eq!(
240 TrackingAllocator::_infer_type_from_allocation_context(999),
241 "unknown"
242 );
243 }
244
245 #[test]
246 fn test_variable_inference_from_size() {
247 assert_eq!(
249 TrackingAllocator::_infer_variable_from_allocation_context(4),
250 "primitive_data"
251 );
252 assert_eq!(
253 TrackingAllocator::_infer_variable_from_allocation_context(32),
254 "struct_data"
255 );
256 assert_eq!(
257 TrackingAllocator::_infer_variable_from_allocation_context(512),
258 "collection_data"
259 );
260 assert_eq!(
261 TrackingAllocator::_infer_variable_from_allocation_context(2048),
262 "buffer_data"
263 );
264 }
265
266 #[test]
267 fn test_default_implementation() {
268 let allocator = TrackingAllocator::new();
269 assert_eq!(
270 std::mem::size_of_val(&allocator),
271 std::mem::size_of::<TrackingAllocator>()
272 );
273 }
274
275 #[test]
276 fn test_type_inference() {
277 assert_eq!(
279 TrackingAllocator::_infer_type_from_allocation_context(1),
280 "u8"
281 );
282 assert_eq!(
283 TrackingAllocator::_infer_type_from_allocation_context(2),
284 "u16"
285 );
286 assert_eq!(
287 TrackingAllocator::_infer_type_from_allocation_context(4),
288 "u32"
289 );
290 assert_eq!(
291 TrackingAllocator::_infer_type_from_allocation_context(8),
292 "u64"
293 );
294 assert_eq!(
295 TrackingAllocator::_infer_type_from_allocation_context(16),
296 "u128"
297 );
298 assert_eq!(
299 TrackingAllocator::_infer_type_from_allocation_context(24),
300 "String"
301 );
302 assert_eq!(
303 TrackingAllocator::_infer_type_from_allocation_context(32),
304 "Vec<T>"
305 );
306 assert_eq!(
307 TrackingAllocator::_infer_type_from_allocation_context(48),
308 "HashMap<K,V>"
309 );
310
311 assert_eq!(
313 TrackingAllocator::_infer_type_from_allocation_context(12345),
314 "unknown"
315 );
316 }
317
318 #[test]
319 fn test_variable_inference() {
320 assert_eq!(
322 TrackingAllocator::_infer_variable_from_allocation_context(0),
323 "buffer_data"
324 );
325 assert_eq!(
326 TrackingAllocator::_infer_variable_from_allocation_context(4),
327 "primitive_data"
328 );
329 assert_eq!(
330 TrackingAllocator::_infer_variable_from_allocation_context(8),
331 "primitive_data"
332 );
333 assert_eq!(
334 TrackingAllocator::_infer_variable_from_allocation_context(16),
335 "struct_data"
336 );
337 assert_eq!(
338 TrackingAllocator::_infer_variable_from_allocation_context(32),
339 "struct_data"
340 );
341 assert_eq!(
342 TrackingAllocator::_infer_variable_from_allocation_context(64),
343 "struct_data"
344 );
345 assert_eq!(
346 TrackingAllocator::_infer_variable_from_allocation_context(65),
347 "collection_data"
348 );
349 assert_eq!(
350 TrackingAllocator::_infer_variable_from_allocation_context(128),
351 "collection_data"
352 );
353 assert_eq!(
354 TrackingAllocator::_infer_variable_from_allocation_context(1024),
355 "collection_data"
356 );
357 assert_eq!(
358 TrackingAllocator::_infer_variable_from_allocation_context(1025),
359 "buffer_data"
360 );
361 assert_eq!(
362 TrackingAllocator::_infer_variable_from_allocation_context(usize::MAX),
363 "buffer_data"
364 );
365 }
366
367 #[test]
368 fn test_thread_local_tracking() {
369 reset_thread_local_state();
370
371 TRACKING_DISABLED.with(|disabled| {
373 assert!(!disabled.get());
374 });
375
376 TRACKING_DISABLED.with(|disabled| {
378 disabled.set(true);
379 assert!(disabled.get());
380 disabled.set(false);
381 });
382 }
383
384 #[test]
385 fn test_simplified_call_stack() {
386 let stack = TrackingAllocator::_get_simplified_call_stack();
387 assert_eq!(stack.len(), 2);
388 assert_eq!(stack[0], "global_allocator");
389 assert_eq!(stack[1], "system_alloc");
390 }
391
392 #[test]
393 fn test_allocation_edge_cases() {
394 let allocator = TrackingAllocator::new();
395
396 let max_align = std::mem::size_of::<usize>() * 2;
398 let layout = Layout::from_size_align(16, max_align).unwrap();
399
400 unsafe {
401 let ptr = allocator.alloc(layout);
402 if !ptr.is_null() {
403 assert_eq!((ptr as usize) % max_align, 0);
405 allocator.dealloc(ptr, layout);
406 }
407 }
408
409 let layout = Layout::from_size_align(1, 1).unwrap();
411 unsafe {
412 let ptr = allocator.alloc(layout);
413 if !ptr.is_null() {
414 allocator.dealloc(ptr, layout);
415 }
416 }
417 }
418
419 #[test]
420 fn test_recursive_allocation_handling() {
421 let allocator = TrackingAllocator::new();
423 let layout = Layout::from_size_align(64, 8).unwrap();
424
425 static RECURSION_DETECTED: AtomicBool = AtomicBool::new(false);
427 static INIT: Once = Once::new();
428
429 INIT.call_once(|| {
430 let original_hook = std::panic::take_hook();
432 std::panic::set_hook(Box::new(move |panic_info| {
433 if let Some(s) = panic_info.payload().downcast_ref::<&str>() {
434 if s.contains("stack overflow") {
435 RECURSION_DETECTED.store(true, Ordering::SeqCst);
436 }
437 }
438 original_hook(panic_info);
439 }));
440 });
441
442 unsafe {
444 let ptr = allocator.alloc(layout);
445 if !ptr.is_null() {
446 allocator.dealloc(ptr, layout);
447 }
448 }
449
450 assert!(
452 !RECURSION_DETECTED.load(Ordering::SeqCst),
453 "Recursive allocation detected - thread-local tracking failed"
454 );
455 }
456}