1use std::fmt::{self, Display, Formatter};
2use std::future::Future;
3use std::io::Write;
4use std::marker::PhantomData;
5use std::pin::Pin;
6use std::task::{Context, Poll};
7
8use bytes::Bytes;
9use futures_core::ready;
10use http::header::{HeaderValue, AUTHORIZATION, CONTENT_TYPE};
11use http::Uri;
12use http_body::Body;
13use oauth_credentials::Credentials;
14use pin_project_lite::pin_project;
15use serde::{de, Deserialize};
16
17use crate::error::Error;
18use crate::response::RawResponseFuture;
19use crate::traits::{HttpService, HttpTryFuture};
20
21use self::private::AuthDeserialize;
22
23#[derive(Clone, Debug)]
24#[non_exhaustive]
25pub struct AccessToken {
26 pub credentials: Credentials<Box<str>>,
27 pub user_id: i64,
28 pub screen_name: Box<str>,
29}
30
31#[derive(Clone, Debug, Deserialize)]
32#[non_exhaustive]
33pub struct AccessToken2 {
34 pub access_token: Box<str>,
35}
36
37pin_project! {
38 pub struct AuthFuture<T, F: HttpTryFuture> {
39 #[pin]
40 inner: RawResponseFuture<F>,
41 marker: PhantomData<fn() -> T>,
42 }
43}
44
45#[derive(Clone, Copy, Debug, PartialEq, Eq)]
46#[non_exhaustive]
47pub enum AuthAccessType {
48 Read,
49 Write,
50}
51
52impl<T: AuthDeserialize, F: HttpTryFuture> AuthFuture<T, F> {
53 pub(crate) fn new(response: F) -> Self {
54 AuthFuture {
55 inner: RawResponseFuture::new(response),
56 marker: PhantomData,
57 }
58 }
59}
60
61impl<T: AuthDeserialize, F: HttpTryFuture> Future for AuthFuture<T, F> {
62 #[allow(clippy::type_complexity)]
63 type Output = Result<T, Error<F::Error, <F::Body as Body>::Error>>;
64
65 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
66 let res = ready!(self.project().inner.poll(cx)?);
67 let data = T::deserialize(res.data).ok_or(Error::Unexpected)?;
68 Poll::Ready(Ok(data))
69 }
70}
71
72impl AuthAccessType {
73 pub fn as_str(&self) -> &'static str {
74 match *self {
75 AuthAccessType::Read => "read",
76 AuthAccessType::Write => "write",
77 }
78 }
79}
80
81impl AsRef<str> for AuthAccessType {
82 fn as_ref(&self) -> &str {
83 self.as_str()
84 }
85}
86
87impl Display for AuthAccessType {
88 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
89 f.write_str(self.as_ref())
90 }
91}
92
93impl AuthDeserialize for Credentials<Box<str>> {
94 fn deserialize(input: Bytes) -> Option<Self> {
95 serde_urlencoded::from_bytes(&input).ok()
96 }
97}
98
99impl AuthDeserialize for AccessToken {
100 fn deserialize(input: Bytes) -> Option<Self> {
101 serde_urlencoded::from_bytes(&input).ok()
102 }
103}
104
105impl AuthDeserialize for AccessToken2 {
106 fn deserialize(input: Bytes) -> Option<Self> {
107 serde_json::from_slice(&input).ok()
108 }
109}
110
111pub fn request_token<C, S, B>(
112 client_credentials: &Credentials<C>,
113 callback: &str,
114 x_auth_access_type: Option<AuthAccessType>,
115 client: &mut S,
116) -> AuthFuture<Credentials<Box<str>>, S::Future>
117where
118 C: AsRef<str>,
119 S: HttpService<B>,
120 B: Default,
121{
122 let req = request_token_request(client_credentials.as_ref(), callback, x_auth_access_type);
123 AuthFuture::new(client.call(req.map(|()| Default::default())))
124}
125
126fn request_token_request(
127 client_credentials: Credentials<&str>,
128 callback: &str,
129 x_auth_access_type: Option<AuthAccessType>,
130) -> http::Request<()> {
131 const URI: &str = "https://api.twitter.com/oauth/request_token";
132
133 #[derive(oauth::Request)]
134 struct RequestToken {
135 x_auth_access_type: Option<AuthAccessType>,
136 }
137
138 let authorization = oauth::Builder::<_, _>::new(client_credentials, oauth::HmacSha1)
139 .callback(callback)
140 .post(URI, &RequestToken { x_auth_access_type });
141
142 http::Request::post(Uri::from_static(URI))
143 .header(AUTHORIZATION, authorization)
144 .body(())
145 .unwrap()
146}
147
148pub fn access_token<C, T, S, B>(
149 client_credentials: &Credentials<C>,
150 temporary_credentials: &Credentials<T>,
151 oauth_verifier: &str,
152 client: &mut S,
153) -> AuthFuture<AccessToken, S::Future>
154where
155 C: AsRef<str>,
156 T: AsRef<str>,
157 S: HttpService<B>,
158 B: Default,
159{
160 let req = access_token_request(
161 client_credentials.as_ref(),
162 temporary_credentials.as_ref(),
163 oauth_verifier,
164 );
165 AuthFuture::new(client.call(req.map(|()| Default::default())))
166}
167
168fn access_token_request(
169 client_credentials: Credentials<&str>,
170 temporary_credentials: Credentials<&str>,
171 oauth_verifier: &str,
172) -> http::Request<()> {
173 const URI: &str = "https://api.twitter.com/oauth/access_token";
174
175 let authorization = oauth::Builder::new(client_credentials, oauth::HmacSha1)
176 .token(temporary_credentials)
177 .verifier(oauth_verifier)
178 .post(URI, &());
179
180 http::Request::post(Uri::from_static(URI))
181 .header(AUTHORIZATION, authorization)
182 .body(())
183 .unwrap()
184}
185
186impl<'de> Deserialize<'de> for AccessToken {
187 fn deserialize<D: de::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
188 #[derive(Deserialize)]
189 struct AccessToken {
190 oauth_token: Box<str>,
191 oauth_token_secret: Box<str>,
192 user_id: i64,
193 screen_name: Box<str>,
194 }
195 AccessToken::deserialize(d).map(|t| Self {
196 credentials: Credentials::new(t.oauth_token, t.oauth_token_secret),
197 user_id: t.user_id,
198 screen_name: t.screen_name,
199 })
200 }
201}
202
203pub fn token<C, S, B>(
204 client_credentials: &Credentials<C>,
205 client: &mut S,
206) -> AuthFuture<AccessToken2, S::Future>
207where
208 C: AsRef<str>,
209 S: HttpService<B>,
210 B: From<&'static [u8]>,
211{
212 let req = token_request(client_credentials.as_ref());
213 AuthFuture::new(client.call(req.map(Into::into)))
214}
215
216fn token_request(client_credentials: Credentials<&str>) -> http::Request<&'static [u8]> {
217 const URI: &str = "https://api.twitter.com/oauth2/token";
218
219 let authorization = basic_auth(client_credentials);
220
221 let application_www_form_urlencoded =
222 HeaderValue::from_static("application/x-www-form-urlencoded");
223 http::Request::post(Uri::from_static(URI))
224 .header(AUTHORIZATION, authorization)
225 .header(CONTENT_TYPE, application_www_form_urlencoded)
226 .body(&b"grant_type=client_credentials"[..])
227 .unwrap()
228}
229
230fn basic_auth(credentials: Credentials<&str>) -> String {
231 let b64len = (credentials.identifier.len() + credentials.secret.len()) / 3 * 4 + 4;
232 let mut authorization = String::with_capacity("Basic ".len() + b64len);
233
234 authorization.push_str("Basic ");
235
236 let mut enc = base64::write::EncoderStringWriter::from(authorization, base64::STANDARD);
237 enc.write_all(credentials.identifier.as_bytes()).unwrap();
238 enc.write_all(b":").unwrap();
239 enc.write_all(credentials.secret.as_bytes()).unwrap();
240
241 enc.into_inner()
242}
243
244mod private {
245 use bytes::Bytes;
246
247 pub trait AuthDeserialize: Sized {
248 fn deserialize(input: Bytes) -> Option<Self>;
249 }
250}