#[cfg(all(feature = "global-allocator", not(feature = "dhat-heap"), unix))]
use std::alloc::GlobalAlloc;
#[cfg(all(feature = "global-allocator", not(feature = "dhat-heap"), unix))]
use std::alloc::Layout;
use std::cell::Cell;
use std::ffi::CStr;
use std::future::Future;
use std::pin::Pin;
use std::ptr::NonNull;
use std::sync::Arc;
use std::sync::OnceLock;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
use std::task::Context;
use std::task::Poll;
#[cfg(feature = "dhat-heap")]
use parking_lot::Mutex;
#[allow(dead_code)] pub(crate) struct AllocationLimit {
bytes: usize,
on_exceeded: Box<dyn Fn(usize) + Send + Sync>,
exceeded: AtomicBool,
}
impl std::fmt::Debug for AllocationLimit {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"AllocationLimit {{ bytes: {}, exceeded: {} }}",
self.bytes,
self.exceeded.load(Ordering::Relaxed)
)
}
}
#[derive(Debug)]
#[allow(dead_code)] pub(crate) struct AllocationStats {
name: &'static str,
parent: Option<Arc<AllocationStats>>,
bytes_allocated: AtomicUsize,
bytes_deallocated: AtomicUsize,
bytes_zeroed: AtomicUsize,
bytes_reallocated: AtomicUsize,
allocation_limit: Arc<OnceLock<AllocationLimit>>,
}
impl AllocationStats {
fn new(name: &'static str) -> Self {
Self {
name,
parent: None,
bytes_allocated: AtomicUsize::new(0),
bytes_deallocated: AtomicUsize::new(0),
bytes_zeroed: AtomicUsize::new(0),
bytes_reallocated: AtomicUsize::new(0),
allocation_limit: Arc::new(OnceLock::new()),
}
}
fn with_parent(name: &'static str, parent: Arc<AllocationStats>) -> Self {
let allocation_limit = parent.allocation_limit.clone();
Self {
name,
parent: Some(parent),
bytes_allocated: AtomicUsize::new(0),
bytes_deallocated: AtomicUsize::new(0),
bytes_zeroed: AtomicUsize::new(0),
bytes_reallocated: AtomicUsize::new(0),
allocation_limit,
}
}
#[inline]
#[cfg(all(feature = "global-allocator", not(feature = "dhat-heap"), unix))]
pub(crate) fn name(&self) -> &'static str {
self.name
}
#[inline]
#[cfg(all(feature = "global-allocator", not(feature = "dhat-heap"), unix))]
fn track_alloc(&self, size: usize) {
let mut current = Some(self);
while let Some(stats) = current {
stats.bytes_allocated.fetch_add(size, Ordering::Relaxed);
current = stats.parent.as_ref().map(|p| p.as_ref());
}
}
#[inline]
#[cfg(all(feature = "global-allocator", not(feature = "dhat-heap"), unix))]
fn track_dealloc(&self, size: usize) {
let mut current = Some(self);
while let Some(stats) = current {
stats.bytes_deallocated.fetch_add(size, Ordering::Relaxed);
current = stats.parent.as_ref().map(|p| p.as_ref());
}
}
#[inline]
#[cfg(all(feature = "global-allocator", not(feature = "dhat-heap"), unix))]
fn track_zeroed(&self, size: usize) {
let mut current = Some(self);
while let Some(stats) = current {
stats.bytes_zeroed.fetch_add(size, Ordering::Relaxed);
current = stats.parent.as_ref().map(|p| p.as_ref());
}
}
#[inline]
#[cfg(all(feature = "global-allocator", not(feature = "dhat-heap"), unix))]
fn track_realloc(&self, size: usize) {
let mut current = Some(self);
while let Some(stats) = current {
stats.bytes_reallocated.fetch_add(size, Ordering::Relaxed);
current = stats.parent.as_ref().map(|p| p.as_ref());
}
}
#[inline]
#[cfg(all(feature = "global-allocator", not(feature = "dhat-heap"), unix))]
pub(crate) fn bytes_allocated(&self) -> usize {
self.bytes_allocated.load(Ordering::Relaxed)
}
#[inline]
#[cfg(all(feature = "global-allocator", not(feature = "dhat-heap"), unix))]
pub(crate) fn bytes_deallocated(&self) -> usize {
self.bytes_deallocated.load(Ordering::Relaxed)
}
#[inline]
#[cfg(all(feature = "global-allocator", not(feature = "dhat-heap"), unix))]
pub(crate) fn bytes_zeroed(&self) -> usize {
self.bytes_zeroed.load(Ordering::Relaxed)
}
#[inline]
#[cfg(all(feature = "global-allocator", not(feature = "dhat-heap"), unix))]
pub(crate) fn bytes_reallocated(&self) -> usize {
self.bytes_reallocated.load(Ordering::Relaxed)
}
#[inline]
#[cfg(all(feature = "global-allocator", not(feature = "dhat-heap"), unix, test))]
pub(crate) fn net_allocated(&self) -> usize {
let allocated = self.bytes_allocated();
let zeroed = self.bytes_zeroed();
let deallocated = self.bytes_deallocated();
allocated.saturating_add(zeroed).saturating_sub(deallocated)
}
pub(crate) fn set_allocation_limit(
&self,
bytes: usize,
on_exceeded: Box<dyn Fn(usize) + Send + Sync>,
) {
let _ = self.allocation_limit.set(AllocationLimit {
bytes,
on_exceeded,
exceeded: AtomicBool::new(false),
});
}
}
thread_local! {
static CURRENT_TASK_STATS: Cell<Option<NonNull<AllocationStats>>> = const { Cell::new(None) };
}
pub(crate) struct MemoryTrackedFuture<F> {
inner: F,
stats: Arc<AllocationStats>,
}
impl<F: Future> Future for MemoryTrackedFuture<F> {
type Output = F::Output;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = unsafe { self.get_unchecked_mut() };
let inner = unsafe { Pin::new_unchecked(&mut this.inner) };
let stats_ptr = unsafe { NonNull::new_unchecked(Arc::as_ptr(&this.stats) as *mut _) };
let previous = CURRENT_TASK_STATS.with(|cell| cell.replace(Some(stats_ptr)));
let result = inner.poll(cx);
CURRENT_TASK_STATS.with(|cell| cell.set(previous));
result
}
}
#[must_use]
pub(crate) fn current() -> Option<Arc<AllocationStats>> {
CURRENT_TASK_STATS.with(|cell| {
cell.get().map(|ptr| {
unsafe {
Arc::increment_strong_count(ptr.as_ptr());
Arc::from_raw(ptr.as_ptr())
}
})
})
}
#[cfg(all(feature = "global-allocator", not(feature = "dhat-heap"), unix, test))]
pub(crate) fn with_memory_tracking<F, R>(name: &'static str, f: F) -> R
where
F: FnOnce() -> R,
{
let stats = CURRENT_TASK_STATS.with(|cell| {
cell.get().map_or_else(
|| Arc::new(AllocationStats::new(name)),
|ptr| {
let parent = unsafe {
Arc::increment_strong_count(ptr.as_ptr());
Arc::from_raw(ptr.as_ptr())
};
Arc::new(AllocationStats::with_parent(name, parent))
},
)
});
with_explicit_memory_tracking(stats, f)
}
pub(crate) fn with_parented_memory_tracking<F, R>(
name: &'static str,
parent: Arc<AllocationStats>,
f: F,
) -> R
where
F: FnOnce() -> R,
{
let stats = Arc::new(AllocationStats::with_parent(name, parent));
with_explicit_memory_tracking(stats, f)
}
fn with_explicit_memory_tracking<F, R>(stats: Arc<AllocationStats>, f: F) -> R
where
F: FnOnce() -> R,
{
let stats_ptr = unsafe { NonNull::new_unchecked(Arc::as_ptr(&stats) as *mut _) };
let previous = CURRENT_TASK_STATS.with(|cell| cell.replace(Some(stats_ptr)));
let result = f();
CURRENT_TASK_STATS.with(|cell| cell.set(previous));
result
}
pub(crate) trait WithMemoryTracking: Future + Sized {
fn with_memory_tracking(self, name: &'static str) -> MemoryTrackedFuture<Self>;
}
impl<F: Future> WithMemoryTracking for F {
fn with_memory_tracking(self, name: &'static str) -> MemoryTrackedFuture<Self> {
let stats = CURRENT_TASK_STATS.with(|cell| {
cell.get().map_or_else(
|| Arc::new(AllocationStats::new(name)),
|ptr| {
let parent = unsafe {
Arc::increment_strong_count(ptr.as_ptr());
Arc::from_raw(ptr.as_ptr())
};
Arc::new(AllocationStats::with_parent(name, parent))
},
)
});
MemoryTrackedFuture { inner: self, stats }
}
}
#[cfg(all(feature = "global-allocator", not(feature = "dhat-heap"), unix))]
struct CustomAllocator {
inner: tikv_jemallocator::Jemalloc,
}
#[cfg(all(feature = "global-allocator", not(feature = "dhat-heap"), unix))]
impl CustomAllocator {
#[allow(dead_code)] const fn new() -> Self {
Self {
inner: tikv_jemallocator::Jemalloc,
}
}
}
#[cfg(all(feature = "global-allocator", not(feature = "dhat-heap"), unix))]
unsafe impl GlobalAlloc for CustomAllocator {
#[inline]
unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
unsafe {
let ptr = self.inner.alloc(layout);
if !ptr.is_null() {
CURRENT_TASK_STATS.with(|cell| {
if let Some(stats_ptr) = cell.get() {
stats_ptr.as_ref().track_alloc(layout.size());
#[cfg(all(feature = "global-allocator", not(feature = "dhat-heap"), unix))]
if let Some(limit) = stats_ptr.as_ref().allocation_limit.get() {
let bytes_allocated = stats_ptr.as_ref().bytes_allocated();
if bytes_allocated > limit.bytes
&& !limit.exceeded.load(Ordering::Relaxed)
{
limit.exceeded.store(true, Ordering::Relaxed);
(limit.on_exceeded)(bytes_allocated);
}
}
}
});
}
ptr
}
}
#[inline]
unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
unsafe {
self.inner.dealloc(ptr, layout);
CURRENT_TASK_STATS.with(|cell| {
if let Some(stats_ptr) = cell.get() {
stats_ptr.as_ref().track_dealloc(layout.size());
}
});
}
}
#[inline]
unsafe fn alloc_zeroed(&self, layout: Layout) -> *mut u8 {
unsafe {
let ptr = self.inner.alloc_zeroed(layout);
if !ptr.is_null() {
CURRENT_TASK_STATS.with(|cell| {
if let Some(stats_ptr) = cell.get() {
stats_ptr.as_ref().track_zeroed(layout.size());
}
});
}
ptr
}
}
#[inline]
unsafe fn realloc(&self, ptr: *mut u8, layout: Layout, new_size: usize) -> *mut u8 {
unsafe {
let new_ptr = self.inner.realloc(ptr, layout, new_size);
if !new_ptr.is_null() {
CURRENT_TASK_STATS.with(|cell| {
if let Some(stats_ptr) = cell.get() {
stats_ptr.as_ref().track_realloc(new_size);
}
});
}
new_ptr
}
}
}
#[cfg(all(feature = "global-allocator", not(feature = "dhat-heap"), unix))]
#[global_allocator]
static ALLOC: CustomAllocator = CustomAllocator::new();
#[cfg(feature = "dhat-heap")]
#[global_allocator]
pub(crate) static ALLOC: dhat::Alloc = dhat::Alloc;
#[cfg(feature = "dhat-heap")]
pub(crate) static DHAT_HEAP_PROFILER: Mutex<Option<dhat::Profiler>> = Mutex::new(None);
#[cfg(feature = "dhat-ad-hoc")]
pub(crate) static DHAT_AD_HOC_PROFILER: Mutex<Option<dhat::Profiler>> = Mutex::new(None);
#[cfg(feature = "dhat-heap")]
pub(crate) fn create_heap_profiler() {
*DHAT_HEAP_PROFILER.lock() = Some(dhat::Profiler::new_heap());
println!("heap profiler installed");
unsafe { libc::atexit(drop_heap_profiler) };
}
#[cfg(feature = "dhat-heap")]
#[unsafe(no_mangle)]
extern "C" fn drop_heap_profiler() {
if let Some(p) = DHAT_HEAP_PROFILER.lock().take() {
drop(p);
}
}
#[cfg(feature = "dhat-ad-hoc")]
pub(crate) fn create_ad_hoc_profiler() {
*DHAT_AD_HOC_PROFILER.lock() = Some(dhat::Profiler::new_heap());
println!("ad-hoc profiler installed");
unsafe { libc::atexit(drop_ad_hoc_profiler) };
}
#[cfg(feature = "dhat-ad-hoc")]
#[unsafe(no_mangle)]
extern "C" fn drop_ad_hoc_profiler() {
if let Some(p) = DHAT_AD_HOC_PROFILER.lock().take() {
drop(p);
}
}
#[allow(non_upper_case_globals)]
#[unsafe(export_name = "_rjem_malloc_conf")]
static malloc_conf: Option<&'static libc::c_char> = Some(unsafe {
let data: &'static CStr = c"prof:true,prof_active:false";
let ptr: *const libc::c_char = data.as_ptr();
let output: &'static libc::c_char = &*ptr;
output
});
#[cfg(test)]
#[cfg(all(feature = "global-allocator", not(feature = "dhat-heap"), unix))]
mod tests {
use std::ffi::CStr;
use std::thread;
use tokio::task;
use super::*;
#[test]
fn test_malloc_conf_is_valid_c_string() {
if let Some(conf_ptr) = malloc_conf {
let c_str = unsafe { CStr::from_ptr(conf_ptr) };
let rust_str = c_str.to_str().expect("malloc_conf should be valid UTF-8");
assert_eq!(rust_str, "prof:true,prof_active:false");
let bytes = c_str.to_bytes_with_nul();
assert!(
bytes.ends_with(&[0u8]),
"C string should be null-terminated"
);
assert_eq!(bytes, b"prof:true,prof_active:false\0");
} else {
panic!("malloc_conf should not be None");
}
}
#[tokio::test]
async fn test_async_memory_tracking() {
let result = async {
let _v = Vec::<u8>::with_capacity(10000);
current().expect("stats should be set")
}
.with_memory_tracking("test")
.await;
assert_eq!(result.name(), "test");
assert!(
result.bytes_allocated() >= 10000,
"should track at least 10000 bytes, got {}",
result.bytes_allocated()
);
assert!(
result.net_allocated() < 100,
"net allocated should be near 0 after Vec is dropped, got {}",
result.net_allocated()
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_spawned_task_memory_tracking() {
async {
let parent_stats = current().expect("stats should be set in parent");
assert_eq!(parent_stats.name(), "parent");
let child_future = async {
let child_stats = current().expect("stats should be set in child");
assert_eq!(child_stats.name(), "child");
let _v = Vec::<u8>::with_capacity(5000);
}
.with_memory_tracking("child");
task::spawn(child_future).await.unwrap();
let final_stats = current().expect("stats should still be set");
assert!(
final_stats.bytes_allocated() >= 5000,
"child task allocations should be tracked in parent, got {}",
final_stats.bytes_allocated()
);
assert!(
Arc::ptr_eq(&parent_stats, &final_stats),
"should be the same Arc"
);
}
.with_memory_tracking("parent")
.await;
}
#[test]
fn test_sync_memory_tracking() {
let stats = with_memory_tracking("sync_test", || {
let stats = current().expect("stats should be set");
assert_eq!(stats.name(), "sync_test");
{
let _v = Vec::<u8>::with_capacity(8000);
assert!(
stats.bytes_allocated() >= 8000,
"should track at least 8000 bytes, got {}",
stats.bytes_allocated()
);
}
assert!(
stats.net_allocated() < 100,
"net allocated should be near 0 after Vec is dropped, got {}",
stats.net_allocated()
);
let first_allocated = stats.bytes_allocated();
let parent_stats = stats.clone();
let handle = thread::spawn(move || {
with_parented_memory_tracking("sync_test_child", parent_stats, || {
let child_stats = current().expect("child stats should be set");
assert_eq!(child_stats.name(), "sync_test_child");
let _v = Vec::<u8>::with_capacity(3000);
})
});
handle.join().unwrap();
assert!(
stats.bytes_allocated() >= first_allocated + 3000,
"should track allocations from both contexts, got {} (expected at least {})",
stats.bytes_allocated(),
first_allocated + 3000
);
stats
});
assert!(
stats.net_allocated() < 200,
"net allocated should be near 0 after all Vecs are dropped, got {}",
stats.net_allocated()
);
}
#[tokio::test]
async fn test_nested_memory_tracking() {
async {
let root_stats = current().expect("root stats should be set");
assert_eq!(root_stats.name(), "root");
let _root_vec = Vec::<u8>::with_capacity(1000);
async {
let child_stats = current().expect("child stats should be set");
assert_eq!(child_stats.name(), "child");
let _child_vec = Vec::<u8>::with_capacity(2000);
assert!(
child_stats.bytes_allocated() >= 2000,
"child should track its own allocations, got {}",
child_stats.bytes_allocated()
);
async {
let grandchild_stats = current().expect("grandchild stats should be set");
assert_eq!(grandchild_stats.name(), "grandchild");
let _grandchild_vec = Vec::<u8>::with_capacity(3000);
assert!(
grandchild_stats.bytes_allocated() >= 3000,
"grandchild should track its own allocations, got {}",
grandchild_stats.bytes_allocated()
);
}
.with_memory_tracking("grandchild")
.await;
assert!(
child_stats.bytes_allocated() >= 5000,
"child should track child + grandchild allocations, got {}",
child_stats.bytes_allocated()
);
}
.with_memory_tracking("child")
.await;
assert!(
root_stats.bytes_allocated() >= 6000,
"root should track root + child + grandchild allocations, got {}",
root_stats.bytes_allocated()
);
}
.with_memory_tracking("root")
.await;
}
#[test]
fn test_dealloc_tracking() {
let stats = with_memory_tracking("dealloc_test", || {
let _v = Vec::<u8>::with_capacity(1000);
current().expect("stats should be set")
});
assert_eq!(stats.bytes_deallocated(), 1000);
}
#[test]
fn test_zeroed_tracking() {
let stats = with_memory_tracking("zeroed_test", || {
unsafe {
let layout = Layout::new::<u64>();
let ptr = std::alloc::alloc_zeroed(layout);
std::alloc::dealloc(ptr, layout);
}
current().expect("stats should be set")
});
assert_eq!(stats.bytes_zeroed(), 8);
}
#[test]
fn test_realloc_tracking() {
let stats = with_memory_tracking("realloc_test", || {
let layout = Layout::array::<u32>(4).unwrap();
unsafe {
let ptr = std::alloc::alloc(layout);
let new_size = 8 * std::mem::size_of::<u32>();
let new_ptr = std::alloc::realloc(ptr, layout, new_size);
let final_layout = Layout::from_size_align(new_size, layout.align()).unwrap();
std::alloc::dealloc(new_ptr, final_layout);
}
current().expect("stats should be set")
});
assert_eq!(stats.bytes_reallocated(), 32);
}
}