clawspec_core/client/oauth2/
provider.rs1use std::time::Duration;
4
5use oauth2::{AccessToken, TokenResponse};
6
7use super::config::{OAuth2Config, OAuth2GrantType};
8use super::error::OAuth2Error;
9use super::token::OAuth2Token;
10
11impl OAuth2Config {
12 pub async fn acquire_token(&self) -> Result<OAuth2Token, OAuth2Error> {
25 match self.grant_type {
26 OAuth2GrantType::ClientCredentials => self.acquire_client_credentials_token().await,
27 OAuth2GrantType::PreAcquired => self.get_pre_acquired_token().await,
28 }
29 }
30
31 async fn acquire_client_credentials_token(&self) -> Result<OAuth2Token, OAuth2Error> {
33 let http_client = oauth2::reqwest::ClientBuilder::new()
36 .redirect(oauth2::reqwest::redirect::Policy::none())
37 .build()
38 .map_err(|e| OAuth2Error::TokenAcquisitionFailed {
39 reason: format!("Failed to create HTTP client: {e}"),
40 })?;
41
42 self.acquire_client_credentials_token_with_client(&http_client)
43 .await
44 }
45
46 pub(crate) async fn acquire_client_credentials_token_with_client(
51 &self,
52 http_client: &oauth2::reqwest::Client,
53 ) -> Result<OAuth2Token, OAuth2Error> {
54 use oauth2::basic::BasicClient;
55 use oauth2::{AuthUrl, ClientId, ClientSecret, Scope, TokenUrl};
56
57 let client_id = ClientId::new(self.client_id.clone());
58
59 let auth_url_str = self
61 .auth_url
62 .as_ref()
63 .map(|u| u.to_string())
64 .unwrap_or_else(|| format!("{}/../authorize", self.token_url));
65
66 let auth_url = AuthUrl::new(auth_url_str).map_err(|e| OAuth2Error::ConfigurationError {
67 reason: format!("Invalid authorization URL: {e}"),
68 })?;
69
70 let token_url = TokenUrl::new(self.token_url.to_string()).map_err(|e| {
71 OAuth2Error::ConfigurationError {
72 reason: format!("Invalid token URL: {e}"),
73 }
74 })?;
75
76 let mut client = BasicClient::new(client_id)
80 .set_auth_uri(auth_url)
81 .set_token_uri(token_url);
82
83 if let Some(ref secret) = self.client_secret {
85 client = client.set_client_secret(ClientSecret::new(secret.as_str().to_string()));
86 }
87
88 let mut request = client.exchange_client_credentials();
89
90 for scope in self.scopes.iter().map(|s| Scope::new(s.clone())) {
92 request = request.add_scope(scope);
93 }
94
95 let token_result = request.request_async(http_client).await.map_err(|e| {
97 OAuth2Error::TokenAcquisitionFailed {
98 reason: format!("{e}"),
99 }
100 })?;
101
102 let token =
104 Self::convert_token_response(token_result.access_token(), token_result.expires_in());
105
106 self.set_token(token.clone()).await;
108
109 Ok(token)
110 }
111
112 async fn get_pre_acquired_token(&self) -> Result<OAuth2Token, OAuth2Error> {
114 self.get_token().await.ok_or(OAuth2Error::TokenExpired)
115 }
116
117 fn convert_token_response(
119 access_token: &AccessToken,
120 expires_in: Option<Duration>,
121 ) -> OAuth2Token {
122 if let Some(duration) = expires_in {
123 OAuth2Token::with_expiry(access_token.secret().clone(), duration)
124 } else {
125 OAuth2Token::new(access_token.secret().clone())
126 }
127 }
128
129 pub async fn get_valid_token(&self) -> Result<OAuth2Token, OAuth2Error> {
134 if !self.needs_token().await
136 && let Some(token) = self.get_token().await
137 {
138 return Ok(token);
139 }
140
141 self.acquire_token().await
143 }
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149
150 #[tokio::test]
155 async fn should_return_pre_acquired_token() {
156 let config = OAuth2Config::pre_acquired(
157 "client-id",
158 "https://auth.example.com/token",
159 "pre-acquired-access-token",
160 )
161 .expect("Should create builder")
162 .build()
163 .expect("Should build config");
164
165 let token = config.get_valid_token().await.expect("Should get token");
166 assert_eq!(token.access_token(), "pre-acquired-access-token");
167 }
168
169 #[tokio::test]
170 async fn should_fail_when_no_pre_acquired_token() {
171 let config =
172 OAuth2Config::pre_acquired("client-id", "https://auth.example.com/token", "token")
173 .expect("Should create builder")
174 .build()
175 .expect("Should build config");
176
177 config.token_cache.clear().await;
178
179 let result = config.get_pre_acquired_token().await;
180 assert!(result.is_err());
181 match result.expect_err("Should fail") {
182 OAuth2Error::TokenExpired => {}
183 _ => panic!("Expected TokenExpired error"),
184 }
185 }
186
187 #[tokio::test]
192 async fn should_return_cached_token_without_network_call() {
193 let config = OAuth2Config::client_credentials(
194 "test-client",
195 "test-secret",
196 "https://auth.example.com/token",
197 )
198 .expect("Should create builder")
199 .build()
200 .expect("Should build config");
201
202 let token = OAuth2Token::with_expiry("cached-valid-token", Duration::from_secs(3600));
204 config.set_token(token).await;
205
206 let result = config.get_valid_token().await;
208 let token = result.expect("Should return cached token");
209 assert_eq!(token.access_token(), "cached-valid-token");
210 }
211
212 #[test]
217 fn should_convert_token_with_expiry() {
218 let access_token = oauth2::AccessToken::new("test-token".to_string());
219 let expires_in = Some(Duration::from_secs(3600));
220
221 let token = OAuth2Config::convert_token_response(&access_token, expires_in);
222
223 assert_eq!(token.access_token(), "test-token");
224 assert!(token.time_until_expiry().is_some());
225 }
226
227 #[test]
228 fn should_convert_token_without_expiry() {
229 let access_token = oauth2::AccessToken::new("no-expiry".to_string());
230 let expires_in = None;
231
232 let token = OAuth2Config::convert_token_response(&access_token, expires_in);
233
234 assert_eq!(token.access_token(), "no-expiry");
235 assert!(token.time_until_expiry().is_none());
236 }
237
238 #[tokio::test]
243 async fn should_configure_scopes() {
244 let config = OAuth2Config::client_credentials(
245 "test-client",
246 "test-secret",
247 "https://auth.example.com/token",
248 )
249 .expect("Should create builder")
250 .add_scope("read:users")
251 .add_scope("write:users")
252 .build()
253 .expect("Should build config");
254
255 assert_eq!(config.scopes, vec!["read:users", "write:users"]);
256 }
257
258 mod mock_server_tests {
263 use super::*;
264 use wiremock::matchers::{method, path};
265 use wiremock::{Mock, MockServer, ResponseTemplate};
266
267 #[tokio::test]
268 async fn should_acquire_client_credentials_token() {
269 let mock_server = MockServer::start().await;
270
271 Mock::given(method("POST"))
272 .and(path("/oauth/token"))
273 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
274 "access_token": "test-access-token-12345",
275 "token_type": "Bearer",
276 "expires_in": 3600
277 })))
278 .expect(1)
279 .mount(&mock_server)
280 .await;
281
282 let token_url = format!("{}/oauth/token", mock_server.uri());
283 let config = OAuth2Config::client_credentials("test-client", "test-secret", &token_url)
284 .expect("Should create builder")
285 .build()
286 .expect("Should build config");
287
288 let token = config
289 .acquire_token()
290 .await
291 .expect("Should acquire token successfully");
292
293 assert_eq!(token.access_token(), "test-access-token-12345");
294 assert!(token.time_until_expiry().is_some());
295 }
296
297 #[tokio::test]
298 async fn should_include_scopes_in_token_request() {
299 let mock_server = MockServer::start().await;
300
301 Mock::given(method("POST"))
302 .and(path("/oauth/token"))
303 .and(wiremock::matchers::body_string_contains(
304 "scope=read%3Ausers",
305 ))
306 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
307 "access_token": "scoped-token",
308 "token_type": "Bearer",
309 "expires_in": 3600
310 })))
311 .expect(1)
312 .mount(&mock_server)
313 .await;
314
315 let token_url = format!("{}/oauth/token", mock_server.uri());
316 let config = OAuth2Config::client_credentials("test-client", "test-secret", &token_url)
317 .expect("Should create builder")
318 .add_scope("read:users")
319 .build()
320 .expect("Should build config");
321
322 let token = config
323 .acquire_token()
324 .await
325 .expect("Should acquire token with scopes");
326
327 assert_eq!(token.access_token(), "scoped-token");
328 }
329
330 #[tokio::test]
331 async fn should_handle_token_request_failure() {
332 let mock_server = MockServer::start().await;
333
334 Mock::given(method("POST"))
335 .and(path("/oauth/token"))
336 .respond_with(ResponseTemplate::new(400).set_body_json(serde_json::json!({
337 "error": "invalid_client",
338 "error_description": "Client authentication failed"
339 })))
340 .expect(1)
341 .mount(&mock_server)
342 .await;
343
344 let token_url = format!("{}/oauth/token", mock_server.uri());
345 let config =
346 OAuth2Config::client_credentials("invalid-client", "wrong-secret", &token_url)
347 .expect("Should create builder")
348 .build()
349 .expect("Should build config");
350
351 let result = config.acquire_token().await;
352
353 assert!(result.is_err());
354 match result.expect_err("Should fail") {
355 OAuth2Error::TokenAcquisitionFailed { reason } => {
356 assert!(
357 reason.contains("invalid_client") || reason.contains("Client"),
358 "Error should contain client error info: {reason}"
359 );
360 }
361 other => panic!("Expected TokenAcquisitionFailed, got {:?}", other),
362 }
363 }
364
365 #[tokio::test]
366 async fn should_handle_invalid_token_url() {
367 let result =
369 OAuth2Config::client_credentials("test-client", "test-secret", "not-a-valid-url");
370
371 assert!(result.is_err());
372 }
373
374 #[tokio::test]
375 async fn should_acquire_token_with_multiple_scopes() {
376 let mock_server = MockServer::start().await;
377
378 Mock::given(method("POST"))
379 .and(path("/oauth/token"))
380 .and(wiremock::matchers::body_string_contains(
381 "scope=read%3Ausers",
382 ))
383 .and(wiremock::matchers::body_string_contains("write%3Ausers"))
384 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
385 "access_token": "multi-scope-token",
386 "token_type": "Bearer",
387 "expires_in": 3600
388 })))
389 .expect(1)
390 .mount(&mock_server)
391 .await;
392
393 let token_url = format!("{}/oauth/token", mock_server.uri());
394 let config = OAuth2Config::client_credentials("test-client", "test-secret", &token_url)
395 .expect("Should create builder")
396 .add_scope("read:users")
397 .add_scope("write:users")
398 .build()
399 .expect("Should build config");
400
401 let token = config
402 .acquire_token()
403 .await
404 .expect("Should acquire token with multiple scopes");
405
406 assert_eq!(token.access_token(), "multi-scope-token");
407 }
408
409 #[tokio::test]
410 async fn should_handle_token_without_expiry() {
411 let mock_server = MockServer::start().await;
412
413 Mock::given(method("POST"))
414 .and(path("/oauth/token"))
415 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
416 "access_token": "no-expiry-token",
417 "token_type": "Bearer"
418 })))
419 .expect(1)
420 .mount(&mock_server)
421 .await;
422
423 let token_url = format!("{}/oauth/token", mock_server.uri());
424 let config = OAuth2Config::client_credentials("test-client", "test-secret", &token_url)
425 .expect("Should create builder")
426 .build()
427 .expect("Should build config");
428
429 let token = config
430 .acquire_token()
431 .await
432 .expect("Should acquire token without expiry");
433
434 assert_eq!(token.access_token(), "no-expiry-token");
435 assert!(
436 token.time_until_expiry().is_none(),
437 "Token without expires_in should have no expiry"
438 );
439 }
440
441 #[tokio::test]
442 async fn should_cache_token_after_acquisition() {
443 let mock_server = MockServer::start().await;
444
445 Mock::given(method("POST"))
446 .and(path("/oauth/token"))
447 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
448 "access_token": "cached-token-value",
449 "token_type": "Bearer",
450 "expires_in": 3600
451 })))
452 .expect(1) .mount(&mock_server)
454 .await;
455
456 let token_url = format!("{}/oauth/token", mock_server.uri());
457 let config = OAuth2Config::client_credentials("test-client", "test-secret", &token_url)
458 .expect("Should create builder")
459 .build()
460 .expect("Should build config");
461
462 let token1 = config
464 .acquire_token()
465 .await
466 .expect("First token acquisition should succeed");
467
468 let token2 = config
470 .get_valid_token()
471 .await
472 .expect("Second call should use cached token");
473
474 assert_eq!(token1.access_token(), "cached-token-value");
475 assert_eq!(token2.access_token(), "cached-token-value");
476 }
477 }
478}