iter_tee/
lib.rs

1//! Make several clones of an iterator.
2//!
3//! Each handle to the iterator is represented with an instance of [`Tee`]. A
4//! `Tee` is itself an iterator which will yield the same sequence of items as
5//! the original iterator. A `Tee` can be freely cloned at any point to create
6//! more handles to the same underlying iterator. Once cloned, the two `Tee`s
7//! are identical, but separate: they will yield the same items.
8//!
9//! The implementation uses a single ring buffer for storing items already
10//! pulled from the underlying iterator, but not yet consumed by all the `Tee`s.
11//! The buffer is protected with a [`RwLock`], and [atomics](std::sync::atomic)
12//! are used to keep item reference counts.
13//!
14//! While the implementation tries to be efficient, it will not be as efficient
15//! as natively cloning the underlying iterator if it implements [`Clone`].
16//!
17//! # Examples
18//!
19//! ```
20//! use iter_tee::Tee;
21//!
22//! // Wrap an iterator in a Tee:
23//! let mut tee1 = Tee::new(0..10);
24//! // It yields the same items:
25//! assert_eq!(tee1.next(), Some(0));
26//! assert_eq!(tee1.next(), Some(1));
27//! // Create a second Tee:
28//! let mut tee2 = tee1.clone();
29//! // Both yield the same items:
30//! assert_eq!(tee1.next(), Some(2));
31//! assert_eq!(tee2.next(), Some(2));
32//! // Create a third Tee:
33//! let mut tee3 = tee2.clone();
34//! // All three yield the same items:
35//! assert_eq!(tee1.next(), Some(3));
36//! assert_eq!(tee2.next(), Some(3));
37//! assert_eq!(tee3.next(), Some(3));
38//! // The Tees can be advanced independently:
39//! assert_eq!(tee1.next(), Some(4));
40//! assert_eq!(tee1.next(), Some(5));
41//! assert_eq!(tee2.next(), Some(4));
42//! ```
43
44
45use std::{
46    collections::VecDeque,
47    sync::{
48        atomic::{AtomicUsize, Ordering},
49        Arc, RwLock,
50    },
51};
52
53struct BufferItem<T> {
54    value: T,
55    ref_count: AtomicUsize,
56}
57
58struct Shared<I: Iterator> {
59    iter: Option<I>,
60    buffer: VecDeque<BufferItem<I::Item>>,
61    next_item_ref_count: AtomicUsize,
62    num_items_dropped: usize,
63}
64
65#[derive(Debug)]
66enum Outcome<T> {
67    /// The value is ready, and the ref count has already been advanced.
68    Ready(Option<T>),
69    /// Will need to re-lock for writing and pull the next item from the
70    /// iterator. The ref count has not been advanced.
71    PastTheBuffer,
72    /// Will need to re-lock for writing and take the last item (without cloning
73    /// its value). The ref count has not been advanced.
74    TakeTail,
75    /// The value is ready, but we will need to re-lock for writing and clean up
76    /// the tail. The ref count has already been advanced.
77    DropTail(T),
78}
79
80impl<I> Shared<I>
81where
82    I: Iterator,
83    I::Item: Clone,
84{
85    fn offset(&self, pos: usize) -> usize {
86        debug_assert!(pos >= self.num_items_dropped);
87        let offset = pos - self.num_items_dropped;
88        debug_assert!(offset <= self.buffer.len());
89        offset
90    }
91
92    fn inc_ref_count(&self, offset: usize) {
93        let count = if offset == self.buffer.len() {
94            &self.next_item_ref_count
95        } else {
96            &self.buffer[offset].ref_count
97        };
98        count.fetch_add(1, Ordering::Relaxed);
99    }
100
101    fn dec_ref_count(&self, offset: usize) -> bool {
102        let count = if offset == self.buffer.len() {
103            &self.next_item_ref_count
104        } else {
105            &self.buffer[offset].ref_count
106        };
107        count.fetch_sub(1, Ordering::Relaxed) == 1
108    }
109
110    fn advance_ref_count(&self, offset: usize) -> bool {
111        self.inc_ref_count(offset + 1);
112        self.dec_ref_count(offset)
113    }
114
115    fn try_take(&self, offset: usize) -> Outcome<I::Item> {
116        if offset == self.buffer.len() {
117            // We're past the buffer; need to pull the next
118            // item from the iterator. If there is still an
119            // iterator in the first place.
120            if self.iter.is_some() {
121                Outcome::PastTheBuffer
122            } else {
123                Outcome::Ready(None)
124            }
125        } else if offset > 0 {
126            // Fast path: we're in the middle of the buffer.
127            let value = self.buffer[offset].value.clone();
128            self.advance_ref_count(offset);
129            Outcome::Ready(Some(value))
130        } else if self.buffer[0].ref_count.load(Ordering::Relaxed) == 1 {
131            // We're the only one still interested in that item;
132            // take it without cloning.
133            Outcome::TakeTail
134        } else {
135            let value = self.buffer[0].value.clone();
136            let was_last = self.advance_ref_count(0);
137            if was_last {
138                Outcome::DropTail(value)
139            } else {
140                Outcome::Ready(Some(value))
141            }
142        }
143    }
144
145    /// Attempts to pull the next item from the iterator.
146    ///
147    /// Advances the ref count.
148    fn pull_next_item(&mut self) -> Option<I::Item> {
149        let iter = self.iter.as_mut().expect("iter should not be none here");
150        let value = match iter.next() {
151            Some(value) => value,
152            None => {
153                // We have exhausted the underlying iterator; drop it.
154                self.iter = None;
155                return None;
156            }
157        };
158        if self.buffer.is_empty() && *self.next_item_ref_count.get_mut() == 1 {
159            // We're the only consumer out there!
160            // Skip the buffering altogether.
161            self.num_items_dropped += 1;
162            return Some(value);
163        }
164        // So far, we're the only one interested in the *next* next item.
165        let new_item_ref_count = std::mem::replace(self.next_item_ref_count.get_mut(), 1) - 1;
166        let new_item = BufferItem {
167            value: value.clone(),
168            ref_count: AtomicUsize::new(new_item_ref_count),
169        };
170        self.buffer.push_back(new_item);
171        Some(value)
172    }
173
174    /// Drops any unused tail of the buffer.
175    fn drop_tail(&mut self) {
176        while let Some(buffer_item) = self.buffer.front_mut() {
177            if *buffer_item.ref_count.get_mut() > 0 {
178                break;
179            }
180            self.buffer.pop_front();
181            self.num_items_dropped += 1;
182        }
183    }
184
185    fn take(this: &RwLock<Self>, pos: usize) -> Option<I::Item> {
186        let mut outcome;
187        let mut offset;
188        // First, lock for reading and see if that's enough.
189        {
190            let shared = this.read().unwrap();
191            offset = shared.offset(pos);
192            outcome = shared.try_take(offset);
193        };
194        if let Outcome::Ready(item) = outcome {
195            return item;
196        }
197
198        // Now, lock for writing.
199        let mut shared = this.write().unwrap();
200        // If we were past the buffer, we might be in any situation now.
201        // Re-evaluate.
202        if let Outcome::PastTheBuffer = outcome {
203            offset = shared.offset(pos);
204            outcome = shared.try_take(offset);
205        }
206
207        match outcome {
208            Outcome::Ready(item) => item,
209            Outcome::PastTheBuffer => shared.pull_next_item(),
210            Outcome::TakeTail => {
211                debug_assert_eq!(offset, 0);
212                shared.advance_ref_count(0);
213                let mut buffer_item = shared
214                    .buffer
215                    .pop_front()
216                    .expect("the buffer should not be empty here");
217                debug_assert_eq!(*buffer_item.ref_count.get_mut(), 0);
218                shared.num_items_dropped += 1;
219                Some(buffer_item.value)
220            }
221            Outcome::DropTail(item) => {
222                debug_assert_eq!(offset, 0);
223                shared.drop_tail();
224                Some(item)
225            }
226        }
227    }
228}
229
230/// Shared iterator handle.
231///
232/// `Tee`s can be freely cloned at any point to get several independent handles
233/// to the same underlying iterator.
234pub struct Tee<I>
235where
236    I: Iterator,
237    I::Item: Clone,
238{
239    shared: Arc<RwLock<Shared<I>>>,
240    pos: usize,
241}
242
243impl<I> Tee<I>
244where
245    I: Iterator,
246    I::Item: Clone,
247{
248    /// Wraps an iterator into a new `Tee`.
249    pub fn new(iter: I) -> Self {
250        let shared = Shared {
251            iter: Some(iter),
252            buffer: VecDeque::new(),
253            next_item_ref_count: AtomicUsize::new(1),
254            num_items_dropped: 0,
255        };
256        Tee {
257            shared: Arc::new(RwLock::new(shared)),
258            pos: 0,
259        }
260    }
261}
262
263impl<I> Clone for Tee<I>
264where
265    I: Iterator,
266    I::Item: Clone,
267{
268    fn clone(&self) -> Self {
269        {
270            let shared = self.shared.read().unwrap();
271            let offset = shared.offset(self.pos);
272            shared.inc_ref_count(offset);
273        }
274        Tee {
275            shared: self.shared.clone(),
276            pos: self.pos,
277        }
278    }
279}
280
281impl<I> Drop for Tee<I>
282where
283    I: Iterator,
284    I::Item: Clone,
285{
286    fn drop(&mut self) {
287        let need_to_drop;
288
289        if let Ok(shared) = self.shared.read() {
290            let offset = shared.offset(self.pos);
291            let was_last = shared.dec_ref_count(offset);
292            need_to_drop = offset == 0 && was_last;
293        } else {
294            // If the lock is poisoned, do not propagate the panic into this
295            // thread. It's fine if we leave an extra ref count.
296            return;
297        }
298        if !need_to_drop {
299            return;
300        }
301        if let Ok(mut shared) = self.shared.write() {
302            shared.drop_tail();
303        }
304    }
305}
306
307impl<I> Iterator for Tee<I>
308where
309    I: Iterator,
310    I::Item: Clone,
311{
312    type Item = I::Item;
313
314    fn next(&mut self) -> Option<Self::Item> {
315        let item = Shared::take(&self.shared, self.pos);
316        if item.is_some() {
317            self.pos += 1;
318        }
319        item
320    }
321
322    fn size_hint(&self) -> (usize, Option<usize>) {
323        let shared = self.shared.read().unwrap();
324        let total_buffered = shared.num_items_dropped + shared.buffer.len();
325        let more_in_buffer = total_buffered - self.pos;
326        let (iter_min, iter_max) = match &shared.iter {
327            Some(iter) => iter.size_hint(),
328            None => (0, Some(0)),
329        };
330        (
331            more_in_buffer + iter_min,
332            iter_max.map(|im| more_in_buffer + im),
333        )
334    }
335}
336
337#[cfg(test)]
338mod tests {
339    use super::Tee;
340    use std::{fmt::Debug, thread};
341
342    fn make_string_iter() -> impl Iterator<Item = String> {
343        (0..1024).map(|i| i.to_string())
344    }
345
346    fn assert_iter_eq<I1, I2>(mut i1: I1, mut i2: I2)
347    where
348        I1: Iterator,
349        I2: Iterator<Item = I1::Item>,
350        I1::Item: PartialEq + Debug,
351    {
352        while let Some(item1) = i1.next() {
353            assert_eq!(item1, i2.next().unwrap());
354        }
355        assert!(i2.next().is_none());
356    }
357
358    #[test]
359    fn just_one_tee() {
360        let tee = Tee::new(make_string_iter());
361        assert_iter_eq(tee, make_string_iter());
362    }
363
364    #[test]
365    fn two_tees() {
366        let tee1 = Tee::new(make_string_iter());
367        let tee2 = tee1.clone();
368        assert_iter_eq(tee1, make_string_iter());
369        assert_iter_eq(tee2, make_string_iter());
370    }
371
372    #[test]
373    fn two_tees_parallel() {
374        let tee1 = Tee::new(make_string_iter());
375        let tee2 = tee1.clone();
376        let t1 = thread::spawn(|| assert_iter_eq(tee1, make_string_iter()));
377        let t2 = thread::spawn(|| assert_iter_eq(tee2, make_string_iter()));
378        t1.join().unwrap();
379        t2.join().unwrap();
380    }
381
382    #[test]
383    fn ten_tees_parallel() {
384        let tee = Tee::new(make_string_iter());
385        let mut threads = vec![];
386        for tee in vec![tee; 10] {
387            let t = thread::spawn(|| assert_iter_eq(tee, make_string_iter()));
388            threads.push(t);
389        }
390        for t in threads {
391            t.join().unwrap();
392        }
393    }
394
395    #[test]
396    fn drop_in_the_middle() {
397        let tee = Tee::new(make_string_iter());
398        let mut threads = vec![];
399        for (i, tee) in vec![tee; 10].into_iter().enumerate() {
400            let t = thread::spawn(move || assert_iter_eq(tee.take(i), make_string_iter().take(i)));
401            threads.push(t);
402        }
403        for t in threads {
404            t.join().unwrap();
405        }
406    }
407
408    #[test]
409    fn clone_in_the_middle() {
410        let mut tee1 = Tee::new(make_string_iter());
411        assert_iter_eq(
412            tee1.by_ref().take(10),
413            make_string_iter().take(10)
414        );
415        let tee2 = tee1.clone();
416
417        assert_iter_eq(
418            tee1,
419            make_string_iter().skip(10)
420        );
421        assert_iter_eq(
422            tee2,
423            make_string_iter().skip(10)
424        );
425    }
426}