use std::cell::Cell;
use std::marker::PhantomData;
use crate::dual::Dual;
use crate::float::Float;
use super::BytecodeTape;
thread_local! {
static BTAPE_F32: Cell<*mut BytecodeTape<f32>> = const { Cell::new(std::ptr::null_mut()) };
static BTAPE_F64: Cell<*mut BytecodeTape<f64>> = const { Cell::new(std::ptr::null_mut()) };
static BTAPE_DUAL_F32: Cell<*mut BytecodeTape<Dual<f32>>> = const { Cell::new(std::ptr::null_mut()) };
static BTAPE_DUAL_F64: Cell<*mut BytecodeTape<Dual<f64>>> = const { Cell::new(std::ptr::null_mut()) };
static BTAPE_BORROWED_F32: Cell<bool> = const { Cell::new(false) };
static BTAPE_BORROWED_F64: Cell<bool> = const { Cell::new(false) };
static BTAPE_BORROWED_DUAL_F32: Cell<bool> = const { Cell::new(false) };
static BTAPE_BORROWED_DUAL_F64: Cell<bool> = const { Cell::new(false) };
}
pub trait BtapeThreadLocal: Float {
fn btape_cell() -> &'static std::thread::LocalKey<Cell<*mut BytecodeTape<Self>>>;
fn btape_borrow_cell() -> &'static std::thread::LocalKey<Cell<bool>>;
}
impl BtapeThreadLocal for f32 {
fn btape_cell() -> &'static std::thread::LocalKey<Cell<*mut BytecodeTape<Self>>> {
&BTAPE_F32
}
fn btape_borrow_cell() -> &'static std::thread::LocalKey<Cell<bool>> {
&BTAPE_BORROWED_F32
}
}
impl BtapeThreadLocal for f64 {
fn btape_cell() -> &'static std::thread::LocalKey<Cell<*mut BytecodeTape<Self>>> {
&BTAPE_F64
}
fn btape_borrow_cell() -> &'static std::thread::LocalKey<Cell<bool>> {
&BTAPE_BORROWED_F64
}
}
impl BtapeThreadLocal for Dual<f32> {
fn btape_cell() -> &'static std::thread::LocalKey<Cell<*mut BytecodeTape<Self>>> {
&BTAPE_DUAL_F32
}
fn btape_borrow_cell() -> &'static std::thread::LocalKey<Cell<bool>> {
&BTAPE_BORROWED_DUAL_F32
}
}
impl BtapeThreadLocal for Dual<f64> {
fn btape_cell() -> &'static std::thread::LocalKey<Cell<*mut BytecodeTape<Self>>> {
&BTAPE_DUAL_F64
}
fn btape_borrow_cell() -> &'static std::thread::LocalKey<Cell<bool>> {
&BTAPE_BORROWED_DUAL_F64
}
}
struct BtapeBorrowGuard {
cell: &'static std::thread::LocalKey<Cell<bool>>,
}
impl BtapeBorrowGuard {
fn new<F: BtapeThreadLocal>() -> Self {
let cell = F::btape_borrow_cell();
cell.with(|b| {
assert!(
!b.get(),
"reentrant with_active_btape call detected — this would create aliased &mut references"
);
b.set(true);
});
BtapeBorrowGuard { cell }
}
}
impl Drop for BtapeBorrowGuard {
fn drop(&mut self) {
self.cell.with(|b| b.set(false));
}
}
#[inline]
pub fn with_active_btape<F: BtapeThreadLocal, R>(f: impl FnOnce(&mut BytecodeTape<F>) -> R) -> R {
let _guard = BtapeBorrowGuard::new::<F>();
F::btape_cell().with(|cell| {
let ptr = cell.get();
assert!(
!ptr.is_null(),
"No active bytecode tape. Use echidna::record() to record a function."
);
let tape = unsafe { &mut *ptr };
f(tape)
})
}
pub struct BtapeGuard<'a, F: BtapeThreadLocal> {
prev: *mut BytecodeTape<F>,
_borrow: PhantomData<&'a mut BytecodeTape<F>>,
}
impl<'a, F: BtapeThreadLocal> BtapeGuard<'a, F> {
#[must_use = "dropping the guard immediately deactivates the tape; bind it to extend the recording scope"]
pub fn new(tape: &'a mut BytecodeTape<F>) -> Self {
let prev = F::btape_cell().with(|cell| {
let prev = cell.get();
cell.set(tape as *mut BytecodeTape<F>);
prev
});
BtapeGuard {
prev,
_borrow: PhantomData,
}
}
}
impl<'a, F: BtapeThreadLocal> Drop for BtapeGuard<'a, F> {
fn drop(&mut self) {
F::btape_cell().with(|cell| {
cell.set(self.prev);
});
}
}