use std::{borrow::Cow, env, time::Duration};
use chrono::Utc;
use hyper::{
body::{aggregate, Buf},
header::{ACCEPT, AUTHORIZATION, USER_AGENT},
Body, Request, Uri,
};
use serde::de::DeserializeOwned;
use thiserror::Error;
use crate::models::{
auth::RedditAuth,
fullname::FullName,
link::{RedditListing, Sort},
};
pub struct RedditClient {
pub(crate) client: hyper::Client<hyper_rustls::HttpsConnector<hyper::client::HttpConnector>>,
#[allow(dead_code)]
pub(crate) client_secret: String,
#[allow(dead_code)]
pub(crate) client_id: String,
#[allow(dead_code)]
pub(crate) base_path: Uri,
pub(crate) auth: RedditAuth,
pub(crate) user_agent: &'static str,
}
#[derive(Debug, thiserror::Error)]
pub enum MissingEnvVariableError {
#[error("REDDIT_CLIENT_ID environment variable missing")]
RedditClientId,
#[error("REDDIT_CLIENT_SECRET environment variable missing")]
RedditClientSecret,
}
pub struct Config<'a> {
pub client_id: &'a str,
pub client_secret: &'a str,
pub client_name: &'a str,
}
impl RedditClient {
pub async fn from_env(client_name: &str) -> Result<Self, RedditError> {
#[allow(clippy::unwrap_used)]
Self::try_from_env(client_name).await.unwrap()
}
pub async fn try_from_env(
client_name: &str,
) -> Result<Result<Self, RedditError>, MissingEnvVariableError> {
dotenv::dotenv().ok();
#[allow(clippy::disallowed_method)]
let client_id = match env::var("REDDIT_CLIENT_ID") {
Ok(client_id) => client_id,
Err(_) => return Err(MissingEnvVariableError::RedditClientId),
};
#[allow(clippy::disallowed_method)]
let client_secret = match env::var("REDDIT_CLIENT_SECRET") {
Ok(client_secret) => client_secret,
Err(_) => return Err(MissingEnvVariableError::RedditClientSecret),
};
Self::from_config(Config {
client_id: &client_id,
client_secret: &client_secret,
client_name,
})
.await
}
pub async fn from_config(
Config {
client_id,
client_secret,
client_name,
}: Config<'_>,
) -> Result<Result<Self, RedditError>, MissingEnvVariableError> {
let base_path = "https://www.reddit.com".parse().expect("infallible");
let https = hyper_rustls::HttpsConnector::with_native_roots();
let client = hyper::Client::builder().build::<_, hyper::Body>(https);
let auth = match RedditClient::authorize(&client, client_id, client_secret).await {
Ok(auth) => auth,
Err(err) => return Ok(Err(err)),
};
let user_agent: &'static str = {
let version = env!("CARGO_PKG_VERSION");
let user_agent = format!(
"ubuntu:{name}:{version} (by /u/benluelo)",
name = client_name,
version = version
);
Box::leak(Box::new(user_agent))
};
Ok(Ok(Self {
client,
client_secret: client_secret.into(),
client_id: client_id.into(),
base_path,
auth,
user_agent,
}))
}
async fn authorize(
client: &hyper::Client<hyper_rustls::HttpsConnector<hyper::client::HttpConnector>>,
client_id: &str,
client_secret: &str,
) -> Result<RedditAuth, RedditError> {
let auth_path = format!(
"{}/{}",
std::env::temp_dir().display(),
"reddit_client_authorization.json"
);
if let Ok(Ok(auth)) = std::fs::read_to_string(&auth_path)
.as_deref()
.map(serde_json::from_str::<RedditAuth>)
{
if auth.expires_at > Utc::now() {
println!("cache hit for auth");
return Ok(auth);
}
println!("fetching new auth");
}
let response = client
.request(
Request::post("https://www.reddit.com/api/v1/access_token")
.header(
AUTHORIZATION,
dbg!(format!(
"Basic {}",
base64::encode_config(
format!("{}:{}", client_id, client_secret),
base64::URL_SAFE
)
)),
)
.body(hyper::Body::from("grant_type=client_credentials"))
.expect("infallible"),
)
.await?;
let buf = aggregate(response).await?;
let auth = serde_json::from_reader::<_, RedditAuth>(buf.reader())?;
if let Err(why) = std::fs::write(
&auth_path,
serde_json::to_string_pretty(&auth).expect("infallible"),
) {
log::warn!(target: "reddit_client", "Unable to cache auth file: {}", why);
}
Ok(auth)
}
pub(crate) fn auth_header_value(&self) -> String {
format!("Bearer {}", &self.auth.access_token)
}
pub async fn get_subreddit_posts(
&mut self,
subreddit: &str,
sorting: Sort,
after: Option<&FullName>,
) -> Result<RedditListing, RedditError> {
let after: Cow<_> = if let Some(fullname) = after {
format!("after={}", fullname).into()
} else {
"".into()
};
let uri = Uri::builder()
.scheme("https")
.authority("oauth.reddit.com")
.path_and_query(format!(
"/r/{subreddit}/{sorting}.json?{after}",
subreddit = subreddit,
sorting = sorting,
after = &after
))
.build()
.expect("Uri builder shouldn't fail");
self.request(uri).await
}
async fn request<T: DeserializeOwned>(&mut self, uri: Uri) -> Result<T, RedditError> {
self.check_auth().await?;
tokio::time::sleep(Duration::from_secs(1)).await;
let request = self
.base_request(uri)
.body(Body::empty())
.expect("infallible");
let response = self.client.request(request).await?;
let mut buf = aggregate(response).await?.reader();
let mut bytes = Vec::new();
std::io::copy(&mut buf, &mut bytes)?;
let json = String::from_utf8(bytes)?;
let listings = serde_json::from_str(&json)?;
Ok(listings)
}
fn base_request(&self, uri: Uri) -> hyper::http::request::Builder {
Request::get(uri)
.header(AUTHORIZATION, &self.auth_header_value())
.header(ACCEPT, "*/*")
.header(USER_AGENT, self.user_agent)
}
pub(crate) async fn check_auth(&mut self) -> Result<(), RedditError> {
if self.auth.expires_at <= Utc::now() {
self.auth =
RedditClient::authorize(&self.client, &self.client_id, &self.client_secret).await?;
Ok(())
} else {
Ok(())
}
}
pub async fn get_comments(&mut self) -> Result<Vec<RedditListing>, RedditError> {
let uri: Uri = "https://oauth.reddit.com/r/benluelo_testing/comments/qbq1jr/yeet/.json"
.parse()
.expect("");
self.request(uri).await
}
}
#[derive(Debug, Error)]
pub enum RedditError {
#[error("error fetching resource")]
Request(#[from] hyper::Error),
#[error("error deserializing resource")]
Deserialize(#[from] serde_json::Error),
#[error("Payload was not valid UTF8")]
Utf8(#[from] std::string::FromUtf8Error),
#[error("IO error")]
Io(#[from] std::io::Error),
}