Skip to main content

modkit_auth/oauth2/
config.rs

1use std::fmt;
2use std::time::Duration;
3use url::Url;
4
5use super::error::TokenError;
6use super::types::{ClientAuthMethod, SecretString};
7
8/// Configuration for an outbound `OAuth2` client credentials flow.
9///
10/// Exactly one of [`token_endpoint`](Self::token_endpoint) or
11/// [`issuer_url`](Self::issuer_url) must be set.  Call
12/// [`validate`](Self::validate) to enforce this constraint.
13///
14/// `Debug` is manually implemented to redact [`client_secret`](Self::client_secret).
15pub struct OAuthClientConfig {
16    // ---- endpoint resolution ------------------------------------------------
17    /// Direct token endpoint URL (mutually exclusive with `issuer_url`).
18    pub token_endpoint: Option<Url>,
19
20    /// OIDC issuer URL for discovery (mutually exclusive with `token_endpoint`).
21    /// The actual token endpoint is resolved via
22    /// `{issuer_url}/.well-known/openid-configuration`.
23    pub issuer_url: Option<Url>,
24
25    // ---- credentials --------------------------------------------------------
26    /// `OAuth2` client identifier.
27    pub client_id: String,
28
29    /// `OAuth2` client secret (redacted in `Debug` output).
30    pub client_secret: SecretString,
31
32    /// Requested scopes (normalized once, stable order).
33    pub scopes: Vec<String>,
34
35    /// How client credentials are transmitted to the token endpoint.
36    pub auth_method: ClientAuthMethod,
37
38    /// Extra headers attached to every token request (vendor quirks).
39    pub extra_headers: Vec<(String, String)>,
40
41    // ---- refresh policy -----------------------------------------------------
42    /// How far before expiry the token should be refreshed (default: 30 min).
43    pub refresh_offset: Duration,
44
45    /// Maximum random jitter added to the refresh offset (default: 5 min).
46    pub jitter_max: Duration,
47
48    /// Minimum period between consecutive refresh attempts (default: 10 s).
49    pub min_refresh_period: Duration,
50
51    /// Fallback TTL when the token endpoint omits `expires_in` (default: 5 min).
52    pub default_ttl: Duration,
53
54    // ---- HTTP client --------------------------------------------------------
55    /// Override for the internal HTTP client configuration.
56    /// When `None`,
57    /// [`HttpClientConfig::token_endpoint()`](modkit_http::HttpClientConfig::token_endpoint)
58    /// is used.
59    pub http_config: Option<modkit_http::HttpClientConfig>,
60}
61
62impl OAuthClientConfig {
63    /// Validate that the configuration is self-consistent.
64    ///
65    /// # Errors
66    ///
67    /// Returns [`TokenError::ConfigError`] if:
68    /// - both `token_endpoint` and `issuer_url` are set, or
69    /// - neither is set.
70    pub fn validate(&self) -> Result<(), TokenError> {
71        if self.client_id.trim().is_empty() {
72            return Err(TokenError::ConfigError(
73                "client_id must not be empty".into(),
74            ));
75        }
76        if self.client_secret.expose().is_empty() {
77            return Err(TokenError::ConfigError(
78                "client_secret must not be empty".into(),
79            ));
80        }
81        match (&self.token_endpoint, &self.issuer_url) {
82            (Some(_), Some(_)) => Err(TokenError::ConfigError(
83                "token_endpoint and issuer_url are mutually exclusive".into(),
84            )),
85            (None, None) => Err(TokenError::ConfigError(
86                "one of token_endpoint or issuer_url must be set".into(),
87            )),
88            _ => Ok(()),
89        }
90    }
91}
92
93impl Clone for OAuthClientConfig {
94    fn clone(&self) -> Self {
95        Self {
96            token_endpoint: self.token_endpoint.clone(),
97            issuer_url: self.issuer_url.clone(),
98            client_id: self.client_id.clone(),
99            client_secret: self.client_secret.clone(),
100            scopes: self.scopes.clone(),
101            auth_method: self.auth_method,
102            extra_headers: self.extra_headers.clone(),
103            refresh_offset: self.refresh_offset,
104            jitter_max: self.jitter_max,
105            min_refresh_period: self.min_refresh_period,
106            default_ttl: self.default_ttl,
107            http_config: self.http_config.clone(),
108        }
109    }
110}
111
112/// `Debug` redacts `client_secret` to prevent accidental exposure in logs.
113impl fmt::Debug for OAuthClientConfig {
114    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
115        let redacted_headers: Vec<_> = self
116            .extra_headers
117            .iter()
118            .map(|(k, _)| (k.as_str(), "[REDACTED]"))
119            .collect();
120        f.debug_struct("OAuthClientConfig")
121            .field("token_endpoint", &self.token_endpoint)
122            .field("issuer_url", &self.issuer_url)
123            .field("client_id", &self.client_id)
124            .field("client_secret", &"[REDACTED]")
125            .field("scopes", &self.scopes)
126            .field("auth_method", &self.auth_method)
127            .field("extra_headers", &redacted_headers)
128            .field("refresh_offset", &self.refresh_offset)
129            .field("jitter_max", &self.jitter_max)
130            .field("min_refresh_period", &self.min_refresh_period)
131            .field("default_ttl", &self.default_ttl)
132            .field("http_config", &self.http_config)
133            .finish()
134    }
135}
136
137impl Default for OAuthClientConfig {
138    fn default() -> Self {
139        Self {
140            token_endpoint: None,
141            issuer_url: None,
142            client_id: String::new(),
143            client_secret: SecretString::new(String::new()),
144            scopes: Vec::new(),
145            auth_method: ClientAuthMethod::default(),
146            extra_headers: Vec::new(),
147            refresh_offset: Duration::from_secs(30 * 60),
148            jitter_max: Duration::from_secs(5 * 60),
149            min_refresh_period: Duration::from_secs(10),
150            default_ttl: Duration::from_secs(5 * 60),
151            http_config: None,
152        }
153    }
154}
155
156#[cfg(test)]
157#[cfg_attr(coverage_nightly, coverage(off))]
158mod tests {
159    use super::*;
160
161    fn test_url(s: &str) -> Url {
162        Url::parse(s).unwrap()
163    }
164
165    // ---- validate -----------------------------------------------------------
166
167    /// Returns a minimal valid config (credentials + one endpoint).
168    fn valid_base() -> OAuthClientConfig {
169        OAuthClientConfig {
170            client_id: "my-client".into(),
171            client_secret: SecretString::new("my-secret"),
172            ..Default::default()
173        }
174    }
175
176    #[test]
177    fn validate_ok_with_token_endpoint_only() {
178        let cfg = OAuthClientConfig {
179            token_endpoint: Some(test_url("https://auth.example.com/token")),
180            ..valid_base()
181        };
182        assert!(cfg.validate().is_ok());
183    }
184
185    #[test]
186    fn validate_ok_with_issuer_url_only() {
187        let cfg = OAuthClientConfig {
188            issuer_url: Some(test_url("https://auth.example.com")),
189            ..valid_base()
190        };
191        assert!(cfg.validate().is_ok());
192    }
193
194    #[test]
195    fn validate_err_when_both_set() {
196        let cfg = OAuthClientConfig {
197            token_endpoint: Some(test_url("https://a.example.com/token")),
198            issuer_url: Some(test_url("https://b.example.com")),
199            ..valid_base()
200        };
201        let err = cfg.validate().unwrap_err();
202        assert!(
203            err.to_string().contains("mutually exclusive"),
204            "unexpected error: {err}"
205        );
206    }
207
208    #[test]
209    fn validate_err_when_neither_set() {
210        let cfg = valid_base();
211        let err = cfg.validate().unwrap_err();
212        assert!(
213            err.to_string().contains("must be set"),
214            "unexpected error: {err}"
215        );
216    }
217
218    #[test]
219    fn validate_err_when_client_id_empty() {
220        let cfg = OAuthClientConfig {
221            token_endpoint: Some(test_url("https://auth.example.com/token")),
222            client_id: String::new(),
223            client_secret: SecretString::new("my-secret"),
224            ..Default::default()
225        };
226        let err = cfg.validate().unwrap_err();
227        assert!(
228            err.to_string().contains("client_id"),
229            "unexpected error: {err}"
230        );
231    }
232
233    #[test]
234    fn validate_err_when_client_id_whitespace() {
235        let cfg = OAuthClientConfig {
236            token_endpoint: Some(test_url("https://auth.example.com/token")),
237            client_id: "   ".into(),
238            client_secret: SecretString::new("my-secret"),
239            ..Default::default()
240        };
241        let err = cfg.validate().unwrap_err();
242        assert!(
243            err.to_string().contains("client_id"),
244            "unexpected error: {err}"
245        );
246    }
247
248    #[test]
249    fn validate_err_when_client_secret_empty() {
250        let cfg = OAuthClientConfig {
251            token_endpoint: Some(test_url("https://auth.example.com/token")),
252            client_id: "my-client".into(),
253            client_secret: SecretString::new(""),
254            ..Default::default()
255        };
256        let err = cfg.validate().unwrap_err();
257        assert!(
258            err.to_string().contains("client_secret"),
259            "unexpected error: {err}"
260        );
261    }
262
263    // ---- Debug redaction ----------------------------------------------------
264
265    #[test]
266    fn debug_redacts_client_secret() {
267        let cfg = OAuthClientConfig {
268            token_endpoint: Some(test_url("https://auth.example.com/token")),
269            client_id: "my-client".into(),
270            client_secret: SecretString::new("super-secret"),
271            ..Default::default()
272        };
273        let dbg = format!("{cfg:?}");
274        assert!(dbg.contains("[REDACTED]"), "Debug must contain [REDACTED]");
275        assert!(
276            !dbg.contains("super-secret"),
277            "Debug must not contain the raw secret"
278        );
279        assert!(dbg.contains("my-client"), "Debug should contain client_id");
280    }
281
282    #[test]
283    fn debug_redacts_extra_header_values() {
284        let cfg = OAuthClientConfig {
285            token_endpoint: Some(test_url("https://auth.example.com/token")),
286            client_id: "my-client".into(),
287            client_secret: SecretString::new("s"),
288            extra_headers: vec![("x-api-key".into(), "secret-api-key-value".into())],
289            ..Default::default()
290        };
291        let dbg = format!("{cfg:?}");
292        assert!(
293            dbg.contains("x-api-key"),
294            "Debug should contain header name"
295        );
296        assert!(
297            !dbg.contains("secret-api-key-value"),
298            "Debug must not contain header value"
299        );
300    }
301
302    // ---- Default ------------------------------------------------------------
303
304    #[test]
305    fn default_durations() {
306        let cfg = OAuthClientConfig::default();
307        assert_eq!(cfg.refresh_offset, Duration::from_secs(30 * 60));
308        assert_eq!(cfg.jitter_max, Duration::from_secs(5 * 60));
309        assert_eq!(cfg.min_refresh_period, Duration::from_secs(10));
310        assert_eq!(cfg.default_ttl, Duration::from_secs(5 * 60));
311        assert_eq!(cfg.auth_method, ClientAuthMethod::Basic);
312    }
313}