#![forbid(unsafe_code)]
use crate::internal_prelude::*;
pub use tor_memquota_cost::memory_cost::HasMemoryCost;
#[derive(Deref, Educe)]
#[educe(Clone)]
#[educe(Debug(named_field = false))]
pub struct TypedParticipation<T> {
#[deref]
raw: Participation,
#[educe(Debug(ignore))]
marker: PhantomData<fn(T)>,
}
#[derive(Educe, derive_more::Display)]
#[educe(Copy, Clone)]
#[educe(Debug(named_field = false))]
#[display("{raw}")]
pub struct TypedMemoryCost<T> {
raw: usize,
#[educe(Debug(ignore))]
marker: PhantomData<fn(T)>,
}
pub trait HasTypedMemoryCost<T>: Sized {
fn typed_memory_cost(&self, _: EnabledToken) -> TypedMemoryCost<T>;
}
impl<T: HasMemoryCost> HasTypedMemoryCost<T> for T {
fn typed_memory_cost(&self, enabled: EnabledToken) -> TypedMemoryCost<T> {
TypedMemoryCost::from_raw(self.memory_cost(enabled))
}
}
impl<T> HasTypedMemoryCost<T> for TypedMemoryCost<T> {
fn typed_memory_cost(&self, _: EnabledToken) -> TypedMemoryCost<T> {
*self
}
}
impl<T> TypedParticipation<T> {
pub fn new(raw: Participation) -> Self {
TypedParticipation {
raw,
marker: PhantomData,
}
}
pub fn claim(&mut self, t: &impl HasTypedMemoryCost<T>) -> Result<(), Error> {
let Some(enabled) = EnabledToken::new_if_compiled_in() else {
return Ok(());
};
self.raw.claim(t.typed_memory_cost(enabled).raw)
}
pub fn release(&mut self, t: &impl HasTypedMemoryCost<T>) {
let Some(enabled) = EnabledToken::new_if_compiled_in() else {
return;
};
self.raw.release(t.typed_memory_cost(enabled).raw);
}
pub fn try_claim<C, F, E, R>(&mut self, item: C, call: F) -> Result<Result<R, E>, Error>
where
C: HasTypedMemoryCost<T>,
F: FnOnce(C) -> Result<R, E>,
{
self.try_claim_or_return(item, call).map_err(|(e, _item)| e)
}
pub fn try_claim_or_return<C, F, E, R>(
&mut self,
item: C,
call: F,
) -> Result<Result<R, E>, (Error, C)>
where
C: HasTypedMemoryCost<T>,
F: FnOnce(C) -> Result<R, E>,
{
let Some(enabled) = EnabledToken::new_if_compiled_in() else {
return Ok(call(item));
};
let cost = item.typed_memory_cost(enabled);
match self.claim(&cost) {
Ok(()) => {}
Err(e) => return Err((e, item)),
}
match catch_unwind(AssertUnwindSafe(move || call(item))) {
Err(panic_payload) => {
self.release(&cost);
std::panic::resume_unwind(panic_payload)
}
Ok(Err(caller_error)) => {
self.release(&cost);
Ok(Err(caller_error))
}
Ok(Ok(y)) => Ok(Ok(y)),
}
}
pub fn as_raw(&mut self) -> &mut Participation {
&mut self.raw
}
pub fn into_raw(self) -> Participation {
self.raw
}
}
impl<T> From<Participation> for TypedParticipation<T> {
fn from(untyped: Participation) -> TypedParticipation<T> {
TypedParticipation::new(untyped)
}
}
impl<T> TypedMemoryCost<T> {
pub fn from_raw(raw: usize) -> Self {
TypedMemoryCost {
raw,
marker: PhantomData,
}
}
pub fn into_raw(self) -> usize {
self.raw
}
}
#[cfg(all(test, feature = "memquota", not(miri) /* coarsetime */))]
mod test {
#![allow(clippy::bool_assert_comparison)]
#![allow(clippy::clone_on_copy)]
#![allow(clippy::dbg_macro)]
#![allow(clippy::mixed_attributes_style)]
#![allow(clippy::print_stderr)]
#![allow(clippy::print_stdout)]
#![allow(clippy::single_char_pattern)]
#![allow(clippy::unwrap_used)]
#![allow(clippy::unchecked_time_subtraction)]
#![allow(clippy::useless_vec)]
#![allow(clippy::needless_pass_by_value)]
#![allow(clippy::string_slice)] #![allow(clippy::arithmetic_side_effects)]
use super::*;
use crate::mtracker::test::*;
use crate::mtracker::*;
use tor_rtmock::MockRuntime;
#[derive(Debug)]
struct DummyParticipant;
impl IsParticipant for DummyParticipant {
fn get_oldest(&self, _: EnabledToken) -> Option<CoarseInstant> {
None
}
fn reclaim(self: Arc<Self>, _: EnabledToken) -> ReclaimFuture {
panic!()
}
}
struct Costed;
impl HasMemoryCost for Costed {
fn memory_cost(&self, _: EnabledToken) -> usize {
TEST_DEFAULT_LIMIT - mbytes(1)
}
}
#[test]
fn api() {
MockRuntime::test_with_various(|rt| async move {
let trk = mk_tracker(&rt);
let acct = trk.new_account(None).unwrap();
let particip = Arc::new(DummyParticipant);
let partn = acct
.register_participant(Arc::downgrade(&particip) as _)
.unwrap();
let mut partn: TypedParticipation<Costed> = partn.into();
partn.claim(&Costed).unwrap();
partn.release(&Costed);
let cost = Costed.typed_memory_cost(EnabledToken::new());
partn.claim(&cost).unwrap();
partn.release(&cost);
partn
.try_claim(Costed, |_: Costed| Err::<Void, _>(()))
.unwrap()
.unwrap_err();
catch_unwind(AssertUnwindSafe(|| {
let didnt_panic =
partn.try_claim(Costed, |_: Costed| -> Result<Void, Void> { panic!() });
panic!("{:?}", didnt_panic);
}))
.unwrap_err();
let did_claim = partn
.try_claim(Costed, |c: Costed| Ok::<Costed, Void>(c))
.unwrap()
.void_unwrap();
assert!(trk.used_current_approx().unwrap() > 0);
partn.release(&did_claim);
drop(acct);
drop(particip);
drop(trk);
partn
.try_claim(Costed, |_| -> Result<Void, Void> { panic!() })
.unwrap_err();
rt.advance_until_stalled().await;
});
}
}