use std::{borrow::Cow, marker::PhantomData};
use oxide_auth::{
endpoint::{WebResponse, QueryParameter, NormalizedParameter},
code_grant::authorization::{Error as AuthorizationError, Request as AuthorizationRequest},
};
use crate::code_grant::authorization::{
authorization_code, Endpoint as AuthorizationEndpoint, Extension, Pending,
};
use super::*;
use url::Url;
pub struct AuthorizationFlow<E, R>
where
E: Endpoint<R>,
R: WebRequest,
{
endpoint: WrappedAuthorization<E, R>,
}
struct WrappedAuthorization<E: Endpoint<R>, R>
where
E: Endpoint<R>,
R: WebRequest,
{
inner: E,
extension_fallback: (),
r_type: PhantomData<R>,
}
#[derive(Clone)]
pub struct WrappedRequest<R>
where
R: WebRequest,
{
query: NormalizedParameter,
error: Option<R::Error>,
}
struct AuthorizationPending<'a, E: 'a, R: 'a>
where
E: Endpoint<R> + Send,
R: WebRequest,
{
endpoint: &'a mut WrappedAuthorization<E, R>,
pending: Pending,
request: R,
}
struct AuthorizationPartial<'a, E: 'a, R: 'a>
where
E: Endpoint<R> + Send,
R: WebRequest,
{
inner: AuthorizationPartialInner<'a, E, R>,
_with_request: Option<Box<dyn FnOnce(R) + Send>>,
}
enum AuthorizationPartialInner<'a, E: 'a, R: 'a>
where
E: Endpoint<R> + Send,
R: WebRequest,
{
Pending {
pending: AuthorizationPending<'a, E, R>,
},
Failed {
request: R,
response: R::Response,
},
Error {
request: R,
error: E::Error,
},
}
impl<E, R> AuthorizationFlow<E, R>
where
E: Endpoint<R> + Send + Sync,
R: WebRequest + Send + Sync,
<R as WebRequest>::Error: Send + Sync,
{
pub fn prepare(mut endpoint: E) -> Result<Self, E::Error> {
if endpoint.registrar().is_none() {
return Err(endpoint.error(OAuthError::PrimitiveError));
}
if endpoint.authorizer_mut().is_none() {
return Err(endpoint.error(OAuthError::PrimitiveError));
}
Ok(AuthorizationFlow {
endpoint: WrappedAuthorization {
inner: endpoint,
extension_fallback: (),
r_type: PhantomData,
},
})
}
pub async fn execute(&mut self, mut request: R) -> Result<R::Response, E::Error> {
let negotiated =
authorization_code(&mut self.endpoint, &WrappedRequest::new(&mut request)).await;
let inner = match negotiated {
Err(err) => match authorization_error(&mut self.endpoint.inner, &mut request, err) {
Ok(response) => AuthorizationPartialInner::Failed { request, response },
Err(error) => AuthorizationPartialInner::Error { request, error },
},
Ok(negotiated) => AuthorizationPartialInner::Pending {
pending: AuthorizationPending {
endpoint: &mut self.endpoint,
pending: negotiated,
request,
},
},
};
let partial = AuthorizationPartial {
inner,
_with_request: None,
};
partial.finish().await
}
}
impl<'a, E, R> AuthorizationPartial<'a, E, R>
where
E: Endpoint<R> + Send,
R: WebRequest + Send,
{
pub async fn finish(self) -> Result<R::Response, E::Error> {
let (_request, result) = match self.inner {
AuthorizationPartialInner::Pending { pending } => pending.finish().await,
AuthorizationPartialInner::Failed { request, response } => (request, Ok(response)),
AuthorizationPartialInner::Error { request, error } => (request, Err(error)),
};
result
}
}
fn authorization_error<E, R>(
endpoint: &mut E, request: &mut R, error: AuthorizationError,
) -> Result<R::Response, E::Error>
where
E: Endpoint<R>,
R: WebRequest,
{
match error {
AuthorizationError::Ignore => Err(endpoint.error(OAuthError::DenySilently)),
AuthorizationError::Redirect(mut target) => {
let mut response =
endpoint.response(request, Template::new_redirect(Some(target.description())))?;
response
.redirect(target.into())
.map_err(|err| endpoint.web_error(err))?;
Ok(response)
}
AuthorizationError::PrimitiveError => Err(endpoint.error(OAuthError::PrimitiveError)),
}
}
impl<'a, E, R> AuthorizationPending<'a, E, R>
where
E: Endpoint<R> + Send,
R: WebRequest + Send,
{
async fn finish(mut self) -> (R, Result<R::Response, E::Error>) {
let checked = self
.endpoint
.owner_solicitor()
.check_consent(&mut self.request, self.pending.as_solicitation())
.await;
match checked {
OwnerConsent::Denied => self.deny(),
OwnerConsent::InProgress(resp) => self.in_progress(resp),
OwnerConsent::Authorized(who) => self.authorize(who).await,
OwnerConsent::Error(err) => (self.request, Err(self.endpoint.inner.web_error(err))),
}
}
fn in_progress(self, response: R::Response) -> (R, Result<R::Response, E::Error>) {
(self.request, Ok(response))
}
fn deny(mut self) -> (R, Result<R::Response, E::Error>) {
let result = self.pending.deny();
let result = Self::convert_result(result, &mut self.endpoint.inner, &mut self.request);
(self.request, result)
}
async fn authorize(mut self, who: String) -> (R, Result<R::Response, E::Error>) {
let result = self.pending.authorize(self.endpoint, who.into()).await;
let result = Self::convert_result(result, &mut self.endpoint.inner, &mut self.request);
(self.request, result)
}
fn convert_result(
result: Result<Url, AuthorizationError>, endpoint: &mut E, request: &mut R,
) -> Result<R::Response, E::Error> {
match result {
Ok(url) => {
let mut response = endpoint.response(request, Template::new_redirect(None))?;
response.redirect(url).map_err(|err| endpoint.web_error(err))?;
Ok(response)
}
Err(err) => authorization_error(endpoint, request, err),
}
}
}
impl<E, R> WrappedAuthorization<E, R>
where
E: Endpoint<R>,
R: WebRequest,
{
fn owner_solicitor(&mut self) -> &mut (dyn OwnerSolicitor<R> + Send) {
self.inner.owner_solicitor().unwrap()
}
}
impl<E, R> AuthorizationEndpoint for WrappedAuthorization<E, R>
where
E: Endpoint<R>,
R: WebRequest,
{
fn registrar(&self) -> &(dyn Registrar + Sync) {
self.inner.registrar().unwrap()
}
fn authorizer(&mut self) -> &mut (dyn Authorizer + Send) {
self.inner.authorizer_mut().unwrap()
}
fn extension(&mut self) -> &mut (dyn Extension + Send) {
self.inner
.extension()
.and_then(super::Extension::authorization)
.unwrap_or(&mut self.extension_fallback)
}
}
impl<'a, R> WrappedRequest<R>
where
R: WebRequest + 'a,
{
pub fn new(request: &'a mut R) -> Self {
Self::new_or_fail(request).unwrap_or_else(Self::from_err)
}
fn new_or_fail(request: &'a mut R) -> Result<Self, R::Error> {
Ok(WrappedRequest {
query: request.query()?.into_owned(),
error: None,
})
}
fn from_err(err: R::Error) -> Self {
WrappedRequest {
query: Default::default(),
error: Some(err),
}
}
}
impl<R> AuthorizationRequest for WrappedRequest<R>
where
R: WebRequest,
{
fn valid(&self) -> bool {
self.error.is_none()
}
fn client_id(&self) -> Option<Cow<str>> {
self.query.unique_value("client_id")
}
fn scope(&self) -> Option<Cow<str>> {
self.query.unique_value("scope")
}
fn redirect_uri(&self) -> Option<Cow<str>> {
self.query.unique_value("redirect_uri")
}
fn state(&self) -> Option<Cow<str>> {
self.query.unique_value("state")
}
fn response_type(&self) -> Option<Cow<str>> {
self.query.unique_value("response_type")
}
fn extension(&self, key: &str) -> Option<Cow<str>> {
self.query.unique_value(key)
}
}