axum_scientist/
lib.rs

1use std::{convert::Infallible, future::Future, pin::Pin};
2
3use axum_core::{
4    body::BoxBody,
5    response::{IntoResponse, Response},
6};
7use hyper::{body::HttpBody, Body};
8use tower_service::Service;
9
10type Request = http::Request<hyper::body::Body>;
11
12pub mod comparators;
13
14pub enum BodyProxyState {
15    Raw(BoxBody),
16    Read(Vec<u8>),
17    /// Only used to swap around the internal state between the other states.
18    Interim,
19}
20
21pub struct BodyProxy {
22    state: BodyProxyState,
23}
24
25impl BodyProxy {
26    fn new(body: BoxBody) -> Self {
27        Self {
28            state: BodyProxyState::Raw(body),
29        }
30    }
31
32    /// Afterwards, the inner state is always [`BodyProxyState::Read`]
33    async fn read(&mut self) {
34        if let BodyProxyState::Read(_) = self.state {
35            return;
36        }
37
38        let mut interim = BodyProxyState::Interim;
39        std::mem::swap(&mut self.state, &mut interim);
40
41        let result = match interim {
42            BodyProxyState::Raw(mut b) => {
43                let mut body = Vec::<u8>::new();
44
45                while let Some(Ok(bytes)) = b.data().await {
46                    body.extend_from_slice(&bytes);
47                }
48                body
49            }
50            BodyProxyState::Read(r) => r,
51            BodyProxyState::Interim => unreachable!(),
52        };
53
54        self.state = BodyProxyState::Read(result);
55    }
56
57    pub async fn read_body_as_vec(&mut self) -> Vec<u8> {
58        self.read().await;
59
60        match &self.state {
61            BodyProxyState::Read(r) => r.clone(),
62            _ => unreachable!(),
63        }
64    }
65
66    pub async fn read_body_as_slice(&mut self) -> &[u8] {
67        self.read().await;
68
69        match &self.state {
70            BodyProxyState::Read(r) => r,
71            _ => unreachable!(),
72        }
73    }
74
75    fn into_body(self) -> BoxBody {
76        match self.state {
77            BodyProxyState::Raw(b) => b,
78            BodyProxyState::Read(bytes) => Body::from(bytes)
79                .map_err(axum_core::Error::new)
80                .boxed_unsync(),
81            BodyProxyState::Interim => unreachable!(),
82        }
83    }
84}
85
86#[async_trait::async_trait]
87pub trait ResponseComparator {
88    async fn compare_response_bodies(
89        &self,
90        response_parts_left: &http::response::Parts,
91        response_body_left: &mut BodyProxy,
92        response_parts_right: &http::response::Parts,
93        response_body_right: &mut BodyProxy,
94    );
95}
96
97async fn make_both_calls<L, R>(
98    mut left_service: L,
99    mut right_service: R,
100    req: Request,
101) -> (Response, Response)
102where
103    L: Service<Request, Error = Infallible> + Clone + Send + 'static,
104    R: Service<Request, Error = Infallible> + Clone + Send + 'static,
105    L::Response: IntoResponse,
106    R::Response: IntoResponse,
107    L::Future: Send + 'static,
108    R::Future: Send + 'static,
109{
110    let (req_l, req_r) = duplicate_request(req).await;
111
112    let response_left = left_service.call(req_l).await.unwrap().into_response();
113    let response_right = right_service.call(req_r).await.unwrap().into_response();
114    (response_left, response_right)
115}
116
117async fn compare_responses<C: ResponseComparator>(
118    comparator: C,
119    response_left: Response,
120    response_right: Response,
121) -> (Response, Response) {
122    let (parts_left, body_left) = response_left.into_parts();
123    let (parts_right, body_right) = response_right.into_parts();
124
125    let mut body_proxy_left = BodyProxy::new(body_left);
126    let mut body_proxy_right = BodyProxy::new(body_right);
127
128    comparator
129        .compare_response_bodies(
130            &parts_left,
131            &mut body_proxy_left,
132            &parts_right,
133            &mut body_proxy_right,
134        )
135        .await;
136    (
137        Response::from_parts(parts_left, body_proxy_left.into_body()),
138        Response::from_parts(parts_right, body_proxy_right.into_body()),
139    )
140}
141
142async fn duplicate_body(mut body: hyper::body::Body) -> (Body, Body) {
143    let mut body1 = Vec::<u8>::new();
144    let mut body2 = Vec::<u8>::new();
145
146    while let Some(Ok(bytes)) = body.data().await {
147        body1.extend_from_slice(&bytes);
148        body2.extend_from_slice(&bytes);
149    }
150
151    (Body::from(body1), Body::from(body2))
152}
153
154async fn duplicate_request(req: Request) -> (Request, Request) {
155    let (parts, body) = req.into_parts();
156
157    let mut request_builder = http::Request::builder()
158        .method(parts.method.clone())
159        .uri(parts.uri.clone());
160
161    if let Some(headers) = request_builder.headers_mut() {
162        *headers = parts.headers.clone();
163    }
164
165    let (body1, body2) = duplicate_body(body).await;
166
167    let req1 = Request::from_parts(parts, body1);
168    let req2 = request_builder.body(body2).unwrap();
169
170    (req1, req2)
171}
172
173#[derive(Clone)]
174pub struct Scientist<L, R, C> {
175    left_service: L,
176    right_service: R,
177    comparator: C,
178}
179
180impl<L, R, C> Scientist<L, R, C>
181where
182    L: Service<Request, Error = Infallible> + Clone + Send + 'static,
183    R: Service<Request, Error = Infallible> + Clone + Send + 'static,
184    L::Response: IntoResponse,
185    R::Response: IntoResponse,
186    L::Future: Send + 'static,
187    R::Future: Send + 'static,
188    C: ResponseComparator + Clone + Send + 'static,
189{
190    pub fn new(left: L, right: R, comparator: C) -> Self {
191        Self {
192            left_service: left,
193            right_service: right,
194            comparator,
195        }
196    }
197}
198
199impl<L, R, C> Service<Request> for Scientist<L, R, C>
200where
201    L: Service<Request, Error = Infallible> + Clone + Send + 'static,
202    R: Service<Request, Error = Infallible> + Clone + Send + 'static,
203    L::Response: IntoResponse,
204    R::Response: IntoResponse,
205    L::Future: Send + 'static,
206    R::Future: Send + 'static,
207    C: ResponseComparator + Clone + Send + Sync + 'static,
208{
209    type Response = Response;
210
211    type Error = Infallible;
212
213    type Future =
214        Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
215
216    fn poll_ready(
217        &mut self,
218        _cx: &mut std::task::Context<'_>,
219    ) -> std::task::Poll<Result<(), Self::Error>> {
220        std::task::Poll::Ready(Ok(()))
221    }
222
223    fn call(&mut self, req: Request) -> Self::Future {
224        let left_service = self.left_service.clone();
225        let right_service = self.right_service.clone();
226        let comparator = self.comparator.clone();
227
228        let future = async move {
229            let left_service = left_service.clone();
230            let right_service = right_service.clone();
231            let comparator = comparator.clone();
232
233            let (response_left, response_right) =
234                make_both_calls(left_service, right_service, req).await;
235            let (response_left, _response_right) =
236                compare_responses(comparator, response_left, response_right).await;
237
238            Ok(response_left)
239        };
240        Box::pin(future)
241    }
242}
243
244#[cfg(test)]
245mod tests {
246    use std::sync::Arc;
247
248    use http::response::Parts;
249    use tokio::sync::Mutex;
250
251    use super::*;
252
253    #[tokio::test]
254    async fn body_proxy_read_test() {
255        let bytes = vec![1, 2, 3, 4, 5, 6, 7];
256        let body = Body::from(bytes.clone())
257            .map_err(|e| axum_core::Error::new(e))
258            .boxed_unsync();
259        let mut proxy = BodyProxy::new(body);
260
261        let v = proxy.read_body_as_vec().await;
262        assert_eq!(bytes, v);
263
264        let s = proxy.read_body_as_slice().await;
265        assert_eq!(&bytes, s);
266    }
267
268    #[tokio::test]
269    async fn handle_request() {
270        use axum::routing::{any, get};
271
272        struct Log {
273            status_left: http::StatusCode,
274            status_right: http::StatusCode,
275        }
276
277        #[derive(Clone)]
278        struct StatusCodeComparator {
279            log: Arc<Mutex<Vec<Log>>>,
280        }
281
282        impl StatusCodeComparator {
283            fn new() -> Self {
284                Self {
285                    log: Arc::new(Mutex::new(Vec::new())),
286                }
287            }
288        }
289
290        #[async_trait::async_trait]
291        impl ResponseComparator for StatusCodeComparator {
292            async fn compare_response_bodies(
293                &self,
294                response_parts_left: &Parts,
295                _response_body_left: &mut BodyProxy,
296                response_parts_right: &Parts,
297                _response_body_right: &mut BodyProxy,
298            ) {
299                let mut log = self.log.lock().await;
300                log.push(Log {
301                    status_left: response_parts_left.status.clone(),
302                    status_right: response_parts_right.status.clone(),
303                })
304            }
305        }
306
307        let comparator = StatusCodeComparator::new();
308
309        let mut scientist = Scientist::new(
310            get(|| async { "LEFT" }),
311            any(|| async { "RIGHT" }),
312            comparator.clone(),
313        );
314
315        let req = http::Request::builder()
316            .uri("http://localhost/")
317            .method("GET")
318            .body(hyper::body::Body::from(vec![1, 2, 3]))
319            .unwrap();
320
321        let _response = scientist.call(req).await.unwrap();
322
323        {
324            let log = comparator.log.lock().await;
325            assert_eq!(1, log.len());
326            assert_eq!(http::StatusCode::OK, log[0].status_left);
327            assert_eq!(http::StatusCode::OK, log[0].status_right);
328        }
329
330        let req = http::Request::builder()
331            .uri("http://localhost/")
332            .method("POST")
333            .body(hyper::body::Body::from(vec![1, 2, 3]))
334            .unwrap();
335
336        let _response = scientist.call(req).await.unwrap();
337
338        {
339            let log = comparator.log.lock().await;
340            assert_eq!(2, log.len());
341            assert_eq!(http::StatusCode::METHOD_NOT_ALLOWED, log[1].status_left);
342            assert_eq!(http::StatusCode::OK, log[1].status_right);
343        }
344    }
345
346    #[tokio::test]
347    async fn can_be_used_in_axum() {
348        #[derive(Clone)]
349        struct NoOpComparator;
350
351        #[async_trait::async_trait]
352        impl ResponseComparator for NoOpComparator {
353            async fn compare_response_bodies(
354                &self,
355                _response_parts_left: &Parts,
356                _response_body_left: &mut BodyProxy,
357                _response_parts_right: &Parts,
358                _response_body_right: &mut BodyProxy,
359            ) {
360            }
361        }
362
363        let scientist = Scientist::new(
364            axum::routing::get(|| async { "LEFT" }),
365            axum::routing::any(|| async { "RIGHT" }),
366            NoOpComparator,
367        );
368
369        let _: axum::Router<()> = axum::Router::new()
370            .route_service("/", scientist)
371            .with_state(());
372    }
373}