csheap/
lib.rs

1#![feature(is_sorted)]
2
3macro_rules! get_parent {
4    ($index: expr) => {
5        ($index - 1) / 2
6    };
7}
8
9pub type HeapComparator<T> = fn(a: &T, b: &T) -> bool;
10
11#[derive(PartialEq)]
12pub enum HeapType {
13    Min,
14    Max,
15}
16
17#[derive(Debug)]
18pub struct Heap<T> {
19    data: Vec<T>,
20    comparator: HeapComparator<T>,
21}
22
23impl<T> Heap<T> where T: PartialOrd + std::fmt::Display {
24    pub fn new(heap_type: HeapType) -> Self {
25        Self {
26            data: Vec::new(),
27            comparator: Heap::get_comparator(heap_type),
28        }
29    }
30
31    #[deprecated]
32    pub fn from_array(heap_type: HeapType, input: Vec<T>) -> Self {
33        Heap::from_vec(heap_type, input)
34    }
35
36    pub fn from_vec(heap_type: HeapType, input: Vec<T>) -> Self {
37        let mut heap = Self {
38            data: input,
39            comparator: Self::get_comparator(heap_type),
40        };
41
42        let size = heap.data.len();
43        if size == 0 {
44            return heap;
45        }
46
47        let mut i: usize = (size / 2) - 1;
48        loop {
49            let mut do_more = heap.heapify(i);
50            while do_more.is_some() {
51                do_more = heap.heapify(do_more.unwrap());
52            }
53            if i != 0 {
54                i = i - 1;
55            } else {
56                break;
57            }
58        }
59
60        heap
61    }
62
63    fn float_down(&mut self, index: usize) {
64        let mut do_more = self.heapify(index);
65        while do_more.is_some() {
66            do_more = self.heapify(do_more.unwrap());
67        }
68    }
69
70    fn float_up(&mut self, index: usize) {
71        let mut _index = get_parent!(index);
72        let mut do_more = self.heapify(_index);
73
74        while _index != 0 && do_more.is_some() {
75            _index = get_parent!(_index);
76            do_more = self.heapify(_index);
77        }
78    }
79
80    pub fn insert(&mut self, value: T) {
81        self.data.push(value);
82        let new_size = self.data.len();
83        if new_size > 1 {
84            self.float_up(new_size - 1);
85        }
86    }
87
88    pub fn root(&self) -> Option<&T> {
89        self.data.get(0)
90    }
91
92    pub fn extract(&mut self) -> Option<T> {
93        if self.data.is_empty() {
94            None
95        } else if self.data.len() == 1 {
96            Some(self.data.remove(0))
97        } else {
98            let last_index = self.data.len() - 1;
99            self.data.swap(0, last_index);
100
101            let result = self.data.remove(last_index);
102
103            self.float_down(0);
104
105            Some(result)
106        }
107    }
108
109    pub fn get(&self, index: usize) -> Option<&T> {
110        self.data.get(index)
111    }
112
113    pub fn raw(&self) -> &Vec<T> {
114        &self.data
115    }
116
117    pub fn collect(&mut self) -> Vec<T> {
118        let mut output: Vec<T> = Vec::new();
119        for _ in 0..self.data.len() {
120            output.push(self.extract().unwrap());
121        }
122        output
123    }
124
125    fn get_comparator(heap_type: HeapType) -> HeapComparator<T> {
126        match heap_type {
127            HeapType::Min => |a: &T, b: &T| {
128                a > b
129            },
130            HeapType::Max => |a: &T, b: &T| -> bool {
131                a < b
132            }
133        }
134    }
135
136    fn heapify(&mut self, index: usize) -> Option<usize> {
137        let mut affected_index: usize = index;   // set parent as max element.
138
139        let l_index = 2 * index + 1;        // get left child
140        let r_index = 2 * index + 2;        // get right child
141
142        let data = &mut self.data;          // to short
143        let comparator = self.comparator;
144
145        //  determinate max index
146        if !data.get(l_index).is_none() && comparator(&data[affected_index], &data[l_index]) {
147            affected_index = l_index;
148        }
149
150        if !data.get(r_index).is_none() && comparator(&data[affected_index], &data[r_index]) {
151            affected_index = r_index;
152        }
153
154        if affected_index != index {
155            data.swap(index, affected_index);
156            Some(affected_index)
157        } else {
158            None
159        }
160    }
161}
162
163#[cfg(test)]
164mod test {
165    use rand::seq::SliceRandom;
166    use rand::thread_rng;
167
168    use crate::{Heap, HeapType};
169
170    fn random_vec(size: u32) -> Vec<u32> {
171        let mut result: Vec<u32> = (0u32..size).collect();
172        let mut rng = thread_rng();
173
174        result.shuffle(&mut rng);
175        result
176    }
177
178    #[test]
179    fn test_heapify() {
180        let mut heap = Heap::<u32>::new(HeapType::Max);
181
182        let input = vec![1, 3, 2];
183        let expected = [3, 1, 2];
184
185        heap.data = input;
186        heap.heapify(0);
187
188        let output = heap.raw();
189
190        assert_eq!(*output, expected);
191    }
192
193    #[test]
194    fn test_insert_min() {
195        let mut heap = Heap::<u32>::new(HeapType::Min);
196
197        heap.insert(10);
198        assert_eq!(*heap.raw(), vec![10]);
199
200        heap.insert(11);
201        assert_eq!(*heap.raw(), vec![10, 11]);
202
203        heap.insert(9);
204        assert_eq!(*heap.raw(), vec![9, 11, 10]);
205
206        heap.insert(5);
207        assert_eq!(*heap.raw(), vec![5, 9, 10, 11]);
208
209        heap.insert(6);
210        assert_eq!(*heap.raw(), vec![5, 6, 10, 11, 9]);
211
212        let mut expected = vec![5, 6, 10, 11, 9];
213        expected.sort();
214        assert_eq!(expected, heap.collect());
215    }
216
217    #[test]
218    fn test_insert_max() {
219        let mut heap = Heap::<u32>::new(HeapType::Max);
220
221        heap.insert(10);
222        assert_eq!(*heap.raw(), vec![10]);
223
224        heap.insert(11);
225        assert_eq!(*heap.raw(), vec![11, 10]);
226
227        heap.insert(9);
228        assert_eq!(*heap.raw(), vec![11, 10, 9]);
229
230        heap.insert(5);
231        assert_eq!(*heap.raw(), vec![11, 10, 9, 5]);
232
233        heap.insert(6);
234        assert_eq!(*heap.raw(), vec![11, 10, 9, 5, 6]);
235
236        let mut expected = vec![5, 6, 10, 11, 9];
237        expected.sort();
238        expected.reverse();
239        assert_eq!(expected, heap.collect());
240    }
241
242    #[test]
243    fn test_extract_min() {
244        let mut heap = Heap::<u32>::new(HeapType::Min);
245        let mut input: [u32; 4] = [10, 11, 9, 5];
246
247        for i in input {
248            heap.insert(i)
249        }
250
251        input.sort();
252        for i in 0..input.len() {
253            assert_eq!(input[i], heap.extract().unwrap());
254        }
255    }
256
257    #[test]
258    fn test_extract_max() {
259        let mut heap = Heap::<u32>::new(HeapType::Max);
260        let mut input: [u32; 4] = [10, 11, 9, 5];
261
262        for i in input {
263            heap.insert(i)
264        }
265
266        input.sort();
267        input.reverse();
268        for i in 0..input.len() {
269            assert_eq!(input[i], heap.extract().unwrap());
270        }
271    }
272
273    #[test]
274    fn test_from_vec_min() {
275        let input: Vec<u32> = random_vec(100);
276        let mut heap = Heap::<u32>::from_vec(HeapType::Min, input.clone());
277
278        let output = heap.collect();
279        let mut expected = input.clone();
280        expected.sort();
281
282        assert_eq!(expected, output);
283    }
284
285    #[test]
286    fn test_from_vec_edge() {
287        // From empty vector
288        let _ = Heap::<u32>::from_vec(HeapType::Max, Vec::new());
289    }
290
291    #[test]
292    fn test_from_vec_max() {
293        let input: Vec<u32> = random_vec(100);
294        let mut heap = Heap::<u32>::from_vec(HeapType::Max, input.clone());
295
296        let output = heap.collect();
297        let mut expected = input.clone();
298        expected.sort();
299        expected.reverse();
300
301        assert_eq!(expected, output);
302    }
303}