1use crate::error::Result;
5use crate::{
6 auth_server_builder, AuthAction, AuthConfig, AuthUrlProvider, AuthUser, GenericAuthAction,
7};
8use async_trait::async_trait;
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11use serde_with::{
12 formats::{CommaSeparator, SpaceSeparator},
13 serde_as, StringWithSeparator,
14};
15use std::collections::HashMap;
16
17pub struct AuthorizationServer {
18 config: AuthConfig,
19}
20
21auth_server_builder!();
22
23impl AuthUrlProvider for AuthorizationServer {
24 type AuthRequest = AuthRequest;
25 type TokenRequest = GetTokenRequest;
26 type UserInfoRequest = GetUserInfoRequest;
27
28 fn authorize_url(request: Self::AuthRequest) -> Result<String> {
29 let query = serde_urlencoded::to_string(request)?;
30 Ok(format!(
31 "https://twitter.com/i/oauth2/authorize?response_type=code&{query}"
32 ))
33 }
34
35 fn access_token_url(request: Self::TokenRequest) -> Result<String> {
36 let query = serde_urlencoded::to_string(request)?;
37 Ok(format!(
38 "https://api.x.com/2/oauth2/token?grant_type=authorization_code&{query}"
39 ))
40 }
41
42 fn user_info_url(_request: Self::UserInfoRequest) -> Result<String> {
43 Ok(format!("https://api.x.com/2/users/me"))
44 }
45}
46
47#[async_trait]
48impl AuthAction for AuthorizationServer {
49 type AuthCallback = AuthCallback;
50 type AuthToken = TokenResponse;
51 type AuthUser = UserInfoResponse;
52
53 async fn get_access_token(&self, callback: Self::AuthCallback) -> Result<Self::AuthToken> {
54 let AuthConfig {
55 client_id,
56 redirect_uri,
57 ..
58 } = &self.config;
59 let access_token_url = Self::access_token_url(GetTokenRequest {
60 client_id: client_id.to_string(),
61 code: callback.code,
62 redirect_uri: redirect_uri.to_string(),
63 code_verifier: "aaa".to_string(),
64 })?;
65 Ok(reqwest::get(access_token_url).await?.json().await?)
66 }
67
68 async fn get_user_info(&self, token: Self::AuthToken) -> Result<Self::AuthUser> {
69 let user_info_url = Self::user_info_url(GetUserInfoRequest {
70 user_fields: [
71 "created_at",
72 "description",
73 "entities",
74 "id",
75 "location",
76 "most_recent_tweet_id",
77 "name",
78 "pinned_tweet_id",
79 "profile_image_url",
80 "protected",
81 "public_metrics",
82 "url",
83 "username",
84 "verified",
85 "verified_type",
86 "withheld",
87 ]
88 .map(|s| s.to_string())
89 .to_vec(),
90 ..Default::default()
91 })?;
92 Ok(reqwest::Client::default()
93 .get(user_info_url)
94 .bearer_auth(token.access_token)
95 .send()
96 .await?
97 .json()
98 .await?)
99 }
100}
101
102#[async_trait]
103impl GenericAuthAction for AuthorizationServer {
104 async fn authorize<S: Into<String> + Send>(&self, state: S) -> Result<String> {
105 let AuthConfig {
106 client_id,
107 redirect_uri,
108 scope,
109 ..
110 } = &self.config;
111 Self::authorize_url(AuthRequest {
112 client_id: client_id.to_string(),
113 redirect_uri: redirect_uri.to_string(),
114 state: state.into(),
115 scope: scope
116 .clone()
117 .or_else(|| Some(vec!["tweet.read".into(), "users.read".into()]))
118 .expect("scope is empty"),
119 ..Default::default()
120 })
121 }
122
123 async fn login<S: Into<String> + Send>(&self, callback: S) -> Result<AuthUser> {
124 let callback: AuthCallback = serde_urlencoded::from_str(&callback.into())?;
125 let token = self.get_access_token(callback).await?;
126 let user = self.get_user_info(token.clone()).await?;
127 Ok(AuthUser {
128 user_id: user.id,
129 name: user.name,
130 access_token: token.access_token,
131 refresh_token: token.token_type,
132 expires_in: i64::MAX,
133 extra: user.extra,
134 })
135 }
136}
137
138#[serde_as]
139#[derive(Debug, Default, Serialize, Deserialize)]
140pub struct AuthRequest {
141 client_id: String,
142 redirect_uri: String,
143 #[serde_as(as = "StringWithSeparator::<SpaceSeparator, String>")]
144 scope: Vec<String>,
145 state: String,
146 code_challenge: Option<String>,
148 code_challenge_method: Option<String>,
149}
150
151#[derive(Debug, Serialize, Deserialize)]
152pub struct AuthCallback {
153 code: String,
154 state: String,
155}
156
157#[derive(Debug, Serialize)]
158pub struct GetTokenRequest {
159 client_id: String,
160 code: String,
161 redirect_uri: String,
162 code_verifier: String,
163}
164
165#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
166pub struct TokenResponse {
167 pub access_token: String,
168 pub scope: String,
169 pub token_type: String,
170}
171
172#[serde_as]
173#[derive(Debug, Default, Serialize, Deserialize)]
174pub struct GetUserInfoRequest {
175 expansions: Option<String>,
176 #[serde(rename = "tweet.fields")]
177 #[serde_as(as = "StringWithSeparator::<CommaSeparator, String>")]
178 tweet_fields: Vec<String>,
179 #[serde(rename = "user.fields")]
180 #[serde_as(as = "StringWithSeparator::<CommaSeparator, String>")]
181 user_fields: Vec<String>,
182}
183
184#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
186pub struct UserInfoResponse {
187 pub id: String,
188 pub name: String,
189 #[serde(flatten)]
190 pub extra: HashMap<String, Value>,
191}