1use std::{
9 fmt::Debug,
10 marker::PhantomData,
11 sync::{Arc, RwLock},
12};
13
14use async_trait::async_trait;
15use chrono::{DateTime, Duration, Utc};
16
17use reqwest::{header, Client};
18use serde::{Deserialize, Serialize};
19use serde_with::serde_as;
20use tracing::{instrument, trace};
21
22use crate::Error;
23
24pub trait Provider: Debug + Send + Sync {
26 const SESSION_COOKIE_NAME: &'static str;
28
29 const DOMAIN: &'static str;
31}
32
33#[derive(Debug)]
35pub struct DefaultTokenStore<P> {
36 refresh_token: String,
37 session_id: String,
38 access_token: Arc<RwLock<Option<AccessToken>>>,
39 provider: PhantomData<P>,
40}
41
42impl<P> DefaultTokenStore<P> {
43 #[must_use]
45 pub fn new(refresh_token: impl Into<String>, session_id: impl Into<String>) -> Self {
46 Self {
47 refresh_token: refresh_token.into(),
48 session_id: session_id.into(),
49 access_token: Arc::new(RwLock::new(None)),
50 provider: PhantomData::default(),
51 }
52 }
53}
54
55#[async_trait]
56impl<P: Provider> TokenStore for DefaultTokenStore<P> {
57 async fn get_refresh_token(&self, _client: &Client) -> crate::Result<String> {
58 Ok(self.refresh_token.clone())
59 }
60
61 #[instrument(level = "trace", skip_all)]
62 async fn get_access_token(&self, client: &Client) -> crate::Result<AccessToken> {
63 {
64 let lock = self.access_token.read().unwrap();
65
66 if let Some(ref access_token) = *lock {
67 if access_token.exp() >= Utc::now() + Duration::minutes(5) {
68 trace!("found fresh cached access token");
69 return Ok(access_token.clone());
70 }
71 }
72 }
73
74 trace!("renewing access token");
75
76 let res = client
77 .get(format!("https://{}/web/token", P::DOMAIN))
78 .header(
79 header::COOKIE,
80 format!(
81 "refresh_token={}; {}={}",
82 self.get_refresh_token(client).await?,
83 P::SESSION_COOKIE_NAME,
84 self.session_id,
85 ),
86 )
87 .send()
88 .await?;
89
90 let cookie = res
91 .cookies()
92 .find(|c| c.name() == "access_token")
93 .ok_or(Error::TokenRenewalFailed)?;
94
95 let access_token = AccessToken::new(cookie.value().into());
96
97 *self.access_token.write().unwrap() = Some(access_token.clone());
98
99 Ok(access_token)
100 }
101}
102
103#[async_trait]
105pub trait TokenStore: Debug + Send + Sync {
106 async fn get_refresh_token(&self, client: &Client) -> crate::Result<String>;
108
109 async fn get_access_token(&self, client: &Client) -> crate::Result<AccessToken>;
111}
112
113pub mod provider {
115 use super::Provider;
116
117 macro_rules! provider {
118 ($name:ident, $domain:literal, $cookie_name:literal) => {
119 #[doc=$domain]
121 #[derive(Debug, Clone)]
122 pub struct $name;
123
124 impl Provider for $name {
125 const DOMAIN: &'static str = $domain;
126
127 const SESSION_COOKIE_NAME: &'static str = $cookie_name;
128 }
129 };
130 }
131
132 provider!(Jottacloud, "jottacloud.com", "jottacloud.session");
133 provider!(Tele2, "mittcloud.tele2.se", "tele2.se.session");
134}
135
136#[serde_as]
138#[derive(Debug, Deserialize)]
139pub struct AccessTokenClaims {
140 pub username: String,
142 #[serde_as(as = "serde_with::TimestampSeconds<i64>")]
143 pub exp: DateTime<Utc>,
145}
146
147#[derive(Debug, Clone, Serialize)]
149pub struct AccessToken(String);
150
151impl AccessToken {
152 #[must_use]
154 pub fn new(value: String) -> Self {
155 Self(value)
156 }
157
158 #[must_use]
164 pub fn claims(&self) -> AccessTokenClaims {
165 let mut segments = self.0.split('.');
166 let _header = segments.next();
167 let payload = segments.next().expect("malformed token");
168 let json = base64::decode_config(payload, base64::URL_SAFE_NO_PAD).expect("invalid base64");
169 let json = String::from_utf8(json).expect("invalid utf-8");
170 let claims: AccessTokenClaims = serde_json::from_str(&json).expect("parse claims failed");
171
172 claims
173 }
174
175 #[must_use]
177 pub fn username(&self) -> String {
178 self.claims().username
179 }
180
181 #[must_use]
183 pub fn exp(&self) -> DateTime<Utc> {
184 self.claims().exp
185 }
186}
187
188impl std::fmt::Display for AccessToken {
189 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190 write!(f, "{}", self.0)
191 }
192}