use std::sync::{Arc, Mutex};
use std::task::{self, Poll};
use tokio::sync::oneshot;
use tower_service::Service;
use self::internal::{DitchGuard, SingletonError, SingletonFuture, State};
type BoxError = Box<dyn std::error::Error + Send + Sync>;
#[cfg(docsrs)]
pub use self::internal::Singled;
#[derive(Debug)]
pub struct Singleton<M, Dst>
where
M: Service<Dst>,
{
mk_svc: M,
state: Arc<Mutex<State<M::Response>>>,
}
impl<M, Target> Singleton<M, Target>
where
M: Service<Target>,
M::Response: Clone,
{
pub fn new(mk_svc: M) -> Self {
Singleton {
mk_svc,
state: Arc::new(Mutex::new(State::Empty)),
}
}
pub fn retain<F>(&mut self, mut predicate: F)
where
F: FnMut(&mut M::Response) -> bool,
{
let mut locked = self.state.lock().unwrap();
match *locked {
State::Empty => {}
State::Making(..) => {}
State::Made(ref mut svc) => {
if !predicate(svc) {
*locked = State::Empty;
}
}
}
}
pub fn is_empty(&self) -> bool {
matches!(*self.state.lock().unwrap(), State::Empty)
}
}
impl<M, Target> Service<Target> for Singleton<M, Target>
where
M: Service<Target>,
M::Response: Clone,
M::Error: Into<BoxError>,
{
type Response = internal::Singled<M::Response>;
type Error = SingletonError;
type Future = SingletonFuture<M::Future, M::Response>;
fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
if let State::Empty = *self.state.lock().unwrap() {
return self
.mk_svc
.poll_ready(cx)
.map_err(|e| SingletonError(e.into()));
}
Poll::Ready(Ok(()))
}
fn call(&mut self, dst: Target) -> Self::Future {
let mut locked = self.state.lock().unwrap();
match *locked {
State::Empty => {
let fut = self.mk_svc.call(dst);
*locked = State::Making(Vec::new());
SingletonFuture::Driving {
future: fut,
singleton: DitchGuard(Arc::downgrade(&self.state)),
}
}
State::Making(ref mut waiters) => {
let (tx, rx) = oneshot::channel();
waiters.push(tx);
SingletonFuture::Waiting {
rx,
state: Arc::downgrade(&self.state),
}
}
State::Made(ref svc) => SingletonFuture::Made {
svc: Some(svc.clone()),
state: Arc::downgrade(&self.state),
},
}
}
}
impl<M, Target> Clone for Singleton<M, Target>
where
M: Service<Target> + Clone,
{
fn clone(&self) -> Self {
Self {
mk_svc: self.mk_svc.clone(),
state: self.state.clone(),
}
}
}
mod internal {
use std::future::Future;
use std::pin::Pin;
use std::sync::{Mutex, Weak};
use std::task::{self, ready, Poll};
use pin_project_lite::pin_project;
use tokio::sync::oneshot;
use tower_service::Service;
use super::BoxError;
pin_project! {
#[project = SingletonFutureProj]
pub enum SingletonFuture<F, S> {
Driving {
#[pin]
future: F,
singleton: DitchGuard<S>,
},
Waiting {
rx: oneshot::Receiver<S>,
state: Weak<Mutex<State<S>>>,
},
Made {
svc: Option<S>,
state: Weak<Mutex<State<S>>>,
},
}
}
#[derive(Debug)]
pub enum State<S> {
Empty,
Making(Vec<oneshot::Sender<S>>),
Made(S),
}
pub struct DitchGuard<S>(pub(super) Weak<Mutex<State<S>>>);
#[derive(Debug)]
pub struct Singled<S> {
inner: S,
state: Weak<Mutex<State<S>>>,
}
impl<F, S, E> Future for SingletonFuture<F, S>
where
F: Future<Output = Result<S, E>>,
E: Into<BoxError>,
S: Clone,
{
type Output = Result<Singled<S>, SingletonError>;
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
match self.project() {
SingletonFutureProj::Driving { future, singleton } => {
match ready!(future.poll(cx)) {
Ok(svc) => {
if let Some(state) = singleton.0.upgrade() {
let mut locked = state.lock().unwrap();
match std::mem::replace(&mut *locked, State::Made(svc.clone())) {
State::Making(waiters) => {
for tx in waiters {
let _ = tx.send(svc.clone());
}
}
State::Empty | State::Made(_) => {
unreachable!()
}
}
}
let state = std::mem::replace(&mut singleton.0, Weak::new());
Poll::Ready(Ok(Singled::new(svc, state)))
}
Err(e) => {
if let Some(state) = singleton.0.upgrade() {
let mut locked = state.lock().unwrap();
singleton.0 = Weak::new();
*locked = State::Empty;
}
Poll::Ready(Err(SingletonError(e.into())))
}
}
}
SingletonFutureProj::Waiting { rx, state } => match ready!(Pin::new(rx).poll(cx)) {
Ok(svc) => Poll::Ready(Ok(Singled::new(svc, state.clone()))),
Err(_canceled) => Poll::Ready(Err(SingletonError(Canceled.into()))),
},
SingletonFutureProj::Made { svc, state } => {
Poll::Ready(Ok(Singled::new(svc.take().unwrap(), state.clone())))
}
}
}
}
impl<S> Drop for DitchGuard<S> {
fn drop(&mut self) {
if let Some(state) = self.0.upgrade() {
if let Ok(mut locked) = state.lock() {
*locked = State::Empty;
}
}
}
}
impl<S> Singled<S> {
fn new(inner: S, state: Weak<Mutex<State<S>>>) -> Self {
Singled { inner, state }
}
}
impl<S, Req> Service<Req> for Singled<S>
where
S: Service<Req>,
{
type Response = S::Response;
type Error = S::Error;
type Future = S::Future;
fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
match self.inner.poll_ready(cx) {
Poll::Ready(Err(err)) => {
if let Some(state) = self.state.upgrade() {
*state.lock().unwrap() = State::Empty;
}
Poll::Ready(Err(err))
}
other => other,
}
}
fn call(&mut self, req: Req) -> Self::Future {
self.inner.call(req)
}
}
#[derive(Debug)]
pub struct SingletonError(pub(super) BoxError);
impl std::fmt::Display for SingletonError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("singleton connection error")
}
}
impl std::error::Error for SingletonError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
Some(&*self.0)
}
}
#[derive(Debug)]
struct Canceled;
impl std::fmt::Display for Canceled {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("singleton connection canceled")
}
}
impl std::error::Error for Canceled {}
}
#[cfg(test)]
mod tests {
use std::future::Future;
use std::pin::Pin;
use std::task::Poll;
use tower_service::Service;
use super::Singleton;
#[tokio::test]
async fn first_call_drives_subsequent_wait() {
let (mock_svc, mut handle) = tower_test::mock::pair::<(), &'static str>();
let mut singleton = Singleton::new(mock_svc);
handle.allow(1);
std::future::poll_fn(|cx| singleton.poll_ready(cx))
.await
.unwrap();
let fut1 = singleton.call(());
let fut2 = singleton.call(());
let ((), send_response) = handle.next_request().await.unwrap();
send_response.send_response("svc");
fut1.await.unwrap();
fut2.await.unwrap();
}
#[tokio::test]
async fn made_state_returns_immediately() {
let (mock_svc, mut handle) = tower_test::mock::pair::<(), &'static str>();
let mut singleton = Singleton::new(mock_svc);
handle.allow(1);
std::future::poll_fn(|cx| singleton.poll_ready(cx))
.await
.unwrap();
let fut1 = singleton.call(());
let ((), send_response) = handle.next_request().await.unwrap();
send_response.send_response("svc");
fut1.await.unwrap();
singleton.call(()).await.unwrap();
}
#[tokio::test]
async fn cached_service_poll_ready_error_clears_singleton() {
let (outer, mut outer_handle) =
tower_test::mock::pair::<(), tower_test::mock::Mock<(), &'static str>>();
let mut singleton = Singleton::new(outer);
outer_handle.allow(2);
std::future::poll_fn(|cx| singleton.poll_ready(cx))
.await
.unwrap();
let fut1 = singleton.call(());
let ((), send_inner) = outer_handle.next_request().await.unwrap();
let (inner, mut inner_handle) = tower_test::mock::pair::<(), &'static str>();
send_inner.send_response(inner);
let mut cached = fut1.await.unwrap();
inner_handle.allow(1);
inner_handle.send_error(std::io::Error::new(
std::io::ErrorKind::Other,
"cached poll_ready failed",
));
let err = std::future::poll_fn(|cx| cached.poll_ready(cx))
.await
.err()
.expect("expected poll_ready error");
assert_eq!(err.to_string(), "cached poll_ready failed");
outer_handle.allow(1);
std::future::poll_fn(|cx| singleton.poll_ready(cx))
.await
.unwrap();
let fut2 = singleton.call(());
let ((), send_inner2) = outer_handle.next_request().await.unwrap();
let (inner2, mut inner_handle2) = tower_test::mock::pair::<(), &'static str>();
send_inner2.send_response(inner2);
let mut cached2 = fut2.await.unwrap();
inner_handle2.allow(1);
std::future::poll_fn(|cx| cached2.poll_ready(cx))
.await
.expect("expected poll_ready");
let cfut2 = cached2.call(());
let ((), send_cached2) = inner_handle2.next_request().await.unwrap();
send_cached2.send_response("svc2");
cfut2.await.unwrap();
}
#[tokio::test]
async fn cancel_waiter_does_not_affect_others() {
let (mock_svc, mut handle) = tower_test::mock::pair::<(), &'static str>();
let mut singleton = Singleton::new(mock_svc);
std::future::poll_fn(|cx| singleton.poll_ready(cx))
.await
.unwrap();
let fut1 = singleton.call(());
let fut2 = singleton.call(());
drop(fut2);
let ((), send_response) = handle.next_request().await.unwrap();
send_response.send_response("svc");
fut1.await.unwrap();
}
#[tokio::test]
async fn cancel_driver_cancels_all() {
let (mock_svc, mut handle) = tower_test::mock::pair::<(), &'static str>();
let mut singleton = Singleton::new(mock_svc);
std::future::poll_fn(|cx| singleton.poll_ready(cx))
.await
.unwrap();
let mut fut1 = singleton.call(());
let fut2 = singleton.call(());
std::future::poll_fn(move |cx| {
let _ = Pin::new(&mut fut1).poll(cx);
Poll::Ready(())
})
.await;
let ((), send_response) = handle.next_request().await.unwrap();
send_response.send_response("svc");
assert_eq!(
fut2.await.unwrap_err().0.to_string(),
"singleton connection canceled"
);
}
}