use std::borrow::Cow;
use std::fmt;
use rocket::figment::{self, Error, Figment};
pub struct OAuthConfig {
provider: Box<dyn Provider>,
client_id: String,
client_secret: String,
redirect_uri: Option<String>,
}
impl OAuthConfig {
pub fn new(
provider: impl Provider,
client_id: String,
client_secret: String,
redirect_uri: Option<String>,
) -> OAuthConfig {
OAuthConfig {
provider: Box::new(provider),
client_id,
client_secret,
redirect_uri,
}
}
pub fn from_figment(figment: &Figment, name: &str) -> Result<Self, Error> {
#[derive(serde::Deserialize)]
struct Config {
provider: Option<String>,
auth_uri: Option<String>,
token_uri: Option<String>,
client_id: String,
client_secret: String,
redirect_uri: Option<String>,
}
let conf: Config = figment.extract_inner(&format!("oauth.{}", name))?;
let provider = match (conf.provider, conf.auth_uri, conf.token_uri) {
(Some(provider_name), None, None) => StaticProvider::from_known_name(&provider_name)
.ok_or_else(|| {
figment::error::Kind::InvalidValue(
figment::error::Actual::Str(provider_name),
"one of the predefined 'provider' names".into(),
)
})?,
(None, Some(auth_uri), Some(token_uri)) => StaticProvider {
auth_uri: auth_uri.into(),
token_uri: token_uri.into(),
},
_ => {
return Err("either 'provider' or 'auth_uri'+'token_uri' should be specified, but not both".to_string().into());
}
};
Ok(OAuthConfig::new(
provider,
conf.client_id,
conf.client_secret,
conf.redirect_uri,
))
}
pub fn provider(&self) -> &dyn Provider {
&*self.provider
}
pub fn client_id(&self) -> &str {
&self.client_id
}
pub fn client_secret(&self) -> &str {
&self.client_secret
}
pub fn redirect_uri(&self) -> Option<&str> {
self.redirect_uri.as_deref()
}
}
impl fmt::Debug for OAuthConfig {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("OAuthConfig")
.field("provider", &(..))
.field("client_id", &self.client_id)
.field("client_secret", &self.client_secret)
.field("redirect_uri", &self.redirect_uri)
.finish()
}
}
pub trait Provider: Send + Sync + 'static {
fn auth_uri(&self) -> Cow<'_, str>;
fn token_uri(&self) -> Cow<'_, str>;
}
#[derive(Clone, Eq, PartialEq, Hash, Debug)]
pub struct StaticProvider {
pub auth_uri: Cow<'static, str>,
pub token_uri: Cow<'static, str>,
}
impl Provider for StaticProvider {
fn auth_uri(&self) -> Cow<'_, str> {
Cow::Borrowed(&*self.auth_uri)
}
fn token_uri(&self) -> Cow<'_, str> {
Cow::Borrowed(&*self.token_uri)
}
}
macro_rules! providers {
(@ $(($name:ident $docstr:expr) : $auth:expr, $token:expr),*) => {
impl StaticProvider {
$(
#[doc = $docstr]
#[allow(non_upper_case_globals)]
pub const $name: StaticProvider = StaticProvider {
auth_uri: Cow::Borrowed($auth),
token_uri: Cow::Borrowed($token),
};
)*
pub(crate) fn from_known_name(name: &str) -> Option<StaticProvider> {
$(
if name.eq_ignore_ascii_case(stringify!($name)) {
return Some(StaticProvider::$name);
}
)*
None
}
}
};
($($name:ident : $auth:expr, $token:expr),* $(,)*) => {
providers!(@ $(($name concat!("A `Provider` suitable for authorizing users with ", stringify!($name), ".")) : $auth, $token),*);
};
}
providers! {
Discord: "https://discordapp.com/api/oauth2/authorize", "https://discordapp.com/api/oauth2/token",
Facebook: "https://www.facebook.com/v3.1/dialog/oauth", "https://graph.facebook.com/v3.1/oauth/access_token",
GitHub: "https://github.com/login/oauth/authorize", "https://github.com/login/oauth/access_token",
Google: "https://accounts.google.com/o/oauth2/v2/auth", "https://www.googleapis.com/oauth2/v4/token",
Microsoft: "https://login.microsoftonline.com/common/oauth2/v2.0/authorize", "https://login.microsoftonline.com/common/oauth2/v2.0/token",
Reddit: "https://www.reddit.com/api/v1/authorize", "https://www.reddit.com/api/v1/access_token",
Wikimedia: "https://meta.wikimedia.org/w/rest.php/oauth2/authorize", "https://meta.wikimedia.org/w/rest.php/oauth2/access_token",
Yahoo: "https://api.login.yahoo.com/oauth2/request_auth", "https://api.login.yahoo.com/oauth2/get_token",
}