apicize_lib/execution/
oauth2_client_tokens.rs

1//! This module implements OAuth2 client flow support, including support for caching tokens
2use crate::{ApicizeError, Certificate, Identifiable, Proxy};
3use oauth2::basic::BasicClient;
4use oauth2::{reqwest, AuthType};
5use oauth2::{ClientId, ClientSecret, Scope, TokenResponse, TokenUrl};
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::ops::Add;
9use std::sync::LazyLock;
10use std::time::{SystemTime, UNIX_EPOCH};
11use tokio::sync::Mutex;
12
13pub static TOKEN_CACHE: LazyLock<Mutex<HashMap<String, CachedTokenInfo>>> =
14    LazyLock::new(|| Mutex::new(HashMap::new()));
15
16/// Cached token
17#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)]
18#[serde(rename_all = "camelCase")]
19pub struct CachedTokenInfo {
20    /// Access token
21    pub access_token: String,
22    /// Refresh token
23    pub refresh_token: Option<String>,
24    /// Expiration of token in seconds past Unix epoch
25    pub expiration: Option<u64>,
26}
27
28/// OAuth2 issued client token result
29#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)]
30#[serde(rename_all = "camelCase")]
31pub struct TokenResult {
32    /// Issued token
33    pub token: String,
34    /// Set to True if token was retrieved via cache
35    pub cached: bool,
36    /// URL used to retrieve token
37    #[serde(skip_serializing_if = "Option::is_none")]
38    pub url: Option<String>,
39    /// Name of the certificate parameter, if any
40    #[serde(skip_serializing_if = "Option::is_none")]
41    pub certificate: Option<String>,
42    /// Name of the proxy parameter, if any
43    #[serde(skip_serializing_if = "Option::is_none")]
44    pub proxy: Option<String>,
45}
46
47/// OAuth2 issued PKCE token result
48#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)]
49#[serde(rename_all = "camelCase")]
50pub struct PkceTokenResult {
51    /// Access token
52    pub access_token: String,
53    /// Refresh token
54    pub refresh_token: Option<String>,
55    /// Expiration of token in seconds past Unix epoch
56    pub expiration: Option<u64>,
57}
58
59/// Return cached oauth2 token, with indicator of whether value was retrieved from cache
60#[allow(clippy::too_many_arguments)]
61pub async fn get_oauth2_client_credentials<'a>(
62    id: &str,
63    token_url: &str,
64    client_id: &str,
65    client_secret: &str,
66    send_credentials_in_body: bool,
67    scopes: &'a Option<String>,
68    audience: &'a Option<String>,
69    certificate: Option<&'a Certificate>,
70    proxy: Option<&'a Proxy>,
71    enable_trace: bool,
72) -> Result<TokenResult, ApicizeError> {
73    // Check cache and return if token found and not expired
74    let mut locked_cache = TOKEN_CACHE.lock().await;
75    let valid_token = match locked_cache.get(id) {
76        Some(cached_token) => match cached_token.expiration {
77            Some(expiration) => {
78                let now = SystemTime::now()
79                    .duration_since(UNIX_EPOCH)
80                    .unwrap()
81                    .as_secs();
82                if expiration.gt(&now) {
83                    Some(cached_token.clone())
84                } else {
85                    None
86                }
87            }
88            None => None,
89        },
90        None => None,
91    };
92
93    if let Some(cached_token) = valid_token {
94        return Ok(TokenResult {
95            token: cached_token.access_token,
96            cached: true,
97            url: None,
98            certificate: None,
99            proxy: None,
100        });
101    }
102
103    // Retrieve an access token
104    let mut client = BasicClient::new(ClientId::new(String::from(client_id)))
105        .set_token_uri(
106            TokenUrl::new(String::from(token_url)).expect("Unable to parse OAuth token URL"),
107        )
108        .set_auth_type(if send_credentials_in_body {
109            AuthType::RequestBody
110        } else {
111            AuthType::BasicAuth
112        });
113
114    if !client_secret.trim().is_empty() {
115        client = client.set_client_secret(ClientSecret::new(String::from(client_secret)));
116    }
117
118    let mut token_request = client.exchange_client_credentials();
119
120    if let Some(scope_value) = &scopes {
121        if !scope_value.is_empty() {
122            token_request = token_request.add_scope(Scope::new(scope_value.clone()));
123        }
124    }
125
126    if let Some(audience_value) = &audience {
127        if !audience_value.is_empty() {
128            token_request = token_request.add_extra_param("audience", audience_value);
129        }
130    }
131
132    let mut reqwest_builder = reqwest::ClientBuilder::new()
133        .connection_verbose(enable_trace)
134        .redirect(reqwest::redirect::Policy::none());
135
136    // Add certificate to builder if configured
137    if let Some(active_cert) = certificate {
138        match active_cert.append_to_builder(reqwest_builder) {
139            Ok(updated_builder) => reqwest_builder = updated_builder,
140            Err(err) => {
141                return Err(ApicizeError::OAuth2Client {
142                    description: String::from("Error assigning OAuth certificate"),
143                    source: Some(Box::new(err)),
144                })
145            }
146        }
147    }
148
149    // Add proxy to builder if configured
150    if let Some(active_proxy) = proxy {
151        match active_proxy.append_to_builder(reqwest_builder) {
152            Ok(updated_builder) => reqwest_builder = updated_builder,
153            Err(err) => {
154                return Err(ApicizeError::OAuth2Client {
155                    description: String::from("Error assigning OAuth proxy"),
156                    source: Some(Box::new(ApicizeError::from_reqwest(err))),
157                })
158            }
159        }
160    }
161
162    let http_client = match reqwest_builder.build() {
163        Ok(client) => client,
164        Err(err) => {
165            return Err(ApicizeError::OAuth2Client {
166                description: String::from("Error building OAuth request"),
167                source: Some(Box::new(ApicizeError::from_reqwest(err))),
168            })
169        }
170    };
171
172    match token_request.request_async(&http_client).await {
173        Ok(token_response) => {
174            let expiration = token_response.expires_in().map(|token_expires_in|
175                SystemTime::now()
176                    .duration_since(UNIX_EPOCH)
177                    .unwrap()
178                    .as_secs()
179                    .add(token_expires_in.as_secs())
180            );
181            let token = token_response.access_token().secret().clone();
182            locked_cache.insert(
183                String::from(id),
184                CachedTokenInfo {
185                    access_token: token.clone(),
186                    refresh_token: None,
187                    expiration,
188                },
189            );
190            Ok(TokenResult {
191                token,
192                cached: false,
193                url: Some(String::from(token_url)),
194                certificate: certificate.map(|c| c.get_name().to_owned()),
195                proxy: proxy.map(|p| p.get_name().to_owned()),
196            })
197        }
198        Err(err) => Err(ApicizeError::OAuth2Client {
199            description: String::from("Error dispatching OAuth2 token request"),
200            source: Some(Box::new(ApicizeError::from_oauth2(err))),
201        }),
202    }
203}
204
205/// Store OAuth2 token in cache
206pub async fn store_oauth2_token(authorization_id: &str, token_info: CachedTokenInfo) {
207    let locked_cache = &mut TOKEN_CACHE.lock().await;
208    locked_cache.insert(authorization_id.to_owned(), token_info);
209}
210
211/// Clear all cached OAuth2 tokens
212pub async fn clear_all_oauth2_tokens<'a>() -> usize {
213    let locked_cache = &mut TOKEN_CACHE.lock().await;
214    let count = locked_cache.len();
215    locked_cache.clear();
216    count
217}
218
219/// Clear specified cached OAuth2 credentials, returning true if value was cached
220pub async fn clear_oauth2_token(id: &str) -> bool {
221    let mut locked_cache = TOKEN_CACHE.lock().await;
222    locked_cache.remove(&String::from(id)).is_some()
223}
224
225#[cfg(test)]
226pub mod tests {
227    use std::ops::{Add, Sub};
228    use std::time::{SystemTime, UNIX_EPOCH};
229
230    use mockall::automock;
231    use serial_test::{parallel, serial};
232
233    use crate::oauth2_client_tokens::{
234        clear_all_oauth2_tokens, clear_oauth2_token, get_oauth2_client_credentials,
235        CachedTokenInfo, TokenResult, TOKEN_CACHE,
236    };
237
238    pub struct OAuth2ClientTokens;
239    #[automock]
240    impl OAuth2ClientTokens {
241        pub async fn get_oauth2_client_credentials<'a>(
242            _id: &str,
243            _token_url: &str,
244            _client_id: &str,
245            _client_secret: &str,
246            _send_credentials_in_body: bool,
247            _scope: &'a Option<String>,
248            _audience: &'a Option<String>,
249            _certificate: Option<&'a crate::Certificate>,
250            _proxy: Option<&'a crate::Proxy>,
251            _enable_trace: bool,
252        ) -> Result<TokenResult, crate::ApicizeError> {
253            Ok(TokenResult {
254                token: String::from(""),
255                cached: false,
256                url: None,
257                certificate: None,
258                proxy: None,
259            })
260        }
261        pub async fn clear_all_oauth2_tokens<'a>() -> usize {
262            1
263        }
264        pub async fn clear_oauth2_token(_id: &str) -> bool {
265            true
266        }
267    }
268
269    // Note - because we are using shared storage for cached tokens, some tests cannot be run in parallel, thus the "serial" attributes.
270    // We also do explicitly run some tests in parallel to ensure that the module itself is threadsafe.
271
272    const FAKE_TOKEN: &str = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c";
273
274    #[tokio::test()]
275    #[serial]
276    async fn get_oauth2_client_credentials_returns_cached_token() {
277        {
278            let mut locked_cache = TOKEN_CACHE.lock().await;
279            locked_cache.clear();
280            let expiration = Some(
281                SystemTime::now()
282                    .duration_since(UNIX_EPOCH)
283                    .unwrap()
284                    .as_secs()
285                    .add(10),
286            );
287            locked_cache.insert(
288                String::from("abc"),
289                CachedTokenInfo {
290                    expiration,
291                    access_token: String::from("123"),
292                    refresh_token: None,
293                },
294            );
295        }
296        assert_eq!(
297            (get_oauth2_client_credentials(
298                "abc",
299                "http://server",
300                "me",
301                "shhh",
302                false,
303                &None,
304                &None,
305                None,
306                None,
307                false,
308            )
309            .await)
310                .unwrap(),
311            TokenResult {
312                token: String::from("123"),
313                cached: true,
314                url: None,
315                certificate: None,
316                proxy: None
317            }
318        );
319    }
320
321    #[tokio::test]
322    #[serial]
323    async fn get_oauth2_client_credentials_calls_server() {
324        {
325            let mut locked_cache = TOKEN_CACHE.lock().await;
326            locked_cache.clear();
327        }
328        let mut server = mockito::Server::new_async().await;
329        let oauth2_response = format!(
330            "{{\"access_token\":\"{}\",\"expires_in\":86400,\"token_type\":\"Bearer\"}}",
331            FAKE_TOKEN
332        );
333        let mock = server
334            .mock("POST", "/")
335            // .match_body("foo")
336            .with_status(200)
337            .with_header("Content-Type", "application/json")
338            .with_body(oauth2_response)
339            .create();
340
341        let result = get_oauth2_client_credentials(
342            "abc",
343            server.url().as_str(),
344            "me",
345            "shhh",
346            false,
347            &None,
348            &None,
349            None,
350            None,
351            false,
352        )
353        .await;
354
355        mock.assert();
356
357        assert_eq!(
358            result.unwrap(),
359            TokenResult {
360                token: String::from(FAKE_TOKEN),
361                cached: false,
362                url: Some(server.url()),
363                certificate: None,
364                proxy: None
365            }
366        );
367
368        {
369            let locked_cache = TOKEN_CACHE.lock().await;
370            assert!(locked_cache.get(&String::from("abc")).is_some());
371        }
372    }
373
374    #[tokio::test]
375    #[serial]
376    async fn get_oauth2_client_credentials_ignores_expired_cache() {
377        let mut server = mockito::Server::new_async().await;
378        let oauth2_response = format!(
379            "{{\"access_token\":\"{}\",\"expires_in\":86400,\"token_type\":\"Bearer\"}}",
380            FAKE_TOKEN
381        );
382        let mock = server
383            .mock("POST", "/")
384            // .match_body("foo")
385            .with_status(200)
386            .with_header("Content-Type", "application/json")
387            .with_body(oauth2_response)
388            .create();
389
390        {
391            let mut locked_cache = TOKEN_CACHE.lock().await;
392            locked_cache.clear();
393            let expiration = Some(
394                SystemTime::now()
395                    .duration_since(UNIX_EPOCH)
396                    .unwrap()
397                    .as_secs()
398                    .sub(10),
399            );
400            let cached_token = CachedTokenInfo {
401                expiration,
402                access_token: String::from("123"),
403                refresh_token: None,
404            };
405            locked_cache.insert(String::from("abc"), cached_token.clone());
406            assert_eq!(locked_cache.get(&String::from("abc")), Some(&cached_token));
407        }
408
409        let result = get_oauth2_client_credentials(
410            "abc",
411            server.url().as_str(),
412            "me",
413            "shhh",
414            false,
415            &None,
416            &None,
417            None,
418            None,
419            false,
420        )
421        .await;
422
423        mock.assert();
424
425        assert_eq!(
426            result.unwrap(),
427            TokenResult {
428                token: String::from(FAKE_TOKEN),
429                cached: false,
430                url: Some(server.url()),
431                certificate: None,
432                proxy: None
433            }
434        );
435        {
436            let locked_cache = TOKEN_CACHE.lock().await;
437            assert!(locked_cache.get(&String::from("abc")).is_some());
438        }
439    }
440
441    #[tokio::test]
442    #[serial]
443    async fn clear_all_oauth2_tokens_clears_tokens() {
444        {
445            let mut locked_cache = TOKEN_CACHE.lock().await;
446            locked_cache.clear();
447            let expiration = Some(
448                SystemTime::now()
449                    .duration_since(UNIX_EPOCH)
450                    .unwrap()
451                    .as_secs()
452                    .add(10),
453            );
454            let cached_token = CachedTokenInfo {
455                expiration,
456                access_token: String::from("123"),
457                refresh_token: None,
458            };
459            locked_cache.insert(String::from("abc"), cached_token.clone());
460            assert_eq!(locked_cache.get(&String::from("abc")), Some(&cached_token));
461        }
462        assert_eq!(clear_all_oauth2_tokens().await, 1);
463        {
464            let locked_cache = TOKEN_CACHE.lock().await;
465            assert_eq!(locked_cache.len(), 0);
466        }
467    }
468
469    #[tokio::test]
470    #[serial]
471    async fn clear_oauth2_token_removes_item() {
472        {
473            let mut locked_cache = TOKEN_CACHE.lock().await;
474            locked_cache.clear();
475            let expiration = Some(
476                SystemTime::now()
477                    .duration_since(UNIX_EPOCH)
478                    .unwrap()
479                    .as_secs()
480                    .add(10),
481            );
482            let cached_token = CachedTokenInfo {
483                expiration,
484                access_token: String::from("123"),
485                refresh_token: None,
486            };
487            locked_cache.insert(String::from("abc"), cached_token.clone());
488            assert_eq!(locked_cache.get(&String::from("abc")), Some(&cached_token));
489        }
490        assert_eq!(clear_oauth2_token("abc").await, true);
491        {
492            let locked_cache = TOKEN_CACHE.lock().await;
493            assert_eq!(locked_cache.get(&String::from("abc")), None);
494        }
495    }
496
497    #[tokio::test]
498    #[serial]
499    async fn clear_oauth2_token_ignores_invalid_id() {
500        assert_eq!(clear_oauth2_token("abc_bogus").await, false);
501    }
502
503    #[tokio::test()]
504    #[parallel]
505    async fn get_oauth2_client_credentials_parallel_1() {
506        let mut server = mockito::Server::new_async().await;
507        let oauth2_response = format!(
508            "{{\"access_token\":\"{}\",\"expires_in\":86400,\"token_type\":\"Bearer\"}}",
509            FAKE_TOKEN
510        );
511        let mock = server
512            .mock("POST", "/")
513            // .match_body("foo")
514            .with_status(200)
515            .with_header("Content-Type", "application/json")
516            .with_body(oauth2_response)
517            .create();
518        assert_eq!(
519            (get_oauth2_client_credentials(
520                "abc1",
521                &server.url(),
522                "me",
523                "shhh",
524                false,
525                &None,
526                &None,
527                None,
528                None,
529                false,
530            )
531            .await)
532                .unwrap(),
533            TokenResult {
534                token: String::from(FAKE_TOKEN),
535                cached: false,
536                url: Some(server.url()),
537                certificate: None,
538                proxy: None
539            }
540        );
541        mock.assert();
542
543        // Second attempt will use cache
544        assert_eq!(
545            (get_oauth2_client_credentials(
546                "abc1",
547                &server.url(),
548                "me",
549                "shhh",
550                false,
551                &None,
552                &None,
553                None,
554                None,
555                false,
556            )
557            .await)
558                .unwrap(),
559            TokenResult {
560                token: String::from(FAKE_TOKEN),
561                cached: true,
562                url: None,
563                certificate: None,
564                proxy: None
565            }
566        );
567        mock.expect_at_most(0);
568    }
569
570    #[tokio::test()]
571    #[parallel]
572    async fn get_oauth2_client_credentials_parallel_2() {
573        let mut server = mockito::Server::new_async().await;
574        let oauth2_response = format!(
575            "{{\"access_token\":\"{}\",\"expires_in\":86400,\"token_type\":\"Bearer\"}}",
576            FAKE_TOKEN
577        );
578        let mock = server
579            .mock("POST", "/")
580            // .match_body("foo")
581            .with_status(200)
582            .with_header("Content-Type", "application/json")
583            .with_body(oauth2_response)
584            .create();
585        assert_eq!(
586            (get_oauth2_client_credentials(
587                "abc2",
588                &server.url(),
589                "me",
590                "shhh",
591                false,
592                &None,
593                &None,
594                None,
595                None,
596                false,
597            )
598            .await)
599                .unwrap(),
600            TokenResult {
601                token: String::from(FAKE_TOKEN),
602                cached: false,
603                url: Some(server.url()),
604                certificate: None,
605                proxy: None
606            }
607        );
608        mock.assert();
609
610        // Second attempt will use cache
611        assert_eq!(
612            (get_oauth2_client_credentials(
613                "abc2",
614                &server.url(),
615                "me",
616                "shhh",
617                false,
618                &None,
619                &None,
620                None,
621                None,
622                false,
623            )
624            .await)
625                .unwrap(),
626            TokenResult {
627                token: String::from(FAKE_TOKEN),
628                cached: true,
629                url: None,
630                certificate: None,
631                proxy: None
632            }
633        );
634        mock.expect_at_most(0);
635    }
636
637    #[tokio::test()]
638    #[parallel]
639    async fn get_oauth2_client_credentials_parallel_3() {
640        let mut server = mockito::Server::new_async().await;
641        let oauth2_response = format!(
642            "{{\"access_token\":\"{}\",\"expires_in\":86400,\"token_type\":\"Bearer\"}}",
643            FAKE_TOKEN
644        );
645        let mock = server
646            .mock("POST", "/")
647            // .match_body("foo")
648            .with_status(200)
649            .with_header("Content-Type", "application/json")
650            .with_body(oauth2_response)
651            .create();
652        assert_eq!(
653            (get_oauth2_client_credentials(
654                "abc3",
655                &server.url(),
656                "me",
657                "shhh",
658                false,
659                &None,
660                &None,
661                None,
662                None,
663                false,
664            )
665            .await)
666                .unwrap(),
667            TokenResult {
668                token: String::from(FAKE_TOKEN),
669                cached: false,
670                url: Some(server.url()),
671                certificate: None,
672                proxy: None
673            }
674        );
675        mock.assert();
676
677        // Second attempt will use cache
678        assert_eq!(
679            (get_oauth2_client_credentials(
680                "abc3",
681                &server.url(),
682                "me",
683                "shhh",
684                false,
685                &None,
686                &None,
687                None,
688                None,
689                false,
690            )
691            .await)
692                .unwrap(),
693            TokenResult {
694                token: String::from(FAKE_TOKEN),
695                cached: true,
696                url: None,
697                certificate: None,
698                proxy: None
699            }
700        );
701        mock.expect_at_most(0);
702    }
703}