hamming_heap/
fixed_heap.rs

1/// This keeps the nearest `cap` items at all times.
2///
3/// This heap is not intended to be popped. Instead, this maintains the best `cap` items, and then when you are
4/// done adding items, you may fill a slice or iterate over the results. Theoretically, this could also allow
5/// popping elements in constant time, but that would incur a performance penalty for the highly specialized
6/// purpose this serves. This is specifically tailored for doing hamming space nearest neighbor searches.
7///
8/// To use this you will need to call `set_distances` before use. This should be passed the maximum number of
9/// distances. Please keep in mind that the maximum number of hamming distances between an `n` bit number
10/// is `n + 1`. An example would be:
11///
12/// ```
13/// assert_eq!((0u128 ^ !0).count_ones(), 128);
14/// ```
15///
16/// So make sure you use `n + 1` as your `distances` or else you may encounter a runtime panic.
17///
18/// ```
19/// use hamming_heap::FixedHammingHeap;
20/// let mut candidates = FixedHammingHeap::new_distances(129);
21/// candidates.set_capacity(3);
22/// candidates.push((0u128 ^ !0u128).count_ones(), ());
23/// ```
24#[derive(Clone, Debug)]
25pub struct FixedHammingHeap<T> {
26    cap: usize,
27    size: usize,
28    worst: u32,
29    distances: Vec<Vec<T>>,
30}
31
32impl<T> FixedHammingHeap<T> {
33    pub fn new() -> Self {
34        Self::default()
35    }
36
37    /// Automatically initializes self with `distances` distances.
38    pub fn new_distances(distances: usize) -> Self {
39        let mut s = Self::new();
40        s.set_distances(distances);
41        s
42    }
43
44    /// This sets the capacity of the queue to `cap`, meaning that adding items to the queue will eject the worst ones
45    /// if they are better once `cap` is reached. If the capacity is lowered, this removes the worst elements to
46    /// keep `size == cap`.
47    pub fn set_capacity(&mut self, cap: usize) {
48        assert_ne!(cap, 0);
49        self.set_len(cap);
50        self.cap = cap;
51        // After the capacity is changed, if the size now equals the capacity we need to update the worst because it must
52        // actually be set to the worst item.
53        self.worst = self.distances.len() as u32 - 1;
54        if self.size == self.cap {
55            self.update_worst();
56        }
57    }
58
59    /// This removes elements until it reaches `len`. If `len` is higher than the current
60    /// number of elements, this does nothing. If the len is lowered, this will unconditionally allow insertions
61    /// until `cap` is reached.
62    pub fn set_len(&mut self, len: usize) {
63        if len == 0 {
64            let end = self.end();
65            for v in &mut self.distances[..=end] {
66                v.clear();
67            }
68            self.size = 0;
69            self.worst = self.distances.len() as u32 - 1;
70        } else if len < self.size {
71            // Remove the difference between them.
72            let end = self.end();
73            let mut remaining = self.size - len;
74            for vec in &mut self.distances[..=end] {
75                if vec.len() >= remaining {
76                    // This has enough, remove them then break.
77                    vec.drain(vec.len() - remaining..);
78                    break;
79                } else {
80                    // There werent enough, so remove everything and move on.
81                    remaining -= vec.len();
82                    vec.clear();
83                }
84            }
85            // When len is less than the cap, worst must be set to max.
86            self.worst = self.distances.len() as u32 - 1;
87            self.size = len;
88        }
89    }
90
91    /// Gets the `len` or `size` of the heap.
92    pub fn len(&self) -> usize {
93        self.size
94    }
95
96    /// Checks if the heap is empty.
97    pub fn is_empty(&self) -> bool {
98        self.size == 0
99    }
100
101    /// Clear the queue while maintaining the allocated memory.
102    pub fn clear(&mut self) {
103        assert_ne!(
104            self.distances.len(),
105            0,
106            "you must call set_distances() before calling clear()"
107        );
108        let end = self.end();
109        for v in self.distances[..=end].iter_mut() {
110            v.clear();
111        }
112        self.size = 0;
113        self.worst = self.distances.len() as u32 - 1;
114    }
115
116    /// Set number of distances. Also clears the heap.
117    ///
118    /// This does not preserve the allocated memory, so don't call this on each search.
119    ///
120    /// If you have a 128-bit number, keep in mind that it has `129` distances because
121    /// `128` is one of the possible distances.
122    pub fn set_distances(&mut self, distances: usize) {
123        self.distances.clear();
124        self.distances.resize_with(distances, || vec![]);
125        self.worst = self.distances.len() as u32 - 1;
126        self.size = 0;
127    }
128
129    /// Add a feature to the search.
130    ///
131    /// Returns true if it was added.
132    pub fn push(&mut self, distance: u32, item: T) -> bool {
133        if self.size != self.cap {
134            self.distances[distance as usize].push(item);
135            self.size += 1;
136            // Set the worst feature appropriately.
137            if self.size == self.cap {
138                self.update_worst();
139            }
140            true
141        } else {
142            unsafe { self.push_at_cap(distance, item) }
143        }
144    }
145
146    /// Fill a slice with the `top` elements and return the part of the slice written.
147    pub fn fill_slice<'a>(&self, s: &'a mut [T]) -> &'a mut [T]
148    where
149        T: Clone,
150    {
151        let total_fill = std::cmp::min(s.len(), self.size);
152        for (ix, f) in self.distances[..=self.end()]
153            .iter()
154            .flat_map(|v| v.iter())
155            .take(total_fill)
156            .enumerate()
157        {
158            s[ix] = f.clone();
159        }
160        &mut s[0..total_fill]
161    }
162
163    /// Gets the worst distance in the queue currently.
164    ///
165    /// This is initialized to max (which is the worst possible distance) until `cap` elements have been inserted.
166    pub fn worst(&self) -> u32 {
167        self.worst
168    }
169
170    /// Returns true if the cap has been reached.
171    pub fn at_cap(&self) -> bool {
172        self.size == self.cap
173    }
174
175    /// Iterate over the entire queue in best-to-worse order.
176    pub fn iter(&mut self) -> impl Iterator<Item = (u32, &T)> {
177        self.distances[..=self.end()]
178            .iter()
179            .enumerate()
180            .flat_map(|(distance, v)| v.iter().map(move |item| (distance as u32, item)))
181    }
182
183    /// Iterate over the entire queue in best-to-worse order.
184    pub fn iter_mut(&mut self) -> impl Iterator<Item = (u32, &mut T)> {
185        let end = self.end();
186        self.distances[..=end]
187            .iter_mut()
188            .enumerate()
189            .flat_map(|(distance, v)| v.iter_mut().map(move |item| (distance as u32, item)))
190    }
191
192    /// Add a feature to the search with the precondition we are already at the cap.
193    ///
194    /// Warning: This function cannot cause undefined behavior, but it can be used incorrectly.
195    /// This should only be called after `at_cap()` can been called and returns true.
196    /// This shouldn't be used unless you profile and actually find that the branch predictor is having
197    /// issues with the if statement in `push()`.
198    pub unsafe fn push_at_cap(&mut self, distance: u32, item: T) -> bool {
199        // We stop searching once we have enough features under the search distance,
200        // so if this is true it will always get added to the FeatureHeap.
201        if distance < self.worst {
202            self.distances[distance as usize].push(item);
203            self.remove_worst();
204            true
205        } else {
206            false
207        }
208    }
209
210    /// Gets the smallest known inclusive end of the datastructure.
211    fn end(&self) -> usize {
212        if self.at_cap() {
213            self.worst as usize
214        } else {
215            self.distances.len() - 1
216        }
217    }
218
219    /// Updates the worst when it has been set.
220    fn update_worst(&mut self) {
221        // If there is nothing left, it gets reset to max.
222        self.worst = self.distances[0..=self.worst as usize]
223            .iter()
224            .rev()
225            .position(|v| !v.is_empty())
226            .map(|n| self.worst - n as u32)
227            .unwrap_or(self.distances.len() as u32 - 1);
228    }
229
230    /// Remove the worst item and update the worst distance.
231    fn remove_worst(&mut self) {
232        self.distances[self.worst as usize].pop();
233        self.update_worst();
234    }
235}
236
237impl<T> Default for FixedHammingHeap<T> {
238    fn default() -> Self {
239        Self {
240            cap: 0,
241            size: 0,
242            worst: 0,
243            distances: vec![],
244        }
245    }
246}
247
248#[cfg(test)]
249#[test]
250fn test_fixed_heap() {
251    let mut candidates: FixedHammingHeap<u32> = FixedHammingHeap::new();
252    candidates.set_distances(11);
253    candidates.set_capacity(3);
254    assert!(candidates.push(5, 0));
255    assert!(candidates.push(4, 1));
256    assert!(candidates.push(3, 2));
257    assert!(!candidates.push(6, 3));
258    assert!(!candidates.push(7, 4));
259    assert!(candidates.push(2, 5));
260    assert!(candidates.push(3, 6));
261    assert!(!candidates.push(10, 7));
262    assert!(!candidates.push(6, 8));
263    assert!(!candidates.push(4, 9));
264    assert!(candidates.push(1, 10));
265    assert!(candidates.push(2, 11));
266    let mut arr = [0; 3];
267    candidates.fill_slice(&mut arr);
268    arr[1..3].sort_unstable();
269    assert_eq!(arr, [10, 5, 11]);
270}