alloy_transport_http/layers/
auth.rs1use crate::hyper::{header::AUTHORIZATION, Request, Response};
2use alloy_rpc_types_engine::{Claims, JwtSecret};
3use alloy_transport::{TransportError, TransportErrorKind};
4use hyper::header::HeaderValue;
5use jsonwebtoken::get_current_timestamp;
6use std::{
7    future::Future,
8    pin::Pin,
9    time::{Duration, SystemTime, UNIX_EPOCH},
10};
11use tower::{Layer, Service};
12
13#[derive(Clone, Debug)]
18pub struct AuthLayer {
19    secret: JwtSecret,
20    latency_buffer: u64,
21}
22
23impl AuthLayer {
24    pub const fn new(secret: JwtSecret) -> Self {
26        Self { secret, latency_buffer: 5000 }
27    }
28
29    pub const fn with_latency_buffer(self, latency_buffer: u64) -> Self {
34        Self { latency_buffer, ..self }
35    }
36}
37
38impl<S> Layer<S> for AuthLayer {
39    type Service = AuthService<S>;
40
41    fn layer(&self, inner: S) -> Self::Service {
42        AuthService::new(inner, self.secret, self.latency_buffer)
43    }
44}
45
46#[derive(Clone, Debug)]
48pub struct AuthService<S> {
49    inner: S,
50    secret: JwtSecret,
51    latency_buffer: u64,
53    most_recent_claim: Option<Claims>,
54}
55
56impl<S> AuthService<S> {
57    pub const fn new(inner: S, secret: JwtSecret, latency_buffer: u64) -> Self {
59        Self { inner, secret, latency_buffer, most_recent_claim: None }
60    }
61
62    fn validate(&self) -> bool {
66        if let Some(claim) = self.most_recent_claim.as_ref() {
67            let curr_secs = get_current_timestamp();
68            if claim.iat.abs_diff(curr_secs) * 1000 > self.latency_buffer {
69                return true;
70            }
71        }
72
73        false
74    }
75
76    fn create_token_from_secret(&mut self) -> Result<String, jsonwebtoken::errors::Error> {
80        let claims = Claims {
81            iat: (SystemTime::now().duration_since(UNIX_EPOCH).unwrap() + Duration::from_secs(60))
82                .as_secs(),
83            exp: None,
84        };
85
86        self.most_recent_claim = Some(claims);
87
88        let token = self.secret.encode(&claims)?;
89
90        Ok(format!("Bearer {token}"))
91    }
92}
93
94impl<S, B, ResBody> Service<Request<B>> for AuthService<S>
95where
96    S: Service<hyper::Request<B>, Response = Response<ResBody>> + Clone + Send + Sync + 'static,
97    S::Future: Send,
98    S::Error: std::error::Error + Send + Sync + 'static,
99    B: From<Vec<u8>> + Send + 'static + Clone + Sync,
100    ResBody: hyper::body::Body + Send + 'static,
101    ResBody::Error: std::error::Error + Send + Sync + 'static,
102    ResBody::Data: Send,
103{
104    type Response = Response<ResBody>;
105    type Error = TransportError;
106    type Future =
107        Pin<Box<dyn Future<Output = Result<Response<ResBody>, Self::Error>> + Send + 'static>>;
108
109    fn poll_ready(
110        &mut self,
111        cx: &mut std::task::Context<'_>,
112    ) -> std::task::Poll<Result<(), Self::Error>> {
113        self.inner.poll_ready(cx).map_err(TransportErrorKind::custom)
114    }
115
116    fn call(&mut self, req: Request<B>) -> Self::Future {
117        let mut req = req;
118        let res = if self.validate() {
119            self.secret.encode(self.most_recent_claim.as_ref().unwrap())
121        } else {
122            self.create_token_from_secret()
124        };
125
126        match res {
127            Ok(token) => {
128                req.headers_mut().insert(AUTHORIZATION, HeaderValue::from_str(&token).unwrap());
129
130                let mut this = self.clone();
131
132                Box::pin(
133                    async move { this.inner.call(req).await.map_err(TransportErrorKind::custom) },
134                )
135            }
136            Err(e) => {
137                let e = TransportErrorKind::custom(e);
138                Box::pin(async move { Err(e) })
139            }
140        }
141    }
142}