alloy_transport_http/layers/
auth.rs1use crate::hyper::{header::AUTHORIZATION, Request, Response};
2use alloy_rpc_types_engine::{Claims, JwtSecret};
3use alloy_transport::{TransportError, TransportErrorKind};
4use hyper::header::HeaderValue;
5use jsonwebtoken::get_current_timestamp;
6use std::{
7 future::Future,
8 pin::Pin,
9 time::{SystemTime, UNIX_EPOCH},
10};
11use tower::{Layer, Service};
12
13#[derive(Clone, Debug)]
18pub struct AuthLayer {
19 secret: JwtSecret,
20 latency_buffer: u64,
21}
22
23impl AuthLayer {
24 pub const fn new(secret: JwtSecret) -> Self {
26 Self { secret, latency_buffer: 5000 }
27 }
28
29 pub const fn with_latency_buffer(self, latency_buffer: u64) -> Self {
34 Self { latency_buffer, ..self }
35 }
36}
37
38impl<S> Layer<S> for AuthLayer {
39 type Service = AuthService<S>;
40
41 fn layer(&self, inner: S) -> Self::Service {
42 AuthService::new(inner, self.secret, self.latency_buffer)
43 }
44}
45
46#[derive(Clone, Debug)]
48pub struct AuthService<S> {
49 inner: S,
50 secret: JwtSecret,
51 latency_buffer: u64,
53 most_recent_claim: Option<Claims>,
54}
55
56impl<S> AuthService<S> {
57 pub const fn new(inner: S, secret: JwtSecret, latency_buffer: u64) -> Self {
59 Self { inner, secret, latency_buffer, most_recent_claim: None }
60 }
61
62 fn validate(&self) -> bool {
68 if let Some(claim) = self.most_recent_claim.as_ref() {
69 let curr_secs = get_current_timestamp();
70 if claim.iat.abs_diff(curr_secs) * 1000 <= self.latency_buffer {
73 return true;
74 }
75 }
76
77 false
78 }
79
80 fn create_token_from_secret(&mut self) -> Result<String, jsonwebtoken::errors::Error> {
85 let claims = Claims {
86 iat: SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs(),
88 exp: None,
89 };
90
91 self.most_recent_claim = Some(claims);
92
93 let token = self.secret.encode(&claims)?;
94
95 Ok(format!("Bearer {token}"))
96 }
97}
98
99impl<S, B, ResBody> Service<Request<B>> for AuthService<S>
100where
101 S: Service<hyper::Request<B>, Response = Response<ResBody>> + Clone + Send + Sync + 'static,
102 S::Future: Send,
103 S::Error: std::error::Error + Send + Sync + 'static,
104 B: From<Vec<u8>> + Send + 'static + Clone + Sync,
105 ResBody: hyper::body::Body + Send + 'static,
106 ResBody::Error: std::error::Error + Send + Sync + 'static,
107 ResBody::Data: Send,
108{
109 type Response = Response<ResBody>;
110 type Error = TransportError;
111 type Future =
112 Pin<Box<dyn Future<Output = Result<Response<ResBody>, Self::Error>> + Send + 'static>>;
113
114 fn poll_ready(
115 &mut self,
116 cx: &mut std::task::Context<'_>,
117 ) -> std::task::Poll<Result<(), Self::Error>> {
118 self.inner.poll_ready(cx).map_err(TransportErrorKind::custom)
119 }
120
121 fn call(&mut self, req: Request<B>) -> Self::Future {
122 let mut req = req;
123 let res = if self.validate() {
124 self.secret.encode(self.most_recent_claim.as_ref().unwrap())
126 } else {
127 self.create_token_from_secret()
129 };
130
131 match res {
132 Ok(token) => {
133 req.headers_mut().insert(AUTHORIZATION, HeaderValue::from_str(&token).unwrap());
134
135 let mut this = self.clone();
136
137 Box::pin(
138 async move { this.inner.call(req).await.map_err(TransportErrorKind::custom) },
139 )
140 }
141 Err(e) => {
142 let e = TransportErrorKind::custom(e);
143 Box::pin(async move { Err(e) })
144 }
145 }
146 }
147}