use std::sync::{
mpsc::{channel, Receiver, SendError, Sender, TryRecvError},
Arc, Mutex,
};
use slab::Slab;
#[derive(Debug)]
struct SharedIterCore<I: Iterator> {
iter: I,
sender: Slab<Sender<I::Item>>,
}
impl<I: Iterator> SharedIterCore<I> {
fn new(iter: I) -> Self {
Self {
iter,
sender: Slab::with_capacity(1),
}
}
fn send(&mut self, val: I::Item) -> Result<(), SendError<I::Item>>
where
I::Item: Copy,
{
for (_, sender) in self.sender.iter() {
sender.send(val)?;
}
Ok(())
}
fn next(&mut self)
where
I::Item: Copy,
{
if let Some(val) = self.iter.next() {
self.send(val).expect("");
}
}
fn new_recv(&mut self) -> (usize, Receiver<I::Item>) {
let (sender, receiver) = channel();
let id = self.sender.insert(sender);
(id, receiver)
}
fn remove_recv(&mut self, id: usize) {
self.sender.remove(id);
}
}
#[derive(Debug)]
pub struct SharedIter<I: Iterator> {
id: usize,
inner: Arc<Mutex<SharedIterCore<I>>>,
receiver: Receiver<I::Item>,
}
impl<I: Iterator> SharedIter<I> {
fn new(iter: I) -> Self {
let mut inner = SharedIterCore::new(iter);
let (id, receiver) = inner.new_recv();
Self {
id,
inner: Arc::new(Mutex::new(inner)),
receiver,
}
}
}
impl<I: Iterator> Clone for SharedIter<I> {
fn clone(&self) -> Self {
let (id, receiver) = self.inner.lock().unwrap().new_recv();
Self {
inner: self.inner.clone(),
receiver,
id,
}
}
}
impl<I: Iterator> Iterator for SharedIter<I>
where
I::Item: Copy,
{
type Item = I::Item;
fn next(&mut self) -> Option<I::Item> {
match self.receiver.try_recv() {
Ok(val) => Some(val),
Err(TryRecvError::Disconnected) => None,
Err(TryRecvError::Empty) => {
self.inner.lock().unwrap().next();
self.receiver.try_recv().ok()
}
}
}
}
impl<I: Iterator> Drop for SharedIter<I> {
fn drop(&mut self) {
self.inner.lock().unwrap().remove_recv(self.id);
}
}
pub trait ShareIterator: Iterator + Sized {
fn share(self) -> SharedIter<Self>;
}
impl<I: Iterator> ShareIterator for I {
fn share(self) -> SharedIter<Self> {
SharedIter::new(self)
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_iter() {
let iter = (1..20).share();
let iter2 = iter.clone();
assert_eq!(
iter.take(10).collect::<Vec<_>>(),
iter2.take(10).collect::<Vec<_>>()
);
}
#[test]
fn test_multi_threaded() {
use std::thread;
let iter = (1..).share();
let threads = (0..5)
.map(|_| iter.clone())
.collect::<Vec<_>>()
.into_iter()
.map(|liter| thread::spawn(move || liter.take(10).collect::<Vec<_>>()))
.collect::<Vec<_>>();
let r = iter.take(10).collect::<Vec<_>>();
for t in threads {
assert_eq!(t.join().unwrap(), r);
}
}
}