#![cfg(feature = "reentrancy")]
use crate::memory_map::*;
use crate::types::*;
use crate::utils::*;
use candid::Principal;
use ic_stable_structures::StableBTreeMap;
use std::cell::RefCell;
pub struct ReentrancyGuard {
caller: Principal,
}
thread_local! {
static REENTRANCY_GUARD_MAP: RefCell<StableBTreeMap<StablePrincipal, (), VM>> =
MEMORY_MANAGER.with(|mm| {
RefCell::new(StableBTreeMap::init(
mm.borrow().get(REENTRANCY_GUARD_MEM_ID)))
});
}
impl ReentrancyGuard {
pub fn new() -> Self {
let caller = canister_caller();
if REENTRANCY_GUARD_MAP.with(|g| g.borrow().contains_key(&caller.into())) {
ic_cdk::trap("ReentrancyGuard: reentrant call");
}
REENTRANCY_GUARD_MAP.with(|g| g.borrow_mut().insert(caller.into(), ()));
Self { caller }
}
}
impl Default for ReentrancyGuard {
fn default() -> Self {
Self::new()
}
}
impl Drop for ReentrancyGuard {
fn drop(&mut self) {
REENTRANCY_GUARD_MAP.with(|g| g.borrow_mut().remove(&self.caller.into()));
}
}
#[cfg(test)]
mod unit_tests {
use super::*;
#[test]
#[should_panic(expected = "trap should only be called inside canisters")]
#[allow(unconditional_recursion)]
fn test_reentrancy_guard_reentrant() {
let _guard = ReentrancyGuard::new();
test_reentrancy_guard_reentrant();
}
#[test]
fn test_reentrancy_guard_non_reentrant() {
let _guard = ReentrancyGuard::new();
}
#[test]
#[should_panic(expected = "trap should only be called inside canisters")]
fn test_reentrancy_guard_cross_reentrant() {
let _guard = ReentrancyGuard::new();
test_reentrancy_guard_non_reentrant();
}
}