jotta/auth/
oauth2.rs

1#![allow(clippy::doc_markdown)]
2
3use async_trait::async_trait;
4use jsonwebtoken::{DecodingKey, Validation};
5use reqwest::Client;
6use serde::{Deserialize, Serialize};
7use time::{Duration, OffsetDateTime};
8use tracing::instrument;
9
10use super::{AccessToken, AccessTokenCache, TokenStore};
11
12/// Tele2 Cloud (formerly ComHem Cloud) OAuth2 token url.
13pub const TELE2_TOKEN_URL: &str =
14    "https://mittcloud-auth.tele2.se/auth/realms/comhem/protocol/openid-connect/token";
15
16/// An OAuth2 client.
17#[derive(Debug)]
18pub struct OAuth2 {
19    access_token: AccessTokenCache,
20    refresh_token: String,
21    username: String,
22    token_url: &'static str,
23}
24
25fn extract_username(refresh_token: &str) -> Option<String> {
26    #[derive(Deserialize)]
27    struct Payload {
28        sub: String,
29    }
30
31    let mut validation = Validation::default();
32    validation.insecure_disable_signature_validation();
33    validation.validate_exp = false;
34    let jwt =
35        jsonwebtoken::decode::<Payload>(refresh_token, &DecodingKey::from_secret(&[]), &validation)
36            .ok()?;
37
38    jwt.claims.sub.split(':').last().map(Into::into)
39}
40
41impl OAuth2 {
42    /// Initialize an OAuth2 client.
43    ///
44    /// # Errors
45    ///
46    /// If the username cannot be extracted from the refresh token, this function will
47    /// return an error.
48    pub fn init(token_url: &'static str, refresh_token: impl Into<String>) -> crate::Result<Self> {
49        let refresh_token = refresh_token.into();
50
51        Ok(Self {
52            access_token: AccessTokenCache::default(),
53            username: extract_username(&refresh_token).ok_or(crate::Error::TokenRenewalFailed)?,
54            refresh_token,
55            token_url,
56        })
57    }
58}
59
60#[async_trait]
61impl TokenStore for OAuth2 {
62    #[instrument(skip_all)]
63    async fn get_access_token(&self, client: &Client) -> crate::Result<AccessToken> {
64        #[derive(Serialize)]
65        struct Params<'a> {
66            grant_type: &'static str,
67            refresh_token: &'a str,
68            client_id: &'static str,
69        }
70
71        #[derive(Deserialize)]
72        struct Response {
73            access_token: String,
74            expires_in: i64,
75        }
76
77        if let Some(access_token) = self.access_token.get_fresh().await {
78            return Ok(access_token);
79        }
80
81        let mut w = self.access_token.write().await;
82
83        let res: Response = client
84            .post(self.token_url)
85            .form(&Params {
86                grant_type: "refresh_token",
87                refresh_token: &self.refresh_token,
88                client_id: "desktop",
89            })
90            .send()
91            .await?
92            .json()
93            .await?;
94
95        let access_token = AccessToken::new(
96            res.access_token,
97            OffsetDateTime::now_utc() + Duration::seconds(res.expires_in),
98        );
99
100        *w = Some(access_token.clone());
101        Ok(access_token)
102    }
103
104    fn username(&self) -> &str {
105        &self.username
106    }
107}