use std::{
cell::RefCell,
num::NonZeroUsize,
sync::atomic::{AtomicUsize, Ordering},
};
use crate::{stack::GroupStack, util::PhantomNotSend};
thread_local! {
pub(crate) static LOCAL_ALLOCATION_GROUP_STACK: RefCell<GroupStack> =
RefCell::new(GroupStack::new());
}
fn push_group_to_stack(group: AllocationGroupId) {
LOCAL_ALLOCATION_GROUP_STACK.with(|stack| stack.borrow_mut().push(group));
}
fn pop_group_from_stack() -> AllocationGroupId {
LOCAL_ALLOCATION_GROUP_STACK.with(|stack| stack.borrow_mut().pop())
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct AllocationGroupId(NonZeroUsize);
impl AllocationGroupId {
pub(crate) fn from_raw(id: usize) -> Option<Self> {
NonZeroUsize::new(id).map(Self)
}
}
impl AllocationGroupId {
pub const ROOT: Self = Self(unsafe { NonZeroUsize::new_unchecked(1) });
#[must_use]
pub const fn as_usize(&self) -> NonZeroUsize {
self.0
}
fn register() -> Option<AllocationGroupId> {
static GROUP_ID: AtomicUsize = AtomicUsize::new(AllocationGroupId::ROOT.0.get() + 1);
static HIGHEST_GROUP_ID: AtomicUsize =
AtomicUsize::new(AllocationGroupId::ROOT.0.get() + 1);
let group_id = GROUP_ID.fetch_add(1, Ordering::Relaxed);
let highest_group_id = HIGHEST_GROUP_ID.fetch_max(group_id, Ordering::AcqRel);
if group_id >= highest_group_id {
let group_id = NonZeroUsize::new(group_id).expect("bug: GROUP_ID overflowed");
Some(AllocationGroupId(group_id))
} else {
None
}
}
}
pub struct AllocationGroupToken(AllocationGroupId);
impl AllocationGroupToken {
pub fn register() -> Option<AllocationGroupToken> {
AllocationGroupId::register().map(AllocationGroupToken)
}
#[must_use]
pub fn id(&self) -> AllocationGroupId {
self.0.clone()
}
#[cfg(feature = "tracing-compat")]
pub(crate) fn into_unsafe(self) -> UnsafeAllocationGroupToken {
UnsafeAllocationGroupToken::new(self.0)
}
pub fn enter(&mut self) -> AllocationGuard<'_> {
AllocationGuard::enter(self)
}
}
#[cfg(feature = "tracing-compat")]
#[cfg_attr(docsrs, doc(cfg(feature = "tracing-compat")))]
impl AllocationGroupToken {
pub fn attach_to_span(self, span: &tracing::Span) {
use crate::tracing::WithAllocationGroup;
let mut unsafe_token = Some(self.into_unsafe());
tracing::dispatcher::get_default(move |dispatch| {
if let Some(id) = span.id() {
if let Some(ctx) = dispatch.downcast_ref::<WithAllocationGroup>() {
let unsafe_token = unsafe_token.take().expect("token already consumed");
ctx.with_allocation_group(dispatch, &id, unsafe_token);
}
}
});
}
}
pub struct AllocationGuard<'token> {
token: &'token mut AllocationGroupToken,
_ns: PhantomNotSend,
}
impl<'token> AllocationGuard<'token> {
pub(crate) fn enter(token: &'token mut AllocationGroupToken) -> Self {
push_group_to_stack(token.id());
Self {
token,
_ns: PhantomNotSend::default(),
}
}
fn exit_inner(&mut self) {
#[allow(unused_variables)]
let current = pop_group_from_stack();
debug_assert_eq!(
current,
self.token.id(),
"popped group from stack but got unexpected group"
);
}
pub fn exit(mut self) {
self.exit_inner();
}
}
impl<'token> Drop for AllocationGuard<'token> {
fn drop(&mut self) {
self.exit_inner();
}
}
#[cfg(feature = "tracing-compat")]
pub(crate) struct UnsafeAllocationGroupToken {
id: AllocationGroupId,
}
#[cfg(feature = "tracing-compat")]
impl UnsafeAllocationGroupToken {
pub fn new(id: AllocationGroupId) -> Self {
Self { id }
}
pub fn enter(&mut self) {
push_group_to_stack(self.id.clone());
}
pub fn exit(&mut self) {
#[allow(unused_variables)]
let current = pop_group_from_stack();
debug_assert_eq!(
current, self.id,
"popped group from stack but got unexpected group"
);
}
}
#[inline(always)]
pub(crate) fn try_with_suspended_allocation_group<F>(f: F)
where
F: FnOnce(AllocationGroupId),
{
let _ = LOCAL_ALLOCATION_GROUP_STACK.try_with(
#[inline(always)]
|stack| {
if let Ok(stack) = stack.try_borrow_mut() {
f(stack.current());
}
},
);
}
#[inline(always)]
pub(crate) fn with_suspended_allocation_group<F, R>(f: F) -> R
where
F: FnOnce() -> R,
{
LOCAL_ALLOCATION_GROUP_STACK.with(
#[inline(always)]
|stack| {
let _result = stack.try_borrow_mut();
f()
},
)
}