1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
//! [`Or`] used to combine two services into one.

use super::FromEmptyRouter;
use crate::body::BoxBody;
use futures_util::ready;
use http::{Request, Response};
use pin_project_lite::pin_project;
use std::{
    future::Future,
    pin::Pin,
    task::{Context, Poll},
};
use tower::{util::Oneshot, ServiceExt};
use tower_service::Service;

/// [`tower::Service`] that is the combination of two routers.
///
/// See [`Router::or`] for more details.
///
/// [`Router::or`]: super::Router::or
#[derive(Debug, Clone, Copy)]
pub struct Or<A, B> {
    pub(super) first: A,
    pub(super) second: B,
}

#[allow(warnings)]
impl<A, B, ReqBody> Service<Request<ReqBody>> for Or<A, B>
where
    A: Service<Request<ReqBody>, Response = Response<BoxBody>> + Clone,
    B: Service<Request<ReqBody>, Response = Response<BoxBody>, Error = A::Error> + Clone,
    ReqBody: Send + Sync + 'static,
    A: Send + 'static,
    B: Send + 'static,
    A::Future: Send + 'static,
    B::Future: Send + 'static,
{
    type Response = Response<BoxBody>;
    type Error = A::Error;
    type Future = ResponseFuture<A, B, ReqBody>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        Poll::Ready(Ok(()))
    }

    fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
        let original_uri = req.uri().clone();

        ResponseFuture {
            state: State::FirstFuture {
                f: self.first.clone().oneshot(req),
            },
            second: Some(self.second.clone()),
            original_uri: Some(original_uri),
        }
    }
}

pin_project! {
    /// Response future for [`Or`].
    pub struct ResponseFuture<A, B, ReqBody>
    where
        A: Service<Request<ReqBody>>,
        B: Service<Request<ReqBody>>,
    {
        #[pin]
        state: State<A, B, ReqBody>,
        second: Option<B>,
        // Some services, namely `Nested`, mutates the request URI so we must
        // restore it to its original state before calling `second`
        original_uri: Option<http::Uri>,
    }
}

pin_project! {
    #[project = StateProj]
    enum State<A, B, ReqBody>
    where
        A: Service<Request<ReqBody>>,
        B: Service<Request<ReqBody>>,
    {
        FirstFuture { #[pin] f: Oneshot<A, Request<ReqBody>> },
        SecondFuture {
            #[pin]
            f: Oneshot<B, Request<ReqBody>>,
        }
    }
}

impl<A, B, ReqBody> Future for ResponseFuture<A, B, ReqBody>
where
    A: Service<Request<ReqBody>, Response = Response<BoxBody>>,
    B: Service<Request<ReqBody>, Response = Response<BoxBody>, Error = A::Error>,
    ReqBody: Send + Sync + 'static,
{
    type Output = Result<Response<BoxBody>, A::Error>;

    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        loop {
            let mut this = self.as_mut().project();

            let new_state = match this.state.as_mut().project() {
                StateProj::FirstFuture { f } => {
                    let mut response = ready!(f.poll(cx)?);

                    let mut req = if let Some(ext) = response
                        .extensions_mut()
                        .remove::<FromEmptyRouter<ReqBody>>()
                    {
                        ext.request
                    } else {
                        return Poll::Ready(Ok(response));
                    };

                    *req.uri_mut() = this.original_uri.take().unwrap();

                    let second = this.second.take().expect("future polled after completion");

                    State::SecondFuture {
                        f: second.oneshot(req),
                    }
                }
                StateProj::SecondFuture { f } => return f.poll(cx),
            };

            this.state.set(new_state);
        }
    }
}