small_vec2/
drain_filter.rs

1use std::{ptr, slice};
2
3use crate::SmallVec;
4
5/// An iterator which uses a closure to determine if an element should be removed.
6///
7/// # Example
8///
9/// ```
10/// # use small_vec2::{SmallVec, DrainFilter};
11/// let mut v: SmallVec<_, 3> = SmallVec::from(vec![0, 1, 2]);
12/// let iter: DrainFilter<_, _, 3> = v.drain_filter(|x| *x % 2 == 0);
13/// ```
14#[derive(Debug)]
15pub struct DrainFilter<'a, T, F, const N: usize>
16where
17    F: FnMut(&mut T) -> bool,
18{
19    pub(crate) vec: &'a mut SmallVec<T, N>,
20    /// The index of the item that will be inspected by the next call to `next`.
21    pub(crate) idx: usize,
22    /// The number of items that have been drained (removed) thus far.
23    pub(crate) del: usize,
24    /// The original length of `vec` prior to draining.
25    pub(crate) old_len: usize,
26    /// The filter test predicate.
27    pub(crate) pred: F,
28    /// A flag that indicates a panic has occurred in the filter test predicate.
29    /// This is used as a hint in the drop implementation to prevent consumption
30    /// of the remainder of the `DrainFilter`. Any unprocessed items will be
31    /// backshifted in the `vec`, but no further items will be dropped or
32    /// tested by the filter predicate.
33    pub(crate) panic_flag: bool,
34}
35
36impl<'a, T, F, const N: usize> DrainFilter<'a, T, F, N>
37where
38    F: FnMut(&mut T) -> bool,
39{
40    pub fn new(
41        vec: &'a mut SmallVec<T, N>,
42        idx: usize,
43        del: usize,
44        old_len: usize,
45        pred: F,
46        panic_flag: bool,
47    ) -> Self {
48        Self {
49            vec,
50            idx,
51            del,
52            old_len,
53            pred,
54            panic_flag,
55        }
56    }
57}
58
59impl<T, F, const N: usize> Iterator for DrainFilter<'_, T, F, N>
60where
61    F: FnMut(&mut T) -> bool,
62{
63    type Item = T;
64
65    fn next(&mut self) -> Option<T> {
66        unsafe {
67            while self.idx < self.old_len {
68                let i = self.idx;
69                let v = slice::from_raw_parts_mut(self.vec.as_mut_ptr(), self.old_len);
70                self.panic_flag = true;
71                let drained = (self.pred)(&mut v[i]);
72                self.panic_flag = false;
73                // Update the index *after* the predicate is called. If the index
74                // is updated prior and the predicate panics, the element at this
75                // index would be leaked.
76                self.idx += 1;
77                if drained {
78                    self.del += 1;
79                    return Some(ptr::read(&v[i]));
80                } else if self.del > 0 {
81                    let del = self.del;
82                    let src: *const T = &v[i];
83                    let dst: *mut T = &mut v[i - del];
84                    ptr::copy_nonoverlapping(src, dst, 1);
85                }
86            }
87            None
88        }
89    }
90
91    fn size_hint(&self) -> (usize, Option<usize>) {
92        (0, Some(self.old_len - self.idx))
93    }
94}
95
96impl<T, F, const N: usize> Drop for DrainFilter<'_, T, F, N>
97where
98    F: FnMut(&mut T) -> bool,
99{
100    fn drop(&mut self) {
101        struct BackshiftOnDrop<'a, 'b, T, F, const N: usize>
102        where
103            F: FnMut(&mut T) -> bool,
104        {
105            drain: &'b mut DrainFilter<'a, T, F, N>,
106        }
107
108        impl<'a, 'b, T, F, const N: usize> Drop for BackshiftOnDrop<'a, 'b, T, F, N>
109        where
110            F: FnMut(&mut T) -> bool,
111        {
112            fn drop(&mut self) {
113                unsafe {
114                    if self.drain.idx < self.drain.old_len && self.drain.del > 0 {
115                        // This is a pretty messed up state, and there isn't really an
116                        // obviously right thing to do. We don't want to keep trying
117                        // to execute `pred`, so we just backshift all the unprocessed
118                        // elements and tell the vec that they still exist. The backshift
119                        // is required to prevent a double-drop of the last successfully
120                        // drained item prior to a panic in the predicate.
121                        let ptr = self.drain.vec.as_mut_ptr();
122                        let src = ptr.add(self.drain.idx);
123                        let dst = src.sub(self.drain.del);
124                        let tail_len = self.drain.old_len - self.drain.idx;
125                        src.copy_to(dst, tail_len);
126                    }
127                    self.drain.vec.set_len(self.drain.old_len - self.drain.del);
128                }
129            }
130        }
131
132        let backshift = BackshiftOnDrop { drain: self };
133
134        // Attempt to consume any remaining elements if the filter predicate
135        // has not yet panicked. We'll backshift any remaining elements
136        // whether we've already panicked or if the consumption here panics.
137        if !backshift.drain.panic_flag {
138            backshift.drain.for_each(drop);
139        }
140    }
141}