openwire_core/
interceptor.rs1use std::task::{Context, Poll};
2
3use http::{Request, Response};
4use tower::layer::Layer;
5use tower::util::BoxCloneSyncService;
6use tower::Service;
7
8use crate::{BoxFuture, CallContext, RequestBody, ResponseBody, WireError};
9
10pub type WireResponse = Response<ResponseBody>;
11pub type BoxWireService = BoxCloneSyncService<Exchange, WireResponse, WireError>;
12pub type SharedInterceptor = std::sync::Arc<dyn Interceptor>;
13
14#[derive(Debug)]
15pub struct Exchange {
16 request: Request<RequestBody>,
17 context: CallContext,
18 attempt: u32,
19}
20
21impl Exchange {
22 pub fn new(request: Request<RequestBody>, context: CallContext, attempt: u32) -> Self {
23 Self {
24 request,
25 context,
26 attempt,
27 }
28 }
29
30 pub fn request(&self) -> &Request<RequestBody> {
31 &self.request
32 }
33
34 pub fn request_mut(&mut self) -> &mut Request<RequestBody> {
35 &mut self.request
36 }
37
38 pub fn into_request(self) -> Request<RequestBody> {
39 self.request
40 }
41
42 pub fn context(&self) -> &CallContext {
43 &self.context
44 }
45
46 pub fn into_parts(self) -> (Request<RequestBody>, CallContext, u32) {
47 (self.request, self.context, self.attempt)
48 }
49
50 pub fn attempt(&self) -> u32 {
51 self.attempt
52 }
53}
54
55#[derive(Clone)]
56pub struct Next {
57 inner: BoxWireService,
58}
59
60impl Next {
61 pub fn new(inner: BoxWireService) -> Self {
62 Self { inner }
63 }
64
65 pub fn run(self, exchange: Exchange) -> BoxFuture<Result<WireResponse, WireError>> {
66 Box::pin(async move {
67 let mut inner = self.inner;
68 inner.call(exchange).await
69 })
70 }
71}
72
73pub trait Interceptor: Send + Sync + 'static {
74 fn intercept(
75 &self,
76 exchange: Exchange,
77 next: Next,
78 ) -> BoxFuture<Result<WireResponse, WireError>>;
79}
80
81#[derive(Clone)]
82pub struct InterceptorLayer {
83 interceptor: SharedInterceptor,
84}
85
86impl InterceptorLayer {
87 pub fn new(interceptor: SharedInterceptor) -> Self {
88 Self { interceptor }
89 }
90}
91
92impl<S> Layer<S> for InterceptorLayer
93where
94 S: Service<Exchange, Response = WireResponse, Error = WireError>
95 + Clone
96 + Send
97 + Sync
98 + 'static,
99 S::Future: Send + 'static,
100{
101 type Service = InterceptorService<S>;
102
103 fn layer(&self, inner: S) -> Self::Service {
104 InterceptorService {
105 inner,
106 interceptor: self.interceptor.clone(),
107 }
108 }
109}
110
111#[derive(Clone)]
112pub struct InterceptorService<S> {
113 inner: S,
114 interceptor: SharedInterceptor,
115}
116
117impl<S> Service<Exchange> for InterceptorService<S>
118where
119 S: Service<Exchange, Response = WireResponse, Error = WireError>
120 + Clone
121 + Send
122 + Sync
123 + 'static,
124 S::Future: Send + 'static,
125{
126 type Response = WireResponse;
127 type Error = WireError;
128 type Future = BoxFuture<Result<Self::Response, Self::Error>>;
129
130 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
131 self.inner.poll_ready(cx)
132 }
133
134 fn call(&mut self, exchange: Exchange) -> Self::Future {
135 let replacement = self.inner.clone();
136 let inner = std::mem::replace(&mut self.inner, replacement);
137 let next = Next::new(BoxCloneSyncService::new(inner));
138 self.interceptor.intercept(exchange, next)
139 }
140}
141
142#[cfg(test)]
143mod tests {
144 use std::task::{Context, Poll};
145
146 use http::Request;
147 use tower::{Service, ServiceExt};
148
149 use super::{Exchange, Interceptor, InterceptorService, Next, WireResponse};
150 use crate::{
151 BoxFuture, CallContext, NoopEventListenerFactory, RequestBody, ResponseBody, WireError,
152 };
153
154 #[derive(Clone)]
155 struct PassthroughInterceptor;
156
157 impl Interceptor for PassthroughInterceptor {
158 fn intercept(
159 &self,
160 exchange: Exchange,
161 next: Next,
162 ) -> BoxFuture<Result<WireResponse, WireError>> {
163 next.run(exchange)
164 }
165 }
166
167 struct ReadinessTrackingService {
168 was_polled: bool,
169 is_clone: bool,
170 }
171
172 impl Clone for ReadinessTrackingService {
173 fn clone(&self) -> Self {
174 Self {
175 was_polled: false,
176 is_clone: true,
177 }
178 }
179 }
180
181 impl Service<Exchange> for ReadinessTrackingService {
182 type Response = WireResponse;
183 type Error = WireError;
184 type Future = BoxFuture<Result<Self::Response, Self::Error>>;
185
186 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
187 assert!(
188 !self.is_clone,
189 "poll_ready should not be re-run against a cloned inner service",
190 );
191 self.was_polled = true;
192 Poll::Ready(Ok(()))
193 }
194
195 fn call(&mut self, _exchange: Exchange) -> Self::Future {
196 let was_polled = std::mem::take(&mut self.was_polled);
197 Box::pin(async move {
198 assert!(
199 was_polled,
200 "call must use the exact inner service instance that was polled ready",
201 );
202 Ok(http::Response::new(ResponseBody::empty()))
203 })
204 }
205 }
206
207 fn test_exchange() -> Exchange {
208 let request = Request::builder()
209 .uri("http://example.com/")
210 .body(RequestBody::absent())
211 .expect("request");
212 let factory =
213 std::sync::Arc::new(NoopEventListenerFactory) as crate::SharedEventListenerFactory;
214 let ctx = CallContext::from_factory(&factory, &request, None);
215 Exchange::new(request, ctx, 1)
216 }
217
218 #[tokio::test]
219 async fn interceptor_service_preserves_ready_inner_service() {
220 let mut service = InterceptorService {
221 inner: ReadinessTrackingService {
222 was_polled: false,
223 is_clone: false,
224 },
225 interceptor: std::sync::Arc::new(PassthroughInterceptor),
226 };
227
228 let response = service
229 .ready()
230 .await
231 .expect("service ready")
232 .call(test_exchange())
233 .await
234 .expect("response");
235
236 assert_eq!(response.status(), http::StatusCode::OK);
237 }
238}