memscope_rs/core/
allocator.rs1use std::alloc::{GlobalAlloc, Layout, System};
4
5pub struct TrackingAllocator;
10
11impl TrackingAllocator {
12 pub const fn new() -> Self {
14 Self
15 }
16
17 fn _infer_type_from_allocation_context(size: usize) -> &'static str {
19 match size {
21 1 => "u8",
23 2 => "u16",
24 4 => "u32",
25 8 => "u64",
26 16 => "u128",
27
28 24 => "String",
30 32 => "Vec<T>",
31 48 => "HashMap<K,V>",
32
33 size if size == std::mem::size_of::<std::sync::Arc<String>>() => "Arc<T>",
35 size if size == std::mem::size_of::<std::rc::Rc<String>>() => "Rc<T>",
36 size if size == std::mem::size_of::<Box<String>>() => "Box<T>",
37
38 _ => "unknown",
40 }
41 }
42
43 fn _get_simplified_call_stack() -> Vec<String> {
47 vec!["global_allocator".to_string(), "system_alloc".to_string()]
50 }
51
52 fn _infer_variable_from_allocation_context(size: usize) -> &'static str {
54 match size {
56 1..=8 => "primitive_data",
58
59 9..=64 => "struct_data",
61
62 65..=1024 => "collection_data",
64
65 _ => "buffer_data",
67 }
68 }
69}
70
71thread_local! {
73 static TRACKING_DISABLED: std::cell::Cell<bool> = const { std::cell::Cell::new(false) };
74}
75
76unsafe impl GlobalAlloc for TrackingAllocator {
77 unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
78 let ptr = System.alloc(layout);
80
81 if !ptr.is_null() {
83 let should_track = TRACKING_DISABLED.with(|disabled| !disabled.get());
85
86 if should_track {
87 TRACKING_DISABLED.with(|disabled| disabled.set(true));
89
90 if let Ok(tracker) = std::panic::catch_unwind(crate::core::tracker::get_tracker) {
92 let _ = tracker.track_allocation(ptr as usize, layout.size());
94 }
95
96 TRACKING_DISABLED.with(|disabled| disabled.set(false));
98 }
99 }
100
101 ptr
102 }
103
104 unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
105 let should_track = TRACKING_DISABLED.with(|disabled| !disabled.get());
107
108 if should_track {
109 TRACKING_DISABLED.with(|disabled| disabled.set(true));
111
112 if let Ok(tracker) = std::panic::catch_unwind(crate::core::tracker::get_tracker) {
114 let _ = tracker.track_deallocation(ptr as usize);
116 }
117
118 TRACKING_DISABLED.with(|disabled| disabled.set(false));
120 }
121
122 System.dealloc(ptr, layout);
124 }
125}
126
127impl Default for TrackingAllocator {
128 fn default() -> Self {
129 Self::new()
130 }
131}
132
133#[cfg(test)]
134mod tests {
135 use super::*;
136 use std::alloc::{GlobalAlloc, Layout};
137 use std::sync::atomic::{AtomicBool, Ordering};
138 use std::sync::Once;
139
140 fn reset_thread_local_state() {
142 TRACKING_DISABLED.with(|disabled| disabled.set(false));
143 }
144
145 #[test]
146 fn test_allocation_tracking() {
147 let allocator = TrackingAllocator::new();
148 let layout = Layout::from_size_align(1024, 8).unwrap();
149
150 unsafe {
151 let ptr = allocator.alloc(layout);
152 assert!(!ptr.is_null());
153
154 allocator.dealloc(ptr, layout);
156 }
157 }
158
159 #[test]
160 fn test_zero_sized_allocation() {
161 let allocator = TrackingAllocator::new();
162 let layout = Layout::from_size_align(0, 1).unwrap();
163
164 unsafe {
165 let ptr = allocator.alloc(layout);
166 allocator.dealloc(ptr, layout);
169 }
170 }
171
172 #[test]
173 fn test_large_allocation() {
174 let allocator = TrackingAllocator::new();
175 let layout = Layout::from_size_align(1024 * 1024, 8).unwrap(); unsafe {
178 let ptr = allocator.alloc(layout);
179 if !ptr.is_null() {
180 allocator.dealloc(ptr, layout);
182 }
183 }
184 }
185
186 #[test]
187 fn test_multiple_allocations() {
188 let allocator = TrackingAllocator::new();
189 let mut ptrs = Vec::new();
190
191 for i in 1..=10 {
193 let layout = Layout::from_size_align(i * 64, 8).unwrap();
194 unsafe {
195 let ptr = allocator.alloc(layout);
196 if !ptr.is_null() {
197 ptrs.push((ptr, layout));
198 }
199 }
200 }
201
202 for (ptr, layout) in ptrs {
204 unsafe {
205 allocator.dealloc(ptr, layout);
206 }
207 }
208 }
209
210 #[test]
211 fn test_type_inference_from_size() {
212 assert_eq!(
214 TrackingAllocator::_infer_type_from_allocation_context(1),
215 "u8"
216 );
217 assert_eq!(
218 TrackingAllocator::_infer_type_from_allocation_context(4),
219 "u32"
220 );
221 assert_eq!(
222 TrackingAllocator::_infer_type_from_allocation_context(8),
223 "u64"
224 );
225 assert_eq!(
226 TrackingAllocator::_infer_type_from_allocation_context(24),
227 "String"
228 );
229 assert_eq!(
230 TrackingAllocator::_infer_type_from_allocation_context(32),
231 "Vec<T>"
232 );
233 assert_eq!(
234 TrackingAllocator::_infer_type_from_allocation_context(999),
235 "unknown"
236 );
237 }
238
239 #[test]
240 fn test_variable_inference_from_size() {
241 assert_eq!(
243 TrackingAllocator::_infer_variable_from_allocation_context(4),
244 "primitive_data"
245 );
246 assert_eq!(
247 TrackingAllocator::_infer_variable_from_allocation_context(32),
248 "struct_data"
249 );
250 assert_eq!(
251 TrackingAllocator::_infer_variable_from_allocation_context(512),
252 "collection_data"
253 );
254 assert_eq!(
255 TrackingAllocator::_infer_variable_from_allocation_context(2048),
256 "buffer_data"
257 );
258 }
259
260 #[test]
261 fn test_default_implementation() {
262 let allocator = TrackingAllocator::new();
263 assert_eq!(
264 std::mem::size_of_val(&allocator),
265 std::mem::size_of::<TrackingAllocator>()
266 );
267 }
268
269 #[test]
270 fn test_type_inference() {
271 assert_eq!(
273 TrackingAllocator::_infer_type_from_allocation_context(1),
274 "u8"
275 );
276 assert_eq!(
277 TrackingAllocator::_infer_type_from_allocation_context(2),
278 "u16"
279 );
280 assert_eq!(
281 TrackingAllocator::_infer_type_from_allocation_context(4),
282 "u32"
283 );
284 assert_eq!(
285 TrackingAllocator::_infer_type_from_allocation_context(8),
286 "u64"
287 );
288 assert_eq!(
289 TrackingAllocator::_infer_type_from_allocation_context(16),
290 "u128"
291 );
292 assert_eq!(
293 TrackingAllocator::_infer_type_from_allocation_context(24),
294 "String"
295 );
296 assert_eq!(
297 TrackingAllocator::_infer_type_from_allocation_context(32),
298 "Vec<T>"
299 );
300 assert_eq!(
301 TrackingAllocator::_infer_type_from_allocation_context(48),
302 "HashMap<K,V>"
303 );
304
305 assert_eq!(
307 TrackingAllocator::_infer_type_from_allocation_context(12345),
308 "unknown"
309 );
310 }
311
312 #[test]
313 fn test_variable_inference() {
314 assert_eq!(
316 TrackingAllocator::_infer_variable_from_allocation_context(0),
317 "buffer_data"
318 );
319 assert_eq!(
320 TrackingAllocator::_infer_variable_from_allocation_context(4),
321 "primitive_data"
322 );
323 assert_eq!(
324 TrackingAllocator::_infer_variable_from_allocation_context(8),
325 "primitive_data"
326 );
327 assert_eq!(
328 TrackingAllocator::_infer_variable_from_allocation_context(16),
329 "struct_data"
330 );
331 assert_eq!(
332 TrackingAllocator::_infer_variable_from_allocation_context(32),
333 "struct_data"
334 );
335 assert_eq!(
336 TrackingAllocator::_infer_variable_from_allocation_context(64),
337 "struct_data"
338 );
339 assert_eq!(
340 TrackingAllocator::_infer_variable_from_allocation_context(65),
341 "collection_data"
342 );
343 assert_eq!(
344 TrackingAllocator::_infer_variable_from_allocation_context(128),
345 "collection_data"
346 );
347 assert_eq!(
348 TrackingAllocator::_infer_variable_from_allocation_context(1024),
349 "collection_data"
350 );
351 assert_eq!(
352 TrackingAllocator::_infer_variable_from_allocation_context(1025),
353 "buffer_data"
354 );
355 assert_eq!(
356 TrackingAllocator::_infer_variable_from_allocation_context(usize::MAX),
357 "buffer_data"
358 );
359 }
360
361 #[test]
362 fn test_thread_local_tracking() {
363 reset_thread_local_state();
364
365 TRACKING_DISABLED.with(|disabled| {
367 assert!(!disabled.get());
368 });
369
370 TRACKING_DISABLED.with(|disabled| {
372 disabled.set(true);
373 assert!(disabled.get());
374 disabled.set(false);
375 });
376 }
377
378 #[test]
379 fn test_simplified_call_stack() {
380 let stack = TrackingAllocator::_get_simplified_call_stack();
381 assert_eq!(stack.len(), 2);
382 assert_eq!(stack[0], "global_allocator");
383 assert_eq!(stack[1], "system_alloc");
384 }
385
386 #[test]
387 fn test_allocation_edge_cases() {
388 let allocator = TrackingAllocator::new();
389
390 let max_align = std::mem::size_of::<usize>() * 2;
392 let layout = Layout::from_size_align(16, max_align).unwrap();
393
394 unsafe {
395 let ptr = allocator.alloc(layout);
396 if !ptr.is_null() {
397 assert_eq!((ptr as usize) % max_align, 0);
399 allocator.dealloc(ptr, layout);
400 }
401 }
402
403 let layout = Layout::from_size_align(1, 1).unwrap();
405 unsafe {
406 let ptr = allocator.alloc(layout);
407 if !ptr.is_null() {
408 allocator.dealloc(ptr, layout);
409 }
410 }
411 }
412
413 #[test]
414 fn test_recursive_allocation_handling() {
415 let allocator = TrackingAllocator::new();
417 let layout = Layout::from_size_align(64, 8).unwrap();
418
419 static RECURSION_DETECTED: AtomicBool = AtomicBool::new(false);
421 static INIT: Once = Once::new();
422
423 INIT.call_once(|| {
424 let original_hook = std::panic::take_hook();
426 std::panic::set_hook(Box::new(move |panic_info| {
427 if let Some(s) = panic_info.payload().downcast_ref::<&str>() {
428 if s.contains("stack overflow") {
429 RECURSION_DETECTED.store(true, Ordering::SeqCst);
430 }
431 }
432 original_hook(panic_info);
433 }));
434 });
435
436 unsafe {
438 let ptr = allocator.alloc(layout);
439 if !ptr.is_null() {
440 allocator.dealloc(ptr, layout);
441 }
442 }
443
444 assert!(
446 !RECURSION_DETECTED.load(Ordering::SeqCst),
447 "Recursive allocation detected - thread-local tracking failed"
448 );
449 }
450}