#![warn(missing_docs)]
use async_trait::async_trait;
use futures::channel::mpsc;
use futures::Stream;
use futures::{Sink, SinkExt};
use log::{debug, warn};
use roux::{
response::{BasicThing, Listing},
submission::SubmissionData,
comment::CommentData,
util::RouxError,
Subreddit,
};
use std::error::Error;
use std::fmt::Display;
use std::marker::Unpin;
use std::{collections::HashSet, time::Duration};
use tokio::sync::Mutex;
use tokio::task::JoinHandle;
use tokio::time::error::Elapsed;
use tokio::time::sleep;
use tokio_retry::RetryIf;
#[async_trait]
trait Puller<Data, E: Error> {
async fn pull(&mut self) -> Result<BasicThing<Listing<BasicThing<Data>>>, E>;
fn get_id(&self, data: &Data) -> String;
fn get_items_name(&self) -> String;
fn get_source_name(&self) -> String;
}
struct SubredditPuller {
subreddit: Subreddit,
}
const LIMIT: u32 = 100;
#[async_trait]
impl Puller<SubmissionData, RouxError> for SubredditPuller {
async fn pull(
&mut self,
) -> Result<BasicThing<Listing<BasicThing<SubmissionData>>>, RouxError> {
self.subreddit.latest(LIMIT, None).await
}
fn get_id(&self, data: &SubmissionData) -> String {
data.id.clone()
}
fn get_items_name(&self) -> String {
"submissions".to_owned()
}
fn get_source_name(&self) -> String {
format!("r/{}", self.subreddit.name)
}
}
#[async_trait]
impl Puller<CommentData, RouxError> for SubredditPuller {
async fn pull(
&mut self,
) -> Result<BasicThing<Listing<BasicThing<CommentData>>>, RouxError> {
self.subreddit.latest_comments(None, Some(LIMIT)).await
}
fn get_id(&self, data: &CommentData) -> String {
data.id.as_ref().cloned().unwrap()
}
fn get_items_name(&self) -> String {
"comments".to_owned()
}
fn get_source_name(&self) -> String {
format!("r/{}", self.subreddit.name)
}
}
#[derive(Debug, PartialEq)]
pub enum StreamError<E> {
TimeoutError(Elapsed),
SourceError(E),
}
impl<E> Display for StreamError<E>
where
E: Display,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
StreamError::TimeoutError(err) => err.fmt(f),
StreamError::SourceError(err) => err.fmt(f),
}
}
}
impl<E> Error for StreamError<E> where E: std::fmt::Debug + Display {}
async fn pull_into_sink<S, R, Data, E>(
puller: &mut (dyn Puller<Data, E> + Send + Sync),
sleep_time: Duration,
retry_strategy: R,
timeout: Option<Duration>,
mut sink: S,
) -> Result<(), S::Error>
where
S: Sink<Result<Data, StreamError<E>>> + Unpin,
R: IntoIterator<Item = Duration> + Clone,
E: Error,
{
let items_name = puller.get_items_name();
let source_name = puller.get_source_name();
let mut seen_ids: HashSet<String> = HashSet::new();
let puller_mutex = Mutex::new(puller);
loop {
debug!("Fetching latest {} from {}", items_name, source_name);
let latest = RetryIf::spawn(
retry_strategy.clone(),
|| async {
let mut puller = puller_mutex.lock().await;
if let Some(timeout_duration) = timeout {
let timeout_result =
tokio::time::timeout(timeout_duration, puller.pull()).await;
match timeout_result {
Err(timeout_err) => Err::<BasicThing<Listing<BasicThing<Data>>>, _>(
StreamError::TimeoutError(timeout_err),
),
Ok(timeout_ok) => match timeout_ok {
Err(puller_err) => Err(StreamError::SourceError(puller_err)),
Ok(pull_ok) => Ok(pull_ok),
},
}
} else {
match puller.pull().await {
Err(puller_err) => Err(StreamError::SourceError(puller_err)),
Ok(pull_ok) => Ok(pull_ok),
}
}
},
|error: &StreamError<E>| {
debug!(
"Error while fetching the latest {} from {}: {}",
items_name, source_name, error,
);
true
},
)
.await;
match latest {
Ok(latest_items) => {
let latest_items = latest_items.data.children.into_iter().map(|item| item.data);
let mut latest_ids: HashSet<String> = HashSet::new();
let mut num_new = 0;
let puller = puller_mutex.lock().await;
for item in latest_items {
let id = puller.get_id(&item);
latest_ids.insert(id.clone());
if !seen_ids.contains(&id) {
num_new += 1;
sink.send(Ok(item)).await?;
}
}
debug!(
"Got {} new {} for {} (out of {})",
num_new, items_name, source_name, LIMIT
);
if num_new == latest_ids.len() && !seen_ids.is_empty() {
warn!(
"All received {} for {} were new, try a shorter sleep_time",
items_name, source_name
);
}
seen_ids = latest_ids;
}
Err(error) => {
warn!(
"Error while fetching the latest {} from {}: {}",
items_name, source_name, error,
);
sink.send(Err(error)).await?;
}
}
sleep(sleep_time).await;
}
}
fn stream_items<R, I, T>(
subreddit: &Subreddit,
sleep_time: Duration,
retry_strategy: R,
timeout: Option<Duration>,
) -> (
impl Stream<Item = Result<T, StreamError<RouxError>>>,
JoinHandle<Result<(), mpsc::SendError>>,
)
where
R: IntoIterator<IntoIter = I, Item = Duration> + Clone + Send + Sync + 'static,
I: Iterator<Item = Duration> + Send + Sync + 'static,
SubredditPuller: Puller<T, RouxError>,
T: Send + 'static,
{
let (sink, stream) = mpsc::unbounded();
let subreddit = Subreddit::new(subreddit.name.as_str());
let join_handle = tokio::spawn(async move {
pull_into_sink(
&mut SubredditPuller { subreddit },
sleep_time,
retry_strategy,
timeout,
sink,
)
.await
});
(stream, join_handle)
}
pub fn stream_submissions<R, I>(
subreddit: &Subreddit,
sleep_time: Duration,
retry_strategy: R,
timeout: Option<Duration>,
) -> (
impl Stream<Item = Result<SubmissionData, StreamError<RouxError>>>,
JoinHandle<Result<(), mpsc::SendError>>,
)
where
R: IntoIterator<IntoIter = I, Item = Duration> + Clone + Send + Sync + 'static,
I: Iterator<Item = Duration> + Send + Sync + 'static,
{
stream_items(subreddit, sleep_time, retry_strategy, timeout)
}
pub fn stream_comments<R, I>(
subreddit: &Subreddit,
sleep_time: Duration,
retry_strategy: R,
timeout: Option<Duration>,
) -> (
impl Stream<Item = Result<CommentData, StreamError<RouxError>>>,
JoinHandle<Result<(), mpsc::SendError>>,
)
where
R: IntoIterator<IntoIter = I, Item = Duration> + Clone + Send + Sync + 'static,
I: Iterator<Item = Duration> + Send + Sync + 'static,
{
stream_items(subreddit, sleep_time, retry_strategy, timeout)
}
#[cfg(test)]
mod tests {
use super::{pull_into_sink, Puller, StreamError};
use async_trait::async_trait;
use futures::{channel::mpsc, StreamExt};
use log::{Level, LevelFilter};
use logtest::Logger;
use roux::response::{BasicThing, Listing};
use std::{error::Error, fmt::Display, time::Duration};
use tokio::{sync::RwLock, time::sleep};
static LOCK: RwLock<()> = RwLock::const_new(());
#[derive(Debug, PartialEq)]
struct MockSourceError(String);
impl Display for MockSourceError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl Error for MockSourceError {}
struct MockPuller {
iter: Box<dyn Iterator<Item = Vec<String>> + Sync + Send>,
}
impl MockPuller {
fn new(batches: Vec<Vec<&str>>) -> Self {
MockPuller {
iter: Box::new(
batches
.iter()
.map(|batch| batch.iter().map(|item| item.to_string()).collect())
.collect::<Vec<Vec<String>>>()
.into_iter(),
),
}
}
}
#[async_trait]
impl Puller<String, MockSourceError> for MockPuller {
async fn pull(
&mut self,
) -> Result<BasicThing<Listing<BasicThing<String>>>, MockSourceError> {
let children;
if let Some(items) = self.iter.next() {
match items.as_slice() {
[item] if item.starts_with("error") => {
return Err(MockSourceError(item.clone()));
}
_ => {
if items.len() == 1 && items.get(0).unwrap().starts_with("sleep") {
sleep(Duration::from_secs(1)).await;
}
children = items
.iter()
.map(|item| BasicThing {
kind: Some("mock".to_owned()),
data: item.clone(),
})
.collect();
}
}
} else {
children = vec![];
}
let listing = Listing {
modhash: None,
dist: None,
after: None,
before: None,
children: children,
};
let result = BasicThing {
kind: Some("listing".to_owned()),
data: listing,
};
Ok(result)
}
fn get_id(&self, data: &String) -> String {
data.clone()
}
fn get_items_name(&self) -> String {
"MockItems".to_owned()
}
fn get_source_name(&self) -> String {
"MockSource".to_owned()
}
}
async fn check<R, I>(
responses: Vec<Vec<&str>>,
retry_strategy: R,
timeout: Option<Duration>,
expected: Vec<Result<&str, StreamError<MockSourceError>>>,
) where
R: IntoIterator<IntoIter = I, Item = Duration> + Clone + Send + Sync + 'static,
I: Iterator<Item = Duration> + Send + Sync + 'static,
{
let mut mock_puller = MockPuller::new(responses);
let (sink, stream) = mpsc::unbounded();
tokio::spawn(async move {
pull_into_sink(
&mut mock_puller,
Duration::from_millis(1),
retry_strategy,
timeout,
sink,
)
.await
});
let items = stream.take(expected.len()).collect::<Vec<_>>().await;
assert_eq!(
items,
expected
.into_iter()
.map(|result| result.map(|ok_value| ok_value.to_string()))
.collect::<Vec<_>>()
);
}
#[tokio::test]
async fn test_simple_pull() {
let _lock = LOCK.read().await;
check(vec![vec!["hello"]], vec![], None, vec![Ok("hello")]).await;
}
#[tokio::test]
async fn test_duplicate_filtering() {
let _lock = LOCK.read().await;
check(
vec![vec!["a", "b", "c"], vec!["b", "c", "d"], vec!["d", "e"]],
vec![],
None,
vec![Ok("a"), Ok("b"), Ok("c"), Ok("d"), Ok("e")],
)
.await;
}
#[tokio::test]
async fn test_success_after_retry() {
let _lock = LOCK.read().await;
check(
vec![
vec!["a", "b", "c"],
vec!["error1"],
vec!["error2"],
vec!["b", "c", "d"],
],
vec![Duration::from_millis(1), Duration::from_millis(1)],
None,
vec![Ok("a"), Ok("b"), Ok("c"), Ok("d")],
)
.await;
}
#[tokio::test]
async fn test_failure_after_retry() {
let _lock = LOCK.read().await;
check(
vec![
vec!["a", "b", "c"],
vec!["error1"],
vec!["error2"],
vec!["b", "c", "d"],
],
vec![Duration::from_millis(1)],
None,
vec![
Ok("a"),
Ok("b"),
Ok("c"),
Err(StreamError::SourceError(MockSourceError(
"error2".to_owned(),
))),
Ok("d"),
],
)
.await;
}
#[tokio::test]
async fn test_warning_if_all_items_are_unseen() {
let _lock = LOCK.write().await; let mut logger = Logger::start();
log::set_max_level(LevelFilter::Warn);
check(
vec![vec!["a", "b"], vec!["c", "d"]],
vec![],
None,
vec![Ok("a"), Ok("b"), Ok("c"), Ok("d")],
)
.await;
let num_records = logger.len();
if num_records != 1 {
println!();
println!("{} LOG MESSAGES:", logger.len());
while let Some(record) = logger.pop() {
println!("[{}] {}", record.level(), record.args());
}
println!();
assert!(false, "Expected 1 log message, got {}", num_records);
}
let record = logger.pop().unwrap();
assert_eq!(record.level(), Level::Warn);
assert_eq!(
record.args(),
"All received MockItems for MockSource were new, try a shorter sleep_time",
);
}
#[tokio::test]
async fn test_sink_error_when_sending_new_item() {
let _lock = LOCK.read().await;
let mut mock_puller = MockPuller::new(vec![vec!["a"]]);
let (sink, stream) = mpsc::unbounded();
drop(stream); let join_handle = tokio::spawn(async move {
pull_into_sink(
&mut mock_puller,
Duration::from_millis(1),
vec![],
None,
sink,
)
.await
});
let result = join_handle.await.unwrap();
assert!(result.is_err());
}
#[tokio::test]
async fn test_sink_error_when_sending_error() {
let _lock = LOCK.read().await;
let mut mock_puller = MockPuller::new(vec![vec!["error"]]);
let (sink, stream) = mpsc::unbounded();
drop(stream); let join_handle = tokio::spawn(async move {
pull_into_sink(
&mut mock_puller,
Duration::from_millis(1),
vec![],
None,
sink,
)
.await
});
let result = join_handle.await.unwrap();
assert!(result.is_err());
}
#[tokio::test]
async fn test_timeout_ok() {
let _lock = LOCK.read().await;
check(
vec![vec!["a", "b", "c"], vec!["b", "c", "d"]],
vec![],
Some(Duration::from_secs(1)),
vec![Ok("a"), Ok("b"), Ok("c"), Ok("d")],
)
.await;
}
#[tokio::test]
async fn test_timeout_error() {
let _lock = LOCK.read().await;
let timeout = Duration::from_millis(100);
let elapsed = tokio::time::timeout(timeout.clone(), sleep(Duration::from_secs(1)))
.await
.unwrap_err();
check(
vec![vec!["a", "b", "c"], vec!["sleep"], vec!["b", "c", "d"]],
vec![],
Some(timeout),
vec![
Ok("a"),
Ok("b"),
Ok("c"),
Err(StreamError::TimeoutError(elapsed)),
Ok("d"),
],
)
.await;
}
#[tokio::test]
async fn test_timeout_retry() {
let _lock = LOCK.read().await;
check(
vec![vec!["a", "b", "c"], vec!["sleep"], vec!["b", "c", "d"]],
vec![Duration::from_millis(1)],
Some(Duration::from_millis(100)),
vec![Ok("a"), Ok("b"), Ok("c"), Ok("d")],
)
.await;
}
}