Skip to main content

twapi_oauth2/
oauth2.rs

1use base64::prelude::*;
2use std::time::Duration;
3
4use query_string_builder::QueryString;
5use reqwest::{StatusCode, header::HeaderMap};
6use serde::{Deserialize, Serialize};
7use sha2::Digest;
8
9use crate::{error::Error, execute_retry, make_url};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct TokenResult {
13    pub access_token: String,
14    pub refresh_token: String,
15    pub expires_in: u64,
16    pub scope: String,
17    pub token_type: String,
18}
19
20enum ResponseType {
21    Code,
22    #[allow(unused)]
23    Token,
24}
25
26impl std::fmt::Display for ResponseType {
27    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
28        match self {
29            Self::Code => write!(f, "code"),
30            Self::Token => write!(f, "token"),
31        }
32    }
33}
34
35enum CodeChallengeMethod {
36    S256,
37    #[allow(unused)]
38    Plain,
39}
40
41impl std::fmt::Display for CodeChallengeMethod {
42    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
43        match self {
44            Self::S256 => write!(f, "S256"),
45            Self::Plain => write!(f, "plain"),
46        }
47    }
48}
49
50pub(crate) struct PkceS256 {
51    pub code_challenge: String,
52    pub code_verifier: String,
53}
54
55impl PkceS256 {
56    pub fn new() -> Self {
57        let size = 32;
58        let random_bytes: Vec<u8> = (0..size).map(|_| rand::random::<u8>()).collect();
59        let code_verifier = BASE64_URL_SAFE_NO_PAD.encode(&random_bytes);
60        let code_challenge = {
61            let hash = sha2::Sha256::digest(code_verifier.as_bytes());
62            BASE64_URL_SAFE_NO_PAD.encode(hash)
63        };
64        Self {
65            code_challenge,
66            code_verifier,
67        }
68    }
69}
70
71impl Default for PkceS256 {
72    fn default() -> Self {
73        Self::new()
74    }
75}
76
77#[allow(clippy::too_many_arguments)]
78fn authorize_url(
79    url: &str,
80    response_type: ResponseType,
81    client_id: &str,
82    redirect_uri: &str,
83    scopes: &str,
84    state: &str,
85    code_challenge: &str,
86    code_challenge_method: CodeChallengeMethod,
87) -> String {
88    let qs = QueryString::dynamic()
89        .with_value("response_type", response_type.to_string())
90        .with_value("client_id", client_id)
91        .with_value("redirect_uri", redirect_uri)
92        .with_value("scope", scopes)
93        .with_value("state", state)
94        .with_value("code_challenge", code_challenge)
95        .with_value("code_challenge_method", code_challenge_method.to_string());
96    format!("{}{}", url, qs)
97}
98
99#[allow(clippy::too_many_arguments)]
100pub(crate) async fn token(
101    url: &str,
102    client_id: &str,
103    client_secret: &str,
104    redirect_uri: &str,
105    code: &str,
106    code_verifier: &str,
107    grant_type: &str,
108    timeout: Duration,
109    try_count: usize,
110    retry_duration: Duration,
111) -> Result<(TokenResult, StatusCode, HeaderMap), Error> {
112    let params = [
113        ("grant_type", grant_type),
114        ("code", code),
115        ("redirect_uri", redirect_uri),
116        ("client_id", client_id),
117        ("code_verifier", code_verifier),
118    ];
119
120    let client = reqwest::Client::new();
121
122    execute_retry(
123        || {
124            client
125                .post(url)
126                .form(&params)
127                .basic_auth(client_id, Some(client_secret))
128                .timeout(timeout)
129        },
130        try_count,
131        retry_duration,
132    )
133    .await
134}
135
136pub async fn refresh_token(
137    client_id: &str,
138    client_secret: &str,
139    refresh_token: &str,
140    timeout: Duration,
141    try_count: usize,
142    retry_duration: Duration,
143    prefix_url: Option<String>,
144) -> Result<(TokenResult, StatusCode, HeaderMap), Error> {
145    let url = &make_url(URL_POSTFIX, X_TOKEN_URL_PREFIX, &prefix_url);
146    let params = [
147        ("grant_type", "refresh_token"),
148        ("refresh_token", refresh_token),
149        ("client_id", client_id),
150    ];
151
152    let client = reqwest::Client::new();
153
154    execute_retry(
155        || {
156            client
157                .post(url)
158                .form(&params)
159                .basic_auth(client_id, Some(client_secret))
160                .timeout(timeout)
161        },
162        try_count,
163        retry_duration,
164    )
165    .await
166}
167
168pub enum XScope {
169    TweetRead,
170    TweetWrite,
171    TweetModerateWrite,
172    UsersEmail,
173    UsersRead,
174    FollowsRead,
175    FollowsWrite,
176    OfflineAccess,
177    SpaceRead,
178    MuteRead,
179    MuteWrite,
180    LikeRead,
181    LikeWrite,
182    ListRead,
183    ListWrite,
184    BlockRead,
185    BlockWrite,
186    BookmarkRead,
187    BookmarkWrite,
188    DmRead,
189    DmWrite,
190    MediaWrite,
191}
192
193impl XScope {
194    pub fn all() -> Vec<Self> {
195        vec![
196            Self::TweetRead,
197            Self::TweetWrite,
198            Self::TweetModerateWrite,
199            Self::UsersEmail,
200            Self::UsersRead,
201            Self::FollowsRead,
202            Self::FollowsWrite,
203            Self::OfflineAccess,
204            Self::SpaceRead,
205            Self::MuteRead,
206            Self::MuteWrite,
207            Self::LikeRead,
208            Self::LikeWrite,
209            Self::ListRead,
210            Self::ListWrite,
211            Self::BlockRead,
212            Self::BlockWrite,
213            Self::BookmarkRead,
214            Self::BookmarkWrite,
215            Self::DmRead,
216            Self::DmWrite,
217            Self::MediaWrite,
218        ]
219    }
220
221    pub fn scopes_to_string(scopes: &[XScope]) -> String {
222        scopes
223            .iter()
224            .map(|s| s.to_string())
225            .collect::<Vec<String>>()
226            .join(" ")
227    }
228}
229
230impl std::fmt::Display for XScope {
231    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
232        match self {
233            Self::TweetRead => write!(f, "tweet.read"),
234            Self::TweetWrite => write!(f, "tweet.write"),
235            Self::TweetModerateWrite => write!(f, "tweet.moderate.write"),
236            Self::UsersEmail => write!(f, "users.email"),
237            Self::UsersRead => write!(f, "users.read"),
238            Self::FollowsRead => write!(f, "follows.read"),
239            Self::FollowsWrite => write!(f, "follows.write"),
240            Self::OfflineAccess => write!(f, "offline.access"),
241            Self::SpaceRead => write!(f, "space.read"),
242            Self::MuteRead => write!(f, "mute.read"),
243            Self::MuteWrite => write!(f, "mute.write"),
244            Self::LikeRead => write!(f, "like.read"),
245            Self::LikeWrite => write!(f, "like.write"),
246            Self::ListRead => write!(f, "list.read"),
247            Self::ListWrite => write!(f, "list.write"),
248            Self::BlockRead => write!(f, "block.read"),
249            Self::BlockWrite => write!(f, "block.write"),
250            Self::BookmarkRead => write!(f, "bookmark.read"),
251            Self::BookmarkWrite => write!(f, "bookmark.write"),
252            Self::DmRead => write!(f, "dm.read"),
253            Self::DmWrite => write!(f, "dm.write"),
254            Self::MediaWrite => write!(f, "media.write"),
255        }
256    }
257}
258
259pub const X_AUTHORIZE_URL: &str = "https://x.com/i/oauth2/authorize";
260
261const URL_POSTFIX: &str = "https://api.x.com";
262pub const X_TOKEN_URL_PREFIX: &str = "/2/oauth2/token";
263
264pub struct XClient {
265    client_id: String,
266    client_secret: String,
267    redirect_uri: String,
268    scopes: Vec<XScope>,
269    try_count: usize,
270    retry_duration: Duration,
271    timeout: Duration,
272    prefix_url: Option<String>,
273}
274
275impl XClient {
276    pub fn new(
277        client_id: &str,
278        client_secret: &str,
279        redirect_uri: &str,
280        scopes: Vec<XScope>,
281    ) -> Self {
282        Self::new_with_token_options(
283            client_id,
284            client_secret,
285            redirect_uri,
286            scopes,
287            3,
288            Duration::from_millis(100),
289            Duration::from_secs(10),
290            None,
291        )
292    }
293
294    #[allow(clippy::too_many_arguments)]
295    pub fn new_with_token_options(
296        client_id: &str,
297        client_secret: &str,
298        redirect_uri: &str,
299        scopes: Vec<XScope>,
300        try_count: usize,
301        retry_duration: Duration,
302        timeout: Duration,
303        prefix_url: Option<String>,
304    ) -> Self {
305        Self {
306            client_id: client_id.to_string(),
307            client_secret: client_secret.to_string(),
308            redirect_uri: redirect_uri.to_string(),
309            scopes,
310            try_count,
311            retry_duration,
312            timeout,
313            prefix_url,
314        }
315    }
316
317    pub fn authorize_url(&self, state: &str) -> (String, String) {
318        let pkce = PkceS256::new();
319
320        let scopes_str = XScope::scopes_to_string(&self.scopes);
321        (
322            authorize_url(
323                X_AUTHORIZE_URL,
324                ResponseType::Code,
325                &self.client_id,
326                &self.redirect_uri,
327                &scopes_str,
328                state,
329                &pkce.code_challenge,
330                CodeChallengeMethod::S256,
331            ),
332            pkce.code_verifier,
333        )
334    }
335
336    pub async fn token(
337        &self,
338        code: &str,
339        code_verifier: &str,
340    ) -> Result<(TokenResult, StatusCode, HeaderMap), Error> {
341        let (token_json, status_code, headers) = token(
342            &make_url(URL_POSTFIX, X_TOKEN_URL_PREFIX, &self.prefix_url),
343            &self.client_id,
344            &self.client_secret,
345            &self.redirect_uri,
346            code,
347            code_verifier,
348            "authorization_code",
349            self.timeout,
350            self.try_count,
351            self.retry_duration,
352        )
353        .await?;
354        Ok((token_json, status_code, headers))
355    }
356}
357
358#[cfg(test)]
359mod tests {
360    use super::*;
361
362    // CLIENT_ID=xxx CLIENT_SECRET=xxx REDIRECT_URL=http://localhost:8000/callback cargo test test_x_authorize -- --nocapture
363    #[tokio::test]
364    async fn test_x_authorize() {
365        let client_id = std::env::var("CLIENT_ID").unwrap();
366        let client_secret = std::env::var("CLIENT_SECRET").unwrap();
367        let redirect_url = std::env::var("REDIRECT_URL").unwrap();
368        let state = "test_state";
369        let x_client = XClient::new(&client_id, &client_secret, &redirect_url, XScope::all());
370        let (auth_url, code_verifier) = x_client.authorize_url(state);
371        println!("Authorize URL: {}", auth_url);
372        println!("Code Verifier: {}", code_verifier);
373    }
374
375    // CLIENT_ID=xxx cargo test -- --nocapture
376    #[tokio::test]
377    async fn test_authorize() {
378        let client_id = std::env::var("CLIENT_ID").unwrap();
379        let redirect_url = std::env::var("REDIRECT_URL").unwrap();
380        let state = "test_state";
381        let scopes = XScope::scopes_to_string(&XScope::all());
382        let code_challenge = "test_code_challenge";
383        let res = authorize_url(
384            X_AUTHORIZE_URL,
385            ResponseType::Code,
386            &client_id,
387            &redirect_url,
388            &scopes,
389            &state,
390            &code_challenge,
391            CodeChallengeMethod::Plain,
392        );
393        println!("res: {}", res);
394    }
395}