use crate::error::{Error, Result};
use super::index::LshIndex;
#[derive(Copy, Clone, Debug)]
pub struct LshIndexBuilder {
pub bands: usize,
pub rows: usize,
}
impl LshIndexBuilder {
#[must_use]
pub fn new(bands: usize, rows: usize) -> Self {
Self { bands, rows }
}
pub fn for_threshold(threshold: f32, h: usize) -> Result<Self> {
if !(threshold > 0.0 && threshold < 1.0) {
return Err(Error::Config(alloc::format!(
"threshold must be in (0.0, 1.0); got {threshold}"
)));
}
if h == 0 {
return Err(Error::Config("H cannot be zero".into()));
}
let mut best = None;
let mut best_err = f32::INFINITY;
for b in 1..=h {
if h % b != 0 {
continue;
}
let r = h / b;
let err = fp_rate(threshold, b, r) + fn_rate(threshold, b, r);
if err < best_err {
best_err = err;
best = Some((b, r));
}
}
match best {
Some((bands, rows)) => Ok(Self { bands, rows }),
None => Err(Error::Config(alloc::format!(
"no factor pair found for H={h}"
))),
}
}
pub fn build<const H: usize>(self) -> LshIndex<H> {
self.try_build()
.expect("bands * rows must equal H; use try_build for a Result")
}
pub fn try_build<const H: usize>(self) -> Result<LshIndex<H>> {
LshIndex::with_bands_rows(self.bands, self.rows)
}
}
#[inline]
fn prob_match(t: f32, b: usize, r: usize) -> f32 {
let inner = 1.0 - t.powi(r as i32);
let inner = inner.clamp(0.0, 1.0);
1.0 - inner.powi(b as i32)
}
#[inline]
fn fp_rate(threshold: f32, b: usize, r: usize) -> f32 {
integrate(0.0, threshold, |t| prob_match(t, b, r))
}
#[inline]
fn fn_rate(threshold: f32, b: usize, r: usize) -> f32 {
integrate(threshold, 1.0, |t| 1.0 - prob_match(t, b, r))
}
fn integrate<F>(lo: f32, hi: f32, mut f: F) -> f32
where
F: FnMut(f32) -> f32,
{
if hi <= lo {
return 0.0;
}
const N: usize = 200;
let dx = (hi - lo) / N as f32;
let mut sum = 0.5 * (f(lo) + f(hi));
for i in 1..N {
let x = lo + i as f32 * dx;
sum += f(x);
}
sum * dx
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn builder_new_round_trip() {
let b = LshIndexBuilder::new(16, 8);
assert_eq!(b.bands, 16);
assert_eq!(b.rows, 8);
}
#[test]
fn for_threshold_finds_factor_pair() {
let b = LshIndexBuilder::for_threshold(0.7, 128).unwrap();
assert_eq!(b.bands * b.rows, 128);
}
#[test]
fn for_threshold_rejects_extremes() {
assert!(LshIndexBuilder::for_threshold(0.0, 128).is_err());
assert!(LshIndexBuilder::for_threshold(1.0, 128).is_err());
assert!(LshIndexBuilder::for_threshold(-0.1, 128).is_err());
assert!(LshIndexBuilder::for_threshold(1.1, 128).is_err());
}
#[test]
fn for_threshold_rejects_zero_h() {
assert!(LshIndexBuilder::for_threshold(0.5, 0).is_err());
}
#[test]
fn higher_threshold_picks_more_rows() {
let b_low = LshIndexBuilder::for_threshold(0.3, 128).unwrap();
let b_high = LshIndexBuilder::for_threshold(0.9, 128).unwrap();
assert!(b_high.rows >= b_low.rows, "{b_high:?} vs {b_low:?}");
}
#[test]
fn try_build_rejects_mismatched_h() {
let b = LshIndexBuilder::new(16, 8);
let r: Result<LshIndex<64>> = b.try_build();
assert!(matches!(r, Err(Error::Config(_))));
}
#[test]
fn try_build_succeeds_when_h_matches() {
let b = LshIndexBuilder::new(16, 8);
let r: Result<LshIndex<128>> = b.try_build();
assert!(r.is_ok());
}
#[test]
fn prob_match_is_zero_at_zero_jaccard() {
let p = prob_match(0.0, 16, 8);
assert!(p.abs() < 1e-6, "got {p}");
}
#[test]
fn prob_match_is_one_at_one_jaccard() {
let p = prob_match(1.0, 16, 8);
assert!((p - 1.0).abs() < 1e-6, "got {p}");
}
#[test]
fn prob_match_is_monotone() {
let a = prob_match(0.3, 16, 8);
let b = prob_match(0.5, 16, 8);
let c = prob_match(0.8, 16, 8);
assert!(a < b && b < c);
}
}