github_device_oauth/
lib.rs

1#![feature(stmt_expr_attributes)]
2
3// More info: https://docs.github.com/en/apps/creating-github-apps/authenticating-with-a-github-app/generating-a-user-access-token-for-a-github-app#using-the-device-flow-to-generate-a-user-access-token
4
5use serde_derive::{Deserialize, Serialize};
6use std::time::Duration;
7use thiserror::Error;
8use tokio::time;
9
10#[derive(Error, Debug)]
11pub enum DeviceFlowError {
12    #[error(transparent)]
13    RequestError(#[from] reqwest::Error),
14    #[error("Request failed with status code: {}", .0)]
15    RequestFailureError(reqwest::StatusCode),
16    #[error("Authorization request expired")]
17    AuthRequestExpired,
18    #[error("Expired access token")]
19    ExpiredTokenError,
20    // We want to show the erroneous response in the error message
21    // thus we do not use #[from] here
22    #[error("Bad refresh token")]
23    BadRefreshToken,
24    #[error("Unverified user email")]
25    UnverifiedUserEmail,
26    #[error("Slow down")]
27    SlowDown,
28    #[error("Authorization pending")]
29    AuthorizationPending,
30    #[error("Could not deserialize response")]
31    DeserializationError(String),
32    #[error("Device flow disabled")]
33    DeviceFlowDisabled,
34    #[error("Incorrect client credentials")]
35    IncorrectClientCredentials,
36    #[error("Incorrect device code")]
37    IncorrectDeviceCode,
38    #[error("Access denied")]
39    AccessDenied,
40    #[error("Unsupported grant type")]
41    UnsupportedGrantType,
42    #[error("Refresh token not found")]
43    RefreshTokenNotFound,
44    #[error("This error should be unreachable")]
45    UnreachableError,
46}
47
48#[derive(Serialize, Deserialize, Debug)]
49struct VerificationParams {
50    device_code: String,
51    user_code: String,
52    verification_uri: String,
53    expires_in: u64,
54    interval: u64,
55}
56
57#[derive(Serialize, Deserialize, Debug)]
58struct AnotherResponse {
59    x: i32,
60    y: i32,
61}
62
63#[derive(Serialize, Deserialize, Debug)]
64pub struct Credentials {
65    pub access_token: String,
66    pub expires_in: u64,
67    pub refresh_token: String,
68    pub refresh_token_expires_in: u64,
69    pub scope: String,
70    pub token_type: String,
71}
72
73#[derive(Serialize, Deserialize, Debug)]
74#[serde(untagged)]
75enum GithubAPIResponse {
76    VerificationParams(VerificationParams),
77    Credentials(Credentials),
78    ErrorResponse(GithubAPIErrorResponse),
79}
80
81#[derive(Serialize, Deserialize, Debug)]
82struct GithubAPIErrorResponse {
83    #[serde(flatten)]
84    variant: GithubAPIErrorVariant,
85    error_description: String,
86    error_uri: String,
87}
88
89#[derive(Serialize, Deserialize, Debug)]
90#[serde(tag = "error", rename_all = "snake_case")]
91enum GithubAPIErrorVariant {
92    AuthorizationPending,
93    SlowDown,
94    ExpiredToken,
95    UnsupportedGrantType,
96    BadRefreshToken,
97    UnverifiedUserEmail,
98    IncorrectClientCredentials,
99    IncorrectDeviceCode,
100    AccessDenied,
101    DeviceFlowDisabled,
102}
103
104#[derive(Debug, Clone)]
105pub struct DeviceFlow {
106    client_id: String,
107    host: String,
108    scopes: String,
109}
110
111impl DeviceFlow {
112    pub fn new(client_id: String, host: String, scopes: String) -> Self {
113        Self {
114            client_id,
115            host,
116            scopes,
117        }
118    }
119
120    pub async fn refresh_or_authorize(
121        &self,
122        retrive_refresh_token: impl FnOnce() -> Result<String, DeviceFlowError>,
123    ) -> Result<Credentials, DeviceFlowError> {
124        let authorize_and_verify = || async {
125            let vp = self.verify_device().await?;
126            eprintln!("Please enter the code: {}", vp.user_code);
127            eprintln!("At the following URL in your browser:");
128            eprintln!("{}", vp.verification_uri);
129            self.authorize(&vp).await
130        };
131
132        match retrive_refresh_token() {
133            Ok(token) => match self.refresh(token).await {
134                Ok(credentials) => Ok(credentials),
135                Err(e) => match e {
136                    DeviceFlowError::ExpiredTokenError
137                    | DeviceFlowError::IncorrectClientCredentials // Will be returned when the refresh token has been replaced with a new one
138                    | DeviceFlowError::BadRefreshToken => authorize_and_verify().await,
139                    e => Err(e),
140                },
141            },
142            Err(DeviceFlowError::RefreshTokenNotFound) => authorize_and_verify().await,
143            Err(e) => Err(e),
144        }
145    }
146
147    async fn verify_device(&self) -> Result<VerificationParams, DeviceFlowError> {
148        // TODO use serde to build request body
149        let r = send_request(
150            format!("https:/{}/login/device/code", self.host),
151            format!("client_id={}&scope={}", self.client_id, self.scopes),
152        )
153        .await?;
154
155        use GithubAPIErrorVariant::*;
156        use GithubAPIResponse::*;
157        #[rustfmt::skip]
158        let vp_result = match r {
159            VerificationParams(vp) => Ok(vp),
160            Credentials(_) => Err(DeviceFlowError::UnreachableError),
161            ErrorResponse(e) => match e.variant {
162                IncorrectClientCredentials => Err(DeviceFlowError::IncorrectClientCredentials),
163                DeviceFlowDisabled => Err(DeviceFlowError::DeviceFlowDisabled),
164                _ => Err(DeviceFlowError::UnreachableError),
165            },
166        };
167        vp_result
168    }
169
170    async fn authorize(&self, vp: &VerificationParams) -> Result<Credentials, DeviceFlowError> {
171        let request_url = format!("https:/{}/login/oauth/access_token", self.host);
172        let request_body = format!(
173            "client_id={}&device_code={}&grant_type=urn:ietf:params:oauth:grant-type:device_code",
174            self.client_id, vp.device_code
175        );
176        /*
177         * Do not poll this endpoint at a higher frequency than the frequency indicated by interval. If
178         * you do, you will hit the rate limit and receive a slow_down error. The slow_down error
179         * response adds 5 seconds to the last interval.
180         */
181        let mut interval = vp.interval;
182
183        let time_start = std::time::Instant::now();
184        while time_start.elapsed().as_secs() < vp.expires_in {
185            let r = request_access_token(request_url.clone(), request_body.clone()).await;
186            match r {
187                Ok(credentials) => return Ok(credentials),
188                Err(DeviceFlowError::SlowDown) => interval += 5,
189                Err(DeviceFlowError::AuthorizationPending) => {
190                    time::sleep(Duration::from_secs(interval)).await;
191                }
192                r => return r,
193            }
194        }
195
196        Err(DeviceFlowError::AuthRequestExpired)
197    }
198
199    async fn refresh(&self, refresh_token: String) -> Result<Credentials, DeviceFlowError> {
200        let request_url = format!("https:/{}/login/oauth/access_token", self.host);
201        let request_body = format!(
202            "client_id={}&refresh_token={}&client_secret=&grant_type=refresh_token",
203            self.client_id, refresh_token
204        );
205
206        request_access_token(request_url, request_body.to_string()).await
207    }
208}
209
210async fn send_request(
211    url: impl AsRef<str>,
212    body: String,
213) -> Result<GithubAPIResponse, DeviceFlowError> {
214    let client = reqwest::Client::new();
215    let response = client
216        .post(url.as_ref())
217        .header("Accept", "application/json")
218        .body(body)
219        .send()
220        .await?
221        .error_for_status()?;
222
223    // Try to deserialize to a [`GithubApiResponse`] enum if that fails, dump the response body as
224    // a string in the error message
225    let body_bytes = response.bytes().await?;
226    String::from_utf8_lossy(&body_bytes).to_string();
227    if let Ok(body) = serde_json::from_slice::<GithubAPIResponse>(&body_bytes) {
228        return Ok(body);
229    } else {
230        let bytes_as_string: String = String::from_utf8_lossy(&body_bytes).to_string();
231        return Err(DeviceFlowError::DeserializationError(bytes_as_string));
232    }
233}
234
235async fn request_access_token(
236    request_url: String,
237    request_body: String,
238) -> Result<Credentials, DeviceFlowError> {
239    let r = send_request(&request_url, request_body.clone()).await?;
240
241    use GithubAPIResponse::*;
242    match r {
243        Credentials(credentials) => Ok(credentials),
244        VerificationParams(_) => Err(DeviceFlowError::UnreachableError),
245        ErrorResponse(er) => Err(er.variant.into()),
246    }
247}
248
249use GithubAPIErrorVariant::*;
250impl Into<DeviceFlowError> for GithubAPIErrorVariant {
251    fn into(self) -> DeviceFlowError {
252        match self {
253            AuthorizationPending => DeviceFlowError::AuthorizationPending,
254            SlowDown => DeviceFlowError::SlowDown,
255            ExpiredToken => DeviceFlowError::ExpiredTokenError,
256            UnsupportedGrantType => DeviceFlowError::UnsupportedGrantType,
257            IncorrectClientCredentials => DeviceFlowError::IncorrectClientCredentials,
258            IncorrectDeviceCode => DeviceFlowError::IncorrectDeviceCode,
259            AccessDenied => DeviceFlowError::AccessDenied,
260            DeviceFlowDisabled => DeviceFlowError::DeviceFlowDisabled,
261            BadRefreshToken => DeviceFlowError::BadRefreshToken,
262            UnverifiedUserEmail => DeviceFlowError::UnverifiedUserEmail,
263        }
264    }
265}
266
267impl Credentials {
268    pub fn try_to_string(&self) -> Result<String, serde_json::Error> {
269        serde_json::to_string(self)
270    }
271    pub fn try_from_string(s: &str) -> Result<Self, serde_json::Error> {
272        serde_json::from_str(s)
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279
280    #[tokio::test]
281    async fn test_decode_credentials() {
282        let payload = r#"{
283            "access_token":"secret",
284            "expires_in":28800,
285            "refresh_token":"secret",
286            "token_type":"bearer",
287            "refresh_token_expires_in":15811200,
288            "scope":""}"#;
289
290        let _ = serde_json::from_str::<GithubAPIResponse>(payload).unwrap();
291    }
292
293    #[tokio::test]
294    async fn test_decode_verification_params() {
295        let payload = r#"{
296        "device_code":"AA",
297        "user_code":"user-code",
298        "verification_uri":"https://example.com/device",
299        "expires_in":1800,
300        "interval":5
301        }"#;
302
303        let _ = serde_json::from_str::<GithubAPIResponse>(payload).unwrap();
304    }
305}