vec_drain_where/
lib.rs

1//! Provides an alternative implementation for `Vec::drain_filter`.
2//!
3//! Import `VecDrainWhereExt` to extend `Vec` with an
4//! `e_drain_where` method which drains all elements where
5//! a predicate indicates it. The `e_` prefix is to prevent
6//! name collision/confusion as `drain_filter` might be
7//! stabilized as `drain_where`. Also in difference to
8//! `drain_filter` this implementation doesn't run to
9//! completion when dropped, allowing stopping the draining
10//! from the outside (through combinators/for loop break)
11//! and is not prone to double panics/panics on drop.
12#[cfg(test)]
13extern crate quickcheck;
14
15use std::{isize, ptr, mem};
16
17/// Ext. trait adding `e_drain_where` to `Vec`
18pub trait VecDrainWhereExt<Item> {
19    /// Drains all elements from the vector where the predicate is true.
20    ///
21    /// Note that dropping the iterator early will stop the process
22    /// of draining. So for example if you add an combinator to the
23    /// drain iterator which short circuits (e.g. `any`/`all`) this
24    /// will stop draining once short circuiting is hit. So use it
25    /// with care.
26    ///
27    /// you can use fold e.g. `any(pred)` => `fold(false, |s| )
28    ///
29    /// # Leak Behavior
30    ///
31    /// For safety reasons the length of the original vector
32    /// is set to 0 while the drain iterator lives.
33    ///
34    /// # Panic/Drop Behavior
35    ///
36    /// When the iterator is dropped due to an panic in
37    /// the predicate the element it panicked on is leaked
38    /// but all elements which have already been decided
39    /// to not be drained and such which have not yet been
40    /// decided about will still be in the vector safely.
41    /// I.e. if the panic also causes the vector to drop
42    /// they are normally dropped if not the vector still
43    /// can be normally used.
44    ///
45    /// # Tip: non iterator short circuiting `all`/`any`
46    ///
47    /// Instead of `iter.any(pred)` use
48    /// `iter.fold(false, |s,i| s|pred(i))`.
49    ///
50    /// Instead of `iter.all(pred)` use
51    /// `iter.fold(true, |s,i| s&pred(i))`.
52    ///
53    /// And if it is fine to not call `pred` once
54    /// it's found/has show to not hold but it's
55    /// still required to run the iterator to end
56    /// in the normal case replace the `|` with `||`
57    /// and the `&` with `&&`.
58    fn e_drain_where<F>(&mut self, predicate: F)
59        -> VecDrainWhere<Item, F>
60        where F: FnMut(&mut Item) -> bool;
61}
62
63impl<Item> VecDrainWhereExt<Item> for Vec<Item> {
64    fn e_drain_where<F>(&mut self, predicate: F)
65        -> VecDrainWhere<Item, F>
66        where F: FnMut(&mut Item) -> bool
67    {
68        let ptr = self.as_mut_ptr();
69        let len = self.len();
70        if len == 0 {
71            let nptr = 0 as *mut _;
72            return VecDrainWhere {
73                pos: nptr,
74                gap_pos: nptr,
75                end: nptr,
76                self_ref: self,
77                predicate
78            };
79        }
80
81        if len > isize::MAX as usize {
82            panic!("can not handle more then isize::MAX elements");
83        }
84
85        // leak amplification for safety
86        unsafe { self.set_len(0) }
87
88        let end = unsafe { ptr.offset(len as isize) };
89
90        VecDrainWhere {
91            pos: ptr,
92            gap_pos: ptr,
93            end,
94            self_ref: self,
95            predicate
96        }
97    }
98}
99
100/// Iterator for draining a vector conditionally.
101#[must_use]
102#[derive(Debug)]
103pub struct VecDrainWhere<'a, Item: 'a, Pred> {
104    pos: *mut Item,
105    gap_pos: *mut Item,
106    end: *mut Item,
107    predicate: Pred,
108    self_ref: &'a mut Vec<Item>
109}
110
111impl<'a, I: 'a, P> Iterator for VecDrainWhere<'a, I, P>
112    where P: FnMut(&mut I) -> bool
113{
114    type Item = I;
115
116    fn next(&mut self) -> Option<Self::Item> {
117        loop {
118            if self.pos.is_null() || self.pos >= self.end {
119                return None;
120            } else {
121                unsafe {
122                    let ref_to_current = &mut *self.pos;
123                    self.pos = self.pos.offset(1);
124                    let should_be_drained = (self.predicate)(ref_to_current);
125                    if should_be_drained {
126                        let item = ptr::read(ref_to_current);
127                        return Some(item);
128                    } else {
129                        if self.gap_pos < ref_to_current {
130                            ptr::copy_nonoverlapping(ref_to_current, self.gap_pos, 1);
131                        }
132                        self.gap_pos = self.gap_pos.offset(1);
133                    }
134                }
135            }
136        }
137    }
138
139    fn size_hint(&self) -> (usize, Option<usize>) {
140        (0, Some(self.self_ref.len()))
141    }
142}
143
144impl<'a, I: 'a, P> Drop for VecDrainWhere<'a, I, P> {
145    /// If the iterator was run to completion this will
146    /// set the len to the new len after drop. I.e. it
147    /// will undo the leak amplification.
148    ///
149    /// If the iterator is dropped before completion this
150    /// will move the remaining elements to the (single)
151    /// gap (still) left from draining elements and then
152    /// sets the new length.
153    ///
154    /// If the iterator is dropped because the called
155    /// predicate panicked the element it panicked on
156    /// is _leaked_. This is because its simply to easy
157    /// to leaf the `&mut T` value in a illegal state
158    /// likely to panic drop or even behave unsafely
159    /// (through it surly shouldn't behave this way).
160    fn drop(&mut self) {
161        let pos = self.pos as usize;
162        if self.pos.is_null() {
163            return
164        }
165        let start  = self.self_ref.as_mut_ptr() as usize;
166        let end = self.end as usize;
167        let gap = self.gap_pos as usize;
168        let item_size: usize = mem::size_of::<I>();
169        unsafe {
170            let cur_len = (gap - start)/item_size;
171            let rem_len = (end - pos)/item_size;
172            ptr::copy(self.pos, self.gap_pos, rem_len);
173            self.self_ref.set_len(cur_len + rem_len);
174        }
175    }
176}
177
178
179#[cfg(test)]
180mod tests {
181    use quickcheck::TestResult;
182    //Uhm, this is not unused at all, so it being displayed
183    // as such is a rustc bug (is in the bug tracker).
184    #[allow(unused_imports)]
185    use super::VecDrainWhereExt;
186
187    mod check_with_mask {
188        use super::*;
189
190        fn cmp_with_mask(mask: Vec<bool>) -> TestResult {
191            let mut data = (0..mask.len()).collect::<Vec<_>>();
192            let data2 = data.clone();
193            let new_len = mask.len() - mask.iter().fold(0, |s,i| if *i { s + 1 } else { s });
194            let mut mask_iter = mask.clone().into_iter();
195            let mut last_el: Option<usize> = None;
196
197            let mut failed = None;
198            data.e_drain_where(|el| {
199                if let Some(lel) = last_el {
200                    if lel + 1 != *el {
201                        failed = Some(TestResult::error(
202                            format!("unexpected element (exp {}, got {})", lel + 1, el)));
203                    }
204                }
205                last_el = Some(*el);
206
207                if let Some(mask) = mask_iter.next() {
208                    mask
209                } else {
210                    failed = Some(TestResult::error("called predicate to often"));
211                    false
212                }
213            }).for_each(drop);
214
215            if let Some(f) = failed {
216                return f;
217            }
218
219            if new_len != data.len() {
220                return TestResult::error(format!(
221                    "rem count: {}, found count: {} - {:?} | {:?}",
222                    new_len, data.len(), data, mask
223                ))
224            }
225
226            let expected = data2.iter().zip(mask.iter())
227                    .filter(|&(_d, p)| *p)
228                    .map(|(d, _p)| *d)
229                    .collect::<Vec<_>>();
230
231            if expected != data {
232                TestResult::error("unexpected data");
233            }
234            TestResult::passed()
235
236        }
237
238        #[test]
239        fn qc_cmp_with_mask() {
240            ::quickcheck::quickcheck(cmp_with_mask as fn(Vec<bool>) -> TestResult);
241        }
242
243
244        #[test]
245        fn fix_divide_byte_len_by_size_of() {
246            let res = cmp_with_mask(vec![false]);
247            assert!(!res.is_error(), "{:?}", res)
248        }
249
250        #[test]
251        fn fix_update_last_el_in_test() {
252            let res = cmp_with_mask(vec![false, false, false]);
253            assert!(!res.is_error(), "{:?}", res)
254        }
255    }
256
257    mod check_with_panic {
258        use super::*;
259
260        fn panic_situations(mask: Vec<(bool, bool)>) -> TestResult {
261            let mut data = (0..mask.len()).collect::<Vec<_>>();
262            let mut mask_iter = mask.clone().into_iter();
263            let mut fail = None;
264            let mut expect_panic = false;
265            let expected_len = mask.iter()
266                .fold(0, |sum, &(msk, pnk)| {
267                    if expect_panic { sum + 1 }
268                    else if pnk { expect_panic=true; sum }
269                    else if msk { sum }
270                    else { sum + 1}
271                });
272
273            let res = ::std::panic::catch_unwind(::std::panic::AssertUnwindSafe(|| {
274                data.e_drain_where(|_item| {
275                    let (mask, do_panic) = mask_iter.next()
276                        .unwrap_or_else(|| {
277                            fail = Some(TestResult::error("unexpected no more masks"));
278                            (false, false)
279                        });
280
281                    if do_panic {
282                        panic!("-- yes panic --");
283                    }
284                    mask
285                }).for_each(drop);
286            }));
287
288            if let Some(failure) = fail {
289                return failure;
290            }
291
292            if expect_panic {
293                if res.is_ok() {
294                    return TestResult::error(format!(
295                        "unexpectedly no panic? exp {}, len {}, ({:?})",
296                        expected_len, mask.len(), mask
297                    ))
298                }
299            } else {
300                if res.is_err() {
301                    return TestResult::error(format!(
302                        "unexpectedly error? exp {}, len {}, ({:?})",
303                        expected_len, mask.len(), mask
304                    ))
305                }
306            }
307
308            if data.len() != expected_len {
309                return TestResult::error(format!(
310                    "unexpected resulting len {}, exp {} ({:?} - {:?})",
311                    data.len(), expected_len, data, mask
312                ));
313            }
314
315            TestResult::passed()
316        }
317
318
319        #[test]
320        fn qc_panic_test() {
321            ::quickcheck::quickcheck(panic_situations as fn(Vec<(bool,bool)>) -> TestResult)
322        }
323
324        #[test]
325        fn fix_messed_up_test() {
326            let res = panic_situations(vec![(true, false)]);
327            assert!(!res.is_error(), "{:?}", res);
328        }
329    }
330
331}