roux_stream/
lib.rs

1/*
2Copyright (c) 2021 Florian Brucker (www.florianbrucker.de)
3
4Permission is hereby granted, free of charge, to any person obtaining a copy
5of this software and associated documentation files (the "Software"), to deal
6in the Software without restriction, including without limitation the rights
7to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8copies of the Software, and to permit persons to whom the Software is
9furnished to do so, subject to the following conditions:
10
11The above copyright notice and this permission notice shall be included in
12all copies or substantial portions of the Software.
13
14THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20THE SOFTWARE.
21*/
22
23/*!
24A streaming API for the [`roux`] Reddit client.
25
26Reddit's API does not provide "firehose"-style streaming of new posts and
27comments. Instead, the endpoints for retrieving the latest posts and comments
28have to be polled regularly. This crate automates that task and provides streams
29for a subreddit's posts (submissions) and comments.
30
31See [`stream_submissions`] and [`stream_comments`] for
32details.
33
34# Logging
35
36This module uses the logging infrastructure provided by the [`log`] crate.
37*/
38
39#![warn(missing_docs)]
40
41use async_trait::async_trait;
42use futures::channel::mpsc;
43use futures::Stream;
44use futures::{Sink, SinkExt};
45use log::{debug, warn};
46use roux::{
47    response::{BasicThing, Listing},
48    submission::SubmissionData,
49    comment::CommentData,
50    util::RouxError,
51    Subreddit,
52};
53use std::error::Error;
54use std::fmt::Display;
55use std::marker::Unpin;
56use std::{collections::HashSet, time::Duration};
57use tokio::sync::Mutex;
58use tokio::task::JoinHandle;
59use tokio::time::error::Elapsed;
60use tokio::time::sleep;
61use tokio_retry::RetryIf;
62
63/**
64The [`roux`] APIs for submissions and comments are slightly different. We use
65the [`Puller`] trait as the common interface to which we then adapt those APIs.
66This allows us to implement our core logic (e.g. retries and duplicate
67filtering) once without caring about the differences between submissions and
68comments. In addition, this makes it easier to test the core logic because
69we can provide a mock implementation.
70*/
71#[async_trait]
72trait Puller<Data, E: Error> {
73    // The "real" implementations of this function (for pulling
74    // submissions and comments from Reddit) would not need `self` to
75    // be `mut` here (because there the state change happens externally,
76    // i.e. within Reddit). However, writing good tests is much easier
77    // if `self` is mutable here.
78    async fn pull(&mut self) -> Result<BasicThing<Listing<BasicThing<Data>>>, E>;
79    fn get_id(&self, data: &Data) -> String;
80    fn get_items_name(&self) -> String;
81    fn get_source_name(&self) -> String;
82}
83
84struct SubredditPuller {
85    subreddit: Subreddit,
86}
87
88// How many items to fetch per request
89const LIMIT: u32 = 100;
90
91#[async_trait]
92impl Puller<SubmissionData, RouxError> for SubredditPuller {
93    async fn pull(
94        &mut self,
95    ) -> Result<BasicThing<Listing<BasicThing<SubmissionData>>>, RouxError> {
96        self.subreddit.latest(LIMIT, None).await
97    }
98
99    fn get_id(&self, data: &SubmissionData) -> String {
100        data.id.clone()
101    }
102
103    fn get_items_name(&self) -> String {
104        "submissions".to_owned()
105    }
106
107    fn get_source_name(&self) -> String {
108        format!("r/{}", self.subreddit.name)
109    }
110}
111
112#[async_trait]
113impl Puller<CommentData, RouxError> for SubredditPuller {
114    async fn pull(
115        &mut self,
116    ) -> Result<BasicThing<Listing<BasicThing<CommentData>>>, RouxError> {
117        self.subreddit.latest_comments(None, Some(LIMIT)).await
118    }
119
120    fn get_id(&self, data: &CommentData) -> String {
121        data.id.as_ref().cloned().unwrap()
122    }
123
124    fn get_items_name(&self) -> String {
125        "comments".to_owned()
126    }
127
128    fn get_source_name(&self) -> String {
129        format!("r/{}", self.subreddit.name)
130    }
131}
132
133/**
134Error that occurs when pulling new data from Reddit failed.
135 */
136#[derive(Debug, PartialEq)]
137pub enum StreamError<E> {
138    /**
139    Returned when pulling new data timed out.
140     */
141    TimeoutError(Elapsed),
142
143    /**
144    Returned when [`roux`] reported an error while pulling new data.
145    */
146    SourceError(E),
147}
148
149impl<E> Display for StreamError<E>
150where
151    E: Display,
152{
153    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
154        match self {
155            StreamError::TimeoutError(err) => err.fmt(f),
156            StreamError::SourceError(err) => err.fmt(f),
157        }
158    }
159}
160
161impl<E> Error for StreamError<E> where E: std::fmt::Debug + Display {}
162
163/**
164Pull new items from Reddit and push them into a sink.
165
166This function contains the core of the streaming logic. It performs the
167following steps in an endless loop:
168
1691. Pull the latest items (submissions or comments) from Reddit, retrying
170   that operation if necessary according to `retry_strategy`. If
171   `timeout` is given and pulling the items from Reddit takes longer
172   then abort and yield an error (see step 3).
1732. Filter out already seen items using their ID.
1743. Push the new items (or an error if pulling failed) into `sink`.
1754. Sleep for `sleep_time`.
176*/
177async fn pull_into_sink<S, R, Data, E>(
178    puller: &mut (dyn Puller<Data, E> + Send + Sync),
179    sleep_time: Duration,
180    retry_strategy: R,
181    timeout: Option<Duration>,
182    mut sink: S,
183) -> Result<(), S::Error>
184where
185    S: Sink<Result<Data, StreamError<E>>> + Unpin,
186    R: IntoIterator<Item = Duration> + Clone,
187    E: Error,
188{
189    let items_name = puller.get_items_name();
190    let source_name = puller.get_source_name();
191    let mut seen_ids: HashSet<String> = HashSet::new();
192
193    /*
194    Because `puller.pull` takes a mutable reference we need wrap it in
195    a mutex to be able to pass it as a callback to `RetryIf::spawn`.
196     */
197    let puller_mutex = Mutex::new(puller);
198
199    loop {
200        debug!("Fetching latest {} from {}", items_name, source_name);
201        let latest = RetryIf::spawn(
202            retry_strategy.clone(),
203            || async {
204                let mut puller = puller_mutex.lock().await;
205
206                // TODO: There is probably a nicer way to write those matches
207
208                if let Some(timeout_duration) = timeout {
209                    let timeout_result =
210                        tokio::time::timeout(timeout_duration, puller.pull()).await;
211                    match timeout_result {
212                        Err(timeout_err) => Err::<BasicThing<Listing<BasicThing<Data>>>, _>(
213                            StreamError::TimeoutError(timeout_err),
214                        ),
215                        Ok(timeout_ok) => match timeout_ok {
216                            Err(puller_err) => Err(StreamError::SourceError(puller_err)),
217                            Ok(pull_ok) => Ok(pull_ok),
218                        },
219                    }
220                } else {
221                    match puller.pull().await {
222                        Err(puller_err) => Err(StreamError::SourceError(puller_err)),
223                        Ok(pull_ok) => Ok(pull_ok),
224                    }
225                }
226            },
227            |error: &StreamError<E>| {
228                debug!(
229                    "Error while fetching the latest {} from {}: {}",
230                    items_name, source_name, error,
231                );
232                true
233            },
234        )
235        .await;
236        match latest {
237            Ok(latest_items) => {
238                let latest_items = latest_items.data.children.into_iter().map(|item| item.data);
239                let mut latest_ids: HashSet<String> = HashSet::new();
240
241                let mut num_new = 0;
242                let puller = puller_mutex.lock().await;
243                for item in latest_items {
244                    let id = puller.get_id(&item);
245                    latest_ids.insert(id.clone());
246                    if !seen_ids.contains(&id) {
247                        num_new += 1;
248                        sink.send(Ok(item)).await?;
249                    }
250                }
251
252                debug!(
253                    "Got {} new {} for {} (out of {})",
254                    num_new, items_name, source_name, LIMIT
255                );
256                if num_new == latest_ids.len() && !seen_ids.is_empty() {
257                    warn!(
258                        "All received {} for {} were new, try a shorter sleep_time",
259                        items_name, source_name
260                    );
261                }
262
263                seen_ids = latest_ids;
264            }
265            Err(error) => {
266                // Forward the error through the stream
267                warn!(
268                    "Error while fetching the latest {} from {}: {}",
269                    items_name, source_name, error,
270                );
271                sink.send(Err(error)).await?;
272            }
273        }
274
275        sleep(sleep_time).await;
276    }
277}
278
279/**
280Spawn a task that pulls items and puts them into a stream.
281
282Depending on `T`, this function will either stream submissions or comments.
283*/
284fn stream_items<R, I, T>(
285    subreddit: &Subreddit,
286    sleep_time: Duration,
287    retry_strategy: R,
288    timeout: Option<Duration>,
289) -> (
290    impl Stream<Item = Result<T, StreamError<RouxError>>>,
291    JoinHandle<Result<(), mpsc::SendError>>,
292)
293where
294    R: IntoIterator<IntoIter = I, Item = Duration> + Clone + Send + Sync + 'static,
295    I: Iterator<Item = Duration> + Send + Sync + 'static,
296    SubredditPuller: Puller<T, RouxError>,
297    T: Send + 'static,
298{
299    let (sink, stream) = mpsc::unbounded();
300    // We need an owned instance (or at least statically bound
301    // reference) for tokio::spawn. Since Subreddit isn't Clone,
302    // we simply create a new instance.
303    let subreddit = Subreddit::new(subreddit.name.as_str());
304    let join_handle = tokio::spawn(async move {
305        pull_into_sink(
306            &mut SubredditPuller { subreddit },
307            sleep_time,
308            retry_strategy,
309            timeout,
310            sink,
311        )
312        .await
313    });
314    (stream, join_handle)
315}
316
317/**
318Stream new submissions in a subreddit.
319
320Creates a separate tokio task that regularly polls the subreddit for new
321submissions. Previously unseen submissions are sent into the returned
322stream.
323
324Returns a tuple `(stream, join_handle)` where `stream` is the
325[`Stream`](futures::Stream) from which the submissions can be read, and
326`join_handle` is the [`JoinHandle`](tokio::task::JoinHandle) for the
327polling task.
328
329`sleep_time` controls the interval between calls to the Reddit API, and
330depends on how much traffic the subreddit has. Each call fetches the 100
331latest items (the maximum number allowed by Reddit). A warning is logged
332if none of those items has been seen in the previous call: this indicates
333a potential miss of new content and suggests that a smaller `sleep_time`
334should be chosen. Enable debug logging for more statistics.
335
336If `timeout` is not `None` then calls to the Reddit API that take longer
337than `timeout` are aborted with a [`StreamError::TimeoutError`].
338
339If an error occurs while fetching the latest submissions from Reddit then
340fetching is retried according to `retry_strategy` (see [`tokio_retry`] for
341details). If one of the retries succeeds then normal operation is resumed.
342If `retry_strategy` is finite and the last retry fails then its error is
343sent into the stream, afterwards normal operation is resumed.
344
345The spawned task runs indefinitely unless an error is encountered when
346sending data into the stream (for example because the receiver is dropped).
347In that case the task stops and the error is returned via `join_handle`.
348
349See also [`stream_comments`].
350
351
352# Example
353
354The following example prints new submissions to
355[r/AskReddit](https://reddit.com/r/AskReddit) in an endless loop.
356
357```
358use futures::StreamExt;
359use roux::Subreddit;
360use roux_stream::stream_submissions;
361use std::time::Duration;
362use tokio_retry::strategy::ExponentialBackoff;
363
364#[tokio::main]
365async fn main() {
366    let subreddit = Subreddit::new("AskReddit");
367
368    // How often to retry when pulling the data from Reddit fails and
369    // how long to wait between retries. See the docs of `tokio_retry`
370    // for details.
371    let retry_strategy = ExponentialBackoff::from_millis(5).factor(100).take(3);
372
373    let (mut stream, join_handle) = stream_submissions(
374        &subreddit,
375        Duration::from_secs(60),
376        retry_strategy,
377        Some(Duration::from_secs(10)),
378    );
379
380    while let Some(submission) = stream.next().await {
381        // `submission` is an `Err` if getting the latest submissions
382        // from Reddit failed even after retrying.
383        let submission = submission.unwrap();
384        println!("\"{}\" by {}", submission.title, submission.author);
385        # // An endless loop doesn't work well with doctests, so in that
386        # // case we abort the task and exit the loop directly.
387        # join_handle.abort();
388        # break;
389    }
390    # // Aborting the task will make the join handle return an error. Let's
391    # // make sure it's the right one.
392    # let join_result = join_handle.await;
393    # assert!(join_result.is_err());
394    # assert!(join_result.err().unwrap().is_cancelled());
395    # // Now we need to make sure that the remaining code in the example
396    # // still works, so we create a fake `join_handle` for it to work
397    # // with.
398    # let join_handle = async { Some(Some(())) };
399
400    // In case there was an error sending the submissions through the
401    // stream, `join_handle` will report it.
402    join_handle.await.unwrap().unwrap();
403}
404```
405*/
406pub fn stream_submissions<R, I>(
407    subreddit: &Subreddit,
408    sleep_time: Duration,
409    retry_strategy: R,
410    timeout: Option<Duration>,
411) -> (
412    impl Stream<Item = Result<SubmissionData, StreamError<RouxError>>>,
413    JoinHandle<Result<(), mpsc::SendError>>,
414)
415where
416    R: IntoIterator<IntoIter = I, Item = Duration> + Clone + Send + Sync + 'static,
417    I: Iterator<Item = Duration> + Send + Sync + 'static,
418{
419    stream_items(subreddit, sleep_time, retry_strategy, timeout)
420}
421
422/**
423Stream new comments in a subreddit.
424
425Creates a separate tokio task that regularly polls the subreddit for new
426comments. Previously unseen comments are sent into the returned
427stream.
428
429Returns a tuple `(stream, join_handle)` where `stream` is the
430[`Stream`](futures::Stream) from which the comments can be read, and
431`join_handle` is the [`JoinHandle`](tokio::task::JoinHandle) for the
432polling task.
433
434`sleep_time` controls the interval between calls to the Reddit API, and
435depends on how much traffic the subreddit has. Each call fetches the 100
436latest items (the maximum number allowed by Reddit). A warning is logged
437if none of those items has been seen in the previous call: this indicates
438a potential miss of new content and suggests that a smaller `sleep_time`
439should be chosen. Enable debug logging for more statistics.
440
441If `timeout` is not `None` then calls to the Reddit API that take longer
442than `timeout` are aborted with a [`StreamError::TimeoutError`].
443
444If an error occurs while fetching the latest comments from Reddit then
445fetching is retried according to `retry_strategy` (see [`tokio_retry`] for
446details). If one of the retries succeeds then normal operation is resumed.
447If `retry_strategy` is finite and the last retry fails then its error is
448sent into the stream, afterwards normal operation is resumed.
449
450The spawned task runs indefinitely unless an error is encountered when
451sending data into the stream (for example because the receiver is dropped).
452In that case the task stops and the error is returned via `join_handle`.
453
454See also [`stream_submissions`].
455
456
457# Example
458
459The following example prints new comments to
460[r/AskReddit](https://reddit.com/r/AskReddit) in an endless loop.
461
462```
463use futures::StreamExt;
464use roux::Subreddit;
465use roux_stream::stream_comments;
466use std::time::Duration;
467use tokio_retry::strategy::ExponentialBackoff;
468
469
470#[tokio::main]
471async fn main() {
472    let subreddit = Subreddit::new("AskReddit");
473
474    // How often to retry when pulling the data from Reddit fails and
475    // how long to wait between retries. See the docs of `tokio_retry`
476    // for details.
477    let retry_strategy = ExponentialBackoff::from_millis(5).factor(100).take(3);
478
479    let (mut stream, join_handle) = stream_comments(
480        &subreddit,
481        Duration::from_secs(10),
482        retry_strategy,
483        Some(Duration::from_secs(10)),
484    );
485
486    while let Some(comment) = stream.next().await {
487        // `comment` is an `Err` if getting the latest comments
488        // from Reddit failed even after retrying.
489        let comment = comment.unwrap();
490        println!(
491            "{}{} (by u/{})",
492            comment.link_url.unwrap(),
493            comment.id.unwrap(),
494            comment.author.unwrap()
495        );
496        # // An endless loop doesn't work well with doctests, so in that
497        # // case we abort the task and exit the loop directly.
498        # join_handle.abort();
499        # break;
500    }
501    # // Aborting the task will make the join handle return an error. Let's
502    # // make sure it's the right one.
503    # let join_result = join_handle.await;
504    # assert!(join_result.is_err());
505    # assert!(join_result.err().unwrap().is_cancelled());
506    # // Now we need to make sure that the remaining code in the example
507    # // still works, so we create a fake `join_handle` for it to work
508    # // with.
509    # let join_handle = async { Some(Some(())) };
510
511    // In case there was an error sending the submissions through the
512    // stream, `join_handle` will report it.
513    join_handle.await.unwrap().unwrap();
514}
515```
516*/
517pub fn stream_comments<R, I>(
518    subreddit: &Subreddit,
519    sleep_time: Duration,
520    retry_strategy: R,
521    timeout: Option<Duration>,
522) -> (
523    impl Stream<Item = Result<CommentData, StreamError<RouxError>>>,
524    JoinHandle<Result<(), mpsc::SendError>>,
525)
526where
527    R: IntoIterator<IntoIter = I, Item = Duration> + Clone + Send + Sync + 'static,
528    I: Iterator<Item = Duration> + Send + Sync + 'static,
529{
530    stream_items(subreddit, sleep_time, retry_strategy, timeout)
531}
532
533#[cfg(test)]
534mod tests {
535    use super::{pull_into_sink, Puller, StreamError};
536    use async_trait::async_trait;
537    use futures::{channel::mpsc, StreamExt};
538    use log::{Level, LevelFilter};
539    use logtest::Logger;
540    use roux::response::{BasicThing, Listing};
541    use std::{error::Error, fmt::Display, time::Duration};
542    use tokio::{sync::RwLock, time::sleep};
543
544    /*
545    Any test case that checks the logging output must run in isolation,
546    so that the log output of other test cases does not disturb it. We
547    use an `RwLock` to achieve that: tests that do log checking take a
548    write lock, while the other test cases take a read lock.
549    */
550    static LOCK: RwLock<()> = RwLock::const_new(());
551
552    #[derive(Debug, PartialEq)]
553    struct MockSourceError(String);
554
555    impl Display for MockSourceError {
556        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
557            write!(f, "{}", self.0)
558        }
559    }
560
561    impl Error for MockSourceError {}
562
563    struct MockPuller {
564        iter: Box<dyn Iterator<Item = Vec<String>> + Sync + Send>,
565    }
566
567    impl MockPuller {
568        fn new(batches: Vec<Vec<&str>>) -> Self {
569            MockPuller {
570                iter: Box::new(
571                    batches
572                        .iter()
573                        .map(|batch| batch.iter().map(|item| item.to_string()).collect())
574                        .collect::<Vec<Vec<String>>>()
575                        .into_iter(),
576                ),
577            }
578        }
579    }
580
581    #[async_trait]
582    impl Puller<String, MockSourceError> for MockPuller {
583        /*
584        Each call to `pull` returns the next batch of items. If a batch
585        consists of a single String that begins with "error" then instead
586        of an Ok an Err is returned. If a batch consists of a single String
587        that begins with "sleep" then the function sleeps for 1s before
588        returning.
589        */
590        async fn pull(
591            &mut self,
592        ) -> Result<BasicThing<Listing<BasicThing<String>>>, MockSourceError> {
593            let children;
594            if let Some(items) = self.iter.next() {
595                match items.as_slice() {
596                    [item] if item.starts_with("error") => {
597                        return Err(MockSourceError(item.clone()));
598                    }
599                    _ => {
600                        if items.len() == 1 && items.get(0).unwrap().starts_with("sleep") {
601                            sleep(Duration::from_secs(1)).await;
602                        }
603                        children = items
604                            .iter()
605                            .map(|item| BasicThing {
606                                kind: Some("mock".to_owned()),
607                                data: item.clone(),
608                            })
609                            .collect();
610                    }
611                }
612            } else {
613                children = vec![];
614            }
615
616            let listing = Listing {
617                modhash: None,
618                dist: None,
619                after: None,
620                before: None,
621                children: children,
622            };
623            let result = BasicThing {
624                kind: Some("listing".to_owned()),
625                data: listing,
626            };
627            Ok(result)
628        }
629
630        fn get_id(&self, data: &String) -> String {
631            data.clone()
632        }
633
634        fn get_items_name(&self) -> String {
635            "MockItems".to_owned()
636        }
637
638        fn get_source_name(&self) -> String {
639            "MockSource".to_owned()
640        }
641    }
642
643    async fn check<R, I>(
644        responses: Vec<Vec<&str>>,
645        retry_strategy: R,
646        timeout: Option<Duration>,
647        expected: Vec<Result<&str, StreamError<MockSourceError>>>,
648    ) where
649        R: IntoIterator<IntoIter = I, Item = Duration> + Clone + Send + Sync + 'static,
650        I: Iterator<Item = Duration> + Send + Sync + 'static,
651    {
652        let mut mock_puller = MockPuller::new(responses);
653        let (sink, stream) = mpsc::unbounded();
654        tokio::spawn(async move {
655            pull_into_sink(
656                &mut mock_puller,
657                Duration::from_millis(1),
658                retry_strategy,
659                timeout,
660                sink,
661            )
662            .await
663        });
664        let items = stream.take(expected.len()).collect::<Vec<_>>().await;
665        assert_eq!(
666            items,
667            expected
668                .into_iter()
669                .map(|result| result.map(|ok_value| ok_value.to_string()))
670                .collect::<Vec<_>>()
671        );
672    }
673
674    #[tokio::test]
675    async fn test_simple_pull() {
676        let _lock = LOCK.read().await;
677        check(vec![vec!["hello"]], vec![], None, vec![Ok("hello")]).await;
678    }
679
680    #[tokio::test]
681    async fn test_duplicate_filtering() {
682        let _lock = LOCK.read().await;
683        check(
684            vec![vec!["a", "b", "c"], vec!["b", "c", "d"], vec!["d", "e"]],
685            vec![],
686            None,
687            vec![Ok("a"), Ok("b"), Ok("c"), Ok("d"), Ok("e")],
688        )
689        .await;
690    }
691
692    #[tokio::test]
693    async fn test_success_after_retry() {
694        let _lock = LOCK.read().await;
695        check(
696            vec![
697                vec!["a", "b", "c"],
698                vec!["error1"],
699                vec!["error2"],
700                vec!["b", "c", "d"],
701            ],
702            vec![Duration::from_millis(1), Duration::from_millis(1)],
703            None,
704            vec![Ok("a"), Ok("b"), Ok("c"), Ok("d")],
705        )
706        .await;
707    }
708
709    #[tokio::test]
710    async fn test_failure_after_retry() {
711        let _lock = LOCK.read().await;
712        check(
713            vec![
714                vec!["a", "b", "c"],
715                vec!["error1"],
716                vec!["error2"],
717                vec!["b", "c", "d"],
718            ],
719            vec![Duration::from_millis(1)],
720            None,
721            vec![
722                Ok("a"),
723                Ok("b"),
724                Ok("c"),
725                Err(StreamError::SourceError(MockSourceError(
726                    "error2".to_owned(),
727                ))),
728                Ok("d"),
729            ],
730        )
731        .await;
732    }
733
734    #[tokio::test]
735    async fn test_warning_if_all_items_are_unseen() {
736        let _lock = LOCK.write().await; // exclusive lock
737        let mut logger = Logger::start();
738        log::set_max_level(LevelFilter::Warn);
739        check(
740            vec![vec!["a", "b"], vec!["c", "d"]],
741            vec![],
742            None,
743            vec![Ok("a"), Ok("b"), Ok("c"), Ok("d")],
744        )
745        .await;
746
747        let num_records = logger.len();
748        if num_records != 1 {
749            println!();
750            println!("{} LOG MESSAGES:", logger.len());
751            while let Some(record) = logger.pop() {
752                println!("[{}] {}", record.level(), record.args());
753            }
754            println!();
755            assert!(false, "Expected 1 log message, got {}", num_records);
756        }
757
758        let record = logger.pop().unwrap();
759        assert_eq!(record.level(), Level::Warn);
760        assert_eq!(
761            record.args(),
762            "All received MockItems for MockSource were new, try a shorter sleep_time",
763        );
764    }
765
766    #[tokio::test]
767    async fn test_sink_error_when_sending_new_item() {
768        let _lock = LOCK.read().await;
769        let mut mock_puller = MockPuller::new(vec![vec!["a"]]);
770        let (sink, stream) = mpsc::unbounded();
771        drop(stream); // drop receiver so that sending fails
772        let join_handle = tokio::spawn(async move {
773            pull_into_sink(
774                &mut mock_puller,
775                Duration::from_millis(1),
776                vec![],
777                None,
778                sink,
779            )
780            .await
781        });
782        let result = join_handle.await.unwrap();
783        assert!(result.is_err());
784    }
785
786    #[tokio::test]
787    async fn test_sink_error_when_sending_error() {
788        let _lock = LOCK.read().await;
789        let mut mock_puller = MockPuller::new(vec![vec!["error"]]);
790        let (sink, stream) = mpsc::unbounded();
791        drop(stream); // drop receiver so that sending fails
792        let join_handle = tokio::spawn(async move {
793            pull_into_sink(
794                &mut mock_puller,
795                Duration::from_millis(1),
796                vec![],
797                None,
798                sink,
799            )
800            .await
801        });
802        let result = join_handle.await.unwrap();
803        assert!(result.is_err());
804    }
805
806    #[tokio::test]
807    async fn test_timeout_ok() {
808        let _lock = LOCK.read().await;
809        check(
810            vec![vec!["a", "b", "c"], vec!["b", "c", "d"]],
811            vec![],
812            Some(Duration::from_secs(1)),
813            vec![Ok("a"), Ok("b"), Ok("c"), Ok("d")],
814        )
815        .await;
816    }
817
818    #[tokio::test]
819    async fn test_timeout_error() {
820        let _lock = LOCK.read().await;
821
822        let timeout = Duration::from_millis(100);
823
824        // There is probably a less stupid way of constructing
825        // an instance of Elapsed...
826        let elapsed = tokio::time::timeout(timeout.clone(), sleep(Duration::from_secs(1)))
827            .await
828            .unwrap_err();
829
830        check(
831            vec![vec!["a", "b", "c"], vec!["sleep"], vec!["b", "c", "d"]],
832            vec![],
833            Some(timeout),
834            vec![
835                Ok("a"),
836                Ok("b"),
837                Ok("c"),
838                Err(StreamError::TimeoutError(elapsed)),
839                Ok("d"),
840            ],
841        )
842        .await;
843    }
844
845    #[tokio::test]
846    async fn test_timeout_retry() {
847        let _lock = LOCK.read().await;
848
849        check(
850            vec![vec!["a", "b", "c"], vec!["sleep"], vec!["b", "c", "d"]],
851            vec![Duration::from_millis(1)],
852            Some(Duration::from_millis(100)),
853            vec![Ok("a"), Ok("b"), Ok("c"), Ok("d")],
854        )
855        .await;
856    }
857}