jotta_fs/
auth.rs

1//! Authentication and authorization for Jottacloud itself and whitelabel providers.
2//!
3//! ```
4//! use jotta_fs::auth::{DefaultTokenStore, provider::Tele2};
5//!
6//! let store = DefaultTokenStore::<Tele2>::new("refresh_token", "session_id");
7//! ```
8use std::{
9    fmt::Debug,
10    marker::PhantomData,
11    sync::{Arc, RwLock},
12};
13
14use async_trait::async_trait;
15use chrono::{DateTime, Duration, Utc};
16
17use reqwest::{header, Client};
18use serde::{Deserialize, Serialize};
19use serde_with::serde_as;
20use tracing::{instrument, trace};
21
22use crate::Error;
23
24/// Generic auth provider.
25pub trait Provider: Debug + Send + Sync {
26    /// Name of the session cookie, e.g. `jottacloud.session`.
27    const SESSION_COOKIE_NAME: &'static str;
28
29    /// Domain, e.g. `jottacloud.com`.
30    const DOMAIN: &'static str;
31}
32
33/// A thread-safe caching token store.
34#[derive(Debug)]
35pub struct DefaultTokenStore<P> {
36    refresh_token: String,
37    session_id: String,
38    access_token: Arc<RwLock<Option<AccessToken>>>,
39    provider: PhantomData<P>,
40}
41
42impl<P> DefaultTokenStore<P> {
43    /// Construct a new [`DefaultTokenStore`].
44    #[must_use]
45    pub fn new(refresh_token: impl Into<String>, session_id: impl Into<String>) -> Self {
46        Self {
47            refresh_token: refresh_token.into(),
48            session_id: session_id.into(),
49            access_token: Arc::new(RwLock::new(None)),
50            provider: PhantomData::default(),
51        }
52    }
53}
54
55#[async_trait]
56impl<P: Provider> TokenStore for DefaultTokenStore<P> {
57    async fn get_refresh_token(&self, _client: &Client) -> crate::Result<String> {
58        Ok(self.refresh_token.clone())
59    }
60
61    #[instrument(level = "trace", skip_all)]
62    async fn get_access_token(&self, client: &Client) -> crate::Result<AccessToken> {
63        {
64            let lock = self.access_token.read().unwrap();
65
66            if let Some(ref access_token) = *lock {
67                if access_token.exp() >= Utc::now() + Duration::minutes(5) {
68                    trace!("found fresh cached access token");
69                    return Ok(access_token.clone());
70                }
71            }
72        }
73
74        trace!("renewing access token");
75
76        let res = client
77            .get(format!("https://{}/web/token", P::DOMAIN))
78            .header(
79                header::COOKIE,
80                format!(
81                    "refresh_token={}; {}={}",
82                    self.get_refresh_token(client).await?,
83                    P::SESSION_COOKIE_NAME,
84                    self.session_id,
85                ),
86            )
87            .send()
88            .await?;
89
90        let cookie = res
91            .cookies()
92            .find(|c| c.name() == "access_token")
93            .ok_or(Error::TokenRenewalFailed)?;
94
95        let access_token = AccessToken::new(cookie.value().into());
96
97        *self.access_token.write().unwrap() = Some(access_token.clone());
98
99        Ok(access_token)
100    }
101}
102
103/// A [`TokenStore`] manages authentication tokens.
104#[async_trait]
105pub trait TokenStore: Debug + Send + Sync {
106    /// Get the cached refresh token or renew it.
107    async fn get_refresh_token(&self, client: &Client) -> crate::Result<String>;
108
109    /// Get the cached access token or renew it if it needs to be renewed.
110    async fn get_access_token(&self, client: &Client) -> crate::Result<AccessToken>;
111}
112
113/// Auth providers.
114pub mod provider {
115    use super::Provider;
116
117    macro_rules! provider {
118        ($name:ident, $domain:literal, $cookie_name:literal) => {
119            /// Authentication provider with domain
120            #[doc=$domain]
121            #[derive(Debug, Clone)]
122            pub struct $name;
123
124            impl Provider for $name {
125                const DOMAIN: &'static str = $domain;
126
127                const SESSION_COOKIE_NAME: &'static str = $cookie_name;
128            }
129        };
130    }
131
132    provider!(Jottacloud, "jottacloud.com", "jottacloud.session");
133    provider!(Tele2, "mittcloud.tele2.se", "tele2.se.session");
134}
135
136/// JWT claims for the [`AccessToken`].
137#[serde_as]
138#[derive(Debug, Deserialize)]
139pub struct AccessTokenClaims {
140    /// Username associated with this access token.
141    pub username: String,
142    #[serde_as(as = "serde_with::TimestampSeconds<i64>")]
143    /// Expiration date of the token.
144    pub exp: DateTime<Utc>,
145}
146
147/// An access token used to authenticate with all Jottacloud services.
148#[derive(Debug, Clone, Serialize)]
149pub struct AccessToken(String);
150
151impl AccessToken {
152    /// Construct a new access token.
153    #[must_use]
154    pub fn new(value: String) -> Self {
155        Self(value)
156    }
157
158    /// Parse claims.
159    ///
160    /// # Panics
161    ///
162    /// Panics if the access token isn't a JWT or is missing some or all [`AccessTokenClaims`].
163    #[must_use]
164    pub fn claims(&self) -> AccessTokenClaims {
165        let mut segments = self.0.split('.');
166        let _header = segments.next();
167        let payload = segments.next().expect("malformed token");
168        let json = base64::decode_config(payload, base64::URL_SAFE_NO_PAD).expect("invalid base64");
169        let json = String::from_utf8(json).expect("invalid utf-8");
170        let claims: AccessTokenClaims = serde_json::from_str(&json).expect("parse claims failed");
171
172        claims
173    }
174
175    /// Get the associated username.
176    #[must_use]
177    pub fn username(&self) -> String {
178        self.claims().username
179    }
180
181    /// Expiration time.
182    #[must_use]
183    pub fn exp(&self) -> DateTime<Utc> {
184        self.claims().exp
185    }
186}
187
188impl std::fmt::Display for AccessToken {
189    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190        write!(f, "{}", self.0)
191    }
192}