Skip to main content

modkit_auth/oauth2/
builder_ext.rs

1use http::header::HeaderName;
2use tower::ServiceExt;
3
4use super::layer::BearerAuthLayer;
5use super::token::Token;
6
7/// Extension trait for adding bearer auth to [`modkit_http::HttpClientBuilder`].
8///
9/// # Example
10///
11/// ```ignore
12/// use modkit_auth::HttpClientBuilderExt;
13///
14/// let token = Token::new(config).await?;
15/// let client = HttpClientBuilder::new()
16///     .with_bearer_auth(token)
17///     .build()?;
18/// ```
19pub trait HttpClientBuilderExt {
20    /// Add `Authorization: Bearer <token>` injection to the HTTP client.
21    #[must_use]
22    fn with_bearer_auth(self, token: Token) -> Self;
23
24    /// Add `<header_name>: Bearer <token>` injection to the HTTP client.
25    #[must_use]
26    fn with_bearer_auth_header(self, token: Token, header_name: HeaderName) -> Self;
27}
28
29impl HttpClientBuilderExt for modkit_http::HttpClientBuilder {
30    fn with_bearer_auth(self, token: Token) -> Self {
31        let layer = BearerAuthLayer::new(token);
32        self.with_auth_layer(move |svc| {
33            tower::ServiceBuilder::new()
34                .layer(layer)
35                .service(svc)
36                .boxed_clone()
37        })
38    }
39
40    fn with_bearer_auth_header(self, token: Token, header_name: HeaderName) -> Self {
41        let layer = BearerAuthLayer::with_header_name(token, header_name);
42        self.with_auth_layer(move |svc| {
43            tower::ServiceBuilder::new()
44                .layer(layer)
45                .service(svc)
46                .boxed_clone()
47        })
48    }
49}
50
51#[cfg(test)]
52#[cfg_attr(coverage_nightly, coverage(off))]
53mod tests {
54    use super::*;
55    use httpmock::prelude::*;
56    use modkit_utils::SecretString;
57    use std::time::Duration;
58    use url::Url;
59
60    use crate::oauth2::config::OAuthClientConfig;
61
62    /// Build a test config pointing at the given mock server for token acquisition.
63    fn token_config(server: &MockServer) -> OAuthClientConfig {
64        OAuthClientConfig {
65            token_endpoint: Some(
66                Url::parse(&format!("http://localhost:{}/token", server.port())).unwrap(),
67            ),
68            client_id: "test-client".into(),
69            client_secret: SecretString::new("test-secret"),
70            http_config: Some(modkit_http::HttpClientConfig::for_testing()),
71            jitter_max: Duration::from_millis(0),
72            min_refresh_period: Duration::from_millis(100),
73            ..Default::default()
74        }
75    }
76
77    fn token_json(token: &str, expires_in: u64) -> String {
78        format!(r#"{{"access_token":"{token}","expires_in":{expires_in},"token_type":"Bearer"}}"#)
79    }
80
81    #[tokio::test]
82    async fn with_bearer_auth_injects_header() {
83        // OAuth token endpoint
84        let oauth_server = MockServer::start();
85        let _token_mock = oauth_server.mock(|when, then| {
86            when.method(POST).path("/token");
87            then.status(200)
88                .header("content-type", "application/json")
89                .body(token_json("tok-builder-ext", 3600));
90        });
91
92        // Target API server
93        let api_server = MockServer::start();
94        let api_mock = api_server.mock(|when, then| {
95            when.method(GET)
96                .path("/api/data")
97                .header("authorization", "Bearer tok-builder-ext");
98            then.status(200)
99                .header("content-type", "application/json")
100                .body(r#"{"ok":true}"#);
101        });
102
103        let token = Token::new(token_config(&oauth_server)).await.unwrap();
104
105        let client = modkit_http::HttpClientBuilder::new()
106            .with_bearer_auth(token)
107            .build()
108            .unwrap();
109
110        let _resp = client
111            .get(&format!("http://localhost:{}/api/data", api_server.port()))
112            .send()
113            .await
114            .unwrap();
115
116        api_mock.assert();
117    }
118
119    #[tokio::test]
120    async fn with_bearer_auth_header_injects_custom_header() {
121        let oauth_server = MockServer::start();
122        let _token_mock = oauth_server.mock(|when, then| {
123            when.method(POST).path("/token");
124            then.status(200)
125                .header("content-type", "application/json")
126                .body(token_json("tok-custom-hdr-ext", 3600));
127        });
128
129        let api_server = MockServer::start();
130        let api_mock = api_server.mock(|when, then| {
131            when.method(GET)
132                .path("/api/data")
133                .header("x-api-key", "Bearer tok-custom-hdr-ext");
134            then.status(200)
135                .header("content-type", "application/json")
136                .body(r#"{"ok":true}"#);
137        });
138
139        let token = Token::new(token_config(&oauth_server)).await.unwrap();
140        let custom = HeaderName::from_static("x-api-key");
141
142        let client = modkit_http::HttpClientBuilder::new()
143            .with_bearer_auth_header(token, custom)
144            .build()
145            .unwrap();
146
147        let _resp = client
148            .get(&format!("http://localhost:{}/api/data", api_server.port()))
149            .send()
150            .await
151            .unwrap();
152
153        api_mock.assert();
154    }
155
156    #[tokio::test]
157    async fn without_bearer_auth_no_header() {
158        let api_server = MockServer::start();
159
160        // Mock that REQUIRES Authorization header — should NOT be hit.
161        let auth_mock = api_server.mock(|when, then| {
162            when.method(GET)
163                .path("/api/data")
164                .header_exists("authorization");
165            then.status(200).body("authed");
166        });
167
168        // Catch-all mock for the GET.
169        let fallback_mock = api_server.mock(|when, then| {
170            when.method(GET).path("/api/data");
171            then.status(200).body("no-auth");
172        });
173
174        let client = modkit_http::HttpClientBuilder::new().build().unwrap();
175
176        let _resp = client
177            .get(&format!("http://localhost:{}/api/data", api_server.port()))
178            .send()
179            .await
180            .unwrap();
181
182        assert_eq!(
183            auth_mock.calls(),
184            0,
185            "No Authorization header should be sent without bearer auth"
186        );
187        fallback_mock.assert();
188    }
189}