use std::collections::btree_map::{BTreeMap, Entry};
use std::collections::{BTreeSet, HashMap};
use std::default::Default;
use std::fmt;
use std::fmt::Display;
use url::form_urlencoded::Serializer;
use graph_error::{AuthorizationFailure, IdentityResult};
use crate::identity::{AsQuery, Prompt, ResponseType};
use crate::strum::IntoEnumIterator;
#[derive(
Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Serialize, Deserialize, EnumIter,
)]
pub enum AuthParameter {
ClientId,
ClientSecret,
RedirectUri,
AuthorizationCode,
AccessToken,
RefreshToken,
ResponseMode,
State,
SessionState,
ResponseType,
GrantType,
Nonce,
Prompt,
IdToken,
Resource,
DomainHint,
Scope,
LoginHint,
ClientAssertion,
ClientAssertionType,
CodeVerifier,
CodeChallenge,
CodeChallengeMethod,
AdminConsent,
Username,
Password,
DeviceCode,
}
impl AuthParameter {
pub fn alias(self) -> &'static str {
match self {
AuthParameter::ClientId => "client_id",
AuthParameter::ClientSecret => "client_secret",
AuthParameter::RedirectUri => "redirect_uri",
AuthParameter::AuthorizationCode => "code",
AuthParameter::AccessToken => "access_token",
AuthParameter::RefreshToken => "refresh_token",
AuthParameter::ResponseMode => "response_mode",
AuthParameter::ResponseType => "response_type",
AuthParameter::State => "state",
AuthParameter::SessionState => "session_state",
AuthParameter::GrantType => "grant_type",
AuthParameter::Nonce => "nonce",
AuthParameter::Prompt => "prompt",
AuthParameter::IdToken => "id_token",
AuthParameter::Resource => "resource",
AuthParameter::DomainHint => "domain_hint",
AuthParameter::Scope => "scope",
AuthParameter::LoginHint => "login_hint",
AuthParameter::ClientAssertion => "client_assertion",
AuthParameter::ClientAssertionType => "client_assertion_type",
AuthParameter::CodeVerifier => "code_verifier",
AuthParameter::CodeChallenge => "code_challenge",
AuthParameter::CodeChallengeMethod => "code_challenge_method",
AuthParameter::AdminConsent => "admin_consent",
AuthParameter::Username => "username",
AuthParameter::Password => "password",
AuthParameter::DeviceCode => "device_code",
}
}
fn is_debug_redacted(&self) -> bool {
matches!(
self,
AuthParameter::ClientId
| AuthParameter::ClientSecret
| AuthParameter::AccessToken
| AuthParameter::RefreshToken
| AuthParameter::IdToken
| AuthParameter::CodeVerifier
| AuthParameter::CodeChallenge
| AuthParameter::Password
)
}
}
impl Display for AuthParameter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.alias())
}
}
impl AsRef<str> for AuthParameter {
fn as_ref(&self) -> &'static str {
self.alias()
}
}
#[derive(Default, Clone, Eq, PartialEq, Serialize, Deserialize)]
pub struct AuthSerializer {
scopes: BTreeSet<String>,
parameters: BTreeMap<String, String>,
log_pii: bool,
}
impl AuthSerializer {
pub fn new() -> AuthSerializer {
AuthSerializer {
scopes: BTreeSet::new(),
parameters: BTreeMap::new(),
log_pii: false,
}
}
pub fn insert<V: ToString>(&mut self, oac: AuthParameter, value: V) -> &mut AuthSerializer {
self.parameters.insert(oac.to_string(), value.to_string());
self
}
pub fn entry_with<V: ToString>(&mut self, oac: AuthParameter, value: V) -> &mut String {
self.parameters
.entry(oac.alias().to_string())
.or_insert_with(|| value.to_string())
}
pub fn entry<V: ToString>(&mut self, oac: AuthParameter) -> Entry<String, String> {
self.parameters.entry(oac.alias().to_string())
}
pub fn get(&self, oac: AuthParameter) -> Option<String> {
self.parameters.get(oac.alias()).cloned()
}
pub fn contains(&self, t: AuthParameter) -> bool {
if t == AuthParameter::Scope {
return !self.scopes.is_empty();
}
self.parameters.contains_key(t.alias())
}
pub fn contains_key(&self, key: &str) -> bool {
self.parameters.contains_key(key)
}
pub fn remove(&mut self, oac: AuthParameter) -> &mut AuthSerializer {
self.parameters.remove(oac.alias());
self
}
pub fn client_id(&mut self, value: &str) -> &mut AuthSerializer {
self.insert(AuthParameter::ClientId, value)
}
pub fn state(&mut self, value: &str) -> &mut AuthSerializer {
self.insert(AuthParameter::State, value)
}
pub fn client_secret(&mut self, value: &str) -> &mut AuthSerializer {
self.insert(AuthParameter::ClientSecret, value)
}
pub fn redirect_uri(&mut self, value: &str) -> &mut AuthSerializer {
self.insert(AuthParameter::RedirectUri, value)
}
pub fn authorization_code(&mut self, value: &str) -> &mut AuthSerializer {
self.insert(AuthParameter::AuthorizationCode, value)
}
pub fn response_mode(&mut self, value: &str) -> &mut AuthSerializer {
self.insert(AuthParameter::ResponseMode, value)
}
pub fn response_type<T: ToString>(&mut self, value: T) -> &mut AuthSerializer {
self.insert(AuthParameter::ResponseType, value)
}
pub fn response_types(
&mut self,
value: std::collections::btree_set::Iter<'_, ResponseType>,
) -> &mut AuthSerializer {
self.insert(AuthParameter::ResponseType, value.as_query())
}
pub fn nonce(&mut self, value: &str) -> &mut AuthSerializer {
self.insert(AuthParameter::Nonce, value)
}
pub fn prompt(&mut self, value: &str) -> &mut AuthSerializer {
self.insert(AuthParameter::Prompt, value)
}
pub fn prompts(&mut self, value: &[Prompt]) -> &mut AuthSerializer {
self.insert(AuthParameter::Prompt, value.to_vec().as_query())
}
pub fn session_state(&mut self, value: &str) -> &mut AuthSerializer {
self.insert(AuthParameter::SessionState, value)
}
pub fn grant_type(&mut self, value: &str) -> &mut AuthSerializer {
self.insert(AuthParameter::GrantType, value)
}
pub fn resource(&mut self, value: &str) -> &mut AuthSerializer {
self.insert(AuthParameter::Resource, value)
}
pub fn code_verifier(&mut self, value: &str) -> &mut AuthSerializer {
self.insert(AuthParameter::CodeVerifier, value)
}
pub fn domain_hint(&mut self, value: &str) -> &mut AuthSerializer {
self.insert(AuthParameter::DomainHint, value)
}
pub fn code_challenge(&mut self, value: &str) -> &mut AuthSerializer {
self.insert(AuthParameter::CodeChallenge, value)
}
pub fn code_challenge_method(&mut self, value: &str) -> &mut AuthSerializer {
self.insert(AuthParameter::CodeChallengeMethod, value)
}
pub fn login_hint(&mut self, value: &str) -> &mut AuthSerializer {
self.insert(AuthParameter::LoginHint, value)
}
pub fn client_assertion(&mut self, value: &str) -> &mut AuthSerializer {
self.insert(AuthParameter::ClientAssertion, value)
}
pub fn client_assertion_type(&mut self, value: &str) -> &mut AuthSerializer {
self.insert(AuthParameter::ClientAssertionType, value)
}
pub fn username(&mut self, value: &str) -> &mut AuthSerializer {
self.insert(AuthParameter::Username, value)
}
pub fn password(&mut self, value: &str) -> &mut AuthSerializer {
self.insert(AuthParameter::Password, value)
}
pub fn refresh_token(&mut self, value: &str) -> &mut AuthSerializer {
self.insert(AuthParameter::RefreshToken, value)
}
pub fn device_code(&mut self, value: &str) -> &mut AuthSerializer {
self.insert(AuthParameter::DeviceCode, value)
}
pub fn add_scope<T: ToString>(&mut self, scope: T) -> &mut AuthSerializer {
self.scopes.insert(scope.to_string());
self
}
pub fn get_scopes(&self) -> &BTreeSet<String> {
&self.scopes
}
pub fn join_scopes(&self, sep: &str) -> String {
self.scopes
.iter()
.map(|s| &**s)
.collect::<Vec<&str>>()
.join(sep)
}
pub fn set_scope<T: ToString, I: IntoIterator<Item = T>>(&mut self, iter: I) -> &mut Self {
self.scopes = iter.into_iter().map(|s| s.to_string()).collect();
self
}
pub fn extend_scopes<T: ToString, I: IntoIterator<Item = T>>(&mut self, iter: I) -> &mut Self {
self.scopes.extend(iter.into_iter().map(|s| s.to_string()));
self
}
pub fn contains_scope<T: ToString>(&self, scope: T) -> bool {
self.scopes.contains(&scope.to_string())
}
}
impl AuthSerializer {
fn try_as_tuple(&self, oac: &AuthParameter) -> IdentityResult<(String, String)> {
if oac.eq(&AuthParameter::Scope) {
if self.scopes.is_empty() {
return Err(AuthorizationFailure::required(oac));
}
Ok((oac.alias().to_owned(), self.join_scopes(" ")))
} else {
Ok((
oac.alias().to_owned(),
self.get(*oac).ok_or(AuthorizationFailure::required(oac))?,
))
}
}
pub fn encode_query(
&mut self,
optional_fields: Vec<AuthParameter>,
required_fields: Vec<AuthParameter>,
) -> IdentityResult<String> {
let mut serializer = Serializer::new(String::new());
for parameter in required_fields {
if parameter.alias().eq("scope") {
if self.scopes.is_empty() {
return AuthorizationFailure::result::<String>(parameter.alias());
} else {
serializer.append_pair("scope", self.join_scopes(" ").as_str());
}
} else {
let value = self
.get(parameter)
.ok_or(AuthorizationFailure::required(parameter))?;
serializer.append_pair(parameter.alias(), value.as_str());
}
}
for parameter in optional_fields {
if parameter.alias().eq("scope") && !self.scopes.is_empty() {
serializer.append_pair("scope", self.join_scopes(" ").as_str());
} else if let Some(val) = self.get(parameter) {
serializer.append_pair(parameter.alias(), val.as_str());
}
}
Ok(serializer.finish())
}
pub fn as_credential_map(
&mut self,
optional_fields: Vec<AuthParameter>,
required_fields: Vec<AuthParameter>,
) -> IdentityResult<HashMap<String, String>> {
let mut required_map = required_fields
.iter()
.map(|oac| self.try_as_tuple(oac))
.collect::<IdentityResult<HashMap<String, String>>>()?;
let optional_map: HashMap<String, String> = optional_fields
.iter()
.flat_map(|oac| self.try_as_tuple(oac))
.collect();
required_map.extend(optional_map);
Ok(required_map)
}
}
impl<V: ToString> Extend<(AuthParameter, V)> for AuthSerializer {
fn extend<I: IntoIterator<Item = (AuthParameter, V)>>(&mut self, iter: I) {
iter.into_iter().for_each(|entry| {
self.insert(entry.0, entry.1);
});
}
}
impl fmt::Debug for AuthSerializer {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut map_debug: BTreeMap<&str, &str> = BTreeMap::new();
for (key, value) in self.parameters.iter() {
if self.log_pii {
map_debug.insert(key.as_str(), value.as_str());
} else if let Some(oac) = AuthParameter::iter()
.find(|oac| oac.alias().eq(key.as_str()) && oac.is_debug_redacted())
{
map_debug.insert(oac.alias(), "[REDACTED]");
} else {
map_debug.insert(key.as_str(), value.as_str());
}
}
f.debug_struct("OAuthSerializer")
.field("credentials", &map_debug)
.field("scopes", &self.scopes)
.finish()
}
}