actix_oidc_token/
lib.rs

1use actix_web::client::Client;
2
3use actix_web_httpauth::headers::authorization::Bearer;
4
5use serde::{Deserialize, Serialize};
6
7use tokio::sync::RwLock;
8
9use jonases_tracing_util::log_simple_err_callback;
10use jonases_tracing_util::tracing::{event, Level};
11
12use std::sync::Arc;
13use std::time::Duration;
14
15pub mod error {
16  use actix_web::client::{JsonPayloadError, SendRequestError};
17
18  #[derive(Debug)]
19  pub enum Error {
20    SendRequestError(SendRequestError),
21    JsonPayloadError(JsonPayloadError),
22  }
23
24  impl From<SendRequestError> for Error {
25    fn from(e: SendRequestError) -> Self {
26      Self::SendRequestError(e)
27    }
28  }
29
30  impl From<JsonPayloadError> for Error {
31    fn from(e: JsonPayloadError) -> Self {
32      Self::JsonPayloadError(e)
33    }
34  }
35}
36
37#[derive(Clone)]
38pub struct AccessToken {
39  inner: Arc<RwLock<InnerAccessToken>>,
40}
41
42impl AccessToken {
43  pub fn new(endpoint: String, token_request: TokenRequest) -> Self {
44    let inner = InnerAccessToken::new(endpoint, token_request);
45
46    let access_token = AccessToken {
47      inner: Arc::new(RwLock::new(inner)),
48    };
49
50    access_token.periodically_refresh()
51  }
52
53  fn periodically_refresh(self) -> Self {
54    let client = Client::builder().disable_timeout().finish();
55
56    let res = self.clone();
57
58    actix_web::rt::spawn(async move {
59      self.refresh_token(&client).await;
60
61      loop {
62        actix_web::rt::time::delay_for({
63          let expires_in = match self.inner.read().await.expires_in()
64          {
65            Some(expires_in) => expires_in as f64,
66            None => 60.,
67          };
68
69          Duration::from_secs_f64(expires_in * 0.9_f64)
70        })
71        .await;
72
73        self.refresh_token(&client).await;
74      }
75    });
76
77    res
78  }
79
80  pub async fn refresh_token(&self, client: &Client) {
81    self.log_token_request(
82      self.inner.write().await.get_token(client).await,
83    );
84  }
85
86  pub async fn bearer(&self) -> Option<Bearer> {
87    self.inner.read().await.bearer()
88  }
89
90  pub async fn token_response(&self) -> Option<TokenResponse> {
91    self.inner.read().await.token_response()
92  }
93
94  fn log_token_request(
95    &self,
96    token_request_result: Result<(), error::Error>,
97  ) {
98    if let Err(e) = token_request_result {
99      event!(
100        Level::ERROR, msg = "could not refresh token", error = ?e
101      );
102    }
103  }
104}
105
106struct InnerAccessToken {
107  token_response: Option<TokenResponse>,
108  endpoint: String,
109  token_request: TokenRequest,
110}
111
112impl InnerAccessToken {
113  fn new(
114    endpoint: String,
115    token_request: TokenRequest,
116  ) -> InnerAccessToken {
117    InnerAccessToken {
118      token_response: None,
119      endpoint,
120      token_request,
121    }
122  }
123
124  async fn get_token(
125    &mut self,
126    client: &Client,
127  ) -> Result<(), error::Error> {
128    self.token_response = Some(
129      client
130        .post(&self.endpoint)
131        .send_form(&self.token_request)
132        .await?
133        .json()
134        .await?,
135    );
136
137    Ok(())
138  }
139
140  fn expires_in(&self) -> Option<i64> {
141    let token_response = self.token_response.as_ref()?;
142    Some(token_response.expires_in)
143  }
144
145  fn access_token(&self) -> Option<String> {
146    let token_response = self.token_response.as_ref()?;
147    Some(token_response.access_token.clone())
148  }
149
150  fn token_response(&self) -> Option<TokenResponse> {
151    self.token_response.clone()
152  }
153
154  fn bearer(&self) -> Option<Bearer> {
155    Some(Bearer::new(self.access_token()?))
156  }
157}
158
159#[derive(Serialize, Deserialize, Debug, Clone)]
160#[serde(tag = "grant_type")]
161#[serde(rename_all = "snake_case")]
162pub enum TokenRequest {
163  ClientCredentials {
164    client_id: String,
165    client_secret: String,
166  },
167  Password {
168    username: String,
169    password: String,
170    client_id: Option<String>,
171  },
172  RefreshToken {
173    refresh_token: String,
174    client_id: Option<String>,
175  },
176}
177
178impl TokenRequest {
179  pub fn client_credentials(
180    client_id: String,
181    client_secret: String,
182  ) -> Self {
183    Self::ClientCredentials {
184      client_id: client_id,
185      client_secret: client_secret,
186    }
187  }
188
189  pub fn password(username: String, password: String) -> Self {
190    Self::Password {
191      username: username,
192      password: password,
193      client_id: None,
194    }
195  }
196
197  pub fn password_with_client_id(
198    username: String,
199    password: String,
200    client_id: String,
201  ) -> Self {
202    Self::Password {
203      username,
204      password,
205      client_id: Some(client_id),
206    }
207  }
208
209  pub fn refresh_token(refresh_token: String) -> Self {
210    Self::RefreshToken {
211      refresh_token,
212      client_id: None,
213    }
214  }
215
216  pub fn refresh_token_with_client_id(
217    refresh_token: String,
218    client_id: String,
219  ) -> Self {
220    Self::RefreshToken {
221      refresh_token,
222      client_id: Some(client_id),
223    }
224  }
225
226  pub fn add_client_id(self, client_id: String) -> Self {
227    match self {
228      Self::Password {
229        username,
230        password,
231        client_id: _,
232      } => {
233        Self::password_with_client_id(username, password, client_id)
234      }
235      Self::RefreshToken {
236        refresh_token,
237        client_id: _,
238      } => {
239        Self::refresh_token_with_client_id(refresh_token, client_id)
240      }
241      other => other,
242    }
243  }
244
245  pub async fn send(
246    &self,
247    url: &str,
248  ) -> Result<TokenResponse, Error> {
249    let client = Client::builder().disable_timeout().finish();
250    self.send_with_client(url, &client).await
251  }
252
253  pub async fn send_with_client(
254    &self,
255    url: &str,
256    client: &Client,
257  ) -> Result<TokenResponse, Error> {
258    let mut response =
259      client.post(url).send_form(&self).await.map_err(
260        log_simple_err_callback("error during connection"),
261      )?;
262
263    let body = response
264      .body()
265      .await
266      .map_err(log_simple_err_callback("error retrieving payload"))?;
267
268    if response.status().is_success() {
269      Ok(serde_json::from_slice(&*body).map_err(
270        log_simple_err_callback(
271          "could not parse response to TokenResponse",
272        ),
273      )?)
274    } else {
275      event!(
276        Level::ERROR,
277        body = %String::from_utf8_lossy(&*body),
278        status = %response.status(),
279      );
280
281      Err(Error::StatusCode(response.status().as_u16()))
282    }
283  }
284}
285
286#[derive(Deserialize, Clone, Debug, PartialEq)]
287pub struct TokenResponse {
288  pub access_token: String,
289  pub expires_in: i64,
290  pub refresh_token: Option<String>,
291  pub refresh_expires_in: Option<i64>,
292}
293
294impl TokenResponse {
295  pub fn bearer(&self) -> Bearer {
296    Bearer::new(self.access_token.clone())
297  }
298}
299
300#[derive(Debug)]
301pub enum Error {
302  ParseError,
303  SendRequestError,
304  PayloadError,
305  StatusCode(u16),
306}
307
308impl From<serde_json::Error> for Error {
309  fn from(_: serde_json::Error) -> Self {
310    Error::ParseError
311  }
312}
313
314impl From<actix_web::client::PayloadError> for Error {
315  fn from(_: actix_web::client::PayloadError) -> Self {
316    Error::PayloadError
317  }
318}
319
320impl From<actix_web::client::SendRequestError> for Error {
321  fn from(_: actix_web::client::SendRequestError) -> Self {
322    Error::SendRequestError
323  }
324}
325
326#[cfg(test)]
327mod tests {
328  use super::TokenRequest;
329
330  use serde_urlencoded::to_string;
331
332  #[test]
333  fn serializing_client_credentials_token_request_to_url_encoded() {
334    let token_request = TokenRequest::client_credentials(
335      String::from("some id"),
336      String::from("some secret"),
337    );
338
339    assert_eq!(
340      to_string(token_request).unwrap(),
341      concat!(
342        "grant_type=client_credentials",
343        "&client_id=some+id&client_secret=some+secret",
344      )
345    );
346  }
347
348  #[test]
349  fn serializing_password_token_request_to_url_encoded() {
350    let token_request = TokenRequest::password(
351      String::from("some name"),
352      String::from("some password"),
353    );
354
355    assert_eq!(
356      to_string(token_request).unwrap(),
357      concat!(
358        "grant_type=password",
359        "&username=some+name&password=some+password",
360      )
361    );
362  }
363
364  #[test]
365  fn serializing_password_token_request_with_id_to_url_encoded() {
366    let token_request = TokenRequest::password_with_client_id(
367      String::from("some name"),
368      String::from("some password"),
369      String::from("some id"),
370    );
371
372    assert_eq!(
373      to_string(token_request).unwrap(),
374      concat!(
375        "grant_type=password&username=some+name",
376        "&password=some+password&client_id=some+id",
377      )
378    );
379  }
380
381  #[test]
382  fn serializing_refresh_token_request_to_url_encoded() {
383    let token_request =
384      TokenRequest::refresh_token(String::from("token"));
385
386    assert_eq!(
387      to_string(token_request).unwrap(),
388      "grant_type=refresh_token&refresh_token=token".to_owned(),
389    );
390  }
391
392  #[test]
393  fn serializing_refresh_token_request_with_id_to_url_encoded() {
394    let token_request = TokenRequest::refresh_token_with_client_id(
395      String::from("token"),
396      String::from("some id"),
397    );
398
399    assert_eq!(
400      to_string(token_request).unwrap(),
401      concat!(
402        "grant_type=refresh_token&refresh_token=token",
403        "&client_id=some+id",
404      )
405    );
406  }
407}