oauth2-twitch 0.1.1

OAuth 2.0 Twitch
Documentation
use std::error;

use oauth2_client::{
    authorization_code_grant::provider_ext::{
        AccessTokenResponseErrorBody, AccessTokenResponseSuccessfulBody,
    },
    oauth2_core::re_exports::AccessTokenResponseErrorBodyError,
    re_exports::{
        serde_json, thiserror, Body, ClientId, ClientSecret, Deserialize, Map, RedirectUri,
        Response, SerdeJsonError, Serialize, Url, UrlParseError, Value,
    },
    Provider, ProviderExtAuthorizationCodeGrant,
};

use crate::{TwitchScope, AUTHORIZATION_URL, TOKEN_URL};

#[derive(Debug, Clone)]
pub struct TwitchProviderForWebServerApps {
    client_id: ClientId,
    client_secret: ClientSecret,
    redirect_uri: RedirectUri,
    pub force_verify: Option<bool>,
    //
    token_endpoint_url: Url,
    authorization_endpoint_url: Url,
}

impl TwitchProviderForWebServerApps {
    pub fn new(
        client_id: ClientId,
        client_secret: ClientSecret,
        redirect_uri: RedirectUri,
    ) -> Result<Self, UrlParseError> {
        Ok(Self {
            client_id,
            client_secret,
            redirect_uri,
            force_verify: None,
            token_endpoint_url: TOKEN_URL.parse()?,
            authorization_endpoint_url: AUTHORIZATION_URL.parse()?,
        })
    }

    pub fn configure<F>(mut self, mut f: F) -> Self
    where
        F: FnMut(&mut Self),
    {
        f(&mut self);
        self
    }
}
impl Provider for TwitchProviderForWebServerApps {
    type Scope = TwitchScope;

    fn client_id(&self) -> Option<&ClientId> {
        Some(&self.client_id)
    }

    fn client_secret(&self) -> Option<&ClientSecret> {
        Some(&self.client_secret)
    }

    fn token_endpoint_url(&self) -> &Url {
        &self.token_endpoint_url
    }
}
impl ProviderExtAuthorizationCodeGrant for TwitchProviderForWebServerApps {
    fn redirect_uri(&self) -> Option<&RedirectUri> {
        Some(&self.redirect_uri)
    }

    fn scopes_default(&self) -> Option<Vec<<Self as Provider>::Scope>> {
        Some(vec![TwitchScope::UserReadEmail])
    }

    fn authorization_endpoint_url(&self) -> &Url {
        &self.authorization_endpoint_url
    }

    fn authorization_request_query_extra(&self) -> Option<Map<String, Value>> {
        let mut map = Map::new();

        if let Some(force_verify) = &self.force_verify {
            if *force_verify {
                map.insert("force_verify".to_owned(), Value::String(true.to_string()));
            }
        }

        if map.is_empty() {
            None
        } else {
            Some(map)
        }
    }

    #[allow(clippy::type_complexity)]
    fn access_token_response_parsing(
        &self,
        response: &Response<Body>,
    ) -> Option<
        Result<
            Result<
                AccessTokenResponseSuccessfulBody<<Self as Provider>::Scope>,
                AccessTokenResponseErrorBody,
            >,
            Box<dyn error::Error + Send + Sync + 'static>,
        >,
    > {
        fn doing(
            response: &Response<Body>,
        ) -> Result<
            Result<
                AccessTokenResponseSuccessfulBody<
                    <TwitchProviderForWebServerApps as Provider>::Scope,
                >,
                TwitchAccessTokenResponseErrorBody,
            >,
            Box<dyn error::Error + Send + Sync + 'static>,
        > {
            if response.status().is_success() {
                let map = serde_json::from_slice::<Map<String, Value>>(response.body())
                    .map_err(AccessTokenResponseParsingError::DeResponseBodyFailed)?;
                if !map.contains_key("errcode") {
                    let body = serde_json::from_slice::<
                        AccessTokenResponseSuccessfulBody<
                            <TwitchProviderForWebServerApps as Provider>::Scope,
                        >,
                    >(response.body())
                    .map_err(AccessTokenResponseParsingError::DeResponseBodyFailed)?;

                    return Ok(Ok(body));
                }
            }

            let body =
                serde_json::from_slice::<TwitchAccessTokenResponseErrorBody>(response.body())
                    .map_err(AccessTokenResponseParsingError::DeResponseBodyFailed)?;
            Ok(Err(body))
        }

        Some(doing(response).map(|ret| ret.map(Into::into).map_err(Into::into)))
    }
}

#[derive(Serialize, Deserialize)]
pub struct TwitchAccessTokenResponseErrorBody {
    pub status: usize,
    pub message: String,
}
impl From<TwitchAccessTokenResponseErrorBody> for AccessTokenResponseErrorBody {
    fn from(body: TwitchAccessTokenResponseErrorBody) -> Self {
        let error = if body.message.to_ascii_lowercase().contains("invalid") {
            AccessTokenResponseErrorBodyError::InvalidRequest
        } else {
            AccessTokenResponseErrorBodyError::Other(body.status.to_string())
        };

        Self::new(error, Some(body.message), None)
    }
}

#[derive(thiserror::Error, Debug)]
pub enum AccessTokenResponseParsingError {
    //
    #[error("DeResponseBodyFailed {0}")]
    DeResponseBodyFailed(SerdeJsonError),
}

#[cfg(test)]
mod tests {
    use super::*;

    use std::error;

    use oauth2_client::{
        authorization_code_grant::{AccessTokenEndpoint, AuthorizationEndpoint},
        re_exports::{http::StatusCode, Endpoint as _},
    };

    #[test]
    fn authorization_request() -> Result<(), Box<dyn error::Error>> {
        let provider = TwitchProviderForWebServerApps::new(
            "CLIENT_ID".to_owned(),
            "CLIENT_SECRET".to_owned(),
            RedirectUri::new("https://client.example.com/cb")?,
        )?
        .configure(|x| {
            x.force_verify = Some(true);
        });

        let request = AuthorizationEndpoint::new(&provider, vec![TwitchScope::UserReadEmail])
            .configure(|x| x.state = Some("STATE".to_owned()))
            .render_request()?;

        assert_eq!(request.uri(), "https://id.twitch.tv/oauth2/authorize?response_type=code&client_id=CLIENT_ID&redirect_uri=https%3A%2F%2Fclient.example.com%2Fcb&scope=user%3Aread%3Aemail&state=STATE&force_verify=true");

        Ok(())
    }

    #[test]
    fn access_token_request() -> Result<(), Box<dyn error::Error>> {
        let provider = TwitchProviderForWebServerApps::new(
            "CLIENT_ID".to_owned(),
            "CLIENT_SECRET".to_owned(),
            RedirectUri::new("https://client.example.com/cb")?,
        )?;

        let request = AccessTokenEndpoint::new(&provider, "CODE".to_owned()).render_request()?;

        assert_eq!(request.body(), b"grant_type=authorization_code&code=CODE&redirect_uri=https%3A%2F%2Fclient.example.com%2Fcb&client_id=CLIENT_ID&client_secret=CLIENT_SECRET");

        Ok(())
    }

    #[test]
    fn access_token_response() -> Result<(), Box<dyn error::Error>> {
        let provider = TwitchProviderForWebServerApps::new(
            "CLIENT_ID".to_owned(),
            "CLIENT_SECRET".to_owned(),
            RedirectUri::new("https://client.example.com/cb")?,
        )?;

        //
        let response_body = r#"{"status":400, "message":"Invalid authorization code"}"#;
        let body_ret = AccessTokenEndpoint::new(&provider, "CODE".to_owned()).parse_response(
            Response::builder()
                .status(StatusCode::BAD_REQUEST)
                .body(response_body.as_bytes().to_vec())?,
        )?;

        match body_ret {
            Ok(body) => panic!("{:?}", body),
            Err(body) => {
                assert_eq!(
                    body.error,
                    AccessTokenResponseErrorBodyError::InvalidRequest
                );
                assert!(body
                    .error_description
                    .unwrap()
                    .contains("Invalid authorization code"));
            }
        }

        //
        let response_body = include_str!(
            "../tests/response_body_json_files/access_token_with_authorization_code_grant.json"
        );
        let body_ret = AccessTokenEndpoint::new(&provider, "CODE".to_owned())
            .parse_response(Response::builder().body(response_body.as_bytes().to_vec())?)?;

        match body_ret {
            Ok(body) => {
                let map = body.extra().unwrap();
                assert!(map.is_empty());
            }
            Err(body) => panic!("{:?}", body),
        }

        Ok(())
    }
}