1use std::collections::HashMap;
2use std::{fmt, result::Result, thread, time};
3
4use chrono::offset::Utc;
5use chrono::{DateTime, Duration};
6
7mod util;
8
9#[derive(Debug, Default, Clone, serde_derive::Serialize, serde_derive::Deserialize)]
10pub struct Credential {
11 pub token: String,
12 pub expiry: String,
13 pub refresh_token: String,
14}
15
16impl Credential {
17 fn empty() -> Credential {
18 Credential {
19 token: String::new(),
20 expiry: String::new(),
21 refresh_token: String::new(),
22 }
23 }
24
25 pub fn is_expired(&self) -> bool {
26 let exp = match DateTime::parse_from_rfc3339(self.expiry.as_str()) {
27 Ok(time) => time,
28 Err(_) => return false,
29 };
30 let now = Utc::now();
31 now > exp
32 }
33}
34
35#[derive(Debug, Clone)]
36pub enum DeviceFlowError {
37 HttpError(String),
38 GitHubError(String),
39}
40
41impl fmt::Display for DeviceFlowError {
42 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
43 match self {
44 DeviceFlowError::HttpError(string) => write!(f, "DeviceFlowError: {}", string),
45 DeviceFlowError::GitHubError(string) => write!(f, "DeviceFlowError: {}", string),
46 }
47 }
48}
49
50impl std::error::Error for DeviceFlowError {}
51
52impl From<reqwest::Error> for DeviceFlowError {
53 fn from(e: reqwest::Error) -> Self {
54 DeviceFlowError::HttpError(format!("{:?}", e))
55 }
56}
57
58pub fn authorize(
59 client_id: String,
60 host: Option<String>,
61 scope: Option<String>,
62) -> Result<Credential, DeviceFlowError> {
63 let my_string: String;
64 let thost = match host {
65 Some(string) => {
66 my_string = string;
67 Some(my_string.as_str())
68 }
69 None => None,
70 };
71
72 let binding: String;
73 let tscope = match scope {
74 Some(string) => {
75 binding = string;
76 Some(binding.as_str())
77 }
78 None => None,
79 };
80
81 let mut flow = DeviceFlow::start(client_id.as_str(), thost, tscope)?;
82
83 eprintln!(
85 "Please visit {} in your browser",
86 flow.verification_uri.clone().unwrap()
87 );
88 eprintln!("And enter code: {}", flow.user_code.clone().unwrap());
89
90 thread::sleep(FIVE_SECONDS);
91
92 flow.poll(20)
93}
94
95pub fn refresh(
96 client_id: &str,
97 refresh_token: &str,
98 host: Option<String>,
99 scope: Option<String>,
100) -> Result<Credential, DeviceFlowError> {
101 let my_string: String;
102 let thost = match host {
103 Some(string) => {
104 my_string = string;
105 Some(my_string.as_str())
106 }
107 None => None,
108 };
109
110 let scope_binding;
111 let tscope = match scope {
112 Some(string) => {
113 scope_binding = string;
114 Some(scope_binding.as_str())
115 }
116 None => None,
117 };
118
119 refresh_access_token(client_id, refresh_token, thost, tscope)
120}
121
122#[derive(Debug, Clone)]
123pub enum DeviceFlowState {
124 Pending,
125 Processing(time::Duration),
126 Success(Credential),
127 Failure(DeviceFlowError),
128}
129
130#[derive(Clone)]
131pub struct DeviceFlow {
132 pub host: String,
133 pub client_id: String,
134 pub scope: String,
135 pub user_code: Option<String>,
136 pub device_code: Option<String>,
137 pub verification_uri: Option<String>,
138 pub state: DeviceFlowState,
139}
140
141const FIVE_SECONDS: time::Duration = time::Duration::new(5, 0);
142
143impl DeviceFlow {
144 pub fn new(client_id: &str, maybe_host: Option<&str>, scope: Option<&str>) -> Self {
145 Self {
146 client_id: String::from(client_id),
147 scope: match scope {
148 Some(string) => String::from(string),
149 None => String::new(),
150 },
151 host: match maybe_host {
152 Some(string) => String::from(string),
153 None => String::from("github.com"),
154 },
155 user_code: None,
156 device_code: None,
157 verification_uri: None,
158 state: DeviceFlowState::Pending,
159 }
160 }
161
162 pub fn start(
163 client_id: &str,
164 maybe_host: Option<&str>,
165 scope: Option<&str>,
166 ) -> Result<DeviceFlow, DeviceFlowError> {
167 let mut flow = DeviceFlow::new(client_id, maybe_host, scope);
168
169 flow.setup();
170
171 match flow.state {
172 DeviceFlowState::Processing(_) => Ok(flow.to_owned()),
173 DeviceFlowState::Failure(err) => Err(err),
174 _ => Err(util::credential_error(
175 "Something truly unexpected happened".into(),
176 )),
177 }
178 }
179
180 pub fn setup(&mut self) {
181 let body = format!("client_id={}&scope={}", &self.client_id, &self.scope);
182 let entry_url = format!("https://{}/login/device/code", &self.host);
183
184 if let Some(res) = util::send_request(self, entry_url, body) {
185 if res.contains_key("error") && res.contains_key("error_description") {
186 self.state = DeviceFlowState::Failure(util::credential_error(
187 res["error_description"].as_str().unwrap().into(),
188 ))
189 } else if res.contains_key("error") {
190 self.state = DeviceFlowState::Failure(util::credential_error(format!(
191 "Error response: {:?}",
192 res["error"].as_str().unwrap()
193 )))
194 } else {
195 self.user_code = Some(String::from(res["user_code"].as_str().unwrap()));
196 self.device_code = Some(String::from(res["device_code"].as_str().unwrap()));
197 self.verification_uri =
198 Some(String::from(res["verification_uri"].as_str().unwrap()));
199 self.state = DeviceFlowState::Processing(FIVE_SECONDS);
200 }
201 };
202 }
203
204 pub fn poll(&mut self, iterations: u32) -> Result<Credential, DeviceFlowError> {
205 for count in 0..iterations {
206 self.update();
207
208 if let DeviceFlowState::Processing(interval) = self.state {
209 if count == iterations {
210 return Err(util::credential_error("Max poll iterations reached".into()));
211 }
212
213 thread::sleep(interval);
214 } else {
215 break;
216 }
217 }
218
219 match &self.state {
220 DeviceFlowState::Success(cred) => Ok(cred.to_owned()),
221 DeviceFlowState::Failure(err) => Err(err.to_owned()),
222 _ => Err(util::credential_error(
223 "Unable to fetch credential, sorry :/".into(),
224 )),
225 }
226 }
227
228 pub fn update(&mut self) {
229 let poll_url = format!("https://{}/login/oauth/access_token", self.host);
230 let poll_payload = format!(
231 "client_id={}&device_code={}&grant_type=urn:ietf:params:oauth:grant-type:device_code",
232 self.client_id,
233 &self.device_code.clone().unwrap()
234 );
235
236 if let Some(res) = util::send_request(self, poll_url, poll_payload) {
237 if res.contains_key("error") {
238 match res["error"].as_str().unwrap() {
239 "authorization_pending" => {}
240 "slow_down" => {
241 if let DeviceFlowState::Processing(current_interval) = self.state {
242 self.state =
243 DeviceFlowState::Processing(current_interval + FIVE_SECONDS);
244 };
245 }
246 other_reason => {
247 self.state = DeviceFlowState::Failure(util::credential_error(format!(
248 "Error checking for token: {}",
249 other_reason
250 )));
251 }
252 }
253 } else {
254 let mut this_credential = Credential::empty();
255 this_credential.token = res["access_token"].as_str().unwrap().to_string();
256
257 if let Some(expires_in) = res.get("expires_in") {
258 this_credential.expiry = calculate_expiry(expires_in.as_i64().unwrap());
259 this_credential.refresh_token =
260 res["refresh_token"].as_str().unwrap().to_string();
261 }
262
263 self.state = DeviceFlowState::Success(this_credential);
264 }
265 }
266 }
267}
268
269fn calculate_expiry(expires_in: i64) -> String {
270 let expires_in = Duration::seconds(expires_in);
271 let mut expiry: DateTime<Utc> = Utc::now();
272 expiry = expiry + expires_in;
273 expiry.to_rfc3339()
274}
275
276fn refresh_access_token(
277 client_id: &str,
278 refresh_token: &str,
279 maybe_host: Option<&str>,
280 maybe_scope: Option<&str>,
281) -> Result<Credential, DeviceFlowError> {
282 let host = match maybe_host {
283 Some(string) => string,
284 None => "github.com",
285 };
286
287 let scope = match maybe_scope {
288 Some(string) => string,
289 None => "",
290 };
291
292 let client = reqwest::blocking::Client::new();
293 let entry_url = format!("https://{}/login/oauth/access_token", &host);
294 let request_body = format!(
295 "client_id={}&refresh_token={}&client_secret=&grant_type=refresh_token&scope={}",
296 &client_id, &refresh_token, &scope
297 );
298
299 let res = client
300 .post(&entry_url)
301 .header("Accept", "application/json")
302 .body(request_body)
303 .send()?
304 .json::<HashMap<String, serde_json::Value>>()?;
305
306 if res.contains_key("error") {
307 Err(util::credential_error(
308 res["error"].as_str().unwrap().into(),
309 ))
310 } else {
311 let mut credential = Credential::empty();
312 credential.token = res["access_token"].as_str().unwrap().to_string();
314
315 if let Some(expires_in) = res.get("expires_in") {
316 credential.expiry = calculate_expiry(expires_in.as_i64().unwrap());
317 credential.refresh_token = res["refresh_token"].as_str().unwrap().to_string();
318 }
319
320 Ok(credential.clone())
321 }
322}
323
324#[cfg(test)]
325mod tests {
326 use crate::{Credential};
327 use chrono::offset::Utc;
328 use chrono::{DateTime, Duration};
329
330 #[test]
331 fn credential_expiry_is_expired_returns_false_when_expiry_is_in_the_future() {
332 let expires_in = Duration::seconds(28800);
333 let mut expiry: DateTime<Utc> = Utc::now();
334 expiry = expiry + expires_in;
335 let calculated_expiry = expiry.to_rfc3339();
336
337 let credential = Credential {
338 token: String::from("irrelevant"),
339 expiry: calculated_expiry,
340 refresh_token: String::from("irrelevant"),
341 };
342
343 eprintln!("{:?}", credential);
344
345 assert_eq!(true, credential.is_expired());
346 }
347
348 #[test]
349 fn credential_expiry_is_expired_returns_true_when_expiry_is_in_the_past() {
350 let expires_in = Duration::seconds(42);
351 let mut expiry: DateTime<Utc> = Utc::now();
352 expiry = expiry - expires_in;
353 let calculated_expiry = expiry.to_rfc3339();
354
355 let credential = Credential {
356 token: String::from("irrelevant"),
357 expiry: calculated_expiry,
358 refresh_token: String::from("irrelevant"),
359 };
360
361 assert_eq!(true, credential.is_expired());
362 }
363}