#![allow(clippy::cast_precision_loss)]
use std::mem::size_of;
use crate::{
Key, PgmError, Result, Segment,
build::{build_lut, build_segments},
consts::MIN_EPSILON,
};
#[cfg_attr(feature = "bitcode", derive(bitcode::Encode, bitcode::Decode))]
#[derive(Clone, Debug)]
pub struct Pgm<K: Key> {
pub epsilon: usize,
pub segments: Vec<Segment<K>>,
pub lut: Vec<u32>,
pub scale: f64,
pub min_key: f64,
pub len: usize,
}
impl<K: Key> Pgm<K> {
pub fn new(data: &[K], epsilon: usize, check_sorted: bool) -> Result<Self> {
if epsilon < MIN_EPSILON {
return Err(PgmError::InvalidEpsilon {
provided: epsilon,
min: MIN_EPSILON,
});
}
if data.is_empty() {
return Err(PgmError::EmptyData);
}
if check_sorted && !is_sorted(data) {
return Err(PgmError::NotSorted);
}
let segments = build_segments(data, epsilon);
let (lut, scale, min_key) = build_lut(data, &segments);
Ok(Self {
epsilon,
segments,
lut,
scale,
min_key,
len: data.len(),
})
}
#[inline]
#[must_use]
pub fn len(&self) -> usize {
self.len
}
#[inline]
#[must_use]
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[inline]
#[must_use]
pub fn segment_count(&self) -> usize {
self.segments.len()
}
#[inline]
#[must_use]
pub fn avg_segment_size(&self) -> f64 {
self.len as f64 / self.segments.len().max(1) as f64
}
#[inline]
#[must_use]
pub fn mem_usage(&self) -> usize {
self.segments.len() * size_of::<Segment<K>>() + self.lut.len() * size_of::<u32>()
}
#[inline]
#[must_use]
pub fn predict(&self, key: K) -> usize {
let seg = self.find_seg(key);
predict_in_seg(seg, key.as_f64())
}
#[inline]
#[must_use]
pub fn predict_range(&self, key: K) -> (usize, usize) {
let seg = self.find_seg(key);
let pred = predict_in_seg(seg, key.as_f64());
let start = pred
.saturating_sub(self.epsilon)
.max(seg.start_idx as usize);
let end = (pred + self.epsilon + 1).min(seg.end_idx as usize);
(start, end)
}
#[inline]
pub fn find<'a, Q, F>(&self, key: &Q, get_key: F) -> usize
where
Q: crate::ToKey<K> + ?Sized,
F: Fn(usize) -> Option<&'a [u8]>,
{
let k = key.to_key();
let (lo, hi) = self.predict_range(k);
let key_bytes = key.as_bytes();
let mut left = lo;
let mut right = hi;
while left < right {
let mid = left + (right - left) / 2;
match get_key(mid) {
Some(mk) if mk < key_bytes => left = mid + 1,
_ => right = mid,
}
}
left
}
#[inline]
pub fn find_key<F>(&self, key: K, get_key: F) -> usize
where
F: Fn(usize) -> Option<K>,
{
let (lo, hi) = self.predict_range(key);
let mut left = lo;
let mut right = hi;
while left < right {
let mid = left + (right - left) / 2;
match get_key(mid) {
Some(k) if k < key => left = mid + 1,
_ => right = mid,
}
}
left
}
#[inline]
fn find_seg(&self, key: K) -> &Segment<K> {
if self.segments.len() <= 1 {
unsafe { self.segments.get_unchecked(0) }
} else {
let y = key.as_f64();
let idx_candidate = (y - self.min_key) * self.scale;
let lut_max = (self.lut.len() - 1) as isize;
let idx_i = idx_candidate as isize;
let bin = if idx_i < 0 {
0
} else if idx_i >= lut_max {
lut_max as usize
} else {
idx_i as usize
};
let mut idx = unsafe { *self.lut.get_unchecked(bin) } as usize;
let mut seg = unsafe { self.segments.get_unchecked(idx) };
while idx + 1 < self.segments.len() {
if key <= seg.max_key {
break;
}
idx += 1;
seg = unsafe { self.segments.get_unchecked(idx) };
}
while idx > 0 {
if key >= seg.min_key {
break;
}
idx -= 1;
seg = unsafe { self.segments.get_unchecked(idx) };
}
seg
}
}
}
#[inline]
fn predict_in_seg(seg: &Segment<impl Key>, key_f64: f64) -> usize {
let pos = seg.slope.mul_add(key_f64, seg.intercept) + 0.5;
let lo = seg.start_idx as usize;
let hi = (seg.end_idx - 1) as usize;
(pos as usize).clamp(lo, hi)
}
#[inline]
fn is_sorted<K: Ord>(data: &[K]) -> bool {
data.windows(2).all(|w| w[0] <= w[1])
}