use async_trait::async_trait;
use std::convert::Infallible;
use std::fmt::{Debug, Display};
#[cfg(any(
all(
feature = "refreshing-token-native-tls",
feature = "refreshing-token-rustls-native-roots"
),
all(
feature = "refreshing-token-native-tls",
feature = "refreshing-token-rustls-webpki-roots"
),
all(
feature = "refreshing-token-rustls-native-roots",
feature = "refreshing-token-rustls-webpki-roots"
),
))]
compile_error!(
"`refreshing-token-native-tls`, `refreshing-token-rustls-native-roots` and `refreshing-token-rustls-webpki-roots` feature flags are mutually exclusive, enable at most one of them"
);
#[cfg(feature = "__refreshing-token")]
use {
chrono::DateTime,
chrono::Utc,
reqwest::ClientBuilder,
std::{sync::Arc, time::Duration},
thiserror::Error,
tokio::sync::Mutex,
};
#[cfg(feature = "with-serde")]
use {serde::Deserialize, serde::Serialize};
#[derive(Clone)]
#[cfg_attr(feature = "with-serde", derive(Serialize, Deserialize))]
pub struct CredentialsPair {
pub login: String,
pub token: Option<String>,
}
impl Debug for CredentialsPair {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CredentialsPair")
.field("login", &self.login)
.field("token", &self.token.as_ref().map(|_| "[redacted]"))
.finish()
}
}
#[async_trait]
pub trait LoginCredentials: Debug + Send + Sync + 'static {
type Error: Send + Sync + Debug + Display;
async fn get_credentials(&self) -> Result<CredentialsPair, Self::Error>;
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "with-serde", derive(Serialize, Deserialize))]
pub struct StaticLoginCredentials {
pub credentials: CredentialsPair,
}
impl StaticLoginCredentials {
#[must_use]
pub fn new(login: String, token: Option<String>) -> StaticLoginCredentials {
StaticLoginCredentials {
credentials: CredentialsPair { login, token },
}
}
#[must_use]
pub fn anonymous() -> StaticLoginCredentials {
StaticLoginCredentials::new("justinfan12345".to_owned(), None)
}
}
#[async_trait]
impl LoginCredentials for StaticLoginCredentials {
type Error = Infallible;
async fn get_credentials(&self) -> Result<CredentialsPair, Infallible> {
Ok(self.credentials.clone())
}
}
#[cfg(feature = "__refreshing-token")]
#[derive(Clone, Serialize, Deserialize)]
pub struct UserAccessToken {
pub access_token: String,
pub refresh_token: String,
pub created_at: DateTime<Utc>,
pub expires_at: Option<DateTime<Utc>>,
}
#[cfg(feature = "__refreshing-token")]
impl Debug for UserAccessToken {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("UserAccessToken")
.field("access_token", &"[redacted]")
.field("refresh_token", &"[redacted]")
.field("created_at", &self.created_at)
.field("expires_at", &self.expires_at)
.finish()
}
}
#[cfg(feature = "__refreshing-token")]
#[derive(Serialize, Deserialize)]
pub struct GetAccessTokenResponse {
pub access_token: String,
pub refresh_token: String,
pub expires_in: Option<u64>,
}
#[cfg(feature = "__refreshing-token")]
impl From<GetAccessTokenResponse> for UserAccessToken {
fn from(response: GetAccessTokenResponse) -> Self {
let now = Utc::now();
UserAccessToken {
access_token: response.access_token,
refresh_token: response.refresh_token,
created_at: now,
expires_at: response
.expires_in
.map(|d| now + chrono::Duration::from_std(Duration::from_secs(d)).unwrap()),
}
}
}
#[cfg(feature = "__refreshing-token")]
#[async_trait]
pub trait TokenStorage: Debug + Send + 'static {
type LoadError: Send + Sync + Debug + Display;
type UpdateError: Send + Sync + Debug + Display;
async fn load_token(&mut self) -> Result<UserAccessToken, Self::LoadError>;
async fn update_token(&mut self, token: &UserAccessToken) -> Result<(), Self::UpdateError>;
}
#[cfg(feature = "__refreshing-token")]
#[derive(Clone)]
pub struct RefreshingLoginCredentials<S: TokenStorage> {
http_client: reqwest::Client,
user_login: Arc<Mutex<Option<String>>>,
client_id: String,
client_secret: String,
token_storage: Arc<Mutex<S>>,
}
#[cfg(feature = "__refreshing-token")]
impl<S: TokenStorage> Debug for RefreshingLoginCredentials<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RefreshingLoginCredentials")
.field("http_client", &self.http_client)
.field("user_login", &self.user_login)
.field("client_id", &self.client_id)
.field("client_secret", &"[redacted]")
.field("token_storage", &self.token_storage)
.finish()
}
}
#[cfg(feature = "__refreshing-token")]
impl<S: TokenStorage> RefreshingLoginCredentials<S> {
pub fn init(
client_id: String,
client_secret: String,
token_storage: S,
) -> RefreshingLoginCredentials<S> {
RefreshingLoginCredentials::init_with_username(
None,
client_id,
client_secret,
token_storage,
)
}
pub fn init_with_username(
user_login: Option<String>,
client_id: String,
client_secret: String,
token_storage: S,
) -> RefreshingLoginCredentials<S> {
let http_client = {
#[cfg_attr(
not(feature = "refreshing-token-rustls-webpki-roots"),
allow(unused_mut)
)]
let mut builder = ClientBuilder::new();
#[cfg(feature = "refreshing-token-rustls-webpki-roots")]
{
builder = builder.tls_certs_only(
webpki_root_certs::TLS_SERVER_ROOT_CERTS
.iter()
.map(|cert| reqwest::tls::Certificate::from_der(cert).unwrap()),
);
}
builder.build().unwrap()
};
RefreshingLoginCredentials {
http_client,
user_login: Arc::new(Mutex::new(user_login)),
client_id,
client_secret,
token_storage: Arc::new(Mutex::new(token_storage)),
}
}
}
#[cfg(feature = "__refreshing-token")]
#[derive(Error, Debug)]
pub enum RefreshingLoginError<S: TokenStorage> {
#[error("Failed to retrieve token from storage: {0}")]
LoadError(S::LoadError),
#[error("Failed to refresh token: {0}")]
RefreshError(reqwest::Error),
#[error("Failed to update token in storage: {0}")]
UpdateError(S::UpdateError),
}
#[cfg(feature = "__refreshing-token")]
const SHOULD_REFRESH_AFTER_FACTOR: f64 = 0.9;
#[cfg(feature = "__refreshing-token")]
#[async_trait]
impl<S: TokenStorage> LoginCredentials for RefreshingLoginCredentials<S> {
type Error = RefreshingLoginError<S>;
async fn get_credentials(&self) -> Result<CredentialsPair, RefreshingLoginError<S>> {
let mut token_storage = self.token_storage.lock().await;
let mut current_token = token_storage
.load_token()
.await
.map_err(RefreshingLoginError::LoadError)?;
let token_expires_after = if let Some(expires_at) = current_token.expires_at {
(expires_at - current_token.created_at).to_std().unwrap()
} else {
Duration::from_secs(24 * 60 * 60)
};
let token_age = (Utc::now() - current_token.created_at).to_std().unwrap();
let max_token_age = token_expires_after.mul_f64(SHOULD_REFRESH_AFTER_FACTOR);
let is_token_expired = token_age >= max_token_age;
if is_token_expired {
let response = self
.http_client
.post("https://id.twitch.tv/oauth2/token")
.query(&[
("grant_type", "refresh_token"),
("refresh_token", ¤t_token.refresh_token),
("client_id", &self.client_id),
("client_secret", &self.client_secret),
])
.send()
.await
.map_err(RefreshingLoginError::RefreshError)?
.json::<GetAccessTokenResponse>()
.await
.map_err(RefreshingLoginError::RefreshError)?;
current_token = UserAccessToken::from(response);
token_storage
.update_token(¤t_token)
.await
.map_err(RefreshingLoginError::UpdateError)?;
}
let mut current_login = self.user_login.lock().await;
let login = if let Some(login) = &*current_login {
login.clone()
} else {
let response = self
.http_client
.get("https://api.twitch.tv/helix/users")
.header("Client-Id", &self.client_id)
.bearer_auth(¤t_token.access_token)
.send()
.await
.map_err(RefreshingLoginError::RefreshError)?;
let users_response = response
.json::<UsersResponse>()
.await
.map_err(RefreshingLoginError::RefreshError)?;
let user = users_response.data.into_iter().next().unwrap();
tracing::info!(
"Fetched login name `{}` for provided auth token",
&user.login
);
*current_login = Some(user.login.clone());
user.login
};
Ok(CredentialsPair {
login,
token: Some(current_token.access_token.clone()),
})
}
}
#[cfg(feature = "__refreshing-token")]
#[derive(Deserialize)]
struct UsersResponse {
data: Vec<UserObject>,
}
#[cfg(feature = "__refreshing-token")]
#[derive(Deserialize)]
struct UserObject {
login: String,
}