pgm_extra/index/external/
cached.rs1use core::ops::RangeBounds;
6
7use crate::error::Error;
8use crate::index::Static;
9use crate::index::key::Indexable;
10use crate::util::ApproxPos;
11use crate::util::cache::{FastHash, HotCache};
12use crate::util::range::range_to_indices;
13
14#[derive(Debug)]
38#[cfg_attr(
39 feature = "rkyv",
40 derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
41)]
42#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
43#[cfg_attr(
44 feature = "serde",
45 serde(bound = "T::Key: serde::Serialize + serde::de::DeserializeOwned")
46)]
47pub struct Cached<T: Indexable>
48where
49 T::Key: FastHash + core::default::Default,
50{
51 inner: Static<T>,
52 #[cfg_attr(feature = "rkyv", rkyv(with = rkyv::with::Skip))]
53 #[cfg_attr(feature = "serde", serde(skip, default))]
54 cache: HotCache<T::Key>,
55}
56
57impl<T: Indexable> Cached<T>
58where
59 T::Key: Ord + FastHash + core::default::Default,
60{
61 pub fn new(data: &[T], epsilon: usize, epsilon_recursive: usize) -> Result<Self, Error> {
63 let inner = Static::new(data, epsilon, epsilon_recursive)?;
64 Ok(Self {
65 inner,
66 cache: HotCache::new(),
67 })
68 }
69
70 pub fn from_index(index: Static<T>) -> Self {
72 Self {
73 inner: index,
74 cache: HotCache::new(),
75 }
76 }
77
78 #[inline]
80 pub fn search(&self, value: &T) -> ApproxPos {
81 self.inner.search(value)
82 }
83
84 #[inline]
86 pub fn lower_bound(&self, data: &[T], value: &T) -> usize
87 where
88 T: Ord,
89 {
90 let key = value.index_key();
91
92 if let Some(pos) = self.cache.lookup(&key)
93 && pos < data.len()
94 && data[pos] == *value
95 {
96 return pos;
97 }
98
99 let result = self.inner.lower_bound(data, value);
100
101 if result < data.len() && data[result] == *value {
102 self.cache.insert(key, result);
103 }
104
105 result
106 }
107
108 #[inline]
110 pub fn upper_bound(&self, data: &[T], value: &T) -> usize
111 where
112 T: Ord,
113 {
114 self.inner.upper_bound(data, value)
115 }
116
117 #[inline]
119 pub fn contains(&self, data: &[T], value: &T) -> bool
120 where
121 T: Ord,
122 {
123 let key = value.index_key();
124
125 if let Some(pos) = self.cache.lookup(&key)
126 && pos < data.len()
127 && data[pos] == *value
128 {
129 return true;
130 }
131
132 let result = self.inner.contains(data, value);
133
134 if result {
135 let pos = self.inner.lower_bound(data, value);
136 self.cache.insert(key, pos);
137 }
138
139 result
140 }
141
142 #[inline]
143 pub fn len(&self) -> usize {
144 self.inner.len()
145 }
146
147 #[inline]
148 pub fn is_empty(&self) -> bool {
149 self.inner.is_empty()
150 }
151
152 #[inline]
153 pub fn segments_count(&self) -> usize {
154 self.inner.segments_count()
155 }
156
157 #[inline]
158 pub fn height(&self) -> usize {
159 self.inner.height()
160 }
161
162 #[inline]
163 pub fn epsilon(&self) -> usize {
164 self.inner.epsilon()
165 }
166
167 #[inline]
168 pub fn epsilon_recursive(&self) -> usize {
169 self.inner.epsilon_recursive()
170 }
171
172 pub fn size_in_bytes(&self) -> usize {
173 self.inner.size_in_bytes() + core::mem::size_of::<HotCache<T::Key>>()
174 }
175
176 pub fn clear_cache(&self) {
178 self.cache.clear();
179 }
180
181 pub fn inner(&self) -> &Static<T> {
183 &self.inner
184 }
185
186 pub fn into_inner(self) -> Static<T> {
188 self.inner
189 }
190
191 #[inline]
193 pub fn range_indices<R>(&self, data: &[T], range: R) -> (usize, usize)
194 where
195 T: Ord,
196 R: RangeBounds<T>,
197 {
198 range_to_indices(
199 range,
200 data.len(),
201 |v| self.lower_bound(data, v),
202 |v| self.upper_bound(data, v),
203 )
204 }
205
206 #[inline]
208 pub fn range<'a, R>(&self, data: &'a [T], range: R) -> impl DoubleEndedIterator<Item = &'a T>
209 where
210 T: Ord,
211 R: RangeBounds<T>,
212 {
213 let (start, end) = self.range_indices(data, range);
214 data[start..end].iter()
215 }
216}
217
218impl<T: Indexable> From<Static<T>> for Cached<T>
219where
220 T::Key: Ord + FastHash + core::default::Default,
221{
222 fn from(index: Static<T>) -> Self {
223 Self::from_index(index)
224 }
225}
226
227impl<T: Indexable> From<Cached<T>> for Static<T>
228where
229 T::Key: Ord + FastHash + core::default::Default,
230{
231 fn from(cached: Cached<T>) -> Self {
232 cached.into_inner()
233 }
234}
235
236impl<T: Indexable> crate::index::External<T> for Cached<T>
237where
238 T::Key: Ord + crate::util::cache::FastHash + core::default::Default,
239{
240 #[inline]
241 fn search(&self, value: &T) -> ApproxPos {
242 self.search(value)
243 }
244
245 #[inline]
246 fn lower_bound(&self, data: &[T], value: &T) -> usize
247 where
248 T: Ord,
249 {
250 self.lower_bound(data, value)
251 }
252
253 #[inline]
254 fn upper_bound(&self, data: &[T], value: &T) -> usize
255 where
256 T: Ord,
257 {
258 self.upper_bound(data, value)
259 }
260
261 #[inline]
262 fn contains(&self, data: &[T], value: &T) -> bool
263 where
264 T: Ord,
265 {
266 self.contains(data, value)
267 }
268
269 #[inline]
270 fn len(&self) -> usize {
271 self.len()
272 }
273
274 #[inline]
275 fn segments_count(&self) -> usize {
276 self.segments_count()
277 }
278
279 #[inline]
280 fn epsilon(&self) -> usize {
281 self.epsilon()
282 }
283
284 #[inline]
285 fn size_in_bytes(&self) -> usize {
286 self.size_in_bytes()
287 }
288}
289
290#[cfg(test)]
291mod tests {
292 use super::*;
293 use alloc::vec::Vec;
294
295 #[test]
296 fn test_cached_index_basic() {
297 let keys: Vec<u64> = (0..10000).collect();
298 let index = Cached::new(&keys, 64, 4).unwrap();
299
300 assert_eq!(index.len(), 10000);
301 assert!(!index.is_empty());
302 }
303
304 #[test]
305 fn test_cached_index_hit() {
306 let keys: Vec<u64> = (0..1000).collect();
307 let index = Cached::new(&keys, 64, 4).unwrap();
308
309 let key = 500u64;
310 let pos1 = index.lower_bound(&keys, &key);
311 assert_eq!(pos1, 500);
312
313 let pos2 = index.lower_bound(&keys, &key);
314 assert_eq!(pos2, 500);
315 }
316
317 #[test]
318 fn test_cached_contains() {
319 let keys: Vec<u64> = (0..100).map(|i| i * 2).collect();
320 let index = Cached::new(&keys, 8, 4).unwrap();
321
322 assert!(index.contains(&keys, &0));
323 assert!(index.contains(&keys, &100));
324
325 assert!(index.contains(&keys, &0));
326
327 assert!(!index.contains(&keys, &1));
328 assert!(!index.contains(&keys, &99));
329 }
330
331 #[test]
332 fn test_cached_clear() {
333 let keys: Vec<u64> = (0..100).collect();
334 let index = Cached::new(&keys, 16, 4).unwrap();
335
336 let _ = index.lower_bound(&keys, &50);
337 index.clear_cache();
338 }
339}