1pub use self::internal::builder;
9
10#[cfg(docsrs)]
11pub use self::internal::Builder;
12#[cfg(docsrs)]
13pub use self::internal::Cache;
14#[cfg(docsrs)]
15pub use self::internal::Cached;
16
17mod internal {
21 use std::fmt;
22 use std::future::Future;
23 use std::pin::Pin;
24 use std::sync::{Arc, Mutex, Weak};
25 use std::task::{self, ready, Poll};
26
27 use futures_util::future;
28 use tokio::sync::oneshot;
29 use tower_service::Service;
30
31 use super::events;
32
33 pub fn builder() -> Builder<events::Ignore> {
35 Builder {
36 events: events::Ignore,
37 }
38 }
39
40 #[derive(Debug)]
50 pub struct Cache<M, Dst, Ev>
51 where
52 M: Service<Dst>,
53 {
54 connector: M,
55 shared: Arc<Mutex<Shared<M::Response>>>,
56 events: Ev,
57 }
58
59 #[derive(Debug)]
67 pub struct Builder<Ev> {
68 events: Ev,
69 }
70
71 pub struct Cached<S> {
82 is_closed: bool,
83 inner: Option<S>,
84 shared: Weak<Mutex<Shared<S>>>,
85 }
87
88 pub enum CacheFuture<M, Dst, Ev>
89 where
90 M: Service<Dst>,
91 {
92 Racing {
93 shared: Arc<Mutex<Shared<M::Response>>>,
94 select: future::Select<oneshot::Receiver<M::Response>, M::Future>,
95 events: Ev,
96 },
97 Connecting {
98 shared: Arc<Mutex<Shared<M::Response>>>,
100 future: M::Future,
101 },
102 Cached {
103 svc: Option<Cached<M::Response>>,
104 },
105 }
106
107 #[derive(Debug)]
109 pub struct Shared<S> {
110 services: Vec<S>,
111 waiters: Vec<oneshot::Sender<S>>,
112 }
113
114 impl<Ev> Builder<Ev> {
117 pub fn executor<E>(self, exec: E) -> Builder<events::WithExecutor<E>> {
137 Builder {
138 events: events::WithExecutor(exec),
139 }
140 }
141
142 pub fn build<M, Dst>(self, connector: M) -> Cache<M, Dst, Ev>
144 where
145 M: Service<Dst>,
146 {
147 Cache {
148 connector,
149 events: self.events,
150 shared: Arc::new(Mutex::new(Shared {
151 services: Vec::new(),
152 waiters: Vec::new(),
153 })),
154 }
155 }
156 }
157
158 impl<M, Dst, Ev> Cache<M, Dst, Ev>
161 where
162 M: Service<Dst>,
163 {
164 pub fn retain<F>(&mut self, predicate: F)
166 where
167 F: FnMut(&mut M::Response) -> bool,
168 {
169 self.shared.lock().unwrap().services.retain_mut(predicate);
170 }
171
172 pub fn is_empty(&self) -> bool {
174 self.shared.lock().unwrap().services.is_empty()
175 }
176 }
177
178 impl<M, Dst, Ev> Service<Dst> for Cache<M, Dst, Ev>
179 where
180 M: Service<Dst>,
181 M::Future: Unpin,
182 M::Response: Unpin,
183 Ev: events::Events<BackgroundConnect<M::Future, M::Response>> + Clone + Unpin,
184 {
185 type Response = Cached<M::Response>;
186 type Error = M::Error;
187 type Future = CacheFuture<M, Dst, Ev>;
188
189 fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
190 if !self.shared.lock().unwrap().services.is_empty() {
191 Poll::Ready(Ok(()))
192 } else {
193 self.connector.poll_ready(cx)
194 }
195 }
196
197 fn call(&mut self, target: Dst) -> Self::Future {
198 let waiter = {
200 let mut locked = self.shared.lock().unwrap();
201 if let Some(found) = locked.take() {
202 return CacheFuture::Cached {
203 svc: Some(Cached::new(found, Arc::downgrade(&self.shared))),
204 };
205 }
206
207 let (tx, rx) = oneshot::channel();
208 locked.waiters.push(tx);
209 rx
210 };
211
212 CacheFuture::Racing {
215 shared: self.shared.clone(),
216 select: future::select(waiter, self.connector.call(target)),
217 events: self.events.clone(),
218 }
219 }
220 }
221
222 impl<M, Dst, Ev> Clone for Cache<M, Dst, Ev>
223 where
224 M: Service<Dst> + Clone,
225 Ev: Clone,
226 {
227 fn clone(&self) -> Self {
228 Self {
229 connector: self.connector.clone(),
230 events: self.events.clone(),
231 shared: self.shared.clone(),
232 }
233 }
234 }
235
236 impl<M, Dst, Ev> Future for CacheFuture<M, Dst, Ev>
237 where
238 M: Service<Dst>,
239 M::Future: Unpin,
240 M::Response: Unpin,
241 Ev: events::Events<BackgroundConnect<M::Future, M::Response>> + Unpin,
242 {
243 type Output = Result<Cached<M::Response>, M::Error>;
244
245 fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
246 loop {
247 match &mut *self.as_mut() {
248 CacheFuture::Racing {
249 shared,
250 select,
251 events,
252 } => {
253 match ready!(Pin::new(select).poll(cx)) {
254 future::Either::Left((Err(_pool_closed), connecting)) => {
255 *self = CacheFuture::Connecting {
259 shared: shared.clone(),
260 future: connecting,
261 };
262 }
263 future::Either::Left((Ok(pool_got), connecting)) => {
264 events.on_race_lost(BackgroundConnect {
265 future: connecting,
266 shared: Arc::downgrade(&shared),
267 });
268 return Poll::Ready(Ok(Cached::new(
269 pool_got,
270 Arc::downgrade(&shared),
271 )));
272 }
273 future::Either::Right((connected, _waiter)) => {
274 let inner = connected?;
275 return Poll::Ready(Ok(Cached::new(
276 inner,
277 Arc::downgrade(&shared),
278 )));
279 }
280 }
281 }
282 CacheFuture::Connecting { shared, future } => {
283 let inner = ready!(Pin::new(future).poll(cx))?;
284 return Poll::Ready(Ok(Cached::new(inner, Arc::downgrade(&shared))));
285 }
286 CacheFuture::Cached { svc } => {
287 return Poll::Ready(Ok(svc.take().unwrap()));
288 }
289 }
290 }
291 }
292 }
293
294 impl<S> Cached<S> {
297 fn new(inner: S, shared: Weak<Mutex<Shared<S>>>) -> Self {
298 Cached {
299 is_closed: false,
300 inner: Some(inner),
301 shared,
302 }
303 }
304
305 pub fn inner(&self) -> &S {
309 self.inner.as_ref().expect("inner only taken in drop")
310 }
311
312 pub fn inner_mut(&mut self) -> &mut S {
314 self.inner.as_mut().expect("inner only taken in drop")
315 }
316 }
317
318 impl<S, Req> Service<Req> for Cached<S>
319 where
320 S: Service<Req>,
321 {
322 type Response = S::Response;
323 type Error = S::Error;
324 type Future = S::Future;
325
326 fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
327 self.inner.as_mut().unwrap().poll_ready(cx).map_err(|err| {
328 self.is_closed = true;
329 err
330 })
331 }
332
333 fn call(&mut self, req: Req) -> Self::Future {
334 self.inner.as_mut().unwrap().call(req)
335 }
336 }
337
338 impl<S> Drop for Cached<S> {
339 fn drop(&mut self) {
340 if self.is_closed {
341 return;
342 }
343 if let Some(value) = self.inner.take() {
344 if let Some(shared) = self.shared.upgrade() {
345 if let Ok(mut shared) = shared.lock() {
346 shared.put(value);
347 }
348 }
349 }
350 }
351 }
352
353 impl<S: fmt::Debug> fmt::Debug for Cached<S> {
354 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
355 f.debug_tuple("Cached")
356 .field(self.inner.as_ref().unwrap())
357 .finish()
358 }
359 }
360
361 impl<V> Shared<V> {
364 fn put(&mut self, val: V) {
365 let mut val = Some(val);
366 while let Some(tx) = self.waiters.pop() {
367 if !tx.is_closed() {
368 match tx.send(val.take().unwrap()) {
369 Ok(()) => break,
370 Err(v) => {
371 val = Some(v);
372 }
373 }
374 }
375 }
376
377 if let Some(val) = val {
378 self.services.push(val);
379 }
380 }
381
382 fn take(&mut self) -> Option<V> {
383 self.services.pop()
385 }
386 }
387
388 pub struct BackgroundConnect<CF, S> {
389 future: CF,
390 shared: Weak<Mutex<Shared<S>>>,
391 }
392
393 impl<CF, S, E> Future for BackgroundConnect<CF, S>
394 where
395 CF: Future<Output = Result<S, E>> + Unpin,
396 {
397 type Output = ();
398
399 fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
400 match ready!(Pin::new(&mut self.future).poll(cx)) {
401 Ok(svc) => {
402 if let Some(shared) = self.shared.upgrade() {
403 if let Ok(mut locked) = shared.lock() {
404 locked.put(svc);
405 }
406 }
407 Poll::Ready(())
408 }
409 Err(_e) => Poll::Ready(()),
410 }
411 }
412 }
413}
414
415mod events {
416 #[derive(Clone, Debug)]
417 #[non_exhaustive]
418 pub struct Ignore;
419
420 #[derive(Clone, Debug)]
421 pub struct WithExecutor<E>(pub(super) E);
422
423 pub trait Events<CF> {
424 fn on_race_lost(&self, fut: CF);
425 }
426
427 impl<CF> Events<CF> for Ignore {
428 fn on_race_lost(&self, _fut: CF) {}
429 }
430
431 impl<E, CF> Events<CF> for WithExecutor<E>
432 where
433 E: hyper::rt::Executor<CF>,
434 {
435 fn on_race_lost(&self, fut: CF) {
436 self.0.execute(fut);
437 }
438 }
439}
440
441#[cfg(test)]
442mod tests {
443 use futures_util::future;
444 use tower_service::Service;
445 use tower_test::assert_request_eq;
446
447 #[tokio::test]
448 async fn test_makes_svc_when_empty() {
449 let (mock, mut handle) = tower_test::mock::pair();
450 let mut cache = super::builder().build(mock);
451 handle.allow(1);
452
453 std::future::poll_fn(|cx| cache.poll_ready(cx))
454 .await
455 .unwrap();
456
457 let f = cache.call(1);
458
459 future::join(f, async move {
460 assert_request_eq!(handle, 1).send_response("one");
461 })
462 .await
463 .0
464 .expect("call");
465 }
466
467 #[tokio::test]
468 async fn test_reuses_after_idle() {
469 let (mock, mut handle) = tower_test::mock::pair();
470 let mut cache = super::builder().build(mock);
471
472 handle.allow(1);
474
475 std::future::poll_fn(|cx| cache.poll_ready(cx))
476 .await
477 .unwrap();
478 let f = cache.call(1);
479 let cached = future::join(f, async {
480 assert_request_eq!(handle, 1).send_response("one");
481 })
482 .await
483 .0
484 .expect("call");
485 drop(cached);
486
487 std::future::poll_fn(|cx| cache.poll_ready(cx))
488 .await
489 .unwrap();
490 let f = cache.call(1);
491 let cached = f.await.expect("call");
492 drop(cached);
493 }
494}