use core::ops::RangeBounds;
use crate::error::Error;
use crate::index::Static;
use crate::index::key::Indexable;
use crate::util::ApproxPos;
use crate::util::cache::{FastHash, HotCache};
use crate::util::range::range_to_indices;
#[derive(Debug)]
#[cfg_attr(
feature = "rkyv",
derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(
feature = "serde",
serde(bound = "T::Key: serde::Serialize + serde::de::DeserializeOwned")
)]
pub struct Cached<T: Indexable>
where
T::Key: FastHash + core::default::Default,
{
inner: Static<T>,
#[cfg_attr(feature = "rkyv", rkyv(with = rkyv::with::Skip))]
#[cfg_attr(feature = "serde", serde(skip, default))]
cache: HotCache<T::Key>,
}
impl<T: Indexable> Cached<T>
where
T::Key: Ord + FastHash + core::default::Default,
{
pub fn new(data: &[T], epsilon: usize, epsilon_recursive: usize) -> Result<Self, Error> {
let inner = Static::new(data, epsilon, epsilon_recursive)?;
Ok(Self {
inner,
cache: HotCache::new(),
})
}
pub fn from_index(index: Static<T>) -> Self {
Self {
inner: index,
cache: HotCache::new(),
}
}
#[inline]
pub fn search(&self, value: &T) -> ApproxPos {
self.inner.search(value)
}
#[inline]
pub fn lower_bound(&self, data: &[T], value: &T) -> usize
where
T: Ord,
{
let key = value.index_key();
if let Some(pos) = self.cache.lookup(&key)
&& pos < data.len()
&& data[pos] == *value
{
return pos;
}
let result = self.inner.lower_bound(data, value);
if result < data.len() && data[result] == *value {
self.cache.insert(key, result);
}
result
}
#[inline]
pub fn upper_bound(&self, data: &[T], value: &T) -> usize
where
T: Ord,
{
self.inner.upper_bound(data, value)
}
#[inline]
pub fn contains(&self, data: &[T], value: &T) -> bool
where
T: Ord,
{
let key = value.index_key();
if let Some(pos) = self.cache.lookup(&key)
&& pos < data.len()
&& data[pos] == *value
{
return true;
}
let result = self.inner.contains(data, value);
if result {
let pos = self.inner.lower_bound(data, value);
self.cache.insert(key, pos);
}
result
}
#[inline]
pub fn len(&self) -> usize {
self.inner.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
#[inline]
pub fn segments_count(&self) -> usize {
self.inner.segments_count()
}
#[inline]
pub fn height(&self) -> usize {
self.inner.height()
}
#[inline]
pub fn epsilon(&self) -> usize {
self.inner.epsilon()
}
#[inline]
pub fn epsilon_recursive(&self) -> usize {
self.inner.epsilon_recursive()
}
pub fn size_in_bytes(&self) -> usize {
self.inner.size_in_bytes() + core::mem::size_of::<HotCache<T::Key>>()
}
pub fn clear_cache(&self) {
self.cache.clear();
}
pub fn inner(&self) -> &Static<T> {
&self.inner
}
pub fn into_inner(self) -> Static<T> {
self.inner
}
#[inline]
pub fn range_indices<R>(&self, data: &[T], range: R) -> (usize, usize)
where
T: Ord,
R: RangeBounds<T>,
{
range_to_indices(
range,
data.len(),
|v| self.lower_bound(data, v),
|v| self.upper_bound(data, v),
)
}
#[inline]
pub fn range<'a, R>(&self, data: &'a [T], range: R) -> impl DoubleEndedIterator<Item = &'a T>
where
T: Ord,
R: RangeBounds<T>,
{
let (start, end) = self.range_indices(data, range);
data[start..end].iter()
}
}
impl<T: Indexable> From<Static<T>> for Cached<T>
where
T::Key: Ord + FastHash + core::default::Default,
{
fn from(index: Static<T>) -> Self {
Self::from_index(index)
}
}
impl<T: Indexable> From<Cached<T>> for Static<T>
where
T::Key: Ord + FastHash + core::default::Default,
{
fn from(cached: Cached<T>) -> Self {
cached.into_inner()
}
}
impl<T: Indexable> crate::index::External<T> for Cached<T>
where
T::Key: Ord + crate::util::cache::FastHash + core::default::Default,
{
#[inline]
fn search(&self, value: &T) -> ApproxPos {
self.search(value)
}
#[inline]
fn lower_bound(&self, data: &[T], value: &T) -> usize
where
T: Ord,
{
self.lower_bound(data, value)
}
#[inline]
fn upper_bound(&self, data: &[T], value: &T) -> usize
where
T: Ord,
{
self.upper_bound(data, value)
}
#[inline]
fn contains(&self, data: &[T], value: &T) -> bool
where
T: Ord,
{
self.contains(data, value)
}
#[inline]
fn len(&self) -> usize {
self.len()
}
#[inline]
fn segments_count(&self) -> usize {
self.segments_count()
}
#[inline]
fn epsilon(&self) -> usize {
self.epsilon()
}
#[inline]
fn size_in_bytes(&self) -> usize {
self.size_in_bytes()
}
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec::Vec;
#[test]
fn test_cached_index_basic() {
let keys: Vec<u64> = (0..10000).collect();
let index = Cached::new(&keys, 64, 4).unwrap();
assert_eq!(index.len(), 10000);
assert!(!index.is_empty());
}
#[test]
fn test_cached_index_hit() {
let keys: Vec<u64> = (0..1000).collect();
let index = Cached::new(&keys, 64, 4).unwrap();
let key = 500u64;
let pos1 = index.lower_bound(&keys, &key);
assert_eq!(pos1, 500);
let pos2 = index.lower_bound(&keys, &key);
assert_eq!(pos2, 500);
}
#[test]
fn test_cached_contains() {
let keys: Vec<u64> = (0..100).map(|i| i * 2).collect();
let index = Cached::new(&keys, 8, 4).unwrap();
assert!(index.contains(&keys, &0));
assert!(index.contains(&keys, &100));
assert!(index.contains(&keys, &0));
assert!(!index.contains(&keys, &1));
assert!(!index.contains(&keys, &99));
}
#[test]
fn test_cached_clear() {
let keys: Vec<u64> = (0..100).collect();
let index = Cached::new(&keys, 16, 4).unwrap();
let _ = index.lower_bound(&keys, &50);
index.clear_cache();
}
}