pgm_extra/index/external/
cached.rs

1//! Cached PGM-Index with hot-key lookup optimization.
2//!
3//! Wraps a multi-level index with a small cache for frequently accessed keys.
4
5use core::ops::RangeBounds;
6
7use crate::error::Error;
8use crate::index::Static;
9use crate::index::key::Indexable;
10use crate::util::ApproxPos;
11use crate::util::cache::{FastHash, HotCache};
12use crate::util::range::range_to_indices;
13
14/// A PGM-Index wrapper with a small hot-key cache.
15///
16/// This struct uses interior mutability (`Cell`) to update the cache on read operations.
17/// Therefore, it is `!Sync` and cannot be shared across threads without synchronization
18/// (e.g., `Mutex<Cached>`).
19///
20/// Note: When serialized with serde, only the inner index is saved. The cache is
21/// recreated empty on deserialization.
22///
23/// # Example
24///
25/// ```
26/// use pgm_extra::index::external::Cached;
27///
28/// let keys: Vec<u64> = (0..10000).collect();
29/// let index = Cached::new(&keys, 64, 4).unwrap();
30///
31/// // First lookup - cache miss, populates cache
32/// assert!(index.contains(&keys, &5000));
33///
34/// // Second lookup - cache hit
35/// assert!(index.contains(&keys, &5000));
36/// ```
37#[derive(Debug)]
38#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
39#[cfg_attr(
40    feature = "serde",
41    serde(bound = "T::Key: serde::Serialize + serde::de::DeserializeOwned")
42)]
43pub struct Cached<T: Indexable>
44where
45    T::Key: FastHash + core::default::Default,
46{
47    inner: Static<T>,
48    #[cfg_attr(feature = "serde", serde(skip, default = "HotCache::new"))]
49    cache: HotCache<T::Key>,
50}
51
52impl<T: Indexable> Cached<T>
53where
54    T::Key: Ord + FastHash + core::default::Default,
55{
56    /// Build a new cached PGM-Index from sorted data.
57    pub fn new(data: &[T], epsilon: usize, epsilon_recursive: usize) -> Result<Self, Error> {
58        let inner = Static::new(data, epsilon, epsilon_recursive)?;
59        Ok(Self {
60            inner,
61            cache: HotCache::new(),
62        })
63    }
64
65    /// Wrap an existing index with a cache.
66    pub fn from_index(index: Static<T>) -> Self {
67        Self {
68            inner: index,
69            cache: HotCache::new(),
70        }
71    }
72
73    /// Get an approximate position for the given value.
74    #[inline]
75    pub fn search(&self, value: &T) -> ApproxPos {
76        self.inner.search(value)
77    }
78
79    /// Find the first position where `data[pos] >= value`.
80    #[inline]
81    pub fn lower_bound(&self, data: &[T], value: &T) -> usize
82    where
83        T: Ord,
84    {
85        let key = value.index_key();
86
87        if let Some(pos) = self.cache.lookup(&key)
88            && pos < data.len()
89            && data[pos] == *value
90        {
91            return pos;
92        }
93
94        let result = self.inner.lower_bound(data, value);
95
96        if result < data.len() && data[result] == *value {
97            self.cache.insert(key, result);
98        }
99
100        result
101    }
102
103    /// Find the first position where `data[pos] > value`.
104    #[inline]
105    pub fn upper_bound(&self, data: &[T], value: &T) -> usize
106    where
107        T: Ord,
108    {
109        self.inner.upper_bound(data, value)
110    }
111
112    /// Check if the value exists in the data.
113    #[inline]
114    pub fn contains(&self, data: &[T], value: &T) -> bool
115    where
116        T: Ord,
117    {
118        let key = value.index_key();
119
120        if let Some(pos) = self.cache.lookup(&key)
121            && pos < data.len()
122            && data[pos] == *value
123        {
124            return true;
125        }
126
127        let result = self.inner.contains(data, value);
128
129        if result {
130            let pos = self.inner.lower_bound(data, value);
131            self.cache.insert(key, pos);
132        }
133
134        result
135    }
136
137    #[inline]
138    pub fn len(&self) -> usize {
139        self.inner.len()
140    }
141
142    #[inline]
143    pub fn is_empty(&self) -> bool {
144        self.inner.is_empty()
145    }
146
147    #[inline]
148    pub fn segments_count(&self) -> usize {
149        self.inner.segments_count()
150    }
151
152    #[inline]
153    pub fn height(&self) -> usize {
154        self.inner.height()
155    }
156
157    #[inline]
158    pub fn epsilon(&self) -> usize {
159        self.inner.epsilon()
160    }
161
162    #[inline]
163    pub fn epsilon_recursive(&self) -> usize {
164        self.inner.epsilon_recursive()
165    }
166
167    pub fn size_in_bytes(&self) -> usize {
168        self.inner.size_in_bytes() + core::mem::size_of::<HotCache<T::Key>>()
169    }
170
171    /// Clear all cached entries.
172    pub fn clear_cache(&self) {
173        self.cache.clear();
174    }
175
176    /// Get a reference to the inner index.
177    pub fn inner(&self) -> &Static<T> {
178        &self.inner
179    }
180
181    /// Consume and return the inner index.
182    pub fn into_inner(self) -> Static<T> {
183        self.inner
184    }
185
186    /// Returns the (start, end) indices for iterating over data in the given range.
187    #[inline]
188    pub fn range_indices<R>(&self, data: &[T], range: R) -> (usize, usize)
189    where
190        T: Ord,
191        R: RangeBounds<T>,
192    {
193        range_to_indices(
194            range,
195            data.len(),
196            |v| self.lower_bound(data, v),
197            |v| self.upper_bound(data, v),
198        )
199    }
200
201    /// Returns an iterator over data in the given range.
202    #[inline]
203    pub fn range<'a, R>(&self, data: &'a [T], range: R) -> impl DoubleEndedIterator<Item = &'a T>
204    where
205        T: Ord,
206        R: RangeBounds<T>,
207    {
208        let (start, end) = self.range_indices(data, range);
209        data[start..end].iter()
210    }
211}
212
213impl<T: Indexable> From<Static<T>> for Cached<T>
214where
215    T::Key: Ord + FastHash + core::default::Default,
216{
217    fn from(index: Static<T>) -> Self {
218        Self::from_index(index)
219    }
220}
221
222impl<T: Indexable> From<Cached<T>> for Static<T>
223where
224    T::Key: Ord + FastHash + core::default::Default,
225{
226    fn from(cached: Cached<T>) -> Self {
227        cached.into_inner()
228    }
229}
230
231impl<T: Indexable> crate::index::External<T> for Cached<T>
232where
233    T::Key: Ord + crate::util::cache::FastHash + core::default::Default,
234{
235    #[inline]
236    fn search(&self, value: &T) -> ApproxPos {
237        self.search(value)
238    }
239
240    #[inline]
241    fn lower_bound(&self, data: &[T], value: &T) -> usize
242    where
243        T: Ord,
244    {
245        self.lower_bound(data, value)
246    }
247
248    #[inline]
249    fn upper_bound(&self, data: &[T], value: &T) -> usize
250    where
251        T: Ord,
252    {
253        self.upper_bound(data, value)
254    }
255
256    #[inline]
257    fn contains(&self, data: &[T], value: &T) -> bool
258    where
259        T: Ord,
260    {
261        self.contains(data, value)
262    }
263
264    #[inline]
265    fn len(&self) -> usize {
266        self.len()
267    }
268
269    #[inline]
270    fn segments_count(&self) -> usize {
271        self.segments_count()
272    }
273
274    #[inline]
275    fn epsilon(&self) -> usize {
276        self.epsilon()
277    }
278
279    #[inline]
280    fn size_in_bytes(&self) -> usize {
281        self.size_in_bytes()
282    }
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288    use alloc::vec::Vec;
289
290    #[test]
291    fn test_cached_index_basic() {
292        let keys: Vec<u64> = (0..10000).collect();
293        let index = Cached::new(&keys, 64, 4).unwrap();
294
295        assert_eq!(index.len(), 10000);
296        assert!(!index.is_empty());
297    }
298
299    #[test]
300    fn test_cached_index_hit() {
301        let keys: Vec<u64> = (0..1000).collect();
302        let index = Cached::new(&keys, 64, 4).unwrap();
303
304        let key = 500u64;
305        let pos1 = index.lower_bound(&keys, &key);
306        assert_eq!(pos1, 500);
307
308        let pos2 = index.lower_bound(&keys, &key);
309        assert_eq!(pos2, 500);
310    }
311
312    #[test]
313    fn test_cached_contains() {
314        let keys: Vec<u64> = (0..100).map(|i| i * 2).collect();
315        let index = Cached::new(&keys, 8, 4).unwrap();
316
317        assert!(index.contains(&keys, &0));
318        assert!(index.contains(&keys, &100));
319
320        assert!(index.contains(&keys, &0));
321
322        assert!(!index.contains(&keys, &1));
323        assert!(!index.contains(&keys, &99));
324    }
325
326    #[test]
327    fn test_cached_clear() {
328        let keys: Vec<u64> = (0..100).collect();
329        let index = Cached::new(&keys, 16, 4).unwrap();
330
331        let _ = index.lower_bound(&keys, &50);
332        index.clear_cache();
333    }
334}