use std::{fmt, future::Future, pin::Pin, sync::Mutex, time::Duration};
use chrono::Utc;
use oauth2::{
basic::BasicClient, reqwest::async_http_client, AccessToken, AuthUrl, AuthorizationCode,
ClientId, ClientSecret, CsrfToken, PkceCodeChallenge, RedirectUrl, RefreshToken, RevocationUrl,
Scope, TokenResponse as _, TokenUrl,
};
use crate::{errors::TokenError, API_URL, RUNTIME};
#[derive(Debug, thiserror::Error)]
pub enum ClientError {
#[error("failed to refresh token")]
Refresh,
#[error("{0:?}")]
Token(#[from] TokenError),
#[error("{0}")]
Reqwest(#[from] reqwest::Error),
}
#[derive(Debug, Clone)]
pub struct Code(pub String);
#[derive(Debug, Clone)]
pub struct State(pub String);
pub type Callback = Box<
dyn Fn(
reqwest::Url,
State,
) -> Pin<
Box<
dyn Future<Output = Result<(Code, State), Box<dyn std::error::Error>>>
+ Send
+ 'static,
>,
> + Send
+ 'static,
>;
pub struct Auth {
client: BasicClient,
app_token: String,
access_token: Mutex<Option<AccessToken>>,
refresh_token: Mutex<Option<RefreshToken>>,
expires_at: Mutex<Option<u64>>,
scopes: Mutex<Vec<Scope>>,
callback: tokio::sync::Mutex<Callback>,
}
impl fmt::Debug for Auth {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let Auth { client, .. } = self;
f.debug_struct("Auth")
.field("client", &client)
.field("app_token", &"[redacted]")
.field("access_token", &"[redacted]")
.field("refresh_token", &"[redacted]")
.field("expires_at", &"[redacted]")
.field("scopes", &"[redacted]")
.field("callback", &"<ptr>")
.finish()
}
}
impl Auth {
pub(crate) fn new(
client_id: &str,
client_secret: &str,
app_token: &str,
redirect_uri: &str,
) -> Result<Self, TokenError> {
let client = BasicClient::new(
ClientId::new(client_id.to_owned()),
Some(ClientSecret::new(client_secret.to_owned())),
AuthUrl::new(format!("{API_URL}/oauth2/authorize")).unwrap(),
Some(TokenUrl::new(format!("{API_URL}/oauth2/token")).unwrap()),
)
.set_redirect_uri(RedirectUrl::new(redirect_uri.to_owned())?)
.set_revocation_uri(RevocationUrl::new(format!("{API_URL}/oauth2/revoke")).unwrap());
let slf = Self {
client,
app_token: app_token.to_owned(),
access_token: Mutex::new(None),
refresh_token: Mutex::new(None),
expires_at: Mutex::new(None),
scopes: Mutex::new(Vec::new()),
callback: tokio::sync::Mutex::new(Box::new(|_, _| {
unimplemented!("oauth2 callback not implemented")
})),
};
Ok(slf)
}
pub fn app_token(&self) -> &str {
&self.app_token
}
pub fn set_refresh_token(&self, token: Option<&str>) {
let mut lock = self.refresh_token.lock().unwrap();
*lock = token.map(|t| RefreshToken::new(t.to_owned()));
}
pub fn set_access_token(&self, token: &str) {
let mut lock = self.access_token.lock().unwrap();
*lock = Some(AccessToken::new(token.to_owned()));
}
pub fn set_expires_in(&self, duration: Option<Duration>) {
let mut lock = self.expires_at.lock().unwrap();
*lock = duration.map(|d| Utc::now().timestamp() as u64 + d.as_secs());
}
pub fn add_scope(&self, scope: &str) {
let mut lock = self.scopes.lock().unwrap();
lock.push(Scope::new(scope.to_owned()));
}
pub async fn set_callback<
F: Fn(reqwest::Url, State) -> Fut + Send + 'static,
Fut: Future<Output = Result<(Code, State), Box<dyn std::error::Error>>> + 'static + Send,
>(
&self,
f: F,
) {
let mut lock = self.callback.lock().await;
*lock = Box::new(move |url, state| Box::pin(f(url, state)));
}
pub fn set_callback_blocking<
F: Fn(reqwest::Url, State) -> Fut + Send + 'static,
Fut: Future<Output = Result<(Code, State), Box<dyn std::error::Error>>> + 'static + Send,
>(
&self,
f: F,
) {
RUNTIME.block_on(self.set_callback(f))
}
pub fn is_valid(&self) -> bool {
let has_access_token = self.access_token.lock().unwrap().is_some();
let is_active = self
.expires_at
.lock()
.unwrap()
.is_some_and(|t| (Utc::now().timestamp() as u64) < t);
has_access_token && is_active
}
pub fn is_refresh_valid(&self) -> bool {
let has_refresh_token = self.refresh_token.lock().unwrap().is_some();
let is_refresh_active = self
.expires_at
.lock()
.unwrap()
.is_some_and(|t| (Utc::now().timestamp() as u64) < t);
has_refresh_token && is_refresh_active
}
pub async fn revoke_token(&self) -> Result<(), TokenError> {
let token = self.access_token.lock().unwrap().clone();
if let Some(token) = token {
let req = self
.client
.revoke_token(oauth2::StandardRevocableToken::AccessToken(token))
.map_err(|e| TokenError::Revoke(e.to_string()))?;
req.request_async(async_http_client)
.await
.map_err(|e| TokenError::Revoke(e.to_string()))?;
}
Ok(())
}
pub fn revoke_token_blocking(&self) -> Result<(), TokenError> {
RUNTIME.block_on(self.revoke_token())
}
pub async fn revoke_refresh_token(&self) -> Result<(), TokenError> {
let token = self.refresh_token.lock().unwrap().clone();
if let Some(token) = token {
let req = self
.client
.revoke_token(oauth2::StandardRevocableToken::RefreshToken(token.clone()))
.map_err(|e| TokenError::Revoke(e.to_string()))?;
req.request_async(async_http_client)
.await
.map_err(|e| TokenError::Revoke(e.to_string()))?;
}
Ok(())
}
pub fn revoke_refresh_token_blocking(&self) -> Result<(), TokenError> {
RUNTIME.block_on(self.revoke_refresh_token())
}
pub async fn try_refresh(&self) -> Result<(), TokenError> {
if self.is_refresh_valid() {
self.refresh().await?;
}
Ok(())
}
pub fn try_refresh_blocking(&self) -> Result<(), TokenError> {
RUNTIME.block_on(self.try_refresh())
}
pub fn access_token(&self) -> Option<AccessToken> {
self.access_token.lock().unwrap().clone()
}
pub fn refresh_token(&self) -> Option<RefreshToken> {
self.refresh_token.lock().unwrap().clone()
}
pub fn expires_at(&self) -> Option<u64> {
*self.expires_at.lock().unwrap()
}
pub async fn refresh(&self) -> Result<(), TokenError> {
let token = self.refresh_token.lock().unwrap().clone();
if let Some(refresh_token) = token {
let token = self
.client
.exchange_refresh_token(&refresh_token)
.request_async(async_http_client)
.await
.map_err(|e| TokenError::OAuth2(e.to_string()))?;
let mut lock = self.access_token.lock().unwrap();
*lock = Some(token.access_token().clone());
let mut lock = self.refresh_token.lock().unwrap();
*lock = token.refresh_token().cloned();
let mut lock = self.expires_at.lock().unwrap();
*lock = token
.expires_in()
.map(|d| (Utc::now().timestamp() as u64) + d.as_secs());
Ok(())
} else {
Err(TokenError::Refresh)
}
}
pub fn refresh_blocking(&self) -> Result<(), TokenError> {
RUNTIME.block_on(self.refresh())
}
pub async fn regenerate(&self) -> Result<(), TokenError> {
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
let scopes = self.scopes.lock().unwrap().clone();
let (auth_url, state) = self
.client
.authorize_url(CsrfToken::new_random)
.add_scopes(scopes.into_iter())
.set_pkce_challenge(pkce_challenge)
.url();
let callback = self.callback.lock().await;
let (res_code, res_state) = match callback(auth_url, State(state.secret().clone())).await {
Ok(v) => v,
Err(e) => return Err(TokenError::Callback(e.to_string())),
};
if state.secret() != &res_state.0 {
return Err(TokenError::StateMismatch);
}
let Ok(token) = self
.client
.exchange_code(AuthorizationCode::new(res_code.0))
.set_pkce_verifier(pkce_verifier)
.request_async(async_http_client)
.await
else {
return Err(TokenError::Access);
};
let mut expires_at = self.expires_at.lock().unwrap();
*expires_at = token
.expires_in()
.map(|d| Utc::now().timestamp() as u64 + d.as_secs());
let mut access_token = self.access_token.lock().unwrap();
*access_token = Some(token.access_token().clone());
let mut refresh_token = self.refresh_token.lock().unwrap();
*refresh_token = token.refresh_token().cloned();
Ok(())
}
pub fn regenerate_blocking(&self) -> Result<(), TokenError> {
RUNTIME.block_on(self.regenerate())
}
}