mev_share_rpc_api/
auth.rs1use std::{
6 error::Error,
7 task::{Context, Poll},
8};
9
10use ethers_core::{types::H256, utils::keccak256};
11use ethers_signers::Signer;
12use futures_util::future::BoxFuture;
13
14use http::{header::HeaderValue, HeaderName, Request};
15use hyper::Body;
16
17use tower::{Layer, Service};
18
19const FLASHBOTS_HEADER: HeaderName = HeaderName::from_static("x-flashbots-signature");
20
21#[derive(Clone)]
23pub struct FlashbotsSignerLayer<S> {
24 signer: S,
25}
26
27impl<S> FlashbotsSignerLayer<S> {
28 pub fn new(signer: S) -> Self {
30 FlashbotsSignerLayer { signer }
31 }
32}
33
34impl<S: Clone, I> Layer<I> for FlashbotsSignerLayer<S> {
35 type Service = FlashbotsSigner<S, I>;
36
37 fn layer(&self, inner: I) -> Self::Service {
38 FlashbotsSigner { signer: self.signer.clone(), inner }
39 }
40}
41
42#[derive(Clone)]
45pub struct FlashbotsSigner<S, I> {
46 signer: S,
47 inner: I,
48}
49
50impl<S, I> Service<Request<Body>> for FlashbotsSigner<S, I>
51where
52 I: Service<Request<Body>> + Clone + Send + 'static,
53 I::Future: Send,
54 I::Error: Into<Box<dyn Error + Send + Sync>> + 'static,
55 S: Signer + Clone + Send + 'static,
56{
57 type Response = I::Response;
58 type Error = Box<dyn Error + Send + Sync>;
59 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
60
61 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
62 self.inner.poll_ready(cx).map_err(Into::into)
63 }
64
65 fn call(&mut self, request: Request<Body>) -> Self::Future {
66 let clone = self.inner.clone();
67 let mut inner = std::mem::replace(&mut self.inner, clone);
69 let signer = self.signer.clone();
70
71 let (mut parts, body) = request.into_parts();
72
73 if parts.method != http::Method::POST {
75 return Box::pin(async move {
76 Err(format!("Invalid method: {}", parts.method.as_str()).into())
77 })
78 }
79
80 let is_json = parts
82 .headers
83 .get(http::header::CONTENT_TYPE)
84 .map(|v| v == HeaderValue::from_static("application/json"))
85 .unwrap_or(false);
86 let has_sig = parts.headers.contains_key(FLASHBOTS_HEADER);
87
88 if !is_json || has_sig {
89 return Box::pin(async move {
90 let request = Request::from_parts(parts, body);
91 inner.call(request).await.map_err(Into::into)
92 })
93 }
94
95 Box::pin(async move {
97 let body_bytes = hyper::body::to_bytes(body).await?;
98
99 let signature = signer
101 .sign_message(format!("0x{:x}", H256::from(keccak256(body_bytes.as_ref()))))
102 .await?;
103
104 let header_val =
105 HeaderValue::from_str(&format!("{:?}:0x{}", signer.address(), signature))
106 .expect("Header contains invalid characters");
107 parts.headers.insert(FLASHBOTS_HEADER, header_val);
108
109 let request = Request::from_parts(parts, Body::from(body_bytes.clone()));
110 inner.call(request).await.map_err(Into::into)
111 })
112 }
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118 use ethers_core::rand::thread_rng;
119 use ethers_signers::LocalWallet;
120 use http::Response;
121 use hyper::Body;
122 use std::convert::Infallible;
123 use tower::{service_fn, ServiceExt};
124
125 #[tokio::test]
126 async fn test_signature() {
127 let fb_signer = LocalWallet::new(&mut thread_rng());
128
129 let svc = FlashbotsSigner {
131 signer: fb_signer.clone(),
132 inner: service_fn(|_req: Request<Body>| async {
133 let (parts, _) = _req.into_parts();
134
135 let mut res = Response::builder();
136 for (k, v) in parts.headers.iter() {
137 res = res.header(k, v);
138 }
139 let res = res.body(Body::empty()).unwrap();
140 Ok::<_, Infallible>(res)
141 }),
142 };
143
144 let bytes = vec![1u8; 32];
146 let req = Request::builder()
147 .method(http::Method::POST)
148 .header(http::header::CONTENT_TYPE, "application/json")
149 .body(Body::from(bytes.clone()))
150 .unwrap();
151
152 let res = svc.oneshot(req).await.unwrap();
153
154 let header = res.headers().get("x-flashbots-signature").unwrap();
155 let header = header.to_str().unwrap();
156 let header = header.split(":0x").collect::<Vec<_>>();
157 let header_address = header[0];
158 let header_signature = header[1];
159
160 let signer_address = format!("{:?}", fb_signer.address());
161 let expected_signature = fb_signer
162 .sign_message(format!("0x{:x}", H256::from(keccak256(bytes.clone()))))
163 .await
164 .unwrap()
165 .to_string();
166
167 assert_eq!(header_address, signer_address);
169 assert_eq!(header_signature, expected_signature);
170 }
171
172 #[tokio::test]
173 async fn test_skips_non_json() {
174 let fb_signer = LocalWallet::new(&mut thread_rng());
175
176 let svc = FlashbotsSigner {
178 signer: fb_signer.clone(),
179 inner: service_fn(|_req: Request<Body>| async {
180 let (parts, _) = _req.into_parts();
181
182 let mut res = Response::builder();
183 for (k, v) in parts.headers.iter() {
184 res = res.header(k, v);
185 }
186 let res = res.body(Body::empty()).unwrap();
187 Ok::<_, Infallible>(res)
188 }),
189 };
190
191 let bytes = vec![1u8; 32];
193 let req = Request::builder()
194 .method(http::Method::POST)
195 .header(http::header::CONTENT_TYPE, "text/plain")
196 .body(Body::from(bytes.clone()))
197 .unwrap();
198
199 let res = svc.oneshot(req).await.unwrap();
200
201 let header = res.headers().get("x-flashbots-signature");
203 assert!(header.is_none());
204 }
205
206 #[tokio::test]
207 async fn test_returns_error_when_not_post() {
208 let fb_signer = LocalWallet::new(&mut thread_rng());
209
210 let svc = FlashbotsSigner {
212 signer: fb_signer.clone(),
213 inner: service_fn(|_req: Request<Body>| async {
214 let (parts, _) = _req.into_parts();
215
216 let mut res = Response::builder();
217 for (k, v) in parts.headers.iter() {
218 res = res.header(k, v);
219 }
220 let res = res.body(Body::empty()).unwrap();
221 Ok::<_, Infallible>(res)
222 }),
223 };
224
225 let bytes = vec![1u8; 32];
227 let req = Request::builder()
228 .method(http::Method::GET)
229 .header(http::header::CONTENT_TYPE, "application/json")
230 .body(Body::from(bytes.clone()))
231 .unwrap();
232
233 let res = svc.oneshot(req).await;
234
235 assert!(res.is_err());
237 }
238}