smartcore/algorithm/sort/
heap_select.rs

1//! # Heap Selection Algorithm
2//!
3//! The goal is to find the k smallest elements in a list or array.
4use std::cmp::Ordering;
5use std::fmt::Debug;
6
7#[derive(Debug)]
8pub struct HeapSelection<T: PartialOrd + Debug> {
9    k: usize,
10    n: usize,
11    sorted: bool,
12    heap: Vec<T>,
13}
14
15impl<T: PartialOrd + Debug> HeapSelection<T> {
16    pub fn with_capacity(k: usize) -> HeapSelection<T> {
17        HeapSelection {
18            k,
19            n: 0,
20            sorted: false,
21            heap: Vec::new(),
22        }
23    }
24
25    pub fn add(&mut self, element: T) {
26        self.sorted = false;
27        if self.n < self.k {
28            self.heap.push(element);
29            self.n += 1;
30            if self.n == self.k {
31                self.sort();
32            }
33        } else {
34            self.n += 1;
35            if element.partial_cmp(&self.heap[0]) == Some(Ordering::Less) {
36                self.heap[0] = element;
37                self.sift_down(0, self.k - 1);
38            }
39        }
40    }
41
42    pub fn heapify(&mut self) {
43        let n = self.heap.len();
44        if n <= 1 {
45            return;
46        }
47        for i in (0..=(n / 2 - 1)).rev() {
48            self.sift_down(i, n - 1);
49        }
50    }
51
52    pub fn peek(&self) -> &T {
53        if self.sorted {
54            &self.heap[0]
55        } else {
56            self.heap
57                .iter()
58                .max_by(|a, b| a.partial_cmp(b).unwrap())
59                .unwrap()
60        }
61    }
62
63    pub fn peek_mut(&mut self) -> &mut T {
64        &mut self.heap[0]
65    }
66
67    pub fn get(self) -> Vec<T> {
68        self.heap
69    }
70
71    fn sift_down(&mut self, k: usize, n: usize) {
72        let mut kk = k;
73        while 2 * kk <= n {
74            let mut j = 2 * kk;
75            if j < n && self.heap[j].partial_cmp(&self.heap[j + 1]) == Some(Ordering::Less) {
76                j += 1;
77            }
78            if self.heap[kk].partial_cmp(&self.heap[j]) == Some(Ordering::Equal)
79                || self.heap[kk].partial_cmp(&self.heap[j]) == Some(Ordering::Greater)
80            {
81                break;
82            }
83            self.heap.swap(kk, j);
84            kk = j;
85        }
86    }
87
88    fn sort(&mut self) {
89        self.sorted = true;
90        self.heap.sort_by(|a, b| b.partial_cmp(a).unwrap());
91    }
92}
93
94#[cfg(test)]
95mod tests {
96    use super::*;
97
98    #[cfg_attr(
99        all(target_arch = "wasm32", not(target_os = "wasi")),
100        wasm_bindgen_test::wasm_bindgen_test
101    )]
102    #[test]
103    fn with_capacity() {
104        let heap = HeapSelection::<i32>::with_capacity(3);
105        assert_eq!(3, heap.k);
106    }
107
108    #[cfg_attr(
109        all(target_arch = "wasm32", not(target_os = "wasi")),
110        wasm_bindgen_test::wasm_bindgen_test
111    )]
112    #[test]
113    fn test_add() {
114        let mut heap = HeapSelection::with_capacity(3);
115        heap.add(-5);
116        assert_eq!(-5, *heap.peek());
117        heap.add(333);
118        assert_eq!(333, *heap.peek());
119        heap.add(13);
120        heap.add(10);
121        heap.add(2);
122        heap.add(0);
123        heap.add(40);
124        heap.add(30);
125        assert_eq!(8, heap.n);
126        assert_eq!(vec![2, 0, -5], heap.get());
127    }
128
129    #[cfg_attr(
130        all(target_arch = "wasm32", not(target_os = "wasi")),
131        wasm_bindgen_test::wasm_bindgen_test
132    )]
133    #[test]
134    fn test_add1() {
135        let mut heap = HeapSelection::with_capacity(3);
136        heap.add(f64::INFINITY);
137        heap.add(-5f64);
138        heap.add(4f64);
139        heap.add(-1f64);
140        heap.add(2f64);
141        heap.add(1f64);
142        heap.add(0f64);
143        assert_eq!(7, heap.n);
144        assert_eq!(vec![0f64, -1f64, -5f64], heap.get());
145    }
146
147    #[cfg_attr(
148        all(target_arch = "wasm32", not(target_os = "wasi")),
149        wasm_bindgen_test::wasm_bindgen_test
150    )]
151    #[test]
152    fn test_add2() {
153        let mut heap = HeapSelection::with_capacity(3);
154        heap.add(f64::INFINITY);
155        heap.add(0.0);
156        heap.add(8.4852);
157        heap.add(5.6568);
158        heap.add(2.8284);
159        assert_eq!(5, heap.n);
160        assert_eq!(vec![5.6568, 2.8284, 0.0], heap.get());
161    }
162
163    #[cfg_attr(
164        all(target_arch = "wasm32", not(target_os = "wasi")),
165        wasm_bindgen_test::wasm_bindgen_test
166    )]
167    #[test]
168    fn test_add_ordered() {
169        let mut heap = HeapSelection::with_capacity(3);
170        heap.add(1.);
171        heap.add(2.);
172        heap.add(3.);
173        heap.add(4.);
174        heap.add(5.);
175        heap.add(6.);
176        assert_eq!(vec![3., 2., 1.], heap.get());
177    }
178}