memscope_rs/core/
allocator.rs1use crate::core::enhanced_type_inference::TypeInferenceEngine;
4use std::alloc::{GlobalAlloc, Layout, System};
5use std::sync::Mutex;
6
7pub struct TrackingAllocator;
12
13static _TYPE_INFERENCE_ENGINE: std::sync::LazyLock<Mutex<TypeInferenceEngine>> =
15 std::sync::LazyLock::new(|| Mutex::new(TypeInferenceEngine::new()));
16
17impl TrackingAllocator {
18 pub const fn new() -> Self {
20 Self
21 }
22
23 fn _infer_type_from_allocation_context(size: usize) -> &'static str {
25 match size {
27 1 => "u8",
29 2 => "u16",
30 4 => "u32",
31 8 => "u64",
32 16 => "u128",
33
34 24 => "String",
36 32 => "Vec<T>",
37 48 => "HashMap<K,V>",
38
39 size if size == std::mem::size_of::<std::sync::Arc<String>>() => "Arc<T>",
41 size if size == std::mem::size_of::<std::rc::Rc<String>>() => "Rc<T>",
42 size if size == std::mem::size_of::<Box<String>>() => "Box<T>",
43
44 _ => "unknown",
46 }
47 }
48
49 fn _get_simplified_call_stack() -> Vec<String> {
53 vec!["global_allocator".to_string(), "system_alloc".to_string()]
56 }
57
58 fn _infer_variable_from_allocation_context(size: usize) -> &'static str {
60 match size {
62 1..=8 => "primitive_data",
64
65 9..=64 => "struct_data",
67
68 65..=1024 => "collection_data",
70
71 _ => "buffer_data",
73 }
74 }
75}
76
77thread_local! {
79 static TRACKING_DISABLED: std::cell::Cell<bool> = const { std::cell::Cell::new(false) };
80}
81
82unsafe impl GlobalAlloc for TrackingAllocator {
83 unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
84 let ptr = System.alloc(layout);
86
87 if !ptr.is_null() {
89 let should_track = TRACKING_DISABLED.with(|disabled| !disabled.get());
91
92 if should_track {
93 TRACKING_DISABLED.with(|disabled| disabled.set(true));
95
96 if let Ok(tracker) =
98 std::panic::catch_unwind(crate::core::tracker::get_global_tracker)
99 {
100 let _ = tracker.track_allocation(ptr as usize, layout.size());
102 }
103
104 TRACKING_DISABLED.with(|disabled| disabled.set(false));
106 }
107 }
108
109 ptr
110 }
111
112 unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
113 let should_track = TRACKING_DISABLED.with(|disabled| !disabled.get());
115
116 if should_track {
117 TRACKING_DISABLED.with(|disabled| disabled.set(true));
119
120 if let Ok(tracker) = std::panic::catch_unwind(crate::core::tracker::get_global_tracker)
122 {
123 let _ = tracker.track_deallocation(ptr as usize);
125 }
126
127 TRACKING_DISABLED.with(|disabled| disabled.set(false));
129 }
130
131 System.dealloc(ptr, layout);
133 }
134}
135
136impl Default for TrackingAllocator {
137 fn default() -> Self {
138 Self::new()
139 }
140}
141
142#[cfg(test)]
143mod tests {
144 use super::*;
145 use std::alloc::{GlobalAlloc, Layout};
146 use std::sync::atomic::{AtomicBool, Ordering};
147 use std::sync::Once;
148
149 fn reset_thread_local_state() {
151 TRACKING_DISABLED.with(|disabled| disabled.set(false));
152 }
153
154 #[test]
155 fn test_allocation_tracking() {
156 let allocator = TrackingAllocator::new();
157 let layout = Layout::from_size_align(1024, 8).unwrap();
158
159 unsafe {
160 let ptr = allocator.alloc(layout);
161 assert!(!ptr.is_null());
162
163 allocator.dealloc(ptr, layout);
165 }
166 }
167
168 #[test]
169 fn test_zero_sized_allocation() {
170 let allocator = TrackingAllocator::new();
171 let layout = Layout::from_size_align(0, 1).unwrap();
172
173 unsafe {
174 let ptr = allocator.alloc(layout);
175 allocator.dealloc(ptr, layout);
178 }
179 }
180
181 #[test]
182 fn test_large_allocation() {
183 let allocator = TrackingAllocator::new();
184 let layout = Layout::from_size_align(1024 * 1024, 8).unwrap(); unsafe {
187 let ptr = allocator.alloc(layout);
188 if !ptr.is_null() {
189 allocator.dealloc(ptr, layout);
191 }
192 }
193 }
194
195 #[test]
196 fn test_multiple_allocations() {
197 let allocator = TrackingAllocator::new();
198 let mut ptrs = Vec::new();
199
200 for i in 1..=10 {
202 let layout = Layout::from_size_align(i * 64, 8).unwrap();
203 unsafe {
204 let ptr = allocator.alloc(layout);
205 if !ptr.is_null() {
206 ptrs.push((ptr, layout));
207 }
208 }
209 }
210
211 for (ptr, layout) in ptrs {
213 unsafe {
214 allocator.dealloc(ptr, layout);
215 }
216 }
217 }
218
219 #[test]
220 fn test_type_inference_from_size() {
221 assert_eq!(
223 TrackingAllocator::_infer_type_from_allocation_context(1),
224 "u8"
225 );
226 assert_eq!(
227 TrackingAllocator::_infer_type_from_allocation_context(4),
228 "u32"
229 );
230 assert_eq!(
231 TrackingAllocator::_infer_type_from_allocation_context(8),
232 "u64"
233 );
234 assert_eq!(
235 TrackingAllocator::_infer_type_from_allocation_context(24),
236 "String"
237 );
238 assert_eq!(
239 TrackingAllocator::_infer_type_from_allocation_context(32),
240 "Vec<T>"
241 );
242 assert_eq!(
243 TrackingAllocator::_infer_type_from_allocation_context(999),
244 "unknown"
245 );
246 }
247
248 #[test]
249 fn test_variable_inference_from_size() {
250 assert_eq!(
252 TrackingAllocator::_infer_variable_from_allocation_context(4),
253 "primitive_data"
254 );
255 assert_eq!(
256 TrackingAllocator::_infer_variable_from_allocation_context(32),
257 "struct_data"
258 );
259 assert_eq!(
260 TrackingAllocator::_infer_variable_from_allocation_context(512),
261 "collection_data"
262 );
263 assert_eq!(
264 TrackingAllocator::_infer_variable_from_allocation_context(2048),
265 "buffer_data"
266 );
267 }
268
269 #[test]
270 fn test_default_implementation() {
271 let allocator = TrackingAllocator::new();
272 assert_eq!(
273 std::mem::size_of_val(&allocator),
274 std::mem::size_of::<TrackingAllocator>()
275 );
276 }
277
278 #[test]
279 fn test_type_inference() {
280 assert_eq!(
282 TrackingAllocator::_infer_type_from_allocation_context(1),
283 "u8"
284 );
285 assert_eq!(
286 TrackingAllocator::_infer_type_from_allocation_context(2),
287 "u16"
288 );
289 assert_eq!(
290 TrackingAllocator::_infer_type_from_allocation_context(4),
291 "u32"
292 );
293 assert_eq!(
294 TrackingAllocator::_infer_type_from_allocation_context(8),
295 "u64"
296 );
297 assert_eq!(
298 TrackingAllocator::_infer_type_from_allocation_context(16),
299 "u128"
300 );
301 assert_eq!(
302 TrackingAllocator::_infer_type_from_allocation_context(24),
303 "String"
304 );
305 assert_eq!(
306 TrackingAllocator::_infer_type_from_allocation_context(32),
307 "Vec<T>"
308 );
309 assert_eq!(
310 TrackingAllocator::_infer_type_from_allocation_context(48),
311 "HashMap<K,V>"
312 );
313
314 assert_eq!(
316 TrackingAllocator::_infer_type_from_allocation_context(12345),
317 "unknown"
318 );
319 }
320
321 #[test]
322 fn test_variable_inference() {
323 assert_eq!(
325 TrackingAllocator::_infer_variable_from_allocation_context(0),
326 "buffer_data"
327 );
328 assert_eq!(
329 TrackingAllocator::_infer_variable_from_allocation_context(4),
330 "primitive_data"
331 );
332 assert_eq!(
333 TrackingAllocator::_infer_variable_from_allocation_context(8),
334 "primitive_data"
335 );
336 assert_eq!(
337 TrackingAllocator::_infer_variable_from_allocation_context(16),
338 "struct_data"
339 );
340 assert_eq!(
341 TrackingAllocator::_infer_variable_from_allocation_context(32),
342 "struct_data"
343 );
344 assert_eq!(
345 TrackingAllocator::_infer_variable_from_allocation_context(64),
346 "struct_data"
347 );
348 assert_eq!(
349 TrackingAllocator::_infer_variable_from_allocation_context(65),
350 "collection_data"
351 );
352 assert_eq!(
353 TrackingAllocator::_infer_variable_from_allocation_context(128),
354 "collection_data"
355 );
356 assert_eq!(
357 TrackingAllocator::_infer_variable_from_allocation_context(1024),
358 "collection_data"
359 );
360 assert_eq!(
361 TrackingAllocator::_infer_variable_from_allocation_context(1025),
362 "buffer_data"
363 );
364 assert_eq!(
365 TrackingAllocator::_infer_variable_from_allocation_context(usize::MAX),
366 "buffer_data"
367 );
368 }
369
370 #[test]
371 fn test_thread_local_tracking() {
372 reset_thread_local_state();
373
374 TRACKING_DISABLED.with(|disabled| {
376 assert!(!disabled.get());
377 });
378
379 TRACKING_DISABLED.with(|disabled| {
381 disabled.set(true);
382 assert!(disabled.get());
383 disabled.set(false);
384 });
385 }
386
387 #[test]
388 fn test_simplified_call_stack() {
389 let stack = TrackingAllocator::_get_simplified_call_stack();
390 assert_eq!(stack.len(), 2);
391 assert_eq!(stack[0], "global_allocator");
392 assert_eq!(stack[1], "system_alloc");
393 }
394
395 #[test]
396 fn test_allocation_edge_cases() {
397 let allocator = TrackingAllocator::new();
398
399 let max_align = std::mem::size_of::<usize>() * 2;
401 let layout = Layout::from_size_align(16, max_align).unwrap();
402
403 unsafe {
404 let ptr = allocator.alloc(layout);
405 if !ptr.is_null() {
406 assert_eq!((ptr as usize) % max_align, 0);
408 allocator.dealloc(ptr, layout);
409 }
410 }
411
412 let layout = Layout::from_size_align(1, 1).unwrap();
414 unsafe {
415 let ptr = allocator.alloc(layout);
416 if !ptr.is_null() {
417 allocator.dealloc(ptr, layout);
418 }
419 }
420 }
421
422 #[test]
423 fn test_recursive_allocation_handling() {
424 let allocator = TrackingAllocator::new();
426 let layout = Layout::from_size_align(64, 8).unwrap();
427
428 static RECURSION_DETECTED: AtomicBool = AtomicBool::new(false);
430 static INIT: Once = Once::new();
431
432 INIT.call_once(|| {
433 let original_hook = std::panic::take_hook();
435 std::panic::set_hook(Box::new(move |panic_info| {
436 if let Some(s) = panic_info.payload().downcast_ref::<&str>() {
437 if s.contains("stack overflow") {
438 RECURSION_DETECTED.store(true, Ordering::SeqCst);
439 }
440 }
441 original_hook(panic_info);
442 }));
443 });
444
445 unsafe {
447 let ptr = allocator.alloc(layout);
448 if !ptr.is_null() {
449 allocator.dealloc(ptr, layout);
450 }
451 }
452
453 assert!(
455 !RECURSION_DETECTED.load(Ordering::SeqCst),
456 "Recursive allocation detected - thread-local tracking failed"
457 );
458 }
459}