Skip to main content

diskann_disk/data_model/
cache.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use diskann::{ANNError, ANNResult};
7use diskann_providers::{
8    common::AlignedBoxWithSlice,
9    model::{
10        graph::{graph_data_model::AdjacencyList, traits::GraphDataType},
11        FP_VECTOR_MEM_ALIGN,
12    },
13};
14use hashbrown::{hash_map::Entry::Occupied, HashMap};
15
16pub struct Cache<Data: GraphDataType<VectorIdType = u32>> {
17    // Maintains the mapping of vector_id to index in the global cached nodes list.
18    mapping: HashMap<Data::VectorIdType, usize>,
19
20    // Aligned buffer to store the vectors of cached nodes.
21    vectors: AlignedBoxWithSlice<Data::VectorDataType>,
22
23    // The cached adjacency lists.
24    adjacency_lists: Vec<AdjacencyList<Data::VectorIdType>>,
25
26    // The cached associated data list.
27    associated_data: Vec<Data::AssociatedDataType>,
28
29    // The dimension of the vectors in the cache.
30    dimension: usize,
31
32    // The capacity of the cache.
33    capacity: usize,
34}
35
36impl<Data> Cache<Data>
37where
38    Data: GraphDataType<VectorIdType = u32>,
39{
40    // Creates a new cache with the specified dimension and capacity.
41    pub fn new(dimension: usize, capacity: usize) -> ANNResult<Self> {
42        Ok(Self {
43            mapping: HashMap::new(),
44            vectors: AlignedBoxWithSlice::new(capacity * dimension, FP_VECTOR_MEM_ALIGN)?,
45            adjacency_lists: Vec::with_capacity(capacity),
46            associated_data: Vec::with_capacity(capacity),
47            dimension,
48            capacity,
49        })
50    }
51
52    // Returns `true` if the cache contains the `vector_id`, otherwise `false`.
53    pub fn contains(&self, vector_id: &Data::VectorIdType) -> bool {
54        self.mapping.contains_key(vector_id)
55    }
56
57    // Returns the vector associated with the `vector_id`, if it exists in the cache otherwise `Option::None`.
58    pub fn get_vector(&self, vector_id: &Data::VectorIdType) -> Option<&[Data::VectorDataType]> {
59        if let Some(idx) = self.mapping.get(vector_id) {
60            Some(&self.vectors[idx * self.dimension..(idx + 1) * self.dimension])
61        } else {
62            Option::None
63        }
64    }
65
66    // Returns the adjacency list associated with the `vector_id``, if it exists in the cache otherwise `Option::None`.
67    pub fn get_adjacency_list(
68        &self,
69        vector_id: &Data::VectorIdType,
70    ) -> Option<&AdjacencyList<Data::VectorIdType>> {
71        if let Some(idx) = self.mapping.get(vector_id) {
72            Some(&self.adjacency_lists[*idx])
73        } else {
74            Option::None
75        }
76    }
77
78    // Returns the associated data associated with the `vector_id`, if it exists in the cache otherwise `Option::None`.
79    pub fn get_associated_data(
80        &self,
81        vector_id: &Data::VectorIdType,
82    ) -> Option<&Data::AssociatedDataType> {
83        if let Some(idx) = self.mapping.get(vector_id) {
84            Some(&self.associated_data[*idx])
85        } else {
86            Option::None
87        }
88    }
89
90    // Inserts a new node in the cache, if the node already exists in the cache, it updates the node.
91    // If the cache is full, it returns an error.
92    pub fn insert(
93        &mut self,
94        vector_id: &Data::VectorIdType,
95        vector: &[Data::VectorDataType],
96        adjacency_list: AdjacencyList<Data::VectorIdType>,
97        associated_data: Data::AssociatedDataType,
98    ) -> ANNResult<()> {
99        if self.dimension != vector.len() {
100            return ANNResult::Err(ANNError::log_index_error(
101                "Vector dimension does not match the dimension set in cache.",
102            ));
103        }
104
105        if let Occupied(occupied_entry) = self.mapping.entry(*vector_id) {
106            let idx = *occupied_entry.get();
107            self.copy_to_cache(idx, vector, adjacency_list, associated_data);
108            return ANNResult::Ok(());
109        }
110
111        if self.mapping.len() >= self.capacity {
112            return ANNResult::Err(ANNError::log_index_error(
113                "Cache is full, cannot insert more nodes",
114            ));
115        }
116
117        let idx = self.mapping.len();
118        self.mapping.insert(*vector_id, idx);
119        self.copy_to_cache(idx, vector, adjacency_list, associated_data);
120        ANNResult::Ok(())
121    }
122
123    // Returns `true` if the cache is empty, otherwise `false`.
124    pub fn is_empty(&self) -> bool {
125        self.mapping.is_empty()
126    }
127
128    // Returns the number of nodes in the cache.
129    pub fn len(&self) -> usize {
130        self.mapping.len()
131    }
132
133    fn copy_to_cache(
134        &mut self,
135        idx: usize,
136        vector: &[Data::VectorDataType],
137        adjacency_list: AdjacencyList<Data::VectorIdType>,
138        associated_data: Data::AssociatedDataType,
139    ) {
140        self.vectors[idx * self.dimension..(idx + 1) * self.dimension].copy_from_slice(vector);
141        self.adjacency_lists.push(adjacency_list);
142        self.associated_data.push(associated_data);
143    }
144}
145
146#[derive(PartialEq)]
147pub enum CachingStrategy {
148    None,
149    StaticCacheWithBfsNodes(usize),
150}
151
152#[cfg(test)]
153mod tests {
154    use diskann_providers::{
155        model::graph::graph_data_model::AdjacencyList,
156        test_utils::graph_data_type_utils::GraphDataF32VectorUnitData,
157    };
158    use rstest::rstest;
159
160    use crate::data_model::Cache;
161
162    #[rstest]
163    fn test_contains() {
164        let mut cache =
165            Cache::<GraphDataF32VectorUnitData>::new(/*dimention=*/ 10, /*capacity=*/ 2).unwrap();
166        insert_a_random_node(&mut cache);
167        let vector_id = 1;
168        let vector = vec![1.0; 10];
169        let adjacency_list = AdjacencyList::from(vec![2, 3, 4]);
170        cache
171            .insert(&vector_id, &vector, adjacency_list, ())
172            .unwrap();
173
174        assert!(cache.contains(&vector_id));
175
176        let not_exist_vector_id = 2;
177        assert!(!cache.contains(&not_exist_vector_id));
178    }
179
180    #[rstest]
181    fn test_get_vector() {
182        let mut cache =
183            Cache::<GraphDataF32VectorUnitData>::new(/*dimention=*/ 10, /*capacity=*/ 2).unwrap();
184        insert_a_random_node(&mut cache);
185        let vector_id = 1;
186        let vector = vec![1.0; 10];
187        let adjacency_list = AdjacencyList::from(vec![2, 3, 4]);
188        cache
189            .insert(&vector_id, &vector, adjacency_list, ())
190            .unwrap();
191
192        let result = cache.get_vector(&vector_id).unwrap();
193        assert_eq!(result, vector.as_slice());
194
195        let not_exist_vector_id = 2;
196        assert!(cache.get_vector(&not_exist_vector_id).is_none());
197    }
198
199    #[rstest]
200    fn test_get_adjacency_list() {
201        let mut cache =
202            Cache::<GraphDataF32VectorUnitData>::new(/*dimention=*/ 10, /*capacity=*/ 2).unwrap();
203        insert_a_random_node(&mut cache);
204        let vector_id = 1;
205        let vector = vec![1.0; 10];
206        let adjacency_list = AdjacencyList::from(vec![2, 3, 4]);
207        cache
208            .insert(&vector_id, &vector, adjacency_list.clone(), ())
209            .unwrap();
210
211        let result = cache.get_adjacency_list(&vector_id).unwrap();
212        assert_eq!(*result, adjacency_list);
213
214        let not_exist_vector_id = 2;
215        assert!(cache.get_adjacency_list(&not_exist_vector_id).is_none());
216    }
217
218    #[rstest]
219    fn test_get_associated_data() {
220        let mut cache =
221            Cache::<GraphDataF32VectorUnitData>::new(/*dimention=*/ 10, /*capacity=*/ 2).unwrap();
222        insert_a_random_node(&mut cache);
223        let vector_id = 1;
224        let vector = vec![1.0; 10];
225        let adjacency_list = AdjacencyList::from(vec![2, 3, 4]);
226        let associated_data = ();
227        cache
228            .insert(&vector_id, &vector, adjacency_list, associated_data)
229            .unwrap();
230
231        let result = cache.get_associated_data(&vector_id);
232        assert!(result.is_some());
233
234        let not_exist_vector_id = 2;
235        assert!(cache.get_associated_data(&not_exist_vector_id).is_none());
236    }
237
238    #[rstest]
239    fn test_insert() {
240        let mut cache =
241            Cache::<GraphDataF32VectorUnitData>::new(/*dimention=*/ 10, /*capacity=*/ 2).unwrap();
242        insert_a_random_node(&mut cache);
243        let vector_id = 1;
244        let vector = vec![1.0; 10];
245        let adjacency_list = AdjacencyList::from(vec![2, 3, 4]);
246
247        // Insert in cache
248        cache
249            .insert(&vector_id, &vector, adjacency_list.clone(), ())
250            .unwrap();
251        assert!(cache.contains(&vector_id));
252
253        // Update in cache
254        let updated_vector = vec![2.0; 10];
255        cache
256            .insert(&vector_id, &updated_vector, adjacency_list.clone(), ())
257            .unwrap();
258        assert_eq!(
259            cache.get_vector(&vector_id).unwrap(),
260            updated_vector.as_slice()
261        );
262
263        // Cache is Full
264        let vector_id_2 = 2;
265        let result = cache.insert(&vector_id_2, &vector, adjacency_list.clone(), ());
266        assert!(result.is_err());
267
268        // Wrong dimention Insert fails.
269        let wrong_dimentions_vector = vec![1.0; 11];
270        assert!(cache
271            .insert(&vector_id, &wrong_dimentions_vector, adjacency_list, ())
272            .is_err());
273    }
274
275    #[rstest]
276    fn test_is_empty() {
277        let mut cache =
278            Cache::<GraphDataF32VectorUnitData>::new(/*dimention=*/ 10, /*capacity=*/ 1).unwrap();
279
280        assert!(cache.is_empty());
281
282        insert_a_random_node(&mut cache);
283
284        assert!(!cache.is_empty());
285    }
286
287    #[rstest]
288    fn test_len() {
289        let mut cache =
290            Cache::<GraphDataF32VectorUnitData>::new(/*dimention=*/ 10, /*capacity=*/ 5).unwrap();
291
292        assert_eq!(cache.len(), 0);
293
294        let vector_id = 1;
295        let vector = vec![1.0; 10];
296        let adjacency_list = AdjacencyList::from(vec![2, 3, 4]);
297        cache
298            .insert(&vector_id, &vector, adjacency_list.clone(), ())
299            .unwrap();
300        let vector_id_2 = 2;
301        cache
302            .insert(&vector_id_2, &vector, adjacency_list.clone(), ())
303            .unwrap();
304        let vector_id_3 = 3;
305        cache
306            .insert(&vector_id_3, &vector, adjacency_list, ())
307            .unwrap();
308
309        assert_eq!(cache.len(), 3);
310    }
311
312    fn insert_a_random_node(cache: &mut Cache<GraphDataF32VectorUnitData>) {
313        let vector_id = 99;
314        let vector = vec![9.0; 10];
315        cache
316            .insert(
317                &vector_id,
318                &vector,
319                AdjacencyList::from(vec![20, 30, 40]),
320                (),
321            )
322            .unwrap();
323    }
324}