1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
use std::sync::Arc;

use serde::{Deserialize, Serialize};
use tokio::sync::{Mutex, Notify, RwLock};

use super::{settings::AuthServer, ClientConfiguration, TokenError};
#[cfg(feature = "tracing-config")]
use crate::tracing_configuration::TracingConfiguration;
#[cfg(feature = "tracing")]
use urlpattern::UrlPatternMatchInput;

/// A single type containing an access token and an associated refresh token.
#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
#[cfg_attr(feature = "python", pyo3::pyclass)]
pub struct Tokens {
    /// The `Bearer` token to include in the `Authorization` header.
    pub bearer_access_token: String,
    /// The token used to refresh the access token.
    pub refresh_token: String,
    /// The server that issued the tokens.
    pub auth_server: AuthServer,
}

/// A wrapper for [`Tokens`] that provides thread-safe access to the inner tokens.
#[derive(Clone, Debug, Default)]
#[cfg_attr(feature = "python", pyo3::pyclass)]
pub struct TokenDispatcher {
    lock: Arc<RwLock<Tokens>>,
    refreshing: Arc<Mutex<bool>>,
    notify_refreshed: Arc<Notify>,
}

impl From<Tokens> for TokenDispatcher {
    fn from(value: Tokens) -> Self {
        Self {
            lock: Arc::new(RwLock::new(value)),
            refreshing: Arc::new(Mutex::new(false)),
            notify_refreshed: Arc::new(Notify::new()),
        }
    }
}

impl TokenDispatcher {
    /// Executes a user-provided closure on a reference to the `Tokens` instance managed by the
    /// dispatcher.
    ///
    /// This function locks the mutex, safely exposing the protected `Tokens` instance to the provided closure `f`.
    /// It is designed to allow safe and controlled access to the `Tokens` instance for reading its state.
    ///
    /// # Parameters
    /// - `f`: A closure that takes a reference to `Tokens` and returns a value of type `O`. The closure is called
    ///   with the `Tokens` instance as an argument once the mutex is successfully locked.
    pub async fn use_tokens<F, O>(&self, f: F) -> O
    where
        F: FnOnce(&Tokens) -> O + Send,
    {
        let tokens = self.lock.read().await;
        f(&tokens)
    }

    /// Get a copy of the current access token.
    #[must_use]
    pub async fn tokens(&self) -> Tokens {
        self.use_tokens(Clone::clone).await
    }

    /// Refreshes the tokens. Readers will be blocked until the refresh is complete.
    ///
    /// # Errors
    ///
    /// See [`TokenError`]
    pub async fn refresh(&self) -> Result<Tokens, TokenError> {
        self.managed_refresh(Self::perform_refresh).await
    }

    /// If tokens are already being refreshed, wait and return the updated tokens. Otherwise, run
    /// ``refresh_fn``.
    async fn managed_refresh<F, Fut>(&self, refresh_fn: F) -> Result<Tokens, TokenError>
    where
        F: FnOnce(Arc<RwLock<Tokens>>) -> Fut + Send,
        Fut: std::future::Future<Output = Result<Tokens, TokenError>> + Send,
    {
        let mut is_refreshing = self.refreshing.lock().await;

        if *is_refreshing {
            drop(is_refreshing);
            self.notify_refreshed.notified().await;
            return Ok(self.tokens().await);
        }

        *is_refreshing = true;
        drop(is_refreshing);

        let result = refresh_fn(self.lock.clone()).await;

        *self.refreshing.lock().await = false;
        self.notify_refreshed.notify_waiters();

        result
    }

    /// Refreshes the tokens. Readers will be blocked until the refresh is complete.
    ///
    /// # Errors
    ///
    /// See [`TokenError`]
    async fn perform_refresh(lock: Arc<RwLock<Tokens>>) -> Result<Tokens, TokenError> {
        let mut tokens = lock.write().await;
        let auth_server = &tokens.auth_server;

        let token_url = format!("{}/v1/token", auth_server.issuer());
        let data = TokenRefreshRequest::new(auth_server.client_id(), &tokens.refresh_token);
        let resp = reqwest::Client::builder()
            .timeout(std::time::Duration::from_secs(10))
            .build()?
            .post(token_url)
            .form(&data)
            .send()
            .await?;

        let response_data: TokenResponse = resp.error_for_status()?.json().await?;
        tokens.bearer_access_token = response_data.access_token;
        tokens.refresh_token = response_data.refresh_token;
        Ok(tokens.clone())
    }
}

#[derive(Debug, Serialize, Deserialize)]
pub(super) struct TokenRefreshRequest<'a> {
    grant_type: &'static str,
    client_id: &'a str,
    refresh_token: &'a str,
}

impl<'a> TokenRefreshRequest<'a> {
    pub(super) const fn new(client_id: &'a str, refresh_token: &'a str) -> TokenRefreshRequest<'a> {
        Self {
            grant_type: "refresh_token",
            client_id,
            refresh_token,
        }
    }
}

#[derive(Deserialize, Debug, Serialize)]
pub(super) struct TokenResponse {
    pub(super) refresh_token: String,
    pub(super) access_token: String,
}

/// Get and refresh access tokens
#[async_trait::async_trait]
pub trait TokenRefresher: Clone + Send {
    /// The type to be returned in the event of a error during getting or
    /// refreshing an access token
    type Error;

    /// Get the current access token
    async fn get_access_token(&self) -> Result<String, Self::Error>;

    /// Get a fresh access token
    async fn refresh_access_token(&self) -> Result<String, Self::Error>;

    /// Get the base URL for requests
    #[cfg(feature = "tracing")]
    fn base_url(&self) -> &str;

    /// Get the tracing configuration
    #[cfg(feature = "tracing-config")]
    fn tracing_configuration(&self) -> Option<&TracingConfiguration>;

    /// Returns whether the given URL should be traced. Following
    /// [`TracingConfiguration::is_enabled`], this defaults to `true`.
    #[cfg(feature = "tracing")]
    #[allow(clippy::needless_return)]
    fn should_trace(&self, url: &UrlPatternMatchInput) -> bool {
        #[cfg(not(feature = "tracing-config"))]
        {
            let _ = url;
            return true;
        }

        #[cfg(feature = "tracing-config")]
        self.tracing_configuration()
            .map_or(true, |config| config.is_enabled(url))
    }
}

#[async_trait::async_trait]
impl TokenRefresher for ClientConfiguration {
    type Error = TokenError;

    async fn refresh_access_token(&self) -> Result<String, Self::Error> {
        Ok(self.refresh().await?.bearer_access_token)
    }

    async fn get_access_token(&self) -> Result<String, Self::Error> {
        self.get_bearer_access_token().await
    }

    #[cfg(feature = "tracing")]
    fn base_url(&self) -> &str {
        &self.grpc_api_url
    }

    #[cfg(feature = "tracing-config")]
    fn tracing_configuration(&self) -> Option<&TracingConfiguration> {
        self.tracing_configuration.as_ref()
    }
}

#[cfg(test)]
mod test {
    use std::time::Duration;

    use super::*;
    use httpmock::prelude::*;
    use tokio::time::Instant;

    #[tokio::test]
    async fn test_tokens_blocked_during_refresh() {
        let mock_server = MockServer::start_async().await;
        let auth_server = AuthServer::new("client_id".to_string(), mock_server.base_url());

        let original_tokens = Tokens {
            bearer_access_token: "access".to_string(),
            refresh_token: "refresh".to_string(),
            auth_server: auth_server.clone(),
        };

        let dispatcher: TokenDispatcher = original_tokens.clone().into();
        let dispatcher_clone1 = dispatcher.clone();
        let dispatcher_clone2 = dispatcher.clone();

        let refresh_duration = Duration::from_secs(3);

        let issuer_mock = mock_server
            .mock_async(|when, then| {
                when.method(POST).path("/v1/token");

                then.status(200)
                    .delay(Duration::from_secs(3))
                    .json_body_obj(&TokenResponse {
                        access_token: "new_access".to_string(),
                        refresh_token: "new_refresh".to_string(),
                    });
            })
            .await;

        let start_write = Instant::now();
        let write_future = tokio::spawn(async move { dispatcher_clone1.refresh().await.unwrap() });

        let start_read = Instant::now();
        let read_future = tokio::spawn(async move { dispatcher_clone2.tokens().await });

        let _ = write_future.await.unwrap();
        let read_result = read_future.await.unwrap();

        let write_duration = start_write.elapsed();
        let read_duration = start_read.elapsed();

        issuer_mock.assert_async().await;

        assert!(
            write_duration >= refresh_duration,
            "Write operation did not take enough time"
        );
        assert!(
            read_duration >= refresh_duration,
            "Read operation was not blocked by the write operation"
        );
        assert_eq!(read_result.bearer_access_token, "new_access");
        assert_eq!(read_result.refresh_token, "new_refresh");
    }
}