use std::collections::HashMap;
use std::env;
use std::fmt;
use dotenv::dotenv;
use serde::{Deserialize, Serialize};
use super::callback_server::get_browser_response;
use super::curl_utils::post;
use super::model::token::OAuthToken;
use super::util::convert_map_to_string;
use super::util::generate_random_string;
use super::util::open_browser;
#[allow(non_camel_case_types)]
#[derive(PartialEq, Debug)]
pub enum AuthorizationTime {
permanent,
temporary,
}
impl fmt::Display for AuthorizationTime {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{:?}", self)
}
}
#[derive(PartialEq, Clone, Serialize, Deserialize, Debug)]
pub enum OAuthState {
IDLE,
AUTHORIZED,
ERROR,
}
impl std::fmt::Display for OAuthState {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{:?}", self)
}
}
#[allow(non_camel_case_types)]
#[derive(PartialEq, Debug)]
pub enum RedditApiScope {
identity,
edit,
flair,
history,
modconfig,
modflair,
modlog,
modposts,
modwiki,
mysubreddits,
privatemessages,
read,
report,
save,
submit,
subscribe,
vote,
wikiedit,
wikiread,
}
impl std::fmt::Display for RedditApiScope {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{:?}", self)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RedditOAuth {
pub callback_url: String,
pub oauth_state: OAuthState,
pub error_string: Option<String>,
pub state_string: String,
pub client_credentials: RedditClientCredentials,
}
impl RedditOAuth {
pub fn default() -> RedditOAuth {
dotenv().ok();
let callback_url = env::var("REDIRECT_URI").unwrap_or_default();
RedditOAuth {
callback_url: callback_url,
oauth_state: OAuthState::IDLE,
error_string: None,
state_string: generate_random_string(10),
client_credentials: RedditClientCredentials::default(),
}
}
pub fn state_string(mut self, state_string: &str) -> RedditOAuth {
self.state_string = state_string.to_owned();
self
}
pub fn oauth_state(mut self, oauth_state: OAuthState) -> RedditOAuth {
self.oauth_state = oauth_state.clone();
self
}
pub fn error_state(mut self, error_state: &str) -> RedditOAuth {
self.error_string = Some(error_state.to_owned());
self
}
pub fn build(mut self) -> RedditOAuth {
let error_flag = self.callback_url == "" || self.state_string == "";
if error_flag {
self.oauth_state = OAuthState::ERROR;
self.error_string = Some("`callback_url` and `state_string` have to be set".to_owned());
return self;
}
self.oauth_state = OAuthState::IDLE;
self
}
pub fn authorize_client(
&mut self,
scope: &str,
duration: Option<AuthorizationTime>,
) -> Option<OAuthToken> {
let mut duration_string = AuthorizationTime::permanent.to_string();
if let Some(duration) = duration {
duration_string = duration.to_string();
}
let mut params: HashMap<String, String> = HashMap::new();
params.insert("response_type".to_owned(), "code".to_owned());
params.insert("duration".to_owned(), duration_string);
params.insert("scope".to_owned(), scope.to_owned());
params.insert("state".to_owned(), self.state_string.clone());
params.insert(
"client_id".to_owned(),
self.client_credentials.client_id.to_owned(),
);
params.insert("redirect_uri".to_owned(), self.callback_url.clone());
let query_string = convert_map_to_string(¶ms);
let authorize_url = format!("https://www.reddit.com/api/v1/authorize?{}", query_string);
match open_browser(&authorize_url) {
Err(e) => {
self.oauth_state = OAuthState::ERROR;
self.error_string = Some(e);
return None;
}
_ => {}
}
match get_browser_response(120, &self.state_string) {
Ok(code) => {
self.oauth_state = OAuthState::AUTHORIZED;
return self.get_initial_access_token(&code);
}
Err(e) => {
self.error_string = Some(e);
self.oauth_state = OAuthState::ERROR;
return None;
}
}
}
pub fn refresh_token(&mut self, to_refresh: &OAuthToken) -> Option<OAuthToken> {
if to_refresh.refresh_token == "" {
self.error_string = Some("Token not refreshable `refresh_token` is empty".to_string());
self.oauth_state = OAuthState::ERROR;
return None;
}
let base_url = "https://oauth.reddit.com/api/v1/access_token";
let data_string = format!(
"grant_type=refresh_token&refresh_token={}",
to_refresh.refresh_token
);
let data_header = format!(
"Authorization: Basic {}",
self.client_credentials.client_secret
);
let answer = post(base_url, &data_string, &data_header);
let bearer_token: OAuthToken = serde_json::from_str(&answer).unwrap();
return Some(OAuthToken {
access_token: bearer_token.access_token,
token_type: bearer_token.token_type,
expires_in: bearer_token.expires_in,
scope: bearer_token.scope,
refresh_token: to_refresh.refresh_token.to_string(),
});
}
pub fn revoke_token(&mut self, to_revoke: &OAuthToken) -> Result<(), String> {
let base_url = "https://oauth.reddit.com/api/v1/revoke_token";
let data_string = format!(
"token={}&token_type_hint=access_token",
to_revoke.access_token
);
let data_header = format!(
"Authorization: Basic {}",
self.client_credentials.client_secret
);
let answer = post(base_url, &data_string, &data_header);
if answer != "" {
return Err(
"Client credentials sent as HTTP Basic Authorization were invalid".to_string(),
);
}
Ok(())
}
pub fn get_initial_access_token(&self, state: &str) -> Option<OAuthToken> {
let data_field_string = format!(
"grant_type=authorization_code&code={}&redirect_uri={}",
state, self.callback_url
);
let data_header = format!(
"Authorization: Basic {}",
self.client_credentials.client_secret
);
let base_url = "https://www.reddit.com/api/v1/access_token";
let answer = post(base_url, &data_field_string, &data_header);
let bearer_token: OAuthToken = serde_json::from_str(&answer).unwrap();
return Some(bearer_token);
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RedditClientCredentials {
pub client_id: String,
pub client_secret: String,
}
impl RedditClientCredentials {
pub fn default() -> RedditClientCredentials {
dotenv().ok();
let client_id = env::var("CLIENT_ID").unwrap_or_default();
let client_secret = env::var("CLIENT_SECRET").unwrap_or_default();
trace!(
"RedditClientCredentials.default(): client_id:{:?}, client_secret:{:?}",
client_id,
client_secret
);
RedditClientCredentials {
client_id,
client_secret,
}
}
pub fn client_id(mut self, client_id: &str) -> RedditClientCredentials {
self.client_id = client_id.to_owned();
self
}
pub fn client_secret(mut self, client_secret: &str) -> RedditClientCredentials {
self.client_secret = client_secret.to_owned();
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_client_credentials_default() {
let rcc = RedditClientCredentials::default();
assert_eq!(rcc.client_id, "7tMofTv8Ip3-Ig".to_owned());
assert_eq!(rcc.client_secret, "N3RNb2ZUdjhJcDMtSWc6".to_owned());
}
}