1use base64::prelude::*;
2use std::time::Duration;
3
4use query_string_builder::QueryString;
5use reqwest::{StatusCode, header::HeaderMap};
6use serde::{Deserialize, Serialize};
7use sha2::Digest;
8
9use crate::{error::Error, execute_retry};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct TokenResult {
13 pub access_token: String,
14 pub refresh_token: String,
15 pub expires_in: u64,
16 pub scope: String,
17 pub token_type: String,
18}
19
20enum ResponseType {
21 Code,
22 #[allow(unused)]
23 Token,
24}
25
26impl std::fmt::Display for ResponseType {
27 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
28 match self {
29 Self::Code => write!(f, "code"),
30 Self::Token => write!(f, "token"),
31 }
32 }
33}
34
35enum CodeChallengeMethod {
36 S256,
37 #[allow(unused)]
38 Plain,
39}
40
41impl std::fmt::Display for CodeChallengeMethod {
42 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
43 match self {
44 Self::S256 => write!(f, "S256"),
45 Self::Plain => write!(f, "plain"),
46 }
47 }
48}
49
50pub(crate) struct PkceS256 {
51 pub code_challenge: String,
52 pub code_verifier: String,
53}
54
55impl PkceS256 {
56 pub fn new() -> Self {
57 let size = 32;
58 let random_bytes: Vec<u8> = (0..size).map(|_| rand::random::<u8>()).collect();
59 let code_verifier = BASE64_URL_SAFE_NO_PAD.encode(&random_bytes);
60 let code_challenge = {
61 let hash = sha2::Sha256::digest(code_verifier.as_bytes());
62 BASE64_URL_SAFE_NO_PAD.encode(hash)
63 };
64 Self {
65 code_challenge,
66 code_verifier,
67 }
68 }
69}
70
71impl Default for PkceS256 {
72 fn default() -> Self {
73 Self::new()
74 }
75}
76
77#[allow(clippy::too_many_arguments)]
78fn authorize_url(
79 url: &str,
80 response_type: ResponseType,
81 client_id: &str,
82 redirect_uri: &str,
83 scopes: &str,
84 state: &str,
85 code_challenge: &str,
86 code_challenge_method: CodeChallengeMethod,
87) -> String {
88 let qs = QueryString::dynamic()
89 .with_value("response_type", response_type.to_string())
90 .with_value("client_id", client_id)
91 .with_value("redirect_uri", redirect_uri)
92 .with_value("scope", scopes)
93 .with_value("state", state)
94 .with_value("code_challenge", code_challenge)
95 .with_value("code_challenge_method", code_challenge_method.to_string());
96 format!("{}{}", url, qs)
97}
98
99#[allow(clippy::too_many_arguments)]
100pub(crate) async fn token(
101 url: &str,
102 client_id: &str,
103 client_secret: &str,
104 redirect_uri: &str,
105 code: &str,
106 code_verifier: &str,
107 grant_type: &str,
108 timeout: Duration,
109 try_count: usize,
110 retry_millis: u64,
111) -> Result<(TokenResult, StatusCode, HeaderMap), Error> {
112 let params = [
113 ("grant_type", grant_type),
114 ("code", code),
115 ("redirect_uri", redirect_uri),
116 ("client_id", client_id),
117 ("code_verifier", code_verifier),
118 ];
119
120 let client = reqwest::Client::new();
121
122 execute_retry(
123 || {
124 client
125 .post(url)
126 .form(¶ms)
127 .basic_auth(client_id, Some(client_secret))
128 .timeout(timeout)
129 },
130 try_count,
131 retry_millis,
132 )
133 .await
134}
135
136pub enum XScope {
137 TweetRead,
138 TweetWrite,
139 TweetModerateWrite,
140 UsersEmail,
141 UsersRead,
142 FollowsRead,
143 FollowsWrite,
144 OfflineAccess,
145 SpaceRead,
146 MuteRead,
147 MuteWrite,
148 LikeRead,
149 LikeWrite,
150 ListRead,
151 ListWrite,
152 BlockRead,
153 BlockWrite,
154 BookmarkRead,
155 BookmarkWrite,
156 DmRead,
157 DmWrite,
158 MediaWrite,
159}
160
161impl XScope {
162 pub fn all() -> Vec<Self> {
163 vec![
164 Self::TweetRead,
165 Self::TweetWrite,
166 Self::TweetModerateWrite,
167 Self::UsersEmail,
168 Self::UsersRead,
169 Self::FollowsRead,
170 Self::FollowsWrite,
171 Self::OfflineAccess,
172 Self::SpaceRead,
173 Self::MuteRead,
174 Self::MuteWrite,
175 Self::LikeRead,
176 Self::LikeWrite,
177 Self::ListRead,
178 Self::ListWrite,
179 Self::BlockRead,
180 Self::BlockWrite,
181 Self::BookmarkRead,
182 Self::BookmarkWrite,
183 Self::DmRead,
184 Self::DmWrite,
185 Self::MediaWrite,
186 ]
187 }
188
189 pub fn scopes_to_string(scopes: &[XScope]) -> String {
190 scopes
191 .iter()
192 .map(|s| s.to_string())
193 .collect::<Vec<String>>()
194 .join(" ")
195 }
196}
197
198impl std::fmt::Display for XScope {
199 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
200 match self {
201 Self::TweetRead => write!(f, "tweet.read"),
202 Self::TweetWrite => write!(f, "tweet.write"),
203 Self::TweetModerateWrite => write!(f, "tweet.moderate.write"),
204 Self::UsersEmail => write!(f, "users.email"),
205 Self::UsersRead => write!(f, "users.read"),
206 Self::FollowsRead => write!(f, "follows.read"),
207 Self::FollowsWrite => write!(f, "follows.write"),
208 Self::OfflineAccess => write!(f, "offline.access"),
209 Self::SpaceRead => write!(f, "space.read"),
210 Self::MuteRead => write!(f, "mute.read"),
211 Self::MuteWrite => write!(f, "mute.write"),
212 Self::LikeRead => write!(f, "like.read"),
213 Self::LikeWrite => write!(f, "like.write"),
214 Self::ListRead => write!(f, "list.read"),
215 Self::ListWrite => write!(f, "list.write"),
216 Self::BlockRead => write!(f, "block.read"),
217 Self::BlockWrite => write!(f, "block.write"),
218 Self::BookmarkRead => write!(f, "bookmark.read"),
219 Self::BookmarkWrite => write!(f, "bookmark.write"),
220 Self::DmRead => write!(f, "dm.read"),
221 Self::DmWrite => write!(f, "dm.write"),
222 Self::MediaWrite => write!(f, "media.write"),
223 }
224 }
225}
226
227pub const X_AUTHORIZE_URL: &str = "https://x.com/i/oauth2/authorize";
228pub const X_TOKEN_URL: &str = "https://api.x.com/2/oauth2/token";
229
230pub struct XClient {
231 client_id: String,
232 client_secret: String,
233 redirect_uri: String,
234 scopes: Vec<XScope>,
235 try_count: usize,
236 retry_millis: u64,
237 timeout: Duration,
238}
239
240impl XClient {
241 pub fn new(
242 client_id: &str,
243 client_secret: &str,
244 redirect_uri: &str,
245 scopes: Vec<XScope>,
246 ) -> Self {
247 Self::new_with_token_options(
248 client_id,
249 client_secret,
250 redirect_uri,
251 scopes,
252 3,
253 500,
254 Duration::from_secs(10),
255 )
256 }
257
258 pub fn new_with_token_options(
259 client_id: &str,
260 client_secret: &str,
261 redirect_uri: &str,
262 scopes: Vec<XScope>,
263 try_count: usize,
264 retry_millis: u64,
265 timeout: Duration,
266 ) -> Self {
267 Self {
268 client_id: client_id.to_string(),
269 client_secret: client_secret.to_string(),
270 redirect_uri: redirect_uri.to_string(),
271 scopes,
272 try_count,
273 retry_millis,
274 timeout,
275 }
276 }
277
278 pub fn authorize_url(&self, state: &str) -> (String, String) {
279 let pkce = PkceS256::new();
280
281 let scopes_str = XScope::scopes_to_string(&self.scopes);
282 (
283 authorize_url(
284 X_AUTHORIZE_URL,
285 ResponseType::Code,
286 &self.client_id,
287 &self.redirect_uri,
288 &scopes_str,
289 state,
290 &pkce.code_challenge,
291 CodeChallengeMethod::S256,
292 ),
293 pkce.code_verifier,
294 )
295 }
296
297 pub async fn token(
298 &self,
299 code: &str,
300 code_verifier: &str,
301 ) -> Result<(TokenResult, StatusCode, HeaderMap), Error> {
302 let (token_json, status_code, headers) = token(
303 X_TOKEN_URL,
304 &self.client_id,
305 &self.client_secret,
306 &self.redirect_uri,
307 code,
308 code_verifier,
309 "authorization_code",
310 self.timeout,
311 self.try_count,
312 self.retry_millis,
313 )
314 .await?;
315 Ok((token_json, status_code, headers))
316 }
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322
323 #[tokio::test]
325 async fn test_x_authorize() {
326 let client_id = std::env::var("CLIENT_ID").unwrap();
327 let client_secret = std::env::var("CLIENT_SECRET").unwrap();
328 let redirect_url = std::env::var("REDIRECT_URL").unwrap();
329 let state = "test_state";
330 let x_client = XClient::new(&client_id, &client_secret, &redirect_url, XScope::all());
331 let (auth_url, code_verifier) = x_client.authorize_url(state);
332 println!("Authorize URL: {}", auth_url);
333 println!("Code Verifier: {}", code_verifier);
334 }
335
336 #[tokio::test]
338 async fn test_authorize() {
339 let client_id = std::env::var("CLIENT_ID").unwrap();
340 let redirect_url = std::env::var("REDIRECT_URL").unwrap();
341 let state = "test_state";
342 let scopes = XScope::scopes_to_string(&XScope::all());
343 let code_challenge = "test_code_challenge";
344 let res = authorize_url(
345 X_AUTHORIZE_URL,
346 ResponseType::Code,
347 &client_id,
348 &redirect_url,
349 &scopes,
350 &state,
351 &code_challenge,
352 CodeChallengeMethod::Plain,
353 );
354 println!("res: {}", res);
355 }
356}