1use crate::{ApicizeError, Certificate, Identifiable, Proxy};
3use oauth2::basic::BasicClient;
4use oauth2::{reqwest, AuthType};
5use oauth2::{ClientId, ClientSecret, Scope, TokenResponse, TokenUrl};
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::ops::Add;
9use std::sync::LazyLock;
10use std::time::{SystemTime, UNIX_EPOCH};
11use tokio::sync::Mutex;
12
13pub static TOKEN_CACHE: LazyLock<Mutex<HashMap<String, CachedTokenInfo>>> =
14 LazyLock::new(|| Mutex::new(HashMap::new()));
15
16#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)]
18#[serde(rename_all = "camelCase")]
19pub struct CachedTokenInfo {
20 pub access_token: String,
22 pub refresh_token: Option<String>,
24 pub expiration: Option<u64>,
26}
27
28#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)]
30#[serde(rename_all = "camelCase")]
31pub struct TokenResult {
32 pub token: String,
34 pub cached: bool,
36 #[serde(skip_serializing_if = "Option::is_none")]
38 pub url: Option<String>,
39 #[serde(skip_serializing_if = "Option::is_none")]
41 pub certificate: Option<String>,
42 #[serde(skip_serializing_if = "Option::is_none")]
44 pub proxy: Option<String>,
45}
46
47#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)]
49#[serde(rename_all = "camelCase")]
50pub struct PkceTokenResult {
51 pub access_token: String,
53 pub refresh_token: Option<String>,
55 pub expiration: Option<u64>,
57}
58
59#[allow(clippy::too_many_arguments)]
61pub async fn get_oauth2_client_credentials<'a>(
62 id: &str,
63 token_url: &str,
64 client_id: &str,
65 client_secret: &str,
66 send_credentials_in_body: bool,
67 scopes: &'a Option<String>,
68 audience: &'a Option<String>,
69 certificate: Option<&'a Certificate>,
70 proxy: Option<&'a Proxy>,
71 enable_trace: bool,
72) -> Result<TokenResult, ApicizeError> {
73 let mut locked_cache = TOKEN_CACHE.lock().await;
75 let valid_token = match locked_cache.get(id) {
76 Some(cached_token) => match cached_token.expiration {
77 Some(expiration) => {
78 let now = SystemTime::now()
79 .duration_since(UNIX_EPOCH)
80 .unwrap()
81 .as_secs();
82 if expiration.gt(&now) {
83 Some(cached_token.clone())
84 } else {
85 None
86 }
87 }
88 None => None,
89 },
90 None => None,
91 };
92
93 if let Some(cached_token) = valid_token {
94 return Ok(TokenResult {
95 token: cached_token.access_token,
96 cached: true,
97 url: None,
98 certificate: None,
99 proxy: None,
100 });
101 }
102
103 let mut client = BasicClient::new(ClientId::new(String::from(client_id)))
105 .set_token_uri(
106 TokenUrl::new(String::from(token_url)).expect("Unable to parse OAuth token URL"),
107 )
108 .set_auth_type(if send_credentials_in_body {
109 AuthType::RequestBody
110 } else {
111 AuthType::BasicAuth
112 });
113
114 if !client_secret.trim().is_empty() {
115 client = client.set_client_secret(ClientSecret::new(String::from(client_secret)));
116 }
117
118 let mut token_request = client.exchange_client_credentials();
119
120 if let Some(scope_value) = &scopes {
121 if !scope_value.is_empty() {
122 token_request = token_request.add_scope(Scope::new(scope_value.clone()));
123 }
124 }
125
126 if let Some(audience_value) = &audience {
127 if !audience_value.is_empty() {
128 token_request = token_request.add_extra_param("audience", audience_value);
129 }
130 }
131
132 let mut reqwest_builder = reqwest::ClientBuilder::new()
133 .connection_verbose(enable_trace)
134 .redirect(reqwest::redirect::Policy::none());
135
136 if let Some(active_cert) = certificate {
138 match active_cert.append_to_builder(reqwest_builder) {
139 Ok(updated_builder) => reqwest_builder = updated_builder,
140 Err(err) => {
141 return Err(ApicizeError::OAuth2Client {
142 description: String::from("Error assigning OAuth certificate"),
143 source: Some(Box::new(err)),
144 })
145 }
146 }
147 }
148
149 if let Some(active_proxy) = proxy {
151 match active_proxy.append_to_builder(reqwest_builder) {
152 Ok(updated_builder) => reqwest_builder = updated_builder,
153 Err(err) => {
154 return Err(ApicizeError::OAuth2Client {
155 description: String::from("Error assigning OAuth proxy"),
156 source: Some(Box::new(ApicizeError::from_reqwest(err))),
157 })
158 }
159 }
160 }
161
162 let http_client = match reqwest_builder.build() {
163 Ok(client) => client,
164 Err(err) => {
165 return Err(ApicizeError::OAuth2Client {
166 description: String::from("Error building OAuth request"),
167 source: Some(Box::new(ApicizeError::from_reqwest(err))),
168 })
169 }
170 };
171
172 match token_request.request_async(&http_client).await {
173 Ok(token_response) => {
174 let expiration = token_response.expires_in().map(|token_expires_in|
175 SystemTime::now()
176 .duration_since(UNIX_EPOCH)
177 .unwrap()
178 .as_secs()
179 .add(token_expires_in.as_secs())
180 );
181 let token = token_response.access_token().secret().clone();
182 locked_cache.insert(
183 String::from(id),
184 CachedTokenInfo {
185 access_token: token.clone(),
186 refresh_token: None,
187 expiration,
188 },
189 );
190 Ok(TokenResult {
191 token,
192 cached: false,
193 url: Some(String::from(token_url)),
194 certificate: certificate.map(|c| c.get_name().to_owned()),
195 proxy: proxy.map(|p| p.get_name().to_owned()),
196 })
197 }
198 Err(err) => Err(ApicizeError::OAuth2Client {
199 description: String::from("Error dispatching OAuth2 token request"),
200 source: Some(Box::new(ApicizeError::from_oauth2(err))),
201 }),
202 }
203}
204
205pub async fn store_oauth2_token(authorization_id: &str, token_info: CachedTokenInfo) {
207 let locked_cache = &mut TOKEN_CACHE.lock().await;
208 locked_cache.insert(authorization_id.to_owned(), token_info);
209}
210
211pub async fn clear_all_oauth2_tokens<'a>() -> usize {
213 let locked_cache = &mut TOKEN_CACHE.lock().await;
214 let count = locked_cache.len();
215 locked_cache.clear();
216 count
217}
218
219pub async fn clear_oauth2_token(id: &str) -> bool {
221 let mut locked_cache = TOKEN_CACHE.lock().await;
222 locked_cache.remove(&String::from(id)).is_some()
223}
224
225#[cfg(test)]
226pub mod tests {
227 use std::ops::{Add, Sub};
228 use std::time::{SystemTime, UNIX_EPOCH};
229
230 use mockall::automock;
231 use serial_test::{parallel, serial};
232
233 use crate::oauth2_client_tokens::{
234 clear_all_oauth2_tokens, clear_oauth2_token, get_oauth2_client_credentials,
235 CachedTokenInfo, TokenResult, TOKEN_CACHE,
236 };
237
238 pub struct OAuth2ClientTokens;
239 #[automock]
240 impl OAuth2ClientTokens {
241 pub async fn get_oauth2_client_credentials<'a>(
242 _id: &str,
243 _token_url: &str,
244 _client_id: &str,
245 _client_secret: &str,
246 _send_credentials_in_body: bool,
247 _scope: &'a Option<String>,
248 _audience: &'a Option<String>,
249 _certificate: Option<&'a crate::Certificate>,
250 _proxy: Option<&'a crate::Proxy>,
251 _enable_trace: bool,
252 ) -> Result<TokenResult, crate::ApicizeError> {
253 Ok(TokenResult {
254 token: String::from(""),
255 cached: false,
256 url: None,
257 certificate: None,
258 proxy: None,
259 })
260 }
261 pub async fn clear_all_oauth2_tokens<'a>() -> usize {
262 1
263 }
264 pub async fn clear_oauth2_token(_id: &str) -> bool {
265 true
266 }
267 }
268
269 const FAKE_TOKEN: &str = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c";
273
274 #[tokio::test()]
275 #[serial]
276 async fn get_oauth2_client_credentials_returns_cached_token() {
277 {
278 let mut locked_cache = TOKEN_CACHE.lock().await;
279 locked_cache.clear();
280 let expiration = Some(
281 SystemTime::now()
282 .duration_since(UNIX_EPOCH)
283 .unwrap()
284 .as_secs()
285 .add(10),
286 );
287 locked_cache.insert(
288 String::from("abc"),
289 CachedTokenInfo {
290 expiration,
291 access_token: String::from("123"),
292 refresh_token: None,
293 },
294 );
295 }
296 assert_eq!(
297 (get_oauth2_client_credentials(
298 "abc",
299 "http://server",
300 "me",
301 "shhh",
302 false,
303 &None,
304 &None,
305 None,
306 None,
307 false,
308 )
309 .await)
310 .unwrap(),
311 TokenResult {
312 token: String::from("123"),
313 cached: true,
314 url: None,
315 certificate: None,
316 proxy: None
317 }
318 );
319 }
320
321 #[tokio::test]
322 #[serial]
323 async fn get_oauth2_client_credentials_calls_server() {
324 {
325 let mut locked_cache = TOKEN_CACHE.lock().await;
326 locked_cache.clear();
327 }
328 let mut server = mockito::Server::new_async().await;
329 let oauth2_response = format!(
330 "{{\"access_token\":\"{}\",\"expires_in\":86400,\"token_type\":\"Bearer\"}}",
331 FAKE_TOKEN
332 );
333 let mock = server
334 .mock("POST", "/")
335 .with_status(200)
337 .with_header("Content-Type", "application/json")
338 .with_body(oauth2_response)
339 .create();
340
341 let result = get_oauth2_client_credentials(
342 "abc",
343 server.url().as_str(),
344 "me",
345 "shhh",
346 false,
347 &None,
348 &None,
349 None,
350 None,
351 false,
352 )
353 .await;
354
355 mock.assert();
356
357 assert_eq!(
358 result.unwrap(),
359 TokenResult {
360 token: String::from(FAKE_TOKEN),
361 cached: false,
362 url: Some(server.url()),
363 certificate: None,
364 proxy: None
365 }
366 );
367
368 {
369 let locked_cache = TOKEN_CACHE.lock().await;
370 assert!(locked_cache.get(&String::from("abc")).is_some());
371 }
372 }
373
374 #[tokio::test]
375 #[serial]
376 async fn get_oauth2_client_credentials_ignores_expired_cache() {
377 let mut server = mockito::Server::new_async().await;
378 let oauth2_response = format!(
379 "{{\"access_token\":\"{}\",\"expires_in\":86400,\"token_type\":\"Bearer\"}}",
380 FAKE_TOKEN
381 );
382 let mock = server
383 .mock("POST", "/")
384 .with_status(200)
386 .with_header("Content-Type", "application/json")
387 .with_body(oauth2_response)
388 .create();
389
390 {
391 let mut locked_cache = TOKEN_CACHE.lock().await;
392 locked_cache.clear();
393 let expiration = Some(
394 SystemTime::now()
395 .duration_since(UNIX_EPOCH)
396 .unwrap()
397 .as_secs()
398 .sub(10),
399 );
400 let cached_token = CachedTokenInfo {
401 expiration,
402 access_token: String::from("123"),
403 refresh_token: None,
404 };
405 locked_cache.insert(String::from("abc"), cached_token.clone());
406 assert_eq!(locked_cache.get(&String::from("abc")), Some(&cached_token));
407 }
408
409 let result = get_oauth2_client_credentials(
410 "abc",
411 server.url().as_str(),
412 "me",
413 "shhh",
414 false,
415 &None,
416 &None,
417 None,
418 None,
419 false,
420 )
421 .await;
422
423 mock.assert();
424
425 assert_eq!(
426 result.unwrap(),
427 TokenResult {
428 token: String::from(FAKE_TOKEN),
429 cached: false,
430 url: Some(server.url()),
431 certificate: None,
432 proxy: None
433 }
434 );
435 {
436 let locked_cache = TOKEN_CACHE.lock().await;
437 assert!(locked_cache.get(&String::from("abc")).is_some());
438 }
439 }
440
441 #[tokio::test]
442 #[serial]
443 async fn clear_all_oauth2_tokens_clears_tokens() {
444 {
445 let mut locked_cache = TOKEN_CACHE.lock().await;
446 locked_cache.clear();
447 let expiration = Some(
448 SystemTime::now()
449 .duration_since(UNIX_EPOCH)
450 .unwrap()
451 .as_secs()
452 .add(10),
453 );
454 let cached_token = CachedTokenInfo {
455 expiration,
456 access_token: String::from("123"),
457 refresh_token: None,
458 };
459 locked_cache.insert(String::from("abc"), cached_token.clone());
460 assert_eq!(locked_cache.get(&String::from("abc")), Some(&cached_token));
461 }
462 assert_eq!(clear_all_oauth2_tokens().await, 1);
463 {
464 let locked_cache = TOKEN_CACHE.lock().await;
465 assert_eq!(locked_cache.len(), 0);
466 }
467 }
468
469 #[tokio::test]
470 #[serial]
471 async fn clear_oauth2_token_removes_item() {
472 {
473 let mut locked_cache = TOKEN_CACHE.lock().await;
474 locked_cache.clear();
475 let expiration = Some(
476 SystemTime::now()
477 .duration_since(UNIX_EPOCH)
478 .unwrap()
479 .as_secs()
480 .add(10),
481 );
482 let cached_token = CachedTokenInfo {
483 expiration,
484 access_token: String::from("123"),
485 refresh_token: None,
486 };
487 locked_cache.insert(String::from("abc"), cached_token.clone());
488 assert_eq!(locked_cache.get(&String::from("abc")), Some(&cached_token));
489 }
490 assert_eq!(clear_oauth2_token("abc").await, true);
491 {
492 let locked_cache = TOKEN_CACHE.lock().await;
493 assert_eq!(locked_cache.get(&String::from("abc")), None);
494 }
495 }
496
497 #[tokio::test]
498 #[serial]
499 async fn clear_oauth2_token_ignores_invalid_id() {
500 assert_eq!(clear_oauth2_token("abc_bogus").await, false);
501 }
502
503 #[tokio::test()]
504 #[parallel]
505 async fn get_oauth2_client_credentials_parallel_1() {
506 let mut server = mockito::Server::new_async().await;
507 let oauth2_response = format!(
508 "{{\"access_token\":\"{}\",\"expires_in\":86400,\"token_type\":\"Bearer\"}}",
509 FAKE_TOKEN
510 );
511 let mock = server
512 .mock("POST", "/")
513 .with_status(200)
515 .with_header("Content-Type", "application/json")
516 .with_body(oauth2_response)
517 .create();
518 assert_eq!(
519 (get_oauth2_client_credentials(
520 "abc1",
521 &server.url(),
522 "me",
523 "shhh",
524 false,
525 &None,
526 &None,
527 None,
528 None,
529 false,
530 )
531 .await)
532 .unwrap(),
533 TokenResult {
534 token: String::from(FAKE_TOKEN),
535 cached: false,
536 url: Some(server.url()),
537 certificate: None,
538 proxy: None
539 }
540 );
541 mock.assert();
542
543 assert_eq!(
545 (get_oauth2_client_credentials(
546 "abc1",
547 &server.url(),
548 "me",
549 "shhh",
550 false,
551 &None,
552 &None,
553 None,
554 None,
555 false,
556 )
557 .await)
558 .unwrap(),
559 TokenResult {
560 token: String::from(FAKE_TOKEN),
561 cached: true,
562 url: None,
563 certificate: None,
564 proxy: None
565 }
566 );
567 mock.expect_at_most(0);
568 }
569
570 #[tokio::test()]
571 #[parallel]
572 async fn get_oauth2_client_credentials_parallel_2() {
573 let mut server = mockito::Server::new_async().await;
574 let oauth2_response = format!(
575 "{{\"access_token\":\"{}\",\"expires_in\":86400,\"token_type\":\"Bearer\"}}",
576 FAKE_TOKEN
577 );
578 let mock = server
579 .mock("POST", "/")
580 .with_status(200)
582 .with_header("Content-Type", "application/json")
583 .with_body(oauth2_response)
584 .create();
585 assert_eq!(
586 (get_oauth2_client_credentials(
587 "abc2",
588 &server.url(),
589 "me",
590 "shhh",
591 false,
592 &None,
593 &None,
594 None,
595 None,
596 false,
597 )
598 .await)
599 .unwrap(),
600 TokenResult {
601 token: String::from(FAKE_TOKEN),
602 cached: false,
603 url: Some(server.url()),
604 certificate: None,
605 proxy: None
606 }
607 );
608 mock.assert();
609
610 assert_eq!(
612 (get_oauth2_client_credentials(
613 "abc2",
614 &server.url(),
615 "me",
616 "shhh",
617 false,
618 &None,
619 &None,
620 None,
621 None,
622 false,
623 )
624 .await)
625 .unwrap(),
626 TokenResult {
627 token: String::from(FAKE_TOKEN),
628 cached: true,
629 url: None,
630 certificate: None,
631 proxy: None
632 }
633 );
634 mock.expect_at_most(0);
635 }
636
637 #[tokio::test()]
638 #[parallel]
639 async fn get_oauth2_client_credentials_parallel_3() {
640 let mut server = mockito::Server::new_async().await;
641 let oauth2_response = format!(
642 "{{\"access_token\":\"{}\",\"expires_in\":86400,\"token_type\":\"Bearer\"}}",
643 FAKE_TOKEN
644 );
645 let mock = server
646 .mock("POST", "/")
647 .with_status(200)
649 .with_header("Content-Type", "application/json")
650 .with_body(oauth2_response)
651 .create();
652 assert_eq!(
653 (get_oauth2_client_credentials(
654 "abc3",
655 &server.url(),
656 "me",
657 "shhh",
658 false,
659 &None,
660 &None,
661 None,
662 None,
663 false,
664 )
665 .await)
666 .unwrap(),
667 TokenResult {
668 token: String::from(FAKE_TOKEN),
669 cached: false,
670 url: Some(server.url()),
671 certificate: None,
672 proxy: None
673 }
674 );
675 mock.assert();
676
677 assert_eq!(
679 (get_oauth2_client_credentials(
680 "abc3",
681 &server.url(),
682 "me",
683 "shhh",
684 false,
685 &None,
686 &None,
687 None,
688 None,
689 false,
690 )
691 .await)
692 .unwrap(),
693 TokenResult {
694 token: String::from(FAKE_TOKEN),
695 cached: true,
696 url: None,
697 certificate: None,
698 proxy: None
699 }
700 );
701 mock.expect_at_most(0);
702 }
703}