1#![feature(stmt_expr_attributes)]
2
3use serde_derive::{Deserialize, Serialize};
6use std::time::Duration;
7use thiserror::Error;
8use tokio::time;
9
10#[derive(Error, Debug)]
11pub enum DeviceFlowError {
12 #[error(transparent)]
13 RequestError(#[from] reqwest::Error),
14 #[error("Request failed with status code: {}", .0)]
15 RequestFailureError(reqwest::StatusCode),
16 #[error("Authorization request expired")]
17 AuthRequestExpired,
18 #[error("Expired access token")]
19 ExpiredTokenError,
20 #[error("Bad refresh token")]
23 BadRefreshToken,
24 #[error("Unverified user email")]
25 UnverifiedUserEmail,
26 #[error("Slow down")]
27 SlowDown,
28 #[error("Authorization pending")]
29 AuthorizationPending,
30 #[error("Could not deserialize response")]
31 DeserializationError(String),
32 #[error("Device flow disabled")]
33 DeviceFlowDisabled,
34 #[error("Incorrect client credentials")]
35 IncorrectClientCredentials,
36 #[error("Incorrect device code")]
37 IncorrectDeviceCode,
38 #[error("Access denied")]
39 AccessDenied,
40 #[error("Unsupported grant type")]
41 UnsupportedGrantType,
42 #[error("Refresh token not found")]
43 RefreshTokenNotFound,
44 #[error("This error should be unreachable")]
45 UnreachableError,
46}
47
48#[derive(Serialize, Deserialize, Debug)]
49struct VerificationParams {
50 device_code: String,
51 user_code: String,
52 verification_uri: String,
53 expires_in: u64,
54 interval: u64,
55}
56
57#[derive(Serialize, Deserialize, Debug)]
58struct AnotherResponse {
59 x: i32,
60 y: i32,
61}
62
63#[derive(Serialize, Deserialize, Debug)]
64pub struct Credentials {
65 pub access_token: String,
66 pub expires_in: u64,
67 pub refresh_token: String,
68 pub refresh_token_expires_in: u64,
69 pub scope: String,
70 pub token_type: String,
71}
72
73#[derive(Serialize, Deserialize, Debug)]
74#[serde(untagged)]
75enum GithubAPIResponse {
76 VerificationParams(VerificationParams),
77 Credentials(Credentials),
78 ErrorResponse(GithubAPIErrorResponse),
79}
80
81#[derive(Serialize, Deserialize, Debug)]
82struct GithubAPIErrorResponse {
83 #[serde(flatten)]
84 variant: GithubAPIErrorVariant,
85 error_description: String,
86 error_uri: String,
87}
88
89#[derive(Serialize, Deserialize, Debug)]
90#[serde(tag = "error", rename_all = "snake_case")]
91enum GithubAPIErrorVariant {
92 AuthorizationPending,
93 SlowDown,
94 ExpiredToken,
95 UnsupportedGrantType,
96 BadRefreshToken,
97 UnverifiedUserEmail,
98 IncorrectClientCredentials,
99 IncorrectDeviceCode,
100 AccessDenied,
101 DeviceFlowDisabled,
102}
103
104#[derive(Debug, Clone)]
105pub struct DeviceFlow {
106 client_id: String,
107 host: String,
108 scopes: String,
109}
110
111impl DeviceFlow {
112 pub fn new(client_id: String, host: String, scopes: String) -> Self {
113 Self {
114 client_id,
115 host,
116 scopes,
117 }
118 }
119
120 pub async fn refresh_or_authorize(
121 &self,
122 retrive_refresh_token: impl FnOnce() -> Result<String, DeviceFlowError>,
123 ) -> Result<Credentials, DeviceFlowError> {
124 let authorize_and_verify = || async {
125 let vp = self.verify_device().await?;
126 eprintln!("Please enter the code: {}", vp.user_code);
127 eprintln!("At the following URL in your browser:");
128 eprintln!("{}", vp.verification_uri);
129 self.authorize(&vp).await
130 };
131
132 match retrive_refresh_token() {
133 Ok(token) => match self.refresh(token).await {
134 Ok(credentials) => Ok(credentials),
135 Err(e) => match e {
136 DeviceFlowError::ExpiredTokenError
137 | DeviceFlowError::IncorrectClientCredentials | DeviceFlowError::BadRefreshToken => authorize_and_verify().await,
139 e => Err(e),
140 },
141 },
142 Err(DeviceFlowError::RefreshTokenNotFound) => authorize_and_verify().await,
143 Err(e) => Err(e),
144 }
145 }
146
147 async fn verify_device(&self) -> Result<VerificationParams, DeviceFlowError> {
148 let r = send_request(
150 format!("https:/{}/login/device/code", self.host),
151 format!("client_id={}&scope={}", self.client_id, self.scopes),
152 )
153 .await?;
154
155 use GithubAPIErrorVariant::*;
156 use GithubAPIResponse::*;
157 #[rustfmt::skip]
158 let vp_result = match r {
159 VerificationParams(vp) => Ok(vp),
160 Credentials(_) => Err(DeviceFlowError::UnreachableError),
161 ErrorResponse(e) => match e.variant {
162 IncorrectClientCredentials => Err(DeviceFlowError::IncorrectClientCredentials),
163 DeviceFlowDisabled => Err(DeviceFlowError::DeviceFlowDisabled),
164 _ => Err(DeviceFlowError::UnreachableError),
165 },
166 };
167 vp_result
168 }
169
170 async fn authorize(&self, vp: &VerificationParams) -> Result<Credentials, DeviceFlowError> {
171 let request_url = format!("https:/{}/login/oauth/access_token", self.host);
172 let request_body = format!(
173 "client_id={}&device_code={}&grant_type=urn:ietf:params:oauth:grant-type:device_code",
174 self.client_id, vp.device_code
175 );
176 let mut interval = vp.interval;
182
183 let time_start = std::time::Instant::now();
184 while time_start.elapsed().as_secs() < vp.expires_in {
185 let r = request_access_token(request_url.clone(), request_body.clone()).await;
186 match r {
187 Ok(credentials) => return Ok(credentials),
188 Err(DeviceFlowError::SlowDown) => interval += 5,
189 Err(DeviceFlowError::AuthorizationPending) => {
190 time::sleep(Duration::from_secs(interval)).await;
191 }
192 r => return r,
193 }
194 }
195
196 Err(DeviceFlowError::AuthRequestExpired)
197 }
198
199 async fn refresh(&self, refresh_token: String) -> Result<Credentials, DeviceFlowError> {
200 let request_url = format!("https:/{}/login/oauth/access_token", self.host);
201 let request_body = format!(
202 "client_id={}&refresh_token={}&client_secret=&grant_type=refresh_token",
203 self.client_id, refresh_token
204 );
205
206 request_access_token(request_url, request_body.to_string()).await
207 }
208}
209
210async fn send_request(
211 url: impl AsRef<str>,
212 body: String,
213) -> Result<GithubAPIResponse, DeviceFlowError> {
214 let client = reqwest::Client::new();
215 let response = client
216 .post(url.as_ref())
217 .header("Accept", "application/json")
218 .body(body)
219 .send()
220 .await?
221 .error_for_status()?;
222
223 let body_bytes = response.bytes().await?;
226 String::from_utf8_lossy(&body_bytes).to_string();
227 if let Ok(body) = serde_json::from_slice::<GithubAPIResponse>(&body_bytes) {
228 return Ok(body);
229 } else {
230 let bytes_as_string: String = String::from_utf8_lossy(&body_bytes).to_string();
231 return Err(DeviceFlowError::DeserializationError(bytes_as_string));
232 }
233}
234
235async fn request_access_token(
236 request_url: String,
237 request_body: String,
238) -> Result<Credentials, DeviceFlowError> {
239 let r = send_request(&request_url, request_body.clone()).await?;
240
241 use GithubAPIResponse::*;
242 match r {
243 Credentials(credentials) => Ok(credentials),
244 VerificationParams(_) => Err(DeviceFlowError::UnreachableError),
245 ErrorResponse(er) => Err(er.variant.into()),
246 }
247}
248
249use GithubAPIErrorVariant::*;
250impl Into<DeviceFlowError> for GithubAPIErrorVariant {
251 fn into(self) -> DeviceFlowError {
252 match self {
253 AuthorizationPending => DeviceFlowError::AuthorizationPending,
254 SlowDown => DeviceFlowError::SlowDown,
255 ExpiredToken => DeviceFlowError::ExpiredTokenError,
256 UnsupportedGrantType => DeviceFlowError::UnsupportedGrantType,
257 IncorrectClientCredentials => DeviceFlowError::IncorrectClientCredentials,
258 IncorrectDeviceCode => DeviceFlowError::IncorrectDeviceCode,
259 AccessDenied => DeviceFlowError::AccessDenied,
260 DeviceFlowDisabled => DeviceFlowError::DeviceFlowDisabled,
261 BadRefreshToken => DeviceFlowError::BadRefreshToken,
262 UnverifiedUserEmail => DeviceFlowError::UnverifiedUserEmail,
263 }
264 }
265}
266
267impl Credentials {
268 pub fn try_to_string(&self) -> Result<String, serde_json::Error> {
269 serde_json::to_string(self)
270 }
271 pub fn try_from_string(s: &str) -> Result<Self, serde_json::Error> {
272 serde_json::from_str(s)
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279
280 #[tokio::test]
281 async fn test_decode_credentials() {
282 let payload = r#"{
283 "access_token":"secret",
284 "expires_in":28800,
285 "refresh_token":"secret",
286 "token_type":"bearer",
287 "refresh_token_expires_in":15811200,
288 "scope":""}"#;
289
290 let _ = serde_json::from_str::<GithubAPIResponse>(payload).unwrap();
291 }
292
293 #[tokio::test]
294 async fn test_decode_verification_params() {
295 let payload = r#"{
296 "device_code":"AA",
297 "user_code":"user-code",
298 "verification_uri":"https://example.com/device",
299 "expires_in":1800,
300 "interval":5
301 }"#;
302
303 let _ = serde_json::from_str::<GithubAPIResponse>(payload).unwrap();
304 }
305}