drogue_client/openid/provider/
openid.rs

1use crate::openid::Expires;
2use core::fmt::{self, Debug, Formatter};
3use std::{ops::Deref, sync::Arc};
4use tokio::sync::RwLock;
5
6/// A provider which provides access tokens for clients.
7#[derive(Clone)]
8pub struct OpenIdTokenProvider {
9    pub client: Arc<openid::Client>,
10    current_token: Arc<RwLock<Option<openid::Bearer>>>,
11    refresh_before: chrono::Duration,
12}
13
14impl Debug for OpenIdTokenProvider {
15    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
16        f.debug_struct("TokenProvider")
17            .field(
18                "client",
19                &format!("{} / {:?}", self.client.client_id, self.client.http_client),
20            )
21            .field("current_token", &"...")
22            .finish()
23    }
24}
25
26impl OpenIdTokenProvider {
27    /// Create a new provider using the provided client.
28    pub fn new(client: openid::Client, refresh_before: chrono::Duration) -> Self {
29        Self {
30            client: Arc::new(client),
31            current_token: Arc::new(RwLock::new(None)),
32            refresh_before,
33        }
34    }
35
36    /// return a fresh token, this may be an existing (non-expired) token
37    /// a newly refreshed token.
38    pub async fn provide_token(&self) -> Result<openid::Bearer, openid::error::Error> {
39        match self.current_token.read().await.deref() {
40            Some(token) if !token.expires_before(self.refresh_before) => {
41                log::debug!("Token still valid");
42                return Ok(token.clone());
43            }
44            _ => {}
45        }
46
47        // fetch fresh token after releasing the read lock
48
49        self.fetch_fresh_token().await
50    }
51
52    async fn fetch_fresh_token(&self) -> Result<openid::Bearer, openid::error::Error> {
53        log::debug!("Fetching fresh token...");
54
55        let mut lock = self.current_token.write().await;
56
57        match lock.deref() {
58            // check if someone else refreshed the token in the meantime
59            Some(token) if !token.expires_before(self.refresh_before) => {
60                log::debug!("Token already got refreshed");
61                return Ok(token.clone());
62            }
63            _ => {}
64        }
65
66        // we hold the write-lock now, and can perform the refresh operation
67
68        let next_token = match lock.take() {
69            // if we don't have any token, fetch an initial one
70            None => {
71                log::debug!("Fetching initial token... ");
72                self.initial_token().await?
73            }
74            // if we have an expired one, refresh it
75            Some(current_token) => {
76                log::debug!("Refreshing token ... ");
77                match current_token.refresh_token.is_some() {
78                    true => self.client.refresh_token(current_token, None).await?,
79                    false => self.initial_token().await?,
80                }
81            }
82        };
83
84        log::debug!("Next token: {:?}", next_token);
85
86        lock.replace(next_token.clone());
87
88        // done
89
90        Ok(next_token)
91    }
92
93    async fn initial_token(&self) -> Result<openid::Bearer, openid::error::Error> {
94        Ok(self.client.request_token_using_client_credentials().await?)
95    }
96}
97
98#[cfg(all(feature = "reqwest", not(target_arch = "wasm32")))]
99use crate::{
100    error::ClientError,
101    openid::{provider::TokenProvider, Credentials},
102};
103
104#[cfg(all(feature = "reqwest", not(target_arch = "wasm32")))]
105#[async_trait::async_trait]
106impl TokenProvider for OpenIdTokenProvider {
107    async fn provide_access_token(&self) -> Result<Option<Credentials>, ClientError> {
108        self.provide_token()
109            .await
110            .map(|token| Some(Credentials::Bearer(token.access_token)))
111            .map_err(|err| ClientError::Token(Box::new(err)))
112    }
113}