twitter-v2 0.1.8

Twitter API v2
Documentation
use super::Authorization;
use crate::error::{Error, Result};
use async_trait::async_trait;
use oauth2::basic::{BasicClient, BasicRequestTokenError, BasicTokenResponse};
use oauth2::{
    AccessToken, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, PkceCodeChallenge,
    PkceCodeVerifier, RedirectUrl, RefreshToken, RevocationUrl, StandardRevocableToken,
    TokenResponse, TokenUrl,
};
use reqwest::header::HeaderValue;
use reqwest::Request;
use serde::{Deserialize, Serialize};
use std::convert::{TryFrom, TryInto};
use std::future::Future;
use std::sync::Arc;
use strum::{Display, EnumString};
use time::OffsetDateTime;
use tokio::sync::{RwLock, RwLockReadGuard};
use url::Url;

#[derive(Copy, Clone, Debug, EnumString, Display, Serialize, Deserialize)]
#[strum(serialize_all = "snake_case")]
pub enum Scope {
    #[strum(serialize = "tweet.read")]
    #[serde(rename = "tweet.read")]
    TweetRead,
    #[strum(serialize = "tweet.write")]
    #[serde(rename = "tweet.write")]
    TweetWrite,
    #[strum(serialize = "tweet.moderate.write")]
    #[serde(rename = "tweet.moderate.write")]
    TweetModerateWrite,
    #[strum(serialize = "users.read")]
    #[serde(rename = "users.read")]
    UsersRead,
    #[strum(serialize = "follows.read")]
    #[serde(rename = "follows.read")]
    FollowsRead,
    #[strum(serialize = "follows.write")]
    #[serde(rename = "follows.write")]
    FollowsWrite,
    #[strum(serialize = "offline.access")]
    #[serde(rename = "offline.access")]
    OfflineAccess,
    #[strum(serialize = "space.read")]
    #[serde(rename = "space.read")]
    SpaceRead,
    #[strum(serialize = "mute.read")]
    #[serde(rename = "mute.read")]
    MuteRead,
    #[strum(serialize = "mute.write")]
    #[serde(rename = "mute.write")]
    MuteWrite,
    #[strum(serialize = "like.read")]
    #[serde(rename = "like.read")]
    LikeRead,
    #[strum(serialize = "like.write")]
    #[serde(rename = "like.write")]
    LikeWrite,
    #[strum(serialize = "list.read")]
    #[serde(rename = "list.read")]
    ListRead,
    #[strum(serialize = "list.write")]
    #[serde(rename = "list.write")]
    ListWrite,
    #[strum(serialize = "block.read")]
    #[serde(rename = "block.read")]
    BlockRead,
    #[strum(serialize = "block.write")]
    #[serde(rename = "block.write")]
    BlockWrite,
    #[strum(serialize = "bookmark.read")]
    #[serde(rename = "bookmark.read")]
    BookmarkRead,
    #[strum(serialize = "bookmark.write")]
    #[serde(rename = "bookmark.write")]
    BookmarkWrite,
}

impl From<Scope> for oauth2::Scope {
    fn from(scope: Scope) -> Self {
        oauth2::Scope::new(scope.to_string())
    }
}

#[derive(Clone, Debug)]
pub struct Oauth2Client(BasicClient);

impl Oauth2Client {
    pub fn new(client_id: impl ToString, client_secret: impl ToString, callback_url: Url) -> Self {
        Self(
            BasicClient::new(
                ClientId::new(client_id.to_string()),
                Some(ClientSecret::new(client_secret.to_string())),
                AuthUrl::from_url("https://twitter.com/i/oauth2/authorize".parse().unwrap()),
                Some(TokenUrl::from_url(
                    "https://api.twitter.com/2/oauth2/token".parse().unwrap(),
                )),
            )
            .set_revocation_uri(RevocationUrl::from_url(
                "https://api.twitter.com/2/oauth2/revoke".parse().unwrap(),
            ))
            .set_redirect_uri(RedirectUrl::from_url(callback_url)),
        )
    }

    pub fn auth_url(
        &self,
        challenge: PkceCodeChallenge,
        scopes: impl IntoIterator<Item = Scope>,
    ) -> (Url, CsrfToken) {
        self.0
            .authorize_url(CsrfToken::new_random)
            .set_pkce_challenge(challenge)
            .add_scopes(scopes.into_iter().map(|s| s.into()))
            .url()
    }

    pub async fn request_token(
        &self,
        code: AuthorizationCode,
        verifier: PkceCodeVerifier,
    ) -> Result<Oauth2Token> {
        let res = self
            .0
            .exchange_code(code)
            .set_pkce_verifier(verifier)
            .request_async(oauth2::reqwest::async_http_client)
            .await?;
        res.try_into()
    }

    pub async fn revoke_token(&self, token: StandardRevocableToken) -> Result<()> {
        Ok(self
            .0
            .revoke_token(token)
            .unwrap()
            .request_async(oauth2::reqwest::async_http_client)
            .await?)
    }

    pub async fn refresh_token(&self, token: &RefreshToken) -> Result<Oauth2Token> {
        self.0
            .exchange_refresh_token(token)
            .request_async(oauth2::reqwest::async_http_client)
            .await?
            .try_into()
    }

    pub async fn refresh_token_if_expired(&self, token: &mut Oauth2Token) -> Result<bool> {
        if token.is_expired() {
            if let Some(refresh_token) = token.refresh_token() {
                *token = self.refresh_token(refresh_token).await?;
                Ok(true)
            } else {
                Err(Error::NoRefreshToken)
            }
        } else {
            Ok(false)
        }
    }
}

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Oauth2Token {
    access_token: AccessToken,
    refresh_token: Option<RefreshToken>,
    #[serde(with = "time::serde::rfc3339")]
    expires: OffsetDateTime,
    scopes: Vec<Scope>,
}

impl Oauth2Token {
    pub fn access_token(&self) -> &AccessToken {
        &self.access_token
    }
    pub fn refresh_token(&self) -> Option<&RefreshToken> {
        self.refresh_token.as_ref()
    }
    pub fn expires(&self) -> OffsetDateTime {
        self.expires
    }
    pub fn is_expired(&self) -> bool {
        self.expires < OffsetDateTime::now_utc()
    }
    pub fn scopes(&self) -> &[Scope] {
        &self.scopes
    }
    pub fn revokable_token(&self) -> StandardRevocableToken {
        if let Some(refresh_token) = self.refresh_token.as_ref() {
            StandardRevocableToken::RefreshToken(refresh_token.clone())
        } else {
            StandardRevocableToken::AccessToken(self.access_token.clone())
        }
    }
}

impl TryFrom<BasicTokenResponse> for Oauth2Token {
    type Error = Error;
    fn try_from(token: BasicTokenResponse) -> Result<Self, Self::Error> {
        Ok(Self {
            access_token: token.access_token().clone(),
            refresh_token: token.refresh_token().cloned(),
            expires: OffsetDateTime::now_utc()
                + token.expires_in().ok_or_else(|| {
                    Error::Oauth2TokenError(BasicRequestTokenError::Other(
                        "Missing expiration".to_string(),
                    ))
                })?,
            scopes: token
                .scopes()
                .ok_or_else(|| {
                    Error::Oauth2TokenError(BasicRequestTokenError::Other(
                        "Missing scopes".to_string(),
                    ))
                })?
                .iter()
                .map(|s| {
                    s.parse().map_err(|err| {
                        Error::Oauth2TokenError(BasicRequestTokenError::Other(format!(
                            "Invalid scope: {}",
                            err
                        )))
                    })
                })
                .collect::<Result<Vec<_>>>()?,
        })
    }
}

#[async_trait]
impl Authorization for Oauth2Token {
    async fn header(&self, _request: &Request) -> Result<HeaderValue> {
        format!("Bearer {}", self.access_token().secret())
            .parse()
            .map_err(Error::InvalidAuthorizationHeader)
    }
}

fn no_op(_: Oauth2Token) -> futures::future::Ready<Result<()>> {
    futures::future::ok(())
}
pub type NoCallback = fn(Oauth2Token) -> futures::future::Ready<Result<()>>;

#[derive(Clone, Debug)]
pub struct RefreshableOauth2Token<C> {
    oauth_client: Oauth2Client,
    token: Arc<RwLock<Oauth2Token>>,
    callback: C,
}

impl RefreshableOauth2Token<NoCallback> {
    pub fn new(oauth_client: Oauth2Client, token: Oauth2Token) -> Self {
        Self {
            oauth_client,
            token: Arc::new(RwLock::new(token)),
            callback: no_op,
        }
    }
}

impl<C> RefreshableOauth2Token<C> {
    pub fn with_callback<T>(&self, callback: T) -> RefreshableOauth2Token<T> {
        RefreshableOauth2Token {
            oauth_client: self.oauth_client.clone(),
            token: self.token.clone(),
            callback,
        }
    }
    pub async fn token(&self) -> RwLockReadGuard<'_, Oauth2Token> {
        self.token.read().await
    }

    pub async fn revoke(&self) -> Result<()> {
        self.oauth_client
            .revoke_token(self.token.read().await.revokable_token())
            .await
    }
}

impl<C, F> RefreshableOauth2Token<C>
where
    C: Fn(Oauth2Token) -> F + Send + Sync,
    F: Future<Output = Result<()>>,
{
    pub async fn refresh(&self) -> Result<()> {
        let mut token = self.token.write().await;
        *token = self
            .oauth_client
            .refresh_token(token.refresh_token.as_ref().ok_or(Error::NoRefreshToken)?)
            .await?;
        (self.callback)(token.clone()).await?;
        Ok(())
    }
}

#[async_trait]
impl<C, F> Authorization for RefreshableOauth2Token<C>
where
    C: Fn(Oauth2Token) -> F + Send + Sync,
    F: Future<Output = Result<()>> + Send,
{
    async fn header(&self, request: &Request) -> Result<HeaderValue> {
        let mut token = self.token.write().await;
        if self
            .oauth_client
            .refresh_token_if_expired(&mut token)
            .await?
        {
            (self.callback)(token.clone()).await?;
        }
        token.header(request).await
    }
}