#![warn(missing_docs)]
use log::debug;
use reqwest::Method;
use reqwest::{
self,
header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT},
};
use serde::Deserialize;
use serde_json::Value;
use std::collections::HashMap;
use std::{fs::File, io::prelude::*};
pub mod query_listing;
use query_listing::QueryListingRequest;
pub mod errors;
use errors::ApiError;
pub mod models;
use models::{subreddit::Subreddit, user::User};
const RATE_LIMIT_HEADER_NAMES: [&str; 3] = [
"X-Ratelimit-Used",
"X-Ratelimit-Remaining",
"X-Ratelimit-Reset",
];
#[derive(Debug, Deserialize, PartialEq, Clone)]
#[cfg_attr(test, derive(Default))]
pub struct Config {
pub username: String,
pub password: String,
pub user_agent: String,
pub client_id: String,
pub client_secret: String,
}
impl Config {
pub fn load_config(path: &str) -> Result<Self, ApiError> {
let mut file = File::open(path)?;
let mut contents = String::new();
file.read_to_string(&mut contents)?;
let c = serde_json::from_str::<Config>(&contents)?;
Ok(c)
}
}
pub struct Api {
config: Config,
client: reqwest::Client,
access_token: Option<AccessTokenResponse>,
pub whoami: Option<Value>,
}
impl Api {
pub fn new(config: Config) -> Self {
debug!("New API object created");
Api {
config,
client: reqwest::Client::new(),
access_token: None,
whoami: None,
}
}
pub fn do_login(&mut self) -> Result<(), ApiError> {
#[cfg(not(test))]
let url = "https://www.reddit.com";
#[cfg(test)]
let url = &mockito::server_url();
debug!("Performing login");
let mut form = HashMap::new();
form.insert("grant_type", "password");
form.insert("username", &self.config.username);
form.insert("password", &self.config.password);
let mut resp = self
.client
.post(&format!("{}/api/v1/access_token", url))
.header("User-Agent", self.config.user_agent.clone())
.basic_auth(&self.config.client_id, Some(&self.config.client_secret))
.form(&form)
.send()?;
debug!("Login response code = {}", resp.status().as_str());
let data = resp.json::<AccessTokenResponse>()?;
debug!("Access token is {}", data.token);
self.access_token = Some(data);
let whoami = self.get_whoami()?;
debug!("Returned whoami is {:?}", whoami);
self.whoami = Some(whoami);
Ok(())
}
pub fn get_whoami(&self) -> Result<Value, ApiError> {
let mut resp = self.query("GET", "api/v1/me", None, None)?;
let data: Value = resp.json()?;
Ok(data)
}
pub fn get_username(&self) -> Option<String> {
Some(self.whoami.as_ref()?["name"].as_str().unwrap().to_owned())
}
fn get_headers(&self) -> HeaderMap {
let mut headers = HeaderMap::new();
headers.insert(
USER_AGENT,
HeaderValue::from_str(&self.config.user_agent).unwrap(),
);
if self.access_token.is_some() {
let auth_header = HeaderValue::from_str(&format!(
"bearer {}",
self.access_token.as_ref().unwrap().token
))
.unwrap();
headers.insert(AUTHORIZATION, auth_header);
}
headers
}
fn reformat_path(&self, path: &str) -> String {
#[cfg(not(test))]
let url = "https://oauth.reddit.com";
#[cfg(test)]
let url = &mockito::server_url();
let path = if path.contains("{username}") {
debug!("Replacing 'username' macro");
path.replace("{username}", &self.get_username().unwrap())
} else {
path.to_owned()
};
format!("{}/{}", url, path)
}
fn process_response_headers(&self, headers: &HeaderMap) {
for header_name in &RATE_LIMIT_HEADER_NAMES {
if let Some(value) = headers.get(*header_name) {
debug!(">> Header {}: {}", header_name, value.to_str().unwrap());
}
}
}
pub fn query(
&self,
method: &str,
path: &str,
query: Option<Vec<(&str, &str)>>,
form_data: Option<HashMap<&str, &str>>,
) -> Result<reqwest::Response, ApiError> {
let method = Method::from_bytes(method.as_bytes()).unwrap();
let path = self.reformat_path(path);
let req = self
.client
.request(method, &path)
.headers(self.get_headers());
let req = match query {
Some(q) => req.query(&q),
None => req,
};
debug!("{:?}", req);
let resp = match form_data {
Some(fd) => req.form(&fd).send()?,
None => req.send()?,
};
let status = resp.status();
if status.is_client_error() || status.is_server_error() {
return Err(ApiError::from(format!("Error code {}", status.as_str(),)));
}
self.process_response_headers(&resp.headers());
Ok(resp)
}
pub fn query_listing(&self, ql: QueryListingRequest) -> Result<Vec<Value>, ApiError> {
debug!("Listing request call: {:?}", ql);
let method = Method::GET;
let path = self.reformat_path(&ql.path);
let headers = self.get_headers();
let req = self.client.request(method, &path).headers(headers);
let mut all_resp: Vec<Value> = Vec::new();
let mut after = match ql.after {
Some(a) => a.to_owned(),
None => String::new(),
};
let mut count = ql.count;
for _ in 0..ql.requests {
let req = req.try_clone().unwrap();
let req = if ql.params.is_empty() {
req.query(ql.params)
} else {
req
};
let mut listing_parms = vec![("limit", ql.limit.to_string())];
if !after.is_empty() {
listing_parms.push(("after", after));
}
if count > 0 {
listing_parms.push(("count", format!("{}", count)));
}
if ql.show_all {
listing_parms.push(("show", "all".to_owned()));
}
let req = req.query(&listing_parms);
let mut resp = req.send()?;
if resp.status().is_client_error() || resp.status().is_server_error() {
return Err(ApiError::from(format!(
"Error code {}",
resp.status().as_str()
)));
}
let data: Value = resp.json()?;
after = data["data"]["after"].as_str().unwrap().to_owned();
for item in data["data"]["children"].as_array().unwrap() {
count += 1;
all_resp.push(item.clone());
}
}
Ok(all_resp)
}
pub fn search_for_subreddit(&self, name: &str) -> Result<Vec<Subreddit>, ApiError> {
let mut resp = self.query(
"GET",
"api/search_reddit_names",
Some(vec![("query", name), ("exact", "false")]),
None,
)?;
let data: Value = resp.json()?;
Ok(data["names"]
.as_array()
.unwrap()
.iter()
.filter_map(|v| v.as_str())
.map(|e| Subreddit {
api: &self,
name: e.to_owned(),
})
.collect::<Vec<Subreddit>>())
}
pub fn get_subreddit(&self, name: &str) -> Result<Subreddit, ApiError> {
let matching = self.search_for_subreddit(name)?;
for sr in matching {
if sr.name == name {
return Ok(sr);
}
}
Err(ApiError::from(String::from("Subreddit not found")))
}
pub fn get_user(&self, name: &str) -> Result<User, ApiError> {
let mut resp = self.query("GET", &format!("user/{}/about", name), None, None)?;
let data: Value = resp.json()?;
Ok(User {
api: self,
about: data,
})
}
}
#[derive(Debug, Deserialize, PartialEq)]
struct AccessTokenResponse {
#[serde(alias = "access_token")]
token: String,
token_type: String,
expires_in: u64,
scope: String,
}
#[cfg(test)]
mod tests {
use super::{AccessTokenResponse, Api, Config, QueryListingRequest};
use mockito::mock;
use std::fs::File;
use std::io::Write;
use tempfile;
fn get_config() -> Config {
std::default::Default::default()
}
fn get_api() -> Api {
let config = get_config();
Api::new(config)
}
fn get_sample_atr() -> String {
String::from(
"{\"access_token\":\"aaaaa\",\"token_type\":\"bbbbb\", \
\"expires_in\":10000,\"scope\":\"ccccc\"}",
)
}
#[test]
fn load_config_from_disk() {
let original_content = "{\"username\":\"a\",\"password\":\"b\", \
\"user_agent\":\"c\",\"client_id\":\"d\",\"client_secret\":\"e\"}";
let dir = tempfile::tempdir().unwrap();
let file_name = "reddit_api-config.json";
let file_path = dir.path().join(file_name);
let mut file = File::create(&file_path).unwrap();
writeln!(file, "{}", original_content).unwrap();
let config = Config::load_config(&file_path.as_os_str().to_str().unwrap()).unwrap();
assert_eq!(config.username, "a");
assert_eq!(config.password, "b");
assert_eq!(config.user_agent, "c");
assert_eq!(config.client_id, "d");
assert_eq!(config.client_secret, "e");
}
#[test]
fn access_token_response_serialize() {
let atr: AccessTokenResponse = serde_json::from_str(&get_sample_atr()).unwrap();
assert_eq!(atr.token, String::from("aaaaa"));
assert_eq!(atr.token_type, String::from("bbbbb"));
assert_eq!(atr.expires_in, 10000);
assert_eq!(atr.scope, String::from("ccccc"));
}
#[test]
fn new_api() {
let config = get_config();
let api = get_api();
assert_eq!(api.config, config);
assert_eq!(api.access_token, None);
assert_eq!(api.whoami, None);
}
#[test]
fn do_login() {
let _m1 = mock("POST", "/api/v1/access_token")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(get_sample_atr())
.create();
let _m2 = mock("GET", "/api/v1/me")
.with_status(200)
.with_header("content-type", "application/json")
.with_body("{\"name\":\"test-name\"}")
.create();
let mut api = get_api();
api.do_login().unwrap();
let username = api.get_username().unwrap();
assert_eq!(username, "test-name");
_m1.assert();
_m2.assert();
}
#[test]
fn query_listing() {
let body = "{\"data\":{\"kind\":\"Listing\",\"after\":\"t3_ccccc\",\"children\": \
[{\"data\":{\"id\":\"aaaaa\"},\"kind\":\"t3\"},{\"data\":{\"id\":\"bbbbb\"}, \
\"kind\":\"t3\"},{\"data\":{\"id\":\"ccccc\"},\"kind\":\"t3\"}]}}";
let _m1 = mock("GET", "/some/endpoint?limit=3&show=all")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(body)
.create();
let ql = QueryListingRequest::new("some/endpoint", 3, 1);
let values = get_api().query_listing(ql).unwrap();
assert_eq!(values.len(), 3);
_m1.assert();
}
#[test]
fn search_for_subreddit() {
let body = "{\"names\":[\"rust1\",\"rust2\",\"rust3\"]}";
let _m1 = mock("GET", "/api/search_reddit_names?query=rust&exact=false")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(body)
.create();
let api = get_api();
let srs = api.search_for_subreddit("rust").unwrap();
assert_eq!(srs.len(), 3);
_m1.assert();
}
#[test]
fn get_subreddit() {
let body = "{\"names\":[\"rust\",\"rust1\",\"rust2\"]}";
let _m1 = mock("GET", "/api/search_reddit_names?query=rust1&exact=false")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(body)
.create();
let api = get_api();
let sr = api.get_subreddit("rust1").unwrap();
assert_eq!(sr.name, "rust1");
_m1.assert();
}
}