pgm_extra/index/external/
cached.rs

1//! Cached PGM-Index with hot-key lookup optimization.
2//!
3//! Wraps a multi-level index with a small cache for frequently accessed keys.
4
5use core::ops::RangeBounds;
6
7use crate::error::Error;
8use crate::index::Static;
9use crate::index::key::Indexable;
10use crate::util::ApproxPos;
11use crate::util::cache::{FastHash, HotCache};
12use crate::util::range::range_to_indices;
13
14/// A PGM-Index wrapper with a small hot-key cache.
15///
16/// This struct uses interior mutability (`Cell`) to update the cache on read operations.
17/// Therefore, it is `!Sync` and cannot be shared across threads without synchronization
18/// (e.g., `Mutex<Cached>`).
19///
20/// Note: When serialized with serde, only the inner index is saved. The cache is
21/// recreated empty on deserialization.
22///
23/// # Example
24///
25/// ```
26/// use pgm_extra::index::external::Cached;
27///
28/// let keys: Vec<u64> = (0..10000).collect();
29/// let index = Cached::new(&keys, 64, 4).unwrap();
30///
31/// // First lookup - cache miss, populates cache
32/// assert!(index.contains(&keys, &5000));
33///
34/// // Second lookup - cache hit
35/// assert!(index.contains(&keys, &5000));
36/// ```
37#[derive(Debug)]
38#[cfg_attr(
39    feature = "rkyv",
40    derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
41)]
42#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
43#[cfg_attr(
44    feature = "serde",
45    serde(bound = "T::Key: serde::Serialize + serde::de::DeserializeOwned")
46)]
47pub struct Cached<T: Indexable>
48where
49    T::Key: FastHash + core::default::Default,
50{
51    inner: Static<T>,
52    #[cfg_attr(feature = "rkyv", rkyv(with = rkyv::with::Skip))]
53    #[cfg_attr(feature = "serde", serde(skip, default))]
54    cache: HotCache<T::Key>,
55}
56
57impl<T: Indexable> Cached<T>
58where
59    T::Key: Ord + FastHash + core::default::Default,
60{
61    /// Build a new cached PGM-Index from sorted data.
62    pub fn new(data: &[T], epsilon: usize, epsilon_recursive: usize) -> Result<Self, Error> {
63        let inner = Static::new(data, epsilon, epsilon_recursive)?;
64        Ok(Self {
65            inner,
66            cache: HotCache::new(),
67        })
68    }
69
70    /// Wrap an existing index with a cache.
71    pub fn from_index(index: Static<T>) -> Self {
72        Self {
73            inner: index,
74            cache: HotCache::new(),
75        }
76    }
77
78    /// Get an approximate position for the given value.
79    #[inline]
80    pub fn search(&self, value: &T) -> ApproxPos {
81        self.inner.search(value)
82    }
83
84    /// Find the first position where `data[pos] >= value`.
85    #[inline]
86    pub fn lower_bound(&self, data: &[T], value: &T) -> usize
87    where
88        T: Ord,
89    {
90        let key = value.index_key();
91
92        if let Some(pos) = self.cache.lookup(&key)
93            && pos < data.len()
94            && data[pos] == *value
95        {
96            return pos;
97        }
98
99        let result = self.inner.lower_bound(data, value);
100
101        if result < data.len() && data[result] == *value {
102            self.cache.insert(key, result);
103        }
104
105        result
106    }
107
108    /// Find the first position where `data[pos] > value`.
109    #[inline]
110    pub fn upper_bound(&self, data: &[T], value: &T) -> usize
111    where
112        T: Ord,
113    {
114        self.inner.upper_bound(data, value)
115    }
116
117    /// Check if the value exists in the data.
118    #[inline]
119    pub fn contains(&self, data: &[T], value: &T) -> bool
120    where
121        T: Ord,
122    {
123        let key = value.index_key();
124
125        if let Some(pos) = self.cache.lookup(&key)
126            && pos < data.len()
127            && data[pos] == *value
128        {
129            return true;
130        }
131
132        let result = self.inner.contains(data, value);
133
134        if result {
135            let pos = self.inner.lower_bound(data, value);
136            self.cache.insert(key, pos);
137        }
138
139        result
140    }
141
142    #[inline]
143    pub fn len(&self) -> usize {
144        self.inner.len()
145    }
146
147    #[inline]
148    pub fn is_empty(&self) -> bool {
149        self.inner.is_empty()
150    }
151
152    #[inline]
153    pub fn segments_count(&self) -> usize {
154        self.inner.segments_count()
155    }
156
157    #[inline]
158    pub fn height(&self) -> usize {
159        self.inner.height()
160    }
161
162    #[inline]
163    pub fn epsilon(&self) -> usize {
164        self.inner.epsilon()
165    }
166
167    #[inline]
168    pub fn epsilon_recursive(&self) -> usize {
169        self.inner.epsilon_recursive()
170    }
171
172    pub fn size_in_bytes(&self) -> usize {
173        self.inner.size_in_bytes() + core::mem::size_of::<HotCache<T::Key>>()
174    }
175
176    /// Clear all cached entries.
177    pub fn clear_cache(&self) {
178        self.cache.clear();
179    }
180
181    /// Get a reference to the inner index.
182    pub fn inner(&self) -> &Static<T> {
183        &self.inner
184    }
185
186    /// Consume and return the inner index.
187    pub fn into_inner(self) -> Static<T> {
188        self.inner
189    }
190
191    /// Returns the (start, end) indices for iterating over data in the given range.
192    #[inline]
193    pub fn range_indices<R>(&self, data: &[T], range: R) -> (usize, usize)
194    where
195        T: Ord,
196        R: RangeBounds<T>,
197    {
198        range_to_indices(
199            range,
200            data.len(),
201            |v| self.lower_bound(data, v),
202            |v| self.upper_bound(data, v),
203        )
204    }
205
206    /// Returns an iterator over data in the given range.
207    #[inline]
208    pub fn range<'a, R>(&self, data: &'a [T], range: R) -> impl DoubleEndedIterator<Item = &'a T>
209    where
210        T: Ord,
211        R: RangeBounds<T>,
212    {
213        let (start, end) = self.range_indices(data, range);
214        data[start..end].iter()
215    }
216}
217
218impl<T: Indexable> From<Static<T>> for Cached<T>
219where
220    T::Key: Ord + FastHash + core::default::Default,
221{
222    fn from(index: Static<T>) -> Self {
223        Self::from_index(index)
224    }
225}
226
227impl<T: Indexable> From<Cached<T>> for Static<T>
228where
229    T::Key: Ord + FastHash + core::default::Default,
230{
231    fn from(cached: Cached<T>) -> Self {
232        cached.into_inner()
233    }
234}
235
236impl<T: Indexable> crate::index::External<T> for Cached<T>
237where
238    T::Key: Ord + crate::util::cache::FastHash + core::default::Default,
239{
240    #[inline]
241    fn search(&self, value: &T) -> ApproxPos {
242        self.search(value)
243    }
244
245    #[inline]
246    fn lower_bound(&self, data: &[T], value: &T) -> usize
247    where
248        T: Ord,
249    {
250        self.lower_bound(data, value)
251    }
252
253    #[inline]
254    fn upper_bound(&self, data: &[T], value: &T) -> usize
255    where
256        T: Ord,
257    {
258        self.upper_bound(data, value)
259    }
260
261    #[inline]
262    fn contains(&self, data: &[T], value: &T) -> bool
263    where
264        T: Ord,
265    {
266        self.contains(data, value)
267    }
268
269    #[inline]
270    fn len(&self) -> usize {
271        self.len()
272    }
273
274    #[inline]
275    fn segments_count(&self) -> usize {
276        self.segments_count()
277    }
278
279    #[inline]
280    fn epsilon(&self) -> usize {
281        self.epsilon()
282    }
283
284    #[inline]
285    fn size_in_bytes(&self) -> usize {
286        self.size_in_bytes()
287    }
288}
289
290#[cfg(test)]
291mod tests {
292    use super::*;
293    use alloc::vec::Vec;
294
295    #[test]
296    fn test_cached_index_basic() {
297        let keys: Vec<u64> = (0..10000).collect();
298        let index = Cached::new(&keys, 64, 4).unwrap();
299
300        assert_eq!(index.len(), 10000);
301        assert!(!index.is_empty());
302    }
303
304    #[test]
305    fn test_cached_index_hit() {
306        let keys: Vec<u64> = (0..1000).collect();
307        let index = Cached::new(&keys, 64, 4).unwrap();
308
309        let key = 500u64;
310        let pos1 = index.lower_bound(&keys, &key);
311        assert_eq!(pos1, 500);
312
313        let pos2 = index.lower_bound(&keys, &key);
314        assert_eq!(pos2, 500);
315    }
316
317    #[test]
318    fn test_cached_contains() {
319        let keys: Vec<u64> = (0..100).map(|i| i * 2).collect();
320        let index = Cached::new(&keys, 8, 4).unwrap();
321
322        assert!(index.contains(&keys, &0));
323        assert!(index.contains(&keys, &100));
324
325        assert!(index.contains(&keys, &0));
326
327        assert!(!index.contains(&keys, &1));
328        assert!(!index.contains(&keys, &99));
329    }
330
331    #[test]
332    fn test_cached_clear() {
333        let keys: Vec<u64> = (0..100).collect();
334        let index = Cached::new(&keys, 16, 4).unwrap();
335
336        let _ = index.lower_bound(&keys, &50);
337        index.clear_cache();
338    }
339}