1use rand::{self, Rng};
62use std;
63use std::cell::{Cell, RefCell};
64use std::collections::HashMap;
65use std::sync::{Arc, Mutex};
66use std::thread;
67use std::time::{Duration, Instant};
68
69use base64;
70use failure::Error;
71use futures::future::ok;
72use futures::sync::oneshot::{self, Sender};
73use futures::Future;
74use hyper::header::{self, HeaderValue};
75use hyper::server::Server;
76use hyper::service::{MakeService, Service};
77use hyper::{Body, Error as HyperError, Method, Request, Response};
78use open;
79use url::{self, Url};
80
81use errors::RedditError;
82use net::body_from_map;
83use net::Connection;
84
85pub type ResponseGenFn = (Fn(&Result<String, InstalledAppError>) -> Response<Body>) + Send + Sync;
87
88type CodeSender = Arc<Mutex<Option<Sender<Result<String, InstalledAppError>>>>>;
89
90#[derive(Debug, Clone)]
93pub enum OAuth {
94 Script {
96 id: String,
98 secret: String,
100 username: String,
102 password: String,
104 token: String,
106 },
107 InstalledApp {
109 id: String,
111 redirect: String,
113 token: RefCell<String>,
115 refresh_token: RefCell<Option<String>>,
118 expire_instant: Cell<Option<Instant>>,
120 },
121}
122
123impl OAuth {
124 pub fn refresh(&self, conn: &Connection) -> Result<(), Error> {
126 match *self {
127 OAuth::Script { .. } => Ok(()),
128 OAuth::InstalledApp {
129 ref id,
130 redirect: ref _redirect,
131 ref token,
132 ref refresh_token,
133 ref expire_instant,
134 } => {
135 let old_refresh_token = if let Some(ref refresh_token) = *refresh_token.borrow() { refresh_token.clone() } else { return Err(RedditError::AuthError.into()) };
136 let mut params: HashMap<&str, &str> = HashMap::new();
138 params.insert("grant_type", "refresh_token");
139 params.insert("refresh_token", &old_refresh_token);
140
141 let mut tokenreq = Request::builder().method(Method::POST).uri("https://www.reddit.com/api/v1/access_token/.json").body(body_from_map(¶ms)).unwrap();
143 tokenreq.headers_mut().insert(header::AUTHORIZATION, HeaderValue::from_str(&format!("Basic {}", { base64::encode(&format!("{}:", id)) })).unwrap());
145
146 let response = conn.run_request(tokenreq)?;
148
149 if let (Some(expires_in), Some(new_token), Some(scope)) = (response.get("expires_in"), response.get("access_token"), response.get("scope")) {
150 let expires_in = expires_in.as_u64().unwrap();
151 let new_token = new_token.as_str().unwrap();
152 let _scope = scope.as_str().unwrap();
153 *token.borrow_mut() = new_token.to_string();
154 expire_instant.set(Some(Instant::now() + Duration::new(expires_in.to_string().parse::<u64>().unwrap(), 0)));
155
156 Ok(())
157 } else {
158 Err(Error::from(RedditError::AuthError))
159 }
160 }
161 }
162 }
163
164 pub fn create_script(conn: &Connection, id: &str, secret: &str, username: &str, password: &str) -> Result<OAuth, Error> {
172 let mut params: HashMap<&str, &str> = HashMap::new();
174 params.insert("grant_type", "password");
175 params.insert("username", &username);
176 params.insert("password", &password);
177
178 let mut tokenreq = Request::builder().method(Method::POST).uri("https://ssl.reddit.com/api/v1/access_token/.json").body(body_from_map(¶ms)).unwrap();
180 tokenreq.headers_mut().insert(header::AUTHORIZATION, HeaderValue::from_str(&format!("Basic {}", { base64::encode(&format!("{}:{}", id, secret)) })).unwrap());
182
183 let response = conn.run_request(tokenreq)?;
185
186 if let Some(token) = response.get("access_token") {
187 let token = token.as_str().unwrap().to_string();
188 Ok(OAuth::Script {
189 id: id.to_string(),
190 secret: secret.to_string(),
191 username: username.to_string(),
192 password: password.to_string(),
193 token,
194 })
195 } else {
196 Err(RedditError::AuthError.into())
197 }
198 }
199
200 pub fn create_installed_app<I: Into<Option<Arc<ResponseGenFn>>>>(conn: &Connection, id: &str, redirect: &str, response_gen: I, scopes: &Scopes) -> Result<OAuth, Error> {
213 let response_gen = response_gen.into();
214 let state = rand::thread_rng().gen_ascii_chars().take(16).collect::<String>();
216
217 let scopes = &scopes.to_string();
218 let browser_uri = format!(
219 "https://www.reddit.com/api/v1/authorize?client_id={}&response_type=code&\
220 state={}&redirect_uri={}&duration=permanent&scope={}",
221 id, state, redirect, scopes
222 );
223
224 let state_rc = Arc::new(state);
225
226 thread::spawn(move || {
228 open::that(browser_uri).expect("Failed to open browser");
229 });
230
231 let (code_sender, code_reciever) = oneshot::channel::<Result<String, InstalledAppError>>();
234
235 let redirect_url = Url::parse(&redirect)?;
237 let main_redirect = format!("{}:{}", redirect_url.host_str().unwrap_or("127.0.0.1"), redirect_url.port().unwrap_or(7878).to_string());
238
239 let response_gen = if let Some(ref response_gen) = response_gen {
241 Arc::clone(response_gen)
242 } else {
243 Arc::new(|res: &Result<String, InstalledAppError>| -> Response<Body> {
244 match res {
245 Ok(_) => Response::new("Successfully got the code".into()),
246 Err(e) => Response::new(format!("{}", e).into()),
247 }
248 })
249 };
250
251 let server = Server::bind(&main_redirect.as_str().parse()?).serve(MakeInstalledAppService {
254 code_sender: Arc::new(Mutex::new(Some(code_sender))),
255 state: Arc::clone(&state_rc),
256 response_gen: Arc::clone(&response_gen),
257 });
258
259 let code: Arc<Mutex<Result<String, InstalledAppError>>> = Arc::new(Mutex::new(Err(InstalledAppError::NeverRecieved)));
261 let code_clone = Arc::clone(&code);
262
263 let finish = code_reciever.then(move |new_code| {
265 let code = code_clone;
266 if let Ok(new_code) = new_code {
267 match new_code {
268 Ok(new_code) => {
269 *code.lock().unwrap() = Ok(new_code);
270 Ok(())
271 }
272 Err(e) => {
273 *code.lock().unwrap() = Err(e);
274 Err(())
275 }
276 }
277 } else {
278 Err(())
279 }
280 });
281
282 let graceful = server.with_graceful_shutdown(finish).map_err(|e| eprintln!("Server failed: {}", e));
283
284 hyper::rt::run(graceful);
286
287 let code = match *code.lock().unwrap() {
289 Ok(ref new_code) => new_code.clone(),
290 Err(ref e) => return Err(e.clone().into()),
291 };
292
293 let mut params: HashMap<&str, &str> = HashMap::new();
295 params.insert("grant_type", "authorization_code");
296 params.insert("code", &code);
297 params.insert("redirect_uri", &redirect);
298
299 let mut tokenreq = Request::builder().method(Method::POST).uri("https://ssl.reddit.com/api/v1/access_token/.json").body(body_from_map(¶ms)).unwrap();
301 tokenreq.headers_mut().insert(header::AUTHORIZATION, HeaderValue::from_str(&format!("Basic {}", base64::encode(&format!("{}:", id)))).unwrap());
303
304 let response = conn.run_request(tokenreq)?;
306
307 if let (Some(expires_in), Some(token), Some(refresh_token), Some(scope)) = (response.get("expires_in"), response.get("access_token"), response.get("refresh_token"), response.get("scope")) {
308 let expires_in = expires_in.as_u64().unwrap();
309 let token = token.as_str().unwrap();
310 let refresh_token = refresh_token.as_str().unwrap();
311 let _scope = scope.as_str().unwrap();
312 Ok(OAuth::InstalledApp {
313 id: id.to_string(),
314 redirect: redirect.to_string(),
315 token: RefCell::new(token.to_string()),
316 refresh_token: RefCell::new(Some(refresh_token.to_string())),
317 expire_instant: Cell::new(Some(Instant::now() + Duration::new(expires_in.to_string().parse::<u64>().unwrap(), 0))),
318 })
319 } else {
320 Err(Error::from(RedditError::AuthError))
321 }
322 }
323}
324
325pub struct Scopes {
330 pub identity: bool,
332 pub edit: bool,
334 pub flair: bool,
336 pub history: bool,
338 pub modconfig: bool,
340 pub modflair: bool,
342 pub modlog: bool,
344 pub modposts: bool,
346 pub modwiki: bool,
348 pub mysubreddits: bool,
350 pub privatemessages: bool,
352 pub read: bool,
354 pub report: bool,
356 pub save: bool,
358 pub submit: bool,
360 pub subscribe: bool,
362 pub vote: bool,
364 pub wikiedit: bool,
366 pub wikiread: bool,
368 pub account: bool,
370}
371
372impl Scopes {
373 pub fn empty() -> Scopes {
375 Scopes {
376 identity: false,
377 edit: false,
378 flair: false,
379 history: false,
380 modconfig: false,
381 modflair: false,
382 modlog: false,
383 modposts: false,
384 modwiki: false,
385 mysubreddits: false,
386 privatemessages: false,
387 read: false,
388 report: false,
389 save: false,
390 submit: false,
391 subscribe: false,
392 vote: false,
393 wikiedit: false,
394 wikiread: false,
395 account: false,
396 }
397 }
398
399 pub fn all() -> Scopes {
401 Scopes {
402 identity: true,
403 edit: true,
404 flair: true,
405 history: true,
406 modconfig: true,
407 modflair: true,
408 modlog: true,
409 modposts: true,
410 modwiki: true,
411 mysubreddits: true,
412 privatemessages: true,
413 read: true,
414 report: true,
415 save: true,
416 submit: true,
417 subscribe: true,
418 vote: true,
419 wikiedit: true,
420 wikiread: true,
421 account: true,
422 }
423 }
424
425 fn to_string(&self) -> String {
427 let mut string = String::new();
428 if self.identity {
429 string.push_str("identity");
430 }
431 if self.edit {
432 string.push_str(",edit");
433 }
434 if self.flair {
435 string.push_str(",flair");
436 }
437 if self.history {
438 string.push_str(",history");
439 }
440 if self.modconfig {
441 string.push_str(",modconfig");
442 }
443 if self.modflair {
444 string.push_str(",modflair");
445 }
446 if self.modlog {
447 string.push_str(",modlog");
448 }
449 if self.modposts {
450 string.push_str(",modposts");
451 }
452 if self.modwiki {
453 string.push_str(",modwiki");
454 }
455 if self.mysubreddits {
456 string.push_str(",mysubreddits");
457 }
458 if self.privatemessages {
459 string.push_str(",privatemessages");
460 }
461 if self.read {
462 string.push_str(",read");
463 }
464 if self.report {
465 string.push_str(",report");
466 }
467 if self.save {
468 string.push_str(",save");
469 }
470 if self.submit {
471 string.push_str(",submit");
472 }
473 if self.subscribe {
474 string.push_str(",subscribe");
475 }
476 if self.vote {
477 string.push_str(",vote");
478 }
479 if self.wikiedit {
480 string.push_str(",wikiedit");
481 }
482 if self.wikiread {
483 string.push_str(",wikiread");
484 }
485 if self.account {
486 string.push_str(",account");
487 }
488
489 string
490 }
491}
492
493#[derive(Debug, Fail, Clone)]
495pub enum InstalledAppError {
496 #[fail(display = "Got an unknown error: {}", msg)]
498 Error {
499 msg: String,
501 },
502 #[fail(display = "The states did not match")]
504 MismatchedState,
505 #[fail(display = "A code was already recieved")]
507 AlreadyRecieved,
508 #[fail(display = "No message was ever recieved")]
510 NeverRecieved,
511}
512
513struct MakeInstalledAppService {
514 code_sender: CodeSender,
515 state: Arc<String>,
516 response_gen: Arc<ResponseGenFn>,
517}
518
519impl<Ctx> MakeService<Ctx> for MakeInstalledAppService {
520 type ReqBody = Body;
521 type ResBody = Body;
522 type Error = hyper::Error;
523 type Service = InstalledAppService;
524 type Future = Box<Future<Item = Self::Service, Error = Self::MakeError> + Send + Sync>;
525 type MakeError = Box<dyn std::error::Error + Send + Sync>;
526
527 fn make_service(&mut self, _ctx: Ctx) -> Self::Future {
528 Box::new(futures::future::ok(InstalledAppService {
529 code_sender: Arc::clone(&self.code_sender),
530 state: Arc::clone(&self.state),
531 response_gen: Arc::clone(&self.response_gen),
532 }))
533 }
534}
535
536struct InstalledAppService {
540 code_sender: CodeSender,
541 state: Arc<String>,
542 response_gen: Arc<ResponseGenFn>,
543}
544
545impl Service for InstalledAppService {
546 type ReqBody = Body;
547 type ResBody = Body;
548 type Error = HyperError;
549 type Future = Box<Future<Item = Response<Self::ResBody>, Error = Self::Error> + Send>;
550
551 fn call(&mut self, req: Request<Self::ReqBody>) -> Self::Future {
552 let query_str = req.uri().path_and_query().unwrap().as_str();
554 let query_str = &query_str[2..query_str.len()];
555 let params: HashMap<_, _> = url::form_urlencoded::parse(query_str.as_bytes()).collect();
556
557 fn create_res(gen: &ResponseGenFn, res: &Result<String, InstalledAppError>, sender: &CodeSender) -> <InstalledAppService as Service>::Future {
560 let mut sender = sender.lock().unwrap();
561 let sender = if let Some(sender) = sender.take() {
562 sender
563 } else {
564 return Box::new(ok(gen(&Err(InstalledAppError::AlreadyRecieved))));
565 };
566 let resp = match sender.send(res.clone()) {
567 Ok(_) => gen(&res),
568 Err(_) => gen(&Err(InstalledAppError::AlreadyRecieved)),
569 };
570 Box::new(ok(resp))
571 }
572
573 if params.contains_key("error") {
575 warn!("Got failed authorization. Error was {}", ¶ms["error"]);
576 let err = InstalledAppError::Error { msg: params["error"].to_string() };
577 create_res(&*self.response_gen, &Err(err.clone()), &self.code_sender)
578 } else {
579 let state = if let Some(state) = params.get("state") {
581 state
582 } else {
583 return create_res(&*self.response_gen, &Err(InstalledAppError::MismatchedState), &self.code_sender);
585 };
586 if *state != *self.state {
588 error!("State didn't match. Got state \"{}\", needed state \"{}\"", state, self.state);
589 create_res(&*self.response_gen, &Err(InstalledAppError::MismatchedState), &self.code_sender)
590 } else {
591 let code = ¶ms["code"];
593 create_res(&*self.response_gen, &Ok(code.clone().into()), &self.code_sender)
594 }
595 }
596 }
597}
598
599trait RefCellExt<T> {
602 fn pop(&self) -> Option<T>;
603}
604
605impl<T: std::fmt::Debug> RefCellExt<T> for RefCell<Option<T>> {
606 fn pop(&self) -> Option<T> {
607 if self.borrow().is_some() {
608 return std::mem::replace(&mut *self.borrow_mut(), None);
609 }
610
611 None
612 }
613}