1pub use self::internal::builder;
40
41#[cfg(docsrs)]
42pub use self::internal::Builder;
43#[cfg(docsrs)]
44pub use self::internal::Negotiate;
45#[cfg(docsrs)]
46pub use self::internal::Negotiated;
47
48mod internal {
49 use std::future::Future;
50 use std::pin::Pin;
51 use std::sync::{Arc, Mutex};
52 use std::task::{self, ready, Poll};
53
54 use pin_project_lite::pin_project;
55 use tower_layer::Layer;
56 use tower_service::Service;
57
58 type BoxError = Box<dyn std::error::Error + Send + Sync>;
59
60 #[derive(Clone)]
70 pub struct Negotiate<L, R> {
71 left: L,
72 right: R,
73 }
74
75 #[derive(Clone, Debug)]
83 pub enum Negotiated<L, R> {
84 #[doc(hidden)]
85 Fallback(L),
86 #[doc(hidden)]
87 Upgraded(R),
88 }
89
90 pin_project! {
91 pub struct Negotiating<Dst, L, R>
92 where
93 L: Service<Dst>,
94 R: Service<()>,
95 {
96 #[pin]
97 state: State<Dst, L::Future, R::Future>,
98 left: L,
99 right: R,
100 }
101 }
102
103 pin_project! {
104 #[project = StateProj]
105 enum State<Dst, FL, FR> {
106 Eager {
107 #[pin]
108 future: FR,
109 dst: Option<Dst>,
110 },
111 Fallback {
112 #[pin]
113 future: FL,
114 },
115 Upgrade {
116 #[pin]
117 future: FR,
118 }
119 }
120 }
121
122 pin_project! {
123 #[project = NegotiatedProj]
124 pub enum NegotiatedFuture<L, R> {
125 Fallback {
126 #[pin]
127 future: L
128 },
129 Upgraded {
130 #[pin]
131 future: R
132 },
133 }
134 }
135
136 #[derive(Debug)]
144 pub struct Builder<C, I, L, R> {
145 connect: C,
146 inspect: I,
147 fallback: L,
148 upgrade: R,
149 }
150
151 #[derive(Debug)]
152 pub struct WantsConnect;
153 #[derive(Debug)]
154 pub struct WantsInspect;
155 #[derive(Debug)]
156 pub struct WantsFallback;
157 #[derive(Debug)]
158 pub struct WantsUpgrade;
159
160 pub fn builder() -> Builder<WantsConnect, WantsInspect, WantsFallback, WantsUpgrade> {
162 Builder {
163 connect: WantsConnect,
164 inspect: WantsInspect,
165 fallback: WantsFallback,
166 upgrade: WantsUpgrade,
167 }
168 }
169
170 impl<C, I, L, R> Builder<C, I, L, R> {
171 pub fn connect<CC>(self, connect: CC) -> Builder<CC, I, L, R> {
173 Builder {
174 connect,
175 inspect: self.inspect,
176 fallback: self.fallback,
177 upgrade: self.upgrade,
178 }
179 }
180
181 pub fn inspect<II>(self, inspect: II) -> Builder<C, II, L, R> {
183 Builder {
184 connect: self.connect,
185 inspect,
186 fallback: self.fallback,
187 upgrade: self.upgrade,
188 }
189 }
190
191 pub fn fallback<LL>(self, fallback: LL) -> Builder<C, I, LL, R> {
193 Builder {
194 connect: self.connect,
195 inspect: self.inspect,
196 fallback,
197 upgrade: self.upgrade,
198 }
199 }
200
201 pub fn upgrade<RR>(self, upgrade: RR) -> Builder<C, I, L, RR> {
203 Builder {
204 connect: self.connect,
205 inspect: self.inspect,
206 fallback: self.fallback,
207 upgrade,
208 }
209 }
210
211 pub fn build<Dst>(self) -> Negotiate<L::Service, R::Service>
213 where
214 C: Service<Dst>,
215 C::Error: Into<BoxError>,
216 L: Layer<Inspector<C, C::Response, I>>,
217 L::Service: Service<Dst> + Clone,
218 <L::Service as Service<Dst>>::Error: Into<BoxError>,
219 R: Layer<Inspected<C::Response>>,
220 R::Service: Service<()> + Clone,
221 <R::Service as Service<()>>::Error: Into<BoxError>,
222 I: Fn(&C::Response) -> bool + Clone,
223 {
224 let Builder {
225 connect,
226 inspect,
227 fallback,
228 upgrade,
229 } = self;
230
231 let slot = Arc::new(Mutex::new(None));
232 let wrapped = Inspector {
233 svc: connect,
234 inspect,
235 slot: slot.clone(),
236 };
237 let left = fallback.layer(wrapped);
238
239 let right = upgrade.layer(Inspected { slot });
240
241 Negotiate { left, right }
242 }
243 }
244
245 impl<L, R> Negotiate<L, R> {
246 pub fn fallback_mut(&mut self) -> &mut L {
248 &mut self.left
249 }
250
251 pub fn upgrade_mut(&mut self) -> &mut R {
253 &mut self.right
254 }
255 }
256
257 impl<L, R, Target> Service<Target> for Negotiate<L, R>
258 where
259 L: Service<Target> + Clone,
260 L::Error: Into<BoxError>,
261 R: Service<()> + Clone,
262 R::Error: Into<BoxError>,
263 {
264 type Response = Negotiated<L::Response, R::Response>;
265 type Error = BoxError;
266 type Future = Negotiating<Target, L, R>;
267
268 fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
269 self.left.poll_ready(cx).map_err(Into::into)
270 }
271
272 fn call(&mut self, dst: Target) -> Self::Future {
273 let left = self.left.clone();
274 Negotiating {
275 state: State::Eager {
276 future: self.right.call(()),
277 dst: Some(dst),
278 },
279 left: std::mem::replace(&mut self.left, left),
281 right: self.right.clone(),
282 }
283 }
284 }
285
286 impl<Dst, L, R> Future for Negotiating<Dst, L, R>
287 where
288 L: Service<Dst>,
289 L::Error: Into<BoxError>,
290 R: Service<()>,
291 R::Error: Into<BoxError>,
292 {
293 type Output = Result<Negotiated<L::Response, R::Response>, BoxError>;
294
295 fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
296 let mut me = self.project();
302 loop {
303 match me.state.as_mut().project() {
304 StateProj::Eager { future, dst } => match ready!(future.poll(cx)) {
305 Ok(out) => return Poll::Ready(Ok(Negotiated::Upgraded(out))),
306 Err(err) => {
307 let err = err.into();
308 if UseOther::is(&*err) {
309 let dst = dst.take().unwrap();
310 let f = me.left.call(dst);
311 me.state.set(State::Fallback { future: f });
312 continue;
313 } else {
314 return Poll::Ready(Err(err));
315 }
316 }
317 },
318 StateProj::Fallback { future } => match ready!(future.poll(cx)) {
319 Ok(out) => return Poll::Ready(Ok(Negotiated::Fallback(out))),
320 Err(err) => {
321 let err = err.into();
322 if UseOther::is(&*err) {
323 let f = me.right.call(());
324 me.state.set(State::Upgrade { future: f });
325 continue;
326 } else {
327 return Poll::Ready(Err(err));
328 }
329 }
330 },
331 StateProj::Upgrade { future } => match ready!(future.poll(cx)) {
332 Ok(out) => return Poll::Ready(Ok(Negotiated::Upgraded(out))),
333 Err(err) => return Poll::Ready(Err(err.into())),
334 },
335 }
336 }
337 }
338 }
339
340 impl<L, R> Negotiated<L, R> {
341 #[cfg(test)]
343 pub(super) fn is_fallback(&self) -> bool {
344 matches!(self, Negotiated::Fallback(_))
345 }
346
347 #[cfg(test)]
348 pub(super) fn is_upgraded(&self) -> bool {
349 matches!(self, Negotiated::Upgraded(_))
350 }
351
352 pub fn fallback_ref(&self) -> Option<&L> {
356 if let Negotiated::Fallback(ref left) = self {
357 Some(left)
358 } else {
359 None
360 }
361 }
362
363 pub fn fallback_mut(&mut self) -> Option<&mut L> {
365 if let Negotiated::Fallback(ref mut left) = self {
366 Some(left)
367 } else {
368 None
369 }
370 }
371
372 pub fn upgraded_ref(&self) -> Option<&R> {
374 if let Negotiated::Upgraded(ref right) = self {
375 Some(right)
376 } else {
377 None
378 }
379 }
380
381 pub fn upgraded_mut(&mut self) -> Option<&mut R> {
383 if let Negotiated::Upgraded(ref mut right) = self {
384 Some(right)
385 } else {
386 None
387 }
388 }
389 }
390
391 impl<L, R, Req, Res, E> Service<Req> for Negotiated<L, R>
392 where
393 L: Service<Req, Response = Res, Error = E>,
394 R: Service<Req, Response = Res, Error = E>,
395 {
396 type Response = Res;
397 type Error = E;
398 type Future = NegotiatedFuture<L::Future, R::Future>;
399
400 fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
401 match self {
402 Negotiated::Fallback(ref mut s) => s.poll_ready(cx),
403 Negotiated::Upgraded(ref mut s) => s.poll_ready(cx),
404 }
405 }
406
407 fn call(&mut self, req: Req) -> Self::Future {
408 match self {
409 Negotiated::Fallback(ref mut s) => NegotiatedFuture::Fallback {
410 future: s.call(req),
411 },
412 Negotiated::Upgraded(ref mut s) => NegotiatedFuture::Upgraded {
413 future: s.call(req),
414 },
415 }
416 }
417 }
418
419 impl<L, R, Out> Future for NegotiatedFuture<L, R>
420 where
421 L: Future<Output = Out>,
422 R: Future<Output = Out>,
423 {
424 type Output = Out;
425
426 fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
427 match self.project() {
428 NegotiatedProj::Fallback { future } => future.poll(cx),
429 NegotiatedProj::Upgraded { future } => future.poll(cx),
430 }
431 }
432 }
433
434 pub struct Inspector<M, S, I> {
437 svc: M,
438 inspect: I,
439 slot: Arc<Mutex<Option<S>>>,
440 }
441
442 pin_project! {
443 pub struct InspectFuture<F, S, I> {
444 #[pin]
445 future: F,
446 inspect: I,
447 slot: Arc<Mutex<Option<S>>>,
448 }
449 }
450
451 impl<M: Clone, S, I: Clone> Clone for Inspector<M, S, I> {
452 fn clone(&self) -> Self {
453 Self {
454 svc: self.svc.clone(),
455 inspect: self.inspect.clone(),
456 slot: self.slot.clone(),
457 }
458 }
459 }
460
461 impl<M, S, I, Target> Service<Target> for Inspector<M, S, I>
462 where
463 M: Service<Target, Response = S>,
464 M::Error: Into<BoxError>,
465 I: Clone + Fn(&S) -> bool,
466 {
467 type Response = M::Response;
468 type Error = BoxError;
469 type Future = InspectFuture<M::Future, S, I>;
470
471 fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
472 self.svc.poll_ready(cx).map_err(Into::into)
473 }
474
475 fn call(&mut self, dst: Target) -> Self::Future {
476 InspectFuture {
477 future: self.svc.call(dst),
478 inspect: self.inspect.clone(),
479 slot: self.slot.clone(),
480 }
481 }
482 }
483
484 impl<F, I, S, E> Future for InspectFuture<F, S, I>
485 where
486 F: Future<Output = Result<S, E>>,
487 E: Into<BoxError>,
488 I: Fn(&S) -> bool,
489 {
490 type Output = Result<S, BoxError>;
491
492 fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
493 let me = self.project();
494 let s = ready!(me.future.poll(cx)).map_err(Into::into)?;
495 Poll::Ready(if (me.inspect)(&s) {
496 *me.slot.lock().unwrap() = Some(s);
497 Err(UseOther.into())
498 } else {
499 Ok(s)
500 })
501 }
502 }
503
504 pub struct Inspected<S> {
505 slot: Arc<Mutex<Option<S>>>,
506 }
507
508 impl<S, Target> Service<Target> for Inspected<S> {
509 type Response = S;
510 type Error = BoxError;
511 type Future = std::future::Ready<Result<S, BoxError>>;
512
513 fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
514 if self.slot.lock().unwrap().is_some() {
515 Poll::Ready(Ok(()))
516 } else {
517 Poll::Ready(Err(UseOther.into()))
518 }
519 }
520
521 fn call(&mut self, _dst: Target) -> Self::Future {
522 let s = self
523 .slot
524 .lock()
525 .unwrap()
526 .take()
527 .ok_or_else(|| UseOther.into());
528 std::future::ready(s)
529 }
530 }
531
532 impl<S> Clone for Inspected<S> {
533 fn clone(&self) -> Inspected<S> {
534 Inspected {
535 slot: self.slot.clone(),
536 }
537 }
538 }
539
540 #[derive(Debug)]
541 struct UseOther;
542
543 impl std::fmt::Display for UseOther {
544 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
545 f.write_str("sentinel error; using other")
546 }
547 }
548
549 impl std::error::Error for UseOther {}
550
551 impl UseOther {
552 fn is(err: &(dyn std::error::Error + 'static)) -> bool {
553 let mut source = Some(err);
554 while let Some(err) = source {
555 if err.is::<UseOther>() {
556 return true;
557 }
558 source = err.source();
559 }
560 false
561 }
562 }
563}
564
565#[cfg(test)]
566mod tests {
567 use futures_util::future;
568 use tower_service::Service;
569 use tower_test::assert_request_eq;
570
571 #[tokio::test]
572 async fn not_negotiated_falls_back_to_left() {
573 let (mock_svc, mut handle) = tower_test::mock::pair::<(), &'static str>();
574
575 let mut negotiate = super::builder()
576 .connect(mock_svc)
577 .inspect(|_: &&str| false)
578 .fallback(layer_fn(|s| s))
579 .upgrade(layer_fn(|s| s))
580 .build();
581
582 std::future::poll_fn(|cx| negotiate.poll_ready(cx))
583 .await
584 .unwrap();
585
586 let fut = negotiate.call(());
587 let nsvc = future::join(fut, async move {
588 assert_request_eq!(handle, ()).send_response("one");
589 })
590 .await
591 .0
592 .expect("call");
593 assert!(nsvc.is_fallback());
594 }
595
596 #[tokio::test]
597 async fn negotiated_uses_right() {
598 let (mock_svc, mut handle) = tower_test::mock::pair::<(), &'static str>();
599
600 let mut negotiate = super::builder()
601 .connect(mock_svc)
602 .inspect(|_: &&str| true)
603 .fallback(layer_fn(|s| s))
604 .upgrade(layer_fn(|s| s))
605 .build();
606
607 std::future::poll_fn(|cx| negotiate.poll_ready(cx))
608 .await
609 .unwrap();
610
611 let fut = negotiate.call(());
612 let nsvc = future::join(fut, async move {
613 assert_request_eq!(handle, ()).send_response("one");
614 })
615 .await
616 .0
617 .expect("call");
618
619 assert!(nsvc.is_upgraded());
620 }
621
622 fn layer_fn<F>(f: F) -> LayerFn<F> {
623 LayerFn(f)
624 }
625
626 #[derive(Clone)]
627 struct LayerFn<F>(F);
628
629 impl<F, S, Out> tower_layer::Layer<S> for LayerFn<F>
630 where
631 F: Fn(S) -> Out,
632 {
633 type Service = Out;
634 fn layer(&self, inner: S) -> Self::Service {
635 (self.0)(inner)
636 }
637 }
638}