use rand::{self, Rng};
use std;
use std::cell::{Cell, RefCell};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::thread;
use std::time::{Duration, Instant};
use base64;
use failure::Error;
use futures::future::ok;
use futures::sync::oneshot::{self, Sender};
use futures::Future;
use hyper::header::{self, HeaderValue};
use hyper::server::Server;
use hyper::service::{MakeService, Service};
use hyper::{Body, Error as HyperError, Method, Request, Response};
use open;
use url::{self, Url};
use errors::RedditError;
use net::body_from_map;
use net::Connection;
pub type ResponseGenFn = (Fn(&Result<String, InstalledAppError>) -> Response<Body>) + Send + Sync;
type CodeSender = Arc<Mutex<Option<Sender<Result<String, InstalledAppError>>>>>;
#[derive(Debug, Clone)]
pub enum OAuth {
Script {
id: String,
secret: String,
username: String,
password: String,
token: String,
},
InstalledApp {
id: String,
redirect: String,
token: RefCell<String>,
refresh_token: RefCell<Option<String>>,
expire_instant: Cell<Option<Instant>>,
},
}
impl OAuth {
pub fn refresh(&self, conn: &Connection) -> Result<(), Error> {
match *self {
OAuth::Script { .. } => Ok(()),
OAuth::InstalledApp {
ref id,
redirect: ref _redirect,
ref token,
ref refresh_token,
ref expire_instant,
} => {
let old_refresh_token = if let Some(ref refresh_token) = *refresh_token.borrow() { refresh_token.clone() } else { return Err(RedditError::AuthError.into()) };
let mut params: HashMap<&str, &str> = HashMap::new();
params.insert("grant_type", "refresh_token");
params.insert("refresh_token", &old_refresh_token);
let mut tokenreq = Request::builder().method(Method::POST).uri("https://www.reddit.com/api/v1/access_token/.json").body(body_from_map(¶ms)).unwrap();
tokenreq.headers_mut().insert(header::AUTHORIZATION, HeaderValue::from_str(&format!("Basic {}", { base64::encode(&format!("{}:", id)) })).unwrap());
let response = conn.run_request(tokenreq)?;
if let (Some(expires_in), Some(new_token), Some(scope)) = (response.get("expires_in"), response.get("access_token"), response.get("scope")) {
let expires_in = expires_in.as_u64().unwrap();
let new_token = new_token.as_str().unwrap();
let _scope = scope.as_str().unwrap();
*token.borrow_mut() = new_token.to_string();
expire_instant.set(Some(Instant::now() + Duration::new(expires_in.to_string().parse::<u64>().unwrap(), 0)));
Ok(())
} else {
Err(Error::from(RedditError::AuthError))
}
}
}
}
pub fn create_script(conn: &Connection, id: &str, secret: &str, username: &str, password: &str) -> Result<OAuth, Error> {
let mut params: HashMap<&str, &str> = HashMap::new();
params.insert("grant_type", "password");
params.insert("username", &username);
params.insert("password", &password);
let mut tokenreq = Request::builder().method(Method::POST).uri("https://ssl.reddit.com/api/v1/access_token/.json").body(body_from_map(¶ms)).unwrap();
tokenreq.headers_mut().insert(header::AUTHORIZATION, HeaderValue::from_str(&format!("Basic {}", { base64::encode(&format!("{}:{}", id, secret)) })).unwrap());
let response = conn.run_request(tokenreq)?;
if let Some(token) = response.get("access_token") {
let token = token.as_str().unwrap().to_string();
Ok(OAuth::Script {
id: id.to_string(),
secret: secret.to_string(),
username: username.to_string(),
password: password.to_string(),
token,
})
} else {
Err(RedditError::AuthError.into())
}
}
pub fn create_installed_app<I: Into<Option<Arc<ResponseGenFn>>>>(conn: &Connection, id: &str, redirect: &str, response_gen: I, scopes: &Scopes) -> Result<OAuth, Error> {
let response_gen = response_gen.into();
let state = rand::thread_rng().gen_ascii_chars().take(16).collect::<String>();
let scopes = &scopes.to_string();
let browser_uri = format!(
"https://www.reddit.com/api/v1/authorize?client_id={}&response_type=code&\
state={}&redirect_uri={}&duration=permanent&scope={}",
id, state, redirect, scopes
);
let state_rc = Arc::new(state);
thread::spawn(move || {
open::that(browser_uri).expect("Failed to open browser");
});
let (code_sender, code_reciever) = oneshot::channel::<Result<String, InstalledAppError>>();
let redirect_url = Url::parse(&redirect)?;
let main_redirect = format!("{}:{}", redirect_url.host_str().unwrap_or("127.0.0.1"), redirect_url.port().unwrap_or(7878).to_string());
let response_gen = if let Some(ref response_gen) = response_gen {
Arc::clone(response_gen)
} else {
Arc::new(|res: &Result<String, InstalledAppError>| -> Response<Body> {
match res {
Ok(_) => Response::new("Successfully got the code".into()),
Err(e) => Response::new(format!("{}", e).into()),
}
})
};
let server = Server::bind(&main_redirect.as_str().parse()?).serve(MakeInstalledAppService {
code_sender: Arc::new(Mutex::new(Some(code_sender))),
state: Arc::clone(&state_rc),
response_gen: Arc::clone(&response_gen),
});
let code: Arc<Mutex<Result<String, InstalledAppError>>> = Arc::new(Mutex::new(Err(InstalledAppError::NeverRecieved)));
let code_clone = Arc::clone(&code);
let finish = code_reciever.then(move |new_code| {
let code = code_clone;
if let Ok(new_code) = new_code {
match new_code {
Ok(new_code) => {
*code.lock().unwrap() = Ok(new_code);
Ok(())
}
Err(e) => {
*code.lock().unwrap() = Err(e);
Err(())
}
}
} else {
Err(())
}
});
let graceful = server.with_graceful_shutdown(finish).map_err(|e| eprintln!("Server failed: {}", e));
hyper::rt::run(graceful);
let code = match *code.lock().unwrap() {
Ok(ref new_code) => new_code.clone(),
Err(ref e) => return Err(e.clone().into()),
};
let mut params: HashMap<&str, &str> = HashMap::new();
params.insert("grant_type", "authorization_code");
params.insert("code", &code);
params.insert("redirect_uri", &redirect);
let mut tokenreq = Request::builder().method(Method::POST).uri("https://ssl.reddit.com/api/v1/access_token/.json").body(body_from_map(¶ms)).unwrap();
tokenreq.headers_mut().insert(header::AUTHORIZATION, HeaderValue::from_str(&format!("Basic {}", base64::encode(&format!("{}:", id)))).unwrap());
let response = conn.run_request(tokenreq)?;
if let (Some(expires_in), Some(token), Some(refresh_token), Some(scope)) = (response.get("expires_in"), response.get("access_token"), response.get("refresh_token"), response.get("scope")) {
let expires_in = expires_in.as_u64().unwrap();
let token = token.as_str().unwrap();
let refresh_token = refresh_token.as_str().unwrap();
let _scope = scope.as_str().unwrap();
Ok(OAuth::InstalledApp {
id: id.to_string(),
redirect: redirect.to_string(),
token: RefCell::new(token.to_string()),
refresh_token: RefCell::new(Some(refresh_token.to_string())),
expire_instant: Cell::new(Some(Instant::now() + Duration::new(expires_in.to_string().parse::<u64>().unwrap(), 0))),
})
} else {
Err(Error::from(RedditError::AuthError))
}
}
}
pub struct Scopes {
pub identity: bool,
pub edit: bool,
pub flair: bool,
pub history: bool,
pub modconfig: bool,
pub modflair: bool,
pub modlog: bool,
pub modposts: bool,
pub modwiki: bool,
pub mysubreddits: bool,
pub privatemessages: bool,
pub read: bool,
pub report: bool,
pub save: bool,
pub submit: bool,
pub subscribe: bool,
pub vote: bool,
pub wikiedit: bool,
pub wikiread: bool,
pub account: bool,
}
impl Scopes {
pub fn empty() -> Scopes {
Scopes {
identity: false,
edit: false,
flair: false,
history: false,
modconfig: false,
modflair: false,
modlog: false,
modposts: false,
modwiki: false,
mysubreddits: false,
privatemessages: false,
read: false,
report: false,
save: false,
submit: false,
subscribe: false,
vote: false,
wikiedit: false,
wikiread: false,
account: false,
}
}
pub fn all() -> Scopes {
Scopes {
identity: true,
edit: true,
flair: true,
history: true,
modconfig: true,
modflair: true,
modlog: true,
modposts: true,
modwiki: true,
mysubreddits: true,
privatemessages: true,
read: true,
report: true,
save: true,
submit: true,
subscribe: true,
vote: true,
wikiedit: true,
wikiread: true,
account: true,
}
}
fn to_string(&self) -> String {
let mut string = String::new();
if self.identity {
string.push_str("identity");
}
if self.edit {
string.push_str(",edit");
}
if self.flair {
string.push_str(",flair");
}
if self.history {
string.push_str(",history");
}
if self.modconfig {
string.push_str(",modconfig");
}
if self.modflair {
string.push_str(",modflair");
}
if self.modlog {
string.push_str(",modlog");
}
if self.modposts {
string.push_str(",modposts");
}
if self.modwiki {
string.push_str(",modwiki");
}
if self.mysubreddits {
string.push_str(",mysubreddits");
}
if self.privatemessages {
string.push_str(",privatemessages");
}
if self.read {
string.push_str(",read");
}
if self.report {
string.push_str(",report");
}
if self.save {
string.push_str(",save");
}
if self.submit {
string.push_str(",submit");
}
if self.subscribe {
string.push_str(",subscribe");
}
if self.vote {
string.push_str(",vote");
}
if self.wikiedit {
string.push_str(",wikiedit");
}
if self.wikiread {
string.push_str(",wikiread");
}
if self.account {
string.push_str(",account");
}
string
}
}
#[derive(Debug, Fail, Clone)]
pub enum InstalledAppError {
#[fail(display = "Got an unknown error: {}", msg)]
Error {
msg: String,
},
#[fail(display = "The states did not match")]
MismatchedState,
#[fail(display = "A code was already recieved")]
AlreadyRecieved,
#[fail(display = "No message was ever recieved")]
NeverRecieved,
}
struct MakeInstalledAppService {
code_sender: CodeSender,
state: Arc<String>,
response_gen: Arc<ResponseGenFn>,
}
impl<Ctx> MakeService<Ctx> for MakeInstalledAppService {
type ReqBody = Body;
type ResBody = Body;
type Error = hyper::Error;
type Service = InstalledAppService;
type Future = Box<Future<Item = Self::Service, Error = Self::MakeError> + Send + Sync>;
type MakeError = Box<dyn std::error::Error + Send + Sync>;
fn make_service(&mut self, _ctx: Ctx) -> Self::Future {
Box::new(futures::future::ok(InstalledAppService {
code_sender: Arc::clone(&self.code_sender),
state: Arc::clone(&self.state),
response_gen: Arc::clone(&self.response_gen),
}))
}
}
struct InstalledAppService {
code_sender: CodeSender,
state: Arc<String>,
response_gen: Arc<ResponseGenFn>,
}
impl Service for InstalledAppService {
type ReqBody = Body;
type ResBody = Body;
type Error = HyperError;
type Future = Box<Future<Item = Response<Self::ResBody>, Error = Self::Error> + Send>;
fn call(&mut self, req: Request<Self::ReqBody>) -> Self::Future {
let query_str = req.uri().path_and_query().unwrap().as_str();
let query_str = &query_str[2..query_str.len()];
let params: HashMap<_, _> = url::form_urlencoded::parse(query_str.as_bytes()).collect();
fn create_res(gen: &ResponseGenFn, res: &Result<String, InstalledAppError>, sender: &CodeSender) -> <InstalledAppService as Service>::Future {
let mut sender = sender.lock().unwrap();
let sender = if let Some(sender) = sender.take() {
sender
} else {
return Box::new(ok(gen(&Err(InstalledAppError::AlreadyRecieved))));
};
let resp = match sender.send(res.clone()) {
Ok(_) => gen(&res),
Err(_) => gen(&Err(InstalledAppError::AlreadyRecieved)),
};
Box::new(ok(resp))
}
if params.contains_key("error") {
warn!("Got failed authorization. Error was {}", ¶ms["error"]);
let err = InstalledAppError::Error { msg: params["error"].to_string() };
create_res(&*self.response_gen, &Err(err.clone()), &self.code_sender)
} else {
let state = if let Some(state) = params.get("state") {
state
} else {
return create_res(&*self.response_gen, &Err(InstalledAppError::MismatchedState), &self.code_sender);
};
if *state != *self.state {
error!("State didn't match. Got state \"{}\", needed state \"{}\"", state, self.state);
create_res(&*self.response_gen, &Err(InstalledAppError::MismatchedState), &self.code_sender)
} else {
let code = ¶ms["code"];
create_res(&*self.response_gen, &Ok(code.clone().into()), &self.code_sender)
}
}
}
}
trait RefCellExt<T> {
fn pop(&self) -> Option<T>;
}
impl<T: std::fmt::Debug> RefCellExt<T> for RefCell<Option<T>> {
fn pop(&self) -> Option<T> {
if self.borrow().is_some() {
return std::mem::replace(&mut *self.borrow_mut(), None);
}
None
}
}