github_device_flow/
lib.rs

1
2use std::{result::Result, thread, time, fmt};
3use std::collections::HashMap;
4
5use chrono::{DateTime, Duration};
6use chrono::offset::Utc;
7
8mod util;
9
10#[derive(Debug, Default, Clone, serde_derive::Serialize)]
11pub struct Credential {
12    pub token: String,
13    pub expiry: String,
14    pub refresh_token: String,
15}
16
17impl Credential {
18    fn empty() -> Credential {
19        Credential {
20            token: String::new(),
21            expiry: String::new(),
22            refresh_token: String::new(),
23        }
24    }
25
26    pub fn is_expired(&self) -> bool {
27        let exp = match DateTime::parse_from_rfc3339(self.expiry.as_str()) {
28            Ok(time) => time,
29            Err(_) => return false
30        };
31        let now = Utc::now();
32        now > exp
33    }
34}
35
36
37#[derive(Debug, Clone)]
38pub enum DeviceFlowError {
39  HttpError(String),
40  GitHubError(String),
41}
42
43impl fmt::Display for DeviceFlowError {
44    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
45        match self {
46            DeviceFlowError::HttpError(string) => write!(f, "DeviceFlowError: {}", string),
47            DeviceFlowError::GitHubError(string) => write!(f, "DeviceFlowError: {}", string)
48        }
49    }
50}
51
52impl std::error::Error for DeviceFlowError {}
53
54impl From<reqwest::Error> for DeviceFlowError {
55    fn from(e: reqwest::Error) -> Self {
56        DeviceFlowError::HttpError(format!("{:?}", e))
57    }
58}
59
60pub fn authorize(client_id: String, host: Option<String>) -> Result<Credential, DeviceFlowError> {
61    let my_string: String;
62    let thost = match host {
63        Some(string) => {
64            my_string = string;
65            Some(my_string.as_str())
66        },
67        None => None
68    };
69
70    let mut flow = DeviceFlow::start(client_id.as_str(), thost)?;
71
72    // eprintln!("res is {:?}", res);
73    eprintln!("Please visit {} in your browser", flow.verification_uri.clone().unwrap());
74    eprintln!("And enter code: {}", flow.user_code.clone().unwrap());
75
76    thread::sleep(FIVE_SECONDS);
77
78    flow.poll(20)
79}
80
81pub fn refresh(client_id: &str, refresh_token: &str, host: Option<String>) -> Result<Credential, DeviceFlowError> {
82    let my_string: String;
83    let thost = match host {
84        Some(string) => {
85            my_string = string;
86            Some(my_string.as_str())
87        },
88        None => None
89    };
90
91    refresh_access_token(client_id, refresh_token, thost)
92}
93
94#[derive(Debug, Clone)]
95pub enum DeviceFlowState {
96    Pending,
97    Processing(time::Duration),
98    Success(Credential),
99    Failure(DeviceFlowError)
100}
101
102#[derive(Clone)]
103pub struct DeviceFlow {
104    pub host: String,
105    pub client_id: String,
106    pub user_code: Option<String>,
107    pub device_code: Option<String>,
108    pub verification_uri: Option<String>,
109    pub state: DeviceFlowState,
110}
111
112const FIVE_SECONDS: time::Duration = time::Duration::new(5, 0);
113
114impl DeviceFlow {
115    pub fn new(client_id: &str, maybe_host: Option<&str>) -> Self {
116        Self{
117            client_id: String::from(client_id),
118            host: match maybe_host {
119                Some(string) => String::from(string),
120                None => String::from("github.com")
121            },
122            user_code: None,
123            device_code: None,
124            verification_uri: None,
125            state: DeviceFlowState::Pending,
126        }
127    }
128
129    pub fn start(client_id: &str, maybe_host: Option<&str>) -> Result<DeviceFlow, DeviceFlowError> {
130        let mut flow = DeviceFlow::new(client_id, maybe_host);
131
132        flow.setup();
133
134        match flow.state {
135            DeviceFlowState::Processing(_) => Ok(flow.to_owned()),
136            DeviceFlowState::Failure(err) => Err(err),
137            _ => Err(util::credential_error("Something truly unexpected happened".into()))
138        }
139    }
140
141    pub fn setup(&mut self) {
142        let body = format!("client_id={}", &self.client_id);
143        let entry_url = format!("https://{}/login/device/code", &self.host);
144
145        if let Some(res) = util::send_request(self, entry_url, body) {
146            if res.contains_key("error") && res.contains_key("error_description"){
147                self.state = DeviceFlowState::Failure(util::credential_error(res["error_description"].as_str().unwrap().into()))
148            } else if res.contains_key("error") {
149                 self.state = DeviceFlowState::Failure(util::credential_error(format!("Error response: {:?}", res["error"].as_str().unwrap())))
150            } else {
151                self.user_code = Some(String::from(res["user_code"].as_str().unwrap()));
152                self.device_code = Some(String::from(res["device_code"].as_str().unwrap()));
153                self.verification_uri = Some(String::from(res["verification_uri"].as_str().unwrap()));
154                self.state = DeviceFlowState::Processing(FIVE_SECONDS);
155            }
156        };
157    }
158
159    pub fn poll(&mut self, iterations: u32) -> Result<Credential, DeviceFlowError> {
160        for count in 0..iterations {
161            self.update();
162
163            if let DeviceFlowState::Processing(interval) = self.state {
164                if count == iterations {
165                    return Err(util::credential_error("Max poll iterations reached".into()))
166                }
167
168                thread::sleep(interval);
169            } else {
170                break
171            }
172        };
173
174        match &self.state {
175            DeviceFlowState::Success(cred) => Ok(cred.to_owned()),
176            DeviceFlowState::Failure(err) => Err(err.to_owned()),
177            _ => Err(util::credential_error("Unable to fetch credential, sorry :/".into()))
178        }
179    }
180
181    pub fn update(&mut self) {
182        let poll_url = format!("https://{}/login/oauth/access_token", self.host);
183        let poll_payload = format!("client_id={}&device_code={}&grant_type=urn:ietf:params:oauth:grant-type:device_code",
184            self.client_id,
185            &self.device_code.clone().unwrap()
186        );
187
188        if let Some(res) = util::send_request(self, poll_url, poll_payload) {
189            if res.contains_key("error") {
190                match res["error"].as_str().unwrap() {
191                    "authorization_pending" => {},
192                    "slow_down" => {
193                        if let DeviceFlowState::Processing(current_interval) = self.state {
194                            self.state = DeviceFlowState::Processing(current_interval + FIVE_SECONDS);
195                        };
196                    },
197                    other_reason => {
198                        self.state = DeviceFlowState::Failure(
199                            util::credential_error(format!("Error checking for token: {}", other_reason))
200                        );
201                    },
202                }
203            } else {
204                let mut this_credential = Credential::empty();
205                this_credential.token = res["access_token"].as_str().unwrap().to_string();
206
207                if let Some(expires_in) = res.get("expires_in") {
208                    this_credential.expiry = calculate_expiry(expires_in.as_i64().unwrap());
209                    this_credential.refresh_token = res["refresh_token"].as_str().unwrap().to_string();
210                }
211
212                self.state = DeviceFlowState::Success(this_credential);
213            }
214        }
215    }
216}
217
218fn calculate_expiry(expires_in: i64) -> String {
219    let expires_in = Duration::seconds(expires_in);
220    let mut expiry: DateTime<Utc> = Utc::now();
221    expiry = expiry + expires_in;
222    expiry.to_rfc3339()
223}
224
225fn refresh_access_token(client_id: &str, refresh_token: &str, maybe_host: Option<&str>) -> Result<Credential, DeviceFlowError> {
226    let host = match maybe_host {
227        Some(string) => string,
228        None => "github.com"
229    };
230
231    let client = reqwest::blocking::Client::new();
232    let entry_url = format!("https://{}/login/oauth/access_token", &host);
233    let request_body = format!("client_id={}&refresh_token={}&client_secret=&grant_type=refresh_token",
234        &client_id, &refresh_token);
235
236    let res = client.post(&entry_url)
237        .header("Accept", "application/json")
238        .body(request_body)
239        .send()?
240        .json::<HashMap<String, serde_json::Value>>()?;
241
242    if res.contains_key("error") {
243        Err(util::credential_error(res["error"].as_str().unwrap().into()))
244    } else {
245        let mut credential = Credential::empty();
246        // eprintln!("res: {:?}", &res);
247        credential.token = res["access_token"].as_str().unwrap().to_string();
248
249        if let Some(expires_in) = res.get("expires_in") {
250            credential.expiry = calculate_expiry(expires_in.as_i64().unwrap());
251            credential.refresh_token = res["refresh_token"].as_str().unwrap().to_string();
252        }
253
254        Ok(credential.clone())
255    }
256}
257
258#[cfg(test)]
259mod tests {
260    use crate::{Credential};
261    use chrono::offset::Utc;
262    use chrono::{DateTime, Duration};
263
264    #[test]
265    fn credential_expiry_is_expired_returns_false_when_expiry_is_in_the_future() {
266        let expires_in = Duration::seconds(28800);
267        let mut expiry: DateTime<Utc> = Utc::now();
268        expiry = expiry + expires_in;
269        let calculated_expiry = expiry.to_rfc3339();
270
271        let credential = Credential {
272            token: String::from("irrelevant"),
273            expiry: calculated_expiry,
274            refresh_token: String::from("irrelevant"),
275        };
276
277        eprintln!("{:?}", credential);
278
279        assert_eq!(true, credential.is_expired());
280    }
281
282    #[test]
283    fn credential_expiry_is_expired_returns_true_when_expiry_is_in_the_past() {
284        let expires_in = Duration::seconds(42);
285        let mut expiry: DateTime<Utc> = Utc::now();
286        expiry = expiry - expires_in;
287        let calculated_expiry = expiry.to_rfc3339();
288
289        let credential = Credential {
290            token: String::from("irrelevant"),
291            expiry: calculated_expiry,
292            refresh_token: String::from("irrelevant"),
293        };
294
295        assert_eq!(true, credential.is_expired());
296    }
297}