use crate::models::{AsAttrs, Comment, Post};
use crate::{Filter, SortType};
use async_stream::stream;
use chrono::{DateTime, Duration, Utc};
use futures::stream::{self, select_all, Stream, StreamExt};
use governor::{Quota, RateLimiter};
use once_cell::sync::OnceCell;
use reqwest::{IntoUrl, Url};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use std::num::NonZeroU32;
use std::ops::Div;
use std::pin::Pin;
type PSRateLimiter = RateLimiter<
governor::state::NotKeyed,
governor::state::InMemoryState,
governor::clock::DefaultClock,
governor::middleware::NoOpMiddleware,
>;
const BATCH_SIZE: i64 = 50;
const DESIRED_BUCKET_VOLUME: i64 = 25;
fn rate_limiter() -> &'static PSRateLimiter {
static PS_RATE_LIMITER: OnceCell<PSRateLimiter> = OnceCell::new();
PS_RATE_LIMITER
.get_or_init(|| RateLimiter::direct(Quota::per_second(NonZeroU32::new(1).unwrap())))
}
#[derive(Deserialize, Debug)]
struct PushShiftMetadata {
total_results: i64,
}
#[derive(Deserialize, Debug)]
struct PushShiftResponse<T> {
data: Vec<T>,
metadata: Option<PushShiftMetadata>,
}
#[derive(Clone, Serialize)]
struct PushShiftQueryParams<'a> {
#[serde(flatten)]
inner: &'a Filter,
sort: Option<&'static str>,
limit: i64,
metadata: bool,
}
#[derive(Clone)]
pub struct Client {
client: reqwest::Client,
limiter: &'static PSRateLimiter,
}
impl Client {
pub fn new() -> Self {
Self::with_client(reqwest::Client::new())
}
pub fn with_client(client: reqwest::Client) -> Self {
Self {
client,
limiter: rate_limiter(),
}
}
pub async fn get_comments(&self, filter: Filter) -> Pin<Box<dyn Stream<Item = Comment> + '_>> {
let url = Url::parse("https://api.pushshift.io/reddit/comment/search/").unwrap();
self._stream(url, filter).await
}
pub async fn get_posts(&self, filter: Filter) -> Pin<Box<dyn Stream<Item = Post> + '_>> {
let url = Url::parse("https://api.pushshift.io/reddit/submission/search/").unwrap();
self._stream(url, filter).await
}
async fn _stream<T: 'static + DeserializeOwned + AsAttrs>(
&self,
url: Url,
filter: Filter,
) -> Pin<Box<dyn Stream<Item = T> + '_>> {
if matches!(filter.sort_type, SortType::CreatedDate) {
if let Some((total, oldest, newest)) =
self.get_date_bounds::<Post>(url.clone(), &filter).await
{
return Box::pin(
select_all(chunked(total, oldest, newest).map(|(l, r)| {
Box::pin(self.paginated(url.clone(), filter.clone().before(r).after(l)))
}))
.flat_map(stream::iter),
);
}
}
Box::pin(self.paginated(url, filter).flat_map(stream::iter))
}
async fn _get<T: DeserializeOwned>(
&self,
url: Url,
params: PushShiftQueryParams<'_>,
) -> Option<PushShiftResponse<T>> {
self.limiter.until_ready().await;
let response = self.client.get(url).query(¶ms).send().await;
if let Ok(response) = response {
if let Ok(parsed_response) = response.json::<PushShiftResponse<T>>().await {
return Some(parsed_response);
}
}
None
}
async fn get_date_bounds<'a, T: DeserializeOwned + AsAttrs>(
&self,
url: Url,
params: &Filter,
) -> Option<(i64, DateTime<Utc>, DateTime<Utc>)> {
let newest: PushShiftResponse<T> = self
._get(
url.clone(),
PushShiftQueryParams {
inner: params,
sort: Some("desc"),
limit: 1,
metadata: true,
},
)
.await?;
let total_results = if let Some(metadata) = &newest.metadata {
metadata.total_results
} else {
return None;
};
if total_results <= BATCH_SIZE {
return None;
}
let oldest: PushShiftResponse<T> = self
._get(
url,
PushShiftQueryParams {
inner: params,
sort: Some("asc"),
limit: 1,
metadata: false,
},
)
.await?;
Some((
total_results,
oldest.data[0].attrs().date,
newest.data[0].attrs().date,
))
}
fn paginated<T, U>(&self, url: U, mut params: Filter) -> impl Stream<Item = Vec<T>> + '_
where
T: 'static + DeserializeOwned + AsAttrs,
U: IntoUrl,
{
let url = url.into_url().unwrap();
stream! {
loop {
let inner_params = PushShiftQueryParams {
inner: ¶ms,
sort: None,
limit: BATCH_SIZE,
metadata: false,
};
if let Some(parsed_response) = self._get::<T>(url.clone(), inner_params).await {
if let Some(last_content) = parsed_response.data.last() {
params = params.before(last_content.attrs().date.clone());
} else {
break;
}
let should_break = parsed_response.data.len() < BATCH_SIZE as usize;
yield parsed_response.data;
if should_break {
break;
} else {
continue;
}
}
break;
}
}
}
}
impl Default for Client {
fn default() -> Self {
Self::new()
}
}
fn chunked(
total: i64,
oldest: DateTime<Utc>,
newest: DateTime<Utc>,
) -> impl Iterator<Item = (DateTime<Utc>, DateTime<Utc>)> {
let buckets = (total / DESIRED_BUCKET_VOLUME).min(200);
let bucket_width = (newest - oldest).div((buckets + 1) as i32).num_seconds();
(0..=buckets).map(move |c| {
let l = if c == 0 {
oldest
} else {
oldest + Duration::seconds((c * bucket_width) + 1)
};
let r = if c == buckets {
newest
} else {
oldest + Duration::seconds((c + 1) * bucket_width)
};
(l, r)
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_client_is_send_and_sync() {
fn is_send_and_sync<T: Send + Sync>() {}
is_send_and_sync::<Client>();
}
}