1#![feature(is_sorted)]
2
3macro_rules! get_parent {
4 ($index: expr) => {
5 ($index - 1) / 2
6 };
7}
8
9pub type HeapComparator<T> = fn(a: &T, b: &T) -> bool;
10
11#[derive(PartialEq)]
12pub enum HeapType {
13 Min,
14 Max,
15}
16
17#[derive(Debug)]
18pub struct Heap<T> {
19 data: Vec<T>,
20 comparator: HeapComparator<T>,
21}
22
23impl<T> Heap<T> where T: PartialOrd + std::fmt::Display {
24 pub fn new(heap_type: HeapType) -> Self {
25 Self {
26 data: Vec::new(),
27 comparator: Heap::get_comparator(heap_type),
28 }
29 }
30
31 #[deprecated]
32 pub fn from_array(heap_type: HeapType, input: Vec<T>) -> Self {
33 Heap::from_vec(heap_type, input)
34 }
35
36 pub fn from_vec(heap_type: HeapType, input: Vec<T>) -> Self {
37 let mut heap = Self {
38 data: input,
39 comparator: Self::get_comparator(heap_type),
40 };
41
42 let size = heap.data.len();
43 if size == 0 {
44 return heap;
45 }
46
47 let mut i: usize = (size / 2) - 1;
48 loop {
49 let mut do_more = heap.heapify(i);
50 while do_more.is_some() {
51 do_more = heap.heapify(do_more.unwrap());
52 }
53 if i != 0 {
54 i = i - 1;
55 } else {
56 break;
57 }
58 }
59
60 heap
61 }
62
63 fn float_down(&mut self, index: usize) {
64 let mut do_more = self.heapify(index);
65 while do_more.is_some() {
66 do_more = self.heapify(do_more.unwrap());
67 }
68 }
69
70 fn float_up(&mut self, index: usize) {
71 let mut _index = get_parent!(index);
72 let mut do_more = self.heapify(_index);
73
74 while _index != 0 && do_more.is_some() {
75 _index = get_parent!(_index);
76 do_more = self.heapify(_index);
77 }
78 }
79
80 pub fn insert(&mut self, value: T) {
81 self.data.push(value);
82 let new_size = self.data.len();
83 if new_size > 1 {
84 self.float_up(new_size - 1);
85 }
86 }
87
88 pub fn root(&self) -> Option<&T> {
89 self.data.get(0)
90 }
91
92 pub fn extract(&mut self) -> Option<T> {
93 if self.data.is_empty() {
94 None
95 } else if self.data.len() == 1 {
96 Some(self.data.remove(0))
97 } else {
98 let last_index = self.data.len() - 1;
99 self.data.swap(0, last_index);
100
101 let result = self.data.remove(last_index);
102
103 self.float_down(0);
104
105 Some(result)
106 }
107 }
108
109 pub fn get(&self, index: usize) -> Option<&T> {
110 self.data.get(index)
111 }
112
113 pub fn raw(&self) -> &Vec<T> {
114 &self.data
115 }
116
117 pub fn collect(&mut self) -> Vec<T> {
118 let mut output: Vec<T> = Vec::new();
119 for _ in 0..self.data.len() {
120 output.push(self.extract().unwrap());
121 }
122 output
123 }
124
125 fn get_comparator(heap_type: HeapType) -> HeapComparator<T> {
126 match heap_type {
127 HeapType::Min => |a: &T, b: &T| {
128 a > b
129 },
130 HeapType::Max => |a: &T, b: &T| -> bool {
131 a < b
132 }
133 }
134 }
135
136 fn heapify(&mut self, index: usize) -> Option<usize> {
137 let mut affected_index: usize = index; let l_index = 2 * index + 1; let r_index = 2 * index + 2; let data = &mut self.data; let comparator = self.comparator;
144
145 if !data.get(l_index).is_none() && comparator(&data[affected_index], &data[l_index]) {
147 affected_index = l_index;
148 }
149
150 if !data.get(r_index).is_none() && comparator(&data[affected_index], &data[r_index]) {
151 affected_index = r_index;
152 }
153
154 if affected_index != index {
155 data.swap(index, affected_index);
156 Some(affected_index)
157 } else {
158 None
159 }
160 }
161}
162
163#[cfg(test)]
164mod test {
165 use rand::seq::SliceRandom;
166 use rand::thread_rng;
167
168 use crate::{Heap, HeapType};
169
170 fn random_vec(size: u32) -> Vec<u32> {
171 let mut result: Vec<u32> = (0u32..size).collect();
172 let mut rng = thread_rng();
173
174 result.shuffle(&mut rng);
175 result
176 }
177
178 #[test]
179 fn test_heapify() {
180 let mut heap = Heap::<u32>::new(HeapType::Max);
181
182 let input = vec![1, 3, 2];
183 let expected = [3, 1, 2];
184
185 heap.data = input;
186 heap.heapify(0);
187
188 let output = heap.raw();
189
190 assert_eq!(*output, expected);
191 }
192
193 #[test]
194 fn test_insert_min() {
195 let mut heap = Heap::<u32>::new(HeapType::Min);
196
197 heap.insert(10);
198 assert_eq!(*heap.raw(), vec![10]);
199
200 heap.insert(11);
201 assert_eq!(*heap.raw(), vec![10, 11]);
202
203 heap.insert(9);
204 assert_eq!(*heap.raw(), vec![9, 11, 10]);
205
206 heap.insert(5);
207 assert_eq!(*heap.raw(), vec![5, 9, 10, 11]);
208
209 heap.insert(6);
210 assert_eq!(*heap.raw(), vec![5, 6, 10, 11, 9]);
211
212 let mut expected = vec![5, 6, 10, 11, 9];
213 expected.sort();
214 assert_eq!(expected, heap.collect());
215 }
216
217 #[test]
218 fn test_insert_max() {
219 let mut heap = Heap::<u32>::new(HeapType::Max);
220
221 heap.insert(10);
222 assert_eq!(*heap.raw(), vec![10]);
223
224 heap.insert(11);
225 assert_eq!(*heap.raw(), vec![11, 10]);
226
227 heap.insert(9);
228 assert_eq!(*heap.raw(), vec![11, 10, 9]);
229
230 heap.insert(5);
231 assert_eq!(*heap.raw(), vec![11, 10, 9, 5]);
232
233 heap.insert(6);
234 assert_eq!(*heap.raw(), vec![11, 10, 9, 5, 6]);
235
236 let mut expected = vec![5, 6, 10, 11, 9];
237 expected.sort();
238 expected.reverse();
239 assert_eq!(expected, heap.collect());
240 }
241
242 #[test]
243 fn test_extract_min() {
244 let mut heap = Heap::<u32>::new(HeapType::Min);
245 let mut input: [u32; 4] = [10, 11, 9, 5];
246
247 for i in input {
248 heap.insert(i)
249 }
250
251 input.sort();
252 for i in 0..input.len() {
253 assert_eq!(input[i], heap.extract().unwrap());
254 }
255 }
256
257 #[test]
258 fn test_extract_max() {
259 let mut heap = Heap::<u32>::new(HeapType::Max);
260 let mut input: [u32; 4] = [10, 11, 9, 5];
261
262 for i in input {
263 heap.insert(i)
264 }
265
266 input.sort();
267 input.reverse();
268 for i in 0..input.len() {
269 assert_eq!(input[i], heap.extract().unwrap());
270 }
271 }
272
273 #[test]
274 fn test_from_vec_min() {
275 let input: Vec<u32> = random_vec(100);
276 let mut heap = Heap::<u32>::from_vec(HeapType::Min, input.clone());
277
278 let output = heap.collect();
279 let mut expected = input.clone();
280 expected.sort();
281
282 assert_eq!(expected, output);
283 }
284
285 #[test]
286 fn test_from_vec_edge() {
287 let _ = Heap::<u32>::from_vec(HeapType::Max, Vec::new());
289 }
290
291 #[test]
292 fn test_from_vec_max() {
293 let input: Vec<u32> = random_vec(100);
294 let mut heap = Heap::<u32>::from_vec(HeapType::Max, input.clone());
295
296 let output = heap.collect();
297 let mut expected = input.clone();
298 expected.sort();
299 expected.reverse();
300
301 assert_eq!(expected, output);
302 }
303}