1use base64::prelude::*;
2use std::time::Duration;
3
4use query_string_builder::QueryString;
5use reqwest::{StatusCode, header::HeaderMap};
6use serde::{Deserialize, Serialize};
7use sha2::Digest;
8
9use crate::{error::Error, execute_retry, make_url};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct TokenResult {
13 pub access_token: String,
14 pub refresh_token: String,
15 pub expires_in: u64,
16 pub scope: String,
17 pub token_type: String,
18}
19
20enum ResponseType {
21 Code,
22 #[allow(unused)]
23 Token,
24}
25
26impl std::fmt::Display for ResponseType {
27 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
28 match self {
29 Self::Code => write!(f, "code"),
30 Self::Token => write!(f, "token"),
31 }
32 }
33}
34
35enum CodeChallengeMethod {
36 S256,
37 #[allow(unused)]
38 Plain,
39}
40
41impl std::fmt::Display for CodeChallengeMethod {
42 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
43 match self {
44 Self::S256 => write!(f, "S256"),
45 Self::Plain => write!(f, "plain"),
46 }
47 }
48}
49
50pub(crate) struct PkceS256 {
51 pub code_challenge: String,
52 pub code_verifier: String,
53}
54
55impl PkceS256 {
56 pub fn new() -> Self {
57 let size = 32;
58 let random_bytes: Vec<u8> = (0..size).map(|_| rand::random::<u8>()).collect();
59 let code_verifier = BASE64_URL_SAFE_NO_PAD.encode(&random_bytes);
60 let code_challenge = {
61 let hash = sha2::Sha256::digest(code_verifier.as_bytes());
62 BASE64_URL_SAFE_NO_PAD.encode(hash)
63 };
64 Self {
65 code_challenge,
66 code_verifier,
67 }
68 }
69}
70
71impl Default for PkceS256 {
72 fn default() -> Self {
73 Self::new()
74 }
75}
76
77#[allow(clippy::too_many_arguments)]
78fn authorize_url(
79 url: &str,
80 response_type: ResponseType,
81 client_id: &str,
82 redirect_uri: &str,
83 scopes: &str,
84 state: &str,
85 code_challenge: &str,
86 code_challenge_method: CodeChallengeMethod,
87) -> String {
88 let qs = QueryString::dynamic()
89 .with_value("response_type", response_type.to_string())
90 .with_value("client_id", client_id)
91 .with_value("redirect_uri", redirect_uri)
92 .with_value("scope", scopes)
93 .with_value("state", state)
94 .with_value("code_challenge", code_challenge)
95 .with_value("code_challenge_method", code_challenge_method.to_string());
96 format!("{}{}", url, qs)
97}
98
99#[allow(clippy::too_many_arguments)]
100pub(crate) async fn token(
101 url: &str,
102 client_id: &str,
103 client_secret: &str,
104 redirect_uri: &str,
105 code: &str,
106 code_verifier: &str,
107 grant_type: &str,
108 timeout: Duration,
109 try_count: usize,
110 retry_duration: Duration,
111) -> Result<(TokenResult, StatusCode, HeaderMap), Error> {
112 let params = [
113 ("grant_type", grant_type),
114 ("code", code),
115 ("redirect_uri", redirect_uri),
116 ("client_id", client_id),
117 ("code_verifier", code_verifier),
118 ];
119
120 let client = reqwest::Client::new();
121
122 execute_retry(
123 || {
124 client
125 .post(url)
126 .form(¶ms)
127 .basic_auth(client_id, Some(client_secret))
128 .timeout(timeout)
129 },
130 try_count,
131 retry_duration,
132 )
133 .await
134}
135
136pub enum XScope {
137 TweetRead,
138 TweetWrite,
139 TweetModerateWrite,
140 UsersEmail,
141 UsersRead,
142 FollowsRead,
143 FollowsWrite,
144 OfflineAccess,
145 SpaceRead,
146 MuteRead,
147 MuteWrite,
148 LikeRead,
149 LikeWrite,
150 ListRead,
151 ListWrite,
152 BlockRead,
153 BlockWrite,
154 BookmarkRead,
155 BookmarkWrite,
156 DmRead,
157 DmWrite,
158 MediaWrite,
159}
160
161impl XScope {
162 pub fn all() -> Vec<Self> {
163 vec![
164 Self::TweetRead,
165 Self::TweetWrite,
166 Self::TweetModerateWrite,
167 Self::UsersEmail,
168 Self::UsersRead,
169 Self::FollowsRead,
170 Self::FollowsWrite,
171 Self::OfflineAccess,
172 Self::SpaceRead,
173 Self::MuteRead,
174 Self::MuteWrite,
175 Self::LikeRead,
176 Self::LikeWrite,
177 Self::ListRead,
178 Self::ListWrite,
179 Self::BlockRead,
180 Self::BlockWrite,
181 Self::BookmarkRead,
182 Self::BookmarkWrite,
183 Self::DmRead,
184 Self::DmWrite,
185 Self::MediaWrite,
186 ]
187 }
188
189 pub fn scopes_to_string(scopes: &[XScope]) -> String {
190 scopes
191 .iter()
192 .map(|s| s.to_string())
193 .collect::<Vec<String>>()
194 .join(" ")
195 }
196}
197
198impl std::fmt::Display for XScope {
199 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
200 match self {
201 Self::TweetRead => write!(f, "tweet.read"),
202 Self::TweetWrite => write!(f, "tweet.write"),
203 Self::TweetModerateWrite => write!(f, "tweet.moderate.write"),
204 Self::UsersEmail => write!(f, "users.email"),
205 Self::UsersRead => write!(f, "users.read"),
206 Self::FollowsRead => write!(f, "follows.read"),
207 Self::FollowsWrite => write!(f, "follows.write"),
208 Self::OfflineAccess => write!(f, "offline.access"),
209 Self::SpaceRead => write!(f, "space.read"),
210 Self::MuteRead => write!(f, "mute.read"),
211 Self::MuteWrite => write!(f, "mute.write"),
212 Self::LikeRead => write!(f, "like.read"),
213 Self::LikeWrite => write!(f, "like.write"),
214 Self::ListRead => write!(f, "list.read"),
215 Self::ListWrite => write!(f, "list.write"),
216 Self::BlockRead => write!(f, "block.read"),
217 Self::BlockWrite => write!(f, "block.write"),
218 Self::BookmarkRead => write!(f, "bookmark.read"),
219 Self::BookmarkWrite => write!(f, "bookmark.write"),
220 Self::DmRead => write!(f, "dm.read"),
221 Self::DmWrite => write!(f, "dm.write"),
222 Self::MediaWrite => write!(f, "media.write"),
223 }
224 }
225}
226
227pub const X_AUTHORIZE_URL: &str = "https://x.com/i/oauth2/authorize";
228
229const URL_POSTFIX: &str = "https://api.x.com";
230pub const X_TOKEN_URL_PREFIX: &str = "/2/oauth2/token";
231
232pub struct XClient {
233 client_id: String,
234 client_secret: String,
235 redirect_uri: String,
236 scopes: Vec<XScope>,
237 try_count: usize,
238 retry_duration: Duration,
239 timeout: Duration,
240 prefix_url: Option<String>,
241}
242
243impl XClient {
244 pub fn new(
245 client_id: &str,
246 client_secret: &str,
247 redirect_uri: &str,
248 scopes: Vec<XScope>,
249 ) -> Self {
250 Self::new_with_token_options(
251 client_id,
252 client_secret,
253 redirect_uri,
254 scopes,
255 3,
256 Duration::from_millis(100),
257 Duration::from_secs(10),
258 None,
259 )
260 }
261
262 #[allow(clippy::too_many_arguments)]
263 pub fn new_with_token_options(
264 client_id: &str,
265 client_secret: &str,
266 redirect_uri: &str,
267 scopes: Vec<XScope>,
268 try_count: usize,
269 retry_duration: Duration,
270 timeout: Duration,
271 prefix_url: Option<String>,
272 ) -> Self {
273 Self {
274 client_id: client_id.to_string(),
275 client_secret: client_secret.to_string(),
276 redirect_uri: redirect_uri.to_string(),
277 scopes,
278 try_count,
279 retry_duration,
280 timeout,
281 prefix_url,
282 }
283 }
284
285 pub fn authorize_url(&self, state: &str) -> (String, String) {
286 let pkce = PkceS256::new();
287
288 let scopes_str = XScope::scopes_to_string(&self.scopes);
289 (
290 authorize_url(
291 X_AUTHORIZE_URL,
292 ResponseType::Code,
293 &self.client_id,
294 &self.redirect_uri,
295 &scopes_str,
296 state,
297 &pkce.code_challenge,
298 CodeChallengeMethod::S256,
299 ),
300 pkce.code_verifier,
301 )
302 }
303
304 pub async fn token(
305 &self,
306 code: &str,
307 code_verifier: &str,
308 ) -> Result<(TokenResult, StatusCode, HeaderMap), Error> {
309 let (token_json, status_code, headers) = token(
310 &make_url(URL_POSTFIX, X_TOKEN_URL_PREFIX, &self.prefix_url),
311 &self.client_id,
312 &self.client_secret,
313 &self.redirect_uri,
314 code,
315 code_verifier,
316 "authorization_code",
317 self.timeout,
318 self.try_count,
319 self.retry_duration,
320 )
321 .await?;
322 Ok((token_json, status_code, headers))
323 }
324}
325
326#[cfg(test)]
327mod tests {
328 use super::*;
329
330 #[tokio::test]
332 async fn test_x_authorize() {
333 let client_id = std::env::var("CLIENT_ID").unwrap();
334 let client_secret = std::env::var("CLIENT_SECRET").unwrap();
335 let redirect_url = std::env::var("REDIRECT_URL").unwrap();
336 let state = "test_state";
337 let x_client = XClient::new(&client_id, &client_secret, &redirect_url, XScope::all());
338 let (auth_url, code_verifier) = x_client.authorize_url(state);
339 println!("Authorize URL: {}", auth_url);
340 println!("Code Verifier: {}", code_verifier);
341 }
342
343 #[tokio::test]
345 async fn test_authorize() {
346 let client_id = std::env::var("CLIENT_ID").unwrap();
347 let redirect_url = std::env::var("REDIRECT_URL").unwrap();
348 let state = "test_state";
349 let scopes = XScope::scopes_to_string(&XScope::all());
350 let code_challenge = "test_code_challenge";
351 let res = authorize_url(
352 X_AUTHORIZE_URL,
353 ResponseType::Code,
354 &client_id,
355 &redirect_url,
356 &scopes,
357 &state,
358 &code_challenge,
359 CodeChallengeMethod::Plain,
360 );
361 println!("res: {}", res);
362 }
363}