#[cfg(doc)]
use crate::deps::tracing::Instrument;
use crate::{
deps::tracing::{debug, warn},
utils::from_env::FromEnv,
};
use core::fmt;
use eyre::eyre;
use oauth2::{
basic::{BasicClient, BasicTokenType},
AccessToken, AuthUrl, ClientId, ClientSecret, EmptyExtraTokenFields, EndpointNotSet,
EndpointSet, HttpClientError, RefreshToken, RequestTokenError, Scope, StandardErrorResponse,
StandardTokenResponse, TokenResponse, TokenUrl,
};
use std::{future::IntoFuture, pin::Pin};
use tokio::{
sync::watch::{self, Ref},
task::JoinHandle,
time::MissedTickBehavior,
};
type Token = StandardTokenResponse<EmptyExtraTokenFields, BasicTokenType>;
type MyOAuthClient =
BasicClient<EndpointSet, EndpointNotSet, EndpointNotSet, EndpointNotSet, EndpointSet>;
#[derive(Debug, Clone, FromEnv)]
#[from_env(crate)]
pub struct OAuthConfig {
#[from_env(var = "OAUTH_CLIENT_ID", desc = "OAuth client ID for the builder")]
pub oauth_client_id: String,
#[from_env(
var = "OAUTH_CLIENT_SECRET",
desc = "OAuth client secret for the builder"
)]
pub oauth_client_secret: String,
#[from_env(
var = "OAUTH_AUTHENTICATE_URL",
desc = "OAuth authenticate URL for the builder for performing OAuth logins"
)]
pub oauth_authenticate_url: url::Url,
#[from_env(
var = "OAUTH_TOKEN_URL",
desc = "OAuth token URL for the builder to get an OAuth2 access token"
)]
pub oauth_token_url: url::Url,
#[from_env(
var = "AUTH_TOKEN_REFRESH_INTERVAL",
desc = "The oauth token refresh interval in seconds"
)]
pub oauth_token_refresh_interval: u64,
}
impl OAuthConfig {
pub fn authenticator(&self) -> Authenticator {
Authenticator::new(self)
}
}
#[derive(Debug)]
pub struct Authenticator {
config: OAuthConfig,
client: MyOAuthClient,
reqwest: reqwest::Client,
token: watch::Sender<Option<Token>>,
}
impl Authenticator {
pub fn new(config: &OAuthConfig) -> Self {
let client = BasicClient::new(ClientId::new(config.oauth_client_id.clone()))
.set_client_secret(ClientSecret::new(config.oauth_client_secret.clone()))
.set_auth_uri(AuthUrl::from_url(config.oauth_authenticate_url.clone()))
.set_token_uri(TokenUrl::from_url(config.oauth_token_url.clone()));
let rq_client = reqwest::Client::builder()
.redirect(reqwest::redirect::Policy::none())
.pool_max_idle_per_host(0)
.build()
.unwrap();
Self {
config: config.clone(),
client,
reqwest: rq_client,
token: watch::channel(None).0,
}
}
pub async fn authenticate(
&self,
) -> Result<
(),
RequestTokenError<
HttpClientError<reqwest::Error>,
StandardErrorResponse<oauth2::basic::BasicErrorResponseType>,
>,
> {
let token = self.fetch_oauth_token().await?;
self.set_token(token);
Ok(())
}
pub fn is_authenticated(&self) -> bool {
self.token.borrow().is_some()
}
fn set_token(&self, token: StandardTokenResponse<EmptyExtraTokenFields, BasicTokenType>) {
self.token.send_replace(Some(token));
}
pub fn token(&self) -> SharedToken {
self.token.subscribe().into()
}
pub async fn fetch_oauth_token(
&self,
) -> Result<
Token,
RequestTokenError<
HttpClientError<reqwest::Error>,
StandardErrorResponse<oauth2::basic::BasicErrorResponseType>,
>,
> {
let token_result = self
.client
.exchange_client_credentials()
.request_async(&self.reqwest)
.await?;
Ok(token_result)
}
pub const fn config(&self) -> &OAuthConfig {
&self.config
}
async fn task_future(self) {
let duration = tokio::time::Duration::from_secs(self.config.oauth_token_refresh_interval);
let mut interval = tokio::time::interval(duration);
interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
loop {
interval.tick().await;
debug!("Refreshing oauth token");
match self.authenticate().await {
Ok(_) => debug!("Successfully refreshed oauth token"),
Err(error) => warn!(
error = %format!("{:#}", eyre!(error)),
"Failed to refresh oauth token"
),
};
}
}
pub fn spawn(self) -> JoinHandle<()> {
tokio::spawn(self.task_future())
}
}
impl IntoFuture for Authenticator {
type Output = ();
type IntoFuture = Pin<Box<dyn std::future::Future<Output = ()> + Send>>;
fn into_future(self) -> Self::IntoFuture {
Box::pin(self.task_future())
}
}
#[derive(Debug, Clone)]
pub struct SharedToken(watch::Receiver<Option<Token>>);
impl From<watch::Receiver<Option<Token>>> for SharedToken {
fn from(inner: watch::Receiver<Option<Token>>) -> Self {
Self(inner)
}
}
impl SharedToken {
pub async fn secret(&self) -> Result<String, watch::error::RecvError> {
Ok(self
.clone()
.token()
.await?
.access_token()
.secret()
.to_owned())
}
pub async fn token(&mut self) -> Result<TokenRef<'_>, watch::error::RecvError> {
self.0.wait_for(Option::is_some).await.map(Into::into)
}
pub async fn wait(&self) -> Result<(), watch::error::RecvError> {
self.clone().0.wait_for(Option::is_some).await.map(drop)
}
pub fn borrow(&mut self) -> Ref<'_, Option<Token>> {
self.0.borrow()
}
pub fn is_authenticated(&self) -> bool {
self.0.borrow().is_some()
}
}
#[doc(hidden)]
impl SharedToken {
pub fn empty() -> Self {
Self(watch::channel(None).1)
}
}
pub struct TokenRef<'a> {
inner: Ref<'a, Option<Token>>,
}
impl<'a> From<Ref<'a, Option<Token>>> for TokenRef<'a> {
fn from(inner: Ref<'a, Option<Token>>) -> Self {
Self { inner }
}
}
impl fmt::Debug for TokenRef<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TokenRef").finish_non_exhaustive()
}
}
impl<'a> TokenRef<'a> {
pub fn inner(&'a self) -> &'a Token {
self.inner.as_ref().unwrap()
}
pub fn access_token(&self) -> &AccessToken {
self.inner().access_token()
}
pub fn token_type(&self) -> &<Token as TokenResponse>::TokenType {
self.inner().token_type()
}
pub fn expires_in(&self) -> Option<std::time::Duration> {
self.inner().expires_in()
}
pub fn refresh_token(&self) -> Option<&RefreshToken> {
self.inner().refresh_token()
}
pub fn scopes(&self) -> Option<&Vec<Scope>> {
self.inner().scopes()
}
}