1use diskann::{ANNError, ANNResult};
7use diskann_providers::{
8 common::AlignedBoxWithSlice,
9 model::{
10 graph::{graph_data_model::AdjacencyList, traits::GraphDataType},
11 FP_VECTOR_MEM_ALIGN,
12 },
13};
14use hashbrown::{hash_map::Entry::Occupied, HashMap};
15
16pub struct Cache<Data: GraphDataType<VectorIdType = u32>> {
17 mapping: HashMap<Data::VectorIdType, usize>,
19
20 vectors: AlignedBoxWithSlice<Data::VectorDataType>,
22
23 adjacency_lists: Vec<AdjacencyList<Data::VectorIdType>>,
25
26 associated_data: Vec<Data::AssociatedDataType>,
28
29 dimension: usize,
31
32 capacity: usize,
34}
35
36impl<Data> Cache<Data>
37where
38 Data: GraphDataType<VectorIdType = u32>,
39{
40 pub fn new(dimension: usize, capacity: usize) -> ANNResult<Self> {
42 Ok(Self {
43 mapping: HashMap::new(),
44 vectors: AlignedBoxWithSlice::new(capacity * dimension, FP_VECTOR_MEM_ALIGN)?,
45 adjacency_lists: Vec::with_capacity(capacity),
46 associated_data: Vec::with_capacity(capacity),
47 dimension,
48 capacity,
49 })
50 }
51
52 pub fn contains(&self, vector_id: &Data::VectorIdType) -> bool {
54 self.mapping.contains_key(vector_id)
55 }
56
57 pub fn get_vector(&self, vector_id: &Data::VectorIdType) -> Option<&[Data::VectorDataType]> {
59 if let Some(idx) = self.mapping.get(vector_id) {
60 Some(&self.vectors[idx * self.dimension..(idx + 1) * self.dimension])
61 } else {
62 Option::None
63 }
64 }
65
66 pub fn get_adjacency_list(
68 &self,
69 vector_id: &Data::VectorIdType,
70 ) -> Option<&AdjacencyList<Data::VectorIdType>> {
71 if let Some(idx) = self.mapping.get(vector_id) {
72 Some(&self.adjacency_lists[*idx])
73 } else {
74 Option::None
75 }
76 }
77
78 pub fn get_associated_data(
80 &self,
81 vector_id: &Data::VectorIdType,
82 ) -> Option<&Data::AssociatedDataType> {
83 if let Some(idx) = self.mapping.get(vector_id) {
84 Some(&self.associated_data[*idx])
85 } else {
86 Option::None
87 }
88 }
89
90 pub fn insert(
93 &mut self,
94 vector_id: &Data::VectorIdType,
95 vector: &[Data::VectorDataType],
96 adjacency_list: AdjacencyList<Data::VectorIdType>,
97 associated_data: Data::AssociatedDataType,
98 ) -> ANNResult<()> {
99 if self.dimension != vector.len() {
100 return ANNResult::Err(ANNError::log_index_error(
101 "Vector dimension does not match the dimension set in cache.",
102 ));
103 }
104
105 if let Occupied(occupied_entry) = self.mapping.entry(*vector_id) {
106 let idx = *occupied_entry.get();
107 self.copy_to_cache(idx, vector, adjacency_list, associated_data);
108 return ANNResult::Ok(());
109 }
110
111 if self.mapping.len() >= self.capacity {
112 return ANNResult::Err(ANNError::log_index_error(
113 "Cache is full, cannot insert more nodes",
114 ));
115 }
116
117 let idx = self.mapping.len();
118 self.mapping.insert(*vector_id, idx);
119 self.copy_to_cache(idx, vector, adjacency_list, associated_data);
120 ANNResult::Ok(())
121 }
122
123 pub fn is_empty(&self) -> bool {
125 self.mapping.is_empty()
126 }
127
128 pub fn len(&self) -> usize {
130 self.mapping.len()
131 }
132
133 fn copy_to_cache(
134 &mut self,
135 idx: usize,
136 vector: &[Data::VectorDataType],
137 adjacency_list: AdjacencyList<Data::VectorIdType>,
138 associated_data: Data::AssociatedDataType,
139 ) {
140 self.vectors[idx * self.dimension..(idx + 1) * self.dimension].copy_from_slice(vector);
141 self.adjacency_lists.push(adjacency_list);
142 self.associated_data.push(associated_data);
143 }
144}
145
146#[derive(PartialEq)]
147pub enum CachingStrategy {
148 None,
149 StaticCacheWithBfsNodes(usize),
150}
151
152#[cfg(test)]
153mod tests {
154 use diskann_providers::{
155 model::graph::graph_data_model::AdjacencyList,
156 test_utils::graph_data_type_utils::GraphDataF32VectorUnitData,
157 };
158 use rstest::rstest;
159
160 use crate::data_model::Cache;
161
162 #[rstest]
163 fn test_contains() {
164 let mut cache =
165 Cache::<GraphDataF32VectorUnitData>::new(10, 2).unwrap();
166 insert_a_random_node(&mut cache);
167 let vector_id = 1;
168 let vector = vec![1.0; 10];
169 let adjacency_list = AdjacencyList::from(vec![2, 3, 4]);
170 cache
171 .insert(&vector_id, &vector, adjacency_list, ())
172 .unwrap();
173
174 assert!(cache.contains(&vector_id));
175
176 let not_exist_vector_id = 2;
177 assert!(!cache.contains(¬_exist_vector_id));
178 }
179
180 #[rstest]
181 fn test_get_vector() {
182 let mut cache =
183 Cache::<GraphDataF32VectorUnitData>::new(10, 2).unwrap();
184 insert_a_random_node(&mut cache);
185 let vector_id = 1;
186 let vector = vec![1.0; 10];
187 let adjacency_list = AdjacencyList::from(vec![2, 3, 4]);
188 cache
189 .insert(&vector_id, &vector, adjacency_list, ())
190 .unwrap();
191
192 let result = cache.get_vector(&vector_id).unwrap();
193 assert_eq!(result, vector.as_slice());
194
195 let not_exist_vector_id = 2;
196 assert!(cache.get_vector(¬_exist_vector_id).is_none());
197 }
198
199 #[rstest]
200 fn test_get_adjacency_list() {
201 let mut cache =
202 Cache::<GraphDataF32VectorUnitData>::new(10, 2).unwrap();
203 insert_a_random_node(&mut cache);
204 let vector_id = 1;
205 let vector = vec![1.0; 10];
206 let adjacency_list = AdjacencyList::from(vec![2, 3, 4]);
207 cache
208 .insert(&vector_id, &vector, adjacency_list.clone(), ())
209 .unwrap();
210
211 let result = cache.get_adjacency_list(&vector_id).unwrap();
212 assert_eq!(*result, adjacency_list);
213
214 let not_exist_vector_id = 2;
215 assert!(cache.get_adjacency_list(¬_exist_vector_id).is_none());
216 }
217
218 #[rstest]
219 fn test_get_associated_data() {
220 let mut cache =
221 Cache::<GraphDataF32VectorUnitData>::new(10, 2).unwrap();
222 insert_a_random_node(&mut cache);
223 let vector_id = 1;
224 let vector = vec![1.0; 10];
225 let adjacency_list = AdjacencyList::from(vec![2, 3, 4]);
226 let associated_data = ();
227 cache
228 .insert(&vector_id, &vector, adjacency_list, associated_data)
229 .unwrap();
230
231 let result = cache.get_associated_data(&vector_id);
232 assert!(result.is_some());
233
234 let not_exist_vector_id = 2;
235 assert!(cache.get_associated_data(¬_exist_vector_id).is_none());
236 }
237
238 #[rstest]
239 fn test_insert() {
240 let mut cache =
241 Cache::<GraphDataF32VectorUnitData>::new(10, 2).unwrap();
242 insert_a_random_node(&mut cache);
243 let vector_id = 1;
244 let vector = vec![1.0; 10];
245 let adjacency_list = AdjacencyList::from(vec![2, 3, 4]);
246
247 cache
249 .insert(&vector_id, &vector, adjacency_list.clone(), ())
250 .unwrap();
251 assert!(cache.contains(&vector_id));
252
253 let updated_vector = vec![2.0; 10];
255 cache
256 .insert(&vector_id, &updated_vector, adjacency_list.clone(), ())
257 .unwrap();
258 assert_eq!(
259 cache.get_vector(&vector_id).unwrap(),
260 updated_vector.as_slice()
261 );
262
263 let vector_id_2 = 2;
265 let result = cache.insert(&vector_id_2, &vector, adjacency_list.clone(), ());
266 assert!(result.is_err());
267
268 let wrong_dimentions_vector = vec![1.0; 11];
270 assert!(cache
271 .insert(&vector_id, &wrong_dimentions_vector, adjacency_list, ())
272 .is_err());
273 }
274
275 #[rstest]
276 fn test_is_empty() {
277 let mut cache =
278 Cache::<GraphDataF32VectorUnitData>::new(10, 1).unwrap();
279
280 assert!(cache.is_empty());
281
282 insert_a_random_node(&mut cache);
283
284 assert!(!cache.is_empty());
285 }
286
287 #[rstest]
288 fn test_len() {
289 let mut cache =
290 Cache::<GraphDataF32VectorUnitData>::new(10, 5).unwrap();
291
292 assert_eq!(cache.len(), 0);
293
294 let vector_id = 1;
295 let vector = vec![1.0; 10];
296 let adjacency_list = AdjacencyList::from(vec![2, 3, 4]);
297 cache
298 .insert(&vector_id, &vector, adjacency_list.clone(), ())
299 .unwrap();
300 let vector_id_2 = 2;
301 cache
302 .insert(&vector_id_2, &vector, adjacency_list.clone(), ())
303 .unwrap();
304 let vector_id_3 = 3;
305 cache
306 .insert(&vector_id_3, &vector, adjacency_list, ())
307 .unwrap();
308
309 assert_eq!(cache.len(), 3);
310 }
311
312 fn insert_a_random_node(cache: &mut Cache<GraphDataF32VectorUnitData>) {
313 let vector_id = 99;
314 let vector = vec![9.0; 10];
315 cache
316 .insert(
317 &vector_id,
318 &vector,
319 AdjacencyList::from(vec![20, 30, 40]),
320 (),
321 )
322 .unwrap();
323 }
324}