use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum LookupResult {
Exact(usize),
Range { low: usize, high: usize },
NotFound,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LearnedSparseIndex {
slope: f64,
intercept: f64,
max_error: usize,
corrections: Vec<(u64, usize)>,
min_key: u64,
max_key: u64,
key_range: f64,
num_keys: usize,
correction_threshold: usize,
}
impl LearnedSparseIndex {
const DEFAULT_CORRECTION_THRESHOLD: usize = 64;
pub fn empty() -> Self {
Self {
slope: 0.0,
intercept: 0.0,
max_error: 0,
corrections: Vec::new(),
min_key: 0,
max_key: 0,
key_range: 0.0,
num_keys: 0,
correction_threshold: Self::DEFAULT_CORRECTION_THRESHOLD,
}
}
#[inline]
fn normalize_key(&self, key: u64) -> f64 {
if self.key_range == 0.0 {
return 0.0;
}
let offset = (key as u128).saturating_sub(self.min_key as u128) as f64;
(offset / self.key_range) * (self.num_keys - 1) as f64
}
pub fn build(keys: &[u64]) -> Self {
Self::build_with_threshold(keys, Self::DEFAULT_CORRECTION_THRESHOLD)
}
pub fn build_with_threshold(keys: &[u64], correction_threshold: usize) -> Self {
let n = keys.len();
if n == 0 {
return Self::empty();
}
if n == 1 {
return Self {
slope: 0.0,
intercept: 0.0,
max_error: 0,
corrections: Vec::new(),
min_key: keys[0],
max_key: keys[0],
key_range: 0.0,
num_keys: 1,
correction_threshold,
};
}
let min_key = keys[0];
let max_key = keys[n - 1];
let key_range = (max_key as u128 - min_key as u128) as f64;
let (slope, intercept) = Self::linear_regression_normalized(keys, min_key, key_range, n);
let mut max_error = 0usize;
let mut corrections = Vec::new();
for (actual_pos, &key) in keys.iter().enumerate() {
let normalized = if key_range == 0.0 {
0.0
} else {
let offset = (key as u128 - min_key as u128) as f64;
(offset / key_range) * (n - 1) as f64
};
let predicted = slope * normalized + intercept;
let predicted_pos = predicted.round() as isize;
let error = (actual_pos as isize - predicted_pos).unsigned_abs();
if error > max_error {
max_error = error;
}
if error > correction_threshold {
corrections.push((key, actual_pos));
}
}
Self {
slope,
intercept,
max_error,
corrections,
min_key,
max_key,
key_range,
num_keys: n,
correction_threshold,
}
}
pub fn lookup(&self, key: u64) -> LookupResult {
if self.num_keys == 0 {
return LookupResult::NotFound;
}
if key < self.min_key || key > self.max_key {
return LookupResult::NotFound;
}
if let Ok(idx) = self.corrections.binary_search_by_key(&key, |&(k, _)| k) {
return LookupResult::Exact(self.corrections[idx].1);
}
let normalized = self.normalize_key(key);
let predicted = self.slope * normalized + self.intercept;
let predicted_pos = predicted.round() as isize;
let low = (predicted_pos - self.max_error as isize).max(0) as usize;
let high =
(predicted_pos + self.max_error as isize).min(self.num_keys as isize - 1) as usize;
LookupResult::Range { low, high }
}
pub fn lookup_with_error(&self, key: u64, max_error: usize) -> LookupResult {
if self.num_keys == 0 {
return LookupResult::NotFound;
}
if key < self.min_key || key > self.max_key {
return LookupResult::NotFound;
}
if let Ok(idx) = self.corrections.binary_search_by_key(&key, |&(k, _)| k) {
return LookupResult::Exact(self.corrections[idx].1);
}
let normalized = self.normalize_key(key);
let predicted = self.slope * normalized + self.intercept;
let predicted_pos = predicted.round() as isize;
let low = (predicted_pos - max_error as isize).max(0) as usize;
let high = (predicted_pos + max_error as isize).min(self.num_keys as isize - 1) as usize;
LookupResult::Range { low, high }
}
pub fn stats(&self) -> LearnedIndexStats {
LearnedIndexStats {
num_keys: self.num_keys,
max_error: self.max_error,
num_corrections: self.corrections.len(),
slope: self.slope,
intercept: self.intercept,
correction_ratio: if self.num_keys > 0 {
self.corrections.len() as f64 / self.num_keys as f64
} else {
0.0
},
}
}
pub fn is_efficient(&self) -> bool {
let low_error = self.max_error <= 128;
let low_corrections =
self.num_keys == 0 || (self.corrections.len() as f64 / self.num_keys as f64) < 0.05;
low_error && low_corrections
}
pub fn memory_bytes(&self) -> usize {
std::mem::size_of::<Self>()
+ self.corrections.len() * (std::mem::size_of::<u64>() + std::mem::size_of::<usize>())
}
fn linear_regression_normalized(
keys: &[u64],
min_key: u64,
key_range: f64,
n: usize,
) -> (f64, f64) {
let n_f64 = n as f64;
let mut sum_x: f64 = 0.0;
let mut sum_y: f64 = 0.0;
let mut sum_xy: f64 = 0.0;
let mut sum_xx: f64 = 0.0;
for (i, &key) in keys.iter().enumerate() {
let x = if key_range == 0.0 {
0.0
} else {
let offset = (key as u128 - min_key as u128) as f64;
(offset / key_range) * (n - 1) as f64
};
let y = i as f64;
sum_x += x;
sum_y += y;
sum_xy += x * y;
sum_xx += x * x;
}
let denominator = n_f64 * sum_xx - sum_x * sum_x;
if denominator.abs() < f64::EPSILON {
return (0.0, sum_y / n_f64);
}
let slope = (n_f64 * sum_xy - sum_x * sum_y) / denominator;
let intercept = (sum_y - slope * sum_x) / n_f64;
(slope, intercept)
}
#[allow(dead_code)]
fn linear_regression(keys: &[u64]) -> (f64, f64) {
let n = keys.len() as f64;
let mut sum_x: f64 = 0.0;
let mut sum_y: f64 = 0.0;
let mut sum_xy: f64 = 0.0;
let mut sum_xx: f64 = 0.0;
for (i, &key) in keys.iter().enumerate() {
let x = key as f64;
let y = i as f64;
sum_x += x;
sum_y += y;
sum_xy += x * y;
sum_xx += x * x;
}
let denominator = n * sum_xx - sum_x * sum_x;
if denominator.abs() < f64::EPSILON {
return (0.0, sum_y / n);
}
let slope = (n * sum_xy - sum_x * sum_y) / denominator;
let intercept = (sum_y - slope * sum_x) / n;
(slope, intercept)
}
pub fn insert(&mut self, key: u64, position: usize, keys: &[u64]) -> bool {
let normalized = self.normalize_key(key);
let predicted = self.slope * normalized + self.intercept;
let predicted_pos = predicted.round() as isize;
let error = (position as isize - predicted_pos).unsigned_abs();
self.min_key = self.min_key.min(key);
self.max_key = self.max_key.max(key);
self.key_range = (self.max_key as u128 - self.min_key as u128) as f64;
self.num_keys += 1;
if error > self.max_error {
self.max_error = error;
}
if error > self.correction_threshold {
match self.corrections.binary_search_by_key(&key, |&(k, _)| k) {
Ok(idx) => self.corrections[idx] = (key, position),
Err(idx) => self.corrections.insert(idx, (key, position)),
}
}
if self.corrections.len() > self.num_keys / 10 {
*self = Self::build_with_threshold(keys, self.correction_threshold);
return true;
}
false
}
}
#[derive(Debug, Clone)]
pub struct LearnedIndexStats {
pub num_keys: usize,
pub max_error: usize,
pub num_corrections: usize,
pub slope: f64,
pub intercept: f64,
pub correction_ratio: f64,
}
#[derive(Debug, Clone)]
pub struct PiecewiseLearnedIndex {
boundaries: Vec<u64>,
segments: Vec<LearnedSparseIndex>,
}
impl PiecewiseLearnedIndex {
pub fn build(keys: &[u64], max_segments: usize) -> Self {
if keys.is_empty() || max_segments == 0 {
return Self {
boundaries: vec![],
segments: vec![],
};
}
let segment_size = keys.len().div_ceil(max_segments);
let mut boundaries = Vec::with_capacity(max_segments);
let mut segments = Vec::with_capacity(max_segments);
for chunk in keys.chunks(segment_size) {
if !chunk.is_empty() {
boundaries.push(chunk[0]);
segments.push(LearnedSparseIndex::build(chunk));
}
}
Self {
boundaries,
segments,
}
}
fn find_segment(&self, key: u64) -> Option<usize> {
if self.boundaries.is_empty() {
return None;
}
match self.boundaries.binary_search(&key) {
Ok(i) => Some(i),
Err(i) => {
if i == 0 {
None
} else {
Some(i - 1)
}
}
}
}
pub fn lookup(&self, key: u64) -> LookupResult {
match self.find_segment(key) {
Some(seg_idx) => self.segments[seg_idx].lookup(key),
None => LookupResult::NotFound,
}
}
pub fn stats(&self) -> PiecewiseStats {
let segment_stats: Vec<_> = self.segments.iter().map(|s| s.stats()).collect();
let total_keys: usize = segment_stats.iter().map(|s| s.num_keys).sum();
let max_error = segment_stats.iter().map(|s| s.max_error).max().unwrap_or(0);
let total_corrections: usize = segment_stats.iter().map(|s| s.num_corrections).sum();
PiecewiseStats {
num_segments: self.segments.len(),
total_keys,
max_error,
total_corrections,
avg_segment_size: if self.segments.is_empty() {
0.0
} else {
total_keys as f64 / self.segments.len() as f64
},
}
}
}
#[derive(Debug, Clone)]
pub struct PiecewiseStats {
pub num_segments: usize,
pub total_keys: usize,
pub max_error: usize,
pub total_corrections: usize,
pub avg_segment_size: f64,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty_index() {
let index = LearnedSparseIndex::build(&[]);
assert_eq!(index.lookup(42), LookupResult::NotFound);
assert_eq!(index.stats().num_keys, 0);
}
#[test]
fn test_single_key() {
let index = LearnedSparseIndex::build(&[100]);
assert!(matches!(
index.lookup(100),
LookupResult::Range { low: 0, high: 0 }
));
assert_eq!(index.lookup(50), LookupResult::NotFound);
assert_eq!(index.lookup(150), LookupResult::NotFound);
}
#[test]
fn test_sequential_keys() {
let keys: Vec<u64> = (0..1000).collect();
let index = LearnedSparseIndex::build(&keys);
let stats = index.stats();
assert!(
stats.max_error <= 1,
"Sequential keys should have near-zero error"
);
assert!(
stats.num_corrections == 0,
"No corrections needed for linear data"
);
if let LookupResult::Range { low, high } = index.lookup(500) {
assert!(low <= 500 && high >= 500, "Key 500 should be in range");
assert!(high - low <= 2, "Range should be very tight");
}
}
#[test]
fn test_timestamp_like_keys() {
let mut keys: Vec<u64> = Vec::new();
let mut ts: u64 = 1704067200; for _ in 0..10000 {
keys.push(ts);
ts += 1 + (ts % 10); }
let index = LearnedSparseIndex::build(&keys);
assert!(
index.is_efficient(),
"Timestamp data should be efficiently indexable"
);
for &key in keys.iter().take(100) {
let result = index.lookup(key);
assert!(
!matches!(result, LookupResult::NotFound),
"Existing key should be found"
);
}
}
#[test]
fn test_sparse_keys() {
let keys: Vec<u64> = vec![1, 100, 10000, 1000000, 100000000];
let index = LearnedSparseIndex::build(&keys);
for (i, &key) in keys.iter().enumerate() {
match index.lookup(key) {
LookupResult::Exact(pos) => assert_eq!(pos, i),
LookupResult::Range { low, high } => {
assert!(
low <= i && i <= high,
"Key {} should be in range [{}, {}]",
key,
low,
high
);
}
LookupResult::NotFound => panic!("Key {} should be found", key),
}
}
}
#[test]
fn test_out_of_bounds() {
let keys: Vec<u64> = (100..200).collect();
let index = LearnedSparseIndex::build(&keys);
assert_eq!(index.lookup(50), LookupResult::NotFound);
assert_eq!(index.lookup(250), LookupResult::NotFound);
}
#[test]
fn test_piecewise_index() {
let mut keys: Vec<u64> = Vec::new();
keys.extend(0..1000); keys.extend((100000..101000).step_by(10)); keys.extend(1000000..1001000);
let piecewise = PiecewiseLearnedIndex::build(&keys, 3);
let stats = piecewise.stats();
assert_eq!(stats.num_segments, 3);
assert!(!matches!(piecewise.lookup(500), LookupResult::NotFound));
assert!(!matches!(piecewise.lookup(100500), LookupResult::NotFound));
assert!(!matches!(piecewise.lookup(1000500), LookupResult::NotFound));
}
#[test]
fn test_memory_efficiency() {
let keys: Vec<u64> = (0..100000).collect();
let index = LearnedSparseIndex::build(&keys);
let lsi_bytes = index.memory_bytes();
let btree_bytes = keys.len() * std::mem::size_of::<u64>();
assert!(
lsi_bytes < btree_bytes,
"LSI ({} bytes) should use less memory than keys alone ({} bytes)",
lsi_bytes,
btree_bytes
);
}
#[test]
fn test_correction_threshold() {
let mut keys: Vec<u64> = (0..100).map(|x| x * 10).collect();
keys.push(5000); keys.sort();
let low_thresh = LearnedSparseIndex::build_with_threshold(&keys, 10);
let high_thresh = LearnedSparseIndex::build_with_threshold(&keys, 1000);
assert!(
low_thresh.stats().num_corrections >= high_thresh.stats().num_corrections,
"Lower threshold should produce more or equal corrections"
);
}
#[test]
fn test_large_key_normalization() {
let base = u64::MAX - 1000;
let keys: Vec<u64> = (0..100).map(|i| base + i * 10).collect();
let index = LearnedSparseIndex::build(&keys);
assert!(
index.max_error < 10,
"Error should be small for linear data"
);
for (i, &key) in keys.iter().enumerate() {
let result = index.lookup(key);
match result {
LookupResult::Range { low, high } => {
assert!(
low <= i && i <= high,
"Key {} at position {} should be in range [{}, {}]",
key,
i,
low,
high
);
}
LookupResult::Exact(pos) => {
assert_eq!(pos, i, "Exact position should match");
}
LookupResult::NotFound => {
panic!("Key {} should be found", key);
}
}
}
}
#[test]
fn test_full_range_keys() {
let keys: Vec<u64> = vec![
0,
1_000_000,
1_000_000_000,
1_000_000_000_000,
1_000_000_000_000_000,
u64::MAX / 2,
u64::MAX - 1000,
u64::MAX - 100,
u64::MAX - 10,
u64::MAX - 1,
];
let index = LearnedSparseIndex::build(&keys);
for (i, &key) in keys.iter().enumerate() {
let result = index.lookup(key);
match result {
LookupResult::Range { low, high } => {
assert!(
low <= i && i <= high,
"Key {} at position {} should be in range [{}, {}]",
key,
i,
low,
high
);
}
LookupResult::Exact(pos) => {
assert_eq!(pos, i, "Exact position should match");
}
LookupResult::NotFound => {
panic!("Key {} should be found", key);
}
}
}
}
#[test]
fn test_timestamp_keys() {
let base_ts: u64 = 1_700_000_000_000_000;
let keys: Vec<u64> = (0..1000).map(|i| base_ts + i * 1000).collect();
let index = LearnedSparseIndex::build(&keys);
assert!(
index.max_error <= 1,
"Error for sequential timestamps should be ≤ 1, got {}",
index.max_error
);
assert!(
index.is_efficient(),
"Sequential timestamp data should be efficient"
);
}
#[test]
fn test_normalization_precision() {
let index = LearnedSparseIndex {
slope: 1.0,
intercept: 0.0,
max_error: 0,
corrections: Vec::new(),
min_key: 0,
max_key: 99,
key_range: 99.0,
num_keys: 100,
correction_threshold: 64,
};
assert!((index.normalize_key(0) - 0.0).abs() < f64::EPSILON);
assert!((index.normalize_key(99) - 99.0).abs() < f64::EPSILON);
assert!((index.normalize_key(49) - 49.0).abs() < 0.5);
}
}