use futures_core::{Future, Stream};
use futures_util::ready;
use std::{
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
#[derive(Debug)]
pub struct Initialized {
semaphore: Arc<Semaphore>,
issued: u32,
}
#[derive(Debug)]
#[must_use]
pub struct Handle(OwnedSemaphorePermit);
pin_project_lite::pin_project! {
#[derive(Debug)]
pub struct ReleasesOnReady<T> {
#[pin]
inner: T,
handle: Option<Handle>,
}
}
impl Default for Initialized {
fn default() -> Self {
Self {
semaphore: Arc::new(Semaphore::new(0)),
issued: 0,
}
}
}
impl Initialized {
pub fn add_handle(&mut self) -> Handle {
let sem = self.semaphore.clone();
sem.add_permits(1);
let permit = sem
.try_acquire_owned()
.expect("semaphore must issue permit");
self.issued += 1;
Handle(permit)
}
pub async fn initialized(self) {
let _permit = self
.semaphore
.acquire_many(self.issued)
.await
.expect("semaphore cannot be closed");
}
}
impl Handle {
pub fn release_on_ready<T>(self, unready: T) -> ReleasesOnReady<T> {
ReleasesOnReady::new(unready, self)
}
}
impl<T> ReleasesOnReady<T> {
fn new(inner: T, handle: Handle) -> Self {
Self {
inner,
handle: Some(handle),
}
}
}
impl<F: Future> Future for ReleasesOnReady<F> {
type Output = F::Output;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<F::Output> {
let mut this = self.project();
let out = ready!(this.inner.as_mut().poll(cx));
drop(this.handle.take());
Poll::Ready(out)
}
}
impl<S: Stream> Stream for ReleasesOnReady<S> {
type Item = S::Item;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<S::Item>> {
let mut this = self.project();
let next = ready!(this.inner.as_mut().poll_next(cx));
drop(this.handle.take());
Poll::Ready(next)
}
}
#[cfg(test)]
mod test {
use super::*;
use tokio_stream::wrappers::ReceiverStream;
use tokio_test::{assert_pending, assert_ready, task};
#[tokio::test]
async fn initializes() {
let mut init = task::spawn(Initialized::default().initialized());
assert_ready!(init.poll());
}
#[tokio::test]
async fn initializes_on_drop() {
let mut init = Initialized::default();
let handle0 = init.add_handle();
let handle1 = init.add_handle();
let mut init = task::spawn(init.initialized());
assert_pending!(init.poll());
drop(handle0);
assert_pending!(init.poll());
drop(handle1);
assert_ready!(init.poll());
}
#[tokio::test]
async fn initializes_on_future() {
let mut init = Initialized::default();
let (tx, mut rx) = {
let (tx, rx) = tokio::sync::oneshot::channel();
let rx = task::spawn(ReleasesOnReady::new(rx, init.add_handle()));
(tx, rx)
};
let mut init = task::spawn(init.initialized());
assert_pending!(rx.poll());
assert_pending!(init.poll());
tx.send("hello").unwrap();
assert_ready!(rx.poll()).unwrap();
assert_ready!(init.poll());
}
#[tokio::test]
async fn initializes_on_stream() {
let mut init = Initialized::default();
let (tx, mut rx) = {
let (tx, rx) = tokio::sync::mpsc::channel(2);
let rx = task::spawn(ReleasesOnReady::new(
ReceiverStream::new(rx),
init.add_handle(),
));
(tx, rx)
};
let mut init = task::spawn(init.initialized());
assert_pending!(rx.poll_next());
assert_pending!(init.poll());
tx.try_send("hello").unwrap();
assert_ready!(rx.poll_next());
assert_ready!(init.poll());
}
}