1use crate::clock::{Clock, SystemClock};
6use oauth2::basic::BasicClient;
7use oauth2::reqwest;
8use oauth2::{
9 AuthUrl, ClientId, ClientSecret, EndpointNotSet, EndpointSet, Scope, TokenResponse, TokenUrl,
10};
11use thiserror::Error;
12use url;
13
14const DEFAULT_REFRESH_SKEW_SECONDS: u64 = 30;
15
16#[derive(Debug, Error)]
17pub enum OAuth2TokenProviderError {
18 #[error("invalid config: {0}")]
19 InvalidConfig(&'static str),
20
21 #[error("invalid token_url: {0}")]
22 InvalidTokenUrl(#[from] url::ParseError),
23
24 #[error("failed to build http client: {0}")]
25 HttpClientBuild(String),
26
27 #[error("token request failed: {0}")]
28 TokenRequest(String),
29
30 #[error("token response missing expires_in")]
31 MissingExpiresIn,
32}
33
34#[derive(Debug, Clone)]
35pub struct ClientCredentialsConfig {
36 pub token_url: String,
37 pub client_id: String,
38 pub client_secret: String,
39 pub scopes: Vec<String>,
40 pub refresh_skew_seconds: Option<u64>,
41}
42
43#[derive(Debug, Clone)]
44struct CachedToken {
45 access_token: String,
46 expired_at: u64,
47}
48
49type ConfiguredBasicClient =
50 BasicClient<EndpointSet, EndpointNotSet, EndpointNotSet, EndpointNotSet, EndpointSet>;
51
52#[derive(Debug)]
53pub struct OAuth2TokenProvider<C: Clock = SystemClock> {
54 config: ClientCredentialsConfig,
55 client: ConfiguredBasicClient,
56 cache: tokio::sync::RwLock<Option<CachedToken>>,
57 clock: C,
58}
59
60pub type Result<T> = std::result::Result<T, OAuth2TokenProviderError>;
61
62impl OAuth2TokenProvider<SystemClock> {
63 pub fn new(config: ClientCredentialsConfig) -> Result<Self> {
64 Self::with_clock(config, SystemClock)
65 }
66}
67
68impl<C: Clock> OAuth2TokenProvider<C> {
69 pub fn with_clock(config: ClientCredentialsConfig, clock: C) -> Result<Self> {
70 Self::validate_config(&config)?;
71
72 let token_url = TokenUrl::new(config.token_url.clone())?;
73 let auth_url = AuthUrl::new("https://invalid.local/authorize".to_string())
74 .expect("hardcoded url must be valid");
75
76 let client = BasicClient::new(ClientId::new(config.client_id.clone()))
77 .set_client_secret(ClientSecret::new(config.client_secret.clone()))
78 .set_auth_uri(auth_url)
79 .set_token_uri(token_url);
80
81 Ok(Self {
82 config,
83 client,
84 cache: tokio::sync::RwLock::new(None),
85 clock,
86 })
87 }
88
89 fn validate_config(config: &ClientCredentialsConfig) -> Result<()> {
90 if config.token_url.trim().is_empty() {
91 return Err(OAuth2TokenProviderError::InvalidConfig(
92 "token_url must not be empty",
93 ));
94 }
95 if config.client_id.trim().is_empty() {
96 return Err(OAuth2TokenProviderError::InvalidConfig(
97 "client_id must not be empty",
98 ));
99 }
100 if config.client_secret.trim().is_empty() {
101 return Err(OAuth2TokenProviderError::InvalidConfig(
102 "client_secret must not be empty",
103 ));
104 }
105 Ok(())
106 }
107
108 pub fn refresh_skew(&self) -> u64 {
109 match self.config.refresh_skew_seconds {
110 Some(0) => 0,
111 Some(seconds) => seconds,
112 None => DEFAULT_REFRESH_SKEW_SECONDS,
113 }
114 }
115
116 pub fn get_token(&self) -> Result<String> {
117 let guard = self.cache.blocking_read();
118 if let Some(token) = guard.as_ref() {
119 let now = self.clock.now();
120 let skew = self.refresh_skew();
121 if skew == 0 || token.expired_at > now.saturating_add(skew) {
122 return Ok(token.access_token.clone());
123 }
124 }
125 drop(guard); let http_client = reqwest::blocking::ClientBuilder::new()
128 .redirect(reqwest::redirect::Policy::none())
129 .build()
130 .map_err(|e| OAuth2TokenProviderError::HttpClientBuild(e.to_string()))?;
131
132 let mut req = self.client.exchange_client_credentials();
133 for s in &self.config.scopes {
134 let scope = s.trim();
135 if !scope.is_empty() {
136 req = req.add_scope(Scope::new(scope.to_string()));
137 }
138 }
139
140 let token = req
141 .request(&http_client)
142 .map_err(|e| OAuth2TokenProviderError::TokenRequest(e.to_string()))?;
143
144 let access_token = token.access_token().secret().to_string();
145 let expires_in = token
146 .expires_in()
147 .ok_or(OAuth2TokenProviderError::MissingExpiresIn)?
148 .as_secs();
149
150 let expired_at = self.clock.now().saturating_add(expires_in);
151
152 let mut guard = self.cache.blocking_write();
153 *guard = Some(CachedToken {
154 access_token: access_token.clone(),
155 expired_at,
156 });
157 drop(guard); Ok(access_token)
160 }
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166 use crate::clock::ManualClock;
167 use httpmock::prelude::*;
168 use std::sync::Arc;
169
170 fn cfg(token_url: String) -> ClientCredentialsConfig {
171 ClientCredentialsConfig {
172 token_url,
173 client_id: "client-id".to_string(),
174 client_secret: "client-secret".to_string(),
175 scopes: vec!["scope-a".into(), "scope-b".into()],
176 refresh_skew_seconds: Some(30),
177 }
178 }
179
180 #[test]
181 fn new_rejects_empty_fields() {
182 let base = ClientCredentialsConfig {
183 token_url: "http://localhost/token".into(),
184 client_id: "id".into(),
185 client_secret: "secret".into(),
186 scopes: vec![],
187 refresh_skew_seconds: None,
188 };
189
190 let mut c = base.clone();
191 c.token_url = " ".into();
192 assert!(matches!(
193 OAuth2TokenProvider::new(c).unwrap_err(),
194 OAuth2TokenProviderError::InvalidConfig(_)
195 ));
196
197 let mut c = base.clone();
198 c.client_id = "".into();
199 assert!(matches!(
200 OAuth2TokenProvider::new(c).unwrap_err(),
201 OAuth2TokenProviderError::InvalidConfig(_)
202 ));
203
204 let mut c = base.clone();
205 c.client_secret = " ".into();
206 assert!(matches!(
207 OAuth2TokenProvider::new(c).unwrap_err(),
208 OAuth2TokenProviderError::InvalidConfig(_)
209 ));
210 }
211
212 #[test]
213 fn refresh_skew_none_uses_default() {
214 let config = ClientCredentialsConfig {
215 token_url: "http://localhost/token".into(),
216 client_id: "id".into(),
217 client_secret: "secret".into(),
218 scopes: vec![],
219 refresh_skew_seconds: None,
220 };
221
222 let provider = OAuth2TokenProvider::new(config).unwrap();
223 assert_eq!(provider.refresh_skew(), DEFAULT_REFRESH_SKEW_SECONDS);
224 }
225
226 #[test]
227 fn get_token_fetches_and_caches_token() {
228 let server = MockServer::start();
229
230 let token_mock = server.mock(|when, then| {
231 when.method(POST)
232 .path("/token")
233 .header("content-type", "application/x-www-form-urlencoded")
234 .body_not("grant_type=client_credentials")
235 .body_not("scope=scope-a+scope-b");
236 then.status(200)
237 .header("content-type", "application/json")
238 .body(r#"{"access_token":"t1","token_type":"bearer","expires_in":3600}"#);
239 });
240
241 let clock = Arc::new(ManualClock::new(1000));
242 let provider =
243 OAuth2TokenProvider::with_clock(cfg(format!("{}/token", server.base_url())), clock)
244 .unwrap();
245
246 let t1 = provider.get_token().unwrap();
247 let t2 = provider.get_token().unwrap();
248
249 assert_eq!(t1, "t1");
250 assert_eq!(t2, "t1");
251 token_mock.assert_calls(1);
252 }
253
254 #[test]
255 fn get_token_refreshes_when_skew_window_reached() {
256 let server = MockServer::start();
257
258 let mock_server = server.mock(|when, then| {
259 when.method(POST).path("/token");
260 then.status(200)
261 .header("content-type", "application/json")
262 .body(r#"{"access_token":"t1","token_type":"bearer","expires_in":10}"#);
263 });
264
265 let clock = Arc::new(ManualClock::new(1000));
266
267 let config = ClientCredentialsConfig {
268 token_url: format!("{}/token", server.base_url()),
269 client_id: "client-id".into(),
270 client_secret: "client-secret".into(),
271 scopes: vec![],
272 refresh_skew_seconds: Some(30),
273 };
274
275 let provider = OAuth2TokenProvider::with_clock(config, clock.clone()).unwrap();
276
277 let token = provider.get_token().unwrap();
278 assert_eq!(token, "t1");
279
280 clock.advance(300);
282
283 provider.get_token().unwrap();
284
285 mock_server.assert_calls(2);
286 }
287
288 #[test]
289 fn get_token_does_not_refresh_when_skew_is_zero() {
290 let server = MockServer::start();
291
292 let mock_server = server.mock(|when, then| {
293 when.method(POST).path("/token");
294 then.status(200)
295 .header("content-type", "application/json")
296 .body(r#"{"access_token":"t1","token_type":"bearer","expires_in":1}"#);
297 });
298
299 let clock = Arc::new(ManualClock::new(1000));
300
301 let config = ClientCredentialsConfig {
302 token_url: format!("{}/token", server.base_url()),
303 client_id: "client-id".into(),
304 client_secret: "client-secret".into(),
305 scopes: vec![],
306 refresh_skew_seconds: Some(0),
307 };
308
309 let provider = OAuth2TokenProvider::with_clock(config, clock.clone()).unwrap();
310
311 let token = provider.get_token().unwrap();
312 clock.advance(1000);
314 provider.get_token().unwrap();
315
316 assert_eq!(token, "t1");
317 mock_server.assert_calls(1);
318 }
319
320 #[test]
321 fn missing_expires_in_returns_error() {
322 let server = MockServer::start();
323
324 server.mock(|when, then| {
325 when.method(POST).path("/token");
326 then.status(200)
327 .header("content-type", "application/json")
328 .body(r#"{"access_token":"t1","token_type":"bearer"}"#);
329 });
330
331 let clock = Arc::new(ManualClock::new(1000));
332 let provider =
333 OAuth2TokenProvider::with_clock(cfg(format!("{}/token", server.base_url())), clock)
334 .unwrap();
335
336 let err = provider.get_token().unwrap_err();
337 assert!(matches!(err, OAuth2TokenProviderError::MissingExpiresIn));
338 }
339
340 #[test]
341 fn token_request_error_is_mapped() {
342 let clock = Arc::new(ManualClock::new(1000));
343 let provider = OAuth2TokenProvider::with_clock(
344 ClientCredentialsConfig {
345 token_url: "http://127.0.0.1:9/token".into(),
346 client_id: "client-id".into(),
347 client_secret: "client-secret".into(),
348 scopes: vec![],
349 refresh_skew_seconds: None,
350 },
351 clock,
352 )
353 .unwrap();
354
355 let err = provider.get_token().unwrap_err();
356 assert!(matches!(err, OAuth2TokenProviderError::TokenRequest(_)));
357 }
358}