use arc_swap::ArcSwap;
use cached::proc_macro::cached;
use futures_lite::future::block_on;
use futures_lite::{future::Boxed, FutureExt};
use hyper::client::HttpConnector;
use hyper::header::HeaderValue;
use hyper::{body, body::Buf, header, Body, Client, Method, Request, Response, Uri};
use hyper_rustls::HttpsConnector;
use libflate::gzip;
use log::{error, trace, warn};
use once_cell::sync::Lazy;
use percent_encoding::{percent_encode, CONTROLS};
use serde_json::Value;
use std::sync::atomic::Ordering;
use std::sync::atomic::{AtomicBool, AtomicU16};
use std::{io, result::Result};
use crate::dbg_msg;
use crate::oauth::{force_refresh_token, token_daemon, Oauth};
use crate::server::RequestExt;
use crate::utils::{format_url, Post};
const REDDIT_URL_BASE: &str = "https://oauth.reddit.com";
const REDDIT_URL_BASE_HOST: &str = "oauth.reddit.com";
const REDDIT_SHORT_URL_BASE: &str = "https://redd.it";
const REDDIT_SHORT_URL_BASE_HOST: &str = "redd.it";
const ALTERNATIVE_REDDIT_URL_BASE: &str = "https://www.reddit.com";
const ALTERNATIVE_REDDIT_URL_BASE_HOST: &str = "www.reddit.com";
pub static HTTPS_CONNECTOR: Lazy<HttpsConnector<HttpConnector>> =
Lazy::new(|| hyper_rustls::HttpsConnectorBuilder::new().with_native_roots().https_only().enable_http2().build());
pub static CLIENT: Lazy<Client<HttpsConnector<HttpConnector>>> = Lazy::new(|| Client::builder().build::<_, Body>(HTTPS_CONNECTOR.clone()));
pub static OAUTH_CLIENT: Lazy<ArcSwap<Oauth>> = Lazy::new(|| {
let client = block_on(Oauth::new());
tokio::spawn(token_daemon());
ArcSwap::new(client.into())
});
pub static OAUTH_RATELIMIT_REMAINING: AtomicU16 = AtomicU16::new(99);
pub static OAUTH_IS_ROLLING_OVER: AtomicBool = AtomicBool::new(false);
const URL_PAIRS: [(&str, &str); 2] = [
(ALTERNATIVE_REDDIT_URL_BASE, ALTERNATIVE_REDDIT_URL_BASE_HOST),
(REDDIT_SHORT_URL_BASE, REDDIT_SHORT_URL_BASE_HOST),
];
#[cached(size = 1024, time = 600, result = true)]
#[async_recursion::async_recursion]
pub async fn canonical_path(path: String, tries: i8) -> Result<Option<String>, String> {
if tries == 0 {
return Ok(None);
}
let res = {
let mut res = None;
for (url_base, url_base_host) in URL_PAIRS {
res = reddit_short_head(path.clone(), true, url_base, url_base_host).await.ok();
if let Some(res) = &res {
if !res.status().is_client_error() {
break;
}
}
}
res
};
let res = res.ok_or_else(|| "Unable to make HEAD request to Reddit.".to_string())?;
let status = res.status().as_u16();
let policy_error = res.headers().get(header::RETRY_AFTER).is_some();
match status {
200..=299 => Ok(Some(path)),
301 => match res.headers().get(header::LOCATION) {
Some(val) => {
let Ok(original) = val.to_str() else {
return Err("Unable to decode Location header.".to_string());
};
let stripped_uri = original.strip_suffix(".json").unwrap_or(original).split('?').next().unwrap_or_default();
let uri = format_url(stripped_uri);
canonical_path(uri, tries - 1).await
}
None => Ok(None),
},
300..=399 => Ok(None),
429 => Err("Too many requests.".to_string()),
403 if policy_error => Err("Too many requests.".to_string()),
_ => Ok(
res
.headers()
.get(header::LOCATION)
.map(|val| percent_encode(val.as_bytes(), CONTROLS).to_string().trim_start_matches(REDDIT_URL_BASE).to_string()),
),
}
}
pub async fn proxy(req: Request<Body>, format: &str) -> Result<Response<Body>, String> {
let mut url = format!("{format}?{}", req.uri().query().unwrap_or_default());
for (name, value) in &req.params() {
url = url.replace(&format!("{{{name}}}"), value);
}
stream(&url, &req).await
}
async fn stream(url: &str, req: &Request<Body>) -> Result<Response<Body>, String> {
let parsed_uri = url.parse::<Uri>().map_err(|_| "Couldn't parse URL".to_string())?;
let client: &Lazy<Client<_, Body>> = &CLIENT;
let mut builder = Request::get(parsed_uri);
for &key in &["Range", "If-Modified-Since", "Cache-Control"] {
if let Some(value) = req.headers().get(key) {
builder = builder.header(key, value);
}
}
let stream_request = builder.body(Body::empty()).map_err(|_| "Couldn't build empty body in stream".to_string())?;
client
.request(stream_request)
.await
.map(|mut res| {
let mut rm = |key: &str| res.headers_mut().remove(key);
rm("access-control-expose-headers");
rm("server");
rm("vary");
rm("etag");
rm("x-cdn");
rm("x-cdn-client-region");
rm("x-cdn-name");
rm("x-cdn-server-region");
rm("x-reddit-cdn");
rm("x-reddit-video-features");
rm("Nel");
rm("Report-To");
res
})
.map_err(|e| e.to_string())
}
fn reddit_get(path: String, quarantine: bool) -> Boxed<Result<Response<Body>, String>> {
request(&Method::GET, path, true, quarantine, REDDIT_URL_BASE, REDDIT_URL_BASE_HOST)
}
fn reddit_short_head(path: String, quarantine: bool, base_path: &'static str, host: &'static str) -> Boxed<Result<Response<Body>, String>> {
request(&Method::HEAD, path, false, quarantine, base_path, host)
}
fn request(method: &'static Method, path: String, redirect: bool, quarantine: bool, base_path: &'static str, host: &'static str) -> Boxed<Result<Response<Body>, String>> {
let url = format!("{base_path}{path}");
let client: &Lazy<Client<_, Body>> = &CLIENT;
let mut headers: Vec<(String, String)> = vec![
("Host".into(), host.into()),
("Accept-Encoding".into(), if method == Method::GET { "gzip".into() } else { "identity".into() }),
(
"Cookie".into(),
if quarantine {
"_options=%7B%22pref_quarantine_optin%22%3A%20true%2C%20%22pref_gated_sr_optin%22%3A%20true%7D".into()
} else {
"".into()
},
),
];
{
let client = OAUTH_CLIENT.load_full();
for (key, value) in client.headers_map.clone() {
headers.push((key, value));
}
}
fastrand::shuffle(&mut headers);
let mut builder = Request::builder().method(method).uri(&url);
for (key, value) in headers {
builder = builder.header(key, value);
}
let builder = builder.body(Body::empty());
async move {
match builder {
Ok(req) => match client.request(req).await {
Ok(mut response) => {
if response.status().is_redirection() {
if !redirect {
return Ok(response);
};
let location_header = response.headers().get(header::LOCATION);
if location_header == Some(&HeaderValue::from_static(ALTERNATIVE_REDDIT_URL_BASE)) {
return Err("Reddit response was invalid".to_string());
}
return request(
method,
location_header
.map(|val| {
let new_path = percent_encode(val.as_bytes(), CONTROLS)
.to_string()
.trim_start_matches(REDDIT_URL_BASE)
.trim_start_matches(ALTERNATIVE_REDDIT_URL_BASE)
.to_string();
format!("{new_path}{}raw_json=1", if new_path.contains('?') { "&" } else { "?" })
})
.unwrap_or_default()
.to_string(),
true,
quarantine,
base_path,
host,
)
.await;
};
match response.headers().get(header::CONTENT_ENCODING) {
None => Ok(response),
Some(hdr) => {
match hdr.to_str() {
Ok(val) => match val {
"gzip" => {}
"identity" => return Ok(response),
_ => return Err("Reddit response was encoded with an unsupported compressor".to_string()),
},
Err(_) => return Err("Reddit response was invalid".to_string()),
}
let mut decompressed: Vec<u8>;
{
let mut aggregated_body = match body::aggregate(response.body_mut()).await {
Ok(b) => b.reader(),
Err(e) => return Err(e.to_string()),
};
let mut decoder = match gzip::Decoder::new(&mut aggregated_body) {
Ok(decoder) => decoder,
Err(e) => return Err(e.to_string()),
};
decompressed = Vec::<u8>::new();
if let Err(e) = io::copy(&mut decoder, &mut decompressed) {
return Err(e.to_string());
};
}
response.headers_mut().remove(header::CONTENT_ENCODING);
response.headers_mut().insert(header::CONTENT_LENGTH, decompressed.len().into());
*(response.body_mut()) = Body::from(decompressed);
Ok(response)
}
}
}
Err(e) => {
dbg_msg!("{method} {REDDIT_URL_BASE}{path}: {}", e);
Err(e.to_string())
}
},
Err(_) => Err("Post url contains non-ASCII characters".to_string()),
}
}
.boxed()
}
#[cached(size = 100, time = 30, result = true)]
pub async fn json(path: String, quarantine: bool) -> Result<Value, String> {
let err = |msg: &str, e: String, path: String| -> Result<Value, String> {
Err(format!("{msg}: {e} | {path}"))
};
let current_rate_limit = OAUTH_RATELIMIT_REMAINING.load(Ordering::SeqCst);
let is_rolling_over = OAUTH_IS_ROLLING_OVER.load(Ordering::SeqCst);
if current_rate_limit < 10 && !is_rolling_over {
warn!("Rate limit {current_rate_limit} is low. Spawning force_refresh_token()");
tokio::spawn(force_refresh_token());
}
OAUTH_RATELIMIT_REMAINING.fetch_sub(1, Ordering::SeqCst);
match reddit_get(path.clone(), quarantine).await {
Ok(response) => {
let status = response.status();
let reset: Option<String> = if let (Some(remaining), Some(reset), Some(used)) = (
response.headers().get("x-ratelimit-remaining").and_then(|val| val.to_str().ok().map(|s| s.to_string())),
response.headers().get("x-ratelimit-reset").and_then(|val| val.to_str().ok().map(|s| s.to_string())),
response.headers().get("x-ratelimit-used").and_then(|val| val.to_str().ok().map(|s| s.to_string())),
) {
trace!(
"Ratelimit remaining: Header says {remaining}, we have {current_rate_limit}. Resets in {reset}. Rollover: {}. Ratelimit used: {used}",
if is_rolling_over { "yes" } else { "no" },
);
if let Ok(val) = remaining.parse::<f32>() {
OAUTH_RATELIMIT_REMAINING.store(val.round() as u16, Ordering::SeqCst);
}
Some(reset)
} else {
None
};
match hyper::body::aggregate(response).await {
Ok(body) => {
let has_remaining = body.has_remaining();
if !has_remaining {
tokio::spawn(force_refresh_token());
return match reset {
Some(val) => Err(format!(
"Reddit rate limit exceeded. Try refreshing in a few seconds.\
Rate limit will reset in: {val}"
)),
None => Err("Reddit rate limit exceeded".to_string()),
};
}
match serde_json::from_reader(body.reader()) {
Ok(value) => {
let json: Value = value;
if let Some(data) = json.get("data") {
if let Some(is_suspended) = data.get("is_suspended").and_then(Value::as_bool) {
if is_suspended {
return Err("suspended".into());
}
}
}
if json["error"].is_i64() {
if json["message"] == "Unauthorized" {
error!("Forcing a token refresh");
let () = force_refresh_token().await;
return Err("OAuth token has expired. Please refresh the page!".to_string());
}
if json["reason"] == "quarantined" {
return Err("quarantined".into());
}
if json["reason"] == "gated" {
return Err("gated".into());
}
if json["reason"] == "private" {
return Err("private".into());
}
if json["reason"] == "banned" {
return Err("banned".into());
}
Err(format!("Reddit error {} \"{}\": {} | {path}", json["error"], json["reason"], json["message"]))
} else {
Ok(json)
}
}
Err(e) => {
error!("Got an invalid response from reddit {e}. Status code: {status}");
if status.is_server_error() {
Err("Reddit is having issues, check if there's an outage".to_string())
} else {
err("Failed to parse page JSON data", e.to_string(), path)
}
}
}
}
Err(e) => err("Failed receiving body from Reddit", e.to_string(), path),
}
}
Err(e) => err("Couldn't send request to Reddit", e, path),
}
}
async fn self_check(sub: &str) -> Result<(), String> {
let query = format!("/r/{sub}/hot.json?&raw_json=1");
match Post::fetch(&query, true).await {
Ok(_) => Ok(()),
Err(e) => Err(e),
}
}
pub async fn rate_limit_check() -> Result<(), String> {
self_check("reddit").await?;
if OAUTH_RATELIMIT_REMAINING.load(Ordering::SeqCst) != 99 {
return Err(format!("Rate limit check failed: expected 99, got {}", OAUTH_RATELIMIT_REMAINING.load(Ordering::SeqCst)));
}
force_refresh_token().await;
self_check("rust").await?;
if OAUTH_RATELIMIT_REMAINING.load(Ordering::SeqCst) != 99 {
return Err(format!("Rate limit check failed: expected 99, got {}", OAUTH_RATELIMIT_REMAINING.load(Ordering::SeqCst)));
}
Ok(())
}
#[cfg(test)]
use {crate::config::get_setting, sealed_test::prelude::*};
#[tokio::test(flavor = "multi_thread")]
async fn test_rate_limit_check() {
rate_limit_check().await.unwrap();
}
#[test]
#[sealed_test(env = [("REDLIB_DEFAULT_SUBSCRIPTIONS", "rust")])]
fn test_default_subscriptions() {
tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap().block_on(async {
let subscriptions = get_setting("REDLIB_DEFAULT_SUBSCRIPTIONS");
assert!(subscriptions.is_some());
rate_limit_check().await.unwrap();
});
}
#[cfg(test)]
const POPULAR_URL: &str = "/r/popular/hot.json?&raw_json=1&geo_filter=GLOBAL";
#[tokio::test(flavor = "multi_thread")]
async fn test_localization_popular() {
let val = json(POPULAR_URL.to_string(), false).await.unwrap();
assert_eq!("GLOBAL", val["data"]["geo_filter"].as_str().unwrap());
}
#[tokio::test(flavor = "multi_thread")]
async fn test_obfuscated_share_link() {
let share_link = "/r/rust/s/kPgq8WNHRK".into();
let canonical_link = "/r/rust/comments/18t5968/why_use_tuple_struct_over_standard_struct/kfbqlbc/".into();
assert_eq!(canonical_path(share_link, 3).await, Ok(Some(canonical_link)));
}
#[tokio::test(flavor = "multi_thread")]
async fn test_private_sub() {
let link = json("/r/suicide/about.json?raw_json=1".into(), true).await;
assert!(link.is_err());
assert_eq!(link, Err("private".into()));
}
#[tokio::test(flavor = "multi_thread")]
async fn test_banned_sub() {
let link = json("/r/aaa/about.json?raw_json=1".into(), true).await;
assert!(link.is_err());
assert_eq!(link, Err("banned".into()));
}
#[tokio::test(flavor = "multi_thread")]
async fn test_gated_sub() {
let link = json("/r/drugs/about.json?raw_json=1".into(), false).await;
assert!(link.is_err());
assert_eq!(link, Err("gated".into()));
}