third_wheel/proxy/
mitm.rs

1use std::pin::Pin;
2
3use crate::error::Error;
4use futures::Future;
5use http::{header::HeaderName, Request, Response};
6use hyper::{client::conn::SendRequest, service::Service, Body};
7use log::error;
8use tokio::sync::{mpsc, oneshot};
9use tower::Layer;
10
11pub(crate) struct RequestSendingSynchronizer {
12    request_sender: SendRequest<Body>,
13    receiver: mpsc::UnboundedReceiver<(
14        oneshot::Sender<Result<Response<Body>, Error>>,
15        Request<Body>,
16    )>,
17}
18
19impl RequestSendingSynchronizer {
20    pub(crate) fn new(
21        request_sender: SendRequest<Body>,
22        receiver: mpsc::UnboundedReceiver<(
23            oneshot::Sender<Result<Response<Body>, Error>>,
24            Request<Body>,
25        )>,
26    ) -> Self {
27        Self {
28            request_sender,
29            receiver,
30        }
31    }
32
33    pub(crate) async fn run(&mut self) {
34        while let Some((sender, mut request)) = self.receiver.recv().await {
35            let relativized_uri = request
36                .uri()
37                .path_and_query()
38                .ok_or_else(|| Error::RequestError("URI did not contain a path".to_string()))
39                .and_then(|path| {
40                    path.as_str()
41                        .parse()
42                        .map_err(|_| Error::RequestError("Given URI was invalid".to_string()))
43                });
44            let response_fut = relativized_uri.and_then(|path| {
45                *request.uri_mut() = path;
46                // TODO: don't have this unnecessary overhead every time
47                let proxy_connection: HeaderName = HeaderName::from_lowercase(b"proxy-connection")
48                    .expect("Infallible: hardcoded header name");
49                request.headers_mut().remove(&proxy_connection);
50                Ok(self.request_sender.send_request(request))
51            });
52            let response_to_send = match response_fut {
53                Ok(response) => response.await.map_err(|e| e.into()),
54                Err(e) => Err(e),
55            };
56            if let Err(e) = sender.send(response_to_send) {
57                error!("Requester not available to receive request {:?}", e);
58            }
59        }
60    }
61}
62
63/// A service that will proxy traffic to a target server and return unmodified responses
64#[derive(Clone)]
65pub struct ThirdWheel {
66    sender: mpsc::UnboundedSender<(
67        oneshot::Sender<Result<Response<Body>, Error>>,
68        Request<Body>,
69    )>,
70}
71
72impl ThirdWheel {
73    pub(crate) fn new(
74        sender: mpsc::UnboundedSender<(
75            oneshot::Sender<Result<Response<Body>, Error>>,
76            Request<Body>,
77        )>,
78    ) -> Self {
79        Self { sender }
80    }
81}
82
83impl Service<Request<Body>> for ThirdWheel {
84    type Response = Response<Body>;
85
86    type Error = crate::error::Error;
87
88    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
89
90    fn poll_ready(
91        &mut self,
92        _: &mut std::task::Context<'_>,
93    ) -> std::task::Poll<Result<(), Self::Error>> {
94        std::task::Poll::Ready(Ok(()))
95    }
96
97    /// ThirdWheel performs very little modification of the request before
98    /// transmitting it, but it does remove the proxy-connection header to
99    /// ensure this is not passed to the target
100    fn call(&mut self, request: Request<Body>) -> Self::Future {
101        let (response_sender, response_receiver) = oneshot::channel();
102        let sender = self.sender.clone();
103        let fut = async move {
104            //TODO: clarify what errors are possible here
105            sender.send((response_sender, request)).map_err(|_| {
106                Error::ServerError("Failed to connect to server correctly".to_string())
107            })?;
108            response_receiver
109                .await
110                .map_err(|_| Error::ServerError("Failed to get response from server".to_string()))?
111        };
112        return Box::pin(fut);
113    }
114}
115
116#[derive(Clone)]
117pub struct MitmService<F: Clone, S: Clone> {
118    f: F,
119    inner: S,
120}
121
122impl<F, S> Service<Request<Body>> for MitmService<F, S>
123where
124    S: Service<Request<Body>, Error = crate::error::Error> + Clone,
125    F: FnMut(
126            Request<Body>,
127            S,
128        )
129            -> Pin<Box<dyn Future<Output = Result<Response<Body>, crate::error::Error>> + Send>>
130        + Clone,
131{
132    type Response = Response<Body>;
133    type Error = crate::error::Error;
134
135    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
136
137    fn poll_ready(
138        &mut self,
139        cx: &mut std::task::Context<'_>,
140    ) -> std::task::Poll<Result<(), Self::Error>> {
141        self.inner.poll_ready(cx)
142    }
143
144    fn call(&mut self, req: Request<Body>) -> Self::Future {
145        (self.f)(req, self.inner.clone())
146    }
147}
148
149#[derive(Clone)]
150pub struct MitmLayer<F: Clone> {
151    f: F,
152}
153
154impl<S: Clone, F: Clone> Layer<S> for MitmLayer<F> {
155    type Service = MitmService<F, S>;
156    fn layer(&self, inner: S) -> Self::Service {
157        MitmService {
158            f: self.f.clone(),
159            inner,
160        }
161    }
162}
163
164/// A convenience function for generating man-in-the-middle services
165///
166/// This function generates a struct that implements the necessary traits to be
167/// used as a man-in-the-middle service and will suffice for many use cases.
168/// ```ignore
169/// let mitm = mitm_layer(|req: Request<Body>, mut third_wheel: ThirdWheel| third_wheel.call(req));
170/// let mitm_proxy = MitmProxy::builder(mitm, ca).build();
171/// ```
172pub fn mitm_layer<F>(f: F) -> MitmLayer<F>
173where
174    F: FnMut(
175            Request<Body>,
176            ThirdWheel,
177        )
178            -> Pin<Box<dyn Future<Output = Result<Response<Body>, crate::error::Error>> + Send>>
179        + Clone,
180{
181    return MitmLayer { f };
182}