Skip to main content

modkit_auth/oauth2/
layer.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use http::header::{AUTHORIZATION, HeaderName};
6use http::{HeaderValue, Request, Response};
7use tower::{Layer, Service};
8
9use super::token::Token;
10use modkit_http::HttpError;
11
12/// Tower layer that injects a bearer token into outbound HTTP requests.
13///
14/// Wraps an [`Token`] handle and sets the `Authorization: Bearer <token>`
15/// header (or a custom header) on every request before forwarding it to the
16/// inner service.
17#[derive(Clone, Debug)]
18pub struct BearerAuthLayer {
19    token: Token,
20    header_name: HeaderName,
21}
22
23impl BearerAuthLayer {
24    /// Create a layer that injects `Authorization: Bearer <token>`.
25    #[must_use]
26    pub fn new(token: Token) -> Self {
27        Self {
28            token,
29            header_name: AUTHORIZATION,
30        }
31    }
32
33    /// Create a layer that injects `<header_name>: Bearer <token>`.
34    #[must_use]
35    pub fn with_header_name(token: Token, header_name: HeaderName) -> Self {
36        Self { token, header_name }
37    }
38}
39
40impl<S> Layer<S> for BearerAuthLayer {
41    type Service = BearerAuthService<S>;
42
43    fn layer(&self, inner: S) -> Self::Service {
44        BearerAuthService {
45            inner,
46            token: self.token.clone(),
47            header_name: self.header_name.clone(),
48        }
49    }
50}
51
52/// Tower service that injects a bearer token header before forwarding the
53/// request to the inner service.
54///
55/// Created by [`BearerAuthLayer`].
56#[derive(Clone, Debug)]
57pub struct BearerAuthService<S> {
58    inner: S,
59    token: Token,
60    header_name: HeaderName,
61}
62
63impl<S, B, ResBody> Service<Request<B>> for BearerAuthService<S>
64where
65    S: Service<Request<B>, Response = Response<ResBody>, Error = HttpError>
66        + Clone
67        + Send
68        + 'static,
69    S::Future: Send,
70    B: Send + 'static,
71    ResBody: Send + 'static,
72{
73    type Response = Response<ResBody>;
74    type Error = HttpError;
75    type Future = Pin<Box<dyn Future<Output = Result<Response<ResBody>, HttpError>> + Send>>;
76
77    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
78        self.inner.poll_ready(cx)
79    }
80
81    fn call(&mut self, mut req: Request<B>) -> Self::Future {
82        let mut bearer_value = match self.token.get() {
83            Ok(secret) => {
84                let raw = zeroize::Zeroizing::new(format!("Bearer {}", secret.expose()));
85                match HeaderValue::from_str(&raw) {
86                    Ok(v) => v,
87                    Err(e) => return Box::pin(async { Err(HttpError::InvalidHeaderValue(e)) }),
88                }
89            }
90            Err(e) => {
91                return Box::pin(async { Err(HttpError::Transport(Box::new(e))) });
92            }
93        };
94        bearer_value.set_sensitive(true);
95
96        req.headers_mut()
97            .insert(self.header_name.clone(), bearer_value);
98
99        // Clone-swap pattern (Tower Service contract).
100        let clone = self.inner.clone();
101        let mut inner = std::mem::replace(&mut self.inner, clone);
102
103        Box::pin(async move { inner.call(req).await })
104    }
105}
106
107#[cfg(test)]
108#[cfg_attr(coverage_nightly, coverage(off))]
109mod tests {
110    use super::*;
111    use bytes::Bytes;
112    use http::{Method, Request, Response, StatusCode};
113    use http_body_util::Full;
114    use httpmock::prelude::*;
115    use modkit_utils::SecretString;
116    use std::time::Duration;
117    use url::Url;
118
119    use crate::oauth2::config::OAuthClientConfig;
120
121    /// Build a test config pointing at the given mock server.
122    fn test_config(server: &MockServer) -> OAuthClientConfig {
123        OAuthClientConfig {
124            token_endpoint: Some(
125                Url::parse(&format!("http://localhost:{}/token", server.port())).unwrap(),
126            ),
127            client_id: "test-client".into(),
128            client_secret: SecretString::new("test-secret"),
129            http_config: Some(modkit_http::HttpClientConfig::for_testing()),
130            jitter_max: Duration::from_millis(0),
131            min_refresh_period: Duration::from_millis(100),
132            ..Default::default()
133        }
134    }
135
136    fn token_json(token: &str, expires_in: u64) -> String {
137        format!(r#"{{"access_token":"{token}","expires_in":{expires_in},"token_type":"Bearer"}}"#)
138    }
139
140    // -- mock inner service ---------------------------------------------------
141
142    /// Mock service that captures request headers and returns 200 OK.
143    #[derive(Clone)]
144    struct CaptureHeaderService {
145        expected_header: HeaderName,
146        expected_value: String,
147    }
148
149    impl Service<Request<Full<Bytes>>> for CaptureHeaderService {
150        type Response = Response<Full<Bytes>>;
151        type Error = HttpError;
152        type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
153
154        fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
155            Poll::Ready(Ok(()))
156        }
157
158        fn call(&mut self, req: Request<Full<Bytes>>) -> Self::Future {
159            let header = req
160                .headers()
161                .get(&self.expected_header)
162                .expect("expected header not found")
163                .to_str()
164                .unwrap()
165                .to_owned();
166            let expected = self.expected_value.clone();
167
168            Box::pin(async move {
169                assert_eq!(header, expected);
170                Ok(Response::builder()
171                    .status(StatusCode::OK)
172                    .body(Full::new(Bytes::new()))
173                    .unwrap())
174            })
175        }
176    }
177
178    // -- trait assertions -----------------------------------------------------
179
180    #[test]
181    fn bearer_auth_is_send_sync_clone() {
182        fn assert_traits<T: Send + Sync + Clone>() {}
183        assert_traits::<BearerAuthLayer>();
184        assert_traits::<BearerAuthService<CaptureHeaderService>>();
185    }
186
187    // -- header injection -----------------------------------------------------
188
189    #[tokio::test]
190    async fn injects_authorization_header() {
191        let server = MockServer::start();
192
193        let _mock = server.mock(|when, then| {
194            when.method(POST).path("/token");
195            then.status(200)
196                .header("content-type", "application/json")
197                .body(token_json("tok-layer-test", 3600));
198        });
199
200        let token = Token::new(test_config(&server)).await.unwrap();
201        let inner = CaptureHeaderService {
202            expected_header: AUTHORIZATION,
203            expected_value: "Bearer tok-layer-test".into(),
204        };
205
206        let layer = BearerAuthLayer::new(token);
207        let mut svc = layer.layer(inner);
208
209        let req = Request::builder()
210            .method(Method::GET)
211            .uri("http://example.com/api")
212            .body(Full::new(Bytes::new()))
213            .unwrap();
214
215        Service::call(&mut svc, req).await.unwrap();
216    }
217
218    #[tokio::test]
219    async fn custom_header_name() {
220        let server = MockServer::start();
221
222        let _mock = server.mock(|when, then| {
223            when.method(POST).path("/token");
224            then.status(200)
225                .header("content-type", "application/json")
226                .body(token_json("tok-custom-hdr", 3600));
227        });
228
229        let token = Token::new(test_config(&server)).await.unwrap();
230        let custom_header = HeaderName::from_static("x-api-key");
231        let inner = CaptureHeaderService {
232            expected_header: custom_header.clone(),
233            expected_value: "Bearer tok-custom-hdr".into(),
234        };
235
236        let layer = BearerAuthLayer::with_header_name(token, custom_header);
237        let mut svc = layer.layer(inner);
238
239        let req = Request::builder()
240            .method(Method::GET)
241            .uri("http://example.com/api")
242            .body(Full::new(Bytes::new()))
243            .unwrap();
244
245        Service::call(&mut svc, req).await.unwrap();
246    }
247
248    // -- error path -----------------------------------------------------------
249
250    #[tokio::test]
251    async fn returns_error_when_token_expired() {
252        let server = MockServer::start();
253
254        // Initial token fetch succeeds but with very short TTL.
255        let mut success_mock = server.mock(|when, then| {
256            when.method(POST).path("/token");
257            then.status(200)
258                .header("content-type", "application/json")
259                .body(token_json("tok-short-lived", 1));
260        });
261
262        let token = Token::new(test_config(&server)).await.unwrap();
263        assert_eq!(success_mock.calls(), 1);
264
265        // Remove the success mock; refresh attempts will now fail.
266        success_mock.delete();
267        let _fail_mock = server.mock(|when, then| {
268            when.method(POST).path("/token");
269            then.status(500)
270                .header("content-type", "application/json")
271                .body(r#"{"error":"server_error"}"#);
272        });
273
274        // Wait for token to expire + refresh to fail.
275        tokio::time::sleep(Duration::from_secs(3)).await;
276
277        let inner = CaptureHeaderService {
278            expected_header: AUTHORIZATION,
279            expected_value: String::new(), // won't be reached
280        };
281
282        let layer = BearerAuthLayer::new(token);
283        let mut svc = layer.layer(inner);
284
285        let req = Request::builder()
286            .method(Method::GET)
287            .uri("http://example.com/api")
288            .body(Full::new(Bytes::new()))
289            .unwrap();
290
291        let err = Service::call(&mut svc, req).await.unwrap_err();
292        assert!(
293            matches!(err, HttpError::Transport(_)),
294            "expected Transport error, got: {err:?}"
295        );
296    }
297
298    // -- debug safety ---------------------------------------------------------
299
300    #[tokio::test]
301    async fn token_value_not_in_debug() {
302        let server = MockServer::start();
303
304        let _mock = server.mock(|when, then| {
305            when.method(POST).path("/token");
306            then.status(200)
307                .header("content-type", "application/json")
308                .body(token_json("super-secret-layer", 3600));
309        });
310
311        let token = Token::new(test_config(&server)).await.unwrap();
312        let layer = BearerAuthLayer::new(token);
313        let dbg = format!("{layer:?}");
314
315        assert!(
316            !dbg.contains("super-secret-layer"),
317            "Debug must not reveal token value: {dbg}"
318        );
319    }
320}