ann_dataset/data/
in_memory_dataset.rs1use crate::data::AnnDataset;
2use crate::io::Hdf5File;
3use crate::{Hdf5Serialization, PointSet, QuerySet};
4use anyhow::{anyhow, Result};
5use hdf5::{File, Group, H5Type};
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::fmt;
9use std::fmt::Formatter;
10
11const QUERY_SETS: &str = "query_sets";
12
13#[derive(Eq, PartialEq, Debug, Clone, Serialize, Deserialize)]
15pub struct InMemoryAnnDataset<DataType: Clone> {
16 data_points: PointSet<DataType>,
17 query_sets: HashMap<String, QuerySet<DataType>>,
18}
19
20impl<DataType: Clone> InMemoryAnnDataset<DataType> {
21 pub fn create(data_points: PointSet<DataType>) -> InMemoryAnnDataset<DataType> {
43 InMemoryAnnDataset {
44 data_points,
45 query_sets: HashMap::new(),
46 }
47 }
48}
49
50impl<DataType: Clone> AnnDataset<DataType> for InMemoryAnnDataset<DataType> {
51 fn get_data_points(&self) -> &PointSet<DataType> {
52 &self.data_points
53 }
54
55 fn get_data_points_mut(&mut self) -> &mut PointSet<DataType> {
56 &mut self.data_points
57 }
58
59 fn select(&self, ids: &[usize]) -> PointSet<DataType> {
60 self.data_points.select(ids)
61 }
62
63 fn add_query_set(&mut self, label: &str, query_set: QuerySet<DataType>) {
82 self.query_sets.insert(label.to_string(), query_set);
83 }
84
85 fn get_query_set(&self, label: &str) -> Result<&QuerySet<DataType>> {
86 match self.query_sets.get(label) {
87 None => Err(anyhow!("Query set {} does not exist", label)),
88 Some(set) => Ok(set),
89 }
90 }
91}
92
93impl<DataType: Clone + H5Type> Hdf5Serialization for InMemoryAnnDataset<DataType> {
94 type Object = InMemoryAnnDataset<DataType>;
95
96 fn add_to(&self, group: &mut Group) -> Result<()> {
97 self.data_points.add_to(group)?;
98
99 let query_group = group.create_group(QUERY_SETS)?;
100 self.query_sets.iter().try_for_each(|entry| {
101 let mut grp = query_group.create_group(entry.0)?;
102 entry.1.add_to(&mut grp)?;
103 anyhow::Ok(())
104 })?;
105 Ok(())
106 }
107
108 fn read_from(group: &Group) -> Result<Self::Object> {
109 let data_points = PointSet::<DataType>::read_from(group)?;
110
111 let mut query_sets: HashMap<String, QuerySet<DataType>> = HashMap::new();
112 let query_group = group.group(QUERY_SETS)?;
113 query_group.groups()?.iter().try_for_each(|grp| {
114 let name = grp.name();
115 let name = name.split('/').last().unwrap();
116 let query_set = QuerySet::<DataType>::read_from(grp)?;
117 query_sets.insert(name.to_string(), query_set);
118 anyhow::Ok(())
119 })?;
120
121 Ok(InMemoryAnnDataset {
122 data_points,
123 query_sets,
124 })
125 }
126
127 fn label() -> String {
128 "ann-dataset".to_string()
129 }
130}
131
132impl<DataType: Clone + H5Type> Hdf5File for InMemoryAnnDataset<DataType> {
133 type Object = InMemoryAnnDataset<DataType>;
134
135 fn write(&self, path: &str) -> Result<()> {
136 let file = File::create(path)?;
137 let mut root = file.group("/")?;
138 Hdf5Serialization::add_to(self, &mut root)?;
139 file.close()?;
140 Ok(())
141 }
142
143 fn read(path: &str) -> Result<Self::Object> {
144 let hdf5_dataset = File::open(path)?;
145 let root = hdf5_dataset.group("/")?;
146 <InMemoryAnnDataset<DataType> as Hdf5Serialization>::read_from(&root)
147 }
148}
149
150impl<DataType: Clone> fmt::Display for InMemoryAnnDataset<DataType> {
151 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
152 write!(
153 f,
154 "Point Set: {}\n{}",
155 self.data_points,
156 self.query_sets
157 .iter()
158 .map(|entry| format!("{}: {}", entry.0, entry.1))
159 .collect::<Vec<_>>()
160 .join("\n")
161 )
162 }
163}
164
165#[cfg(test)]
166mod tests {
167 use crate::data::in_memory_dataset::InMemoryAnnDataset;
168 use crate::data::AnnDataset;
169 use crate::{Hdf5File, PointSet, QuerySet};
170 use ndarray::Array2;
171 use ndarray_rand::rand_distr::Uniform;
172 use ndarray_rand::RandomExt;
173 use sprs::{CsMat, TriMat};
174 use tempdir::TempDir;
175
176 fn sample_data_points() -> PointSet<f32> {
177 let dense_set = Array2::random((4, 10), Uniform::new(0.0, 1.0));
178
179 let mut sparse_set = TriMat::new((4, 4));
180 sparse_set.add_triplet(0, 0, 3.0_f32);
181 sparse_set.add_triplet(1, 2, 2.0);
182 sparse_set.add_triplet(3, 0, -2.0);
183 let sparse_set: CsMat<_> = sparse_set.to_csr();
184
185 PointSet::new(Some(dense_set), Some(sparse_set)).unwrap()
186 }
187
188 #[test]
189 fn test_create() {
190 let data_points = sample_data_points();
191 let dataset = InMemoryAnnDataset::<f32>::create(data_points.clone());
192 let copy = dataset.get_data_points();
193 assert_eq!(&data_points, copy);
194 }
195
196 #[test]
197 fn test_query_points() {
198 let data_points = sample_data_points();
199 let mut dataset = InMemoryAnnDataset::<f32>::create(data_points.clone());
200
201 assert!(dataset.get_train_query_set().is_err());
202 assert!(dataset.get_validation_query_set().is_err());
203 assert!(dataset.get_test_query_set().is_err());
204
205 let query_points = sample_data_points();
206 dataset.add_train_query_set(QuerySet::new(query_points.clone()));
207 assert!(dataset.get_train_query_set().is_ok());
208 let copy = dataset.get_train_query_set().unwrap();
209 assert_eq!(&query_points, copy.get_points());
210
211 let query_points = sample_data_points();
213 dataset.add_train_query_set(QuerySet::new(query_points.clone()));
214 assert!(dataset.get_train_query_set().is_ok());
215 let copy = dataset.get_train_query_set().unwrap();
216 assert_eq!(&query_points, copy.get_points());
217 }
218
219 #[test]
220 fn test_write() {
221 let data_points = sample_data_points();
222 let mut dataset = InMemoryAnnDataset::<f32>::create(data_points.clone());
223 let query_points = sample_data_points();
224 dataset.add_train_query_set(QuerySet::new(query_points.clone()));
225
226 let dir = TempDir::new("test_write").unwrap();
227 let path = dir.path().join("ann-dataset.hdf5");
228 let path = path.to_str().unwrap();
229
230 let result = dataset.write(path);
231 assert!(result.is_ok());
232
233 let dataset = InMemoryAnnDataset::<f32>::read(path);
235 assert!(dataset.is_ok());
236 let dataset = dataset.unwrap();
237
238 assert_eq!(&data_points, dataset.get_data_points());
239 assert_eq!(
240 &query_points,
241 dataset.get_train_query_set().unwrap().get_points()
242 );
243 }
244}