Skip to main content

github_device_flow/
lib.rs

1use std::collections::HashMap;
2use std::{fmt, result::Result, thread, time};
3
4use chrono::offset::Utc;
5use chrono::{DateTime, Duration};
6
7mod util;
8
9#[derive(Debug, Default, Clone, serde_derive::Serialize, serde_derive::Deserialize)]
10pub struct Credential {
11    pub token: String,
12    pub expiry: String,
13    pub refresh_token: String,
14}
15
16impl Credential {
17    fn empty() -> Credential {
18        Credential {
19            token: String::new(),
20            expiry: String::new(),
21            refresh_token: String::new(),
22        }
23    }
24
25    pub fn is_expired(&self) -> bool {
26        let exp = match DateTime::parse_from_rfc3339(self.expiry.as_str()) {
27            Ok(time) => time,
28            Err(_) => return false,
29        };
30        let now = Utc::now();
31        now > exp
32    }
33}
34
35#[derive(Debug, Clone)]
36pub enum DeviceFlowError {
37    HttpError(String),
38    GitHubError(String),
39}
40
41impl fmt::Display for DeviceFlowError {
42    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
43        match self {
44            DeviceFlowError::HttpError(string) => write!(f, "DeviceFlowError: {}", string),
45            DeviceFlowError::GitHubError(string) => write!(f, "DeviceFlowError: {}", string),
46        }
47    }
48}
49
50impl std::error::Error for DeviceFlowError {}
51
52impl From<reqwest::Error> for DeviceFlowError {
53    fn from(e: reqwest::Error) -> Self {
54        DeviceFlowError::HttpError(format!("{:?}", e))
55    }
56}
57
58pub fn authorize(
59    client_id: String,
60    host: Option<String>,
61    scope: Option<String>,
62) -> Result<Credential, DeviceFlowError> {
63    let my_string: String;
64    let thost = match host {
65        Some(string) => {
66            my_string = string;
67            Some(my_string.as_str())
68        }
69        None => None,
70    };
71
72    let binding: String;
73    let tscope = match scope {
74        Some(string) => {
75            binding = string;
76            Some(binding.as_str())
77        }
78        None => None,
79    };
80
81    let mut flow = DeviceFlow::start(client_id.as_str(), thost, tscope)?;
82
83    // eprintln!("res is {:?}", res);
84    eprintln!(
85        "Please visit {} in your browser",
86        flow.verification_uri.clone().unwrap()
87    );
88    eprintln!("And enter code: {}", flow.user_code.clone().unwrap());
89
90    thread::sleep(FIVE_SECONDS);
91
92    flow.poll(20)
93}
94
95pub fn refresh(
96    client_id: &str,
97    refresh_token: &str,
98    host: Option<String>,
99    scope: Option<String>,
100) -> Result<Credential, DeviceFlowError> {
101    let my_string: String;
102    let thost = match host {
103        Some(string) => {
104            my_string = string;
105            Some(my_string.as_str())
106        }
107        None => None,
108    };
109
110    let scope_binding;
111    let tscope = match scope {
112        Some(string) => {
113            scope_binding = string;
114            Some(scope_binding.as_str())
115        }
116        None => None,
117    };
118
119    refresh_access_token(client_id, refresh_token, thost, tscope)
120}
121
122#[derive(Debug, Clone)]
123pub enum DeviceFlowState {
124    Pending,
125    Processing(time::Duration),
126    Success(Credential),
127    Failure(DeviceFlowError),
128}
129
130#[derive(Clone)]
131pub struct DeviceFlow {
132    pub host: String,
133    pub client_id: String,
134    pub scope: String,
135    pub user_code: Option<String>,
136    pub device_code: Option<String>,
137    pub verification_uri: Option<String>,
138    pub state: DeviceFlowState,
139}
140
141const FIVE_SECONDS: time::Duration = time::Duration::new(5, 0);
142
143impl DeviceFlow {
144    pub fn new(client_id: &str, maybe_host: Option<&str>, scope: Option<&str>) -> Self {
145        Self {
146            client_id: String::from(client_id),
147            scope: match scope {
148                Some(string) => String::from(string),
149                None => String::new(),
150            },
151            host: match maybe_host {
152                Some(string) => String::from(string),
153                None => String::from("github.com"),
154            },
155            user_code: None,
156            device_code: None,
157            verification_uri: None,
158            state: DeviceFlowState::Pending,
159        }
160    }
161
162    pub fn start(
163        client_id: &str,
164        maybe_host: Option<&str>,
165        scope: Option<&str>,
166    ) -> Result<DeviceFlow, DeviceFlowError> {
167        let mut flow = DeviceFlow::new(client_id, maybe_host, scope);
168
169        flow.setup();
170
171        match flow.state {
172            DeviceFlowState::Processing(_) => Ok(flow.to_owned()),
173            DeviceFlowState::Failure(err) => Err(err),
174            _ => Err(util::credential_error(
175                "Something truly unexpected happened".into(),
176            )),
177        }
178    }
179
180    pub fn setup(&mut self) {
181        let body = format!("client_id={}&scope={}", &self.client_id, &self.scope);
182        let entry_url = format!("https://{}/login/device/code", &self.host);
183
184        if let Some(res) = util::send_request(self, entry_url, body) {
185            if res.contains_key("error") && res.contains_key("error_description") {
186                self.state = DeviceFlowState::Failure(util::credential_error(
187                    res["error_description"].as_str().unwrap().into(),
188                ))
189            } else if res.contains_key("error") {
190                self.state = DeviceFlowState::Failure(util::credential_error(format!(
191                    "Error response: {:?}",
192                    res["error"].as_str().unwrap()
193                )))
194            } else {
195                self.user_code = Some(String::from(res["user_code"].as_str().unwrap()));
196                self.device_code = Some(String::from(res["device_code"].as_str().unwrap()));
197                self.verification_uri =
198                    Some(String::from(res["verification_uri"].as_str().unwrap()));
199                self.state = DeviceFlowState::Processing(FIVE_SECONDS);
200            }
201        };
202    }
203
204    pub fn poll(&mut self, iterations: u32) -> Result<Credential, DeviceFlowError> {
205        for count in 0..iterations {
206            self.update();
207
208            if let DeviceFlowState::Processing(interval) = self.state {
209                if count == iterations {
210                    return Err(util::credential_error("Max poll iterations reached".into()));
211                }
212
213                thread::sleep(interval);
214            } else {
215                break;
216            }
217        }
218
219        match &self.state {
220            DeviceFlowState::Success(cred) => Ok(cred.to_owned()),
221            DeviceFlowState::Failure(err) => Err(err.to_owned()),
222            _ => Err(util::credential_error(
223                "Unable to fetch credential, sorry :/".into(),
224            )),
225        }
226    }
227
228    pub fn update(&mut self) {
229        let poll_url = format!("https://{}/login/oauth/access_token", self.host);
230        let poll_payload = format!(
231            "client_id={}&device_code={}&grant_type=urn:ietf:params:oauth:grant-type:device_code",
232            self.client_id,
233            &self.device_code.clone().unwrap()
234        );
235
236        if let Some(res) = util::send_request(self, poll_url, poll_payload) {
237            if res.contains_key("error") {
238                match res["error"].as_str().unwrap() {
239                    "authorization_pending" => {}
240                    "slow_down" => {
241                        if let DeviceFlowState::Processing(current_interval) = self.state {
242                            self.state =
243                                DeviceFlowState::Processing(current_interval + FIVE_SECONDS);
244                        };
245                    }
246                    other_reason => {
247                        self.state = DeviceFlowState::Failure(util::credential_error(format!(
248                            "Error checking for token: {}",
249                            other_reason
250                        )));
251                    }
252                }
253            } else {
254                let mut this_credential = Credential::empty();
255                this_credential.token = res["access_token"].as_str().unwrap().to_string();
256
257                if let Some(expires_in) = res.get("expires_in") {
258                    this_credential.expiry = calculate_expiry(expires_in.as_i64().unwrap());
259                    this_credential.refresh_token =
260                        res["refresh_token"].as_str().unwrap().to_string();
261                }
262
263                self.state = DeviceFlowState::Success(this_credential);
264            }
265        }
266    }
267}
268
269fn calculate_expiry(expires_in: i64) -> String {
270    let expires_in = Duration::seconds(expires_in);
271    let mut expiry: DateTime<Utc> = Utc::now();
272    expiry = expiry + expires_in;
273    expiry.to_rfc3339()
274}
275
276fn refresh_access_token(
277    client_id: &str,
278    refresh_token: &str,
279    maybe_host: Option<&str>,
280    maybe_scope: Option<&str>,
281) -> Result<Credential, DeviceFlowError> {
282    let host = match maybe_host {
283        Some(string) => string,
284        None => "github.com",
285    };
286    
287    let scope = match maybe_scope {
288        Some(string) => string,
289        None => "",
290    };
291
292    let client = reqwest::blocking::Client::new();
293    let entry_url = format!("https://{}/login/oauth/access_token", &host);
294    let request_body = format!(
295        "client_id={}&refresh_token={}&client_secret=&grant_type=refresh_token&scope={}",
296        &client_id, &refresh_token, &scope
297    );
298
299    let res = client
300        .post(&entry_url)
301        .header("Accept", "application/json")
302        .body(request_body)
303        .send()?
304        .json::<HashMap<String, serde_json::Value>>()?;
305
306    if res.contains_key("error") {
307        Err(util::credential_error(
308            res["error"].as_str().unwrap().into(),
309        ))
310    } else {
311        let mut credential = Credential::empty();
312        // eprintln!("res: {:?}", &res);
313        credential.token = res["access_token"].as_str().unwrap().to_string();
314
315        if let Some(expires_in) = res.get("expires_in") {
316            credential.expiry = calculate_expiry(expires_in.as_i64().unwrap());
317            credential.refresh_token = res["refresh_token"].as_str().unwrap().to_string();
318        }
319
320        Ok(credential.clone())
321    }
322}
323
324#[cfg(test)]
325mod tests {
326    use crate::{Credential};
327    use chrono::offset::Utc;
328    use chrono::{DateTime, Duration};
329
330    #[test]
331    fn credential_expiry_is_expired_returns_false_when_expiry_is_in_the_future() {
332        let expires_in = Duration::seconds(28800);
333        let mut expiry: DateTime<Utc> = Utc::now();
334        expiry = expiry + expires_in;
335        let calculated_expiry = expiry.to_rfc3339();
336
337        let credential = Credential {
338            token: String::from("irrelevant"),
339            expiry: calculated_expiry,
340            refresh_token: String::from("irrelevant"),
341        };
342
343        eprintln!("{:?}", credential);
344
345        assert_eq!(true, credential.is_expired());
346    }
347
348    #[test]
349    fn credential_expiry_is_expired_returns_true_when_expiry_is_in_the_past() {
350        let expires_in = Duration::seconds(42);
351        let mut expiry: DateTime<Utc> = Utc::now();
352        expiry = expiry - expires_in;
353        let calculated_expiry = expiry.to_rfc3339();
354
355        let credential = Credential {
356            token: String::from("irrelevant"),
357            expiry: calculated_expiry,
358            refresh_token: String::from("irrelevant"),
359        };
360
361        assert_eq!(true, credential.is_expired());
362    }
363}