1use actix_web::client::Client;
2
3use actix_web_httpauth::headers::authorization::Bearer;
4
5use serde::{Deserialize, Serialize};
6
7use tokio::sync::RwLock;
8
9use jonases_tracing_util::log_simple_err_callback;
10use jonases_tracing_util::tracing::{event, Level};
11
12use std::sync::Arc;
13use std::time::Duration;
14
15pub mod error {
16 use actix_web::client::{JsonPayloadError, SendRequestError};
17
18 #[derive(Debug)]
19 pub enum Error {
20 SendRequestError(SendRequestError),
21 JsonPayloadError(JsonPayloadError),
22 }
23
24 impl From<SendRequestError> for Error {
25 fn from(e: SendRequestError) -> Self {
26 Self::SendRequestError(e)
27 }
28 }
29
30 impl From<JsonPayloadError> for Error {
31 fn from(e: JsonPayloadError) -> Self {
32 Self::JsonPayloadError(e)
33 }
34 }
35}
36
37#[derive(Clone)]
38pub struct AccessToken {
39 inner: Arc<RwLock<InnerAccessToken>>,
40}
41
42impl AccessToken {
43 pub fn new(endpoint: String, token_request: TokenRequest) -> Self {
44 let inner = InnerAccessToken::new(endpoint, token_request);
45
46 let access_token = AccessToken {
47 inner: Arc::new(RwLock::new(inner)),
48 };
49
50 access_token.periodically_refresh()
51 }
52
53 fn periodically_refresh(self) -> Self {
54 let client = Client::builder().disable_timeout().finish();
55
56 let res = self.clone();
57
58 actix_web::rt::spawn(async move {
59 self.refresh_token(&client).await;
60
61 loop {
62 actix_web::rt::time::delay_for({
63 let expires_in = match self.inner.read().await.expires_in()
64 {
65 Some(expires_in) => expires_in as f64,
66 None => 60.,
67 };
68
69 Duration::from_secs_f64(expires_in * 0.9_f64)
70 })
71 .await;
72
73 self.refresh_token(&client).await;
74 }
75 });
76
77 res
78 }
79
80 pub async fn refresh_token(&self, client: &Client) {
81 self.log_token_request(
82 self.inner.write().await.get_token(client).await,
83 );
84 }
85
86 pub async fn bearer(&self) -> Option<Bearer> {
87 self.inner.read().await.bearer()
88 }
89
90 pub async fn token_response(&self) -> Option<TokenResponse> {
91 self.inner.read().await.token_response()
92 }
93
94 fn log_token_request(
95 &self,
96 token_request_result: Result<(), error::Error>,
97 ) {
98 if let Err(e) = token_request_result {
99 event!(
100 Level::ERROR, msg = "could not refresh token", error = ?e
101 );
102 }
103 }
104}
105
106struct InnerAccessToken {
107 token_response: Option<TokenResponse>,
108 endpoint: String,
109 token_request: TokenRequest,
110}
111
112impl InnerAccessToken {
113 fn new(
114 endpoint: String,
115 token_request: TokenRequest,
116 ) -> InnerAccessToken {
117 InnerAccessToken {
118 token_response: None,
119 endpoint,
120 token_request,
121 }
122 }
123
124 async fn get_token(
125 &mut self,
126 client: &Client,
127 ) -> Result<(), error::Error> {
128 self.token_response = Some(
129 client
130 .post(&self.endpoint)
131 .send_form(&self.token_request)
132 .await?
133 .json()
134 .await?,
135 );
136
137 Ok(())
138 }
139
140 fn expires_in(&self) -> Option<i64> {
141 let token_response = self.token_response.as_ref()?;
142 Some(token_response.expires_in)
143 }
144
145 fn access_token(&self) -> Option<String> {
146 let token_response = self.token_response.as_ref()?;
147 Some(token_response.access_token.clone())
148 }
149
150 fn token_response(&self) -> Option<TokenResponse> {
151 self.token_response.clone()
152 }
153
154 fn bearer(&self) -> Option<Bearer> {
155 Some(Bearer::new(self.access_token()?))
156 }
157}
158
159#[derive(Serialize, Deserialize, Debug, Clone)]
160#[serde(tag = "grant_type")]
161#[serde(rename_all = "snake_case")]
162pub enum TokenRequest {
163 ClientCredentials {
164 client_id: String,
165 client_secret: String,
166 },
167 Password {
168 username: String,
169 password: String,
170 client_id: Option<String>,
171 },
172 RefreshToken {
173 refresh_token: String,
174 client_id: Option<String>,
175 },
176}
177
178impl TokenRequest {
179 pub fn client_credentials(
180 client_id: String,
181 client_secret: String,
182 ) -> Self {
183 Self::ClientCredentials {
184 client_id: client_id,
185 client_secret: client_secret,
186 }
187 }
188
189 pub fn password(username: String, password: String) -> Self {
190 Self::Password {
191 username: username,
192 password: password,
193 client_id: None,
194 }
195 }
196
197 pub fn password_with_client_id(
198 username: String,
199 password: String,
200 client_id: String,
201 ) -> Self {
202 Self::Password {
203 username,
204 password,
205 client_id: Some(client_id),
206 }
207 }
208
209 pub fn refresh_token(refresh_token: String) -> Self {
210 Self::RefreshToken {
211 refresh_token,
212 client_id: None,
213 }
214 }
215
216 pub fn refresh_token_with_client_id(
217 refresh_token: String,
218 client_id: String,
219 ) -> Self {
220 Self::RefreshToken {
221 refresh_token,
222 client_id: Some(client_id),
223 }
224 }
225
226 pub fn add_client_id(self, client_id: String) -> Self {
227 match self {
228 Self::Password {
229 username,
230 password,
231 client_id: _,
232 } => {
233 Self::password_with_client_id(username, password, client_id)
234 }
235 Self::RefreshToken {
236 refresh_token,
237 client_id: _,
238 } => {
239 Self::refresh_token_with_client_id(refresh_token, client_id)
240 }
241 other => other,
242 }
243 }
244
245 pub async fn send(
246 &self,
247 url: &str,
248 ) -> Result<TokenResponse, Error> {
249 let client = Client::builder().disable_timeout().finish();
250 self.send_with_client(url, &client).await
251 }
252
253 pub async fn send_with_client(
254 &self,
255 url: &str,
256 client: &Client,
257 ) -> Result<TokenResponse, Error> {
258 let mut response =
259 client.post(url).send_form(&self).await.map_err(
260 log_simple_err_callback("error during connection"),
261 )?;
262
263 let body = response
264 .body()
265 .await
266 .map_err(log_simple_err_callback("error retrieving payload"))?;
267
268 if response.status().is_success() {
269 Ok(serde_json::from_slice(&*body).map_err(
270 log_simple_err_callback(
271 "could not parse response to TokenResponse",
272 ),
273 )?)
274 } else {
275 event!(
276 Level::ERROR,
277 body = %String::from_utf8_lossy(&*body),
278 status = %response.status(),
279 );
280
281 Err(Error::StatusCode(response.status().as_u16()))
282 }
283 }
284}
285
286#[derive(Deserialize, Clone, Debug, PartialEq)]
287pub struct TokenResponse {
288 pub access_token: String,
289 pub expires_in: i64,
290 pub refresh_token: Option<String>,
291 pub refresh_expires_in: Option<i64>,
292}
293
294impl TokenResponse {
295 pub fn bearer(&self) -> Bearer {
296 Bearer::new(self.access_token.clone())
297 }
298}
299
300#[derive(Debug)]
301pub enum Error {
302 ParseError,
303 SendRequestError,
304 PayloadError,
305 StatusCode(u16),
306}
307
308impl From<serde_json::Error> for Error {
309 fn from(_: serde_json::Error) -> Self {
310 Error::ParseError
311 }
312}
313
314impl From<actix_web::client::PayloadError> for Error {
315 fn from(_: actix_web::client::PayloadError) -> Self {
316 Error::PayloadError
317 }
318}
319
320impl From<actix_web::client::SendRequestError> for Error {
321 fn from(_: actix_web::client::SendRequestError) -> Self {
322 Error::SendRequestError
323 }
324}
325
326#[cfg(test)]
327mod tests {
328 use super::TokenRequest;
329
330 use serde_urlencoded::to_string;
331
332 #[test]
333 fn serializing_client_credentials_token_request_to_url_encoded() {
334 let token_request = TokenRequest::client_credentials(
335 String::from("some id"),
336 String::from("some secret"),
337 );
338
339 assert_eq!(
340 to_string(token_request).unwrap(),
341 concat!(
342 "grant_type=client_credentials",
343 "&client_id=some+id&client_secret=some+secret",
344 )
345 );
346 }
347
348 #[test]
349 fn serializing_password_token_request_to_url_encoded() {
350 let token_request = TokenRequest::password(
351 String::from("some name"),
352 String::from("some password"),
353 );
354
355 assert_eq!(
356 to_string(token_request).unwrap(),
357 concat!(
358 "grant_type=password",
359 "&username=some+name&password=some+password",
360 )
361 );
362 }
363
364 #[test]
365 fn serializing_password_token_request_with_id_to_url_encoded() {
366 let token_request = TokenRequest::password_with_client_id(
367 String::from("some name"),
368 String::from("some password"),
369 String::from("some id"),
370 );
371
372 assert_eq!(
373 to_string(token_request).unwrap(),
374 concat!(
375 "grant_type=password&username=some+name",
376 "&password=some+password&client_id=some+id",
377 )
378 );
379 }
380
381 #[test]
382 fn serializing_refresh_token_request_to_url_encoded() {
383 let token_request =
384 TokenRequest::refresh_token(String::from("token"));
385
386 assert_eq!(
387 to_string(token_request).unwrap(),
388 "grant_type=refresh_token&refresh_token=token".to_owned(),
389 );
390 }
391
392 #[test]
393 fn serializing_refresh_token_request_with_id_to_url_encoded() {
394 let token_request = TokenRequest::refresh_token_with_client_id(
395 String::from("token"),
396 String::from("some id"),
397 );
398
399 assert_eq!(
400 to_string(token_request).unwrap(),
401 concat!(
402 "grant_type=refresh_token&refresh_token=token",
403 "&client_id=some+id",
404 )
405 );
406 }
407}