median/
heap.rs

1// This Source Code Form is subject to the terms of the Mozilla Public
2// License, v. 2.0. If a copy of the MPL was not distributed with this
3// file, You can obtain one at http://mozilla.org/MPL/2.0/.
4
5//! An implementation of a heap-allocated, efficient O(n) median filter.
6
7use std::fmt;
8
9#[derive(Clone, PartialEq, Eq)]
10struct ListNode<T> {
11    value: Option<T>,
12    previous: usize,
13    next: usize,
14}
15
16impl<T> fmt::Debug for ListNode<T>
17where
18    T: fmt::Debug,
19{
20    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
21        write!(f, "@{:?}-{:?}-@{:?}", self.previous, self.value, self.next)
22    }
23}
24
25/// An implementation of a median filter with linear complexity.
26///
27/// While the common naïve implementation of a median filter
28/// has a worst-case complexity of `O(n^2)` (due to having to sort the sliding window)
29/// the use of a combination of linked list and ring buffer allows for
30/// a worst-case complexity of `O(n)`.
31#[derive(Clone, Debug)]
32pub struct Filter<T> {
33    // Buffer of list nodes:
34    buffer: Vec<ListNode<T>>,
35    // Cursor into circular buffer of data:
36    cursor: usize,
37    // Cursor to beginning of circular list:
38    head: usize,
39    // Cursor to median of circular list:
40    median: usize,
41}
42
43impl<T> Filter<T>
44where
45    T: Clone + PartialOrd,
46{
47    /// Creates a new median filter with a given window size.
48    pub fn new(size: usize) -> Self {
49        let mut buffer = Vec::with_capacity(size);
50        for i in 0..size {
51            buffer.push(ListNode {
52                value: None,
53                previous: (i + size - 1) % size,
54                next: (i + 1) % size,
55            });
56        }
57        Filter {
58            buffer,
59            cursor: 0,
60            head: 0,
61            median: 0,
62        }
63    }
64
65    /// Returns the window size of the filter.
66    #[inline]
67    pub fn len(&self) -> usize {
68        self.buffer.len()
69    }
70
71    /// Returns `true` if the filter has a length of `0`.
72    #[inline]
73    pub fn is_empty(&self) -> usize {
74        self.len()
75    }
76
77    /// Returns the filter buffer's current median value, panicking if empty.
78    #[inline]
79    pub fn median(&self) -> T {
80        assert!(!self.buffer.is_empty());
81
82        unsafe { self.read_median() }
83    }
84
85    /// Returns the filter buffer's current min value, panicking if empty.
86    #[inline]
87    pub fn min(&self) -> T {
88        assert!(!self.buffer.is_empty());
89
90        unsafe { self.read_min() }
91    }
92
93    /// Returns the filter buffer's current max value, panicking if empty.
94    #[inline]
95    pub fn max(&self) -> T {
96        assert!(!self.buffer.is_empty());
97
98        unsafe { self.read_max() }
99    }
100
101    /// Applies a median filter to the consumed value.
102    ///
103    /// # Implementation
104    ///
105    /// The algorithm makes use of a ring buffer of the same size as its filter window.
106    /// Inserting values into the ring buffer appends them to a linked list that is *embedded*
107    /// inside said ring buffer (using relative integer jump offsets as links).
108    ///
109    /// # Example
110    ///
111    /// Given a sequence of values `[3, 2, 4, 6, 5, 1]` and a buffer of size 5,
112    /// the buffer would be filled like this:
113    ///
114    /// ```plain
115    /// new(5)  consume(3)  consume(2)  consume(4)  consume(6)  consume(5)  consume(1)
116    /// ▶︎[ ]      ▷[3]       ┌→[3]       ┌→[3]─┐     ┌→[3]─┐    ▶︎┌→[3]─┐      ▷[1]─┐
117    ///  [ ]      ▶︎[ ]      ▷└─[2]      ▷└─[2] │    ▷└─[2] │    ▷└─[2] │    ▶︎┌─[2]←┘
118    ///  [ ]       [ ]        ▶︎[ ]         [4]←┘     ┌─[4]←┘     ┌─[4]←┘     └→[4]─┐
119    ///  [ ]       [ ]         [ ]        ▶︎[ ]       └→[6]       │ [6]←┐     ┌→[6] │
120    ///  [ ]       [ ]         [ ]         [ ]        ▶︎[ ]       └→[5]─┘     └─[5]←┘
121    /// ```
122    ///
123    /// # Algorithm
124    ///
125    /// 1. **Remove node** at current cursor (`▶︎`) from linked list, if it exists.
126    ///    (by re-wiring its predecessor to its successor).
127    /// 2. **Initialize** `current` and `median` index to first node of linked list (`▷`).
128    /// 3. **Walk through** linked list, **searching** for insertion point.
129    /// 4. **Shift median index** on every other hop (thus ending up in the list's median).
130    /// 5. **Insert value** into ring buffer and linked list respectively.
131    /// 6. **Update index** to linked list's first node, if necessary.
132    /// 7. **Update ring buffer**'s cursor.
133    /// 8. **Return median value**.
134    ///
135    /// (_Based on Phil Ekstrom, Embedded Systems Programming, November 2000._)
136
137    pub fn consume(&mut self, value: T) -> T {
138        // If the current head is about to be overwritten
139        // we need to make sure to have the head point to
140        // the next node after the current head:
141        unsafe {
142            self.move_head_forward();
143        }
144
145        // Remove the node that is about to be overwritten
146        // from the linked list:
147        unsafe {
148            self.remove_node();
149        }
150
151        // Initialize `self.median` pointing
152        // to the first (smallest) node in the sorted list:
153        unsafe {
154            self.initialize_median();
155        }
156
157        // Search for the insertion index in the linked list
158        // in regards to `value` as the insertion index.
159        unsafe {
160            self.insert_value(&value);
161        }
162
163        // Update head to newly inserted node if
164        // cursor's value <= head's value or head is empty:
165        unsafe {
166            self.update_head(&value);
167        }
168
169        // If the filter has an even window size, then shift the median
170        // back one slot, so that it points to the left one
171        // of the middle pair of median values
172        unsafe {
173            self.adjust_median_for_even_length();
174        }
175
176        // Increment and wrap data in pointer:
177        unsafe {
178            self.increment_cursor();
179        }
180
181        // Read node value from buffer at `self.medium`:
182        unsafe { self.read_median() }
183    }
184
185    #[inline]
186    fn should_insert(&self, value: &T, current: usize, index: usize) -> bool {
187        if let Some(ref v) = self.buffer[current].value {
188            (index + 1 == self.len()) || (v >= value)
189        } else {
190            true
191        }
192    }
193
194    #[inline]
195    unsafe fn move_head_forward(&mut self) {
196        if self.cursor == self.head {
197            self.head = self.buffer[self.head].next;
198        }
199    }
200
201    #[inline]
202    unsafe fn remove_node(&mut self) {
203        let (predecessor, successor) = {
204            let node = &self.buffer[self.cursor];
205            (node.previous, node.next)
206        };
207        self.buffer[predecessor].next = successor;
208        self.buffer[self.cursor] = ListNode {
209            previous: usize::max_value(),
210            value: None,
211            next: usize::max_value(),
212        };
213        self.buffer[successor].previous = predecessor;
214    }
215
216    #[inline]
217    unsafe fn initialize_median(&mut self) {
218        self.median = self.head;
219    }
220
221    #[inline]
222    unsafe fn insert_value(&mut self, value: &T) {
223        let mut current = self.head;
224        let buffer_len = self.len();
225        let mut has_inserted = false;
226        for index in 0..buffer_len {
227            if !has_inserted {
228                let should_insert = self.should_insert(value, current, index);
229                if should_insert {
230                    // Insert previously removed node with new value
231                    // into linked list at given insertion index.
232                    self.insert(value, current);
233                    has_inserted = true;
234                }
235            }
236
237            // Shift median on every other element in the list,
238            // so that it ends up in the middle, eventually:
239            self.shift_median(index, current);
240
241            current = self.buffer[current].next;
242        }
243    }
244
245    #[inline]
246    unsafe fn insert(&mut self, value: &T, current: usize) {
247        let successor = current;
248        let predecessor = self.buffer[current].previous;
249        debug_assert!(self.buffer.len() == 1 || current != self.cursor);
250        self.buffer[predecessor].next = self.cursor;
251        self.buffer[self.cursor] = ListNode {
252            previous: predecessor,
253            value: Some(value.clone()),
254            next: successor,
255        };
256        self.buffer[successor].previous = self.cursor;
257    }
258
259    #[inline]
260    unsafe fn shift_median(&mut self, index: usize, current: usize) {
261        if (index & 0b1 == 0b1) && (self.buffer[current].value.is_some()) {
262            self.median = self.buffer[self.median].next;
263        }
264    }
265
266    #[inline]
267    unsafe fn update_head(&mut self, value: &T) {
268        let should_update_head = if let Some(ref head) = self.buffer[self.head].value {
269            value <= head
270        } else {
271            true
272        };
273
274        if should_update_head {
275            self.head = self.cursor;
276            self.median = self.buffer[self.median].previous;
277        }
278    }
279
280    #[inline]
281    unsafe fn adjust_median_for_even_length(&mut self) {
282        if self.len() % 2 == 0 {
283            self.median = self.buffer[self.median].previous;
284        }
285    }
286
287    #[inline]
288    unsafe fn increment_cursor(&mut self) {
289        self.cursor = (self.cursor + 1) % (self.len());
290    }
291
292    #[inline]
293    unsafe fn read_median(&self) -> T {
294        let index = self.median;
295        self.buffer[index].value.clone().unwrap()
296    }
297
298    #[inline]
299    unsafe fn read_min(&self) -> T {
300        let index = self.head;
301        self.buffer[index].value.clone().unwrap()
302    }
303
304    #[inline]
305    unsafe fn read_max(&self) -> T {
306        let index = (self.cursor + self.len() - 1) % (self.len());
307        self.buffer[index].value.clone().unwrap()
308    }
309}
310
311#[cfg(test)]
312mod tests {
313    use super::*;
314
315    macro_rules! test_filter {
316        ($size:expr, $input:expr, $output:expr) => {
317            let filter = Filter::new($size);
318            let output: Vec<_> = $input
319                .iter()
320                .scan(filter, |filter, &input| Some(filter.consume(input)))
321                .collect();
322            assert_eq!(output, $output);
323        };
324    }
325
326    #[test]
327    fn single_peak_4() {
328        let input = vec![10, 20, 30, 100, 30, 20, 10];
329        let output = vec![10, 10, 20, 20, 30, 30, 20];
330
331        test_filter!(4, input, output);
332    }
333
334    #[test]
335    fn single_peak_5() {
336        let input = vec![10, 20, 30, 100, 30, 20, 10];
337        let output = vec![10, 10, 20, 20, 30, 30, 30];
338        test_filter!(5, input, output);
339    }
340
341    #[test]
342    fn single_valley_4() {
343        let input = vec![90, 80, 70, 10, 70, 80, 90];
344        let output = vec![90, 80, 80, 70, 70, 70, 70];
345        test_filter!(4, input, output);
346    }
347
348    #[test]
349    fn single_valley_5() {
350        let input = vec![90, 80, 70, 10, 70, 80, 90];
351        let output = vec![90, 80, 80, 70, 70, 70, 70];
352        test_filter!(5, input, output);
353    }
354
355    #[test]
356    fn single_outlier_4() {
357        let input = vec![10, 10, 10, 100, 10, 10, 10];
358        let output = vec![10, 10, 10, 10, 10, 10, 10];
359        test_filter!(4, input, output);
360    }
361
362    #[test]
363    fn single_outlier_5() {
364        let input = vec![10, 10, 10, 100, 10, 10, 10];
365        let output = vec![10, 10, 10, 10, 10, 10, 10];
366        test_filter!(5, input, output);
367    }
368
369    #[test]
370    fn triple_outlier_4() {
371        let input = vec![10, 10, 100, 100, 100, 10, 10];
372        let output = vec![10, 10, 10, 10, 100, 100, 10];
373        test_filter!(4, input, output);
374    }
375
376    #[test]
377    fn triple_outlier_5() {
378        let input = vec![10, 10, 100, 100, 100, 10, 10];
379        let output = vec![10, 10, 10, 10, 100, 100, 100];
380        test_filter!(5, input, output);
381    }
382
383    #[test]
384    fn quintuple_outlier_4() {
385        let input = vec![10, 100, 100, 100, 100, 100, 10];
386        let output = vec![10, 10, 100, 100, 100, 100, 100];
387        test_filter!(4, input, output);
388    }
389
390    #[test]
391    fn quintuple_outlier_5() {
392        let input = vec![10, 100, 100, 100, 100, 100, 10];
393        let output = vec![10, 10, 100, 100, 100, 100, 100];
394        test_filter!(5, input, output);
395    }
396
397    #[test]
398    fn alternating_4() {
399        let input = vec![10, 20, 10, 20, 10, 20, 10];
400        let output = vec![10, 10, 10, 10, 10, 10, 10];
401        test_filter!(4, input, output);
402    }
403
404    #[test]
405    fn alternating_5() {
406        let input = vec![10, 20, 10, 20, 10, 20, 10];
407        let output = vec![10, 10, 10, 10, 10, 20, 10];
408        test_filter!(5, input, output);
409    }
410
411    #[test]
412    fn ascending_4() {
413        let input = vec![10, 20, 30, 40, 50, 60, 70];
414        let output = vec![10, 10, 20, 20, 30, 40, 50];
415        test_filter!(4, input, output);
416    }
417
418    #[test]
419    fn ascending_5() {
420        let input = vec![10, 20, 30, 40, 50, 60, 70];
421        let output = vec![10, 10, 20, 20, 30, 40, 50];
422        test_filter!(5, input, output);
423    }
424
425    #[test]
426    fn descending_4() {
427        let input = vec![70, 60, 50, 40, 30, 20, 10];
428        let output = vec![70, 60, 60, 50, 40, 30, 20];
429        test_filter!(4, input, output);
430    }
431
432    #[test]
433    fn descending_5() {
434        let input = vec![70, 60, 50, 40, 30, 20, 10];
435        let output = vec![70, 60, 60, 50, 50, 40, 30];
436        test_filter!(5, input, output);
437    }
438
439    #[test]
440    fn min_max_median() {
441        let mut filter = Filter::new(5);
442        for input in vec![70, 50, 30, 10, 20, 40, 60] {
443            filter.consume(input);
444        }
445        assert_eq!(filter.min(), 10);
446        assert_eq!(filter.max(), 60);
447        assert_eq!(filter.median(), 30);
448    }
449}