roux_stream/lib.rs
1/*
2Copyright (c) 2021 Florian Brucker (www.florianbrucker.de)
3
4Permission is hereby granted, free of charge, to any person obtaining a copy
5of this software and associated documentation files (the "Software"), to deal
6in the Software without restriction, including without limitation the rights
7to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8copies of the Software, and to permit persons to whom the Software is
9furnished to do so, subject to the following conditions:
10
11The above copyright notice and this permission notice shall be included in
12all copies or substantial portions of the Software.
13
14THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20THE SOFTWARE.
21*/
22
23/*!
24A streaming API for the [`roux`] Reddit client.
25
26Reddit's API does not provide "firehose"-style streaming of new posts and
27comments. Instead, the endpoints for retrieving the latest posts and comments
28have to be polled regularly. This crate automates that task and provides streams
29for a subreddit's posts (submissions) and comments.
30
31See [`stream_submissions`] and [`stream_comments`] for
32details.
33
34# Logging
35
36This module uses the logging infrastructure provided by the [`log`] crate.
37*/
38
39#![warn(missing_docs)]
40
41use async_trait::async_trait;
42use futures::channel::mpsc;
43use futures::Stream;
44use futures::{Sink, SinkExt};
45use log::{debug, warn};
46use roux::{
47 response::{BasicThing, Listing},
48 submission::SubmissionData,
49 comment::CommentData,
50 util::RouxError,
51 Subreddit,
52};
53use std::error::Error;
54use std::fmt::Display;
55use std::marker::Unpin;
56use std::{collections::HashSet, time::Duration};
57use tokio::sync::Mutex;
58use tokio::task::JoinHandle;
59use tokio::time::error::Elapsed;
60use tokio::time::sleep;
61use tokio_retry::RetryIf;
62
63/**
64The [`roux`] APIs for submissions and comments are slightly different. We use
65the [`Puller`] trait as the common interface to which we then adapt those APIs.
66This allows us to implement our core logic (e.g. retries and duplicate
67filtering) once without caring about the differences between submissions and
68comments. In addition, this makes it easier to test the core logic because
69we can provide a mock implementation.
70*/
71#[async_trait]
72trait Puller<Data, E: Error> {
73 // The "real" implementations of this function (for pulling
74 // submissions and comments from Reddit) would not need `self` to
75 // be `mut` here (because there the state change happens externally,
76 // i.e. within Reddit). However, writing good tests is much easier
77 // if `self` is mutable here.
78 async fn pull(&mut self) -> Result<BasicThing<Listing<BasicThing<Data>>>, E>;
79 fn get_id(&self, data: &Data) -> String;
80 fn get_items_name(&self) -> String;
81 fn get_source_name(&self) -> String;
82}
83
84struct SubredditPuller {
85 subreddit: Subreddit,
86}
87
88// How many items to fetch per request
89const LIMIT: u32 = 100;
90
91#[async_trait]
92impl Puller<SubmissionData, RouxError> for SubredditPuller {
93 async fn pull(
94 &mut self,
95 ) -> Result<BasicThing<Listing<BasicThing<SubmissionData>>>, RouxError> {
96 self.subreddit.latest(LIMIT, None).await
97 }
98
99 fn get_id(&self, data: &SubmissionData) -> String {
100 data.id.clone()
101 }
102
103 fn get_items_name(&self) -> String {
104 "submissions".to_owned()
105 }
106
107 fn get_source_name(&self) -> String {
108 format!("r/{}", self.subreddit.name)
109 }
110}
111
112#[async_trait]
113impl Puller<CommentData, RouxError> for SubredditPuller {
114 async fn pull(
115 &mut self,
116 ) -> Result<BasicThing<Listing<BasicThing<CommentData>>>, RouxError> {
117 self.subreddit.latest_comments(None, Some(LIMIT)).await
118 }
119
120 fn get_id(&self, data: &CommentData) -> String {
121 data.id.as_ref().cloned().unwrap()
122 }
123
124 fn get_items_name(&self) -> String {
125 "comments".to_owned()
126 }
127
128 fn get_source_name(&self) -> String {
129 format!("r/{}", self.subreddit.name)
130 }
131}
132
133/**
134Error that occurs when pulling new data from Reddit failed.
135 */
136#[derive(Debug, PartialEq)]
137pub enum StreamError<E> {
138 /**
139 Returned when pulling new data timed out.
140 */
141 TimeoutError(Elapsed),
142
143 /**
144 Returned when [`roux`] reported an error while pulling new data.
145 */
146 SourceError(E),
147}
148
149impl<E> Display for StreamError<E>
150where
151 E: Display,
152{
153 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
154 match self {
155 StreamError::TimeoutError(err) => err.fmt(f),
156 StreamError::SourceError(err) => err.fmt(f),
157 }
158 }
159}
160
161impl<E> Error for StreamError<E> where E: std::fmt::Debug + Display {}
162
163/**
164Pull new items from Reddit and push them into a sink.
165
166This function contains the core of the streaming logic. It performs the
167following steps in an endless loop:
168
1691. Pull the latest items (submissions or comments) from Reddit, retrying
170 that operation if necessary according to `retry_strategy`. If
171 `timeout` is given and pulling the items from Reddit takes longer
172 then abort and yield an error (see step 3).
1732. Filter out already seen items using their ID.
1743. Push the new items (or an error if pulling failed) into `sink`.
1754. Sleep for `sleep_time`.
176*/
177async fn pull_into_sink<S, R, Data, E>(
178 puller: &mut (dyn Puller<Data, E> + Send + Sync),
179 sleep_time: Duration,
180 retry_strategy: R,
181 timeout: Option<Duration>,
182 mut sink: S,
183) -> Result<(), S::Error>
184where
185 S: Sink<Result<Data, StreamError<E>>> + Unpin,
186 R: IntoIterator<Item = Duration> + Clone,
187 E: Error,
188{
189 let items_name = puller.get_items_name();
190 let source_name = puller.get_source_name();
191 let mut seen_ids: HashSet<String> = HashSet::new();
192
193 /*
194 Because `puller.pull` takes a mutable reference we need wrap it in
195 a mutex to be able to pass it as a callback to `RetryIf::spawn`.
196 */
197 let puller_mutex = Mutex::new(puller);
198
199 loop {
200 debug!("Fetching latest {} from {}", items_name, source_name);
201 let latest = RetryIf::spawn(
202 retry_strategy.clone(),
203 || async {
204 let mut puller = puller_mutex.lock().await;
205
206 // TODO: There is probably a nicer way to write those matches
207
208 if let Some(timeout_duration) = timeout {
209 let timeout_result =
210 tokio::time::timeout(timeout_duration, puller.pull()).await;
211 match timeout_result {
212 Err(timeout_err) => Err::<BasicThing<Listing<BasicThing<Data>>>, _>(
213 StreamError::TimeoutError(timeout_err),
214 ),
215 Ok(timeout_ok) => match timeout_ok {
216 Err(puller_err) => Err(StreamError::SourceError(puller_err)),
217 Ok(pull_ok) => Ok(pull_ok),
218 },
219 }
220 } else {
221 match puller.pull().await {
222 Err(puller_err) => Err(StreamError::SourceError(puller_err)),
223 Ok(pull_ok) => Ok(pull_ok),
224 }
225 }
226 },
227 |error: &StreamError<E>| {
228 debug!(
229 "Error while fetching the latest {} from {}: {}",
230 items_name, source_name, error,
231 );
232 true
233 },
234 )
235 .await;
236 match latest {
237 Ok(latest_items) => {
238 let latest_items = latest_items.data.children.into_iter().map(|item| item.data);
239 let mut latest_ids: HashSet<String> = HashSet::new();
240
241 let mut num_new = 0;
242 let puller = puller_mutex.lock().await;
243 for item in latest_items {
244 let id = puller.get_id(&item);
245 latest_ids.insert(id.clone());
246 if !seen_ids.contains(&id) {
247 num_new += 1;
248 sink.send(Ok(item)).await?;
249 }
250 }
251
252 debug!(
253 "Got {} new {} for {} (out of {})",
254 num_new, items_name, source_name, LIMIT
255 );
256 if num_new == latest_ids.len() && !seen_ids.is_empty() {
257 warn!(
258 "All received {} for {} were new, try a shorter sleep_time",
259 items_name, source_name
260 );
261 }
262
263 seen_ids = latest_ids;
264 }
265 Err(error) => {
266 // Forward the error through the stream
267 warn!(
268 "Error while fetching the latest {} from {}: {}",
269 items_name, source_name, error,
270 );
271 sink.send(Err(error)).await?;
272 }
273 }
274
275 sleep(sleep_time).await;
276 }
277}
278
279/**
280Spawn a task that pulls items and puts them into a stream.
281
282Depending on `T`, this function will either stream submissions or comments.
283*/
284fn stream_items<R, I, T>(
285 subreddit: &Subreddit,
286 sleep_time: Duration,
287 retry_strategy: R,
288 timeout: Option<Duration>,
289) -> (
290 impl Stream<Item = Result<T, StreamError<RouxError>>>,
291 JoinHandle<Result<(), mpsc::SendError>>,
292)
293where
294 R: IntoIterator<IntoIter = I, Item = Duration> + Clone + Send + Sync + 'static,
295 I: Iterator<Item = Duration> + Send + Sync + 'static,
296 SubredditPuller: Puller<T, RouxError>,
297 T: Send + 'static,
298{
299 let (sink, stream) = mpsc::unbounded();
300 // We need an owned instance (or at least statically bound
301 // reference) for tokio::spawn. Since Subreddit isn't Clone,
302 // we simply create a new instance.
303 let subreddit = Subreddit::new(subreddit.name.as_str());
304 let join_handle = tokio::spawn(async move {
305 pull_into_sink(
306 &mut SubredditPuller { subreddit },
307 sleep_time,
308 retry_strategy,
309 timeout,
310 sink,
311 )
312 .await
313 });
314 (stream, join_handle)
315}
316
317/**
318Stream new submissions in a subreddit.
319
320Creates a separate tokio task that regularly polls the subreddit for new
321submissions. Previously unseen submissions are sent into the returned
322stream.
323
324Returns a tuple `(stream, join_handle)` where `stream` is the
325[`Stream`](futures::Stream) from which the submissions can be read, and
326`join_handle` is the [`JoinHandle`](tokio::task::JoinHandle) for the
327polling task.
328
329`sleep_time` controls the interval between calls to the Reddit API, and
330depends on how much traffic the subreddit has. Each call fetches the 100
331latest items (the maximum number allowed by Reddit). A warning is logged
332if none of those items has been seen in the previous call: this indicates
333a potential miss of new content and suggests that a smaller `sleep_time`
334should be chosen. Enable debug logging for more statistics.
335
336If `timeout` is not `None` then calls to the Reddit API that take longer
337than `timeout` are aborted with a [`StreamError::TimeoutError`].
338
339If an error occurs while fetching the latest submissions from Reddit then
340fetching is retried according to `retry_strategy` (see [`tokio_retry`] for
341details). If one of the retries succeeds then normal operation is resumed.
342If `retry_strategy` is finite and the last retry fails then its error is
343sent into the stream, afterwards normal operation is resumed.
344
345The spawned task runs indefinitely unless an error is encountered when
346sending data into the stream (for example because the receiver is dropped).
347In that case the task stops and the error is returned via `join_handle`.
348
349See also [`stream_comments`].
350
351
352# Example
353
354The following example prints new submissions to
355[r/AskReddit](https://reddit.com/r/AskReddit) in an endless loop.
356
357```
358use futures::StreamExt;
359use roux::Subreddit;
360use roux_stream::stream_submissions;
361use std::time::Duration;
362use tokio_retry::strategy::ExponentialBackoff;
363
364#[tokio::main]
365async fn main() {
366 let subreddit = Subreddit::new("AskReddit");
367
368 // How often to retry when pulling the data from Reddit fails and
369 // how long to wait between retries. See the docs of `tokio_retry`
370 // for details.
371 let retry_strategy = ExponentialBackoff::from_millis(5).factor(100).take(3);
372
373 let (mut stream, join_handle) = stream_submissions(
374 &subreddit,
375 Duration::from_secs(60),
376 retry_strategy,
377 Some(Duration::from_secs(10)),
378 );
379
380 while let Some(submission) = stream.next().await {
381 // `submission` is an `Err` if getting the latest submissions
382 // from Reddit failed even after retrying.
383 let submission = submission.unwrap();
384 println!("\"{}\" by {}", submission.title, submission.author);
385 # // An endless loop doesn't work well with doctests, so in that
386 # // case we abort the task and exit the loop directly.
387 # join_handle.abort();
388 # break;
389 }
390 # // Aborting the task will make the join handle return an error. Let's
391 # // make sure it's the right one.
392 # let join_result = join_handle.await;
393 # assert!(join_result.is_err());
394 # assert!(join_result.err().unwrap().is_cancelled());
395 # // Now we need to make sure that the remaining code in the example
396 # // still works, so we create a fake `join_handle` for it to work
397 # // with.
398 # let join_handle = async { Some(Some(())) };
399
400 // In case there was an error sending the submissions through the
401 // stream, `join_handle` will report it.
402 join_handle.await.unwrap().unwrap();
403}
404```
405*/
406pub fn stream_submissions<R, I>(
407 subreddit: &Subreddit,
408 sleep_time: Duration,
409 retry_strategy: R,
410 timeout: Option<Duration>,
411) -> (
412 impl Stream<Item = Result<SubmissionData, StreamError<RouxError>>>,
413 JoinHandle<Result<(), mpsc::SendError>>,
414)
415where
416 R: IntoIterator<IntoIter = I, Item = Duration> + Clone + Send + Sync + 'static,
417 I: Iterator<Item = Duration> + Send + Sync + 'static,
418{
419 stream_items(subreddit, sleep_time, retry_strategy, timeout)
420}
421
422/**
423Stream new comments in a subreddit.
424
425Creates a separate tokio task that regularly polls the subreddit for new
426comments. Previously unseen comments are sent into the returned
427stream.
428
429Returns a tuple `(stream, join_handle)` where `stream` is the
430[`Stream`](futures::Stream) from which the comments can be read, and
431`join_handle` is the [`JoinHandle`](tokio::task::JoinHandle) for the
432polling task.
433
434`sleep_time` controls the interval between calls to the Reddit API, and
435depends on how much traffic the subreddit has. Each call fetches the 100
436latest items (the maximum number allowed by Reddit). A warning is logged
437if none of those items has been seen in the previous call: this indicates
438a potential miss of new content and suggests that a smaller `sleep_time`
439should be chosen. Enable debug logging for more statistics.
440
441If `timeout` is not `None` then calls to the Reddit API that take longer
442than `timeout` are aborted with a [`StreamError::TimeoutError`].
443
444If an error occurs while fetching the latest comments from Reddit then
445fetching is retried according to `retry_strategy` (see [`tokio_retry`] for
446details). If one of the retries succeeds then normal operation is resumed.
447If `retry_strategy` is finite and the last retry fails then its error is
448sent into the stream, afterwards normal operation is resumed.
449
450The spawned task runs indefinitely unless an error is encountered when
451sending data into the stream (for example because the receiver is dropped).
452In that case the task stops and the error is returned via `join_handle`.
453
454See also [`stream_submissions`].
455
456
457# Example
458
459The following example prints new comments to
460[r/AskReddit](https://reddit.com/r/AskReddit) in an endless loop.
461
462```
463use futures::StreamExt;
464use roux::Subreddit;
465use roux_stream::stream_comments;
466use std::time::Duration;
467use tokio_retry::strategy::ExponentialBackoff;
468
469
470#[tokio::main]
471async fn main() {
472 let subreddit = Subreddit::new("AskReddit");
473
474 // How often to retry when pulling the data from Reddit fails and
475 // how long to wait between retries. See the docs of `tokio_retry`
476 // for details.
477 let retry_strategy = ExponentialBackoff::from_millis(5).factor(100).take(3);
478
479 let (mut stream, join_handle) = stream_comments(
480 &subreddit,
481 Duration::from_secs(10),
482 retry_strategy,
483 Some(Duration::from_secs(10)),
484 );
485
486 while let Some(comment) = stream.next().await {
487 // `comment` is an `Err` if getting the latest comments
488 // from Reddit failed even after retrying.
489 let comment = comment.unwrap();
490 println!(
491 "{}{} (by u/{})",
492 comment.link_url.unwrap(),
493 comment.id.unwrap(),
494 comment.author.unwrap()
495 );
496 # // An endless loop doesn't work well with doctests, so in that
497 # // case we abort the task and exit the loop directly.
498 # join_handle.abort();
499 # break;
500 }
501 # // Aborting the task will make the join handle return an error. Let's
502 # // make sure it's the right one.
503 # let join_result = join_handle.await;
504 # assert!(join_result.is_err());
505 # assert!(join_result.err().unwrap().is_cancelled());
506 # // Now we need to make sure that the remaining code in the example
507 # // still works, so we create a fake `join_handle` for it to work
508 # // with.
509 # let join_handle = async { Some(Some(())) };
510
511 // In case there was an error sending the submissions through the
512 // stream, `join_handle` will report it.
513 join_handle.await.unwrap().unwrap();
514}
515```
516*/
517pub fn stream_comments<R, I>(
518 subreddit: &Subreddit,
519 sleep_time: Duration,
520 retry_strategy: R,
521 timeout: Option<Duration>,
522) -> (
523 impl Stream<Item = Result<CommentData, StreamError<RouxError>>>,
524 JoinHandle<Result<(), mpsc::SendError>>,
525)
526where
527 R: IntoIterator<IntoIter = I, Item = Duration> + Clone + Send + Sync + 'static,
528 I: Iterator<Item = Duration> + Send + Sync + 'static,
529{
530 stream_items(subreddit, sleep_time, retry_strategy, timeout)
531}
532
533#[cfg(test)]
534mod tests {
535 use super::{pull_into_sink, Puller, StreamError};
536 use async_trait::async_trait;
537 use futures::{channel::mpsc, StreamExt};
538 use log::{Level, LevelFilter};
539 use logtest::Logger;
540 use roux::response::{BasicThing, Listing};
541 use std::{error::Error, fmt::Display, time::Duration};
542 use tokio::{sync::RwLock, time::sleep};
543
544 /*
545 Any test case that checks the logging output must run in isolation,
546 so that the log output of other test cases does not disturb it. We
547 use an `RwLock` to achieve that: tests that do log checking take a
548 write lock, while the other test cases take a read lock.
549 */
550 static LOCK: RwLock<()> = RwLock::const_new(());
551
552 #[derive(Debug, PartialEq)]
553 struct MockSourceError(String);
554
555 impl Display for MockSourceError {
556 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
557 write!(f, "{}", self.0)
558 }
559 }
560
561 impl Error for MockSourceError {}
562
563 struct MockPuller {
564 iter: Box<dyn Iterator<Item = Vec<String>> + Sync + Send>,
565 }
566
567 impl MockPuller {
568 fn new(batches: Vec<Vec<&str>>) -> Self {
569 MockPuller {
570 iter: Box::new(
571 batches
572 .iter()
573 .map(|batch| batch.iter().map(|item| item.to_string()).collect())
574 .collect::<Vec<Vec<String>>>()
575 .into_iter(),
576 ),
577 }
578 }
579 }
580
581 #[async_trait]
582 impl Puller<String, MockSourceError> for MockPuller {
583 /*
584 Each call to `pull` returns the next batch of items. If a batch
585 consists of a single String that begins with "error" then instead
586 of an Ok an Err is returned. If a batch consists of a single String
587 that begins with "sleep" then the function sleeps for 1s before
588 returning.
589 */
590 async fn pull(
591 &mut self,
592 ) -> Result<BasicThing<Listing<BasicThing<String>>>, MockSourceError> {
593 let children;
594 if let Some(items) = self.iter.next() {
595 match items.as_slice() {
596 [item] if item.starts_with("error") => {
597 return Err(MockSourceError(item.clone()));
598 }
599 _ => {
600 if items.len() == 1 && items.get(0).unwrap().starts_with("sleep") {
601 sleep(Duration::from_secs(1)).await;
602 }
603 children = items
604 .iter()
605 .map(|item| BasicThing {
606 kind: Some("mock".to_owned()),
607 data: item.clone(),
608 })
609 .collect();
610 }
611 }
612 } else {
613 children = vec![];
614 }
615
616 let listing = Listing {
617 modhash: None,
618 dist: None,
619 after: None,
620 before: None,
621 children: children,
622 };
623 let result = BasicThing {
624 kind: Some("listing".to_owned()),
625 data: listing,
626 };
627 Ok(result)
628 }
629
630 fn get_id(&self, data: &String) -> String {
631 data.clone()
632 }
633
634 fn get_items_name(&self) -> String {
635 "MockItems".to_owned()
636 }
637
638 fn get_source_name(&self) -> String {
639 "MockSource".to_owned()
640 }
641 }
642
643 async fn check<R, I>(
644 responses: Vec<Vec<&str>>,
645 retry_strategy: R,
646 timeout: Option<Duration>,
647 expected: Vec<Result<&str, StreamError<MockSourceError>>>,
648 ) where
649 R: IntoIterator<IntoIter = I, Item = Duration> + Clone + Send + Sync + 'static,
650 I: Iterator<Item = Duration> + Send + Sync + 'static,
651 {
652 let mut mock_puller = MockPuller::new(responses);
653 let (sink, stream) = mpsc::unbounded();
654 tokio::spawn(async move {
655 pull_into_sink(
656 &mut mock_puller,
657 Duration::from_millis(1),
658 retry_strategy,
659 timeout,
660 sink,
661 )
662 .await
663 });
664 let items = stream.take(expected.len()).collect::<Vec<_>>().await;
665 assert_eq!(
666 items,
667 expected
668 .into_iter()
669 .map(|result| result.map(|ok_value| ok_value.to_string()))
670 .collect::<Vec<_>>()
671 );
672 }
673
674 #[tokio::test]
675 async fn test_simple_pull() {
676 let _lock = LOCK.read().await;
677 check(vec![vec!["hello"]], vec![], None, vec![Ok("hello")]).await;
678 }
679
680 #[tokio::test]
681 async fn test_duplicate_filtering() {
682 let _lock = LOCK.read().await;
683 check(
684 vec![vec!["a", "b", "c"], vec!["b", "c", "d"], vec!["d", "e"]],
685 vec![],
686 None,
687 vec![Ok("a"), Ok("b"), Ok("c"), Ok("d"), Ok("e")],
688 )
689 .await;
690 }
691
692 #[tokio::test]
693 async fn test_success_after_retry() {
694 let _lock = LOCK.read().await;
695 check(
696 vec![
697 vec!["a", "b", "c"],
698 vec!["error1"],
699 vec!["error2"],
700 vec!["b", "c", "d"],
701 ],
702 vec![Duration::from_millis(1), Duration::from_millis(1)],
703 None,
704 vec![Ok("a"), Ok("b"), Ok("c"), Ok("d")],
705 )
706 .await;
707 }
708
709 #[tokio::test]
710 async fn test_failure_after_retry() {
711 let _lock = LOCK.read().await;
712 check(
713 vec![
714 vec!["a", "b", "c"],
715 vec!["error1"],
716 vec!["error2"],
717 vec!["b", "c", "d"],
718 ],
719 vec![Duration::from_millis(1)],
720 None,
721 vec![
722 Ok("a"),
723 Ok("b"),
724 Ok("c"),
725 Err(StreamError::SourceError(MockSourceError(
726 "error2".to_owned(),
727 ))),
728 Ok("d"),
729 ],
730 )
731 .await;
732 }
733
734 #[tokio::test]
735 async fn test_warning_if_all_items_are_unseen() {
736 let _lock = LOCK.write().await; // exclusive lock
737 let mut logger = Logger::start();
738 log::set_max_level(LevelFilter::Warn);
739 check(
740 vec![vec!["a", "b"], vec!["c", "d"]],
741 vec![],
742 None,
743 vec![Ok("a"), Ok("b"), Ok("c"), Ok("d")],
744 )
745 .await;
746
747 let num_records = logger.len();
748 if num_records != 1 {
749 println!();
750 println!("{} LOG MESSAGES:", logger.len());
751 while let Some(record) = logger.pop() {
752 println!("[{}] {}", record.level(), record.args());
753 }
754 println!();
755 assert!(false, "Expected 1 log message, got {}", num_records);
756 }
757
758 let record = logger.pop().unwrap();
759 assert_eq!(record.level(), Level::Warn);
760 assert_eq!(
761 record.args(),
762 "All received MockItems for MockSource were new, try a shorter sleep_time",
763 );
764 }
765
766 #[tokio::test]
767 async fn test_sink_error_when_sending_new_item() {
768 let _lock = LOCK.read().await;
769 let mut mock_puller = MockPuller::new(vec![vec!["a"]]);
770 let (sink, stream) = mpsc::unbounded();
771 drop(stream); // drop receiver so that sending fails
772 let join_handle = tokio::spawn(async move {
773 pull_into_sink(
774 &mut mock_puller,
775 Duration::from_millis(1),
776 vec![],
777 None,
778 sink,
779 )
780 .await
781 });
782 let result = join_handle.await.unwrap();
783 assert!(result.is_err());
784 }
785
786 #[tokio::test]
787 async fn test_sink_error_when_sending_error() {
788 let _lock = LOCK.read().await;
789 let mut mock_puller = MockPuller::new(vec![vec!["error"]]);
790 let (sink, stream) = mpsc::unbounded();
791 drop(stream); // drop receiver so that sending fails
792 let join_handle = tokio::spawn(async move {
793 pull_into_sink(
794 &mut mock_puller,
795 Duration::from_millis(1),
796 vec![],
797 None,
798 sink,
799 )
800 .await
801 });
802 let result = join_handle.await.unwrap();
803 assert!(result.is_err());
804 }
805
806 #[tokio::test]
807 async fn test_timeout_ok() {
808 let _lock = LOCK.read().await;
809 check(
810 vec![vec!["a", "b", "c"], vec!["b", "c", "d"]],
811 vec![],
812 Some(Duration::from_secs(1)),
813 vec![Ok("a"), Ok("b"), Ok("c"), Ok("d")],
814 )
815 .await;
816 }
817
818 #[tokio::test]
819 async fn test_timeout_error() {
820 let _lock = LOCK.read().await;
821
822 let timeout = Duration::from_millis(100);
823
824 // There is probably a less stupid way of constructing
825 // an instance of Elapsed...
826 let elapsed = tokio::time::timeout(timeout.clone(), sleep(Duration::from_secs(1)))
827 .await
828 .unwrap_err();
829
830 check(
831 vec![vec!["a", "b", "c"], vec!["sleep"], vec!["b", "c", "d"]],
832 vec![],
833 Some(timeout),
834 vec![
835 Ok("a"),
836 Ok("b"),
837 Ok("c"),
838 Err(StreamError::TimeoutError(elapsed)),
839 Ok("d"),
840 ],
841 )
842 .await;
843 }
844
845 #[tokio::test]
846 async fn test_timeout_retry() {
847 let _lock = LOCK.read().await;
848
849 check(
850 vec![vec!["a", "b", "c"], vec!["sleep"], vec!["b", "c", "d"]],
851 vec![Duration::from_millis(1)],
852 Some(Duration::from_millis(100)),
853 vec![Ok("a"), Ok("b"), Ok("c"), Ok("d")],
854 )
855 .await;
856 }
857}