use std::ops::RangeInclusive;
#[derive(Debug, Clone)]
pub struct LearnedRowIdIndex {
pub slope: f64,
pub intercept: f64,
pub residuals_calibration: Vec<f64>,
pub alpha: f64,
radius: u32,
}
impl LearnedRowIdIndex {
#[must_use]
pub fn fit(rowid_to_page: &[(i64, u32)], alpha: f64) -> Self {
let alpha = alpha.clamp(0.0, 1.0);
if rowid_to_page.is_empty() {
return Self {
slope: 0.0,
intercept: 0.0,
residuals_calibration: Vec::new(),
alpha,
radius: u32::MAX,
};
}
if rowid_to_page.len() == 1 {
let (_, p) = rowid_to_page[0];
return Self {
slope: 0.0,
intercept: f64::from(p),
residuals_calibration: Vec::new(),
alpha,
radius: u32::MAX,
};
}
let mut fit_pts: Vec<(f64, f64)> = Vec::with_capacity((rowid_to_page.len() * 7) / 10 + 1);
let mut calib_pts: Vec<(f64, f64)> = Vec::with_capacity((rowid_to_page.len() * 3) / 10 + 1);
for (idx, &(r, p)) in rowid_to_page.iter().enumerate() {
let x = r as f64;
let y = f64::from(p);
let is_calib = matches!(idx % 10, 0 | 3 | 7);
if is_calib {
calib_pts.push((x, y));
} else {
fit_pts.push((x, y));
}
}
if fit_pts.len() < 2 {
fit_pts = rowid_to_page
.iter()
.map(|&(r, p)| (r as f64, f64::from(p)))
.collect();
if calib_pts.is_empty() {
calib_pts.clone_from(&fit_pts);
}
}
let (slope, intercept) = fit_linear_regression(&fit_pts);
let mut residuals: Vec<f64> = calib_pts
.iter()
.map(|&(x, y)| {
let predicted = slope.mul_add(x, intercept);
(predicted - y).abs()
})
.collect();
residuals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let radius = conformal_radius(&residuals, alpha);
Self {
slope,
intercept,
residuals_calibration: residuals,
alpha,
radius,
}
}
#[must_use]
pub fn predict(&self, rowid: i64) -> (u32, u32) {
let x = rowid as f64;
let predicted = self.slope.mul_add(x, self.intercept);
let center = if predicted.is_finite() {
if predicted <= 0.0 {
0
} else if predicted >= f64::from(u32::MAX) {
u32::MAX
} else {
predicted.round() as u32
}
} else {
0
};
(center, self.radius)
}
#[must_use]
pub fn page_range(&self, rowid: i64) -> RangeInclusive<u32> {
let (center, radius) = self.predict(rowid);
let lo = center.saturating_sub(radius);
let hi = center.saturating_add(radius);
lo..=hi
}
#[must_use]
pub fn radius(&self) -> u32 {
self.radius
}
}
fn fit_linear_regression(points: &[(f64, f64)]) -> (f64, f64) {
if points.is_empty() {
return (0.0, 0.0);
}
let n = points.len() as f64;
let (sum_x, sum_y) = points
.iter()
.fold((0.0_f64, 0.0_f64), |(sx, sy), &(x, y)| (sx + x, sy + y));
let mean_x = sum_x / n;
let mean_y = sum_y / n;
let (cov_xy, var_x) = points
.iter()
.fold((0.0_f64, 0.0_f64), |(cxy, vx), &(x, y)| {
let dx = x - mean_x;
let dy = y - mean_y;
(dx.mul_add(dy, cxy), dx.mul_add(dx, vx))
});
if var_x <= f64::EPSILON {
return (0.0, mean_y);
}
let slope = cov_xy / var_x;
let intercept = slope.mul_add(-mean_x, mean_y);
(slope, intercept)
}
fn conformal_radius(sorted_residuals: &[f64], alpha: f64) -> u32 {
if sorted_residuals.is_empty() {
return u32::MAX;
}
let n = sorted_residuals.len();
let target = ((1.0 - alpha) * (n as f64 + 1.0)).ceil();
let k_1indexed = (target as usize).clamp(1, n);
let residual = sorted_residuals[k_1indexed - 1];
let ceil = residual.ceil();
if !ceil.is_finite() || ceil <= 0.0 {
0
} else if ceil >= f64::from(u32::MAX) {
u32::MAX
} else {
ceil as u32
}
}
#[cfg(test)]
mod tests {
use super::*;
use proptest::prelude::*;
fn synthetic_dense(n: i64, rows_per_page: i64) -> Vec<(i64, u32)> {
(0..n)
.map(|r| {
let page = (r / rows_per_page) as u32;
(r, page)
})
.collect()
}
fn rmse(idx: &LearnedRowIdIndex, obs: &[(i64, u32)]) -> f64 {
if obs.is_empty() {
return 0.0;
}
let n = obs.len() as f64;
let sum_sq: f64 = obs
.iter()
.map(|&(r, p)| {
let (center, _) = idx.predict(r);
let err = f64::from(center) - f64::from(p);
err * err
})
.sum();
(sum_sq / n).sqrt()
}
#[test]
fn fit_uniform_dense_is_low_rmse() {
let obs = synthetic_dense(10_000, 100);
let idx = LearnedRowIdIndex::fit(&obs, 0.05);
let rmse = rmse(&idx, &obs);
assert!(
rmse < 50.0,
"RMSE too high for uniform-dense synthetic: {rmse}"
);
println!(
"fit_uniform_dense_is_low_rmse: slope={} intercept={} rmse={} radius={}",
idx.slope,
idx.intercept,
rmse,
idx.radius(),
);
}
#[test]
fn conformal_coverage_meets_guarantee() {
let mut all = synthetic_dense(10_000, 100);
let rot = 12345 % all.len();
all.rotate_left(rot);
let split = (all.len() * 7) / 10;
let train = &all[..split];
let test = &all[split..];
let alpha = 0.1;
let idx = LearnedRowIdIndex::fit(train, alpha);
let mut covered = 0usize;
for &(r, p) in test {
let (center, radius) = idx.predict(r);
let lo = center.saturating_sub(radius);
let hi = center.saturating_add(radius);
if (lo..=hi).contains(&p) {
covered += 1;
}
}
let coverage = covered as f64 / test.len() as f64;
assert!(
coverage >= (1.0 - alpha) - 0.03,
"coverage {coverage} < 1 - alpha - slack = {}",
1.0 - alpha - 0.03
);
println!(
"conformal_coverage_meets_guarantee: alpha={alpha} coverage={coverage} radius={}",
idx.radius()
);
}
#[test]
fn empty_input_degenerates_gracefully() {
let idx = LearnedRowIdIndex::fit(&[], 0.1);
assert_eq!(idx.slope, 0.0);
assert_eq!(idx.intercept, 0.0);
assert!(idx.residuals_calibration.is_empty());
assert_eq!(idx.radius(), u32::MAX);
let (center, radius) = idx.predict(42);
assert_eq!(center, 0);
assert_eq!(radius, u32::MAX);
}
#[test]
fn single_point_degenerates_gracefully() {
let idx = LearnedRowIdIndex::fit(&[(100, 7)], 0.1);
assert_eq!(idx.slope, 0.0);
assert!((idx.intercept - 7.0).abs() < 1e-9);
let (center, _) = idx.predict(100);
assert_eq!(center, 7);
assert_eq!(idx.radius(), u32::MAX);
}
#[test]
fn all_same_rowid_collapses_to_mean() {
let obs: Vec<(i64, u32)> = (0..20).map(|i| (42, i as u32)).collect();
let idx = LearnedRowIdIndex::fit(&obs, 0.1);
assert_eq!(idx.slope, 0.0);
let fit_y_mean: f64 = {
let ys: Vec<f64> = obs
.iter()
.enumerate()
.filter(|(i, _)| !matches!(i % 10, 0 | 3 | 7))
.map(|(_, &(_, p))| f64::from(p))
.collect();
ys.iter().sum::<f64>() / ys.len() as f64
};
assert!(
(idx.intercept - fit_y_mean).abs() < 1e-9,
"intercept {} != fit_y_mean {}",
idx.intercept,
fit_y_mean
);
}
#[test]
fn page_range_saturates_at_u32_bounds() {
let obs = synthetic_dense(100, 10);
let idx = LearnedRowIdIndex::fit(&obs, 0.1);
let range_neg = idx.page_range(i64::MIN);
assert_eq!(*range_neg.start(), 0);
let range_pos = idx.page_range(i64::MAX);
assert_eq!(*range_pos.end(), u32::MAX);
}
#[test]
fn page_range_contains_predict_center() {
let obs = synthetic_dense(1_000, 50);
let idx = LearnedRowIdIndex::fit(&obs, 0.1);
for r in (0..1_000).step_by(17) {
let (center, _) = idx.predict(r);
let range = idx.page_range(r);
assert!(
range.contains(¢er),
"page_range {:?} does not contain predict center {center} for rowid {r}",
range
);
}
}
#[test]
fn predict_radius_is_global_and_page_range_matches_predict() {
let obs = synthetic_dense(1_000, 50);
let idx = LearnedRowIdIndex::fit(&obs, 0.1);
let r = idx.radius();
for rowid in [0_i64, 137, 499, 999] {
let (center, pred_radius) = idx.predict(rowid);
assert_eq!(pred_radius, r);
let range = idx.page_range(rowid);
assert_eq!(*range.start(), center.saturating_sub(r));
assert_eq!(*range.end(), center.saturating_add(r));
}
}
#[test]
fn alpha_clamped_out_of_range() {
let obs = synthetic_dense(1_000, 50);
let a = LearnedRowIdIndex::fit(&obs, -1.0);
let b = LearnedRowIdIndex::fit(&obs, 2.0);
assert!(a.alpha >= 0.0 && a.alpha <= 1.0);
assert!(b.alpha >= 0.0 && b.alpha <= 1.0);
}
proptest! {
#[test]
fn prop_conformal_coverage_uniform_dense(
alpha in 0.05f64..0.3,
seed in 0u64..1_000,
) {
let obs = synthetic_dense(10_000, 100);
let mut permuted = obs;
let mut s = seed.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
for i in (1..permuted.len()).rev() {
s = s.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
let j = (s >> 33) as usize % (i + 1);
permuted.swap(i, j);
}
let split = (permuted.len() * 7) / 10;
let train = &permuted[..split];
let test = &permuted[split..];
let idx = LearnedRowIdIndex::fit(train, alpha);
let mut covered = 0usize;
for &(r, p) in test {
let (center, radius) = idx.predict(r);
let lo = center.saturating_sub(radius);
let hi = center.saturating_add(radius);
if (lo..=hi).contains(&p) {
covered += 1;
}
}
let coverage = covered as f64 / test.len() as f64;
let slack = 0.04;
prop_assert!(
coverage >= (1.0 - alpha) - slack,
"coverage {} < 1 - alpha - slack = {} (alpha={}, seed={})",
coverage,
1.0 - alpha - slack,
alpha,
seed,
);
}
}
}