1use std::{convert::Infallible, future::Future, pin::Pin};
2
3use axum_core::{
4 body::BoxBody,
5 response::{IntoResponse, Response},
6};
7use hyper::{body::HttpBody, Body};
8use tower_service::Service;
9
10type Request = http::Request<hyper::body::Body>;
11
12pub mod comparators;
13
14pub enum BodyProxyState {
15 Raw(BoxBody),
16 Read(Vec<u8>),
17 Interim,
19}
20
21pub struct BodyProxy {
22 state: BodyProxyState,
23}
24
25impl BodyProxy {
26 fn new(body: BoxBody) -> Self {
27 Self {
28 state: BodyProxyState::Raw(body),
29 }
30 }
31
32 async fn read(&mut self) {
34 if let BodyProxyState::Read(_) = self.state {
35 return;
36 }
37
38 let mut interim = BodyProxyState::Interim;
39 std::mem::swap(&mut self.state, &mut interim);
40
41 let result = match interim {
42 BodyProxyState::Raw(mut b) => {
43 let mut body = Vec::<u8>::new();
44
45 while let Some(Ok(bytes)) = b.data().await {
46 body.extend_from_slice(&bytes);
47 }
48 body
49 }
50 BodyProxyState::Read(r) => r,
51 BodyProxyState::Interim => unreachable!(),
52 };
53
54 self.state = BodyProxyState::Read(result);
55 }
56
57 pub async fn read_body_as_vec(&mut self) -> Vec<u8> {
58 self.read().await;
59
60 match &self.state {
61 BodyProxyState::Read(r) => r.clone(),
62 _ => unreachable!(),
63 }
64 }
65
66 pub async fn read_body_as_slice(&mut self) -> &[u8] {
67 self.read().await;
68
69 match &self.state {
70 BodyProxyState::Read(r) => r,
71 _ => unreachable!(),
72 }
73 }
74
75 fn into_body(self) -> BoxBody {
76 match self.state {
77 BodyProxyState::Raw(b) => b,
78 BodyProxyState::Read(bytes) => Body::from(bytes)
79 .map_err(axum_core::Error::new)
80 .boxed_unsync(),
81 BodyProxyState::Interim => unreachable!(),
82 }
83 }
84}
85
86#[async_trait::async_trait]
87pub trait ResponseComparator {
88 async fn compare_response_bodies(
89 &self,
90 response_parts_left: &http::response::Parts,
91 response_body_left: &mut BodyProxy,
92 response_parts_right: &http::response::Parts,
93 response_body_right: &mut BodyProxy,
94 );
95}
96
97async fn make_both_calls<L, R>(
98 mut left_service: L,
99 mut right_service: R,
100 req: Request,
101) -> (Response, Response)
102where
103 L: Service<Request, Error = Infallible> + Clone + Send + 'static,
104 R: Service<Request, Error = Infallible> + Clone + Send + 'static,
105 L::Response: IntoResponse,
106 R::Response: IntoResponse,
107 L::Future: Send + 'static,
108 R::Future: Send + 'static,
109{
110 let (req_l, req_r) = duplicate_request(req).await;
111
112 let response_left = left_service.call(req_l).await.unwrap().into_response();
113 let response_right = right_service.call(req_r).await.unwrap().into_response();
114 (response_left, response_right)
115}
116
117async fn compare_responses<C: ResponseComparator>(
118 comparator: C,
119 response_left: Response,
120 response_right: Response,
121) -> (Response, Response) {
122 let (parts_left, body_left) = response_left.into_parts();
123 let (parts_right, body_right) = response_right.into_parts();
124
125 let mut body_proxy_left = BodyProxy::new(body_left);
126 let mut body_proxy_right = BodyProxy::new(body_right);
127
128 comparator
129 .compare_response_bodies(
130 &parts_left,
131 &mut body_proxy_left,
132 &parts_right,
133 &mut body_proxy_right,
134 )
135 .await;
136 (
137 Response::from_parts(parts_left, body_proxy_left.into_body()),
138 Response::from_parts(parts_right, body_proxy_right.into_body()),
139 )
140}
141
142async fn duplicate_body(mut body: hyper::body::Body) -> (Body, Body) {
143 let mut body1 = Vec::<u8>::new();
144 let mut body2 = Vec::<u8>::new();
145
146 while let Some(Ok(bytes)) = body.data().await {
147 body1.extend_from_slice(&bytes);
148 body2.extend_from_slice(&bytes);
149 }
150
151 (Body::from(body1), Body::from(body2))
152}
153
154async fn duplicate_request(req: Request) -> (Request, Request) {
155 let (parts, body) = req.into_parts();
156
157 let mut request_builder = http::Request::builder()
158 .method(parts.method.clone())
159 .uri(parts.uri.clone());
160
161 if let Some(headers) = request_builder.headers_mut() {
162 *headers = parts.headers.clone();
163 }
164
165 let (body1, body2) = duplicate_body(body).await;
166
167 let req1 = Request::from_parts(parts, body1);
168 let req2 = request_builder.body(body2).unwrap();
169
170 (req1, req2)
171}
172
173#[derive(Clone)]
174pub struct Scientist<L, R, C> {
175 left_service: L,
176 right_service: R,
177 comparator: C,
178}
179
180impl<L, R, C> Scientist<L, R, C>
181where
182 L: Service<Request, Error = Infallible> + Clone + Send + 'static,
183 R: Service<Request, Error = Infallible> + Clone + Send + 'static,
184 L::Response: IntoResponse,
185 R::Response: IntoResponse,
186 L::Future: Send + 'static,
187 R::Future: Send + 'static,
188 C: ResponseComparator + Clone + Send + 'static,
189{
190 pub fn new(left: L, right: R, comparator: C) -> Self {
191 Self {
192 left_service: left,
193 right_service: right,
194 comparator,
195 }
196 }
197}
198
199impl<L, R, C> Service<Request> for Scientist<L, R, C>
200where
201 L: Service<Request, Error = Infallible> + Clone + Send + 'static,
202 R: Service<Request, Error = Infallible> + Clone + Send + 'static,
203 L::Response: IntoResponse,
204 R::Response: IntoResponse,
205 L::Future: Send + 'static,
206 R::Future: Send + 'static,
207 C: ResponseComparator + Clone + Send + Sync + 'static,
208{
209 type Response = Response;
210
211 type Error = Infallible;
212
213 type Future =
214 Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
215
216 fn poll_ready(
217 &mut self,
218 _cx: &mut std::task::Context<'_>,
219 ) -> std::task::Poll<Result<(), Self::Error>> {
220 std::task::Poll::Ready(Ok(()))
221 }
222
223 fn call(&mut self, req: Request) -> Self::Future {
224 let left_service = self.left_service.clone();
225 let right_service = self.right_service.clone();
226 let comparator = self.comparator.clone();
227
228 let future = async move {
229 let left_service = left_service.clone();
230 let right_service = right_service.clone();
231 let comparator = comparator.clone();
232
233 let (response_left, response_right) =
234 make_both_calls(left_service, right_service, req).await;
235 let (response_left, _response_right) =
236 compare_responses(comparator, response_left, response_right).await;
237
238 Ok(response_left)
239 };
240 Box::pin(future)
241 }
242}
243
244#[cfg(test)]
245mod tests {
246 use std::sync::Arc;
247
248 use http::response::Parts;
249 use tokio::sync::Mutex;
250
251 use super::*;
252
253 #[tokio::test]
254 async fn body_proxy_read_test() {
255 let bytes = vec![1, 2, 3, 4, 5, 6, 7];
256 let body = Body::from(bytes.clone())
257 .map_err(|e| axum_core::Error::new(e))
258 .boxed_unsync();
259 let mut proxy = BodyProxy::new(body);
260
261 let v = proxy.read_body_as_vec().await;
262 assert_eq!(bytes, v);
263
264 let s = proxy.read_body_as_slice().await;
265 assert_eq!(&bytes, s);
266 }
267
268 #[tokio::test]
269 async fn handle_request() {
270 use axum::routing::{any, get};
271
272 struct Log {
273 status_left: http::StatusCode,
274 status_right: http::StatusCode,
275 }
276
277 #[derive(Clone)]
278 struct StatusCodeComparator {
279 log: Arc<Mutex<Vec<Log>>>,
280 }
281
282 impl StatusCodeComparator {
283 fn new() -> Self {
284 Self {
285 log: Arc::new(Mutex::new(Vec::new())),
286 }
287 }
288 }
289
290 #[async_trait::async_trait]
291 impl ResponseComparator for StatusCodeComparator {
292 async fn compare_response_bodies(
293 &self,
294 response_parts_left: &Parts,
295 _response_body_left: &mut BodyProxy,
296 response_parts_right: &Parts,
297 _response_body_right: &mut BodyProxy,
298 ) {
299 let mut log = self.log.lock().await;
300 log.push(Log {
301 status_left: response_parts_left.status.clone(),
302 status_right: response_parts_right.status.clone(),
303 })
304 }
305 }
306
307 let comparator = StatusCodeComparator::new();
308
309 let mut scientist = Scientist::new(
310 get(|| async { "LEFT" }),
311 any(|| async { "RIGHT" }),
312 comparator.clone(),
313 );
314
315 let req = http::Request::builder()
316 .uri("http://localhost/")
317 .method("GET")
318 .body(hyper::body::Body::from(vec![1, 2, 3]))
319 .unwrap();
320
321 let _response = scientist.call(req).await.unwrap();
322
323 {
324 let log = comparator.log.lock().await;
325 assert_eq!(1, log.len());
326 assert_eq!(http::StatusCode::OK, log[0].status_left);
327 assert_eq!(http::StatusCode::OK, log[0].status_right);
328 }
329
330 let req = http::Request::builder()
331 .uri("http://localhost/")
332 .method("POST")
333 .body(hyper::body::Body::from(vec![1, 2, 3]))
334 .unwrap();
335
336 let _response = scientist.call(req).await.unwrap();
337
338 {
339 let log = comparator.log.lock().await;
340 assert_eq!(2, log.len());
341 assert_eq!(http::StatusCode::METHOD_NOT_ALLOWED, log[1].status_left);
342 assert_eq!(http::StatusCode::OK, log[1].status_right);
343 }
344 }
345
346 #[tokio::test]
347 async fn can_be_used_in_axum() {
348 #[derive(Clone)]
349 struct NoOpComparator;
350
351 #[async_trait::async_trait]
352 impl ResponseComparator for NoOpComparator {
353 async fn compare_response_bodies(
354 &self,
355 _response_parts_left: &Parts,
356 _response_body_left: &mut BodyProxy,
357 _response_parts_right: &Parts,
358 _response_body_right: &mut BodyProxy,
359 ) {
360 }
361 }
362
363 let scientist = Scientist::new(
364 axum::routing::get(|| async { "LEFT" }),
365 axum::routing::any(|| async { "RIGHT" }),
366 NoOpComparator,
367 );
368
369 let _: axum::Router<()> = axum::Router::new()
370 .route_service("/", scientist)
371 .with_state(());
372 }
373}