#![allow(clippy::cast_precision_loss)]
use std::mem::size_of;
use crate::{
Key, PGMStats, Segment,
error::{PGMError, Result},
pgm::{
build::{build_lut, build_segments},
consts::MIN_EPSILON,
search::predict,
},
};
#[cfg_attr(feature = "bitcode", derive(bitcode::Encode, bitcode::Decode))]
#[derive(Clone, Debug)]
pub struct PGMIndex<K: Key> {
epsilon: usize,
data: Vec<K>,
segments: Vec<Segment<K>>,
lut: Vec<usize>,
scale: f64,
min_key: f64,
}
impl<K: Key> PGMIndex<K> {
pub fn load(data: Vec<K>, epsilon: usize, check_sorted: bool) -> Result<Self> {
if epsilon < MIN_EPSILON {
return Err(PGMError::InvalidEpsilon {
provided: epsilon,
min: MIN_EPSILON,
});
}
if data.is_empty() {
return Err(PGMError::EmptyData);
}
if check_sorted && !is_sorted_fast(&data) {
return Err(PGMError::NotSorted);
}
let segments = build_segments(&data, epsilon);
let (lut, scale, min_key) = build_lut(&data, &segments);
Ok(Self {
epsilon,
data,
segments,
lut,
scale,
min_key,
})
}
#[inline]
#[must_use]
pub fn stats(&self) -> PGMStats {
PGMStats {
segments: self.segments.len(),
avg_segment_size: self.data.len() as f64 / self.segments.len().max(1) as f64,
memory_bytes: self.mem_usage(),
}
}
#[inline]
#[must_use]
pub fn segment_count(&self) -> usize {
self.segments.len()
}
#[inline]
#[must_use]
pub fn avg_segment_size(&self) -> f64 {
self.data.len() as f64 / self.segments.len().max(1) as f64
}
#[inline]
#[must_use]
pub fn mem_usage(&self) -> usize {
self.data.len() * size_of::<K>()
+ self.segments.len() * size_of::<Segment<K>>()
+ self.lut.len() * size_of::<usize>()
}
#[inline]
#[must_use]
pub fn memory_usage(&self) -> usize {
self.mem_usage()
}
#[inline]
#[must_use]
pub fn data(&self) -> &[K] {
&self.data
}
#[inline]
#[must_use]
pub fn epsilon(&self) -> usize {
self.epsilon
}
#[inline]
#[must_use]
pub fn get(&self, key: K) -> Option<usize> {
let seg = if self.segments.len() <= 1 {
unsafe { self.segments.get_unchecked(0) }
} else {
let y = key.as_f64();
let idx_candidate = (y - self.min_key) * self.scale;
let lut_max = (self.lut.len() - 1) as isize;
let idx_i = idx_candidate as isize;
let bin = if idx_i < 0 {
0
} else if idx_i > lut_max {
lut_max as usize
} else {
idx_i as usize
};
let mut idx = unsafe { *self.lut.get_unchecked(bin) };
let mut seg = unsafe { self.segments.get_unchecked(idx) };
while idx + 1 < self.segments.len() {
if key <= seg.max_key {
break;
}
idx += 1;
seg = unsafe { self.segments.get_unchecked(idx) };
}
while idx > 0 {
if key >= seg.min_key {
break;
}
idx -= 1;
seg = unsafe { self.segments.get_unchecked(idx) };
}
seg
};
if key < seg.min_key || key > seg.max_key {
return None;
}
let predicted = predict(seg, key.as_f64());
let start = predicted.saturating_sub(self.epsilon).max(seg.start_idx);
let end = (predicted.saturating_add(self.epsilon + 1)).min(seg.end_idx);
unsafe {
let slice = self.data.get_unchecked(start..end);
if let Ok(pos) = slice.binary_search(&key) {
return Some(start + pos);
}
}
None
}
#[inline]
pub fn get_many<'a, I>(&'a self, keys: I) -> impl Iterator<Item = Option<usize>> + 'a
where
I: IntoIterator<Item = K> + 'a,
<I as IntoIterator>::IntoIter: 'a,
{
keys.into_iter().map(move |k| self.get(k))
}
#[inline]
pub fn count_hits<I>(&self, keys: I) -> usize
where
I: IntoIterator<Item = K>,
{
keys.into_iter().filter(|&k| self.get(k).is_some()).count()
}
#[inline]
#[must_use]
pub fn predict_pos(&self, key: K) -> usize {
let seg = self.find_seg(key);
predict(seg, key.as_f64())
}
#[inline]
fn find_seg(&self, key: K) -> &Segment<K> {
if self.segments.len() <= 1 {
unsafe { self.segments.get_unchecked(0) }
} else {
let y = key.as_f64();
let idx_candidate = (y - self.min_key) * self.scale;
let lut_max = (self.lut.len() - 1) as isize;
let idx_i = idx_candidate as isize;
let bin = if idx_i < 0 {
0
} else if idx_i > lut_max {
lut_max as usize
} else {
idx_i as usize
};
let mut idx = unsafe { *self.lut.get_unchecked(bin) };
let mut seg = unsafe { self.segments.get_unchecked(idx) };
while idx + 1 < self.segments.len() {
if key <= seg.max_key {
break;
}
idx += 1;
seg = unsafe { self.segments.get_unchecked(idx) };
}
while idx > 0 {
if key >= seg.min_key {
break;
}
idx -= 1;
seg = unsafe { self.segments.get_unchecked(idx) };
}
seg
}
}
}
#[inline]
fn is_sorted_fast<K: Ord>(data: &[K]) -> bool {
data.windows(2).all(|w| w[0] <= w[1])
}