Skip to main content

modkit_auth/oauth2/
builder_ext.rs

1use http::header::HeaderName;
2use tower::ServiceExt;
3
4use super::layer::BearerAuthLayer;
5use super::token::Token;
6
7/// Extension trait for adding bearer auth to [`modkit_http::HttpClientBuilder`].
8///
9/// # Example
10///
11/// ```ignore
12/// use modkit_auth::HttpClientBuilderExt;
13///
14/// let token = Token::new(config).await?;
15/// let client = HttpClientBuilder::new()
16///     .with_bearer_auth(token)
17///     .build()?;
18/// ```
19pub trait HttpClientBuilderExt {
20    /// Add `Authorization: Bearer <token>` injection to the HTTP client.
21    #[must_use]
22    fn with_bearer_auth(self, token: Token) -> Self;
23
24    /// Add `<header_name>: Bearer <token>` injection to the HTTP client.
25    #[must_use]
26    fn with_bearer_auth_header(self, token: Token, header_name: HeaderName) -> Self;
27}
28
29impl HttpClientBuilderExt for modkit_http::HttpClientBuilder {
30    fn with_bearer_auth(self, token: Token) -> Self {
31        let layer = BearerAuthLayer::new(token);
32        self.with_auth_layer(move |svc| {
33            tower::ServiceBuilder::new()
34                .layer(layer)
35                .service(svc)
36                .boxed_clone()
37        })
38    }
39
40    fn with_bearer_auth_header(self, token: Token, header_name: HeaderName) -> Self {
41        let layer = BearerAuthLayer::with_header_name(token, header_name);
42        self.with_auth_layer(move |svc| {
43            tower::ServiceBuilder::new()
44                .layer(layer)
45                .service(svc)
46                .boxed_clone()
47        })
48    }
49}
50
51#[cfg(test)]
52#[cfg_attr(coverage_nightly, coverage(off))]
53mod tests {
54    use super::*;
55    use httpmock::prelude::*;
56    use modkit_utils::SecretString;
57    use std::time::Duration;
58    use url::Url;
59
60    use crate::oauth2::config::OAuthClientConfig;
61
62    /// Build a test config pointing at the given mock server for token acquisition.
63    fn token_config(server: &MockServer) -> OAuthClientConfig {
64        OAuthClientConfig {
65            token_endpoint: Some(
66                Url::parse(&format!("http://localhost:{}/token", server.port())).unwrap(),
67            ),
68            client_id: "test-client".into(),
69            client_secret: SecretString::new("test-secret"),
70            http_config: Some(modkit_http::HttpClientConfig::for_testing()),
71            jitter_max: Duration::from_millis(0),
72            min_refresh_period: Duration::from_millis(100),
73            ..Default::default()
74        }
75    }
76
77    fn token_json(token: &str, expires_in: u64) -> String {
78        format!(r#"{{"access_token":"{token}","expires_in":{expires_in},"token_type":"Bearer"}}"#)
79    }
80
81    #[tokio::test]
82    async fn with_bearer_auth_injects_header() {
83        // OAuth token endpoint
84        let oauth_server = MockServer::start();
85        let _token_mock = oauth_server.mock(|when, then| {
86            when.method(POST).path("/token");
87            then.status(200)
88                .header("content-type", "application/json")
89                .body(token_json("tok-builder-ext", 3600));
90        });
91
92        // Target API server
93        let api_server = MockServer::start();
94        let api_mock = api_server.mock(|when, then| {
95            when.method(GET)
96                .path("/api/data")
97                .header("authorization", "Bearer tok-builder-ext");
98            then.status(200)
99                .header("content-type", "application/json")
100                .body(r#"{"ok":true}"#);
101        });
102
103        let token = Token::new(token_config(&oauth_server)).await.unwrap();
104
105        let client = modkit_http::HttpClientBuilder::new()
106            .allow_insecure_http()
107            .with_bearer_auth(token)
108            .build()
109            .unwrap();
110
111        let _resp = client
112            .get(&format!("http://localhost:{}/api/data", api_server.port()))
113            .send()
114            .await
115            .unwrap();
116
117        api_mock.assert();
118    }
119
120    #[tokio::test]
121    async fn with_bearer_auth_header_injects_custom_header() {
122        let oauth_server = MockServer::start();
123        let _token_mock = oauth_server.mock(|when, then| {
124            when.method(POST).path("/token");
125            then.status(200)
126                .header("content-type", "application/json")
127                .body(token_json("tok-custom-hdr-ext", 3600));
128        });
129
130        let api_server = MockServer::start();
131        let api_mock = api_server.mock(|when, then| {
132            when.method(GET)
133                .path("/api/data")
134                .header("x-api-key", "Bearer tok-custom-hdr-ext");
135            then.status(200)
136                .header("content-type", "application/json")
137                .body(r#"{"ok":true}"#);
138        });
139
140        let token = Token::new(token_config(&oauth_server)).await.unwrap();
141        let custom = HeaderName::from_static("x-api-key");
142
143        let client = modkit_http::HttpClientBuilder::new()
144            .allow_insecure_http()
145            .with_bearer_auth_header(token, custom)
146            .build()
147            .unwrap();
148
149        let _resp = client
150            .get(&format!("http://localhost:{}/api/data", api_server.port()))
151            .send()
152            .await
153            .unwrap();
154
155        api_mock.assert();
156    }
157
158    #[tokio::test]
159    async fn without_bearer_auth_no_header() {
160        let api_server = MockServer::start();
161
162        // Mock that REQUIRES Authorization header — should NOT be hit.
163        let auth_mock = api_server.mock(|when, then| {
164            when.method(GET)
165                .path("/api/data")
166                .header_exists("authorization");
167            then.status(200).body("authed");
168        });
169
170        // Catch-all mock for the GET.
171        let fallback_mock = api_server.mock(|when, then| {
172            when.method(GET).path("/api/data");
173            then.status(200).body("no-auth");
174        });
175
176        let client = modkit_http::HttpClientBuilder::new()
177            .allow_insecure_http()
178            .build()
179            .unwrap();
180
181        let _resp = client
182            .get(&format!("http://localhost:{}/api/data", api_server.port()))
183            .send()
184            .await
185            .unwrap();
186
187        assert_eq!(
188            auth_mock.calls(),
189            0,
190            "No Authorization header should be sent without bearer auth"
191        );
192        fallback_mock.assert();
193    }
194}