1use std::{
46 collections::VecDeque,
47 sync::{
48 atomic::{AtomicUsize, Ordering},
49 Arc, RwLock,
50 },
51};
52
53struct BufferItem<T> {
54 value: T,
55 ref_count: AtomicUsize,
56}
57
58struct Shared<I: Iterator> {
59 iter: Option<I>,
60 buffer: VecDeque<BufferItem<I::Item>>,
61 next_item_ref_count: AtomicUsize,
62 num_items_dropped: usize,
63}
64
65#[derive(Debug)]
66enum Outcome<T> {
67 Ready(Option<T>),
69 PastTheBuffer,
72 TakeTail,
75 DropTail(T),
78}
79
80impl<I> Shared<I>
81where
82 I: Iterator,
83 I::Item: Clone,
84{
85 fn offset(&self, pos: usize) -> usize {
86 debug_assert!(pos >= self.num_items_dropped);
87 let offset = pos - self.num_items_dropped;
88 debug_assert!(offset <= self.buffer.len());
89 offset
90 }
91
92 fn inc_ref_count(&self, offset: usize) {
93 let count = if offset == self.buffer.len() {
94 &self.next_item_ref_count
95 } else {
96 &self.buffer[offset].ref_count
97 };
98 count.fetch_add(1, Ordering::Relaxed);
99 }
100
101 fn dec_ref_count(&self, offset: usize) -> bool {
102 let count = if offset == self.buffer.len() {
103 &self.next_item_ref_count
104 } else {
105 &self.buffer[offset].ref_count
106 };
107 count.fetch_sub(1, Ordering::Relaxed) == 1
108 }
109
110 fn advance_ref_count(&self, offset: usize) -> bool {
111 self.inc_ref_count(offset + 1);
112 self.dec_ref_count(offset)
113 }
114
115 fn try_take(&self, offset: usize) -> Outcome<I::Item> {
116 if offset == self.buffer.len() {
117 if self.iter.is_some() {
121 Outcome::PastTheBuffer
122 } else {
123 Outcome::Ready(None)
124 }
125 } else if offset > 0 {
126 let value = self.buffer[offset].value.clone();
128 self.advance_ref_count(offset);
129 Outcome::Ready(Some(value))
130 } else if self.buffer[0].ref_count.load(Ordering::Relaxed) == 1 {
131 Outcome::TakeTail
134 } else {
135 let value = self.buffer[0].value.clone();
136 let was_last = self.advance_ref_count(0);
137 if was_last {
138 Outcome::DropTail(value)
139 } else {
140 Outcome::Ready(Some(value))
141 }
142 }
143 }
144
145 fn pull_next_item(&mut self) -> Option<I::Item> {
149 let iter = self.iter.as_mut().expect("iter should not be none here");
150 let value = match iter.next() {
151 Some(value) => value,
152 None => {
153 self.iter = None;
155 return None;
156 }
157 };
158 if self.buffer.is_empty() && *self.next_item_ref_count.get_mut() == 1 {
159 self.num_items_dropped += 1;
162 return Some(value);
163 }
164 let new_item_ref_count = std::mem::replace(self.next_item_ref_count.get_mut(), 1) - 1;
166 let new_item = BufferItem {
167 value: value.clone(),
168 ref_count: AtomicUsize::new(new_item_ref_count),
169 };
170 self.buffer.push_back(new_item);
171 Some(value)
172 }
173
174 fn drop_tail(&mut self) {
176 while let Some(buffer_item) = self.buffer.front_mut() {
177 if *buffer_item.ref_count.get_mut() > 0 {
178 break;
179 }
180 self.buffer.pop_front();
181 self.num_items_dropped += 1;
182 }
183 }
184
185 fn take(this: &RwLock<Self>, pos: usize) -> Option<I::Item> {
186 let mut outcome;
187 let mut offset;
188 {
190 let shared = this.read().unwrap();
191 offset = shared.offset(pos);
192 outcome = shared.try_take(offset);
193 };
194 if let Outcome::Ready(item) = outcome {
195 return item;
196 }
197
198 let mut shared = this.write().unwrap();
200 if let Outcome::PastTheBuffer = outcome {
203 offset = shared.offset(pos);
204 outcome = shared.try_take(offset);
205 }
206
207 match outcome {
208 Outcome::Ready(item) => item,
209 Outcome::PastTheBuffer => shared.pull_next_item(),
210 Outcome::TakeTail => {
211 debug_assert_eq!(offset, 0);
212 shared.advance_ref_count(0);
213 let mut buffer_item = shared
214 .buffer
215 .pop_front()
216 .expect("the buffer should not be empty here");
217 debug_assert_eq!(*buffer_item.ref_count.get_mut(), 0);
218 shared.num_items_dropped += 1;
219 Some(buffer_item.value)
220 }
221 Outcome::DropTail(item) => {
222 debug_assert_eq!(offset, 0);
223 shared.drop_tail();
224 Some(item)
225 }
226 }
227 }
228}
229
230pub struct Tee<I>
235where
236 I: Iterator,
237 I::Item: Clone,
238{
239 shared: Arc<RwLock<Shared<I>>>,
240 pos: usize,
241}
242
243impl<I> Tee<I>
244where
245 I: Iterator,
246 I::Item: Clone,
247{
248 pub fn new(iter: I) -> Self {
250 let shared = Shared {
251 iter: Some(iter),
252 buffer: VecDeque::new(),
253 next_item_ref_count: AtomicUsize::new(1),
254 num_items_dropped: 0,
255 };
256 Tee {
257 shared: Arc::new(RwLock::new(shared)),
258 pos: 0,
259 }
260 }
261}
262
263impl<I> Clone for Tee<I>
264where
265 I: Iterator,
266 I::Item: Clone,
267{
268 fn clone(&self) -> Self {
269 {
270 let shared = self.shared.read().unwrap();
271 let offset = shared.offset(self.pos);
272 shared.inc_ref_count(offset);
273 }
274 Tee {
275 shared: self.shared.clone(),
276 pos: self.pos,
277 }
278 }
279}
280
281impl<I> Drop for Tee<I>
282where
283 I: Iterator,
284 I::Item: Clone,
285{
286 fn drop(&mut self) {
287 let need_to_drop;
288
289 if let Ok(shared) = self.shared.read() {
290 let offset = shared.offset(self.pos);
291 let was_last = shared.dec_ref_count(offset);
292 need_to_drop = offset == 0 && was_last;
293 } else {
294 return;
297 }
298 if !need_to_drop {
299 return;
300 }
301 if let Ok(mut shared) = self.shared.write() {
302 shared.drop_tail();
303 }
304 }
305}
306
307impl<I> Iterator for Tee<I>
308where
309 I: Iterator,
310 I::Item: Clone,
311{
312 type Item = I::Item;
313
314 fn next(&mut self) -> Option<Self::Item> {
315 let item = Shared::take(&self.shared, self.pos);
316 if item.is_some() {
317 self.pos += 1;
318 }
319 item
320 }
321
322 fn size_hint(&self) -> (usize, Option<usize>) {
323 let shared = self.shared.read().unwrap();
324 let total_buffered = shared.num_items_dropped + shared.buffer.len();
325 let more_in_buffer = total_buffered - self.pos;
326 let (iter_min, iter_max) = match &shared.iter {
327 Some(iter) => iter.size_hint(),
328 None => (0, Some(0)),
329 };
330 (
331 more_in_buffer + iter_min,
332 iter_max.map(|im| more_in_buffer + im),
333 )
334 }
335}
336
337#[cfg(test)]
338mod tests {
339 use super::Tee;
340 use std::{fmt::Debug, thread};
341
342 fn make_string_iter() -> impl Iterator<Item = String> {
343 (0..1024).map(|i| i.to_string())
344 }
345
346 fn assert_iter_eq<I1, I2>(mut i1: I1, mut i2: I2)
347 where
348 I1: Iterator,
349 I2: Iterator<Item = I1::Item>,
350 I1::Item: PartialEq + Debug,
351 {
352 while let Some(item1) = i1.next() {
353 assert_eq!(item1, i2.next().unwrap());
354 }
355 assert!(i2.next().is_none());
356 }
357
358 #[test]
359 fn just_one_tee() {
360 let tee = Tee::new(make_string_iter());
361 assert_iter_eq(tee, make_string_iter());
362 }
363
364 #[test]
365 fn two_tees() {
366 let tee1 = Tee::new(make_string_iter());
367 let tee2 = tee1.clone();
368 assert_iter_eq(tee1, make_string_iter());
369 assert_iter_eq(tee2, make_string_iter());
370 }
371
372 #[test]
373 fn two_tees_parallel() {
374 let tee1 = Tee::new(make_string_iter());
375 let tee2 = tee1.clone();
376 let t1 = thread::spawn(|| assert_iter_eq(tee1, make_string_iter()));
377 let t2 = thread::spawn(|| assert_iter_eq(tee2, make_string_iter()));
378 t1.join().unwrap();
379 t2.join().unwrap();
380 }
381
382 #[test]
383 fn ten_tees_parallel() {
384 let tee = Tee::new(make_string_iter());
385 let mut threads = vec![];
386 for tee in vec![tee; 10] {
387 let t = thread::spawn(|| assert_iter_eq(tee, make_string_iter()));
388 threads.push(t);
389 }
390 for t in threads {
391 t.join().unwrap();
392 }
393 }
394
395 #[test]
396 fn drop_in_the_middle() {
397 let tee = Tee::new(make_string_iter());
398 let mut threads = vec![];
399 for (i, tee) in vec![tee; 10].into_iter().enumerate() {
400 let t = thread::spawn(move || assert_iter_eq(tee.take(i), make_string_iter().take(i)));
401 threads.push(t);
402 }
403 for t in threads {
404 t.join().unwrap();
405 }
406 }
407
408 #[test]
409 fn clone_in_the_middle() {
410 let mut tee1 = Tee::new(make_string_iter());
411 assert_iter_eq(
412 tee1.by_ref().take(10),
413 make_string_iter().take(10)
414 );
415 let tee2 = tee1.clone();
416
417 assert_iter_eq(
418 tee1,
419 make_string_iter().skip(10)
420 );
421 assert_iter_eq(
422 tee2,
423 make_string_iter().skip(10)
424 );
425 }
426}