Skip to main content

openwire_core/
interceptor.rs

1use std::task::{Context, Poll};
2
3use http::{Request, Response};
4use tower::layer::Layer;
5use tower::util::BoxCloneSyncService;
6use tower::Service;
7
8use crate::{BoxFuture, CallContext, RequestBody, ResponseBody, WireError};
9
10pub type WireResponse = Response<ResponseBody>;
11pub type BoxWireService = BoxCloneSyncService<Exchange, WireResponse, WireError>;
12pub type SharedInterceptor = std::sync::Arc<dyn Interceptor>;
13
14#[derive(Debug)]
15pub struct Exchange {
16    request: Request<RequestBody>,
17    context: CallContext,
18    attempt: u32,
19}
20
21impl Exchange {
22    pub fn new(request: Request<RequestBody>, context: CallContext, attempt: u32) -> Self {
23        Self {
24            request,
25            context,
26            attempt,
27        }
28    }
29
30    pub fn request(&self) -> &Request<RequestBody> {
31        &self.request
32    }
33
34    pub fn request_mut(&mut self) -> &mut Request<RequestBody> {
35        &mut self.request
36    }
37
38    pub fn into_request(self) -> Request<RequestBody> {
39        self.request
40    }
41
42    pub fn context(&self) -> &CallContext {
43        &self.context
44    }
45
46    pub fn into_parts(self) -> (Request<RequestBody>, CallContext, u32) {
47        (self.request, self.context, self.attempt)
48    }
49
50    pub fn attempt(&self) -> u32 {
51        self.attempt
52    }
53}
54
55#[derive(Clone)]
56pub struct Next {
57    inner: BoxWireService,
58}
59
60impl Next {
61    pub fn new(inner: BoxWireService) -> Self {
62        Self { inner }
63    }
64
65    pub fn run(self, exchange: Exchange) -> BoxFuture<Result<WireResponse, WireError>> {
66        Box::pin(async move {
67            let mut inner = self.inner;
68            inner.call(exchange).await
69        })
70    }
71}
72
73pub trait Interceptor: Send + Sync + 'static {
74    fn intercept(
75        &self,
76        exchange: Exchange,
77        next: Next,
78    ) -> BoxFuture<Result<WireResponse, WireError>>;
79}
80
81#[derive(Clone)]
82pub struct InterceptorLayer {
83    interceptor: SharedInterceptor,
84}
85
86impl InterceptorLayer {
87    pub fn new(interceptor: SharedInterceptor) -> Self {
88        Self { interceptor }
89    }
90}
91
92impl<S> Layer<S> for InterceptorLayer
93where
94    S: Service<Exchange, Response = WireResponse, Error = WireError>
95        + Clone
96        + Send
97        + Sync
98        + 'static,
99    S::Future: Send + 'static,
100{
101    type Service = InterceptorService<S>;
102
103    fn layer(&self, inner: S) -> Self::Service {
104        InterceptorService {
105            inner,
106            interceptor: self.interceptor.clone(),
107        }
108    }
109}
110
111#[derive(Clone)]
112pub struct InterceptorService<S> {
113    inner: S,
114    interceptor: SharedInterceptor,
115}
116
117impl<S> Service<Exchange> for InterceptorService<S>
118where
119    S: Service<Exchange, Response = WireResponse, Error = WireError>
120        + Clone
121        + Send
122        + Sync
123        + 'static,
124    S::Future: Send + 'static,
125{
126    type Response = WireResponse;
127    type Error = WireError;
128    type Future = BoxFuture<Result<Self::Response, Self::Error>>;
129
130    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
131        self.inner.poll_ready(cx)
132    }
133
134    fn call(&mut self, exchange: Exchange) -> Self::Future {
135        let replacement = self.inner.clone();
136        let inner = std::mem::replace(&mut self.inner, replacement);
137        let next = Next::new(BoxCloneSyncService::new(inner));
138        self.interceptor.intercept(exchange, next)
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use std::task::{Context, Poll};
145
146    use http::Request;
147    use tower::{Service, ServiceExt};
148
149    use super::{Exchange, Interceptor, InterceptorService, Next, WireResponse};
150    use crate::{
151        BoxFuture, CallContext, NoopEventListenerFactory, RequestBody, ResponseBody, WireError,
152    };
153
154    #[derive(Clone)]
155    struct PassthroughInterceptor;
156
157    impl Interceptor for PassthroughInterceptor {
158        fn intercept(
159            &self,
160            exchange: Exchange,
161            next: Next,
162        ) -> BoxFuture<Result<WireResponse, WireError>> {
163            next.run(exchange)
164        }
165    }
166
167    struct ReadinessTrackingService {
168        was_polled: bool,
169        is_clone: bool,
170    }
171
172    impl Clone for ReadinessTrackingService {
173        fn clone(&self) -> Self {
174            Self {
175                was_polled: false,
176                is_clone: true,
177            }
178        }
179    }
180
181    impl Service<Exchange> for ReadinessTrackingService {
182        type Response = WireResponse;
183        type Error = WireError;
184        type Future = BoxFuture<Result<Self::Response, Self::Error>>;
185
186        fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
187            assert!(
188                !self.is_clone,
189                "poll_ready should not be re-run against a cloned inner service",
190            );
191            self.was_polled = true;
192            Poll::Ready(Ok(()))
193        }
194
195        fn call(&mut self, _exchange: Exchange) -> Self::Future {
196            let was_polled = std::mem::take(&mut self.was_polled);
197            Box::pin(async move {
198                assert!(
199                    was_polled,
200                    "call must use the exact inner service instance that was polled ready",
201                );
202                Ok(http::Response::new(ResponseBody::empty()))
203            })
204        }
205    }
206
207    fn test_exchange() -> Exchange {
208        let request = Request::builder()
209            .uri("http://example.com/")
210            .body(RequestBody::absent())
211            .expect("request");
212        let factory =
213            std::sync::Arc::new(NoopEventListenerFactory) as crate::SharedEventListenerFactory;
214        let ctx = CallContext::from_factory(&factory, &request, None);
215        Exchange::new(request, ctx, 1)
216    }
217
218    #[tokio::test]
219    async fn interceptor_service_preserves_ready_inner_service() {
220        let mut service = InterceptorService {
221            inner: ReadinessTrackingService {
222                was_polled: false,
223                is_clone: false,
224            },
225            interceptor: std::sync::Arc::new(PassthroughInterceptor),
226        };
227
228        let response = service
229            .ready()
230            .await
231            .expect("service ready")
232            .call(test_exchange())
233            .await
234            .expect("response");
235
236        assert_eq!(response.status(), http::StatusCode::OK);
237    }
238}