alloy_transport_http/layers/
auth.rs

1use crate::hyper::{header::AUTHORIZATION, Request, Response};
2use alloy_rpc_types_engine::{Claims, JwtSecret};
3use alloy_transport::{TransportError, TransportErrorKind};
4use hyper::header::HeaderValue;
5use jsonwebtoken::get_current_timestamp;
6use std::{
7    future::Future,
8    pin::Pin,
9    time::{SystemTime, UNIX_EPOCH},
10};
11use tower::{Layer, Service};
12
13/// The [`AuthLayer`] uses the provided [`JwtSecret`] to generate and validate the jwt token
14/// in the requests.
15///
16/// The generated token is inserted into the [`AUTHORIZATION`] header of the request.
17#[derive(Clone, Debug)]
18pub struct AuthLayer {
19    secret: JwtSecret,
20    latency_buffer: u64,
21}
22
23impl AuthLayer {
24    /// Create a new [`AuthLayer`].
25    pub const fn new(secret: JwtSecret) -> Self {
26        Self { secret, latency_buffer: 5000 }
27    }
28
29    /// We use this buffer to perform an extra check on the `iat` field to prevent sending any
30    /// requests with tokens that are valid now but may not be upon reaching the server.
31    ///
32    /// In milliseconds. Default is 5s.
33    pub const fn with_latency_buffer(self, latency_buffer: u64) -> Self {
34        Self { latency_buffer, ..self }
35    }
36}
37
38impl<S> Layer<S> for AuthLayer {
39    type Service = AuthService<S>;
40
41    fn layer(&self, inner: S) -> Self::Service {
42        AuthService::new(inner, self.secret, self.latency_buffer)
43    }
44}
45
46/// A service that generates and validates the jwt token in the requests using the provided secret.
47#[derive(Clone, Debug)]
48pub struct AuthService<S> {
49    inner: S,
50    secret: JwtSecret,
51    /// In milliseconds.
52    latency_buffer: u64,
53    most_recent_claim: Option<Claims>,
54}
55
56impl<S> AuthService<S> {
57    /// Create a new [`AuthService`] with the given inner service.
58    pub const fn new(inner: S, secret: JwtSecret, latency_buffer: u64) -> Self {
59        Self { inner, secret, latency_buffer, most_recent_claim: None }
60    }
61
62    /// Validate the token in the request headers.
63    ///
64    /// Returns `true` if the token is still valid and `iat` is within the grace buffer.
65    /// A token is considered valid if the time difference between the current time
66    /// and the issued-at time is within the configured latency buffer.
67    fn validate(&self) -> bool {
68        if let Some(claim) = self.most_recent_claim.as_ref() {
69            let curr_secs = get_current_timestamp();
70            // Check if the token is not too old (within latency buffer)
71            // Convert seconds to milliseconds for comparison with latency_buffer
72            if claim.iat.abs_diff(curr_secs) * 1000 <= self.latency_buffer {
73                return true;
74            }
75        }
76
77        false
78    }
79
80    /// Create a new token from the secret.
81    ///
82    /// Updates the most_recent_claim with the new claim.
83    /// The issued-at time is set to the current timestamp to ensure proper validation.
84    fn create_token_from_secret(&mut self) -> Result<String, jsonwebtoken::errors::Error> {
85        let claims = Claims {
86            // Set iat to current time (not future time) for proper validation
87            iat: SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs(),
88            exp: None,
89        };
90
91        self.most_recent_claim = Some(claims);
92
93        let token = self.secret.encode(&claims)?;
94
95        Ok(format!("Bearer {token}"))
96    }
97}
98
99impl<S, B, ResBody> Service<Request<B>> for AuthService<S>
100where
101    S: Service<hyper::Request<B>, Response = Response<ResBody>> + Clone + Send + Sync + 'static,
102    S::Future: Send,
103    S::Error: std::error::Error + Send + Sync + 'static,
104    B: From<Vec<u8>> + Send + 'static + Clone + Sync,
105    ResBody: hyper::body::Body + Send + 'static,
106    ResBody::Error: std::error::Error + Send + Sync + 'static,
107    ResBody::Data: Send,
108{
109    type Response = Response<ResBody>;
110    type Error = TransportError;
111    type Future =
112        Pin<Box<dyn Future<Output = Result<Response<ResBody>, Self::Error>> + Send + 'static>>;
113
114    fn poll_ready(
115        &mut self,
116        cx: &mut std::task::Context<'_>,
117    ) -> std::task::Poll<Result<(), Self::Error>> {
118        self.inner.poll_ready(cx).map_err(TransportErrorKind::custom)
119    }
120
121    fn call(&mut self, req: Request<B>) -> Self::Future {
122        let mut req = req;
123        let res = if self.validate() {
124            // Encodes the most recent claim into a token.
125            self.secret.encode(self.most_recent_claim.as_ref().unwrap())
126        } else {
127            // Creates a new Claim and encodes it into a token.
128            self.create_token_from_secret()
129        };
130
131        match res {
132            Ok(token) => {
133                req.headers_mut().insert(AUTHORIZATION, HeaderValue::from_str(&token).unwrap());
134
135                let mut this = self.clone();
136
137                Box::pin(
138                    async move { this.inner.call(req).await.map_err(TransportErrorKind::custom) },
139                )
140            }
141            Err(e) => {
142                let e = TransportErrorKind::custom(e);
143                Box::pin(async move { Err(e) })
144            }
145        }
146    }
147}