use std::borrow::Cow;
use std::result::Result as StdResult;
use url::Url;
use chrono::{Duration, Utc};
use crate::code_grant::error::{AuthorizationError, AuthorizationErrorType};
use crate::primitives::authorizer::Authorizer;
use crate::primitives::registrar::{ClientUrl, ExactUrl, Registrar, RegistrarError, PreGrant};
use crate::primitives::grant::{Extensions, Grant};
use crate::{endpoint::Scope, endpoint::Solicitation, primitives::registrar::BoundClient};
pub trait Request {
fn valid(&self) -> bool;
fn client_id(&self) -> Option<Cow<str>>;
fn scope(&self) -> Option<Cow<str>>;
fn redirect_uri(&self) -> Option<Cow<str>>;
fn state(&self) -> Option<Cow<str>>;
fn response_type(&self) -> Option<Cow<str>>;
fn extension(&self, key: &str) -> Option<Cow<str>>;
}
pub trait Extension {
fn extend(&mut self, request: &dyn Request) -> std::result::Result<Extensions, ()>;
}
impl Extension for () {
fn extend(&mut self, _: &dyn Request) -> std::result::Result<Extensions, ()> {
Ok(Extensions::new())
}
}
pub trait Endpoint {
fn registrar(&self) -> &dyn Registrar;
fn authorizer(&mut self) -> &mut dyn Authorizer;
fn extension(&mut self) -> &mut dyn Extension;
}
pub struct Authorization {
state: AuthorizationState,
extensions: Option<Extensions>,
scope: Option<Scope>,
}
enum AuthorizationState {
Binding {
client_id: String,
redirect_uri: Option<ExactUrl>,
},
Extending {
bound_client: BoundClient<'static>,
},
Negotiating {
bound_client: BoundClient<'static>,
},
Pending {
pre_grant: PreGrant,
state: Option<String>,
extensions: Extensions,
},
Err(Error),
}
pub enum Input<'machine> {
Bound {
request: &'machine dyn Request,
bound_client: BoundClient<'static>,
},
Extended(Extensions),
Negotiated {
pre_grant: PreGrant,
state: Option<String>,
},
Finished,
None,
}
pub enum Output<'machine> {
Bind {
client_id: String,
redirect_uri: Option<ExactUrl>,
},
Extend,
Negotiate {
bound_client: &'machine BoundClient<'static>,
scope: Option<Scope>,
},
Ok {
pre_grant: PreGrant,
state: Option<String>,
extensions: Extensions,
},
Err(Error),
}
impl Authorization {
pub fn new(request: &dyn Request) -> Self {
Authorization {
state: Self::validate(request).unwrap_or_else(AuthorizationState::Err),
extensions: None,
scope: None,
}
}
pub fn advance<'req>(&mut self, input: Input<'req>) -> Output<'_> {
self.state = match (self.take(), input) {
(current, Input::None) => current,
(
AuthorizationState::Binding { .. },
Input::Bound {
request,
bound_client,
},
) => self
.bound(request, bound_client)
.unwrap_or_else(AuthorizationState::Err),
(AuthorizationState::Extending { bound_client }, Input::Extended(grant_extension)) => {
self.extended(grant_extension, bound_client)
}
(AuthorizationState::Negotiating { .. }, Input::Negotiated { pre_grant, state }) => {
self.negotiated(state, pre_grant)
}
(AuthorizationState::Err(err), _) => AuthorizationState::Err(err),
(_, _) => AuthorizationState::Err(Error::PrimitiveError),
};
self.output()
}
fn output(&self) -> Output<'_> {
match &self.state {
AuthorizationState::Err(err) => Output::Err(err.clone()),
AuthorizationState::Binding {
client_id,
redirect_uri,
} => Output::Bind {
client_id: client_id.to_string(),
redirect_uri: (*redirect_uri).clone(),
},
AuthorizationState::Extending { .. } => Output::Extend,
AuthorizationState::Negotiating { bound_client } => Output::Negotiate {
bound_client: &bound_client,
scope: self.scope.clone(),
},
AuthorizationState::Pending {
pre_grant,
state,
extensions,
} => Output::Ok {
pre_grant: pre_grant.clone(),
state: state.clone(),
extensions: extensions.clone(),
},
}
}
fn bound(
&mut self, request: &dyn Request, bound_client: BoundClient<'static>,
) -> Result<AuthorizationState> {
match request.response_type() {
Some(ref method) if method.as_ref() == "code" => (),
_ => {
let prepared_error = ErrorUrl::with_request(
request,
(*bound_client.redirect_uri).to_url(),
AuthorizationErrorType::UnsupportedResponseType,
);
return Err(Error::Redirect(prepared_error));
}
}
let scope = request.scope();
self.scope = match scope.map(|scope| scope.as_ref().parse()) {
None => None,
Some(Err(_)) => {
let prepared_error = ErrorUrl::with_request(
request,
(*bound_client.redirect_uri).to_url(),
AuthorizationErrorType::InvalidScope,
);
return Err(Error::Redirect(prepared_error));
}
Some(Ok(scope)) => Some(scope),
};
Ok(AuthorizationState::Extending { bound_client })
}
fn extended(
&mut self, grant_extension: Extensions, bound_client: BoundClient<'static>,
) -> AuthorizationState {
self.extensions = Some(grant_extension);
AuthorizationState::Negotiating { bound_client }
}
fn negotiated(&mut self, state: Option<String>, pre_grant: PreGrant) -> AuthorizationState {
AuthorizationState::Pending {
pre_grant,
state,
extensions: self.extensions.clone().expect("Should have extensions by now"),
}
}
fn take(&mut self) -> AuthorizationState {
std::mem::replace(&mut self.state, AuthorizationState::Err(Error::PrimitiveError))
}
fn validate(request: &dyn Request) -> Result<AuthorizationState> {
if !request.valid() {
return Err(Error::Ignore);
};
let client_id = request.client_id().ok_or(Error::Ignore)?;
let redirect_uri: Option<Cow<ExactUrl>> = match request.redirect_uri() {
None => None,
Some(ref uri) => {
let parsed = uri.parse().map_err(|_| Error::Ignore)?;
Some(Cow::Owned(parsed))
}
};
Ok(AuthorizationState::Binding {
client_id: client_id.into_owned(),
redirect_uri: redirect_uri.map(|uri| uri.into_owned()),
})
}
}
pub fn authorization_code(handler: &mut dyn Endpoint, request: &dyn Request) -> self::Result<Pending> {
enum Requested {
None,
Bind {
client_id: String,
redirect_uri: Option<ExactUrl>,
},
Extend,
Negotiate {
client_id: String,
redirect_uri: Url,
scope: Option<Scope>,
},
}
let mut authorization = Authorization::new(request);
let mut requested = Requested::None;
let mut the_redirect_uri = None;
loop {
let input = match requested {
Requested::None => Input::None,
Requested::Bind {
client_id,
redirect_uri,
} => {
let client_url = ClientUrl {
client_id: Cow::Owned(client_id),
redirect_uri: redirect_uri.map(Cow::Owned),
};
let bound_client = match handler.registrar().bound_redirect(client_url) {
Err(RegistrarError::Unspecified) => return Err(Error::Ignore),
Err(RegistrarError::PrimitiveError) => return Err(Error::PrimitiveError),
Ok(pre_grant) => pre_grant,
};
the_redirect_uri = Some(bound_client.redirect_uri.clone().into_owned());
Input::Bound {
request,
bound_client,
}
}
Requested::Extend => {
let grant_extension = match handler.extension().extend(request) {
Ok(extension_data) => extension_data,
Err(()) => {
let prepared_error = ErrorUrl::with_request(
request,
the_redirect_uri.unwrap().into(),
AuthorizationErrorType::InvalidRequest,
);
return Err(Error::Redirect(prepared_error));
}
};
Input::Extended(grant_extension)
}
Requested::Negotiate {
client_id,
redirect_uri,
scope,
} => {
let bound_client = BoundClient {
client_id: Cow::Owned(client_id),
redirect_uri: Cow::Owned(redirect_uri.clone().into()),
};
let pre_grant = handler
.registrar()
.negotiate(bound_client, scope)
.map_err(|err| match err {
RegistrarError::PrimitiveError => Error::PrimitiveError,
RegistrarError::Unspecified => {
let prepared_error = ErrorUrl::with_request(
request,
redirect_uri,
AuthorizationErrorType::InvalidScope,
);
Error::Redirect(prepared_error)
}
})?;
Input::Negotiated {
pre_grant,
state: request.state().map(|s| s.into_owned()),
}
}
};
requested = match authorization.advance(input) {
Output::Bind {
client_id,
redirect_uri,
} => Requested::Bind {
client_id,
redirect_uri,
},
Output::Extend => Requested::Extend,
Output::Negotiate { bound_client, scope } => Requested::Negotiate {
client_id: bound_client.client_id.clone().into_owned(),
redirect_uri: bound_client.redirect_uri.to_url(),
scope,
},
Output::Ok {
pre_grant,
state,
extensions,
} => {
return Ok(Pending {
pre_grant,
state,
extensions,
})
}
Output::Err(e) => return Err(e),
};
}
}
pub struct Pending {
pre_grant: PreGrant,
state: Option<String>,
extensions: Extensions,
}
impl Pending {
pub fn as_solicitation(&self) -> Solicitation<'_> {
Solicitation {
grant: Cow::Borrowed(&self.pre_grant),
state: self.state.as_ref().map(|s| Cow::Borrowed(&**s)),
}
}
pub fn deny(self) -> Result<Url> {
let url = self.pre_grant.redirect_uri;
let mut error = AuthorizationError::default();
error.set_type(AuthorizationErrorType::AccessDenied);
let error = ErrorUrl::new_generic(url.into_url(), self.state, error);
Err(Error::Redirect(error))
}
pub fn authorize(self, handler: &mut dyn Endpoint, owner_id: Cow<str>) -> Result<Url> {
let mut url = self.pre_grant.redirect_uri.to_url();
let grant = handler
.authorizer()
.authorize(Grant {
owner_id: owner_id.into_owned(),
client_id: self.pre_grant.client_id,
redirect_uri: self.pre_grant.redirect_uri.into_url(),
scope: self.pre_grant.scope,
until: Utc::now() + Duration::minutes(10),
extensions: self.extensions,
})
.map_err(|()| Error::PrimitiveError)?;
url.query_pairs_mut()
.append_pair("code", grant.as_str())
.extend_pairs(self.state.map(|v| ("state", v)))
.finish();
Ok(url)
}
pub fn pre_grant(&self) -> &PreGrant {
&self.pre_grant
}
}
#[derive(Clone)]
pub enum Error {
Ignore,
Redirect(ErrorUrl),
PrimitiveError,
}
#[derive(Clone)]
pub struct ErrorUrl {
base_uri: Url,
error: AuthorizationError,
}
type Result<T> = StdResult<T, Error>;
impl ErrorUrl {
fn new_generic<S>(mut url: Url, state: Option<S>, error: AuthorizationError) -> ErrorUrl
where
S: AsRef<str>,
{
url.query_pairs_mut()
.extend_pairs(state.as_ref().map(|st| ("state", st.as_ref())));
ErrorUrl { base_uri: url, error }
}
pub fn new(url: Url, state: Option<&str>, error: AuthorizationError) -> ErrorUrl {
ErrorUrl::new_generic(url, state, error)
}
pub fn with_request(
request: &dyn Request, redirect_uri: Url, err_type: AuthorizationErrorType,
) -> ErrorUrl {
let mut err = ErrorUrl::new(
redirect_uri,
request.state().as_deref(),
AuthorizationError::default(),
);
err.description().set_type(err_type);
err
}
pub fn description(&mut self) -> &mut AuthorizationError {
&mut self.error
}
}
impl Error {
pub fn description(&mut self) -> Option<&mut AuthorizationError> {
match self {
Error::Ignore => None,
Error::Redirect(inner) => Some(inner.description()),
Error::PrimitiveError => None,
}
}
}
impl Into<Url> for ErrorUrl {
fn into(self) -> Url {
let mut url = self.base_uri;
url.query_pairs_mut().extend_pairs(self.error.into_iter());
url
}
}