1use base64::prelude::*;
2use std::time::Duration;
3
4use query_string_builder::QueryString;
5use reqwest::{StatusCode, header::HeaderMap};
6use serde::{Deserialize, Serialize};
7use sha2::Digest;
8
9use crate::{error::Error, execute_retry, make_url};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct TokenResult {
13 pub access_token: String,
14 pub refresh_token: String,
15 pub expires_in: u64,
16 pub scope: String,
17 pub token_type: String,
18}
19
20enum ResponseType {
21 Code,
22 #[allow(unused)]
23 Token,
24}
25
26impl std::fmt::Display for ResponseType {
27 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
28 match self {
29 Self::Code => write!(f, "code"),
30 Self::Token => write!(f, "token"),
31 }
32 }
33}
34
35enum CodeChallengeMethod {
36 S256,
37 #[allow(unused)]
38 Plain,
39}
40
41impl std::fmt::Display for CodeChallengeMethod {
42 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
43 match self {
44 Self::S256 => write!(f, "S256"),
45 Self::Plain => write!(f, "plain"),
46 }
47 }
48}
49
50pub(crate) struct PkceS256 {
51 pub code_challenge: String,
52 pub code_verifier: String,
53}
54
55impl PkceS256 {
56 pub fn new() -> Self {
57 let size = 32;
58 let random_bytes: Vec<u8> = (0..size).map(|_| rand::random::<u8>()).collect();
59 let code_verifier = BASE64_URL_SAFE_NO_PAD.encode(&random_bytes);
60 let code_challenge = {
61 let hash = sha2::Sha256::digest(code_verifier.as_bytes());
62 BASE64_URL_SAFE_NO_PAD.encode(hash)
63 };
64 Self {
65 code_challenge,
66 code_verifier,
67 }
68 }
69}
70
71impl Default for PkceS256 {
72 fn default() -> Self {
73 Self::new()
74 }
75}
76
77#[allow(clippy::too_many_arguments)]
78fn authorize_url(
79 url: &str,
80 response_type: ResponseType,
81 client_id: &str,
82 redirect_uri: &str,
83 scopes: &str,
84 state: &str,
85 code_challenge: &str,
86 code_challenge_method: CodeChallengeMethod,
87) -> String {
88 let qs = QueryString::dynamic()
89 .with_value("response_type", response_type.to_string())
90 .with_value("client_id", client_id)
91 .with_value("redirect_uri", redirect_uri)
92 .with_value("scope", scopes)
93 .with_value("state", state)
94 .with_value("code_challenge", code_challenge)
95 .with_value("code_challenge_method", code_challenge_method.to_string());
96 format!("{}{}", url, qs)
97}
98
99#[allow(clippy::too_many_arguments)]
100pub(crate) async fn token(
101 url: &str,
102 client_id: &str,
103 client_secret: &str,
104 redirect_uri: &str,
105 code: &str,
106 code_verifier: &str,
107 grant_type: &str,
108 timeout: Duration,
109 try_count: usize,
110 retry_duration: Duration,
111) -> Result<(TokenResult, StatusCode, HeaderMap), Error> {
112 let params = [
113 ("grant_type", grant_type),
114 ("code", code),
115 ("redirect_uri", redirect_uri),
116 ("client_id", client_id),
117 ("code_verifier", code_verifier),
118 ];
119
120 let client = reqwest::Client::new();
121
122 execute_retry(
123 || {
124 client
125 .post(url)
126 .form(¶ms)
127 .basic_auth(client_id, Some(client_secret))
128 .timeout(timeout)
129 },
130 try_count,
131 retry_duration,
132 )
133 .await
134}
135
136pub async fn refresh_token(
137 client_id: &str,
138 client_secret: &str,
139 refresh_token: &str,
140 timeout: Duration,
141 try_count: usize,
142 retry_duration: Duration,
143 prefix_url: Option<String>,
144) -> Result<(TokenResult, StatusCode, HeaderMap), Error> {
145 let url = &make_url(URL_POSTFIX, X_TOKEN_URL_PREFIX, &prefix_url);
146 let params = [
147 ("grant_type", "refresh_token"),
148 ("refresh_token", refresh_token),
149 ("client_id", client_id),
150 ];
151
152 let client = reqwest::Client::new();
153
154 execute_retry(
155 || {
156 client
157 .post(url)
158 .form(¶ms)
159 .basic_auth(client_id, Some(client_secret))
160 .timeout(timeout)
161 },
162 try_count,
163 retry_duration,
164 )
165 .await
166}
167
168pub enum XScope {
169 TweetRead,
170 TweetWrite,
171 TweetModerateWrite,
172 UsersEmail,
173 UsersRead,
174 FollowsRead,
175 FollowsWrite,
176 OfflineAccess,
177 SpaceRead,
178 MuteRead,
179 MuteWrite,
180 LikeRead,
181 LikeWrite,
182 ListRead,
183 ListWrite,
184 BlockRead,
185 BlockWrite,
186 BookmarkRead,
187 BookmarkWrite,
188 DmRead,
189 DmWrite,
190 MediaWrite,
191}
192
193impl XScope {
194 pub fn all() -> Vec<Self> {
195 vec![
196 Self::TweetRead,
197 Self::TweetWrite,
198 Self::TweetModerateWrite,
199 Self::UsersEmail,
200 Self::UsersRead,
201 Self::FollowsRead,
202 Self::FollowsWrite,
203 Self::OfflineAccess,
204 Self::SpaceRead,
205 Self::MuteRead,
206 Self::MuteWrite,
207 Self::LikeRead,
208 Self::LikeWrite,
209 Self::ListRead,
210 Self::ListWrite,
211 Self::BlockRead,
212 Self::BlockWrite,
213 Self::BookmarkRead,
214 Self::BookmarkWrite,
215 Self::DmRead,
216 Self::DmWrite,
217 Self::MediaWrite,
218 ]
219 }
220
221 pub fn scopes_to_string(scopes: &[XScope]) -> String {
222 scopes
223 .iter()
224 .map(|s| s.to_string())
225 .collect::<Vec<String>>()
226 .join(" ")
227 }
228}
229
230impl std::fmt::Display for XScope {
231 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
232 match self {
233 Self::TweetRead => write!(f, "tweet.read"),
234 Self::TweetWrite => write!(f, "tweet.write"),
235 Self::TweetModerateWrite => write!(f, "tweet.moderate.write"),
236 Self::UsersEmail => write!(f, "users.email"),
237 Self::UsersRead => write!(f, "users.read"),
238 Self::FollowsRead => write!(f, "follows.read"),
239 Self::FollowsWrite => write!(f, "follows.write"),
240 Self::OfflineAccess => write!(f, "offline.access"),
241 Self::SpaceRead => write!(f, "space.read"),
242 Self::MuteRead => write!(f, "mute.read"),
243 Self::MuteWrite => write!(f, "mute.write"),
244 Self::LikeRead => write!(f, "like.read"),
245 Self::LikeWrite => write!(f, "like.write"),
246 Self::ListRead => write!(f, "list.read"),
247 Self::ListWrite => write!(f, "list.write"),
248 Self::BlockRead => write!(f, "block.read"),
249 Self::BlockWrite => write!(f, "block.write"),
250 Self::BookmarkRead => write!(f, "bookmark.read"),
251 Self::BookmarkWrite => write!(f, "bookmark.write"),
252 Self::DmRead => write!(f, "dm.read"),
253 Self::DmWrite => write!(f, "dm.write"),
254 Self::MediaWrite => write!(f, "media.write"),
255 }
256 }
257}
258
259pub const X_AUTHORIZE_URL: &str = "https://x.com/i/oauth2/authorize";
260
261const URL_POSTFIX: &str = "https://api.x.com";
262pub const X_TOKEN_URL_PREFIX: &str = "/2/oauth2/token";
263
264pub struct XClient {
265 client_id: String,
266 client_secret: String,
267 redirect_uri: String,
268 scopes: Vec<XScope>,
269 try_count: usize,
270 retry_duration: Duration,
271 timeout: Duration,
272 prefix_url: Option<String>,
273}
274
275impl XClient {
276 pub fn new(
277 client_id: &str,
278 client_secret: &str,
279 redirect_uri: &str,
280 scopes: Vec<XScope>,
281 ) -> Self {
282 Self::new_with_token_options(
283 client_id,
284 client_secret,
285 redirect_uri,
286 scopes,
287 3,
288 Duration::from_millis(100),
289 Duration::from_secs(10),
290 None,
291 )
292 }
293
294 #[allow(clippy::too_many_arguments)]
295 pub fn new_with_token_options(
296 client_id: &str,
297 client_secret: &str,
298 redirect_uri: &str,
299 scopes: Vec<XScope>,
300 try_count: usize,
301 retry_duration: Duration,
302 timeout: Duration,
303 prefix_url: Option<String>,
304 ) -> Self {
305 Self {
306 client_id: client_id.to_string(),
307 client_secret: client_secret.to_string(),
308 redirect_uri: redirect_uri.to_string(),
309 scopes,
310 try_count,
311 retry_duration,
312 timeout,
313 prefix_url,
314 }
315 }
316
317 pub fn authorize_url(&self, state: &str) -> (String, String) {
318 let pkce = PkceS256::new();
319
320 let scopes_str = XScope::scopes_to_string(&self.scopes);
321 (
322 authorize_url(
323 X_AUTHORIZE_URL,
324 ResponseType::Code,
325 &self.client_id,
326 &self.redirect_uri,
327 &scopes_str,
328 state,
329 &pkce.code_challenge,
330 CodeChallengeMethod::S256,
331 ),
332 pkce.code_verifier,
333 )
334 }
335
336 pub async fn token(
337 &self,
338 code: &str,
339 code_verifier: &str,
340 ) -> Result<(TokenResult, StatusCode, HeaderMap), Error> {
341 let (token_json, status_code, headers) = token(
342 &make_url(URL_POSTFIX, X_TOKEN_URL_PREFIX, &self.prefix_url),
343 &self.client_id,
344 &self.client_secret,
345 &self.redirect_uri,
346 code,
347 code_verifier,
348 "authorization_code",
349 self.timeout,
350 self.try_count,
351 self.retry_duration,
352 )
353 .await?;
354 Ok((token_json, status_code, headers))
355 }
356}
357
358#[cfg(test)]
359mod tests {
360 use super::*;
361
362 #[tokio::test]
364 async fn test_x_authorize() {
365 let client_id = std::env::var("CLIENT_ID").unwrap();
366 let client_secret = std::env::var("CLIENT_SECRET").unwrap();
367 let redirect_url = std::env::var("REDIRECT_URL").unwrap();
368 let state = "test_state";
369 let x_client = XClient::new(&client_id, &client_secret, &redirect_url, XScope::all());
370 let (auth_url, code_verifier) = x_client.authorize_url(state);
371 println!("Authorize URL: {}", auth_url);
372 println!("Code Verifier: {}", code_verifier);
373 }
374
375 #[tokio::test]
377 async fn test_authorize() {
378 let client_id = std::env::var("CLIENT_ID").unwrap();
379 let redirect_url = std::env::var("REDIRECT_URL").unwrap();
380 let state = "test_state";
381 let scopes = XScope::scopes_to_string(&XScope::all());
382 let code_challenge = "test_code_challenge";
383 let res = authorize_url(
384 X_AUTHORIZE_URL,
385 ResponseType::Code,
386 &client_id,
387 &redirect_url,
388 &scopes,
389 &state,
390 &code_challenge,
391 CodeChallengeMethod::Plain,
392 );
393 println!("res: {}", res);
394 }
395}