hyper_util/client/pool/
singleton.rs

1//! Singleton pools
2//!
3//! This ensures that only one active connection is made.
4//!
5//! The singleton pool wraps a `MakeService<T, Req>` so that it only produces a
6//! single `Service<Req>`. It bundles all concurrent calls to it, so that only
7//! one connection is made. All calls to the singleton will return a clone of
8//! the inner service once established.
9//!
10//! This fits the HTTP/2 case well.
11//!
12//! ## Example
13//!
14//! ```rust,ignore
15//! let mut pool = Singleton::new(some_make_svc);
16//!
17//! let svc1 = pool.call(some_dst).await?;
18//!
19//! let svc2 = pool.call(some_dst).await?;
20//! // svc1 == svc2
21//! ```
22
23use std::sync::{Arc, Mutex};
24use std::task::{self, Poll};
25
26use tokio::sync::oneshot;
27use tower_service::Service;
28
29use self::internal::{DitchGuard, SingletonError, SingletonFuture, State};
30
31type BoxError = Box<dyn std::error::Error + Send + Sync>;
32
33#[cfg(docsrs)]
34pub use self::internal::Singled;
35
36/// A singleton pool over an inner service.
37///
38/// The singleton wraps an inner service maker, bundling all calls to ensure
39/// only one service is created. Once made, it returns clones of the made
40/// service.
41#[derive(Debug)]
42pub struct Singleton<M, Dst>
43where
44    M: Service<Dst>,
45{
46    mk_svc: M,
47    state: Arc<Mutex<State<M::Response>>>,
48}
49
50impl<M, Target> Singleton<M, Target>
51where
52    M: Service<Target>,
53    M::Response: Clone,
54{
55    /// Create a new singleton pool over an inner make service.
56    pub fn new(mk_svc: M) -> Self {
57        Singleton {
58            mk_svc,
59            state: Arc::new(Mutex::new(State::Empty)),
60        }
61    }
62
63    // pub fn clear? cancel?
64
65    /// Retains the inner made service if specified by the predicate.
66    pub fn retain<F>(&mut self, mut predicate: F)
67    where
68        F: FnMut(&mut M::Response) -> bool,
69    {
70        let mut locked = self.state.lock().unwrap();
71        match *locked {
72            State::Empty => {}
73            State::Making(..) => {}
74            State::Made(ref mut svc) => {
75                if !predicate(svc) {
76                    *locked = State::Empty;
77                }
78            }
79        }
80    }
81
82    /// Returns whether this singleton pool is empty.
83    ///
84    /// If this pool has created a shared instance, or is currently in the
85    /// process of creating one, this returns false.
86    pub fn is_empty(&self) -> bool {
87        matches!(*self.state.lock().unwrap(), State::Empty)
88    }
89}
90
91impl<M, Target> Service<Target> for Singleton<M, Target>
92where
93    M: Service<Target>,
94    M::Response: Clone,
95    M::Error: Into<BoxError>,
96{
97    type Response = internal::Singled<M::Response>;
98    type Error = SingletonError;
99    type Future = SingletonFuture<M::Future, M::Response>;
100
101    fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
102        if let State::Empty = *self.state.lock().unwrap() {
103            return self
104                .mk_svc
105                .poll_ready(cx)
106                .map_err(|e| SingletonError(e.into()));
107        }
108        Poll::Ready(Ok(()))
109    }
110
111    fn call(&mut self, dst: Target) -> Self::Future {
112        let mut locked = self.state.lock().unwrap();
113        match *locked {
114            State::Empty => {
115                let fut = self.mk_svc.call(dst);
116                *locked = State::Making(Vec::new());
117                SingletonFuture::Driving {
118                    future: fut,
119                    singleton: DitchGuard(Arc::downgrade(&self.state)),
120                }
121            }
122            State::Making(ref mut waiters) => {
123                let (tx, rx) = oneshot::channel();
124                waiters.push(tx);
125                SingletonFuture::Waiting {
126                    rx,
127                    state: Arc::downgrade(&self.state),
128                }
129            }
130            State::Made(ref svc) => SingletonFuture::Made {
131                svc: Some(svc.clone()),
132                state: Arc::downgrade(&self.state),
133            },
134        }
135    }
136}
137
138impl<M, Target> Clone for Singleton<M, Target>
139where
140    M: Service<Target> + Clone,
141{
142    fn clone(&self) -> Self {
143        Self {
144            mk_svc: self.mk_svc.clone(),
145            state: self.state.clone(),
146        }
147    }
148}
149
150// Holds some "pub" items that otherwise shouldn't be public.
151mod internal {
152    use std::future::Future;
153    use std::pin::Pin;
154    use std::sync::{Mutex, Weak};
155    use std::task::{self, Poll};
156
157    use futures_core::ready;
158    use pin_project_lite::pin_project;
159    use tokio::sync::oneshot;
160    use tower_service::Service;
161
162    use super::BoxError;
163
164    pin_project! {
165        #[project = SingletonFutureProj]
166        pub enum SingletonFuture<F, S> {
167            Driving {
168                #[pin]
169                future: F,
170                singleton: DitchGuard<S>,
171            },
172            Waiting {
173                rx: oneshot::Receiver<S>,
174                state: Weak<Mutex<State<S>>>,
175            },
176            Made {
177                svc: Option<S>,
178                state: Weak<Mutex<State<S>>>,
179            },
180        }
181    }
182
183    // XXX: pub because of the enum SingletonFuture
184    #[derive(Debug)]
185    pub enum State<S> {
186        Empty,
187        Making(Vec<oneshot::Sender<S>>),
188        Made(S),
189    }
190
191    // XXX: pub because of the enum SingletonFuture
192    pub struct DitchGuard<S>(pub(super) Weak<Mutex<State<S>>>);
193
194    /// A cached service returned from a [`Singleton`].
195    ///
196    /// Implements `Service` by delegating to the inner service. If
197    /// `poll_ready` returns an error, this will clear the cache in the related
198    /// `Singleton`.
199    ///
200    /// [`Singleton`]: super::Singleton
201    ///
202    /// # Unnameable
203    ///
204    /// This type is normally unnameable, forbidding naming of the type within
205    /// code. The type is exposed in the documentation to show which methods
206    /// can be publicly called.
207    #[derive(Debug)]
208    pub struct Singled<S> {
209        inner: S,
210        state: Weak<Mutex<State<S>>>,
211    }
212
213    impl<F, S, E> Future for SingletonFuture<F, S>
214    where
215        F: Future<Output = Result<S, E>>,
216        E: Into<BoxError>,
217        S: Clone,
218    {
219        type Output = Result<Singled<S>, SingletonError>;
220
221        fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
222            match self.project() {
223                SingletonFutureProj::Driving { future, singleton } => {
224                    match ready!(future.poll(cx)) {
225                        Ok(svc) => {
226                            if let Some(state) = singleton.0.upgrade() {
227                                let mut locked = state.lock().unwrap();
228                                match std::mem::replace(&mut *locked, State::Made(svc.clone())) {
229                                    State::Making(waiters) => {
230                                        for tx in waiters {
231                                            let _ = tx.send(svc.clone());
232                                        }
233                                    }
234                                    State::Empty | State::Made(_) => {
235                                        // shouldn't happen!
236                                        unreachable!()
237                                    }
238                                }
239                            }
240                            // take out of the DitchGuard so it doesn't treat as "ditched"
241                            let state = std::mem::replace(&mut singleton.0, Weak::new());
242                            Poll::Ready(Ok(Singled::new(svc, state)))
243                        }
244                        Err(e) => {
245                            if let Some(state) = singleton.0.upgrade() {
246                                let mut locked = state.lock().unwrap();
247                                singleton.0 = Weak::new();
248                                *locked = State::Empty;
249                            }
250                            Poll::Ready(Err(SingletonError(e.into())))
251                        }
252                    }
253                }
254                SingletonFutureProj::Waiting { rx, state } => match ready!(Pin::new(rx).poll(cx)) {
255                    Ok(svc) => Poll::Ready(Ok(Singled::new(svc, state.clone()))),
256                    Err(_canceled) => Poll::Ready(Err(SingletonError(Canceled.into()))),
257                },
258                SingletonFutureProj::Made { svc, state } => {
259                    Poll::Ready(Ok(Singled::new(svc.take().unwrap(), state.clone())))
260                }
261            }
262        }
263    }
264
265    impl<S> Drop for DitchGuard<S> {
266        fn drop(&mut self) {
267            if let Some(state) = self.0.upgrade() {
268                if let Ok(mut locked) = state.lock() {
269                    *locked = State::Empty;
270                }
271            }
272        }
273    }
274
275    impl<S> Singled<S> {
276        fn new(inner: S, state: Weak<Mutex<State<S>>>) -> Self {
277            Singled { inner, state }
278        }
279    }
280
281    impl<S, Req> Service<Req> for Singled<S>
282    where
283        S: Service<Req>,
284    {
285        type Response = S::Response;
286        type Error = S::Error;
287        type Future = S::Future;
288
289        fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
290            // We notice if the cached service dies, and clear the singleton cache.
291            match self.inner.poll_ready(cx) {
292                Poll::Ready(Err(err)) => {
293                    if let Some(state) = self.state.upgrade() {
294                        *state.lock().unwrap() = State::Empty;
295                    }
296                    Poll::Ready(Err(err))
297                }
298                other => other,
299            }
300        }
301
302        fn call(&mut self, req: Req) -> Self::Future {
303            self.inner.call(req)
304        }
305    }
306
307    // An opaque error type. By not exposing the type, nor being specifically
308    // Box<dyn Error>, we can _change_ the type once we no longer need the Canceled
309    // error type. This will be possible with the refactor to baton passing.
310    #[derive(Debug)]
311    pub struct SingletonError(pub(super) BoxError);
312
313    impl std::fmt::Display for SingletonError {
314        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
315            f.write_str("singleton connection error")
316        }
317    }
318
319    impl std::error::Error for SingletonError {
320        fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
321            Some(&*self.0)
322        }
323    }
324
325    #[derive(Debug)]
326    struct Canceled;
327
328    impl std::fmt::Display for Canceled {
329        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
330            f.write_str("singleton connection canceled")
331        }
332    }
333
334    impl std::error::Error for Canceled {}
335}
336
337#[cfg(test)]
338mod tests {
339    use std::future::Future;
340    use std::pin::Pin;
341    use std::task::Poll;
342
343    use tower_service::Service;
344
345    use super::Singleton;
346
347    #[tokio::test]
348    async fn first_call_drives_subsequent_wait() {
349        let (mock_svc, mut handle) = tower_test::mock::pair::<(), &'static str>();
350
351        let mut singleton = Singleton::new(mock_svc);
352
353        handle.allow(1);
354        crate::common::future::poll_fn(|cx| singleton.poll_ready(cx))
355            .await
356            .unwrap();
357        // First call: should go into Driving
358        let fut1 = singleton.call(());
359        // Second call: should go into Waiting
360        let fut2 = singleton.call(());
361
362        // Expect exactly one request to the inner service
363        let ((), send_response) = handle.next_request().await.unwrap();
364        send_response.send_response("svc");
365
366        // Both futures should resolve to the same value
367        fut1.await.unwrap();
368        fut2.await.unwrap();
369    }
370
371    #[tokio::test]
372    async fn made_state_returns_immediately() {
373        let (mock_svc, mut handle) = tower_test::mock::pair::<(), &'static str>();
374        let mut singleton = Singleton::new(mock_svc);
375
376        handle.allow(1);
377        crate::common::future::poll_fn(|cx| singleton.poll_ready(cx))
378            .await
379            .unwrap();
380        // Drive first call to completion
381        let fut1 = singleton.call(());
382        let ((), send_response) = handle.next_request().await.unwrap();
383        send_response.send_response("svc");
384        fut1.await.unwrap();
385
386        // Second call should not hit inner service
387        singleton.call(()).await.unwrap();
388    }
389
390    #[tokio::test]
391    async fn cached_service_poll_ready_error_clears_singleton() {
392        // Outer mock returns an inner mock service
393        let (outer, mut outer_handle) =
394            tower_test::mock::pair::<(), tower_test::mock::Mock<(), &'static str>>();
395        let mut singleton = Singleton::new(outer);
396
397        // Allow the singleton to be made
398        outer_handle.allow(2);
399        crate::common::future::poll_fn(|cx| singleton.poll_ready(cx))
400            .await
401            .unwrap();
402
403        // First call produces an inner mock service
404        let fut1 = singleton.call(());
405        let ((), send_inner) = outer_handle.next_request().await.unwrap();
406        let (inner, mut inner_handle) = tower_test::mock::pair::<(), &'static str>();
407        send_inner.send_response(inner);
408        let mut cached = fut1.await.unwrap();
409
410        // Now: allow readiness on the inner mock, then inject error
411        inner_handle.allow(1);
412
413        // Inject error so next poll_ready fails
414        inner_handle.send_error(std::io::Error::new(
415            std::io::ErrorKind::Other,
416            "cached poll_ready failed",
417        ));
418
419        // Drive poll_ready on cached service
420        let err = crate::common::future::poll_fn(|cx| cached.poll_ready(cx))
421            .await
422            .err()
423            .expect("expected poll_ready error");
424        assert_eq!(err.to_string(), "cached poll_ready failed");
425
426        // After error, the singleton should be cleared, so a new call drives outer again
427        outer_handle.allow(1);
428        crate::common::future::poll_fn(|cx| singleton.poll_ready(cx))
429            .await
430            .unwrap();
431        let fut2 = singleton.call(());
432        let ((), send_inner2) = outer_handle.next_request().await.unwrap();
433        let (inner2, mut inner_handle2) = tower_test::mock::pair::<(), &'static str>();
434        send_inner2.send_response(inner2);
435        let mut cached2 = fut2.await.unwrap();
436
437        // The new cached service should still work
438        inner_handle2.allow(1);
439        crate::common::future::poll_fn(|cx| cached2.poll_ready(cx))
440            .await
441            .expect("expected poll_ready");
442        let cfut2 = cached2.call(());
443        let ((), send_cached2) = inner_handle2.next_request().await.unwrap();
444        send_cached2.send_response("svc2");
445        cfut2.await.unwrap();
446    }
447
448    #[tokio::test]
449    async fn cancel_waiter_does_not_affect_others() {
450        let (mock_svc, mut handle) = tower_test::mock::pair::<(), &'static str>();
451        let mut singleton = Singleton::new(mock_svc);
452
453        crate::common::future::poll_fn(|cx| singleton.poll_ready(cx))
454            .await
455            .unwrap();
456        let fut1 = singleton.call(());
457        let fut2 = singleton.call(());
458        drop(fut2); // cancel one waiter
459
460        let ((), send_response) = handle.next_request().await.unwrap();
461        send_response.send_response("svc");
462
463        fut1.await.unwrap();
464    }
465
466    // TODO: this should be able to be improved with a cooperative baton refactor
467    #[tokio::test]
468    async fn cancel_driver_cancels_all() {
469        let (mock_svc, mut handle) = tower_test::mock::pair::<(), &'static str>();
470        let mut singleton = Singleton::new(mock_svc);
471
472        crate::common::future::poll_fn(|cx| singleton.poll_ready(cx))
473            .await
474            .unwrap();
475        let mut fut1 = singleton.call(());
476        let fut2 = singleton.call(());
477
478        // poll driver just once, and then drop
479        crate::common::future::poll_fn(move |cx| {
480            let _ = Pin::new(&mut fut1).poll(cx);
481            Poll::Ready(())
482        })
483        .await;
484
485        let ((), send_response) = handle.next_request().await.unwrap();
486        send_response.send_response("svc");
487
488        assert_eq!(
489            fut2.await.unwrap_err().0.to_string(),
490            "singleton connection canceled"
491        );
492    }
493}