use std::fmt;
use std::sync::atomic::{AtomicU64, Ordering};
static LEARNED_INDEX_LOOKUPS_TOTAL: AtomicU64 = AtomicU64::new(0);
static LEARNED_INDEX_PREDICTION_ERROR_TOTAL: AtomicU64 = AtomicU64::new(0);
static LEARNED_INDEX_SEGMENTS_TOTAL: AtomicU64 = AtomicU64::new(0);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub struct LearnedIndexMetricsSnapshot {
pub lookups_total: u64,
pub prediction_error_total: u64,
pub segments_total: u64,
}
impl fmt::Display for LearnedIndexMetricsSnapshot {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"lookups={} prediction_error_total={} segments={}",
self.lookups_total, self.prediction_error_total, self.segments_total,
)
}
}
#[must_use]
pub fn learned_index_metrics_snapshot() -> LearnedIndexMetricsSnapshot {
LearnedIndexMetricsSnapshot {
lookups_total: LEARNED_INDEX_LOOKUPS_TOTAL.load(Ordering::Relaxed),
prediction_error_total: LEARNED_INDEX_PREDICTION_ERROR_TOTAL.load(Ordering::Relaxed),
segments_total: LEARNED_INDEX_SEGMENTS_TOTAL.load(Ordering::Relaxed),
}
}
pub fn reset_learned_index_metrics() {
LEARNED_INDEX_LOOKUPS_TOTAL.store(0, Ordering::Relaxed);
LEARNED_INDEX_PREDICTION_ERROR_TOTAL.store(0, Ordering::Relaxed);
LEARNED_INDEX_SEGMENTS_TOTAL.store(0, Ordering::Relaxed);
}
fn record_lookup(error: usize) {
LEARNED_INDEX_LOOKUPS_TOTAL.fetch_add(1, Ordering::Relaxed);
LEARNED_INDEX_PREDICTION_ERROR_TOTAL.fetch_add(error as u64, Ordering::Relaxed);
}
#[derive(Debug, Clone)]
struct Segment {
key_lo: u64,
key_hi: u64,
#[allow(dead_code)]
pos_lo: usize,
slope: f64,
intercept: f64,
}
impl Segment {
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
fn predict(&self, key: u64) -> usize {
let delta = key as f64 - self.key_lo as f64;
let predicted = self.slope.mul_add(delta, self.intercept);
predicted.round().max(0.0) as usize
}
}
#[derive(Debug, Clone, Copy)]
pub struct LearnedIndexConfig {
pub max_error: usize,
}
impl Default for LearnedIndexConfig {
fn default() -> Self {
Self { max_error: 16 }
}
}
pub struct LearnedIndex {
keys: Vec<u64>,
segments: Vec<Segment>,
max_error: usize,
}
impl LearnedIndex {
pub fn build(keys: &[u64], config: LearnedIndexConfig) -> Self {
assert!(keys.windows(2).all(|w| w[0] <= w[1]), "keys must be sorted");
let segments = train_piecewise_linear(keys, config.max_error);
LEARNED_INDEX_SEGMENTS_TOTAL.fetch_add(segments.len() as u64, Ordering::Relaxed);
Self {
keys: keys.to_vec(),
segments,
max_error: config.max_error,
}
}
pub fn lookup(&self, key: u64) -> Option<usize> {
if self.keys.is_empty() {
record_lookup(0);
return None;
}
let Some(seg_idx) = self.find_segment(key) else {
record_lookup(0);
return None;
};
let seg = &self.segments[seg_idx];
let predicted = seg.predict(key);
let lo = predicted.saturating_sub(self.max_error);
let hi = predicted
.saturating_add(self.max_error)
.saturating_add(1)
.min(self.keys.len());
for i in lo..hi {
match self.keys[i].cmp(&key) {
std::cmp::Ordering::Equal => {
let error = predicted.abs_diff(i);
record_lookup(error);
return Some(i);
}
std::cmp::Ordering::Greater => break,
std::cmp::Ordering::Less => {}
}
}
record_lookup(0);
None
}
#[must_use]
pub fn len(&self) -> usize {
self.keys.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.keys.is_empty()
}
#[must_use]
pub fn num_segments(&self) -> usize {
self.segments.len()
}
#[must_use]
pub fn max_error(&self) -> usize {
self.max_error
}
#[must_use]
pub fn keys(&self) -> &[u64] {
&self.keys
}
#[must_use]
pub fn max_observed_error(&self) -> usize {
if self.keys.is_empty() {
return 0;
}
let mut max_err = 0usize;
for (actual_pos, &key) in self.keys.iter().enumerate() {
if let Some(seg_idx) = self.find_segment(key) {
let predicted = self.segments[seg_idx].predict(key);
let err = predicted.abs_diff(actual_pos);
max_err = max_err.max(err);
}
}
max_err
}
fn find_segment(&self, key: u64) -> Option<usize> {
if self.segments.is_empty() {
return None;
}
let idx = self.segments.partition_point(|seg| seg.key_lo <= key);
if idx == 0 {
None
} else {
let seg_idx = idx - 1;
if key <= self.segments[seg_idx].key_hi {
Some(seg_idx)
} else {
None
}
}
}
}
impl fmt::Debug for LearnedIndex {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("LearnedIndex")
.field("num_keys", &self.keys.len())
.field("num_segments", &self.segments.len())
.field("max_error", &self.max_error)
.finish()
}
}
fn train_piecewise_linear(keys: &[u64], max_error: usize) -> Vec<Segment> {
if keys.is_empty() {
return Vec::new();
}
if keys.len() == 1 {
return vec![Segment {
key_lo: keys[0],
key_hi: keys[0],
pos_lo: 0,
slope: 0.0,
intercept: 0.0,
}];
}
let mut segments = Vec::new();
let mut seg_start = 0usize;
while seg_start < keys.len() {
let mut seg_end = seg_start;
let mut slope_min = 0.0f64;
let mut slope_max = f64::INFINITY;
for i in seg_start + 1..keys.len() {
let dx = keys[i] as f64 - keys[seg_start] as f64;
let dy = (i - seg_start) as f64;
if dx == 0.0 {
if i - seg_start > max_error {
break;
}
seg_end = i;
continue;
}
let p_min = (dy - max_error as f64) / dx;
let p_max = (dy + max_error as f64) / dx;
let new_min = slope_min.max(p_min);
let new_max = slope_max.min(p_max);
if new_min > new_max {
break;
}
slope_min = new_min;
slope_max = new_max;
seg_end = i;
}
let dx = keys[seg_end] as f64 - keys[seg_start] as f64;
let slope = if dx > 0.0 {
((seg_end - seg_start) as f64 / dx).clamp(slope_min, slope_max)
} else {
0.0
};
segments.push(Segment {
key_lo: keys[seg_start],
key_hi: keys[seg_end],
pos_lo: seg_start,
slope,
intercept: seg_start as f64,
});
seg_start = seg_end + 1;
}
segments
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn basic_lookup() {
let keys: Vec<u64> = (0..100).collect();
let idx = LearnedIndex::build(&keys, LearnedIndexConfig::default());
assert_eq!(idx.len(), 100);
assert!(!idx.is_empty());
for &k in &keys {
#[allow(clippy::cast_possible_truncation)]
let expected = k as usize;
assert_eq!(idx.lookup(k), Some(expected));
}
assert_eq!(idx.lookup(100), None);
assert_eq!(idx.lookup(999), None);
}
#[test]
fn uniform_distribution() {
let keys: Vec<u64> = (0..1000).map(|i| i * 10).collect();
let idx = LearnedIndex::build(&keys, LearnedIndexConfig { max_error: 1 });
assert!(
idx.num_segments() <= 5,
"uniform distribution should need few segments, got {}",
idx.num_segments()
);
for (pos, &k) in keys.iter().enumerate() {
assert_eq!(idx.lookup(k), Some(pos));
}
}
#[test]
fn empty_index() {
let idx = LearnedIndex::build(&[], LearnedIndexConfig::default());
assert!(idx.is_empty());
assert_eq!(idx.len(), 0);
assert_eq!(idx.num_segments(), 0);
assert_eq!(idx.lookup(42), None);
}
#[test]
fn single_key() {
let idx = LearnedIndex::build(&[42], LearnedIndexConfig::default());
assert_eq!(idx.len(), 1);
assert_eq!(idx.num_segments(), 1);
assert_eq!(idx.lookup(42), Some(0));
assert_eq!(idx.lookup(43), None);
}
#[test]
fn error_bound_respected() {
let keys: Vec<u64> = (0..500).map(|i| i * i).collect(); let config = LearnedIndexConfig { max_error: 32 };
let idx = LearnedIndex::build(&keys, config);
let max_err = idx.max_observed_error();
assert!(
max_err <= config.max_error,
"max observed error {max_err} exceeds bound {}",
config.max_error
);
}
#[test]
fn max_error_and_keys_accessors_reflect_build_inputs() {
let keys = vec![1_u64, 4, 9, 16, 25, 36];
let idx = LearnedIndex::build(&keys, LearnedIndexConfig { max_error: 8 });
assert_eq!(idx.max_error(), 8);
assert_eq!(idx.keys(), keys.as_slice());
assert_eq!(idx.len(), keys.len());
let def = LearnedIndex::build(&keys, LearnedIndexConfig::default());
assert_eq!(def.max_error(), 16);
}
#[test]
fn lookup_with_extreme_key_does_not_overflow() {
let keys: Vec<u64> = (0..128).collect();
let idx = LearnedIndex::build(&keys, LearnedIndexConfig::default());
assert_eq!(idx.lookup(u64::MAX), None);
}
#[test]
fn nonuniform_all_present_findable_and_in_range_gaps_absent() {
let keys: Vec<u64> = (0..300u64).map(|i| i * i).collect();
let idx = LearnedIndex::build(&keys, LearnedIndexConfig { max_error: 16 });
for (pos, &k) in keys.iter().enumerate() {
assert_eq!(
idx.lookup(k),
Some(pos),
"present key {k} must be found at {pos}"
);
}
for &k in &keys[2..60] {
assert_eq!(
idx.lookup(k - 1),
None,
"in-gap key {} must be absent",
k - 1
);
}
let clustered = vec![0u64, 1, 2, 1000, 1001, 1002];
let idx2 = LearnedIndex::build(&clustered, LearnedIndexConfig { max_error: 4 });
for (pos, &k) in clustered.iter().enumerate() {
assert_eq!(idx2.lookup(k), Some(pos));
}
assert_eq!(
idx2.lookup(500),
None,
"a key in the large interior gap is absent"
);
}
}