oauth2-client 0.2.0

OAuth 2.0 Client
Documentation
use http_api_client_endpoint::{Body, Endpoint, Request, Response};
use oauth2_core::{
    access_token_request::{
        Body as REQ_Body, BodyWithResourceOwnerPasswordCredentialsGrant,
        CONTENT_TYPE as REQ_CONTENT_TYPE, METHOD as REQ_METHOD,
    },
    access_token_response::{CONTENT_TYPE as RES_CONTENT_TYPE, GENERAL_ERROR_BODY_KEY_ERROR},
    http::{
        header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE},
        Error as HttpError,
    },
    resource_owner_password_credentials_grant::access_token_response::{
        ErrorBody as RES_ErrorBody, SuccessfulBody as RES_SuccessfulBody,
    },
    serde::{de::DeserializeOwned, Serialize},
    types::{ClientPassword, Scope},
};
use serde_json::{Error as SerdeJsonError, Map, Value};
use serde_urlencoded::ser::Error as SerdeUrlencodedSerError;

use crate::ProviderExtResourceOwnerPasswordCredentialsGrant;

//
//
//
#[derive(Clone)]
pub struct AccessTokenEndpoint<'a, SCOPE>
where
    SCOPE: Scope,
{
    provider:
        &'a (dyn ProviderExtResourceOwnerPasswordCredentialsGrant<Scope = SCOPE> + Send + Sync),
    scopes: Option<Vec<SCOPE>>,
    username: String,
    password: String,
}
impl<'a, SCOPE> AccessTokenEndpoint<'a, SCOPE>
where
    SCOPE: Scope,
{
    pub fn new(
        provider: &'a (dyn ProviderExtResourceOwnerPasswordCredentialsGrant<Scope = SCOPE>
                 + Send
                 + Sync),
        scopes: impl Into<Option<Vec<SCOPE>>>,
        username: impl AsRef<str>,
        password: impl AsRef<str>,
    ) -> Self {
        Self {
            provider,
            scopes: scopes.into(),
            username: username.as_ref().to_owned(),
            password: password.as_ref().to_owned(),
        }
    }
}

impl<'a, SCOPE> Endpoint for AccessTokenEndpoint<'a, SCOPE>
where
    SCOPE: Scope + Serialize + DeserializeOwned,
{
    type RenderRequestError = AccessTokenEndpointError;

    type ParseResponseOutput = Result<RES_SuccessfulBody<SCOPE>, RES_ErrorBody>;
    type ParseResponseError = AccessTokenEndpointError;

    fn render_request(&self) -> Result<Request<Body>, Self::RenderRequestError> {
        let client_password = ClientPassword::new(
            self.provider
                .client_id()
                .cloned()
                .ok_or(AccessTokenEndpointError::ClientIdMissing)?,
            self.provider
                .client_secret()
                .cloned()
                .ok_or(AccessTokenEndpointError::ClientSecretMissing)?,
        );

        let mut body = if self.provider.client_password_in_request_body() {
            BodyWithResourceOwnerPasswordCredentialsGrant::new_with_client_password(
                &self.username,
                &self.password,
                self.scopes.to_owned().map(Into::into),
                client_password.to_owned(),
            )
        } else {
            BodyWithResourceOwnerPasswordCredentialsGrant::new(
                &self.username,
                &self.password,
                self.scopes.to_owned().map(Into::into),
            )
        };
        if let Some(extra_ret) = self.provider.access_token_request_body_extra(&body) {
            match extra_ret {
                Ok(extra) => {
                    body.set_extra(extra);
                }
                Err(err) => {
                    return Err(AccessTokenEndpointError::MakeRequestBodyExtraFailed(err));
                }
            }
        }

        //
        let body = REQ_Body::<SCOPE>::ResourceOwnerPasswordCredentialsGrant(body);

        let body_str = serde_urlencoded::to_string(body)
            .map_err(AccessTokenEndpointError::SerRequestBodyFailed)?;

        let mut request = Request::builder()
            .method(REQ_METHOD)
            .uri(self.provider.token_endpoint_url().as_str())
            .header(CONTENT_TYPE, REQ_CONTENT_TYPE.to_string())
            .header(ACCEPT, RES_CONTENT_TYPE.to_string());
        if !self.provider.client_password_in_request_body() {
            request = request.header(AUTHORIZATION, client_password.header_authorization());
        }
        let request = request
            .body(body_str.as_bytes().to_vec())
            .map_err(AccessTokenEndpointError::MakeRequestFailed)?;

        Ok(request)
    }

    fn parse_response(
        &self,
        response: Response<Body>,
    ) -> Result<Self::ParseResponseOutput, Self::ParseResponseError> {
        //
        if response.status().is_success() {
            let map = serde_json::from_slice::<Map<String, Value>>(response.body())
                .map_err(AccessTokenEndpointError::DeResponseBodyFailed)?;
            if !map.contains_key(GENERAL_ERROR_BODY_KEY_ERROR) {
                let body = serde_json::from_slice::<RES_SuccessfulBody<SCOPE>>(response.body())
                    .map_err(AccessTokenEndpointError::DeResponseBodyFailed)?;

                return Ok(Ok(body));
            }
        }

        let body = serde_json::from_slice::<RES_ErrorBody>(response.body())
            .map_err(AccessTokenEndpointError::DeResponseBodyFailed)?;
        Ok(Err(body))
    }
}

#[derive(thiserror::Error, Debug)]
pub enum AccessTokenEndpointError {
    #[error("ClientIdMissing")]
    ClientIdMissing,
    #[error("ClientSecretMissing")]
    ClientSecretMissing,
    //
    #[error("MakeRequestBodyExtraFailed {0}")]
    MakeRequestBodyExtraFailed(Box<dyn std::error::Error + Send + Sync>),
    //
    #[error("SerRequestBodyFailed {0}")]
    SerRequestBodyFailed(SerdeUrlencodedSerError),
    #[error("MakeRequestFailed {0}")]
    MakeRequestFailed(HttpError),
    //
    #[error("DeResponseBodyFailed {0}")]
    DeResponseBodyFailed(SerdeJsonError),
}