small_vec2/drain_filter.rs
1use std::{ptr, slice};
2
3use crate::SmallVec;
4
5/// An iterator which uses a closure to determine if an element should be removed.
6///
7/// # Example
8///
9/// ```
10/// # use small_vec2::{SmallVec, DrainFilter};
11/// let mut v: SmallVec<_, 3> = SmallVec::from(vec![0, 1, 2]);
12/// let iter: DrainFilter<_, _, 3> = v.drain_filter(|x| *x % 2 == 0);
13/// ```
14#[derive(Debug)]
15pub struct DrainFilter<'a, T, F, const N: usize>
16where
17 F: FnMut(&mut T) -> bool,
18{
19 pub(crate) vec: &'a mut SmallVec<T, N>,
20 /// The index of the item that will be inspected by the next call to `next`.
21 pub(crate) idx: usize,
22 /// The number of items that have been drained (removed) thus far.
23 pub(crate) del: usize,
24 /// The original length of `vec` prior to draining.
25 pub(crate) old_len: usize,
26 /// The filter test predicate.
27 pub(crate) pred: F,
28 /// A flag that indicates a panic has occurred in the filter test predicate.
29 /// This is used as a hint in the drop implementation to prevent consumption
30 /// of the remainder of the `DrainFilter`. Any unprocessed items will be
31 /// backshifted in the `vec`, but no further items will be dropped or
32 /// tested by the filter predicate.
33 pub(crate) panic_flag: bool,
34}
35
36impl<'a, T, F, const N: usize> DrainFilter<'a, T, F, N>
37where
38 F: FnMut(&mut T) -> bool,
39{
40 pub fn new(
41 vec: &'a mut SmallVec<T, N>,
42 idx: usize,
43 del: usize,
44 old_len: usize,
45 pred: F,
46 panic_flag: bool,
47 ) -> Self {
48 Self {
49 vec,
50 idx,
51 del,
52 old_len,
53 pred,
54 panic_flag,
55 }
56 }
57}
58
59impl<T, F, const N: usize> Iterator for DrainFilter<'_, T, F, N>
60where
61 F: FnMut(&mut T) -> bool,
62{
63 type Item = T;
64
65 fn next(&mut self) -> Option<T> {
66 unsafe {
67 while self.idx < self.old_len {
68 let i = self.idx;
69 let v = slice::from_raw_parts_mut(self.vec.as_mut_ptr(), self.old_len);
70 self.panic_flag = true;
71 let drained = (self.pred)(&mut v[i]);
72 self.panic_flag = false;
73 // Update the index *after* the predicate is called. If the index
74 // is updated prior and the predicate panics, the element at this
75 // index would be leaked.
76 self.idx += 1;
77 if drained {
78 self.del += 1;
79 return Some(ptr::read(&v[i]));
80 } else if self.del > 0 {
81 let del = self.del;
82 let src: *const T = &v[i];
83 let dst: *mut T = &mut v[i - del];
84 ptr::copy_nonoverlapping(src, dst, 1);
85 }
86 }
87 None
88 }
89 }
90
91 fn size_hint(&self) -> (usize, Option<usize>) {
92 (0, Some(self.old_len - self.idx))
93 }
94}
95
96impl<T, F, const N: usize> Drop for DrainFilter<'_, T, F, N>
97where
98 F: FnMut(&mut T) -> bool,
99{
100 fn drop(&mut self) {
101 struct BackshiftOnDrop<'a, 'b, T, F, const N: usize>
102 where
103 F: FnMut(&mut T) -> bool,
104 {
105 drain: &'b mut DrainFilter<'a, T, F, N>,
106 }
107
108 impl<'a, 'b, T, F, const N: usize> Drop for BackshiftOnDrop<'a, 'b, T, F, N>
109 where
110 F: FnMut(&mut T) -> bool,
111 {
112 fn drop(&mut self) {
113 unsafe {
114 if self.drain.idx < self.drain.old_len && self.drain.del > 0 {
115 // This is a pretty messed up state, and there isn't really an
116 // obviously right thing to do. We don't want to keep trying
117 // to execute `pred`, so we just backshift all the unprocessed
118 // elements and tell the vec that they still exist. The backshift
119 // is required to prevent a double-drop of the last successfully
120 // drained item prior to a panic in the predicate.
121 let ptr = self.drain.vec.as_mut_ptr();
122 let src = ptr.add(self.drain.idx);
123 let dst = src.sub(self.drain.del);
124 let tail_len = self.drain.old_len - self.drain.idx;
125 src.copy_to(dst, tail_len);
126 }
127 self.drain.vec.set_len(self.drain.old_len - self.drain.del);
128 }
129 }
130 }
131
132 let backshift = BackshiftOnDrop { drain: self };
133
134 // Attempt to consume any remaining elements if the filter predicate
135 // has not yet panicked. We'll backshift any remaining elements
136 // whether we've already panicked or if the consumption here panics.
137 if !backshift.drain.panic_flag {
138 backshift.drain.for_each(drop);
139 }
140 }
141}