use crate::mock_store::{MockLayer, MockStore};
use std::{any::{Any, TypeId}, marker::Tuple};
use std::marker::PhantomData;
use std::mem::transmute;
pub trait Mockable<T: Tuple, O> {
unsafe fn mock_raw<M: FnMut<T, Output = MockResult<T, O>>>(&self, mock: M);
fn mock_safe<M: FnMut<T, Output = MockResult<T, O>> + 'static>(&self, mock: M);
fn clear_mock(&self);
#[doc(hidden)]
fn call_mock(&self, input: T) -> MockResult<T, O>;
#[doc(hidden)]
unsafe fn get_mock_id(&self) -> TypeId;
}
pub enum MockResult<T, O> {
Continue(T),
Return(O),
}
thread_local! {
static MOCK_STORE: MockStore = MockStore::default()
}
pub fn clear_mocks() {
MOCK_STORE.with(|mock_store| mock_store.clear())
}
impl<T: Tuple, O, F: FnOnce<T, Output = O>> Mockable<T, O> for F {
unsafe fn mock_raw<M: FnMut<T, Output = MockResult<T, O>>>(&self, mock: M) {
let id = self.get_mock_id();
let boxed = Box::new(mock) as Box<dyn FnMut<_, Output = _>>;
let static_boxed: Box<dyn FnMut<T, Output = MockResult<T, O>> + 'static> = transmute(boxed);
MOCK_STORE.with(|mock_store| mock_store.add_to_thread_layer(id, static_boxed))
}
fn mock_safe<M: FnMut<T, Output = MockResult<T, O>> + 'static>(&self, mock: M) {
unsafe { self.mock_raw(mock) }
}
fn clear_mock(&self) {
let id = unsafe { self.get_mock_id() };
MOCK_STORE.with(|mock_store| mock_store.clear_id(id))
}
fn call_mock(&self, input: T) -> MockResult<T, O> {
unsafe {
let id = self.get_mock_id();
MOCK_STORE.with(|mock_store| mock_store.call(id, input))
}
}
unsafe fn get_mock_id(&self) -> TypeId {
(|| ()).type_id()
}
}
#[derive(Default)]
pub struct MockContext<'a> {
mock_layer: MockLayer,
phantom_lifetime: PhantomData<&'a ()>,
}
impl<'a> MockContext<'a> {
pub fn new() -> Self {
Self::default()
}
pub fn mock_safe<I: Tuple, O, F, M>(self, mockable: F, mock: M) -> Self
where
F: Mockable<I, O>,
M: FnMut<I, Output = MockResult<I, O>> + 'a,
{
unsafe { self.mock_raw(mockable, mock) }
}
pub unsafe fn mock_raw<I: Tuple, O, F, M>(mut self, mockable: F, mock: M) -> Self
where
F: Mockable<I, O>,
M: FnMut<I, Output = MockResult<I, O>>,
{
let mock_box = Box::new(mock) as Box<dyn FnMut<_, Output = _>>;
let mock_box_static: Box<dyn FnMut<I, Output = MockResult<I, O>> + 'static> =
std::mem::transmute(mock_box);
self.mock_layer.add(mockable.get_mock_id(), mock_box_static);
self
}
pub fn run<T, F: FnOnce() -> T>(self, f: F) -> T {
MOCK_STORE.with(|mock_store| unsafe { mock_store.add_layer(self.mock_layer) });
let _mock_level_guard = MockLayerGuard;
f()
}
}
struct MockLayerGuard;
impl<'a> Drop for MockLayerGuard {
fn drop(&mut self) {
MOCK_STORE.with(|mock_store| unsafe { mock_store.remove_layer() });
}
}