extern crate self as clients;
use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::fmt;
use std::future::Future;
use std::mem::{self, MaybeUninit};
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{OnceLock, RwLock};
use std::thread;
pub use clients_macros::{Depends, client};
mod builtins;
pub use builtins::*;
type DependencyMap = HashMap<TypeId, Box<dyn Any + Send + Sync>>;
pub type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send + 'static>>;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DependencyError {
Missing(&'static str),
Message(&'static str),
Owned(String),
}
impl DependencyError {
pub const fn missing(path: &'static str) -> Self {
Self::Missing(path)
}
pub const fn message(message: &'static str) -> Self {
Self::Message(message)
}
}
impl fmt::Display for DependencyError {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Missing(path) => {
write!(formatter, "missing dependency implementation for `{path}`")
}
Self::Message(message) => formatter.write_str(message),
Self::Owned(message) => formatter.write_str(message),
}
}
}
impl std::error::Error for DependencyError {}
pub trait Dependency: Clone + Send + Sync + 'static {
fn live() -> Self;
}
pub fn boxed<Fut>(future: Fut) -> BoxFuture<Fut::Output>
where
Fut: Future + Send + 'static,
{
Box::pin(future)
}
pub fn get<D>() -> D
where
D: Dependency,
{
current_override::<D>().unwrap_or_else(D::live)
}
pub fn unimplemented_dependency(path: &'static str) -> ! {
panic!(
"dependency `{path}` has no live implementation and no active test override; add `= ...` in `client!` or install a `test_deps!` override"
)
}
fn active_overrides() -> &'static RwLock<Vec<DependencyMap>> {
static ACTIVE_OVERRIDES: OnceLock<RwLock<Vec<DependencyMap>>> = OnceLock::new();
ACTIVE_OVERRIDES.get_or_init(|| RwLock::new(Vec::new()))
}
fn current_override<D>() -> Option<D>
where
D: Dependency,
{
let overrides = active_overrides()
.read()
.expect("dependency override lock poisoned");
for layer in overrides.iter().rev() {
if let Some(value) = layer.get(&TypeId::of::<D>()) {
let dependency = value
.downcast_ref::<D>()
.expect("dependency override stored with the wrong type");
return Some(dependency.clone());
}
}
None
}
fn push_overrides(entries: DependencyMap) {
active_overrides()
.write()
.expect("dependency override lock poisoned")
.push(entries);
}
fn pop_overrides() {
let popped = active_overrides()
.write()
.expect("dependency override lock poisoned")
.pop();
assert!(popped.is_some(), "dependency override stack underflow");
}
fn test_scope_lock() -> &'static AtomicBool {
static TEST_SCOPE_LOCK: AtomicBool = AtomicBool::new(false);
&TEST_SCOPE_LOCK
}
fn acquire_test_lock() {
while test_scope_lock()
.compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
.is_err()
{
thread::yield_now();
}
}
fn release_test_lock() {
test_scope_lock().store(false, Ordering::Release);
}
#[derive(Default)]
pub struct OverrideBuilder {
entries: DependencyMap,
}
impl OverrideBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn set<D>(&mut self, dependency: D) -> &mut Self
where
D: Dependency,
{
self.entries.insert(
TypeId::of::<D>(),
Box::new(dependency) as Box<dyn Any + Send + Sync>,
);
self
}
pub fn update<D, F>(&mut self, update: F) -> &mut Self
where
D: Dependency,
F: FnOnce(D) -> D,
{
let current = self.take_or_resolve::<D>();
self.set(update(current))
}
pub fn enter(self) -> OverrideGuard {
push_overrides(self.entries);
OverrideGuard {
release_test_lock: false,
}
}
pub fn enter_test(self) -> OverrideGuard {
acquire_test_lock();
push_overrides(self.entries);
OverrideGuard {
release_test_lock: true,
}
}
fn take_or_resolve<D>(&mut self) -> D
where
D: Dependency,
{
if let Some(entry) = self.entries.remove(&TypeId::of::<D>()) {
*entry
.downcast::<D>()
.expect("dependency override stored with the wrong type")
} else {
get::<D>()
}
}
}
#[must_use = "keep the guard alive for as long as the dependency overrides should remain active"]
pub struct OverrideGuard {
release_test_lock: bool,
}
impl Drop for OverrideGuard {
fn drop(&mut self) {
pop_overrides();
if self.release_test_lock {
release_test_lock();
}
}
}
fn assert_non_capturing<F>() {
assert!(
mem::size_of::<F>() == 0,
"dependency implementations must be non-capturing closures or function items"
);
}
#[allow(clippy::uninit_assumed_init)]
unsafe fn resurrect_zst<F>() -> F {
debug_assert_eq!(mem::size_of::<F>(), 0);
unsafe { MaybeUninit::<F>::uninit().assume_init() }
}
macro_rules! define_erasers {
($( $sync_name:ident, $async_name:ident, ( $( $arg:ident : $arg_ty:ident ),* ) );* $(;)?) => {
$(
#[doc = "Internal helper that erases a non-capturing synchronous closure into a raw function pointer."]
#[doc(hidden)]
pub fn $sync_name<F, R $(, $arg_ty)*>(_: F) -> fn($( $arg_ty ),*) -> R
where
F: Fn($( $arg_ty ),*) -> R + Copy + 'static,
{
assert_non_capturing::<F>();
fn trampoline<F, R $(, $arg_ty)*>($( $arg : $arg_ty ),*) -> R
where
F: Fn($( $arg_ty ),*) -> R + Copy + 'static,
{
let function: F = unsafe { resurrect_zst() };
function($( $arg ),*)
}
trampoline::<F, R $(, $arg_ty)*>
}
#[doc = "Internal helper that erases a non-capturing asynchronous closure into a raw function pointer returning `BoxFuture`."]
#[doc(hidden)]
pub fn $async_name<F, Fut, R $(, $arg_ty)*>(_: F) -> fn($( $arg_ty ),*) -> BoxFuture<R>
where
F: Fn($( $arg_ty ),*) -> Fut + Copy + 'static,
Fut: Future<Output = R> + Send + 'static,
{
assert_non_capturing::<F>();
fn trampoline<F, Fut, R $(, $arg_ty)*>($( $arg : $arg_ty ),*) -> BoxFuture<R>
where
F: Fn($( $arg_ty ),*) -> Fut + Copy + 'static,
Fut: Future<Output = R> + Send + 'static,
{
let function: F = unsafe { resurrect_zst() };
Box::pin(function($( $arg ),*))
}
trampoline::<F, Fut, R $(, $arg_ty)*>
}
)*
};
}
define_erasers! {
erase_sync_0, erase_async_0, ();
erase_sync_1, erase_async_1, (arg0: A0);
erase_sync_2, erase_async_2, (arg0: A0, arg1: A1);
erase_sync_3, erase_async_3, (arg0: A0, arg1: A1, arg2: A2);
erase_sync_4, erase_async_4, (arg0: A0, arg1: A1, arg2: A2, arg3: A3);
}
#[doc(hidden)]
#[macro_export]
macro_rules! __dep_to_sync_fn {
(() => $implementation:expr) => {
$crate::erase_sync_0($implementation)
};
(($arg0:ident : $arg0_ty:ty) => $implementation:expr) => {
$crate::erase_sync_1($implementation)
};
(($arg0:ident : $arg0_ty:ty, $arg1:ident : $arg1_ty:ty) => $implementation:expr) => {
$crate::erase_sync_2($implementation)
};
(($arg0:ident : $arg0_ty:ty, $arg1:ident : $arg1_ty:ty, $arg2:ident : $arg2_ty:ty) => $implementation:expr) => {
$crate::erase_sync_3($implementation)
};
(($arg0:ident : $arg0_ty:ty, $arg1:ident : $arg1_ty:ty, $arg2:ident : $arg2_ty:ty, $arg3:ident : $arg3_ty:ty) => $implementation:expr) => {
$crate::erase_sync_4($implementation)
};
}
#[doc(hidden)]
#[macro_export]
macro_rules! __dep_to_async_fn {
(() => $implementation:expr) => {
$crate::erase_async_0($implementation)
};
(($arg0:ident : $arg0_ty:ty) => $implementation:expr) => {
$crate::erase_async_1($implementation)
};
(($arg0:ident : $arg0_ty:ty, $arg1:ident : $arg1_ty:ty) => $implementation:expr) => {
$crate::erase_async_2($implementation)
};
(($arg0:ident : $arg0_ty:ty, $arg1:ident : $arg1_ty:ty, $arg2:ident : $arg2_ty:ty) => $implementation:expr) => {
$crate::erase_async_3($implementation)
};
(($arg0:ident : $arg0_ty:ty, $arg1:ident : $arg1_ty:ty, $arg2:ident : $arg2_ty:ty, $arg3:ident : $arg3_ty:ty) => $implementation:expr) => {
$crate::erase_async_4($implementation)
};
}
#[macro_export]
macro_rules! deps {
() => {};
($binding:ident = $client:ident.$method:ident $(, $($rest:tt)*)?) => {
let $binding = $client::$method::get();
$crate::deps!($($($rest)*)?);
};
}
#[macro_export]
macro_rules! test_deps {
() => {
let __dep_test_scope_guard = $crate::OverrideBuilder::new().enter_test();
let _ = &__dep_test_scope_guard;
};
($client:ident.$method:ident => $implementation:expr $(, $($rest:tt)*)?) => {
let mut __dep_builder = $crate::OverrideBuilder::new();
$client::$method::override_with(&mut __dep_builder, $implementation);
$(
$crate::__dep_test_deps_more!(__dep_builder, $($rest)*);
)?
let __dep_test_scope_guard = __dep_builder.enter_test();
let _ = &__dep_test_scope_guard;
};
}
#[doc(hidden)]
#[macro_export]
macro_rules! __dep_test_deps_more {
($builder:ident, ) => {};
($builder:ident, $client:ident.$method:ident => $implementation:expr $(, $($rest:tt)*)?) => {
$client::$method::override_with(&mut $builder, $implementation);
$(
$crate::__dep_test_deps_more!($builder, $($rest)*);
)?
};
}
#[cfg(test)]
mod tests {
use super::*;
use std::panic::{AssertUnwindSafe, catch_unwind};
use std::sync::mpsc;
use std::time::Duration;
client! {
struct UnitClock as unit_clock {
fn now_millis() -> u64 = || 100;
}
}
client! {
struct UnitMath as unit_math {
fn add(lhs: u64, rhs: u64) -> u64 = |lhs, rhs| lhs + rhs;
async fn add_async(lhs: u64, rhs: u64) -> u64 = |lhs, rhs| async move { lhs + rhs };
}
}
client! {
struct UnitGreeter as unit_greeter {
fn greeting(id: u64) -> String = |id| {
deps! {
now = unit_clock.now_millis,
}
format!("{id}@{}", now())
};
}
}
fn block_on<F>(future: F) -> F::Output
where
F: Future,
{
use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
let mut future = Box::pin(future);
unsafe fn clone(_: *const ()) -> RawWaker {
RawWaker::new(std::ptr::null(), &VTABLE)
}
unsafe fn wake(_: *const ()) {}
unsafe fn wake_by_ref(_: *const ()) {}
unsafe fn drop(_: *const ()) {}
static VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake_by_ref, drop);
let waker = unsafe { Waker::from_raw(RawWaker::new(std::ptr::null(), &VTABLE)) };
let mut context = Context::from_waker(&waker);
loop {
match Pin::as_mut(&mut future).poll(&mut context) {
Poll::Ready(value) => return value,
Poll::Pending => std::thread::yield_now(),
}
}
}
fn panic_message(payload: Box<dyn Any + Send>) -> String {
match payload.downcast::<String>() {
Ok(message) => *message,
Err(payload) => match payload.downcast::<&'static str>() {
Ok(message) => (*message).to_string(),
Err(_) => "<non-string panic payload>".into(),
},
}
}
#[test]
fn dependency_error_formats_all_variants() {
assert_eq!(
DependencyError::missing("clock.now_millis").to_string(),
"missing dependency implementation for `clock.now_millis`"
);
assert_eq!(DependencyError::message("nope").to_string(), "nope");
assert_eq!(DependencyError::Owned("owned".into()).to_string(), "owned");
}
#[test]
fn get_falls_back_to_live_dependency() {
let _test_scope = OverrideBuilder::new().enter_test();
assert_eq!(get::<UnitClock>().now_millis(), 100);
}
#[test]
fn deps_macro_reads_dependency_methods_in_free_functions() {
let _test_scope = OverrideBuilder::new().enter_test();
assert_eq!(get::<UnitGreeter>().greeting(7), "7@100");
}
#[test]
fn override_builder_set_replaces_a_whole_client() {
let _test_scope = OverrideBuilder::new().enter_test();
let override_clock = UnitClock {
now_millis: erase_sync_0(|| 777),
};
let mut builder = OverrideBuilder::new();
builder.set(override_clock);
let _guard = builder.enter();
assert_eq!(get::<UnitClock>().now_millis(), 777);
}
#[test]
fn override_builder_update_uses_pending_entries_before_live_values() {
let _test_scope = OverrideBuilder::new().enter_test();
let mut builder = OverrideBuilder::new();
builder.set(UnitClock {
now_millis: erase_sync_0(|| 222),
});
builder.update::<UnitClock, _>(|mut dependency| {
assert_eq!(dependency.now_millis(), 222);
dependency.now_millis = erase_sync_0(|| 333);
dependency
});
let _guard = builder.enter();
assert_eq!(get::<UnitClock>().now_millis(), 333);
}
#[test]
fn nested_override_layers_restore_previous_values() {
let _test_scope = OverrideBuilder::new().enter_test();
let mut outer_builder = OverrideBuilder::new();
outer_builder.update::<UnitClock, _>(|mut dependency| {
dependency.now_millis = erase_sync_0(|| 200);
dependency
});
let outer = outer_builder.enter();
assert_eq!(get::<UnitClock>().now_millis(), 200);
{
let mut inner_builder = OverrideBuilder::new();
inner_builder.update::<UnitClock, _>(|mut dependency| {
dependency.now_millis = erase_sync_0(|| 300);
dependency
});
let _inner = inner_builder.enter();
assert_eq!(get::<UnitClock>().now_millis(), 300);
}
assert_eq!(get::<UnitClock>().now_millis(), 200);
drop(outer);
assert_eq!(get::<UnitClock>().now_millis(), 100);
}
#[test]
fn boxed_and_async_erasers_work_without_a_runtime() {
assert_eq!(block_on(boxed(async { 9 })), 9);
let add_sync: fn(u64, u64) -> u64 = erase_sync_2(|lhs: u64, rhs: u64| lhs + rhs);
assert_eq!(add_sync(2, 3), 5);
let add_async: fn(u64, u64) -> BoxFuture<u64> =
erase_async_2(|lhs: u64, rhs: u64| async move { lhs + rhs });
assert_eq!(block_on(add_async(2, 3)), 5);
}
#[test]
fn capturing_closures_are_rejected_when_erased() {
let captured = 9u64;
let panic = catch_unwind(AssertUnwindSafe(|| {
let _: fn() -> u64 = erase_sync_0(move || captured);
}))
.expect_err("capturing closures should panic");
assert!(panic_message(panic).contains("non-capturing closures"));
}
#[test]
fn unimplemented_dependency_panics_with_the_requested_path() {
let panic = catch_unwind(AssertUnwindSafe(|| {
unimplemented_dependency("unit_clock.now_millis")
}))
.expect_err("unimplemented dependency should panic");
assert!(panic_message(panic).contains("unit_clock.now_millis"));
}
#[test]
fn test_deps_macro_installs_multiple_overrides() {
test_deps! {
unit_clock.now_millis => || 444,
unit_math.add => |lhs, rhs| lhs * rhs,
unit_math.add_async => |lhs, rhs| async move { lhs * rhs },
}
assert_eq!(get::<UnitClock>().now_millis(), 444);
assert_eq!(get::<UnitMath>().add(2, 3), 6);
assert_eq!(block_on(get::<UnitMath>().add_async(2, 3)), 6);
}
#[test]
fn enter_test_serializes_parallel_override_scopes() {
let guard = OverrideBuilder::new().enter_test();
let (ready_tx, ready_rx) = mpsc::channel();
let (done_tx, done_rx) = mpsc::channel();
let handle = std::thread::spawn(move || {
ready_tx.send(()).expect("thread should signal readiness");
let _guard = OverrideBuilder::new().enter_test();
done_tx.send(()).expect("thread should signal acquisition");
});
ready_rx
.recv_timeout(Duration::from_secs(1))
.expect("thread should start");
assert!(
done_rx.recv_timeout(Duration::from_millis(50)).is_err(),
"second test scope should block while the first is held"
);
drop(guard);
done_rx
.recv_timeout(Duration::from_secs(1))
.expect("second test scope should acquire once the first drops");
handle.join().expect("thread should join cleanly");
}
#[test]
fn pop_overrides_panics_when_the_stack_is_empty() {
acquire_test_lock();
let panic =
catch_unwind(AssertUnwindSafe(pop_overrides)).expect_err("empty stack should panic");
release_test_lock();
assert!(panic_message(panic).contains("stack underflow"));
}
}