use futures::FutureExt;
use futures::future::{BoxFuture, Shared};
use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::sync::Arc;
pub use std::sync::Mutex as StdMutex;
pub use tokio::sync::Mutex as TokioMutex; use tokio_util::sync::CancellationToken;
use crate::metrics::{MetricLabel, REGISTRY};
use tracing::Instrument;
use tokio::time::{Duration, Instant};
pub trait Clock: Send + Sync + 'static {
fn sleep(&self, duration: Duration) -> BoxFuture<'static, ()>;
fn now(&self) -> Instant;
}
pub struct LiveClock;
impl Clock for LiveClock {
fn sleep(&self, duration: Duration) -> BoxFuture<'static, ()> {
Box::pin(async move {
tokio::time::sleep(duration).await;
})
}
fn now(&self) -> Instant {
Instant::now()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct FiberId(pub usize);
#[derive(Clone)]
pub struct EnvRef<R> {
pub value: R,
}
#[derive(Clone, Default)]
pub struct Env {
map: HashMap<TypeId, Arc<dyn Any + Send + Sync>>,
}
impl Env {
pub fn new() -> Self {
Self {
map: HashMap::new(),
}
}
pub fn insert<T: Send + Sync + 'static>(&mut self, val: T) {
self.map.insert(TypeId::of::<T>(), Arc::new(val));
}
pub fn get<T: Send + Sync + 'static>(&self) -> Option<Arc<T>> {
self.map
.get(&TypeId::of::<T>())
.cloned()
.and_then(|any| any.downcast::<T>().ok())
}
}
#[derive(Clone)]
pub struct Fiber<E, A> {
pub id: FiberId,
pub join_future: Shared<BoxFuture<'static, Exit<E, A>>>,
pub token: CancellationToken,
}
impl<E, A> Fiber<E, A>
where
E: Send + Sync + Clone + 'static,
A: Send + Sync + Clone + 'static,
{
pub async fn join(self) -> Exit<E, A> {
self.join_future.await
}
pub async fn interrupt(self) -> Exit<E, A> {
self.token.cancel();
self.join().await
}
}
#[derive(Clone)]
pub struct Ctx {
pub token: CancellationToken,
pub scope: ScopeHandle,
pub fiber_id: FiberId,
pub locals: Arc<TokioMutex<HashMap<usize, Arc<dyn Any + Send + Sync>>>>,
pub clock: Arc<dyn Clock>,
}
impl Default for Ctx {
fn default() -> Self {
Self::new()
}
}
impl Ctx {
pub fn new() -> Self {
Self {
token: CancellationToken::new(),
scope: ScopeHandle::new(),
fiber_id: FiberId(0),
locals: Arc::new(TokioMutex::new(HashMap::new())),
clock: Arc::new(LiveClock),
}
}
}
#[derive(Clone)]
pub struct FiberRef<T> {
id: usize,
initial: Arc<T>,
}
impl<T: Send + Sync + 'static + Clone> FiberRef<T> {
pub fn new(initial: T) -> Self {
static NEXT_ID: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(0);
let id = NEXT_ID.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
Self {
id,
initial: Arc::new(initial),
}
}
pub fn get(&self) -> Effect<(), (), T> {
let id = self.id;
let initial = self.initial.clone();
Effect::access_async(move |_, ctx| {
let initial = initial.clone();
async move {
let locals = ctx.locals.lock().await;
if let Some(val) = locals.get(&id) {
val.downcast_ref::<T>().cloned().unwrap()
} else {
(*initial).clone()
}
}
})
}
pub fn set(&self, value: T) -> Effect<(), (), ()> {
let id = self.id;
Effect::<(), (), ()>::access_async(move |_, ctx| async move {
let mut locals = ctx.locals.lock().await;
locals.insert(id, Arc::new(value));
})
}
}
#[derive(Clone)]
pub struct Ref<A> {
value: Arc<TokioMutex<A>>,
}
impl<A> Ref<A>
where
A: Send + Sync + 'static + Clone,
{
pub fn new(value: A) -> Self {
Self {
value: Arc::new(TokioMutex::new(value)),
}
}
pub fn get(&self) -> Effect<(), (), A> {
let value = self.value.clone();
Effect::<(), (), A>::async_effect(move || async move {
let guard = value.lock().await;
guard.clone()
})
}
pub fn set(&self, new_value: A) -> Effect<(), (), ()> {
let value = self.value.clone();
Effect::<(), (), ()>::async_effect(move || async move {
let mut guard = value.lock().await;
*guard = new_value;
})
}
pub fn update<F>(&self, f: F) -> Effect<(), (), A>
where
F: FnOnce(A) -> A + Send + Sync + 'static + Clone,
{
let value = self.value.clone();
Effect::<(), (), A>::async_effect(move || async move {
let mut guard = value.lock().await;
let new_val = f(guard.clone());
*guard = new_val.clone();
new_val
})
}
}
type Waiter<E, A> = tokio::sync::oneshot::Sender<Exit<E, A>>;
#[derive(Clone)]
pub struct Deferred<E, A> {
state: Arc<TokioMutex<Option<Exit<E, A>>>>,
waiters: Arc<TokioMutex<Vec<Waiter<E, A>>>>,
}
impl<E, A> Default for Deferred<E, A>
where
E: Send + Sync + Clone + 'static,
A: Send + Sync + Clone + 'static,
{
fn default() -> Self {
Self::new()
}
}
impl<E, A> Deferred<E, A>
where
E: Send + Sync + Clone + 'static,
A: Send + Sync + Clone + 'static,
{
pub fn new() -> Self {
Self {
state: Arc::new(TokioMutex::new(None)),
waiters: Arc::new(TokioMutex::new(Vec::new())),
}
}
pub fn complete(&self, exit: Exit<E, A>) -> Effect<(), (), bool> {
let state = self.state.clone();
let waiters = self.waiters.clone();
Effect::<(), (), bool>::async_effect(move || async move {
let mut guard = state.lock().await;
if guard.is_some() {
false
} else {
*guard = Some(exit.clone());
let mut waiters = waiters.lock().await;
for sender in waiters.drain(..) {
let _ = sender.send(exit.clone());
}
true
}
})
}
pub fn succeed(&self, value: A) -> Effect<(), (), bool> {
self.complete(Exit::Success(value))
}
pub fn fail(&self, error: E) -> Effect<(), (), bool> {
self.complete(Exit::Failure(Cause::Fail(error)))
}
pub fn await_result(&self) -> Effect<(), E, A> {
let state = self.state.clone();
let waiters = self.waiters.clone();
Effect::<(), E, ()>::done(Exit::Success(())).flat_map(move |_| {
let state = state.clone();
let waiters = waiters.clone();
Effect::async_effect(move || async move {
{
let guard = state.lock().await;
if let Some(exit) = guard.as_ref() {
return exit.clone();
}
}
let (tx, rx) = tokio::sync::oneshot::channel();
{
let guard = state.lock().await;
if let Some(exit) = guard.as_ref() {
return exit.clone();
}
let mut waiters_guard = waiters.lock().await;
waiters_guard.push(tx);
}
rx.await.unwrap_or_else(|_| {
Exit::Failure(Cause::Die(Arc::new("Sender dropped".to_string())))
})
})
.flat_map(Effect::done)
})
}
}
#[derive(Clone)]
pub struct Queue<A> {
sender: tokio::sync::mpsc::Sender<A>,
receiver: Arc<TokioMutex<tokio::sync::mpsc::Receiver<A>>>,
}
impl<A> Queue<A>
where
A: Send + Sync + 'static + Clone,
{
pub fn new(capacity: usize) -> Self {
let (sender, receiver) = tokio::sync::mpsc::channel(capacity);
Self {
sender,
receiver: Arc::new(TokioMutex::new(receiver)),
}
}
pub fn offer(&self, value: A) -> Effect<(), (), bool> {
let sender = self.sender.clone();
Effect::<(), (), bool>::async_effect(
move || async move { (sender.send(value).await).is_ok() },
)
}
pub fn take(&self) -> Effect<(), (), Option<A>> {
let receiver = self.receiver.clone();
Effect::<(), (), Option<A>>::async_effect(move || async move {
let mut options = receiver.lock().await;
options.recv().await
})
}
}
#[derive(Clone, Debug)]
pub enum Cause<E> {
Fail(E),
Die(Defect),
Interrupt,
}
#[derive(Clone, Copy, Debug)]
pub enum ScopeExit {
Success,
Failure, Interrupt,
}
type Finalizer = Box<dyn FnOnce(ScopeExit) -> BoxFuture<'static, ()> + Send>;
#[derive(Clone)]
pub struct ScopeHandle {
finalizers: Arc<TokioMutex<Vec<Finalizer>>>,
}
impl<E, A> From<&Exit<E, A>> for ScopeExit {
fn from(exit: &Exit<E, A>) -> Self {
match exit {
Exit::Success(_) => ScopeExit::Success,
Exit::Failure(Cause::Interrupt) => ScopeExit::Interrupt,
_ => ScopeExit::Failure,
}
}
}
impl Default for ScopeHandle {
fn default() -> Self {
Self::new()
}
}
impl ScopeHandle {
pub fn new() -> Self {
Self {
finalizers: Arc::new(TokioMutex::new(Vec::new())),
}
}
pub async fn add_finalizer<F>(&self, f: F)
where
F: FnOnce(ScopeExit) -> BoxFuture<'static, ()> + Send + 'static,
{
let mut finalizers = self.finalizers.lock().await;
finalizers.push(Box::new(f));
}
pub async fn close(&self, exit: ScopeExit) {
let mut finalizers = self.finalizers.lock().await;
while let Some(f) = finalizers.pop() {
f(exit).await;
}
}
}
#[derive(Debug, Clone)]
pub enum Exit<E, A> {
Success(A),
Failure(Cause<E>),
}
impl<E> Cause<E> {
pub fn map<E2, F>(self, f: &F) -> Cause<E2>
where
F: Fn(E) -> E2,
{
match self {
Cause::Fail(e) => Cause::Fail(f(e)),
Cause::Die(d) => Cause::Die(d),
Cause::Interrupt => Cause::Interrupt,
}
}
}
pub type Defect = Arc<dyn std::any::Any + Send + Sync>;
type EffectFn<R, E, A> = dyn Fn(EnvRef<R>, Ctx) -> BoxFuture<'static, Exit<E, A>> + Send + Sync;
pub struct Effect<R, E, A> {
pub(crate) inner: Arc<EffectFn<R, E, A>>,
}
impl<R, E, A> Clone for Effect<R, E, A> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
}
}
}
impl<R, E, A> Effect<R, E, A>
where
R: Clone + Send + Sync + 'static,
E: Send + Sync + 'static,
A: Send + Sync + 'static,
{
pub fn succeed(value: A) -> Self
where
A: Send + Sync + Clone,
{
Self {
inner: Arc::new(move |_, _| {
let value = value.clone();
Box::pin(async move { Exit::Success(value) })
}),
}
}
pub fn fail(error: E) -> Self
where
E: Send + Sync + Clone,
{
Self {
inner: Arc::new(move |_, _| {
let error = error.clone();
Box::pin(async move { Exit::Failure(Cause::Fail(error)) })
}),
}
}
pub fn sync<F>(f: F) -> Self
where
F: FnOnce() -> A + Send + Sync + 'static + Clone,
A: Send,
{
Self {
inner: Arc::new(move |_, _| {
let f = f.clone();
Box::pin(async move { Exit::Success(f()) })
}),
}
}
pub fn async_effect<F, Fut>(f: F) -> Self
where
F: FnOnce() -> Fut + Send + Sync + 'static + Clone,
Fut: futures::Future<Output = A> + Send + 'static,
A: Send,
{
Self {
inner: Arc::new(move |_, _| {
let f = f.clone();
Box::pin(async move { Exit::Success(f().await) })
}),
}
}
pub fn sleep(duration: Duration) -> Self
where
A: From<()>, {
Self {
inner: Arc::new(move |_, ctx| {
Box::pin(async move {
ctx.clock.sleep(duration).await;
Exit::Success(A::from(()))
})
}),
}
}
pub fn with_metric_increment(self, name: &str, labels: Vec<MetricLabel>) -> Self
where
A: Send + Sync + 'static + Clone,
E: Send + Sync + 'static + Clone,
{
let name = name.to_string();
self.map(move |val| {
REGISTRY.get_counter(&name, labels.clone()).increment(1);
val
})
}
pub fn with_metric_duration(self, name: &str, labels: Vec<MetricLabel>) -> Self
where
A: Send + Sync + 'static + Clone,
E: Send + Sync + 'static + Clone,
R: 'static + Clone + Send + Sync,
{
self.timed(name, labels)
}
pub fn timed(self, name: &str, labels: Vec<MetricLabel>) -> Self
where
A: Send + Sync + 'static + Clone,
E: Send + Sync + 'static + Clone,
R: 'static + Clone + Send + Sync,
{
let name = name.to_string();
Effect::sync(Instant::now).flat_map(move |start| {
let labels = labels.clone();
let name = name.clone();
self.clone().map(move |res| {
let elapsed = start.elapsed().as_secs_f64();
REGISTRY
.get_histogram(&name, labels, vec![0.001, 0.01, 0.1, 1.0, 10.0])
.record(elapsed);
res
})
})
}
pub fn access_async<F, Fut>(f: F) -> Self
where
R: Send + Sync,
F: FnOnce(EnvRef<R>, Ctx) -> Fut + Send + Sync + 'static + Clone,
Fut: futures::Future<Output = A> + Send + 'static,
A: Send,
{
Self {
inner: Arc::new(move |env, ctx| {
let f = f.clone();
Box::pin(async move { Exit::Success(f(env, ctx).await) })
}),
}
}
pub fn provide(self, env: R) -> Effect<(), E, A>
where
R: Clone + Send + Sync + 'static,
E: Send + Sync + 'static,
A: Send + Sync + 'static,
{
Effect {
inner: Arc::new(move |_, ctx| {
let effect = self.clone();
let env = env.clone();
Box::pin(async move { (effect.inner)(EnvRef { value: env }, ctx).await })
}),
}
}
pub fn done(exit: Exit<E, A>) -> Self
where
E: Send + Sync + Clone,
A: Send + Sync + Clone,
{
Self {
inner: Arc::new(move |_, _| {
let exit = exit.clone();
Box::pin(async move { exit })
}),
}
}
pub fn map<B, F>(self, f: F) -> Effect<R, E, B>
where
F: FnOnce(A) -> B + Send + Sync + 'static + Clone,
B: Send + Sync + 'static + Clone,
R: Clone + Send + Sync + 'static,
E: Send + Sync + 'static,
A: Send + Sync + 'static,
{
self.flat_map(move |a| -> Effect<R, E, B> { Effect::<R, E, B>::succeed(f(a)) })
}
pub fn map_error<E2, F>(self, f: F) -> Effect<R, E2, A>
where
F: Fn(E) -> E2 + Send + Sync + 'static + Clone,
R: Send + Sync + 'static,
A: Send + Sync + 'static,
E: Send + Sync + 'static,
E2: Send + Sync + 'static,
{
Effect {
inner: Arc::new(move |env: EnvRef<R>, ctx: Ctx| {
let effect = self.clone();
let f = f.clone();
Box::pin(async move {
match (effect.inner)(env, ctx).await {
Exit::Success(a) => Exit::Success(a),
Exit::Failure(cause) => Exit::Failure(cause.map(&f)),
}
})
}),
}
}
pub fn flat_map<B, F>(self, f: F) -> Effect<R, E, B>
where
F: FnOnce(A) -> Effect<R, E, B> + Send + Sync + 'static + Clone,
B: Send + 'static,
R: Clone + Send + Sync + 'static,
E: Send + Sync + 'static,
A: Send + Sync,
{
Effect {
inner: Arc::new(move |env: EnvRef<R>, ctx: Ctx| {
let effect = self.clone();
let f = f.clone();
Box::pin(async move {
match (effect.inner)(env.clone(), ctx.clone()).await {
Exit::Success(a) => {
let next_effect = f(a);
(next_effect.inner)(env, ctx).await
}
Exit::Failure(c) => Exit::Failure(c),
}
})
}),
}
}
pub fn delay(self, duration: Duration) -> Effect<R, E, A>
where
R: Clone + Send + Sync + 'static,
E: Send + Sync + 'static,
A: Send + Sync + 'static,
{
Effect::<R, E, ()>::sleep(duration).flat_map(move |_| self)
}
pub fn trace(self, name: &'static str) -> Effect<R, E, A>
where
R: Clone + Send + Sync + 'static,
E: Send + Sync + 'static,
A: Send + Sync + 'static,
{
Effect {
inner: Arc::new(move |env, ctx| {
let effect = self.clone();
let span = tracing::info_span!("effect", name = name);
async move { (effect.inner)(env, ctx).await }
.instrument(span)
.boxed()
}),
}
}
pub fn on_interrupt<F, R2, E2, X>(self, cleanup: F) -> Effect<R, E, A>
where
F: Fn() -> Effect<R2, E2, X> + Send + Sync + 'static + Clone,
R2: From<R> + Send + Sync + 'static + Clone,
E2: Send + Sync + 'static,
X: Send + Sync + 'static,
{
Effect {
inner: Arc::new(move |env, ctx| {
let effect = self.clone();
let cleanup = cleanup.clone();
Box::pin(async move {
let env_for_cleanup = R2::from(env.value.clone());
let ctx_for_finalizer = ctx.clone();
let finalizer = move |exit: ScopeExit| {
let cleanup = cleanup.clone();
let env = env_for_cleanup.clone();
let ctx = ctx_for_finalizer.clone();
async move {
if let ScopeExit::Interrupt = exit {
let _ = (cleanup().inner)(EnvRef { value: env }, ctx).await;
}
}
.boxed()
};
ctx.scope.add_finalizer(finalizer).await;
(effect.inner)(env, ctx).await
})
}),
}
}
pub fn acquire_release<F, R2, E2, X>(self, release: F) -> Effect<R, E, A>
where
F: FnOnce(A, ScopeExit) -> Effect<R2, E2, X> + Send + Sync + 'static + Clone,
R: Clone + Send + Sync + 'static,
R2: From<R> + Send + Sync + 'static + Clone,
E: Send + Sync + 'static,
A: Send + Sync + Clone + 'static,
X: Send + Sync + 'static,
E2: Send + Sync + 'static,
{
Effect {
inner: Arc::new(move |env: EnvRef<R>, ctx: Ctx| {
let acquire = self.clone();
let release = release.clone();
let env_for_release = R2::from(env.value.clone());
Box::pin(async move {
let ctx_clone = ctx.clone();
let finalizer_env = env_for_release.clone();
let result: Exit<E, A> = (acquire.inner)(env.clone(), ctx.clone()).await;
if let Exit::Success(a) = &result {
let a_for_release = a.clone();
let release = release.clone();
let finalizer = move |exit| {
let release_effect = release(a_for_release, exit);
async move {
let _ = (release_effect.inner)(
EnvRef {
value: finalizer_env,
},
ctx_clone,
)
.await;
}
.boxed()
};
ctx.scope.add_finalizer(finalizer).await;
}
result
})
}),
}
}
pub fn fork(self) -> Effect<R, E, Fiber<E, A>>
where
R: Clone + Send + Sync + 'static,
E: Send + Sync + Clone + 'static,
A: Send + Sync + Clone + 'static,
{
Effect {
inner: Arc::new(move |env, ctx| {
let effect = self.clone();
let locals = ctx.locals.clone();
Box::pin(async move {
let child_token = CancellationToken::new();
let child_scope = ScopeHandle::new();
let child_ctx = Ctx {
token: child_token.clone(),
scope: child_scope.clone(),
fiber_id: FiberId(0), locals: locals.clone(), clock: ctx.clock.clone(), };
let env_for_child = EnvRef {
value: env.value.clone(),
};
let fut = async move {
let result = tokio::select! {
res = (effect.inner)(env_for_child, child_ctx.clone()) => res,
_ = child_ctx.token.cancelled() => Exit::Failure(Cause::Interrupt),
};
let scope_exit = match &result {
Exit::Success(_) => ScopeExit::Success,
Exit::Failure(Cause::Interrupt) => ScopeExit::Interrupt,
Exit::Failure(_) => ScopeExit::Failure,
};
child_ctx.scope.close(scope_exit).await;
result
};
let future = fut.boxed().shared();
let fiber = Fiber {
id: FiberId(0),
join_future: future,
token: child_token,
};
Exit::Success(fiber)
})
}),
}
}
pub fn zip_par<B>(self, other: Effect<R, E, B>) -> Effect<R, E, (A, B)>
where
R: Clone + Send + Sync + 'static,
E: Send + Sync + Clone + 'static,
A: Send + Sync + Clone + 'static,
B: Send + Sync + Clone + 'static,
{
self.fork().flat_map(move |f1: Fiber<E, A>| {
other.clone().fork().flat_map(move |f2: Fiber<E, B>| {
Effect::async_effect(move || async move {
let f1a = f1.clone();
let f2a = f2.clone();
tokio::select! {
e1 = f1a.join() => {
match e1 {
Exit::Success(a) => {
match f2.join().await {
Exit::Success(b) => Exit::Success((a, b)),
Exit::Failure(c) => Exit::Failure(c),
}
}
Exit::Failure(c) => {
let _ = f2.interrupt().await;
Exit::Failure(c)
}
}
}
e2 = f2a.join() => {
match e2 {
Exit::Success(b) => {
match f1.join().await {
Exit::Success(a) => Exit::Success((a, b)),
Exit::Failure(c) => Exit::Failure(c),
}
}
Exit::Failure(c) => {
let _ = f1.interrupt().await;
Exit::Failure(c)
}
}
}
}
})
.flat_map(Effect::done)
})
})
}
pub fn race(self, other: Effect<R, E, A>) -> Effect<R, E, A>
where
R: Clone + Send + Sync + 'static,
E: Send + Sync + Clone + 'static,
A: Send + Sync + Clone + 'static,
{
self.fork().flat_map(move |f1: Fiber<E, A>| {
other.clone().fork().flat_map(move |f2: Fiber<E, A>| {
Effect::async_effect(move || async move {
let f1a = f1.clone();
let f2a = f2.clone();
tokio::select! {
e1 = f1a.join() => {
let _ = f2.interrupt().await;
e1
}
e2 = f2a.join() => {
let _ = f1.interrupt().await;
e2
}
}
})
.flat_map(Effect::done)
})
})
}
pub fn collect_all_par<I>(effects: I) -> Effect<R, E, Vec<A>>
where
I: IntoIterator<Item = Effect<R, E, A>>,
I::IntoIter: Send,
R: Clone + Send + Sync + 'static,
E: Send + Sync + Clone + 'static,
A: Send + Sync + Clone + 'static,
{
let effects: Vec<_> = effects.into_iter().collect();
effects
.into_iter()
.fold(Effect::<R, E, Vec<A>>::succeed(Vec::new()), |acc, eff| {
acc.zip_par(eff).map(|(mut list, item): (Vec<A>, A)| {
list.push(item);
list
})
})
}
}