1
2use std::{result::Result, thread, time, fmt};
3use std::collections::HashMap;
4
5use chrono::{DateTime, Duration};
6use chrono::offset::Utc;
7
8mod util;
9
10#[derive(Debug, Default, Clone, serde_derive::Serialize)]
11pub struct Credential {
12 pub token: String,
13 pub expiry: String,
14 pub refresh_token: String,
15}
16
17impl Credential {
18 fn empty() -> Credential {
19 Credential {
20 token: String::new(),
21 expiry: String::new(),
22 refresh_token: String::new(),
23 }
24 }
25
26 pub fn is_expired(&self) -> bool {
27 let exp = match DateTime::parse_from_rfc3339(self.expiry.as_str()) {
28 Ok(time) => time,
29 Err(_) => return false
30 };
31 let now = Utc::now();
32 now > exp
33 }
34}
35
36
37#[derive(Debug, Clone)]
38pub enum DeviceFlowError {
39 HttpError(String),
40 GitHubError(String),
41}
42
43impl fmt::Display for DeviceFlowError {
44 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
45 match self {
46 DeviceFlowError::HttpError(string) => write!(f, "DeviceFlowError: {}", string),
47 DeviceFlowError::GitHubError(string) => write!(f, "DeviceFlowError: {}", string)
48 }
49 }
50}
51
52impl std::error::Error for DeviceFlowError {}
53
54impl From<reqwest::Error> for DeviceFlowError {
55 fn from(e: reqwest::Error) -> Self {
56 DeviceFlowError::HttpError(format!("{:?}", e))
57 }
58}
59
60pub fn authorize(client_id: String, host: Option<String>) -> Result<Credential, DeviceFlowError> {
61 let my_string: String;
62 let thost = match host {
63 Some(string) => {
64 my_string = string;
65 Some(my_string.as_str())
66 },
67 None => None
68 };
69
70 let mut flow = DeviceFlow::start(client_id.as_str(), thost)?;
71
72 eprintln!("Please visit {} in your browser", flow.verification_uri.clone().unwrap());
74 eprintln!("And enter code: {}", flow.user_code.clone().unwrap());
75
76 thread::sleep(FIVE_SECONDS);
77
78 flow.poll(20)
79}
80
81pub fn refresh(client_id: &str, refresh_token: &str, host: Option<String>) -> Result<Credential, DeviceFlowError> {
82 let my_string: String;
83 let thost = match host {
84 Some(string) => {
85 my_string = string;
86 Some(my_string.as_str())
87 },
88 None => None
89 };
90
91 refresh_access_token(client_id, refresh_token, thost)
92}
93
94#[derive(Debug, Clone)]
95pub enum DeviceFlowState {
96 Pending,
97 Processing(time::Duration),
98 Success(Credential),
99 Failure(DeviceFlowError)
100}
101
102#[derive(Clone)]
103pub struct DeviceFlow {
104 pub host: String,
105 pub client_id: String,
106 pub user_code: Option<String>,
107 pub device_code: Option<String>,
108 pub verification_uri: Option<String>,
109 pub state: DeviceFlowState,
110}
111
112const FIVE_SECONDS: time::Duration = time::Duration::new(5, 0);
113
114impl DeviceFlow {
115 pub fn new(client_id: &str, maybe_host: Option<&str>) -> Self {
116 Self{
117 client_id: String::from(client_id),
118 host: match maybe_host {
119 Some(string) => String::from(string),
120 None => String::from("github.com")
121 },
122 user_code: None,
123 device_code: None,
124 verification_uri: None,
125 state: DeviceFlowState::Pending,
126 }
127 }
128
129 pub fn start(client_id: &str, maybe_host: Option<&str>) -> Result<DeviceFlow, DeviceFlowError> {
130 let mut flow = DeviceFlow::new(client_id, maybe_host);
131
132 flow.setup();
133
134 match flow.state {
135 DeviceFlowState::Processing(_) => Ok(flow.to_owned()),
136 DeviceFlowState::Failure(err) => Err(err),
137 _ => Err(util::credential_error("Something truly unexpected happened".into()))
138 }
139 }
140
141 pub fn setup(&mut self) {
142 let body = format!("client_id={}", &self.client_id);
143 let entry_url = format!("https://{}/login/device/code", &self.host);
144
145 if let Some(res) = util::send_request(self, entry_url, body) {
146 if res.contains_key("error") && res.contains_key("error_description"){
147 self.state = DeviceFlowState::Failure(util::credential_error(res["error_description"].as_str().unwrap().into()))
148 } else if res.contains_key("error") {
149 self.state = DeviceFlowState::Failure(util::credential_error(format!("Error response: {:?}", res["error"].as_str().unwrap())))
150 } else {
151 self.user_code = Some(String::from(res["user_code"].as_str().unwrap()));
152 self.device_code = Some(String::from(res["device_code"].as_str().unwrap()));
153 self.verification_uri = Some(String::from(res["verification_uri"].as_str().unwrap()));
154 self.state = DeviceFlowState::Processing(FIVE_SECONDS);
155 }
156 };
157 }
158
159 pub fn poll(&mut self, iterations: u32) -> Result<Credential, DeviceFlowError> {
160 for count in 0..iterations {
161 self.update();
162
163 if let DeviceFlowState::Processing(interval) = self.state {
164 if count == iterations {
165 return Err(util::credential_error("Max poll iterations reached".into()))
166 }
167
168 thread::sleep(interval);
169 } else {
170 break
171 }
172 };
173
174 match &self.state {
175 DeviceFlowState::Success(cred) => Ok(cred.to_owned()),
176 DeviceFlowState::Failure(err) => Err(err.to_owned()),
177 _ => Err(util::credential_error("Unable to fetch credential, sorry :/".into()))
178 }
179 }
180
181 pub fn update(&mut self) {
182 let poll_url = format!("https://{}/login/oauth/access_token", self.host);
183 let poll_payload = format!("client_id={}&device_code={}&grant_type=urn:ietf:params:oauth:grant-type:device_code",
184 self.client_id,
185 &self.device_code.clone().unwrap()
186 );
187
188 if let Some(res) = util::send_request(self, poll_url, poll_payload) {
189 if res.contains_key("error") {
190 match res["error"].as_str().unwrap() {
191 "authorization_pending" => {},
192 "slow_down" => {
193 if let DeviceFlowState::Processing(current_interval) = self.state {
194 self.state = DeviceFlowState::Processing(current_interval + FIVE_SECONDS);
195 };
196 },
197 other_reason => {
198 self.state = DeviceFlowState::Failure(
199 util::credential_error(format!("Error checking for token: {}", other_reason))
200 );
201 },
202 }
203 } else {
204 let mut this_credential = Credential::empty();
205 this_credential.token = res["access_token"].as_str().unwrap().to_string();
206
207 if let Some(expires_in) = res.get("expires_in") {
208 this_credential.expiry = calculate_expiry(expires_in.as_i64().unwrap());
209 this_credential.refresh_token = res["refresh_token"].as_str().unwrap().to_string();
210 }
211
212 self.state = DeviceFlowState::Success(this_credential);
213 }
214 }
215 }
216}
217
218fn calculate_expiry(expires_in: i64) -> String {
219 let expires_in = Duration::seconds(expires_in);
220 let mut expiry: DateTime<Utc> = Utc::now();
221 expiry = expiry + expires_in;
222 expiry.to_rfc3339()
223}
224
225fn refresh_access_token(client_id: &str, refresh_token: &str, maybe_host: Option<&str>) -> Result<Credential, DeviceFlowError> {
226 let host = match maybe_host {
227 Some(string) => string,
228 None => "github.com"
229 };
230
231 let client = reqwest::blocking::Client::new();
232 let entry_url = format!("https://{}/login/oauth/access_token", &host);
233 let request_body = format!("client_id={}&refresh_token={}&client_secret=&grant_type=refresh_token",
234 &client_id, &refresh_token);
235
236 let res = client.post(&entry_url)
237 .header("Accept", "application/json")
238 .body(request_body)
239 .send()?
240 .json::<HashMap<String, serde_json::Value>>()?;
241
242 if res.contains_key("error") {
243 Err(util::credential_error(res["error"].as_str().unwrap().into()))
244 } else {
245 let mut credential = Credential::empty();
246 credential.token = res["access_token"].as_str().unwrap().to_string();
248
249 if let Some(expires_in) = res.get("expires_in") {
250 credential.expiry = calculate_expiry(expires_in.as_i64().unwrap());
251 credential.refresh_token = res["refresh_token"].as_str().unwrap().to_string();
252 }
253
254 Ok(credential.clone())
255 }
256}
257
258#[cfg(test)]
259mod tests {
260 use crate::{Credential};
261 use chrono::offset::Utc;
262 use chrono::{DateTime, Duration};
263
264 #[test]
265 fn credential_expiry_is_expired_returns_false_when_expiry_is_in_the_future() {
266 let expires_in = Duration::seconds(28800);
267 let mut expiry: DateTime<Utc> = Utc::now();
268 expiry = expiry + expires_in;
269 let calculated_expiry = expiry.to_rfc3339();
270
271 let credential = Credential {
272 token: String::from("irrelevant"),
273 expiry: calculated_expiry,
274 refresh_token: String::from("irrelevant"),
275 };
276
277 eprintln!("{:?}", credential);
278
279 assert_eq!(true, credential.is_expired());
280 }
281
282 #[test]
283 fn credential_expiry_is_expired_returns_true_when_expiry_is_in_the_past() {
284 let expires_in = Duration::seconds(42);
285 let mut expiry: DateTime<Utc> = Utc::now();
286 expiry = expiry - expires_in;
287 let calculated_expiry = expiry.to_rfc3339();
288
289 let credential = Credential {
290 token: String::from("irrelevant"),
291 expiry: calculated_expiry,
292 refresh_token: String::from("irrelevant"),
293 };
294
295 assert_eq!(true, credential.is_expired());
296 }
297}