use crate::task::{DetachableTask, TaskSpawner};
use std::fmt::Debug;
use std::mem::ManuallyDrop;
use std::ops::{Deref, DerefMut};
pub trait CallableGuard<const SYNC: bool, const ASYNC: bool, Context> {
type Output;
fn call(self, context: Context) -> Self::Output;
}
impl<Context, Guard> CallableGuard<false, false, Context> for Guard
where
Guard: FnOnce(Context),
{
type Output = ();
#[inline]
fn call(self, context: Context) -> Self::Output {
self(context)
}
}
impl<const ASYNC: bool, Context, Guard, R> CallableGuard<true, ASYNC, Context> for Guard
where
Guard: FnOnce(Context) -> R,
{
type Output = R;
#[inline]
fn call(self, context: Context) -> Self::Output {
self(context)
}
}
pub struct AsyncGuard<Spawner, Guard> {
spawner: Spawner,
guard: Guard,
}
impl<Context, Spawner: TaskSpawner<Task>, Guard, Task> CallableGuard<false, true, Context>
for AsyncGuard<Spawner, Guard>
where
Guard: FnOnce(Context) -> Task,
{
type Output = DetachableTask<Spawner, Task>;
#[inline]
fn call(self, context: Context) -> Self::Output {
DetachableTask::with_spawner(self.spawner, (self.guard)(context))
}
}
cfg_select! {
feature = "tokio" => {
use crate::task::TokioHandle;
type DefaultAsyncSpawner = TokioHandle;
const DEFAULT_ASYNC_SPAWNER: DefaultAsyncSpawner = TokioHandle;
}
_ => {
type DefaultAsyncSpawner = ();
const DEFAULT_ASYNC_SPAWNER: DefaultAsyncSpawner = ();
}
}
impl<Context, Guard, Task: Future> CallableGuard<false, true, Context> for Guard
where
Guard: FnOnce(Context) -> Task,
DefaultAsyncSpawner: TaskSpawner<Task>,
{
type Output =
<AsyncGuard<DefaultAsyncSpawner, Guard> as CallableGuard<false, true, Context>>::Output;
#[inline]
fn call(self, context: Context) -> Self::Output {
AsyncGuard {
spawner: DEFAULT_ASYNC_SPAWNER,
guard: self,
}
.call(context)
}
}
struct ContextGuardInner<
const SYNC: bool,
const ASYNC: bool,
Context,
Guard: CallableGuard<SYNC, ASYNC, Context>,
> {
context: Context,
guard: Guard,
}
pub struct ContextGuard<
const SYNC: bool,
const ASYNC: bool,
Context,
Guard: CallableGuard<SYNC, ASYNC, Context>,
>(ManuallyDrop<ContextGuardInner<SYNC, ASYNC, Context, Guard>>);
impl<const SYNC: bool, const ASYNC: bool, Context, Guard: CallableGuard<SYNC, ASYNC, Context>> Deref
for ContextGuard<SYNC, ASYNC, Context, Guard>
{
type Target = Context;
#[inline]
fn deref(&self) -> &Self::Target {
&self.0.context
}
}
impl<const SYNC: bool, const ASYNC: bool, Context, Guard: CallableGuard<SYNC, ASYNC, Context>>
DerefMut for ContextGuard<SYNC, ASYNC, Context, Guard>
{
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0.context
}
}
impl<
const SYNC: bool,
const ASYNC: bool,
Context: Debug,
Guard: CallableGuard<SYNC, ASYNC, Context>,
> Debug for ContextGuard<SYNC, ASYNC, Context, Guard>
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let name = if !SYNC && ASYNC {
"ContextGuard::Async"
} else {
"ContextGuard::Sync"
};
f.debug_struct(name)
.field("context", &self.0.context)
.finish_non_exhaustive()
}
}
impl<const SYNC: bool, const ASYNC: bool, Context, Guard: CallableGuard<SYNC, ASYNC, Context>>
ContextGuard<SYNC, ASYNC, Context, Guard>
{
#[inline]
pub fn with_guard(context: Context, guard: Guard) -> Self {
Self(ManuallyDrop::new(ContextGuardInner { context, guard }))
}
#[inline]
pub fn new<_R>(context: Context, guard: Guard) -> Self
where
Guard: FnOnce(Context) -> _R,
{
Self::with_guard(context, guard)
}
}
impl<Context, Spawner: TaskSpawner<Task>, Guard, Task>
ContextGuard<false, true, Context, AsyncGuard<Spawner, Guard>>
where
Guard: FnOnce(Context) -> Task,
{
#[inline]
pub fn with_spawner(spawner: Spawner, context: Context, guard: Guard) -> Self {
Self::with_guard(context, AsyncGuard { spawner, guard })
}
}
impl<const SYNC: bool, const ASYNC: bool, Context, Guard: CallableGuard<SYNC, ASYNC, Context>>
ContextGuard<SYNC, ASYNC, Context, Guard>
{
#[inline]
unsafe fn call(&mut self) -> Guard::Output {
unsafe {
let ContextGuardInner { context, guard } = ManuallyDrop::take(&mut self.0);
guard.call(context)
}
}
#[inline]
pub fn trigger(self) -> Guard::Output {
let mut this = ManuallyDrop::new(self);
unsafe { this.call() }
}
#[inline]
pub fn defuse(self) -> Context {
let mut this = ManuallyDrop::new(self);
unsafe {
let ContextGuardInner { context, guard: _ } = ManuallyDrop::take(&mut this.0);
context
}
}
}
impl<const SYNC: bool, const ASYNC: bool, Context, Guard: CallableGuard<SYNC, ASYNC, Context>> Drop
for ContextGuard<SYNC, ASYNC, Context, Guard>
{
#[inline]
fn drop(&mut self) {
let _ = unsafe { self.call() };
}
}