use std::{
alloc::{GlobalAlloc, Layout},
cell::{Cell, UnsafeCell},
mem,
sync::{Arc, OnceLock},
};
use heapless::FnvIndexMap;
use tracing_core::span;
use crate::events::{EventQueue, SpanMemoryUpdateEvent};
const MAX_GROUPS: usize = 4;
const GROUP_STACK_SIZE: usize = 128;
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)]
#[doc(hidden)]
pub struct AllocationGroup(u64);
impl From<&span::Id> for AllocationGroup {
#[inline]
fn from(id: &span::Id) -> Self {
AllocationGroup(id.into_u64())
}
}
impl From<span::Id> for AllocationGroup {
#[inline]
fn from(id: span::Id) -> Self {
AllocationGroup(id.into_u64())
}
}
impl From<AllocationGroup> for span::Id {
#[inline]
fn from(group: AllocationGroup) -> Self {
span::Id::from_u64(group.0)
}
}
impl AllocationGroup {
#[inline]
#[must_use]
const fn should_track(self) -> bool {
self.0 != 0
}
}
pub(crate) type ThreadGroupStatisticsMap =
FnvIndexMap<AllocationGroup, GroupAllocationStatistics, MAX_GROUPS>;
#[derive(Debug, Default, Clone, Copy)]
pub(crate) struct GroupAllocationStatistics {
pub(crate) allocation_count: u64,
pub(crate) deallocation_count: u64,
pub(crate) allocated_bytes: u64,
pub(crate) deallocated_bytes: u64,
}
impl GroupAllocationStatistics {
#[must_use]
pub(crate) fn in_use_count(&self) -> u64 {
self.allocation_count
.saturating_sub(self.deallocation_count)
}
#[must_use]
pub(crate) fn in_use_bytes(&self) -> u64 {
self.allocated_bytes.saturating_sub(self.deallocated_bytes)
}
}
pub(crate) trait AllocationTracker: Send + Sync + 'static {
fn collect(&self, allocations: ThreadGroupStatisticsMap);
}
pub(crate) struct QueueAllocTracker {
pub(crate) events: Arc<EventQueue<SpanMemoryUpdateEvent>>,
}
impl AllocationTracker for QueueAllocTracker {
fn collect(&self, allocations: ThreadGroupStatisticsMap) {
for (group, allocs) in allocations {
let span_id = span::Id::from(group);
self.events.push(SpanMemoryUpdateEvent {
span_id,
stats: allocs,
});
}
}
}
static GLOBAL_TRACKER: OnceLock<QueueAllocTracker> = OnceLock::new();
pub(crate) fn set_global_tracker(tracker: QueueAllocTracker) {
if GLOBAL_TRACKER.set(tracker).is_err() {
tracing::debug!("global allocation tracker was already set");
}
}
pub(crate) struct ThreadAllocationStatistics {
groups: ThreadGroupStatisticsMap,
}
impl ThreadAllocationStatistics {
fn flush_global(&mut self) {
if let Some(tracker) = GLOBAL_TRACKER.get() {
tracker.collect(mem::take(&mut self.groups));
} else {
self.groups.clear();
}
}
}
impl Drop for ThreadAllocationStatistics {
fn drop(&mut self) {
self.flush_global();
}
}
thread_local! {
static THREAD_GROUP_STATISTICS: UnsafeCell<ThreadAllocationStatistics> = const {
UnsafeCell::new(ThreadAllocationStatistics {
groups: ThreadGroupStatisticsMap::new(),
})
};
static THREAD_CURRENT_GROUP: Cell<AllocationGroup> = const {
Cell::new(AllocationGroup(0))
};
static THREAD_GROUP_STACK: UnsafeCell<heapless::Vec<AllocationGroup, GROUP_STACK_SIZE>> = const {
UnsafeCell::new(heapless::Vec::new())
}
}
#[doc(hidden)]
pub fn enter_allocation_group(new_group: AllocationGroup) {
let current_group = THREAD_CURRENT_GROUP.try_with(Cell::get).unwrap_or_default();
if current_group.should_track() {
let success = THREAD_GROUP_STACK
.try_with(|stack| {
unsafe { &mut *stack.get() }.push(current_group).is_ok()
})
.unwrap_or_default();
if !success {
no_track(|| {
tracing::warn!("maximum allocation group stack size reached");
});
return;
}
}
_ = THREAD_CURRENT_GROUP.try_with(|g| g.set(new_group));
}
#[doc(hidden)]
pub fn exit_allocation_group() {
let last_group = THREAD_GROUP_STACK
.try_with(|stack| {
let stack = unsafe { &mut *stack.get() };
stack.pop()
})
.ok()
.flatten()
.unwrap_or_default();
_ = THREAD_CURRENT_GROUP.try_with(|g| g.set(last_group));
}
fn current_allocation_group() -> AllocationGroup {
THREAD_CURRENT_GROUP.try_with(Cell::get).unwrap_or_default()
}
pub fn flush_thread_statistics() {
let allocs = THREAD_GROUP_STATISTICS
.try_with(|cell| {
let thread_stats = unsafe { &mut *cell.get() };
mem::take(&mut thread_stats.groups)
})
.unwrap_or_default();
if allocs.is_empty() {
return;
}
if let Some(tracker) = GLOBAL_TRACKER.get() {
tracker.collect(allocs);
}
}
fn record_allocation(group: AllocationGroup, layout: Layout) {
_ = THREAD_GROUP_STATISTICS.try_with(|cell| {
let thread_stats = unsafe { &mut *cell.get() };
if let Some(stats) = thread_stats.groups.get_mut(&group) {
stats.allocation_count += 1;
stats.allocated_bytes += layout.size() as u64;
} else {
let mut stats = GroupAllocationStatistics::default();
stats.allocation_count += 1;
stats.allocated_bytes = layout.size() as u64;
if let Err((k, v)) = thread_stats.groups.insert(group, stats) {
thread_stats.flush_global();
thread_stats.groups.insert(k, v).unwrap();
}
}
});
}
fn record_deallocation(group: AllocationGroup, layout: Layout) {
_ = THREAD_GROUP_STATISTICS.try_with(|cell| {
let thread_stats = unsafe { &mut *cell.get() };
if let Some(stats) = thread_stats.groups.get_mut(&group) {
stats.deallocation_count += 1;
stats.deallocated_bytes += layout.size() as u64;
} else {
let mut stats = GroupAllocationStatistics::default();
stats.deallocation_count += 1;
stats.deallocated_bytes = layout.size() as u64;
if let Err((k, v)) = thread_stats.groups.insert(group, stats) {
thread_stats.flush_global();
thread_stats.groups.insert(k, v).unwrap();
}
}
});
}
pub(crate) fn no_track<F, T>(f: F) -> T
where
F: FnOnce() -> T,
{
let prev_group = THREAD_CURRENT_GROUP
.try_with(Cell::take)
.unwrap_or_default();
let result = f();
_ = THREAD_CURRENT_GROUP.try_with(|c| c.set(prev_group));
result
}
#[must_use]
pub struct TrackingAllocator<A = std::alloc::System> {
inner: A,
}
impl<A> TrackingAllocator<A> {
pub const fn new(inner: A) -> Self {
TrackingAllocator { inner }
}
}
impl TrackingAllocator<std::alloc::System> {
pub const fn system() -> Self {
TrackingAllocator::new(std::alloc::System)
}
}
unsafe impl<A: GlobalAlloc> GlobalAlloc for TrackingAllocator<A> {
unsafe fn alloc(&self, object_layout: Layout) -> *mut u8 {
let (wrapped_layout, offset_to_object) = get_wrapped_layout(object_layout);
let wrapped_ptr = unsafe { self.inner.alloc(wrapped_layout) };
if wrapped_ptr.is_null() {
return std::ptr::null_mut();
}
#[allow(clippy::cast_ptr_alignment)]
let group_id_ptr = wrapped_ptr.cast::<u64>();
let group = current_allocation_group();
if group.should_track() {
record_allocation(group, object_layout);
}
unsafe {
group_id_ptr.write(group.0);
}
wrapped_ptr.wrapping_add(offset_to_object)
}
unsafe fn dealloc(&self, ptr: *mut u8, object_layout: Layout) {
let (wrapped_layout, offset_to_object) = get_wrapped_layout(object_layout);
let wrapped_ptr = ptr.wrapping_sub(offset_to_object);
#[allow(clippy::cast_ptr_alignment)]
let group_id_ptr = wrapped_ptr.cast::<u64>();
let group = AllocationGroup(unsafe { group_id_ptr.read() });
unsafe {
self.inner.dealloc(wrapped_ptr, wrapped_layout);
}
if group.should_track() {
record_deallocation(group, object_layout);
}
}
unsafe fn alloc_zeroed(&self, object_layout: Layout) -> *mut u8 {
let (wrapped_layout, offset_to_object) = get_wrapped_layout(object_layout);
let wrapped_ptr = unsafe { self.inner.alloc_zeroed(wrapped_layout) };
if wrapped_ptr.is_null() {
return std::ptr::null_mut();
}
#[allow(clippy::cast_ptr_alignment)]
let group_id_ptr = wrapped_ptr.cast::<u64>();
let group = current_allocation_group();
if group.should_track() {
record_allocation(group, object_layout);
}
unsafe {
group_id_ptr.write(group.0);
}
wrapped_ptr.wrapping_add(offset_to_object)
}
unsafe fn realloc(&self, ptr: *mut u8, object_layout: Layout, new_size: usize) -> *mut u8 {
let (wrapped_layout, offset_to_object) = get_wrapped_layout(object_layout);
let new_size = new_size + mem::size_of::<u64>();
let wrapped_ptr = ptr.wrapping_sub(offset_to_object);
#[allow(clippy::cast_ptr_alignment)]
let group_id_ptr = wrapped_ptr.cast::<u64>();
let group = AllocationGroup(unsafe { group_id_ptr.read() });
if group.should_track() {
record_deallocation(group, object_layout);
let new_layout = Layout::from_size_align(new_size, object_layout.align())
.expect("reallocation requested layout resulted in overflow");
record_allocation(group, new_layout);
}
let wrapped_ptr = unsafe { self.inner.realloc(wrapped_ptr, wrapped_layout, new_size) };
if wrapped_ptr.is_null() {
return std::ptr::null_mut();
}
#[allow(clippy::cast_ptr_alignment)]
let group_id_ptr = wrapped_ptr.cast::<u64>();
unsafe {
group_id_ptr.write(group.0);
}
wrapped_ptr.wrapping_add(offset_to_object)
}
}
fn get_wrapped_layout(layout: Layout) -> (Layout, usize) {
static HEADER_LAYOUT: Layout = Layout::new::<u64>();
let (wrapped_layout, offset_to_object) = HEADER_LAYOUT
.extend(layout)
.expect("wrapping requested layout resulted in overflow");
let wrapped_layout = wrapped_layout.pad_to_align();
(wrapped_layout, offset_to_object)
}