use super::{ArcAsyncDerived, AsyncDerived};
use crate::{
graph::{AnySource, ToAnySource},
owner::Storage,
signal::guards::{AsyncPlain, Mapped, ReadGuard},
traits::{DefinedAt, Track},
unwrap_signal,
};
use futures::pin_mut;
use or_poisoned::OrPoisoned;
use std::{
future::{Future, IntoFuture},
pin::Pin,
sync::{
atomic::{AtomicBool, Ordering},
Arc, RwLock,
},
task::{Context, Poll, Waker},
};
pub type AsyncDerivedGuard<T> = ReadGuard<T, Mapped<AsyncPlain<Option<T>>, T>>;
pub struct AsyncDerivedReadyFuture {
pub(crate) source: AnySource,
pub(crate) loading: Arc<AtomicBool>,
pub(crate) wakers: Arc<RwLock<Vec<Waker>>>,
}
impl Future for AsyncDerivedReadyFuture {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let waker = cx.waker();
self.source.track();
if self.loading.load(Ordering::Relaxed) {
self.wakers.write().or_poisoned().push(waker.clone());
Poll::Pending
} else {
Poll::Ready(())
}
}
}
impl<T> IntoFuture for ArcAsyncDerived<T>
where
T: Clone + 'static,
{
type Output = T;
type IntoFuture = AsyncDerivedFuture<T>;
fn into_future(self) -> Self::IntoFuture {
AsyncDerivedFuture {
source: self.to_any_source(),
value: Arc::clone(&self.value),
loading: Arc::clone(&self.loading),
wakers: Arc::clone(&self.wakers),
}
}
}
impl<T, S> IntoFuture for AsyncDerived<T, S>
where
T: Clone + 'static,
S: Storage<ArcAsyncDerived<T>>,
{
type Output = T;
type IntoFuture = AsyncDerivedFuture<T>;
#[track_caller]
fn into_future(self) -> Self::IntoFuture {
let this = self
.inner
.try_get_value()
.unwrap_or_else(unwrap_signal!(self));
this.into_future()
}
}
pub struct AsyncDerivedFuture<T> {
source: AnySource,
value: Arc<async_lock::RwLock<Option<T>>>,
loading: Arc<AtomicBool>,
wakers: Arc<RwLock<Vec<Waker>>>,
}
impl<T> Future for AsyncDerivedFuture<T>
where
T: Clone + 'static,
{
type Output = T;
#[track_caller]
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let waker = cx.waker();
self.source.track();
let value = self.value.read_arc();
pin_mut!(value);
match (self.loading.load(Ordering::Relaxed), value.poll(cx)) {
(true, _) => {
self.wakers.write().or_poisoned().push(waker.clone());
Poll::Pending
}
(_, Poll::Pending) => Poll::Pending,
(_, Poll::Ready(guard)) => {
Poll::Ready(guard.as_ref().unwrap().clone())
}
}
}
}
impl<T: 'static> ArcAsyncDerived<T> {
#[track_caller]
pub fn by_ref(&self) -> AsyncDerivedRefFuture<T> {
AsyncDerivedRefFuture {
source: self.to_any_source(),
value: Arc::clone(&self.value),
loading: Arc::clone(&self.loading),
wakers: Arc::clone(&self.wakers),
}
}
}
impl<T, S> AsyncDerived<T, S>
where
T: 'static,
S: Storage<ArcAsyncDerived<T>>,
{
#[track_caller]
pub fn by_ref(&self) -> AsyncDerivedRefFuture<T> {
let this = self
.inner
.try_get_value()
.unwrap_or_else(unwrap_signal!(self));
this.by_ref()
}
}
pub struct AsyncDerivedRefFuture<T> {
source: AnySource,
value: Arc<async_lock::RwLock<Option<T>>>,
loading: Arc<AtomicBool>,
wakers: Arc<RwLock<Vec<Waker>>>,
}
impl<T> Future for AsyncDerivedRefFuture<T>
where
T: 'static,
{
type Output = AsyncDerivedGuard<T>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let waker = cx.waker();
self.source.track();
let value = self.value.read_arc();
pin_mut!(value);
match (self.loading.load(Ordering::Relaxed), value.poll(cx)) {
(true, _) => {
self.wakers.write().or_poisoned().push(waker.clone());
Poll::Pending
}
(_, Poll::Pending) => Poll::Pending,
(_, Poll::Ready(guard)) => Poll::Ready(ReadGuard::new(
Mapped::new_with_guard(AsyncPlain { guard }, |guard| {
guard.as_ref().unwrap()
}),
)),
}
}
}