Skip to main content

hyper_util/client/pool/
singleton.rs

1//! Singleton pools
2//!
3//! This ensures that only one active connection is made.
4//!
5//! The singleton pool wraps a `MakeService<T, Req>` so that it only produces a
6//! single `Service<Req>`. It bundles all concurrent calls to it, so that only
7//! one connection is made. All calls to the singleton will return a clone of
8//! the inner service once established.
9//!
10//! This fits the HTTP/2 case well.
11//!
12//! ## Example
13//!
14//! ```rust,ignore
15//! let mut pool = Singleton::new(some_make_svc);
16//!
17//! let svc1 = pool.call(some_dst).await?;
18//!
19//! let svc2 = pool.call(some_dst).await?;
20//! // svc1 == svc2
21//! ```
22
23use std::sync::{Arc, Mutex};
24use std::task::{self, Poll};
25
26use tokio::sync::oneshot;
27use tower_service::Service;
28
29use self::internal::{DitchGuard, SingletonError, SingletonFuture, State};
30
31type BoxError = Box<dyn std::error::Error + Send + Sync>;
32
33#[cfg(docsrs)]
34pub use self::internal::Singled;
35
36/// A singleton pool over an inner service.
37///
38/// The singleton wraps an inner service maker, bundling all calls to ensure
39/// only one service is created. Once made, it returns clones of the made
40/// service.
41#[derive(Debug)]
42pub struct Singleton<M, Dst>
43where
44    M: Service<Dst>,
45{
46    mk_svc: M,
47    state: Arc<Mutex<State<M::Response>>>,
48}
49
50impl<M, Target> Singleton<M, Target>
51where
52    M: Service<Target>,
53    M::Response: Clone,
54{
55    /// Create a new singleton pool over an inner make service.
56    pub fn new(mk_svc: M) -> Self {
57        Singleton {
58            mk_svc,
59            state: Arc::new(Mutex::new(State::Empty)),
60        }
61    }
62
63    // pub fn clear? cancel?
64
65    /// Retains the inner made service if specified by the predicate.
66    pub fn retain<F>(&mut self, mut predicate: F)
67    where
68        F: FnMut(&mut M::Response) -> bool,
69    {
70        let mut locked = self.state.lock().unwrap();
71        match *locked {
72            State::Empty => {}
73            State::Making(..) => {}
74            State::Made(ref mut svc) => {
75                if !predicate(svc) {
76                    *locked = State::Empty;
77                }
78            }
79        }
80    }
81
82    /// Returns whether this singleton pool is empty.
83    ///
84    /// If this pool has created a shared instance, or is currently in the
85    /// process of creating one, this returns false.
86    pub fn is_empty(&self) -> bool {
87        matches!(*self.state.lock().unwrap(), State::Empty)
88    }
89}
90
91impl<M, Target> Service<Target> for Singleton<M, Target>
92where
93    M: Service<Target>,
94    M::Response: Clone,
95    M::Error: Into<BoxError>,
96{
97    type Response = internal::Singled<M::Response>;
98    type Error = SingletonError;
99    type Future = SingletonFuture<M::Future, M::Response>;
100
101    fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
102        if let State::Empty = *self.state.lock().unwrap() {
103            return self
104                .mk_svc
105                .poll_ready(cx)
106                .map_err(|e| SingletonError(e.into()));
107        }
108        Poll::Ready(Ok(()))
109    }
110
111    fn call(&mut self, dst: Target) -> Self::Future {
112        let mut locked = self.state.lock().unwrap();
113        match *locked {
114            State::Empty => {
115                let fut = self.mk_svc.call(dst);
116                *locked = State::Making(Vec::new());
117                SingletonFuture::Driving {
118                    future: fut,
119                    singleton: DitchGuard(Arc::downgrade(&self.state)),
120                }
121            }
122            State::Making(ref mut waiters) => {
123                let (tx, rx) = oneshot::channel();
124                waiters.push(tx);
125                SingletonFuture::Waiting {
126                    rx,
127                    state: Arc::downgrade(&self.state),
128                }
129            }
130            State::Made(ref svc) => SingletonFuture::Made {
131                svc: Some(svc.clone()),
132                state: Arc::downgrade(&self.state),
133            },
134        }
135    }
136}
137
138impl<M, Target> Clone for Singleton<M, Target>
139where
140    M: Service<Target> + Clone,
141{
142    fn clone(&self) -> Self {
143        Self {
144            mk_svc: self.mk_svc.clone(),
145            state: self.state.clone(),
146        }
147    }
148}
149
150// Holds some "pub" items that otherwise shouldn't be public.
151mod internal {
152    use std::future::Future;
153    use std::pin::Pin;
154    use std::sync::{Mutex, Weak};
155    use std::task::{self, ready, Poll};
156
157    use pin_project_lite::pin_project;
158    use tokio::sync::oneshot;
159    use tower_service::Service;
160
161    use super::BoxError;
162
163    pin_project! {
164        #[project = SingletonFutureProj]
165        pub enum SingletonFuture<F, S> {
166            Driving {
167                #[pin]
168                future: F,
169                singleton: DitchGuard<S>,
170            },
171            Waiting {
172                rx: oneshot::Receiver<S>,
173                state: Weak<Mutex<State<S>>>,
174            },
175            Made {
176                svc: Option<S>,
177                state: Weak<Mutex<State<S>>>,
178            },
179        }
180    }
181
182    // XXX: pub because of the enum SingletonFuture
183    #[derive(Debug)]
184    pub enum State<S> {
185        Empty,
186        Making(Vec<oneshot::Sender<S>>),
187        Made(S),
188    }
189
190    // XXX: pub because of the enum SingletonFuture
191    pub struct DitchGuard<S>(pub(super) Weak<Mutex<State<S>>>);
192
193    /// A cached service returned from a [`Singleton`].
194    ///
195    /// Implements `Service` by delegating to the inner service. If
196    /// `poll_ready` returns an error, this will clear the cache in the related
197    /// `Singleton`.
198    ///
199    /// [`Singleton`]: super::Singleton
200    ///
201    /// # Unnameable
202    ///
203    /// This type is normally unnameable, forbidding naming of the type within
204    /// code. The type is exposed in the documentation to show which methods
205    /// can be publicly called.
206    #[derive(Debug)]
207    pub struct Singled<S> {
208        inner: S,
209        state: Weak<Mutex<State<S>>>,
210    }
211
212    impl<F, S, E> Future for SingletonFuture<F, S>
213    where
214        F: Future<Output = Result<S, E>>,
215        E: Into<BoxError>,
216        S: Clone,
217    {
218        type Output = Result<Singled<S>, SingletonError>;
219
220        fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
221            match self.project() {
222                SingletonFutureProj::Driving { future, singleton } => {
223                    match ready!(future.poll(cx)) {
224                        Ok(svc) => {
225                            if let Some(state) = singleton.0.upgrade() {
226                                let mut locked = state.lock().unwrap();
227                                match std::mem::replace(&mut *locked, State::Made(svc.clone())) {
228                                    State::Making(waiters) => {
229                                        for tx in waiters {
230                                            let _ = tx.send(svc.clone());
231                                        }
232                                    }
233                                    State::Empty | State::Made(_) => {
234                                        // shouldn't happen!
235                                        unreachable!()
236                                    }
237                                }
238                            }
239                            // take out of the DitchGuard so it doesn't treat as "ditched"
240                            let state = std::mem::replace(&mut singleton.0, Weak::new());
241                            Poll::Ready(Ok(Singled::new(svc, state)))
242                        }
243                        Err(e) => {
244                            if let Some(state) = singleton.0.upgrade() {
245                                let mut locked = state.lock().unwrap();
246                                singleton.0 = Weak::new();
247                                *locked = State::Empty;
248                            }
249                            Poll::Ready(Err(SingletonError(e.into())))
250                        }
251                    }
252                }
253                SingletonFutureProj::Waiting { rx, state } => match ready!(Pin::new(rx).poll(cx)) {
254                    Ok(svc) => Poll::Ready(Ok(Singled::new(svc, state.clone()))),
255                    Err(_canceled) => Poll::Ready(Err(SingletonError(Canceled.into()))),
256                },
257                SingletonFutureProj::Made { svc, state } => {
258                    Poll::Ready(Ok(Singled::new(svc.take().unwrap(), state.clone())))
259                }
260            }
261        }
262    }
263
264    impl<S> Drop for DitchGuard<S> {
265        fn drop(&mut self) {
266            if let Some(state) = self.0.upgrade() {
267                if let Ok(mut locked) = state.lock() {
268                    *locked = State::Empty;
269                }
270            }
271        }
272    }
273
274    impl<S> Singled<S> {
275        fn new(inner: S, state: Weak<Mutex<State<S>>>) -> Self {
276            Singled { inner, state }
277        }
278    }
279
280    impl<S, Req> Service<Req> for Singled<S>
281    where
282        S: Service<Req>,
283    {
284        type Response = S::Response;
285        type Error = S::Error;
286        type Future = S::Future;
287
288        fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
289            // We notice if the cached service dies, and clear the singleton cache.
290            match self.inner.poll_ready(cx) {
291                Poll::Ready(Err(err)) => {
292                    if let Some(state) = self.state.upgrade() {
293                        *state.lock().unwrap() = State::Empty;
294                    }
295                    Poll::Ready(Err(err))
296                }
297                other => other,
298            }
299        }
300
301        fn call(&mut self, req: Req) -> Self::Future {
302            self.inner.call(req)
303        }
304    }
305
306    // An opaque error type. By not exposing the type, nor being specifically
307    // Box<dyn Error>, we can _change_ the type once we no longer need the Canceled
308    // error type. This will be possible with the refactor to baton passing.
309    #[derive(Debug)]
310    pub struct SingletonError(pub(super) BoxError);
311
312    impl std::fmt::Display for SingletonError {
313        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
314            f.write_str("singleton connection error")
315        }
316    }
317
318    impl std::error::Error for SingletonError {
319        fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
320            Some(&*self.0)
321        }
322    }
323
324    #[derive(Debug)]
325    struct Canceled;
326
327    impl std::fmt::Display for Canceled {
328        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
329            f.write_str("singleton connection canceled")
330        }
331    }
332
333    impl std::error::Error for Canceled {}
334}
335
336#[cfg(test)]
337mod tests {
338    use std::future::Future;
339    use std::pin::Pin;
340    use std::task::Poll;
341
342    use tower_service::Service;
343
344    use super::Singleton;
345
346    #[tokio::test]
347    async fn first_call_drives_subsequent_wait() {
348        let (mock_svc, mut handle) = tower_test::mock::pair::<(), &'static str>();
349
350        let mut singleton = Singleton::new(mock_svc);
351
352        handle.allow(1);
353        std::future::poll_fn(|cx| singleton.poll_ready(cx))
354            .await
355            .unwrap();
356        // First call: should go into Driving
357        let fut1 = singleton.call(());
358        // Second call: should go into Waiting
359        let fut2 = singleton.call(());
360
361        // Expect exactly one request to the inner service
362        let ((), send_response) = handle.next_request().await.unwrap();
363        send_response.send_response("svc");
364
365        // Both futures should resolve to the same value
366        fut1.await.unwrap();
367        fut2.await.unwrap();
368    }
369
370    #[tokio::test]
371    async fn made_state_returns_immediately() {
372        let (mock_svc, mut handle) = tower_test::mock::pair::<(), &'static str>();
373        let mut singleton = Singleton::new(mock_svc);
374
375        handle.allow(1);
376        std::future::poll_fn(|cx| singleton.poll_ready(cx))
377            .await
378            .unwrap();
379        // Drive first call to completion
380        let fut1 = singleton.call(());
381        let ((), send_response) = handle.next_request().await.unwrap();
382        send_response.send_response("svc");
383        fut1.await.unwrap();
384
385        // Second call should not hit inner service
386        singleton.call(()).await.unwrap();
387    }
388
389    #[tokio::test]
390    async fn cached_service_poll_ready_error_clears_singleton() {
391        // Outer mock returns an inner mock service
392        let (outer, mut outer_handle) =
393            tower_test::mock::pair::<(), tower_test::mock::Mock<(), &'static str>>();
394        let mut singleton = Singleton::new(outer);
395
396        // Allow the singleton to be made
397        outer_handle.allow(2);
398        std::future::poll_fn(|cx| singleton.poll_ready(cx))
399            .await
400            .unwrap();
401
402        // First call produces an inner mock service
403        let fut1 = singleton.call(());
404        let ((), send_inner) = outer_handle.next_request().await.unwrap();
405        let (inner, mut inner_handle) = tower_test::mock::pair::<(), &'static str>();
406        send_inner.send_response(inner);
407        let mut cached = fut1.await.unwrap();
408
409        // Now: allow readiness on the inner mock, then inject error
410        inner_handle.allow(1);
411
412        // Inject error so next poll_ready fails
413        inner_handle.send_error(std::io::Error::new(
414            std::io::ErrorKind::Other,
415            "cached poll_ready failed",
416        ));
417
418        // Drive poll_ready on cached service
419        let err = std::future::poll_fn(|cx| cached.poll_ready(cx))
420            .await
421            .err()
422            .expect("expected poll_ready error");
423        assert_eq!(err.to_string(), "cached poll_ready failed");
424
425        // After error, the singleton should be cleared, so a new call drives outer again
426        outer_handle.allow(1);
427        std::future::poll_fn(|cx| singleton.poll_ready(cx))
428            .await
429            .unwrap();
430        let fut2 = singleton.call(());
431        let ((), send_inner2) = outer_handle.next_request().await.unwrap();
432        let (inner2, mut inner_handle2) = tower_test::mock::pair::<(), &'static str>();
433        send_inner2.send_response(inner2);
434        let mut cached2 = fut2.await.unwrap();
435
436        // The new cached service should still work
437        inner_handle2.allow(1);
438        std::future::poll_fn(|cx| cached2.poll_ready(cx))
439            .await
440            .expect("expected poll_ready");
441        let cfut2 = cached2.call(());
442        let ((), send_cached2) = inner_handle2.next_request().await.unwrap();
443        send_cached2.send_response("svc2");
444        cfut2.await.unwrap();
445    }
446
447    #[tokio::test]
448    async fn cancel_waiter_does_not_affect_others() {
449        let (mock_svc, mut handle) = tower_test::mock::pair::<(), &'static str>();
450        let mut singleton = Singleton::new(mock_svc);
451
452        std::future::poll_fn(|cx| singleton.poll_ready(cx))
453            .await
454            .unwrap();
455        let fut1 = singleton.call(());
456        let fut2 = singleton.call(());
457        drop(fut2); // cancel one waiter
458
459        let ((), send_response) = handle.next_request().await.unwrap();
460        send_response.send_response("svc");
461
462        fut1.await.unwrap();
463    }
464
465    // TODO: this should be able to be improved with a cooperative baton refactor
466    #[tokio::test]
467    async fn cancel_driver_cancels_all() {
468        let (mock_svc, mut handle) = tower_test::mock::pair::<(), &'static str>();
469        let mut singleton = Singleton::new(mock_svc);
470
471        std::future::poll_fn(|cx| singleton.poll_ready(cx))
472            .await
473            .unwrap();
474        let mut fut1 = singleton.call(());
475        let fut2 = singleton.call(());
476
477        // poll driver just once, and then drop
478        std::future::poll_fn(move |cx| {
479            let _ = Pin::new(&mut fut1).poll(cx);
480            Poll::Ready(())
481        })
482        .await;
483
484        let ((), send_response) = handle.next_request().await.unwrap();
485        send_response.send_response("svc");
486
487        assert_eq!(
488            fut2.await.unwrap_err().0.to_string(),
489            "singleton connection canceled"
490        );
491    }
492}