1pub use self::internal::builder;
40
41#[cfg(docsrs)]
42pub use self::internal::Builder;
43#[cfg(docsrs)]
44pub use self::internal::Negotiate;
45#[cfg(docsrs)]
46pub use self::internal::Negotiated;
47
48mod internal {
49 use std::future::Future;
50 use std::pin::Pin;
51 use std::sync::{Arc, Mutex};
52 use std::task::{self, Poll};
53
54 use futures_core::ready;
55 use pin_project_lite::pin_project;
56 use tower_layer::Layer;
57 use tower_service::Service;
58
59 type BoxError = Box<dyn std::error::Error + Send + Sync>;
60
61 #[derive(Clone)]
71 pub struct Negotiate<L, R> {
72 left: L,
73 right: R,
74 }
75
76 #[derive(Clone, Debug)]
84 pub enum Negotiated<L, R> {
85 #[doc(hidden)]
86 Fallback(L),
87 #[doc(hidden)]
88 Upgraded(R),
89 }
90
91 pin_project! {
92 pub struct Negotiating<Dst, L, R>
93 where
94 L: Service<Dst>,
95 R: Service<()>,
96 {
97 #[pin]
98 state: State<Dst, L::Future, R::Future>,
99 left: L,
100 right: R,
101 }
102 }
103
104 pin_project! {
105 #[project = StateProj]
106 enum State<Dst, FL, FR> {
107 Eager {
108 #[pin]
109 future: FR,
110 dst: Option<Dst>,
111 },
112 Fallback {
113 #[pin]
114 future: FL,
115 },
116 Upgrade {
117 #[pin]
118 future: FR,
119 }
120 }
121 }
122
123 pin_project! {
124 #[project = NegotiatedProj]
125 pub enum NegotiatedFuture<L, R> {
126 Fallback {
127 #[pin]
128 future: L
129 },
130 Upgraded {
131 #[pin]
132 future: R
133 },
134 }
135 }
136
137 #[derive(Debug)]
145 pub struct Builder<C, I, L, R> {
146 connect: C,
147 inspect: I,
148 fallback: L,
149 upgrade: R,
150 }
151
152 #[derive(Debug)]
153 pub struct WantsConnect;
154 #[derive(Debug)]
155 pub struct WantsInspect;
156 #[derive(Debug)]
157 pub struct WantsFallback;
158 #[derive(Debug)]
159 pub struct WantsUpgrade;
160
161 pub fn builder() -> Builder<WantsConnect, WantsInspect, WantsFallback, WantsUpgrade> {
163 Builder {
164 connect: WantsConnect,
165 inspect: WantsInspect,
166 fallback: WantsFallback,
167 upgrade: WantsUpgrade,
168 }
169 }
170
171 impl<C, I, L, R> Builder<C, I, L, R> {
172 pub fn connect<CC>(self, connect: CC) -> Builder<CC, I, L, R> {
174 Builder {
175 connect,
176 inspect: self.inspect,
177 fallback: self.fallback,
178 upgrade: self.upgrade,
179 }
180 }
181
182 pub fn inspect<II>(self, inspect: II) -> Builder<C, II, L, R> {
184 Builder {
185 connect: self.connect,
186 inspect,
187 fallback: self.fallback,
188 upgrade: self.upgrade,
189 }
190 }
191
192 pub fn fallback<LL>(self, fallback: LL) -> Builder<C, I, LL, R> {
194 Builder {
195 connect: self.connect,
196 inspect: self.inspect,
197 fallback,
198 upgrade: self.upgrade,
199 }
200 }
201
202 pub fn upgrade<RR>(self, upgrade: RR) -> Builder<C, I, L, RR> {
204 Builder {
205 connect: self.connect,
206 inspect: self.inspect,
207 fallback: self.fallback,
208 upgrade,
209 }
210 }
211
212 pub fn build<Dst>(self) -> Negotiate<L::Service, R::Service>
214 where
215 C: Service<Dst>,
216 C::Error: Into<BoxError>,
217 L: Layer<Inspector<C, C::Response, I>>,
218 L::Service: Service<Dst> + Clone,
219 <L::Service as Service<Dst>>::Error: Into<BoxError>,
220 R: Layer<Inspected<C::Response>>,
221 R::Service: Service<()> + Clone,
222 <R::Service as Service<()>>::Error: Into<BoxError>,
223 I: Fn(&C::Response) -> bool + Clone,
224 {
225 let Builder {
226 connect,
227 inspect,
228 fallback,
229 upgrade,
230 } = self;
231
232 let slot = Arc::new(Mutex::new(None));
233 let wrapped = Inspector {
234 svc: connect,
235 inspect,
236 slot: slot.clone(),
237 };
238 let left = fallback.layer(wrapped);
239
240 let right = upgrade.layer(Inspected { slot });
241
242 Negotiate { left, right }
243 }
244 }
245
246 impl<L, R> Negotiate<L, R> {
247 pub fn fallback_mut(&mut self) -> &mut L {
249 &mut self.left
250 }
251
252 pub fn upgrade_mut(&mut self) -> &mut R {
254 &mut self.right
255 }
256 }
257
258 impl<L, R, Target> Service<Target> for Negotiate<L, R>
259 where
260 L: Service<Target> + Clone,
261 L::Error: Into<BoxError>,
262 R: Service<()> + Clone,
263 R::Error: Into<BoxError>,
264 {
265 type Response = Negotiated<L::Response, R::Response>;
266 type Error = BoxError;
267 type Future = Negotiating<Target, L, R>;
268
269 fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
270 self.left.poll_ready(cx).map_err(Into::into)
271 }
272
273 fn call(&mut self, dst: Target) -> Self::Future {
274 let left = self.left.clone();
275 Negotiating {
276 state: State::Eager {
277 future: self.right.call(()),
278 dst: Some(dst),
279 },
280 left: std::mem::replace(&mut self.left, left),
282 right: self.right.clone(),
283 }
284 }
285 }
286
287 impl<Dst, L, R> Future for Negotiating<Dst, L, R>
288 where
289 L: Service<Dst>,
290 L::Error: Into<BoxError>,
291 R: Service<()>,
292 R::Error: Into<BoxError>,
293 {
294 type Output = Result<Negotiated<L::Response, R::Response>, BoxError>;
295
296 fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
297 let mut me = self.project();
303 loop {
304 match me.state.as_mut().project() {
305 StateProj::Eager { future, dst } => match ready!(future.poll(cx)) {
306 Ok(out) => return Poll::Ready(Ok(Negotiated::Upgraded(out))),
307 Err(err) => {
308 let err = err.into();
309 if UseOther::is(&*err) {
310 let dst = dst.take().unwrap();
311 let f = me.left.call(dst);
312 me.state.set(State::Fallback { future: f });
313 continue;
314 } else {
315 return Poll::Ready(Err(err));
316 }
317 }
318 },
319 StateProj::Fallback { future } => match ready!(future.poll(cx)) {
320 Ok(out) => return Poll::Ready(Ok(Negotiated::Fallback(out))),
321 Err(err) => {
322 let err = err.into();
323 if UseOther::is(&*err) {
324 let f = me.right.call(());
325 me.state.set(State::Upgrade { future: f });
326 continue;
327 } else {
328 return Poll::Ready(Err(err));
329 }
330 }
331 },
332 StateProj::Upgrade { future } => match ready!(future.poll(cx)) {
333 Ok(out) => return Poll::Ready(Ok(Negotiated::Upgraded(out))),
334 Err(err) => return Poll::Ready(Err(err.into())),
335 },
336 }
337 }
338 }
339 }
340
341 impl<L, R> Negotiated<L, R> {
342 #[cfg(test)]
344 pub(super) fn is_fallback(&self) -> bool {
345 matches!(self, Negotiated::Fallback(_))
346 }
347
348 #[cfg(test)]
349 pub(super) fn is_upgraded(&self) -> bool {
350 matches!(self, Negotiated::Upgraded(_))
351 }
352
353 pub fn fallback_ref(&self) -> Option<&L> {
357 if let Negotiated::Fallback(ref left) = self {
358 Some(left)
359 } else {
360 None
361 }
362 }
363
364 pub fn fallback_mut(&mut self) -> Option<&mut L> {
366 if let Negotiated::Fallback(ref mut left) = self {
367 Some(left)
368 } else {
369 None
370 }
371 }
372
373 pub fn upgraded_ref(&self) -> Option<&R> {
375 if let Negotiated::Upgraded(ref right) = self {
376 Some(right)
377 } else {
378 None
379 }
380 }
381
382 pub fn upgraded_mut(&mut self) -> Option<&mut R> {
384 if let Negotiated::Upgraded(ref mut right) = self {
385 Some(right)
386 } else {
387 None
388 }
389 }
390 }
391
392 impl<L, R, Req, Res, E> Service<Req> for Negotiated<L, R>
393 where
394 L: Service<Req, Response = Res, Error = E>,
395 R: Service<Req, Response = Res, Error = E>,
396 {
397 type Response = Res;
398 type Error = E;
399 type Future = NegotiatedFuture<L::Future, R::Future>;
400
401 fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
402 match self {
403 Negotiated::Fallback(ref mut s) => s.poll_ready(cx),
404 Negotiated::Upgraded(ref mut s) => s.poll_ready(cx),
405 }
406 }
407
408 fn call(&mut self, req: Req) -> Self::Future {
409 match self {
410 Negotiated::Fallback(ref mut s) => NegotiatedFuture::Fallback {
411 future: s.call(req),
412 },
413 Negotiated::Upgraded(ref mut s) => NegotiatedFuture::Upgraded {
414 future: s.call(req),
415 },
416 }
417 }
418 }
419
420 impl<L, R, Out> Future for NegotiatedFuture<L, R>
421 where
422 L: Future<Output = Out>,
423 R: Future<Output = Out>,
424 {
425 type Output = Out;
426
427 fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
428 match self.project() {
429 NegotiatedProj::Fallback { future } => future.poll(cx),
430 NegotiatedProj::Upgraded { future } => future.poll(cx),
431 }
432 }
433 }
434
435 pub struct Inspector<M, S, I> {
438 svc: M,
439 inspect: I,
440 slot: Arc<Mutex<Option<S>>>,
441 }
442
443 pin_project! {
444 pub struct InspectFuture<F, S, I> {
445 #[pin]
446 future: F,
447 inspect: I,
448 slot: Arc<Mutex<Option<S>>>,
449 }
450 }
451
452 impl<M: Clone, S, I: Clone> Clone for Inspector<M, S, I> {
453 fn clone(&self) -> Self {
454 Self {
455 svc: self.svc.clone(),
456 inspect: self.inspect.clone(),
457 slot: self.slot.clone(),
458 }
459 }
460 }
461
462 impl<M, S, I, Target> Service<Target> for Inspector<M, S, I>
463 where
464 M: Service<Target, Response = S>,
465 M::Error: Into<BoxError>,
466 I: Clone + Fn(&S) -> bool,
467 {
468 type Response = M::Response;
469 type Error = BoxError;
470 type Future = InspectFuture<M::Future, S, I>;
471
472 fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
473 self.svc.poll_ready(cx).map_err(Into::into)
474 }
475
476 fn call(&mut self, dst: Target) -> Self::Future {
477 InspectFuture {
478 future: self.svc.call(dst),
479 inspect: self.inspect.clone(),
480 slot: self.slot.clone(),
481 }
482 }
483 }
484
485 impl<F, I, S, E> Future for InspectFuture<F, S, I>
486 where
487 F: Future<Output = Result<S, E>>,
488 E: Into<BoxError>,
489 I: Fn(&S) -> bool,
490 {
491 type Output = Result<S, BoxError>;
492
493 fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
494 let me = self.project();
495 let s = ready!(me.future.poll(cx)).map_err(Into::into)?;
496 Poll::Ready(if (me.inspect)(&s) {
497 *me.slot.lock().unwrap() = Some(s);
498 Err(UseOther.into())
499 } else {
500 Ok(s)
501 })
502 }
503 }
504
505 pub struct Inspected<S> {
506 slot: Arc<Mutex<Option<S>>>,
507 }
508
509 impl<S, Target> Service<Target> for Inspected<S> {
510 type Response = S;
511 type Error = BoxError;
512 type Future = std::future::Ready<Result<S, BoxError>>;
513
514 fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
515 if self.slot.lock().unwrap().is_some() {
516 Poll::Ready(Ok(()))
517 } else {
518 Poll::Ready(Err(UseOther.into()))
519 }
520 }
521
522 fn call(&mut self, _dst: Target) -> Self::Future {
523 let s = self
524 .slot
525 .lock()
526 .unwrap()
527 .take()
528 .ok_or_else(|| UseOther.into());
529 std::future::ready(s)
530 }
531 }
532
533 impl<S> Clone for Inspected<S> {
534 fn clone(&self) -> Inspected<S> {
535 Inspected {
536 slot: self.slot.clone(),
537 }
538 }
539 }
540
541 #[derive(Debug)]
542 struct UseOther;
543
544 impl std::fmt::Display for UseOther {
545 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
546 f.write_str("sentinel error; using other")
547 }
548 }
549
550 impl std::error::Error for UseOther {}
551
552 impl UseOther {
553 fn is(err: &(dyn std::error::Error + 'static)) -> bool {
554 let mut source = Some(err);
555 while let Some(err) = source {
556 if err.is::<UseOther>() {
557 return true;
558 }
559 source = err.source();
560 }
561 false
562 }
563 }
564}
565
566#[cfg(test)]
567mod tests {
568 use futures_util::future;
569 use tower_service::Service;
570 use tower_test::assert_request_eq;
571
572 #[tokio::test]
573 async fn not_negotiated_falls_back_to_left() {
574 let (mock_svc, mut handle) = tower_test::mock::pair::<(), &'static str>();
575
576 let mut negotiate = super::builder()
577 .connect(mock_svc)
578 .inspect(|_: &&str| false)
579 .fallback(layer_fn(|s| s))
580 .upgrade(layer_fn(|s| s))
581 .build();
582
583 crate::common::future::poll_fn(|cx| negotiate.poll_ready(cx))
584 .await
585 .unwrap();
586
587 let fut = negotiate.call(());
588 let nsvc = future::join(fut, async move {
589 assert_request_eq!(handle, ()).send_response("one");
590 })
591 .await
592 .0
593 .expect("call");
594 assert!(nsvc.is_fallback());
595 }
596
597 #[tokio::test]
598 async fn negotiated_uses_right() {
599 let (mock_svc, mut handle) = tower_test::mock::pair::<(), &'static str>();
600
601 let mut negotiate = super::builder()
602 .connect(mock_svc)
603 .inspect(|_: &&str| true)
604 .fallback(layer_fn(|s| s))
605 .upgrade(layer_fn(|s| s))
606 .build();
607
608 crate::common::future::poll_fn(|cx| negotiate.poll_ready(cx))
609 .await
610 .unwrap();
611
612 let fut = negotiate.call(());
613 let nsvc = future::join(fut, async move {
614 assert_request_eq!(handle, ()).send_response("one");
615 })
616 .await
617 .0
618 .expect("call");
619
620 assert!(nsvc.is_upgraded());
621 }
622
623 fn layer_fn<F>(f: F) -> LayerFn<F> {
624 LayerFn(f)
625 }
626
627 #[derive(Clone)]
628 struct LayerFn<F>(F);
629
630 impl<F, S, Out> tower_layer::Layer<S> for LayerFn<F>
631 where
632 F: Fn(S) -> Out,
633 {
634 type Service = Out;
635 fn layer(&self, inner: S) -> Self::Service {
636 (self.0)(inner)
637 }
638 }
639}