1use std::future::Future;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use http::header::{AUTHORIZATION, HeaderName};
6use http::{HeaderValue, Request, Response};
7use tower::{Layer, Service};
8
9use super::token::Token;
10use modkit_http::HttpError;
11
12#[derive(Clone, Debug)]
18pub struct BearerAuthLayer {
19 token: Token,
20 header_name: HeaderName,
21}
22
23impl BearerAuthLayer {
24 #[must_use]
26 pub fn new(token: Token) -> Self {
27 Self {
28 token,
29 header_name: AUTHORIZATION,
30 }
31 }
32
33 #[must_use]
35 pub fn with_header_name(token: Token, header_name: HeaderName) -> Self {
36 Self { token, header_name }
37 }
38}
39
40impl<S> Layer<S> for BearerAuthLayer {
41 type Service = BearerAuthService<S>;
42
43 fn layer(&self, inner: S) -> Self::Service {
44 BearerAuthService {
45 inner,
46 token: self.token.clone(),
47 header_name: self.header_name.clone(),
48 }
49 }
50}
51
52#[derive(Clone, Debug)]
57pub struct BearerAuthService<S> {
58 inner: S,
59 token: Token,
60 header_name: HeaderName,
61}
62
63impl<S, B, ResBody> Service<Request<B>> for BearerAuthService<S>
64where
65 S: Service<Request<B>, Response = Response<ResBody>, Error = HttpError>
66 + Clone
67 + Send
68 + 'static,
69 S::Future: Send,
70 B: Send + 'static,
71 ResBody: Send + 'static,
72{
73 type Response = Response<ResBody>;
74 type Error = HttpError;
75 type Future = Pin<Box<dyn Future<Output = Result<Response<ResBody>, HttpError>> + Send>>;
76
77 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
78 self.inner.poll_ready(cx)
79 }
80
81 fn call(&mut self, mut req: Request<B>) -> Self::Future {
82 let mut bearer_value = match self.token.get() {
83 Ok(secret) => {
84 let raw = zeroize::Zeroizing::new(format!("Bearer {}", secret.expose()));
85 match HeaderValue::from_str(&raw) {
86 Ok(v) => v,
87 Err(e) => return Box::pin(async { Err(HttpError::InvalidHeaderValue(e)) }),
88 }
89 }
90 Err(e) => {
91 return Box::pin(async { Err(HttpError::Transport(Box::new(e))) });
92 }
93 };
94 bearer_value.set_sensitive(true);
95
96 req.headers_mut()
97 .insert(self.header_name.clone(), bearer_value);
98
99 let clone = self.inner.clone();
101 let mut inner = std::mem::replace(&mut self.inner, clone);
102
103 Box::pin(async move { inner.call(req).await })
104 }
105}
106
107#[cfg(test)]
108#[cfg_attr(coverage_nightly, coverage(off))]
109mod tests {
110 use super::*;
111 use bytes::Bytes;
112 use http::{Method, Request, Response, StatusCode};
113 use http_body_util::Full;
114 use httpmock::prelude::*;
115 use modkit_utils::SecretString;
116 use std::time::Duration;
117 use url::Url;
118
119 use crate::oauth2::config::OAuthClientConfig;
120
121 fn test_config(server: &MockServer) -> OAuthClientConfig {
123 OAuthClientConfig {
124 token_endpoint: Some(
125 Url::parse(&format!("http://localhost:{}/token", server.port())).unwrap(),
126 ),
127 client_id: "test-client".into(),
128 client_secret: SecretString::new("test-secret"),
129 http_config: Some(modkit_http::HttpClientConfig::for_testing()),
130 jitter_max: Duration::from_millis(0),
131 min_refresh_period: Duration::from_millis(100),
132 ..Default::default()
133 }
134 }
135
136 fn token_json(token: &str, expires_in: u64) -> String {
137 format!(r#"{{"access_token":"{token}","expires_in":{expires_in},"token_type":"Bearer"}}"#)
138 }
139
140 #[derive(Clone)]
144 struct CaptureHeaderService {
145 expected_header: HeaderName,
146 expected_value: String,
147 }
148
149 impl Service<Request<Full<Bytes>>> for CaptureHeaderService {
150 type Response = Response<Full<Bytes>>;
151 type Error = HttpError;
152 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
153
154 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
155 Poll::Ready(Ok(()))
156 }
157
158 fn call(&mut self, req: Request<Full<Bytes>>) -> Self::Future {
159 let header = req
160 .headers()
161 .get(&self.expected_header)
162 .expect("expected header not found")
163 .to_str()
164 .unwrap()
165 .to_owned();
166 let expected = self.expected_value.clone();
167
168 Box::pin(async move {
169 assert_eq!(header, expected);
170 Ok(Response::builder()
171 .status(StatusCode::OK)
172 .body(Full::new(Bytes::new()))
173 .unwrap())
174 })
175 }
176 }
177
178 #[test]
181 fn bearer_auth_is_send_sync_clone() {
182 fn assert_traits<T: Send + Sync + Clone>() {}
183 assert_traits::<BearerAuthLayer>();
184 assert_traits::<BearerAuthService<CaptureHeaderService>>();
185 }
186
187 #[tokio::test]
190 async fn injects_authorization_header() {
191 let server = MockServer::start();
192
193 let _mock = server.mock(|when, then| {
194 when.method(POST).path("/token");
195 then.status(200)
196 .header("content-type", "application/json")
197 .body(token_json("tok-layer-test", 3600));
198 });
199
200 let token = Token::new(test_config(&server)).await.unwrap();
201 let inner = CaptureHeaderService {
202 expected_header: AUTHORIZATION,
203 expected_value: "Bearer tok-layer-test".into(),
204 };
205
206 let layer = BearerAuthLayer::new(token);
207 let mut svc = layer.layer(inner);
208
209 let req = Request::builder()
210 .method(Method::GET)
211 .uri("http://example.com/api")
212 .body(Full::new(Bytes::new()))
213 .unwrap();
214
215 Service::call(&mut svc, req).await.unwrap();
216 }
217
218 #[tokio::test]
219 async fn custom_header_name() {
220 let server = MockServer::start();
221
222 let _mock = server.mock(|when, then| {
223 when.method(POST).path("/token");
224 then.status(200)
225 .header("content-type", "application/json")
226 .body(token_json("tok-custom-hdr", 3600));
227 });
228
229 let token = Token::new(test_config(&server)).await.unwrap();
230 let custom_header = HeaderName::from_static("x-api-key");
231 let inner = CaptureHeaderService {
232 expected_header: custom_header.clone(),
233 expected_value: "Bearer tok-custom-hdr".into(),
234 };
235
236 let layer = BearerAuthLayer::with_header_name(token, custom_header);
237 let mut svc = layer.layer(inner);
238
239 let req = Request::builder()
240 .method(Method::GET)
241 .uri("http://example.com/api")
242 .body(Full::new(Bytes::new()))
243 .unwrap();
244
245 Service::call(&mut svc, req).await.unwrap();
246 }
247
248 #[tokio::test]
251 async fn returns_error_when_token_expired() {
252 let server = MockServer::start();
253
254 let mut success_mock = server.mock(|when, then| {
256 when.method(POST).path("/token");
257 then.status(200)
258 .header("content-type", "application/json")
259 .body(token_json("tok-short-lived", 1));
260 });
261
262 let token = Token::new(test_config(&server)).await.unwrap();
263 assert_eq!(success_mock.calls(), 1);
264
265 success_mock.delete();
267 let _fail_mock = server.mock(|when, then| {
268 when.method(POST).path("/token");
269 then.status(500)
270 .header("content-type", "application/json")
271 .body(r#"{"error":"server_error"}"#);
272 });
273
274 tokio::time::sleep(Duration::from_secs(3)).await;
276
277 let inner = CaptureHeaderService {
278 expected_header: AUTHORIZATION,
279 expected_value: String::new(), };
281
282 let layer = BearerAuthLayer::new(token);
283 let mut svc = layer.layer(inner);
284
285 let req = Request::builder()
286 .method(Method::GET)
287 .uri("http://example.com/api")
288 .body(Full::new(Bytes::new()))
289 .unwrap();
290
291 let err = Service::call(&mut svc, req).await.unwrap_err();
292 assert!(
293 matches!(err, HttpError::Transport(_)),
294 "expected Transport error, got: {err:?}"
295 );
296 }
297
298 #[tokio::test]
301 async fn token_value_not_in_debug() {
302 let server = MockServer::start();
303
304 let _mock = server.mock(|when, then| {
305 when.method(POST).path("/token");
306 then.status(200)
307 .header("content-type", "application/json")
308 .body(token_json("super-secret-layer", 3600));
309 });
310
311 let token = Token::new(test_config(&server)).await.unwrap();
312 let layer = BearerAuthLayer::new(token);
313 let dbg = format!("{layer:?}");
314
315 assert!(
316 !dbg.contains("super-secret-layer"),
317 "Debug must not reveal token value: {dbg}"
318 );
319 }
320}