use std::{
collections::VecDeque,
sync::{
atomic::{AtomicUsize, Ordering},
Arc, RwLock,
},
};
struct BufferItem<T> {
value: T,
ref_count: AtomicUsize,
}
struct Shared<I: Iterator> {
iter: Option<I>,
buffer: VecDeque<BufferItem<I::Item>>,
next_item_ref_count: AtomicUsize,
num_items_dropped: usize,
}
#[derive(Debug)]
enum Outcome<T> {
Ready(Option<T>),
PastTheBuffer,
TakeTail,
DropTail(T),
}
impl<I> Shared<I>
where
I: Iterator,
I::Item: Clone,
{
fn offset(&self, pos: usize) -> usize {
debug_assert!(pos >= self.num_items_dropped);
let offset = pos - self.num_items_dropped;
debug_assert!(offset <= self.buffer.len());
offset
}
fn inc_ref_count(&self, offset: usize) {
let count = if offset == self.buffer.len() {
&self.next_item_ref_count
} else {
&self.buffer[offset].ref_count
};
count.fetch_add(1, Ordering::Relaxed);
}
fn dec_ref_count(&self, offset: usize) -> bool {
let count = if offset == self.buffer.len() {
&self.next_item_ref_count
} else {
&self.buffer[offset].ref_count
};
count.fetch_sub(1, Ordering::Relaxed) == 1
}
fn advance_ref_count(&self, offset: usize) -> bool {
self.inc_ref_count(offset + 1);
self.dec_ref_count(offset)
}
fn try_take(&self, offset: usize) -> Outcome<I::Item> {
if offset == self.buffer.len() {
if self.iter.is_some() {
Outcome::PastTheBuffer
} else {
Outcome::Ready(None)
}
} else if offset > 0 {
let value = self.buffer[offset].value.clone();
self.advance_ref_count(offset);
Outcome::Ready(Some(value))
} else if self.buffer[0].ref_count.load(Ordering::Relaxed) == 1 {
Outcome::TakeTail
} else {
let value = self.buffer[0].value.clone();
let was_last = self.advance_ref_count(0);
if was_last {
Outcome::DropTail(value)
} else {
Outcome::Ready(Some(value))
}
}
}
fn pull_next_item(&mut self) -> Option<I::Item> {
let iter = self.iter.as_mut().expect("iter should not be none here");
let value = match iter.next() {
Some(value) => value,
None => {
self.iter = None;
return None;
}
};
if self.buffer.is_empty() && *self.next_item_ref_count.get_mut() == 1 {
self.num_items_dropped += 1;
return Some(value);
}
let new_item_ref_count = std::mem::replace(self.next_item_ref_count.get_mut(), 1) - 1;
let new_item = BufferItem {
value: value.clone(),
ref_count: AtomicUsize::new(new_item_ref_count),
};
self.buffer.push_back(new_item);
Some(value)
}
fn drop_tail(&mut self) {
while let Some(buffer_item) = self.buffer.front_mut() {
if *buffer_item.ref_count.get_mut() > 0 {
break;
}
self.buffer.pop_front();
self.num_items_dropped += 1;
}
}
fn take(this: &RwLock<Self>, pos: usize) -> Option<I::Item> {
let mut outcome;
let mut offset;
{
let shared = this.read().unwrap();
offset = shared.offset(pos);
outcome = shared.try_take(offset);
};
if let Outcome::Ready(item) = outcome {
return item;
}
let mut shared = this.write().unwrap();
if let Outcome::PastTheBuffer = outcome {
offset = shared.offset(pos);
outcome = shared.try_take(offset);
}
match outcome {
Outcome::Ready(item) => item,
Outcome::PastTheBuffer => shared.pull_next_item(),
Outcome::TakeTail => {
debug_assert_eq!(offset, 0);
shared.advance_ref_count(0);
let mut buffer_item = shared
.buffer
.pop_front()
.expect("the buffer should not be empty here");
debug_assert_eq!(*buffer_item.ref_count.get_mut(), 0);
shared.num_items_dropped += 1;
Some(buffer_item.value)
}
Outcome::DropTail(item) => {
debug_assert_eq!(offset, 0);
shared.drop_tail();
Some(item)
}
}
}
}
pub struct Tee<I>
where
I: Iterator,
I::Item: Clone,
{
shared: Arc<RwLock<Shared<I>>>,
pos: usize,
}
impl<I> Tee<I>
where
I: Iterator,
I::Item: Clone,
{
pub fn new(iter: I) -> Self {
let shared = Shared {
iter: Some(iter),
buffer: VecDeque::new(),
next_item_ref_count: AtomicUsize::new(1),
num_items_dropped: 0,
};
Tee {
shared: Arc::new(RwLock::new(shared)),
pos: 0,
}
}
}
impl<I> Clone for Tee<I>
where
I: Iterator,
I::Item: Clone,
{
fn clone(&self) -> Self {
{
let shared = self.shared.read().unwrap();
let offset = shared.offset(self.pos);
shared.inc_ref_count(offset);
}
Tee {
shared: self.shared.clone(),
pos: self.pos,
}
}
}
impl<I> Drop for Tee<I>
where
I: Iterator,
I::Item: Clone,
{
fn drop(&mut self) {
let need_to_drop;
if let Ok(shared) = self.shared.read() {
let offset = shared.offset(self.pos);
let was_last = shared.dec_ref_count(offset);
need_to_drop = offset == 0 && was_last;
} else {
return;
}
if !need_to_drop {
return;
}
if let Ok(mut shared) = self.shared.write() {
shared.drop_tail();
}
}
}
impl<I> Iterator for Tee<I>
where
I: Iterator,
I::Item: Clone,
{
type Item = I::Item;
fn next(&mut self) -> Option<Self::Item> {
let item = Shared::take(&self.shared, self.pos);
if item.is_some() {
self.pos += 1;
}
item
}
fn size_hint(&self) -> (usize, Option<usize>) {
let shared = self.shared.read().unwrap();
let total_buffered = shared.num_items_dropped + shared.buffer.len();
let more_in_buffer = total_buffered - self.pos;
let (iter_min, iter_max) = match &shared.iter {
Some(iter) => iter.size_hint(),
None => (0, Some(0)),
};
(
more_in_buffer + iter_min,
iter_max.map(|im| more_in_buffer + im),
)
}
}
#[cfg(test)]
mod tests {
use super::Tee;
use std::{fmt::Debug, thread};
fn make_string_iter() -> impl Iterator<Item = String> {
(0..1024).map(|i| i.to_string())
}
fn assert_iter_eq<I1, I2>(mut i1: I1, mut i2: I2)
where
I1: Iterator,
I2: Iterator<Item = I1::Item>,
I1::Item: PartialEq + Debug,
{
while let Some(item1) = i1.next() {
assert_eq!(item1, i2.next().unwrap());
}
assert!(i2.next().is_none());
}
#[test]
fn just_one_tee() {
let tee = Tee::new(make_string_iter());
assert_iter_eq(tee, make_string_iter());
}
#[test]
fn two_tees() {
let tee1 = Tee::new(make_string_iter());
let tee2 = tee1.clone();
assert_iter_eq(tee1, make_string_iter());
assert_iter_eq(tee2, make_string_iter());
}
#[test]
fn two_tees_parallel() {
let tee1 = Tee::new(make_string_iter());
let tee2 = tee1.clone();
let t1 = thread::spawn(|| assert_iter_eq(tee1, make_string_iter()));
let t2 = thread::spawn(|| assert_iter_eq(tee2, make_string_iter()));
t1.join().unwrap();
t2.join().unwrap();
}
#[test]
fn ten_tees_parallel() {
let tee = Tee::new(make_string_iter());
let mut threads = vec![];
for tee in vec![tee; 10] {
let t = thread::spawn(|| assert_iter_eq(tee, make_string_iter()));
threads.push(t);
}
for t in threads {
t.join().unwrap();
}
}
#[test]
fn drop_in_the_middle() {
let tee = Tee::new(make_string_iter());
let mut threads = vec![];
for (i, tee) in vec![tee; 10].into_iter().enumerate() {
let t = thread::spawn(move || assert_iter_eq(tee.take(i), make_string_iter().take(i)));
threads.push(t);
}
for t in threads {
t.join().unwrap();
}
}
#[test]
fn clone_in_the_middle() {
let mut tee1 = Tee::new(make_string_iter());
assert_iter_eq(
tee1.by_ref().take(10),
make_string_iter().take(10)
);
let tee2 = tee1.clone();
assert_iter_eq(
tee1,
make_string_iter().skip(10)
);
assert_iter_eq(
tee2,
make_string_iter().skip(10)
);
}
}