nn_rs/
nearest_neighbours.rs

1use anyhow::Result;
2use nalgebra as na;
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use std::collections::HashMap;
6use std::fs::File;
7use std::path::PathBuf;
8use std::fmt::Debug;
9
10use crate::metrics::metric_factory;
11use crate::types::{Distance, MetricFunction};
12
13#[derive(Deserialize, Serialize, Debug, Clone)]
14pub struct NearestNeighbours {
15    metric: String,
16    vectors: HashMap<String, na::DVector<f64>>,
17}
18
19impl NearestNeighbours {
20    /// Create a new, empty NearestNeighbours struct
21    ///
22    /// ```rust
23    /// use nn_rs::NearestNeighbours;
24    ///
25    /// let metric = String::from("cosine");
26    /// let mut index = NearestNeighbours::new(metric);
27    /// ```
28    ///
29    /// # Parameters
30    /// - metric: distance metric to use. One of "cosine", "euclidean" or "manhattan"
31    ///
32    /// # Return Values
33    /// - NearestNeighbours struct
34    pub fn new(metric: String) -> Result<NearestNeighbours> {
35        let vectors: HashMap<String, na::DVector<f64>> = HashMap::new();
36        Ok(NearestNeighbours { metric, vectors })
37    }
38
39    /// Create a new NearestNeighbours struct from a json
40    ///
41    /// This constructor should be useful for loading vectors from python matrix libraries
42    /// such as torch, tensorflow, jax, numpy etc.
43    ///
44    /// # Parameters
45    /// - metric: distance metric. One of "cosine", "euclidean" or "manhattan"
46    /// - vector_file: a json of {"id_1": [1.0, 2.0, ... ], "id_2": [1.0, 2.0, ... ], ... } format
47    ///
48    /// # Return Values
49    /// - NearestNeighbours struct
50    ///
51    /// Example json format
52    /// ```json
53    /// {
54    ///     "id_1": [1.0, 2.0, ... ],
55    ///     "id_2": [1.0, 2.0, ... ],
56    ///     .
57    ///     .
58    ///     .
59    ///     "id_n": [1.0, 2.0, ... ],
60    /// }
61    /// ```
62    pub fn from_json(metric: String, vectors_file: PathBuf) -> Result<NearestNeighbours> {
63        // TODO: revisit this, it's very scrappy
64        // load the vectors file and parse into vectors
65        let mut vectors: HashMap<String, na::DVector<f64>> = HashMap::new();
66        let input_file = File::open(vectors_file)?;
67        let vector_vectors: Value  = serde_json::from_reader(input_file)?;
68        for (key, value) in vector_vectors.as_object().unwrap(){
69            let n = value.as_array().unwrap();
70            let mut vec_dummy = vec![];
71            for ns in n{
72                vec_dummy.push(ns.to_owned().as_f64().unwrap());
73            }
74            let y = na::DVector::from_vec(vec_dummy);
75            vectors.insert(key.to_owned(), y);
76        }
77        Ok(NearestNeighbours { metric, vectors })
78    }
79
80    /// Load a NearestNeighbours struct from a .nn file
81    /// # Parameters
82    /// - nn_file: .nn file to load
83    /// # Return Values
84    /// - NearestNeighbours struct
85    pub fn load(nn_file: PathBuf) -> Result<NearestNeighbours> {
86        // load the nn file and turn this into a NearestNeighbours struct a parse into vector
87        let input_path = File::open(nn_file)?;
88        let new_nn_struct: NearestNeighbours = serde_json::from_reader(input_path)?;
89        Ok(new_nn_struct)
90    }
91
92    /// Add a new vector to the NearestNeighbour struct
93    ///
94    /// Note: the id should be unique. If it isn't, only the vector associated with the id that was most
95    /// recently added will be kept
96    ///
97    /// # Parameters
98    /// - id: the id of the vector, this can be any string but it must be unique
99    /// - vector: the vector to add
100    /// # Return Values
101    /// - nothing
102    pub fn add_vector(&mut self, id: String, vector: na::DVector<f64>) -> Result<()> {
103        self.vectors.insert(id, vector);
104        Ok(())
105    }
106
107    /// Save the NearestNeighbour struct to a .nn file
108    /// # Parameters
109    /// - output_path: path to save the struct to
110    /// # Return Values
111    /// - nothing
112    pub fn save(&mut self, output_path: PathBuf) -> Result<()> {
113        let output_file = File::create(output_path)?;
114        serde_json::to_writer(output_file, &self)?;
115        Ok(())
116    }
117
118    /// Find the nearest neighbour for a query vector
119    ///
120    /// # Parameters
121    /// - query_vector: vector to find the nearest neighbour to
122    /// - no_neighbours: the number of nearest neighbours to find
123    ///
124    /// # Return Values
125    /// - ids of the nearest neighbours
126    pub fn query_by_vector(
127        &mut self,
128        query_vector: na::DVector<f64>,
129        no_neighbours: usize,
130    ) -> Result<Vec<String>> {
131        // calculate and store the distances between the vectors
132        let mut distances = HashMap::new();
133        let metric: MetricFunction = metric_factory(&self.metric)?;
134        for (index, vector) in &self.vectors {
135            distances.insert(Distance::new(metric(vector, &query_vector)), index);
136        }
137
138        // get the distances and sort to find the smallest distances
139        let mut all_distances: Vec<&Distance> = distances.iter().map(|(dist, _id)| dist).collect();
140        all_distances.sort();
141
142        // find the ids for the smallest distances and return them
143        let mut nns = vec![];
144        for neighbour in all_distances.iter().take(no_neighbours) {
145            let id: &String = distances[neighbour];
146            let id_owned: String = id.to_string();
147            nns.push(id_owned);
148        }
149        Ok(nns)
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156    use na::dvector;
157    use std::fs::remove_file;
158    use anyhow::Result;
159
160    #[test]
161    fn test_nearest_neighbours_new() -> Result<()>  {
162        NearestNeighbours::new(String::from("cosine"))?;
163        Ok(())
164    }
165
166    #[test]
167    fn test_nearest_neighbours_from_json() -> Result<()> {
168        let json_path = PathBuf::from("./tests/data/test.json");
169        let metric = String::from("cosine");
170        let _index = NearestNeighbours::from_json(metric, json_path)?;
171        Ok(())
172    }
173
174    #[test]
175    fn test_nearest_neighbours_load() -> Result<()> {
176        let output_file = PathBuf::from(r"./tests/data/test.nn");
177        let _index = NearestNeighbours::load(output_file)?;
178        Ok(())
179    }
180
181    #[test]
182    fn test_nearest_neighbours_add_vector() -> Result<()> {
183        let mut index = NearestNeighbours::new(String::from("cosine"))?;
184        let new_vector = dvector!(1.0, 2.0);
185        index.add_vector(String::from("one"), new_vector)?;
186        Ok(())
187    }
188
189    #[test]
190    fn test_nearest_neighbours_save() -> Result<()> {
191        let mut index = NearestNeighbours::new(String::from("cosine"))?;
192        let new_vector = dvector!(1.0, 2.0);
193        index.add_vector(String::from("one"), new_vector)?;
194        let output_file = PathBuf::from(r"./tests/data/test_save.nn");
195        index.save(output_file.clone())?;
196        remove_file(output_file)?;
197        Ok(())
198    }
199
200    #[test]
201    fn test_nearest_neighbours_query_by_vector() -> Result<()>  {
202        let mut index = NearestNeighbours::new(String::from("cosine"))?;
203        let new_vector = dvector!(1.0, 2.0);
204        index.add_vector(String::from("one"), new_vector)?;
205        let new_vector_2 = dvector!(9.0, 7.0);
206        index.add_vector(String::from("two"), new_vector_2)?;
207        let new_vector_3 = dvector!(1.0, 2.0);
208        let nn = index.query_by_vector(new_vector_3, 1)?;
209        assert_eq!(nn, vec!["one"]);
210        Ok(())
211    }
212}