mev_share_rpc_api/
auth.rs

1//! A layer responsible for implementing flashbots-style authentication
2//! by signing the request body with a private key and adding the signature
3//! to the request headers.
4
5use std::{
6    error::Error,
7    task::{Context, Poll},
8};
9
10use ethers_core::{types::H256, utils::keccak256};
11use ethers_signers::Signer;
12use futures_util::future::BoxFuture;
13
14use http::{header::HeaderValue, HeaderName, Request};
15use hyper::Body;
16
17use tower::{Layer, Service};
18
19const FLASHBOTS_HEADER: HeaderName = HeaderName::from_static("x-flashbots-signature");
20
21/// Layer that applies [`FlashbotsSigner`] which adds a request header with a signed payload.
22#[derive(Clone)]
23pub struct FlashbotsSignerLayer<S> {
24    signer: S,
25}
26
27impl<S> FlashbotsSignerLayer<S> {
28    /// Creates a new [`FlashbotsSignerLayer`] with the given signer.
29    pub fn new(signer: S) -> Self {
30        FlashbotsSignerLayer { signer }
31    }
32}
33
34impl<S: Clone, I> Layer<I> for FlashbotsSignerLayer<S> {
35    type Service = FlashbotsSigner<S, I>;
36
37    fn layer(&self, inner: I) -> Self::Service {
38        FlashbotsSigner { signer: self.signer.clone(), inner }
39    }
40}
41
42/// Middleware that signs the request body and adds the signature to the x-flashbots-signature
43/// header. For more info, see <https://docs.flashbots.net/flashbots-auction/searchers/advanced/rpc-endpoint#authentication>
44#[derive(Clone)]
45pub struct FlashbotsSigner<S, I> {
46    signer: S,
47    inner: I,
48}
49
50impl<S, I> Service<Request<Body>> for FlashbotsSigner<S, I>
51where
52    I: Service<Request<Body>> + Clone + Send + 'static,
53    I::Future: Send,
54    I::Error: Into<Box<dyn Error + Send + Sync>> + 'static,
55    S: Signer + Clone + Send + 'static,
56{
57    type Response = I::Response;
58    type Error = Box<dyn Error + Send + Sync>;
59    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
60
61    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
62        self.inner.poll_ready(cx).map_err(Into::into)
63    }
64
65    fn call(&mut self, request: Request<Body>) -> Self::Future {
66        let clone = self.inner.clone();
67        // wait for service to be ready
68        let mut inner = std::mem::replace(&mut self.inner, clone);
69        let signer = self.signer.clone();
70
71        let (mut parts, body) = request.into_parts();
72
73        // if method is not POST, return an error.
74        if parts.method != http::Method::POST {
75            return Box::pin(async move {
76                Err(format!("Invalid method: {}", parts.method.as_str()).into())
77            })
78        }
79
80        // if content-type is not json, or signature already exists, just pass through the request
81        let is_json = parts
82            .headers
83            .get(http::header::CONTENT_TYPE)
84            .map(|v| v == HeaderValue::from_static("application/json"))
85            .unwrap_or(false);
86        let has_sig = parts.headers.contains_key(FLASHBOTS_HEADER);
87
88        if !is_json || has_sig {
89            return Box::pin(async move {
90                let request = Request::from_parts(parts, body);
91                inner.call(request).await.map_err(Into::into)
92            })
93        }
94
95        // otherwise, sign the request body and add the signature to the header
96        Box::pin(async move {
97            let body_bytes = hyper::body::to_bytes(body).await?;
98
99            // sign request body and insert header
100            let signature = signer
101                .sign_message(format!("0x{:x}", H256::from(keccak256(body_bytes.as_ref()))))
102                .await?;
103
104            let header_val =
105                HeaderValue::from_str(&format!("{:?}:0x{}", signer.address(), signature))
106                    .expect("Header contains invalid characters");
107            parts.headers.insert(FLASHBOTS_HEADER, header_val);
108
109            let request = Request::from_parts(parts, Body::from(body_bytes.clone()));
110            inner.call(request).await.map_err(Into::into)
111        })
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118    use ethers_core::rand::thread_rng;
119    use ethers_signers::LocalWallet;
120    use http::Response;
121    use hyper::Body;
122    use std::convert::Infallible;
123    use tower::{service_fn, ServiceExt};
124
125    #[tokio::test]
126    async fn test_signature() {
127        let fb_signer = LocalWallet::new(&mut thread_rng());
128
129        // mock service that returns the request headers
130        let svc = FlashbotsSigner {
131            signer: fb_signer.clone(),
132            inner: service_fn(|_req: Request<Body>| async {
133                let (parts, _) = _req.into_parts();
134
135                let mut res = Response::builder();
136                for (k, v) in parts.headers.iter() {
137                    res = res.header(k, v);
138                }
139                let res = res.body(Body::empty()).unwrap();
140                Ok::<_, Infallible>(res)
141            }),
142        };
143
144        // build request
145        let bytes = vec![1u8; 32];
146        let req = Request::builder()
147            .method(http::Method::POST)
148            .header(http::header::CONTENT_TYPE, "application/json")
149            .body(Body::from(bytes.clone()))
150            .unwrap();
151
152        let res = svc.oneshot(req).await.unwrap();
153
154        let header = res.headers().get("x-flashbots-signature").unwrap();
155        let header = header.to_str().unwrap();
156        let header = header.split(":0x").collect::<Vec<_>>();
157        let header_address = header[0];
158        let header_signature = header[1];
159
160        let signer_address = format!("{:?}", fb_signer.address());
161        let expected_signature = fb_signer
162            .sign_message(format!("0x{:x}", H256::from(keccak256(bytes.clone()))))
163            .await
164            .unwrap()
165            .to_string();
166
167        // verify that the header contains expected address and signature
168        assert_eq!(header_address, signer_address);
169        assert_eq!(header_signature, expected_signature);
170    }
171
172    #[tokio::test]
173    async fn test_skips_non_json() {
174        let fb_signer = LocalWallet::new(&mut thread_rng());
175
176        // mock service that returns the request headers
177        let svc = FlashbotsSigner {
178            signer: fb_signer.clone(),
179            inner: service_fn(|_req: Request<Body>| async {
180                let (parts, _) = _req.into_parts();
181
182                let mut res = Response::builder();
183                for (k, v) in parts.headers.iter() {
184                    res = res.header(k, v);
185                }
186                let res = res.body(Body::empty()).unwrap();
187                Ok::<_, Infallible>(res)
188            }),
189        };
190
191        // build plain text request
192        let bytes = vec![1u8; 32];
193        let req = Request::builder()
194            .method(http::Method::POST)
195            .header(http::header::CONTENT_TYPE, "text/plain")
196            .body(Body::from(bytes.clone()))
197            .unwrap();
198
199        let res = svc.oneshot(req).await.unwrap();
200
201        // response should not contain a signature header
202        let header = res.headers().get("x-flashbots-signature");
203        assert!(header.is_none());
204    }
205
206    #[tokio::test]
207    async fn test_returns_error_when_not_post() {
208        let fb_signer = LocalWallet::new(&mut thread_rng());
209
210        // mock service that returns the request headers
211        let svc = FlashbotsSigner {
212            signer: fb_signer.clone(),
213            inner: service_fn(|_req: Request<Body>| async {
214                let (parts, _) = _req.into_parts();
215
216                let mut res = Response::builder();
217                for (k, v) in parts.headers.iter() {
218                    res = res.header(k, v);
219                }
220                let res = res.body(Body::empty()).unwrap();
221                Ok::<_, Infallible>(res)
222            }),
223        };
224
225        // build plain text request
226        let bytes = vec![1u8; 32];
227        let req = Request::builder()
228            .method(http::Method::GET)
229            .header(http::header::CONTENT_TYPE, "application/json")
230            .body(Body::from(bytes.clone()))
231            .unwrap();
232
233        let res = svc.oneshot(req).await;
234
235        // should be an error
236        assert!(res.is_err());
237    }
238}