use std::{collections::HashSet, fmt::Debug, time::Duration};
use chrono::{DateTime, Utc};
use oauth2::{
basic::BasicTokenType, AccessToken, CsrfToken, PkceCodeVerifier, RefreshToken, TokenResponse,
};
use serde::{Deserialize, Serialize};
pub trait AuthenticationState: private::Sealed {}
impl AuthenticationState for Token {}
impl AuthenticationState for Unauthenticated {}
pub trait AuthFlow: private::Sealed + Debug {}
impl AuthFlow for AuthCodeFlow {}
impl AuthFlow for AuthCodePkceFlow {}
impl AuthFlow for ClientCredsFlow {}
impl AuthFlow for UnknownFlow {}
impl Debug for AuthCodeFlow {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AuthCodeFlow")
.field("csrf_token", &"[redacted]")
.finish()
}
}
impl Debug for AuthCodePkceFlow {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AuthCodePkceFlow")
.field("csrf_token", &"[redacted]")
.field("pkce_verifier", &"[redacted]")
.finish()
}
}
pub trait Authorised: private::Sealed {}
impl Authorised for AuthCodeFlow {}
impl Authorised for AuthCodePkceFlow {}
impl Authorised for UnknownFlow {}
mod private {
pub trait Sealed {}
impl Sealed for super::Token {}
impl Sealed for super::Unauthenticated {}
impl Sealed for super::AuthCodeFlow {}
impl Sealed for super::AuthCodePkceFlow {}
impl Sealed for super::ClientCredsFlow {}
impl Sealed for super::UnknownFlow {}
}
#[derive(Clone, Debug, Default)]
pub struct Scopes(pub(crate) HashSet<oauth2::Scope>);
impl<I> From<I> for Scopes
where
I: IntoIterator,
I::Item: Into<String>,
{
fn from(value: I) -> Self {
let scopes = value
.into_iter()
.map(|i| oauth2::Scope::new(i.into()))
.collect();
Self(scopes)
}
}
impl Scopes {
pub fn new<I>(scopes: I) -> Self
where
I: IntoIterator,
I::Item: Into<String>,
{
Self::from(scopes)
}
fn inner_vec(self) -> Vec<oauth2::Scope> {
self.0.into_iter().collect()
}
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct Token {
pub(crate) access_token: AccessToken,
pub(crate) refresh_token: Option<RefreshToken>,
pub expires_in: u64,
#[serde(default = "Utc::now")]
pub created_at: DateTime<Utc>,
#[serde(skip)]
pub expires_at: DateTime<Utc>,
#[serde(deserialize_with = "oauth2::helpers::deserialize_untagged_enum_case_insensitive")]
pub(crate) token_type: BasicTokenType,
#[serde(rename = "scope")]
#[serde(deserialize_with = "oauth2::helpers::deserialize_space_delimited_vec")]
#[serde(serialize_with = "oauth2::helpers::serialize_space_delimited_vec")]
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(default)]
pub(crate) scopes: Option<Vec<oauth2::Scope>>,
}
#[doc = include_str!("docs/internal_implementation_details.md")]
#[derive(Clone, Copy, Debug)]
pub struct Unauthenticated;
pub struct AuthCodeFlow {
pub(crate) csrf_token: CsrfToken,
}
pub struct AuthCodePkceFlow {
pub(crate) csrf_token: CsrfToken,
pub(crate) pkce_verifier: Option<PkceCodeVerifier>,
}
#[derive(Clone, Copy, Debug)]
pub struct ClientCredsFlow;
#[derive(Clone, Copy, Debug)]
pub struct UnknownFlow;
impl Token {
pub fn new(
access_token: impl Into<String>,
refresh_token: Option<&str>,
created_at: DateTime<Utc>,
expires_in: u64,
scopes: Option<Scopes>,
) -> Self {
let access_token = AccessToken::new(access_token.into());
let refresh_token = refresh_token.map(|t| RefreshToken::new(t.to_owned()));
let expires_at =
created_at + chrono::Duration::seconds(i64::try_from(expires_in).unwrap_or(i64::MAX));
let scopes = scopes.map(|s| s.inner_vec());
Self {
access_token,
refresh_token,
expires_in,
created_at,
expires_at,
token_type: BasicTokenType::Bearer,
scopes,
}
}
pub fn secret(&self) -> &str {
self.access_token.secret()
}
pub fn refresh_secret(&self) -> Option<&str> {
self.refresh_token.as_ref().map(|t| t.secret().as_str())
}
pub(crate) fn set_timestamps(self) -> Self {
let created_at = Utc::now();
let expires_at = created_at
+ chrono::Duration::seconds(i64::try_from(self.expires_in).unwrap_or(i64::MAX));
Self {
created_at,
expires_at,
..self
}
}
pub fn is_expired(&self) -> bool {
Utc::now() >= self.expires_at
}
pub fn is_refreshable(&self) -> bool {
self.refresh_token.is_some()
}
}
impl TokenResponse<BasicTokenType> for Token {
fn access_token(&self) -> &AccessToken {
&self.access_token
}
fn token_type(&self) -> &BasicTokenType {
&self.token_type
}
fn expires_in(&self) -> Option<Duration> {
Some(Duration::from_secs(self.expires_in))
}
fn refresh_token(&self) -> Option<&RefreshToken> {
self.refresh_token.as_ref()
}
fn scopes(&self) -> Option<&Vec<oauth2::Scope>> {
self.scopes.as_ref()
}
}