lsph/map/
mod.rs

1mod nn;
2mod table;
3
4use crate::{
5    error::*,
6    geometry::{distance::*, Point},
7    hasher::*,
8    map::{nn::*, table::*},
9    models::Model,
10};
11use core::{fmt::Debug, iter::Sum, mem};
12use num_traits::{
13    cast::{AsPrimitive, FromPrimitive},
14    float::Float,
15};
16use std::collections::BinaryHeap;
17
18/// Initial bucket size is set to 1
19const INITIAL_NBUCKETS: usize = 1;
20
21/// LearnedHashMap takes a model instead of an hasher for hashing indexes in the table.
22///
23/// Default Model for the LearndedHashMap is Linear regression.
24/// In order to build a ordered HashMap, we need to make sure that the model is **monotonic**.
25#[derive(Debug, Clone)]
26pub struct LearnedHashMap<M, F> {
27    hasher: LearnedHasher<M>,
28    table: Table<Point<F>>,
29    items: usize,
30}
31
32/// Default for the LearndedHashMap.
33impl<M, F> Default for LearnedHashMap<M, F>
34where
35    F: Float,
36    M: Model<F = F> + Default,
37{
38    fn default() -> Self {
39        Self {
40            hasher: LearnedHasher::<M>::new(),
41            table: Table::new(),
42            items: 0,
43        }
44    }
45}
46
47impl<M, F> LearnedHashMap<M, F>
48where
49    F: Float + Default + AsPrimitive<u64> + FromPrimitive + Debug + Sum,
50    M: Model<F = F> + Default + Clone,
51{
52    /// Returns a default LearnedHashMap with Model and Float type.
53    ///
54    /// # Examples
55    ///
56    /// ```
57    /// use lsph::{LearnedHashMap, LinearModel};
58    /// let map = LearnedHashMap::<LinearModel<f64>, f64>::new();
59    /// ```
60    #[inline]
61    pub fn new() -> Self {
62        Self::default()
63    }
64
65    /// Returns a default LearnedHashMap with Model and Float type.
66    ///
67    /// # Arguments
68    /// * `hasher` - A LearnedHasher with model
69    ///
70    /// # Examples
71    ///
72    /// ```
73    /// use lsph::{LearnedHashMap, LinearModel, LearnedHasher};
74    /// let map = LearnedHashMap::<LinearModel<f64>, f64>::with_hasher(LearnedHasher::new());
75    /// ```
76    #[inline]
77    pub fn with_hasher(hasher: LearnedHasher<M>) -> Self {
78        Self {
79            hasher,
80            table: Table::new(),
81            items: 0,
82        }
83    }
84
85    /// Returns a default LearnedHashMap with Model and Float type.
86    ///
87    /// # Arguments
88    /// * `capacity` - A predefined capacity size for the LearnedHashMap
89    ///
90    /// # Examples
91    ///
92    /// ```
93    /// use lsph::{LearnedHashMap, LinearModel, LearnedHasher};
94    /// let map = LearnedHashMap::<LinearModel<f64>, f64>::with_capacity(10usize);
95    /// ```
96    #[inline]
97    pub fn with_capacity(capacity: usize) -> Self {
98        Self {
99            hasher: Default::default(),
100            table: Table::with_capacity(capacity),
101            items: 0,
102        }
103    }
104
105    /// Returns a default LearnedHashMap with Model and Float type
106    ///
107    /// # Arguments
108    /// * `data` - A Vec<[F; 2]> of 2d points for the map
109    ///
110    /// # Examples
111    ///
112    /// ```
113    /// use lsph::{LearnedHashMap, LinearModel};
114    /// let data = vec![[1., 1.], [2., 1.], [3., 2.], [4., 4.]];
115    /// let map = LearnedHashMap::<LinearModel<f64>, f64>::with_data(&data);
116    /// ```
117    #[inline]
118    pub fn with_data(data: &[[F; 2]]) -> Result<(Self, Vec<Point<F>>), Error> {
119        use crate::helper::convert_to_points;
120        let mut map = LearnedHashMap::with_capacity(data.len());
121        let mut ps = convert_to_points(data).unwrap();
122        match map.batch_insert(&mut ps) {
123            Ok(()) => Ok((map, ps)),
124            Err(err) => Err(err),
125        }
126    }
127
128    /// Returns Option<Point<F>>  with given point data.
129    ///
130    /// # Arguments
131    /// * `p` - A array slice containing two points for querying
132    ///
133    /// # Examples
134    ///
135    /// ```
136    /// use lsph::{LearnedHashMap, LinearModel, LearnedHasher};
137    /// let point_data = vec![[1., 1.], [2., 1.], [3., 2.], [4., 4.]];
138    /// let (mut map, points) = LearnedHashMap::<LinearModel<f64>, f64>::with_data(&point_data).unwrap();
139    ///
140    /// assert_eq!(map.get(&[1., 1.]).is_some(), true);
141    /// ```
142    #[inline]
143    pub fn get(&mut self, p: &[F; 2]) -> Option<&Point<F>> {
144        let hash = make_hash_point(&mut self.hasher, p) as usize;
145        if hash > self.table.capacity() {
146            return None;
147        }
148        self.find_by_hash(hash, p)
149    }
150
151    /// Returns Option<Point<F>> by hash index, if it exists in the map.
152    ///
153    /// # Arguments
154    /// * `hash` - An usize hash value
155    /// * `p` - A array slice containing two points for querying
156    ///
157    /// # Examples
158    ///
159    /// ```
160    /// use lsph::{LearnedHashMap, LinearModel, LearnedHasher};
161    /// let point_data = vec![[1., 1.], [2., 1.], [3., 2.], [4., 4.]];
162    /// let (mut map, points) = LearnedHashMap::<LinearModel<f64>, f64>::with_data(&point_data).unwrap();
163    ///
164    /// assert_eq!(map.find_by_hash(0, &[1., 1.]).is_some(), true);
165    /// assert_eq!(map.find_by_hash(1, &[1., 1.]).is_none(), true);
166    /// ```
167    pub fn find_by_hash(&self, hash: usize, p: &[F; 2]) -> Option<&Point<F>> {
168        self.table[hash]
169            .iter()
170            .find(|&ep| ep.x == p[0] && ep.y == p[1])
171    }
172
173    /// Returns bool.
174    ///
175    /// # Arguments
176    /// * `p` - A array slice containing two points for querying
177    ///
178    /// # Examples
179    ///
180    /// ```
181    /// use lsph::{LearnedHashMap, LinearModel, LearnedHasher};
182    /// let point_data = vec![[1., 1.], [2., 1.], [3., 2.], [4., 4.]];
183    /// let (mut map, points) = LearnedHashMap::<LinearModel<f64>, f64>::with_data(&point_data).unwrap();
184    ///
185    /// assert_eq!(map.contains_points(&[1., 1.]), true);
186    /// assert_eq!(map.contains_points(&[0., 1.]), false);
187    /// ```
188    #[inline]
189    pub fn contains_points(&mut self, p: &[F; 2]) -> bool {
190        self.get(p).is_some()
191    }
192
193    /// Returns Option<Point<F>> if the map contains a point and successful remove it from the map.
194    ///
195    /// # Arguments
196    /// * `p` - A Point data
197    ///
198    /// # Examples
199    ///
200    /// ```
201    /// use lsph::{LearnedHashMap, LinearModel, LearnedHasher};
202    /// let point_data = vec![[1., 1.], [2., 1.], [3., 2.], [4., 4.]];
203    /// let (mut map, points) = LearnedHashMap::<LinearModel<f64>, f64>::with_data(&point_data).unwrap();
204    ///
205    /// let p = points[0];
206    /// assert_eq!(map.remove(&p).unwrap(), p);
207    /// ```
208    #[inline]
209    pub fn remove(&mut self, p: &Point<F>) -> Option<Point<F>> {
210        let hash = make_hash_point(&mut self.hasher, &[p.x, p.y]);
211        self.items -= 1;
212        self.table.remove_entry(hash, *p)
213    }
214
215    /// Returns usize length.
216    ///
217    /// # Examples
218    ///
219    /// ```
220    /// use lsph::{LearnedHashMap, LinearModel, LearnedHasher};
221    /// let point_data = vec![[1., 1.], [2., 1.], [3., 2.], [4., 4.]];
222    /// let (mut map, points) = LearnedHashMap::<LinearModel<f64>, f64>::with_data(&point_data).unwrap();
223    ///
224    /// assert_eq!(map.len(), 4);
225    /// ```
226    #[inline]
227    pub fn len(&self) -> usize {
228        self.table.len()
229    }
230
231    /// Returns usize number of items.
232    ///
233    /// # Examples
234    ///
235    /// ```
236    /// use lsph::{LearnedHashMap, LinearModel, LearnedHasher};
237    /// let point_data = vec![[1., 1.], [2., 1.], [3., 2.], [4., 4.]];
238    /// let (mut map, points) = LearnedHashMap::<LinearModel<f64>, f64>::with_data(&point_data).unwrap();
239    ///
240    /// assert_eq!(map.items(), 4);
241    /// ```
242    #[inline]
243    pub fn items(&self) -> usize {
244        self.items
245    }
246
247    /// Returns bool if the map is empty.
248    ///
249    /// # Examples
250    ///
251    /// ```
252    /// use lsph::{LearnedHashMap, LinearModel, LearnedHasher};
253    /// let point_data = vec![[1., 1.], [2., 1.], [3., 2.], [4., 4.]];
254    /// let (mut map, points) = LearnedHashMap::<LinearModel<f64>, f64>::with_data(&point_data).unwrap();
255    ///
256    /// assert_eq!(map.is_empty(), false);
257    /// ```
258    #[inline]
259    pub fn is_empty(&self) -> bool {
260        self.items == 0
261    }
262
263    /// Resize the map if needed, it will initialize the map to the INITIAL_NBUCKETS, otherwise it will double the capacity if table is not empty.
264    fn resize(&mut self) {
265        let target_size = match self.table.len() {
266            0 => INITIAL_NBUCKETS,
267            n => 2 * n,
268        };
269        self.resize_with_capacity(target_size);
270    }
271
272    /// Resize the map if needed, it will resize the map to desired capacity.
273    #[inline]
274    fn resize_with_capacity(&mut self, target_size: usize) {
275        let mut new_table = Table::with_capacity(target_size);
276        new_table.extend((0..target_size).map(|_| Bucket::new()));
277
278        for p in self.table.iter_mut().flat_map(|bucket| bucket.drain(..)) {
279            let hash = make_hash_point(&mut self.hasher, &[p.x, p.y]) as usize;
280            new_table[hash].push(p);
281        }
282
283        self.table = new_table;
284    }
285
286    /// Rehash the map.
287    #[inline]
288    fn rehash(&mut self) -> Result<(), Error> {
289        let mut old_data = Vec::with_capacity(self.items());
290        for p in self.table.iter_mut().flat_map(|bucket| bucket.drain(..)) {
291            old_data.push(p);
292        }
293        self.batch_insert(&mut old_data)
294    }
295
296    /// Inner function for insert a single point into the map
297    #[inline]
298    fn insert_inner(&mut self, p: Point<F>) -> Option<Point<F>> {
299        // Resize if the table is empty or 3/4 size of the table is full
300        if self.table.is_empty() || self.items() > 3 * self.table.len() / 4 {
301            self.resize();
302        }
303        let hash = make_hash_point::<M, F>(&mut self.hasher, &[p.x, p.y]);
304        self.insert_with_axis(p, hash)
305    }
306
307    /// Sequencial insert a point into the map.
308    ///
309    /// # Arguments
310    /// * `p` - A Point<F> with float number
311    ///
312    /// # Examples
313    ///
314    /// ```
315    /// use lsph::{LearnedHashMap, LinearModel, Point};
316    /// let a: Point<f64> = Point::new(0., 1.);
317    /// let b: Point<f64> = Point::new(1., 0.);
318
319    /// let mut map = LearnedHashMap::<LinearModel<f64>, f64>::new();
320    /// map.insert(a);
321    /// map.insert(b);
322
323    /// assert_eq!(map.items(), 2);
324    /// assert_eq!(map.get(&[0., 1.]).unwrap(), &a);
325    /// assert_eq!(map.get(&[1., 0.]).unwrap(), &b);
326    /// ```
327    pub fn insert(&mut self, p: Point<F>) -> Option<Point<F>> {
328        // Resize if the table is empty or 3/4 size of the table is full
329        if self.table.is_empty() || self.items() > 3 * self.table.len() / 4 {
330            self.resize();
331        }
332
333        let hash = make_hash_point::<M, F>(&mut self.hasher, &[p.x, p.y]);
334        // resize if hash index is larger or equal to the table capacity
335        if hash >= self.table.capacity() as u64 {
336            self.resize_with_capacity(hash as usize * 2);
337            self.insert_with_axis(p, hash);
338            match self.rehash() {
339                Ok(_) => None,
340                Err(err) => {
341                    eprintln!("{:?}", err);
342                    None
343                }
344            }
345        } else {
346            self.insert_with_axis(p, hash)
347        }
348    }
349
350    /// Insert a point into the map along the given axis.
351    ///
352    /// # Arguments
353    /// * `p_value` - A float number represent the key of a 2d point
354    #[inline]
355    fn insert_with_axis(&mut self, p: Point<F>, hash: u64) -> Option<Point<F>> {
356        let mut insert_index = 0;
357        let bucket_index = self.table.bucket(hash);
358        let bucket = &mut self.table[bucket_index];
359        if self.hasher.sort_by_x() {
360            // Get index from the hasher
361            for ep in bucket.iter_mut() {
362                if ep == &mut p.clone() {
363                    return Some(mem::replace(ep, p));
364                }
365                if ep.y < p.y() {
366                    insert_index += 1;
367                }
368            }
369        } else {
370            for ep in bucket.iter_mut() {
371                if ep == &mut p.clone() {
372                    return Some(mem::replace(ep, p));
373                }
374                if ep.x < p.x() {
375                    insert_index += 1;
376                }
377            }
378        }
379        bucket.insert(insert_index, p);
380        self.items += 1;
381        None
382    }
383
384    /// Fit the input data into the model of the hasher. Returns Error if error occurred during
385    /// model fitting.
386    ///
387    /// # Arguments
388    ///
389    /// * `xs` - A list of tuple of floating number
390    /// * `ys` - A list of tuple of floating number
391    #[inline]
392    pub fn model_fit(&mut self, xs: &[F], ys: &[F]) -> Result<(), Error> {
393        self.hasher.model.fit(xs, ys)
394    }
395
396    /// Fit the input data into the model of the hasher. Returns Error if error occurred during
397    /// model fitting.
398    ///
399    /// # Arguments
400    /// * `data` - A list of tuple of floating number
401    #[inline]
402    pub fn model_fit_tuple(&mut self, data: &[(F, F)]) -> Result<(), Error> {
403        self.hasher.model.fit_tuple(data)
404    }
405
406    /// Inner function for batch insert
407    #[inline]
408    fn batch_insert_inner(&mut self, ps: &[Point<F>]) {
409        // Allocate table capacity before insert
410        let n = ps.len();
411        self.resize_with_capacity(n);
412        for p in ps.iter() {
413            self.insert_inner(*p);
414        }
415    }
416
417    /// Batch insert a batch of 2d data into the map.
418    ///
419    /// # Arguments
420    /// * `ps` - A list of point number
421    ///
422    /// # Examples
423    ///
424    /// ```
425    /// use lsph::{LearnedHashMap, LinearModel};
426    /// let point_data = vec![[1., 1.], [2., 1.], [3., 2.], [4., 4.]];
427    /// let (mut map, points) = LearnedHashMap::<LinearModel<f64>, f64>::with_data(&point_data).unwrap();
428    ///
429    /// assert_eq!(map.get(&[1., 1.]).is_some(), true);
430    /// ```
431    #[inline]
432    pub fn batch_insert(&mut self, ps: &mut [Point<F>]) -> Result<(), Error> {
433        // Select suitable axis for training
434        use crate::geometry::Axis;
435        use crate::models::Trainer;
436
437        // Loading data into trainer
438        if let Ok(trainer) = Trainer::with_points(ps) {
439            trainer.train(&mut self.hasher.model).unwrap();
440            let axis = trainer.axis();
441            match axis {
442                Axis::X => self.hasher.set_sort_by_x(true),
443                _ => self.hasher.set_sort_by_x(false),
444            };
445
446            // Fit the data into model
447            self.model_fit(trainer.train_x(), trainer.train_y())
448                .unwrap();
449            // Batch insert into the map
450            self.batch_insert_inner(ps);
451        }
452        Ok(())
453    }
454
455    /// Range search finds all points for a given 2d range
456    /// Returns all the points within the given range
457    /// ```text
458    ///      |                    top right
459    ///      |        .-----------*
460    ///      |        | .   .     |
461    ///      |        |  .  .  .  |
462    ///      |        |       .   |
463    ///   bottom left *-----------.
464    ///      |
465    ///      |        |           |
466    ///      |________v___________v________
467    ///              left       right
468    ///              hash       hash
469    /// ```
470    /// # Arguments
471    ///
472    /// * `bottom_left` - A tuple containing a pair of points that represent the bottom left of the
473    /// range.
474    ///
475    /// * `top_right` - A tuple containing a pair of points that represent the top right of the
476    /// range.
477    #[inline]
478    pub fn range_search(
479        &mut self,
480        bottom_left: &[F; 2],
481        top_right: &[F; 2],
482    ) -> Option<Vec<Point<F>>> {
483        let mut right_hash = make_hash_point(&mut self.hasher, top_right) as usize;
484        if right_hash > self.table.capacity() {
485            right_hash = self.table.capacity() as usize - 1;
486        }
487        let left_hash = make_hash_point(&mut self.hasher, bottom_left) as usize;
488        if left_hash > self.table.capacity() || left_hash > right_hash {
489            return None;
490        }
491        let mut result: Vec<Point<F>> = Vec::new();
492        for i in left_hash..=right_hash {
493            let bucket = &self.table[i];
494            for item in bucket.iter() {
495                if item.x >= bottom_left[0]
496                    && item.x <= top_right[0]
497                    && item.y >= bottom_left[1]
498                    && item.y <= top_right[1]
499                {
500                    result.push(*item);
501                }
502            }
503        }
504        if result.is_empty() {
505            return None;
506        }
507        Some(result)
508    }
509
510    /// Returns Option<Vec<Point<F>>> if points are found in the map with given range
511    ///
512    /// # Arguments
513    /// * `query_point` - A Point data for querying
514    /// * `radius` - A radius value
515    ///
516    /// # Examples
517    ///
518    /// ```
519    /// use lsph::{LearnedHashMap, LinearModel, LearnedHasher};
520    /// let point_data = vec![[1., 1.], [2., 1.], [3., 2.], [4., 4.]];
521    /// let (mut map, points) = LearnedHashMap::<LinearModel<f64>, f64>::with_data(&point_data).unwrap();
522    /// assert_eq!(map.range_search(&[0., 0.], &[3., 3.]).is_some(), true);
523    /// ```
524    #[inline]
525    pub fn radius_range(&mut self, query_point: &[F; 2], radius: F) -> Option<Vec<Point<F>>> {
526        self.range_search(
527            &[query_point[0] - radius, query_point[1] - radius],
528            &[query_point[0] + radius, query_point[1] + radius],
529        )
530    }
531
532    /// Find the local minimum distance between query points and cadidates neighbors, then store
533    /// the cadidates neighbors in the min_heap.
534    ///
535    ///
536    /// # Arguments
537    /// * `heap` - mutable borrow of an BinaryHeap
538    /// * `local_hash` - A hash index of local bucket
539    /// * `query_point` - A Point data
540    /// * `min_d` - minimum distance
541    /// * `nearest_neighbor` - mutable borrow of an point data, which is the nearest neighbor at
542    /// search index bucket
543    #[inline]
544    fn local_min_heap(
545        &self,
546        heap: &mut BinaryHeap<NearestNeighborState<F>>,
547        local_hash: u64,
548        query_point: &[F; 2],
549        min_d: &mut F,
550        nearest_neighbor: &mut Point<F>,
551    ) {
552        let bucket = &self.table[local_hash as usize];
553        if !bucket.is_empty() {
554            for p in bucket.iter() {
555                let d = Euclidean::distance(query_point, &[p.x, p.y]);
556                heap.push(NearestNeighborState {
557                    distance: d,
558                    point: *p,
559                });
560            }
561        }
562        match heap.pop() {
563            Some(v) => {
564                let local_min_d = v.distance;
565                // Update the nearest neighbour and minimum distance
566                if &local_min_d < min_d {
567                    *nearest_neighbor = v.point;
568                    *min_d = local_min_d;
569                }
570            }
571            None => (),
572        }
573    }
574
575    /// Calculates the horizontal distance between query_point and bucket at index with given hash.
576    ///
577    /// # Arguments
578    /// * `hash` - A hash index of the bucket
579    /// * `query_point` - A Point data
580    #[inline]
581    fn horizontal_distance(&mut self, query_point: &[F; 2], hash: u64) -> F {
582        let x = unhash(&mut self.hasher, hash);
583        match self.hasher.sort_by_x() {
584            true => Euclidean::distance(&[query_point[0], F::zero()], &[x, F::zero()]),
585            false => Euclidean::distance(&[query_point[1], F::zero()], &[x, F::zero()]),
586        }
587    }
588
589    /// Nearest neighbor search for the closest point for given query point
590    /// Returns the closest point
591    ///```text
592    ///      |
593    ///      |            .
594    ///      |         .  |
595    ///      |         |. |  *  . <- nearest neighbor
596    ///      |         || |  | .|
597    ///      |  expand <--------> expand
598    ///      |  left         |     right
599    ///      |               |
600    ///      |_______________v_____________
601    ///                    query
602    ///                    point
603    ///```
604    /// # Arguments
605    ///
606    /// * `query_point` - A tuple containing a pair of points for querying
607    ///
608    /// # Examples
609    ///
610    /// ```
611    /// use lsph::{LearnedHashMap, LinearModel, LearnedHasher};
612    /// let point_data = vec![[1., 1.], [2., 1.], [3., 2.], [4., 4.]];
613    /// let (mut map, points) = LearnedHashMap::<LinearModel<f64>, f64>::with_data(&point_data).unwrap();
614    /// assert_eq!(map.nearest_neighbor(&[2., 1.]).is_some(), true);
615    /// ```
616    #[inline]
617    pub fn nearest_neighbor(&mut self, query_point: &[F; 2]) -> Option<Point<F>> {
618        let mut hash = make_hash_point(&mut self.hasher, query_point);
619        let max_capacity = self.table.capacity() as u64;
620
621        // if hash out of max bound, still search right most bucket
622        if hash > max_capacity {
623            hash = max_capacity - 1;
624        }
625
626        let mut heap = BinaryHeap::new();
627        let mut min_d = F::max_value();
628        let mut nearest_neighbor = Point::default();
629
630        // Searching at current hash index
631        self.local_min_heap(
632            &mut heap,
633            hash,
634            query_point,
635            &mut min_d,
636            &mut nearest_neighbor,
637        );
638
639        // Measure left horizontal distance from current bucket to left hash bucket
640        // left hash must >= 0
641        let mut left_hash = hash.saturating_sub(1);
642        // Unhash the left_hash, then calculate the vertical distance between
643        // left hash point and query point
644        let mut left_hash_d = self.horizontal_distance(query_point, left_hash);
645
646        // Iterate over left
647        while left_hash_d < min_d {
648            self.local_min_heap(
649                &mut heap,
650                left_hash,
651                query_point,
652                &mut min_d,
653                &mut nearest_neighbor,
654            );
655
656            // break before update
657            if left_hash == 0 {
658                break;
659            }
660
661            // Update next right side bucket distance
662            left_hash = left_hash.saturating_sub(1);
663            left_hash_d = self.horizontal_distance(query_point, left_hash);
664        }
665
666        // Measure right vertical distance from current bucket to right hash bucket
667        let mut right_hash = hash + 1;
668        // Unhash the right_hash, then calculate the vertical distance between
669        // right hash point and query point
670        let mut right_hash_d = self.horizontal_distance(query_point, right_hash);
671
672        // Iterate over right
673        while right_hash_d < min_d {
674            self.local_min_heap(
675                &mut heap,
676                right_hash,
677                query_point,
678                &mut min_d,
679                &mut nearest_neighbor,
680            );
681
682            // Move to next right bucket
683            right_hash += 1;
684
685            // break after update
686            if right_hash == self.table.capacity() as u64 {
687                break;
688            }
689            // Update next right side bucket distance
690            right_hash_d = self.horizontal_distance(query_point, right_hash);
691        }
692
693        Some(nearest_neighbor)
694    }
695}
696
697pub struct Iter<'a, M, F>
698where
699    F: Float,
700    M: Model<F = F> + Default + Clone,
701{
702    map: &'a LearnedHashMap<M, F>,
703    bucket: usize,
704    at: usize,
705}
706
707impl<'a, M, F> Iterator for Iter<'a, M, F>
708where
709    F: Float,
710    M: Model<F = F> + Default + Clone,
711{
712    type Item = &'a Point<F>;
713    fn next(&mut self) -> Option<Self::Item> {
714        loop {
715            match self.map.table.get(self.bucket) {
716                Some(bucket) => {
717                    match bucket.get(self.at) {
718                        Some(p) => {
719                            // move along self.at and self.bucket
720                            self.at += 1;
721                            break Some(p);
722                        }
723                        None => {
724                            self.bucket += 1;
725                            self.at = 0;
726                            continue;
727                        }
728                    }
729                }
730                None => break None,
731            }
732        }
733    }
734}
735
736impl<'a, M, F> IntoIterator for &'a LearnedHashMap<M, F>
737where
738    F: Float,
739    M: Model<F = F> + Default + Clone,
740{
741    type Item = &'a Point<F>;
742    type IntoIter = Iter<'a, M, F>;
743    fn into_iter(self) -> Self::IntoIter {
744        Iter {
745            map: self,
746            bucket: 0,
747            at: 0,
748        }
749    }
750}
751
752pub struct IntoIter<M, F>
753where
754    F: Float,
755    M: Model<F = F> + Default + Clone,
756{
757    map: LearnedHashMap<M, F>,
758    bucket: usize,
759}
760
761impl<M, F> Iterator for IntoIter<M, F>
762where
763    F: Float,
764    M: Model<F = F> + Default + Clone,
765{
766    type Item = Point<F>;
767    fn next(&mut self) -> Option<Self::Item> {
768        loop {
769            match self.map.table.get_mut(self.bucket) {
770                Some(bucket) => match bucket.pop() {
771                    Some(x) => break Some(x),
772                    None => {
773                        self.bucket += 1;
774                        continue;
775                    }
776                },
777                None => break None,
778            }
779        }
780    }
781}
782
783impl<M, F> IntoIterator for LearnedHashMap<M, F>
784where
785    F: Float,
786    M: Model<F = F> + Default + Clone,
787{
788    type Item = Point<F>;
789    type IntoIter = IntoIter<M, F>;
790    fn into_iter(self) -> Self::IntoIter {
791        IntoIter {
792            map: self,
793            bucket: 0,
794        }
795    }
796}
797
798#[cfg(test)]
799mod tests {
800    use super::*;
801    use crate::geometry::Point;
802    use crate::models::LinearModel;
803    use crate::test_utilities::*;
804
805    #[test]
806    fn insert() {
807        let a: Point<f64> = Point::new(0., 1.);
808        let b: Point<f64> = Point::new(1., 0.);
809
810        let mut map = LearnedHashMap::<LinearModel<f64>, f64>::new();
811        map.insert(a);
812        map.insert(b);
813
814        assert_eq!(map.items(), 2);
815        assert_eq!(map.get(&[0., 1.]).unwrap(), &a);
816        assert_eq!(map.get(&[1., 0.]).unwrap(), &b);
817    }
818
819    #[test]
820    fn insert_repeated() {
821        let mut map = LearnedHashMap::<LinearModel<f64>, f64>::new();
822        let a: Point<f64> = Point::new(0., 1.);
823        let b: Point<f64> = Point::new(1., 0.);
824        let res = map.insert(a);
825        assert_eq!(map.items(), 1);
826        assert_eq!(res, None);
827
828        let res = map.insert(b);
829        assert_eq!(map.items(), 2);
830        assert_eq!(res, None);
831    }
832
833    #[test]
834    fn with_data() {
835        let data = vec![[1., 1.], [2., 1.], [3., 2.], [4., 4.]];
836        let (mut map, _points) = LearnedHashMap::<LinearModel<f64>, f64>::with_data(&data).unwrap();
837        assert_eq!(map.get(&[1., 1.]).is_some(), true);
838    }
839
840    #[test]
841    fn fit_batch_insert() {
842        let mut data: Vec<Point<f64>> = vec![
843            Point::new(1., 1.),
844            Point::new(3., 1.),
845            Point::new(2., 1.),
846            Point::new(3., 2.),
847            Point::new(5., 1.),
848        ];
849        let mut map = LearnedHashMap::<LinearModel<f64>, f64>::new();
850        map.batch_insert(&mut data).unwrap();
851        dbg!(&map);
852
853        assert_delta!(1.02272, map.hasher.model.coefficient, 0.00001);
854        assert_delta!(-0.86363, map.hasher.model.intercept, 0.00001);
855        assert_eq!(Some(&Point::new(1., 1.)), map.get(&[1., 1.]));
856        assert_eq!(Some(&Point::new(3., 1.,)), map.get(&[3., 1.]));
857        assert_eq!(Some(&Point::new(5., 1.)), map.get(&[5., 1.]));
858
859        assert_eq!(None, map.get(&[5., 2.]));
860        assert_eq!(None, map.get(&[2., 2.]));
861        assert_eq!(None, map.get(&[50., 10.]));
862        assert_eq!(None, map.get(&[500., 100.]));
863    }
864
865    #[test]
866    fn insert_after_batch_insert() {
867        let mut data: Vec<Point<f64>> = vec![
868            Point::new(1., 1.),
869            Point::new(3., 1.),
870            Point::new(2., 1.),
871            Point::new(3., 2.),
872            Point::new(5., 1.),
873        ];
874        let mut map = LearnedHashMap::<LinearModel<f64>, f64>::new();
875        map.batch_insert(&mut data).unwrap();
876        dbg!(&map);
877
878        let a: Point<f64> = Point::new(10., 10.);
879        map.insert(a.clone());
880        assert_eq!(Some(&a), map.get(&[10., 10.]));
881
882        let b: Point<f64> = Point::new(100., 100.);
883        map.insert(b.clone());
884        assert_eq!(Some(&b), map.get(&[100., 100.]));
885        assert_eq!(None, map.get(&[100., 101.]));
886    }
887
888    #[test]
889    fn range_search() {
890        let mut data: Vec<Point<f64>> = vec![
891            Point::new(1., 1.),
892            Point::new(2., 2.),
893            Point::new(3., 3.),
894            Point::new(4., 4.),
895            Point::new(5., 5.),
896        ];
897        let mut map = LearnedHashMap::<LinearModel<f64>, f64>::new();
898        map.batch_insert(&mut data).unwrap();
899        // dbg!(&map);
900
901        let found: Vec<Point<f64>> =
902            vec![Point::new(1., 1.), Point::new(2., 2.), Point::new(3., 3.)];
903
904        assert_eq!(Some(found), map.range_search(&[1., 1.], &[3.5, 3.]));
905
906        let found: Vec<Point<f64>> = vec![Point::new(1., 1.)];
907
908        assert_eq!(Some(found), map.range_search(&[1., 1.], &[3., 1.]));
909        assert_eq!(None, map.range_search(&[4., 2.], &[5., 3.]));
910    }
911
912    #[test]
913    fn test_nearest_neighbor() {
914        let points = create_random_point_type_points(1000, SEED_1);
915        let mut map = LearnedHashMap::<LinearModel<f64>, f64>::new();
916        map.batch_insert(&mut points.clone()).unwrap();
917
918        let sample_points = create_random_point_type_points(100, SEED_2);
919        let mut i = 0;
920        for sample_point in &sample_points {
921            let mut nearest = None;
922            let mut closest_dist = ::core::f64::INFINITY;
923            for point in &points {
924                let new_dist = Euclidean::distance_point(&point, &sample_point);
925                if new_dist < closest_dist {
926                    closest_dist = new_dist;
927                    nearest = Some(point);
928                }
929            }
930            let map_nearest = map
931                .nearest_neighbor(&[sample_point.x, sample_point.y])
932                .unwrap();
933            assert_eq!(nearest.unwrap(), &map_nearest);
934            i = i + 1;
935        }
936    }
937}