use std::pin::Pin;
use std::sync::Arc;
use futures::Stream;
pub type BoxStream<T> = Pin<Box<dyn Stream<Item = T> + Send>>;
#[derive(Clone)]
pub struct Channel<T> {
factory: Arc<dyn Fn() -> BoxStream<T> + Send + Sync>,
}
impl<T> Channel<T> {
pub fn of<F>(factory: F) -> Self
where
F: Fn() -> BoxStream<T> + Send + Sync + 'static,
{
Self {
factory: Arc::new(factory),
}
}
pub fn iter(&self) -> BoxStream<T> {
(self.factory)()
}
}
pub fn channel_of_vec<T, F>(make: F) -> Channel<T>
where
T: Send + 'static,
F: Fn() -> Vec<T> + Send + Sync + 'static,
{
Channel::of(move || Box::pin(futures::stream::iter(make())) as BoxStream<T>)
}
#[cfg(test)]
mod tests {
use super::*;
use futures::StreamExt;
use std::sync::atomic::{AtomicUsize, Ordering};
#[tokio::test]
async fn iterating_twice_reruns_the_factory_and_yields_full_sequence_each_time() {
let runs = Arc::new(AtomicUsize::new(0));
let runs_for_factory = runs.clone();
let channel: Channel<i32> = channel_of_vec(move || {
runs_for_factory.fetch_add(1, Ordering::SeqCst);
vec![1, 2, 3]
});
let first: Vec<i32> = channel.iter().collect().await;
assert_eq!(first, vec![1, 2, 3]);
assert_eq!(
runs.load(Ordering::SeqCst),
1,
"factory ran once after first iter"
);
let second: Vec<i32> = channel.iter().collect().await;
assert_eq!(second, vec![1, 2, 3]);
assert_eq!(
runs.load(Ordering::SeqCst),
2,
"factory RE-RAN on the second iteration (re-iterable contract)"
);
}
#[tokio::test]
async fn channel_is_cloneable_and_clones_share_the_same_factory() {
let runs = Arc::new(AtomicUsize::new(0));
let r = runs.clone();
let channel: Channel<u8> = channel_of_vec(move || {
r.fetch_add(1, Ordering::SeqCst);
vec![7]
});
let clone = channel.clone();
let a: Vec<u8> = channel.iter().collect().await;
let b: Vec<u8> = clone.iter().collect().await;
assert_eq!(a, vec![7]);
assert_eq!(b, vec![7]);
assert_eq!(runs.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn factory_closure_can_capture_per_iteration_state() {
let counter = Arc::new(AtomicUsize::new(0));
let c = counter.clone();
let channel: Channel<usize> = channel_of_vec(move || {
let n = c.fetch_add(1, Ordering::SeqCst);
vec![n, n + 1]
});
let first: Vec<usize> = channel.iter().collect().await;
let second: Vec<usize> = channel.iter().collect().await;
assert_eq!(first, vec![0, 1]);
assert_eq!(second, vec![1, 2]);
}
}