smartcore/algorithm/sort/
heap_select.rs1use std::cmp::Ordering;
5use std::fmt::Debug;
6
7#[derive(Debug)]
8pub struct HeapSelection<T: PartialOrd + Debug> {
9 k: usize,
10 n: usize,
11 sorted: bool,
12 heap: Vec<T>,
13}
14
15impl<T: PartialOrd + Debug> HeapSelection<T> {
16 pub fn with_capacity(k: usize) -> HeapSelection<T> {
17 HeapSelection {
18 k,
19 n: 0,
20 sorted: false,
21 heap: Vec::new(),
22 }
23 }
24
25 pub fn add(&mut self, element: T) {
26 self.sorted = false;
27 if self.n < self.k {
28 self.heap.push(element);
29 self.n += 1;
30 if self.n == self.k {
31 self.sort();
32 }
33 } else {
34 self.n += 1;
35 if element.partial_cmp(&self.heap[0]) == Some(Ordering::Less) {
36 self.heap[0] = element;
37 self.sift_down(0, self.k - 1);
38 }
39 }
40 }
41
42 pub fn heapify(&mut self) {
43 let n = self.heap.len();
44 if n <= 1 {
45 return;
46 }
47 for i in (0..=(n / 2 - 1)).rev() {
48 self.sift_down(i, n - 1);
49 }
50 }
51
52 pub fn peek(&self) -> &T {
53 if self.sorted {
54 &self.heap[0]
55 } else {
56 self.heap
57 .iter()
58 .max_by(|a, b| a.partial_cmp(b).unwrap())
59 .unwrap()
60 }
61 }
62
63 pub fn peek_mut(&mut self) -> &mut T {
64 &mut self.heap[0]
65 }
66
67 pub fn get(self) -> Vec<T> {
68 self.heap
69 }
70
71 fn sift_down(&mut self, k: usize, n: usize) {
72 let mut kk = k;
73 while 2 * kk <= n {
74 let mut j = 2 * kk;
75 if j < n && self.heap[j].partial_cmp(&self.heap[j + 1]) == Some(Ordering::Less) {
76 j += 1;
77 }
78 if self.heap[kk].partial_cmp(&self.heap[j]) == Some(Ordering::Equal)
79 || self.heap[kk].partial_cmp(&self.heap[j]) == Some(Ordering::Greater)
80 {
81 break;
82 }
83 self.heap.swap(kk, j);
84 kk = j;
85 }
86 }
87
88 fn sort(&mut self) {
89 self.sorted = true;
90 self.heap.sort_by(|a, b| b.partial_cmp(a).unwrap());
91 }
92}
93
94#[cfg(test)]
95mod tests {
96 use super::*;
97
98 #[cfg_attr(
99 all(target_arch = "wasm32", not(target_os = "wasi")),
100 wasm_bindgen_test::wasm_bindgen_test
101 )]
102 #[test]
103 fn with_capacity() {
104 let heap = HeapSelection::<i32>::with_capacity(3);
105 assert_eq!(3, heap.k);
106 }
107
108 #[cfg_attr(
109 all(target_arch = "wasm32", not(target_os = "wasi")),
110 wasm_bindgen_test::wasm_bindgen_test
111 )]
112 #[test]
113 fn test_add() {
114 let mut heap = HeapSelection::with_capacity(3);
115 heap.add(-5);
116 assert_eq!(-5, *heap.peek());
117 heap.add(333);
118 assert_eq!(333, *heap.peek());
119 heap.add(13);
120 heap.add(10);
121 heap.add(2);
122 heap.add(0);
123 heap.add(40);
124 heap.add(30);
125 assert_eq!(8, heap.n);
126 assert_eq!(vec![2, 0, -5], heap.get());
127 }
128
129 #[cfg_attr(
130 all(target_arch = "wasm32", not(target_os = "wasi")),
131 wasm_bindgen_test::wasm_bindgen_test
132 )]
133 #[test]
134 fn test_add1() {
135 let mut heap = HeapSelection::with_capacity(3);
136 heap.add(f64::INFINITY);
137 heap.add(-5f64);
138 heap.add(4f64);
139 heap.add(-1f64);
140 heap.add(2f64);
141 heap.add(1f64);
142 heap.add(0f64);
143 assert_eq!(7, heap.n);
144 assert_eq!(vec![0f64, -1f64, -5f64], heap.get());
145 }
146
147 #[cfg_attr(
148 all(target_arch = "wasm32", not(target_os = "wasi")),
149 wasm_bindgen_test::wasm_bindgen_test
150 )]
151 #[test]
152 fn test_add2() {
153 let mut heap = HeapSelection::with_capacity(3);
154 heap.add(f64::INFINITY);
155 heap.add(0.0);
156 heap.add(8.4852);
157 heap.add(5.6568);
158 heap.add(2.8284);
159 assert_eq!(5, heap.n);
160 assert_eq!(vec![5.6568, 2.8284, 0.0], heap.get());
161 }
162
163 #[cfg_attr(
164 all(target_arch = "wasm32", not(target_os = "wasi")),
165 wasm_bindgen_test::wasm_bindgen_test
166 )]
167 #[test]
168 fn test_add_ordered() {
169 let mut heap = HeapSelection::with_capacity(3);
170 heap.add(1.);
171 heap.add(2.);
172 heap.add(3.);
173 heap.add(4.);
174 heap.add(5.);
175 heap.add(6.);
176 assert_eq!(vec![3., 2., 1.], heap.get());
177 }
178}