ann_dataset/data/
in_memory_dataset.rs

1use crate::data::AnnDataset;
2use crate::io::Hdf5File;
3use crate::{Hdf5Serialization, PointSet, QuerySet};
4use anyhow::{anyhow, Result};
5use hdf5::{File, Group, H5Type};
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::fmt;
9use std::fmt::Formatter;
10
11const QUERY_SETS: &str = "query_sets";
12
13/// An ANN dataset.
14#[derive(Eq, PartialEq, Debug, Clone, Serialize, Deserialize)]
15pub struct InMemoryAnnDataset<DataType: Clone> {
16    data_points: PointSet<DataType>,
17    query_sets: HashMap<String, QuerySet<DataType>>,
18}
19
20impl<DataType: Clone> InMemoryAnnDataset<DataType> {
21    /// Creates an `AnnDataset` object.
22    ///
23    /// Here is a simple example:
24    /// ```rust
25    /// use ndarray::Array2;
26    /// use sprs::{CsMat, TriMat};
27    /// use ann_dataset::{InMemoryAnnDataset, PointSet};
28    ///
29    /// let dense = Array2::<f32>::eye(10);
30    /// let mut sparse = TriMat::new((10, 4));
31    /// sparse.add_triplet(0, 0, 3.0_f32);
32    /// sparse.add_triplet(1, 2, 2.0);
33    /// sparse.add_triplet(3, 0, -2.0);
34    /// sparse.add_triplet(9, 2, 3.4);
35    /// let sparse: CsMat<_> = sparse.to_csr();
36    ///
37    /// let data_points = PointSet::new(Some(dense.clone()), Some(sparse.clone()))
38    ///     .expect("Failed to create PointSet.");
39    ///
40    /// let dataset = InMemoryAnnDataset::create(data_points);
41    /// ```
42    pub fn create(data_points: PointSet<DataType>) -> InMemoryAnnDataset<DataType> {
43        InMemoryAnnDataset {
44            data_points,
45            query_sets: HashMap::new(),
46        }
47    }
48}
49
50impl<DataType: Clone> AnnDataset<DataType> for InMemoryAnnDataset<DataType> {
51    fn get_data_points(&self) -> &PointSet<DataType> {
52        &self.data_points
53    }
54
55    fn get_data_points_mut(&mut self) -> &mut PointSet<DataType> {
56        &mut self.data_points
57    }
58
59    fn select(&self, ids: &[usize]) -> PointSet<DataType> {
60        self.data_points.select(ids)
61    }
62
63    /// Adds a new query set to the dataset with the given `label` or replaces one if it already
64    /// exists.
65    ///
66    /// Consider the following example:
67    /// ```rust
68    /// use ndarray::Array2;
69    /// use ann_dataset::{AnnDataset, InMemoryAnnDataset, PointSet, QuerySet};
70    ///
71    /// let dense = Array2::<f32>::eye(10);
72    /// let data_points = PointSet::new(Some(dense.clone()), None)
73    ///     .expect("Failed to create PointSet.");
74    /// let query_points = data_points.clone();
75    ///
76    /// let mut dataset = InMemoryAnnDataset::create(data_points);
77    ///
78    /// let query_set = QuerySet::new(query_points);
79    /// dataset.add_query_set("train", query_set);
80    /// ```
81    fn add_query_set(&mut self, label: &str, query_set: QuerySet<DataType>) {
82        self.query_sets.insert(label.to_string(), query_set);
83    }
84
85    fn get_query_set(&self, label: &str) -> Result<&QuerySet<DataType>> {
86        match self.query_sets.get(label) {
87            None => Err(anyhow!("Query set {} does not exist", label)),
88            Some(set) => Ok(set),
89        }
90    }
91}
92
93impl<DataType: Clone + H5Type> Hdf5Serialization for InMemoryAnnDataset<DataType> {
94    type Object = InMemoryAnnDataset<DataType>;
95
96    fn add_to(&self, group: &mut Group) -> Result<()> {
97        self.data_points.add_to(group)?;
98
99        let query_group = group.create_group(QUERY_SETS)?;
100        self.query_sets.iter().try_for_each(|entry| {
101            let mut grp = query_group.create_group(entry.0)?;
102            entry.1.add_to(&mut grp)?;
103            anyhow::Ok(())
104        })?;
105        Ok(())
106    }
107
108    fn read_from(group: &Group) -> Result<Self::Object> {
109        let data_points = PointSet::<DataType>::read_from(group)?;
110
111        let mut query_sets: HashMap<String, QuerySet<DataType>> = HashMap::new();
112        let query_group = group.group(QUERY_SETS)?;
113        query_group.groups()?.iter().try_for_each(|grp| {
114            let name = grp.name();
115            let name = name.split('/').last().unwrap();
116            let query_set = QuerySet::<DataType>::read_from(grp)?;
117            query_sets.insert(name.to_string(), query_set);
118            anyhow::Ok(())
119        })?;
120
121        Ok(InMemoryAnnDataset {
122            data_points,
123            query_sets,
124        })
125    }
126
127    fn label() -> String {
128        "ann-dataset".to_string()
129    }
130}
131
132impl<DataType: Clone + H5Type> Hdf5File for InMemoryAnnDataset<DataType> {
133    type Object = InMemoryAnnDataset<DataType>;
134
135    fn write(&self, path: &str) -> Result<()> {
136        let file = File::create(path)?;
137        let mut root = file.group("/")?;
138        Hdf5Serialization::add_to(self, &mut root)?;
139        file.close()?;
140        Ok(())
141    }
142
143    fn read(path: &str) -> Result<Self::Object> {
144        let hdf5_dataset = File::open(path)?;
145        let root = hdf5_dataset.group("/")?;
146        <InMemoryAnnDataset<DataType> as Hdf5Serialization>::read_from(&root)
147    }
148}
149
150impl<DataType: Clone> fmt::Display for InMemoryAnnDataset<DataType> {
151    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
152        write!(
153            f,
154            "Point Set: {}\n{}",
155            self.data_points,
156            self.query_sets
157                .iter()
158                .map(|entry| format!("{}: {}", entry.0, entry.1))
159                .collect::<Vec<_>>()
160                .join("\n")
161        )
162    }
163}
164
165#[cfg(test)]
166mod tests {
167    use crate::data::in_memory_dataset::InMemoryAnnDataset;
168    use crate::data::AnnDataset;
169    use crate::{Hdf5File, PointSet, QuerySet};
170    use ndarray::Array2;
171    use ndarray_rand::rand_distr::Uniform;
172    use ndarray_rand::RandomExt;
173    use sprs::{CsMat, TriMat};
174    use tempdir::TempDir;
175
176    fn sample_data_points() -> PointSet<f32> {
177        let dense_set = Array2::random((4, 10), Uniform::new(0.0, 1.0));
178
179        let mut sparse_set = TriMat::new((4, 4));
180        sparse_set.add_triplet(0, 0, 3.0_f32);
181        sparse_set.add_triplet(1, 2, 2.0);
182        sparse_set.add_triplet(3, 0, -2.0);
183        let sparse_set: CsMat<_> = sparse_set.to_csr();
184
185        PointSet::new(Some(dense_set), Some(sparse_set)).unwrap()
186    }
187
188    #[test]
189    fn test_create() {
190        let data_points = sample_data_points();
191        let dataset = InMemoryAnnDataset::<f32>::create(data_points.clone());
192        let copy = dataset.get_data_points();
193        assert_eq!(&data_points, copy);
194    }
195
196    #[test]
197    fn test_query_points() {
198        let data_points = sample_data_points();
199        let mut dataset = InMemoryAnnDataset::<f32>::create(data_points.clone());
200
201        assert!(dataset.get_train_query_set().is_err());
202        assert!(dataset.get_validation_query_set().is_err());
203        assert!(dataset.get_test_query_set().is_err());
204
205        let query_points = sample_data_points();
206        dataset.add_train_query_set(QuerySet::new(query_points.clone()));
207        assert!(dataset.get_train_query_set().is_ok());
208        let copy = dataset.get_train_query_set().unwrap();
209        assert_eq!(&query_points, copy.get_points());
210
211        // Replace an existing query set.
212        let query_points = sample_data_points();
213        dataset.add_train_query_set(QuerySet::new(query_points.clone()));
214        assert!(dataset.get_train_query_set().is_ok());
215        let copy = dataset.get_train_query_set().unwrap();
216        assert_eq!(&query_points, copy.get_points());
217    }
218
219    #[test]
220    fn test_write() {
221        let data_points = sample_data_points();
222        let mut dataset = InMemoryAnnDataset::<f32>::create(data_points.clone());
223        let query_points = sample_data_points();
224        dataset.add_train_query_set(QuerySet::new(query_points.clone()));
225
226        let dir = TempDir::new("test_write").unwrap();
227        let path = dir.path().join("ann-dataset.hdf5");
228        let path = path.to_str().unwrap();
229
230        let result = dataset.write(path);
231        assert!(result.is_ok());
232
233        // Next, load the dataset and assert that vector sets are intact.
234        let dataset = InMemoryAnnDataset::<f32>::read(path);
235        assert!(dataset.is_ok());
236        let dataset = dataset.unwrap();
237
238        assert_eq!(&data_points, dataset.get_data_points());
239        assert_eq!(
240            &query_points,
241            dataset.get_train_query_set().unwrap().get_points()
242        );
243    }
244}