twitter_client/
auth.rs

1use std::fmt::{self, Display, Formatter};
2use std::future::Future;
3use std::io::Write;
4use std::marker::PhantomData;
5use std::pin::Pin;
6use std::task::{Context, Poll};
7
8use bytes::Bytes;
9use futures_core::ready;
10use http::header::{HeaderValue, AUTHORIZATION, CONTENT_TYPE};
11use http::Uri;
12use http_body::Body;
13use oauth_credentials::Credentials;
14use pin_project_lite::pin_project;
15use serde::{de, Deserialize};
16
17use crate::error::Error;
18use crate::response::RawResponseFuture;
19use crate::traits::{HttpService, HttpTryFuture};
20
21use self::private::AuthDeserialize;
22
23#[derive(Clone, Debug)]
24#[non_exhaustive]
25pub struct AccessToken {
26    pub credentials: Credentials<Box<str>>,
27    pub user_id: i64,
28    pub screen_name: Box<str>,
29}
30
31#[derive(Clone, Debug, Deserialize)]
32#[non_exhaustive]
33pub struct AccessToken2 {
34    pub access_token: Box<str>,
35}
36
37pin_project! {
38    pub struct AuthFuture<T, F: HttpTryFuture> {
39        #[pin]
40        inner: RawResponseFuture<F>,
41        marker: PhantomData<fn() -> T>,
42    }
43}
44
45#[derive(Clone, Copy, Debug, PartialEq, Eq)]
46#[non_exhaustive]
47pub enum AuthAccessType {
48    Read,
49    Write,
50}
51
52impl<T: AuthDeserialize, F: HttpTryFuture> AuthFuture<T, F> {
53    pub(crate) fn new(response: F) -> Self {
54        AuthFuture {
55            inner: RawResponseFuture::new(response),
56            marker: PhantomData,
57        }
58    }
59}
60
61impl<T: AuthDeserialize, F: HttpTryFuture> Future for AuthFuture<T, F> {
62    #[allow(clippy::type_complexity)]
63    type Output = Result<T, Error<F::Error, <F::Body as Body>::Error>>;
64
65    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
66        let res = ready!(self.project().inner.poll(cx)?);
67        let data = T::deserialize(res.data).ok_or(Error::Unexpected)?;
68        Poll::Ready(Ok(data))
69    }
70}
71
72impl AuthAccessType {
73    pub fn as_str(&self) -> &'static str {
74        match *self {
75            AuthAccessType::Read => "read",
76            AuthAccessType::Write => "write",
77        }
78    }
79}
80
81impl AsRef<str> for AuthAccessType {
82    fn as_ref(&self) -> &str {
83        self.as_str()
84    }
85}
86
87impl Display for AuthAccessType {
88    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
89        f.write_str(self.as_ref())
90    }
91}
92
93impl AuthDeserialize for Credentials<Box<str>> {
94    fn deserialize(input: Bytes) -> Option<Self> {
95        serde_urlencoded::from_bytes(&input).ok()
96    }
97}
98
99impl AuthDeserialize for AccessToken {
100    fn deserialize(input: Bytes) -> Option<Self> {
101        serde_urlencoded::from_bytes(&input).ok()
102    }
103}
104
105impl AuthDeserialize for AccessToken2 {
106    fn deserialize(input: Bytes) -> Option<Self> {
107        serde_json::from_slice(&input).ok()
108    }
109}
110
111pub fn request_token<C, S, B>(
112    client_credentials: &Credentials<C>,
113    callback: &str,
114    x_auth_access_type: Option<AuthAccessType>,
115    client: &mut S,
116) -> AuthFuture<Credentials<Box<str>>, S::Future>
117where
118    C: AsRef<str>,
119    S: HttpService<B>,
120    B: Default,
121{
122    let req = request_token_request(client_credentials.as_ref(), callback, x_auth_access_type);
123    AuthFuture::new(client.call(req.map(|()| Default::default())))
124}
125
126fn request_token_request(
127    client_credentials: Credentials<&str>,
128    callback: &str,
129    x_auth_access_type: Option<AuthAccessType>,
130) -> http::Request<()> {
131    const URI: &str = "https://api.twitter.com/oauth/request_token";
132
133    #[derive(oauth::Request)]
134    struct RequestToken {
135        x_auth_access_type: Option<AuthAccessType>,
136    }
137
138    let authorization = oauth::Builder::<_, _>::new(client_credentials, oauth::HmacSha1)
139        .callback(callback)
140        .post(URI, &RequestToken { x_auth_access_type });
141
142    http::Request::post(Uri::from_static(URI))
143        .header(AUTHORIZATION, authorization)
144        .body(())
145        .unwrap()
146}
147
148pub fn access_token<C, T, S, B>(
149    client_credentials: &Credentials<C>,
150    temporary_credentials: &Credentials<T>,
151    oauth_verifier: &str,
152    client: &mut S,
153) -> AuthFuture<AccessToken, S::Future>
154where
155    C: AsRef<str>,
156    T: AsRef<str>,
157    S: HttpService<B>,
158    B: Default,
159{
160    let req = access_token_request(
161        client_credentials.as_ref(),
162        temporary_credentials.as_ref(),
163        oauth_verifier,
164    );
165    AuthFuture::new(client.call(req.map(|()| Default::default())))
166}
167
168fn access_token_request(
169    client_credentials: Credentials<&str>,
170    temporary_credentials: Credentials<&str>,
171    oauth_verifier: &str,
172) -> http::Request<()> {
173    const URI: &str = "https://api.twitter.com/oauth/access_token";
174
175    let authorization = oauth::Builder::new(client_credentials, oauth::HmacSha1)
176        .token(temporary_credentials)
177        .verifier(oauth_verifier)
178        .post(URI, &());
179
180    http::Request::post(Uri::from_static(URI))
181        .header(AUTHORIZATION, authorization)
182        .body(())
183        .unwrap()
184}
185
186impl<'de> Deserialize<'de> for AccessToken {
187    fn deserialize<D: de::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
188        #[derive(Deserialize)]
189        struct AccessToken {
190            oauth_token: Box<str>,
191            oauth_token_secret: Box<str>,
192            user_id: i64,
193            screen_name: Box<str>,
194        }
195        AccessToken::deserialize(d).map(|t| Self {
196            credentials: Credentials::new(t.oauth_token, t.oauth_token_secret),
197            user_id: t.user_id,
198            screen_name: t.screen_name,
199        })
200    }
201}
202
203pub fn token<C, S, B>(
204    client_credentials: &Credentials<C>,
205    client: &mut S,
206) -> AuthFuture<AccessToken2, S::Future>
207where
208    C: AsRef<str>,
209    S: HttpService<B>,
210    B: From<&'static [u8]>,
211{
212    let req = token_request(client_credentials.as_ref());
213    AuthFuture::new(client.call(req.map(Into::into)))
214}
215
216fn token_request(client_credentials: Credentials<&str>) -> http::Request<&'static [u8]> {
217    const URI: &str = "https://api.twitter.com/oauth2/token";
218
219    let authorization = basic_auth(client_credentials);
220
221    let application_www_form_urlencoded =
222        HeaderValue::from_static("application/x-www-form-urlencoded");
223    http::Request::post(Uri::from_static(URI))
224        .header(AUTHORIZATION, authorization)
225        .header(CONTENT_TYPE, application_www_form_urlencoded)
226        .body(&b"grant_type=client_credentials"[..])
227        .unwrap()
228}
229
230fn basic_auth(credentials: Credentials<&str>) -> String {
231    let b64len = (credentials.identifier.len() + credentials.secret.len()) / 3 * 4 + 4;
232    let mut authorization = String::with_capacity("Basic ".len() + b64len);
233
234    authorization.push_str("Basic ");
235
236    let mut enc = base64::write::EncoderStringWriter::from(authorization, base64::STANDARD);
237    enc.write_all(credentials.identifier.as_bytes()).unwrap();
238    enc.write_all(b":").unwrap();
239    enc.write_all(credentials.secret.as_bytes()).unwrap();
240
241    enc.into_inner()
242}
243
244mod private {
245    use bytes::Bytes;
246
247    pub trait AuthDeserialize: Sized {
248        fn deserialize(input: Bytes) -> Option<Self>;
249    }
250}