nn_rs/
nearest_neighbours.rs1use anyhow::Result;
2use nalgebra as na;
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use std::collections::HashMap;
6use std::fs::File;
7use std::path::PathBuf;
8use std::fmt::Debug;
9
10use crate::metrics::metric_factory;
11use crate::types::{Distance, MetricFunction};
12
13#[derive(Deserialize, Serialize, Debug, Clone)]
14pub struct NearestNeighbours {
15 metric: String,
16 vectors: HashMap<String, na::DVector<f64>>,
17}
18
19impl NearestNeighbours {
20 pub fn new(metric: String) -> Result<NearestNeighbours> {
35 let vectors: HashMap<String, na::DVector<f64>> = HashMap::new();
36 Ok(NearestNeighbours { metric, vectors })
37 }
38
39 pub fn from_json(metric: String, vectors_file: PathBuf) -> Result<NearestNeighbours> {
63 let mut vectors: HashMap<String, na::DVector<f64>> = HashMap::new();
66 let input_file = File::open(vectors_file)?;
67 let vector_vectors: Value = serde_json::from_reader(input_file)?;
68 for (key, value) in vector_vectors.as_object().unwrap(){
69 let n = value.as_array().unwrap();
70 let mut vec_dummy = vec![];
71 for ns in n{
72 vec_dummy.push(ns.to_owned().as_f64().unwrap());
73 }
74 let y = na::DVector::from_vec(vec_dummy);
75 vectors.insert(key.to_owned(), y);
76 }
77 Ok(NearestNeighbours { metric, vectors })
78 }
79
80 pub fn load(nn_file: PathBuf) -> Result<NearestNeighbours> {
86 let input_path = File::open(nn_file)?;
88 let new_nn_struct: NearestNeighbours = serde_json::from_reader(input_path)?;
89 Ok(new_nn_struct)
90 }
91
92 pub fn add_vector(&mut self, id: String, vector: na::DVector<f64>) -> Result<()> {
103 self.vectors.insert(id, vector);
104 Ok(())
105 }
106
107 pub fn save(&mut self, output_path: PathBuf) -> Result<()> {
113 let output_file = File::create(output_path)?;
114 serde_json::to_writer(output_file, &self)?;
115 Ok(())
116 }
117
118 pub fn query_by_vector(
127 &mut self,
128 query_vector: na::DVector<f64>,
129 no_neighbours: usize,
130 ) -> Result<Vec<String>> {
131 let mut distances = HashMap::new();
133 let metric: MetricFunction = metric_factory(&self.metric)?;
134 for (index, vector) in &self.vectors {
135 distances.insert(Distance::new(metric(vector, &query_vector)), index);
136 }
137
138 let mut all_distances: Vec<&Distance> = distances.iter().map(|(dist, _id)| dist).collect();
140 all_distances.sort();
141
142 let mut nns = vec![];
144 for neighbour in all_distances.iter().take(no_neighbours) {
145 let id: &String = distances[neighbour];
146 let id_owned: String = id.to_string();
147 nns.push(id_owned);
148 }
149 Ok(nns)
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156 use na::dvector;
157 use std::fs::remove_file;
158 use anyhow::Result;
159
160 #[test]
161 fn test_nearest_neighbours_new() -> Result<()> {
162 NearestNeighbours::new(String::from("cosine"))?;
163 Ok(())
164 }
165
166 #[test]
167 fn test_nearest_neighbours_from_json() -> Result<()> {
168 let json_path = PathBuf::from("./tests/data/test.json");
169 let metric = String::from("cosine");
170 let _index = NearestNeighbours::from_json(metric, json_path)?;
171 Ok(())
172 }
173
174 #[test]
175 fn test_nearest_neighbours_load() -> Result<()> {
176 let output_file = PathBuf::from(r"./tests/data/test.nn");
177 let _index = NearestNeighbours::load(output_file)?;
178 Ok(())
179 }
180
181 #[test]
182 fn test_nearest_neighbours_add_vector() -> Result<()> {
183 let mut index = NearestNeighbours::new(String::from("cosine"))?;
184 let new_vector = dvector!(1.0, 2.0);
185 index.add_vector(String::from("one"), new_vector)?;
186 Ok(())
187 }
188
189 #[test]
190 fn test_nearest_neighbours_save() -> Result<()> {
191 let mut index = NearestNeighbours::new(String::from("cosine"))?;
192 let new_vector = dvector!(1.0, 2.0);
193 index.add_vector(String::from("one"), new_vector)?;
194 let output_file = PathBuf::from(r"./tests/data/test_save.nn");
195 index.save(output_file.clone())?;
196 remove_file(output_file)?;
197 Ok(())
198 }
199
200 #[test]
201 fn test_nearest_neighbours_query_by_vector() -> Result<()> {
202 let mut index = NearestNeighbours::new(String::from("cosine"))?;
203 let new_vector = dvector!(1.0, 2.0);
204 index.add_vector(String::from("one"), new_vector)?;
205 let new_vector_2 = dvector!(9.0, 7.0);
206 index.add_vector(String::from("two"), new_vector_2)?;
207 let new_vector_3 = dvector!(1.0, 2.0);
208 let nn = index.query_by_vector(new_vector_3, 1)?;
209 assert_eq!(nn, vec!["one"]);
210 Ok(())
211 }
212}