Skip to main content

hyper_util/client/pool/
cache.rs

1//! A cache of services
2//!
3//! The cache is a single list of cached services, bundled with a `MakeService`.
4//! Calling the cache returns either an existing service, or makes a new one.
5//! The returned `impl Service` can be used to send requests, and when dropped,
6//! it will try to be returned back to the cache.
7
8pub use self::internal::builder;
9
10#[cfg(docsrs)]
11pub use self::internal::Builder;
12#[cfg(docsrs)]
13pub use self::internal::Cache;
14#[cfg(docsrs)]
15pub use self::internal::Cached;
16
17// For now, nothing else in this module is nameable. We can always make things
18// more public, but we can't change type shapes (generics) once things are
19// public.
20mod internal {
21    use std::fmt;
22    use std::future::Future;
23    use std::pin::Pin;
24    use std::sync::{Arc, Mutex, Weak};
25    use std::task::{self, ready, Poll};
26
27    use futures_util::future;
28    use tokio::sync::oneshot;
29    use tower_service::Service;
30
31    use super::events;
32
33    /// Start a builder to construct a `Cache` pool.
34    pub fn builder() -> Builder<events::Ignore> {
35        Builder {
36            events: events::Ignore,
37        }
38    }
39
40    /// A cache pool of services from the inner make service.
41    ///
42    /// Created with [`builder()`].
43    ///
44    /// # Unnameable
45    ///
46    /// This type is normally unnameable, forbidding naming of the type within
47    /// code. The type is exposed in the documentation to show which methods
48    /// can be publicly called.
49    #[derive(Debug)]
50    pub struct Cache<M, Dst, Ev>
51    where
52        M: Service<Dst>,
53    {
54        connector: M,
55        shared: Arc<Mutex<Shared<M::Response>>>,
56        events: Ev,
57    }
58
59    /// A builder to configure a `Cache`.
60    ///
61    /// # Unnameable
62    ///
63    /// This type is normally unnameable, forbidding naming of the type within
64    /// code. The type is exposed in the documentation to show which methods
65    /// can be publicly called.
66    #[derive(Debug)]
67    pub struct Builder<Ev> {
68        events: Ev,
69    }
70
71    /// A cached service returned from a [`Cache`].
72    ///
73    /// Implements `Service` by delegating to the inner service. Once dropped,
74    /// tries to reinsert into the `Cache`.
75    ///
76    /// # Unnameable
77    ///
78    /// This type is normally unnameable, forbidding naming of the type within
79    /// code. The type is exposed in the documentation to show which methods
80    /// can be publicly called.
81    pub struct Cached<S> {
82        is_closed: bool,
83        inner: Option<S>,
84        shared: Weak<Mutex<Shared<S>>>,
85        // todo: on_idle
86    }
87
88    pub enum CacheFuture<M, Dst, Ev>
89    where
90        M: Service<Dst>,
91    {
92        Racing {
93            shared: Arc<Mutex<Shared<M::Response>>>,
94            select: future::Select<oneshot::Receiver<M::Response>, M::Future>,
95            events: Ev,
96        },
97        Connecting {
98            // TODO: could be Weak even here...
99            shared: Arc<Mutex<Shared<M::Response>>>,
100            future: M::Future,
101        },
102        Cached {
103            svc: Option<Cached<M::Response>>,
104        },
105    }
106
107    // shouldn't be pub
108    #[derive(Debug)]
109    pub struct Shared<S> {
110        services: Vec<S>,
111        waiters: Vec<oneshot::Sender<S>>,
112    }
113
114    // impl Builder
115
116    impl<Ev> Builder<Ev> {
117        /// Provide a `Future` executor to be used by the `Cache`.
118        ///
119        /// The executor is used handle some optional background tasks that
120        /// can improve the behavior of the cache, such as reducing connection
121        /// thrashing when a race is won. If not configured with an executor,
122        /// the default behavior is to ignore any of these optional background
123        /// tasks.
124        ///
125        /// The executor should implmenent [`hyper::rt::Executor`].
126        ///
127        /// # Example
128        ///
129        /// ```rust
130        /// # #[cfg(feature = "tokio")]
131        /// # fn run() {
132        /// let builder = hyper_util::client::pool::cache::builder()
133        ///     .executor(hyper_util::rt::TokioExecutor::new());
134        /// # }
135        /// ```
136        pub fn executor<E>(self, exec: E) -> Builder<events::WithExecutor<E>> {
137            Builder {
138                events: events::WithExecutor(exec),
139            }
140        }
141
142        /// Build a `Cache` pool around the `connector`.
143        pub fn build<M, Dst>(self, connector: M) -> Cache<M, Dst, Ev>
144        where
145            M: Service<Dst>,
146        {
147            Cache {
148                connector,
149                events: self.events,
150                shared: Arc::new(Mutex::new(Shared {
151                    services: Vec::new(),
152                    waiters: Vec::new(),
153                })),
154            }
155        }
156    }
157
158    // impl Cache
159
160    impl<M, Dst, Ev> Cache<M, Dst, Ev>
161    where
162        M: Service<Dst>,
163    {
164        /// Retain all cached services indicated by the predicate.
165        pub fn retain<F>(&mut self, predicate: F)
166        where
167            F: FnMut(&mut M::Response) -> bool,
168        {
169            self.shared.lock().unwrap().services.retain_mut(predicate);
170        }
171
172        /// Check whether this cache has no cached services.
173        pub fn is_empty(&self) -> bool {
174            self.shared.lock().unwrap().services.is_empty()
175        }
176    }
177
178    impl<M, Dst, Ev> Service<Dst> for Cache<M, Dst, Ev>
179    where
180        M: Service<Dst>,
181        M::Future: Unpin,
182        M::Response: Unpin,
183        Ev: events::Events<BackgroundConnect<M::Future, M::Response>> + Clone + Unpin,
184    {
185        type Response = Cached<M::Response>;
186        type Error = M::Error;
187        type Future = CacheFuture<M, Dst, Ev>;
188
189        fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
190            if !self.shared.lock().unwrap().services.is_empty() {
191                Poll::Ready(Ok(()))
192            } else {
193                self.connector.poll_ready(cx)
194            }
195        }
196
197        fn call(&mut self, target: Dst) -> Self::Future {
198            // 1. If already cached, easy!
199            let waiter = {
200                let mut locked = self.shared.lock().unwrap();
201                if let Some(found) = locked.take() {
202                    return CacheFuture::Cached {
203                        svc: Some(Cached::new(found, Arc::downgrade(&self.shared))),
204                    };
205                }
206
207                let (tx, rx) = oneshot::channel();
208                locked.waiters.push(tx);
209                rx
210            };
211
212            // 2. Otherwise, we start a new connect, and also listen for
213            //    any newly idle.
214            CacheFuture::Racing {
215                shared: self.shared.clone(),
216                select: future::select(waiter, self.connector.call(target)),
217                events: self.events.clone(),
218            }
219        }
220    }
221
222    impl<M, Dst, Ev> Clone for Cache<M, Dst, Ev>
223    where
224        M: Service<Dst> + Clone,
225        Ev: Clone,
226    {
227        fn clone(&self) -> Self {
228            Self {
229                connector: self.connector.clone(),
230                events: self.events.clone(),
231                shared: self.shared.clone(),
232            }
233        }
234    }
235
236    impl<M, Dst, Ev> Future for CacheFuture<M, Dst, Ev>
237    where
238        M: Service<Dst>,
239        M::Future: Unpin,
240        M::Response: Unpin,
241        Ev: events::Events<BackgroundConnect<M::Future, M::Response>> + Unpin,
242    {
243        type Output = Result<Cached<M::Response>, M::Error>;
244
245        fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
246            loop {
247                match &mut *self.as_mut() {
248                    CacheFuture::Racing {
249                        shared,
250                        select,
251                        events,
252                    } => {
253                        match ready!(Pin::new(select).poll(cx)) {
254                            future::Either::Left((Err(_pool_closed), connecting)) => {
255                                // pool was dropped, so we'll never get it from a waiter,
256                                // but if this future still exists, then the user still
257                                // wants a connection. just wait for the connecting
258                                *self = CacheFuture::Connecting {
259                                    shared: shared.clone(),
260                                    future: connecting,
261                                };
262                            }
263                            future::Either::Left((Ok(pool_got), connecting)) => {
264                                events.on_race_lost(BackgroundConnect {
265                                    future: connecting,
266                                    shared: Arc::downgrade(&shared),
267                                });
268                                return Poll::Ready(Ok(Cached::new(
269                                    pool_got,
270                                    Arc::downgrade(&shared),
271                                )));
272                            }
273                            future::Either::Right((connected, _waiter)) => {
274                                let inner = connected?;
275                                return Poll::Ready(Ok(Cached::new(
276                                    inner,
277                                    Arc::downgrade(&shared),
278                                )));
279                            }
280                        }
281                    }
282                    CacheFuture::Connecting { shared, future } => {
283                        let inner = ready!(Pin::new(future).poll(cx))?;
284                        return Poll::Ready(Ok(Cached::new(inner, Arc::downgrade(&shared))));
285                    }
286                    CacheFuture::Cached { svc } => {
287                        return Poll::Ready(Ok(svc.take().unwrap()));
288                    }
289                }
290            }
291        }
292    }
293
294    // impl Cached
295
296    impl<S> Cached<S> {
297        fn new(inner: S, shared: Weak<Mutex<Shared<S>>>) -> Self {
298            Cached {
299                is_closed: false,
300                inner: Some(inner),
301                shared,
302            }
303        }
304
305        // TODO: inner()? looks like `tower` likes `get_ref()` and `get_mut()`.
306
307        /// Get a reference to the inner service.
308        pub fn inner(&self) -> &S {
309            self.inner.as_ref().expect("inner only taken in drop")
310        }
311
312        /// Get a mutable reference to the inner service.
313        pub fn inner_mut(&mut self) -> &mut S {
314            self.inner.as_mut().expect("inner only taken in drop")
315        }
316    }
317
318    impl<S, Req> Service<Req> for Cached<S>
319    where
320        S: Service<Req>,
321    {
322        type Response = S::Response;
323        type Error = S::Error;
324        type Future = S::Future;
325
326        fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
327            self.inner.as_mut().unwrap().poll_ready(cx).map_err(|err| {
328                self.is_closed = true;
329                err
330            })
331        }
332
333        fn call(&mut self, req: Req) -> Self::Future {
334            self.inner.as_mut().unwrap().call(req)
335        }
336    }
337
338    impl<S> Drop for Cached<S> {
339        fn drop(&mut self) {
340            if self.is_closed {
341                return;
342            }
343            if let Some(value) = self.inner.take() {
344                if let Some(shared) = self.shared.upgrade() {
345                    if let Ok(mut shared) = shared.lock() {
346                        shared.put(value);
347                    }
348                }
349            }
350        }
351    }
352
353    impl<S: fmt::Debug> fmt::Debug for Cached<S> {
354        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
355            f.debug_tuple("Cached")
356                .field(self.inner.as_ref().unwrap())
357                .finish()
358        }
359    }
360
361    // impl Shared
362
363    impl<V> Shared<V> {
364        fn put(&mut self, val: V) {
365            let mut val = Some(val);
366            while let Some(tx) = self.waiters.pop() {
367                if !tx.is_closed() {
368                    match tx.send(val.take().unwrap()) {
369                        Ok(()) => break,
370                        Err(v) => {
371                            val = Some(v);
372                        }
373                    }
374                }
375            }
376
377            if let Some(val) = val {
378                self.services.push(val);
379            }
380        }
381
382        fn take(&mut self) -> Option<V> {
383            // TODO: take in a loop
384            self.services.pop()
385        }
386    }
387
388    pub struct BackgroundConnect<CF, S> {
389        future: CF,
390        shared: Weak<Mutex<Shared<S>>>,
391    }
392
393    impl<CF, S, E> Future for BackgroundConnect<CF, S>
394    where
395        CF: Future<Output = Result<S, E>> + Unpin,
396    {
397        type Output = ();
398
399        fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
400            match ready!(Pin::new(&mut self.future).poll(cx)) {
401                Ok(svc) => {
402                    if let Some(shared) = self.shared.upgrade() {
403                        if let Ok(mut locked) = shared.lock() {
404                            locked.put(svc);
405                        }
406                    }
407                    Poll::Ready(())
408                }
409                Err(_e) => Poll::Ready(()),
410            }
411        }
412    }
413}
414
415mod events {
416    #[derive(Clone, Debug)]
417    #[non_exhaustive]
418    pub struct Ignore;
419
420    #[derive(Clone, Debug)]
421    pub struct WithExecutor<E>(pub(super) E);
422
423    pub trait Events<CF> {
424        fn on_race_lost(&self, fut: CF);
425    }
426
427    impl<CF> Events<CF> for Ignore {
428        fn on_race_lost(&self, _fut: CF) {}
429    }
430
431    impl<E, CF> Events<CF> for WithExecutor<E>
432    where
433        E: hyper::rt::Executor<CF>,
434    {
435        fn on_race_lost(&self, fut: CF) {
436            self.0.execute(fut);
437        }
438    }
439}
440
441#[cfg(test)]
442mod tests {
443    use futures_util::future;
444    use tower_service::Service;
445    use tower_test::assert_request_eq;
446
447    #[tokio::test]
448    async fn test_makes_svc_when_empty() {
449        let (mock, mut handle) = tower_test::mock::pair();
450        let mut cache = super::builder().build(mock);
451        handle.allow(1);
452
453        std::future::poll_fn(|cx| cache.poll_ready(cx))
454            .await
455            .unwrap();
456
457        let f = cache.call(1);
458
459        future::join(f, async move {
460            assert_request_eq!(handle, 1).send_response("one");
461        })
462        .await
463        .0
464        .expect("call");
465    }
466
467    #[tokio::test]
468    async fn test_reuses_after_idle() {
469        let (mock, mut handle) = tower_test::mock::pair();
470        let mut cache = super::builder().build(mock);
471
472        // only 1 connection should ever be made
473        handle.allow(1);
474
475        std::future::poll_fn(|cx| cache.poll_ready(cx))
476            .await
477            .unwrap();
478        let f = cache.call(1);
479        let cached = future::join(f, async {
480            assert_request_eq!(handle, 1).send_response("one");
481        })
482        .await
483        .0
484        .expect("call");
485        drop(cached);
486
487        std::future::poll_fn(|cx| cache.poll_ready(cx))
488            .await
489            .unwrap();
490        let f = cache.call(1);
491        let cached = f.await.expect("call");
492        drop(cached);
493    }
494}