use std::sync::OnceLock;
pub struct FieldnormEncoder;
impl FieldnormEncoder {
#[inline]
pub fn encode(len: u32) -> u8 {
if len < 8 {
return len as u8;
}
let num_bits = 32 - len.leading_zeros(); let shift = num_bits - 4; let mantissa = (len >> shift) & 0x07; let exponent = shift + 1; ((exponent << 3) | mantissa) as u8
}
#[inline]
pub fn decode(byte: u8) -> u32 {
let b = byte as u32;
if b < 8 {
return b;
}
let mantissa = b & 0x07;
let exponent = (b >> 3) - 1;
if exponent >= 29 {
return u32::MAX;
}
(mantissa | 8) << exponent }
pub fn decode_table() -> &'static [u32; 256] {
static TABLE: OnceLock<[u32; 256]> = OnceLock::new();
TABLE.get_or_init(|| {
let mut table = [0u32; 256];
for i in 0..256u16 {
table[i as usize] = Self::decode(i as u8);
}
table
})
}
pub fn build_norm_table(avg_dl: f32, k1: f32, b: f32) -> [f32; 256] {
let mut table = [0.0f32; 256];
let inv_avg_dl = if avg_dl > 0.0 { 1.0 / avg_dl } else { 0.0 };
for i in 0..256u16 {
let dl = Self::decode(i as u8) as f32;
table[i as usize] = k1 * (1.0 - b + b * dl * inv_avg_dl);
}
table
}
pub fn build_score_table(
avg_dl: f32,
k1: f32,
b: f32,
idf: f32,
max_tf: u16,
) -> Vec<[f32; 256]> {
let norm_table = Self::build_norm_table(avg_dl, k1, b);
let k1p1 = k1 + 1.0;
let mut table = vec![[0.0f32; 256]; max_tf as usize + 1];
for tf in 0..=max_tf {
let tf_f = tf as f32;
for byte in 0..256usize {
let len_norm = norm_table[byte];
table[tf as usize][byte] = idf * tf_f * k1p1 / (tf_f + len_norm);
}
}
table
}
#[inline]
pub fn quantization_error(len: u32) -> f64 {
if len == 0 {
return 0.0;
}
let decoded = Self::decode(Self::encode(len));
(len as f64 - decoded as f64) / len as f64
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_exact_small_values() {
for i in 0..8u32 {
let byte = FieldnormEncoder::encode(i);
assert_eq!(byte, i as u8, "encode({i}) should be {i}");
assert_eq!(FieldnormEncoder::decode(byte), i, "decode(encode({i})) should be {i}");
}
}
#[test]
fn test_exact_values_8_to_15() {
for i in 8..16u32 {
let byte = FieldnormEncoder::encode(i);
let decoded = FieldnormEncoder::decode(byte);
assert_eq!(decoded, i, "decode(encode({i})) = {decoded}, expected {i}");
}
}
#[test]
fn test_known_encodings() {
assert_eq!(FieldnormEncoder::encode(0), 0);
assert_eq!(FieldnormEncoder::encode(1), 1);
assert_eq!(FieldnormEncoder::encode(7), 7);
assert_eq!(FieldnormEncoder::encode(8), 8); assert_eq!(FieldnormEncoder::encode(15), 15); assert_eq!(FieldnormEncoder::encode(16), 16); assert_eq!(FieldnormEncoder::encode(18), 17);
assert_eq!(FieldnormEncoder::decode(0), 0);
assert_eq!(FieldnormEncoder::decode(8), 8);
assert_eq!(FieldnormEncoder::decode(16), 16);
assert_eq!(FieldnormEncoder::decode(17), 18); }
#[test]
fn test_monotonic_encoding() {
let mut prev_byte = 0u8;
for i in 0..100_000u32 {
let byte = FieldnormEncoder::encode(i);
assert!(
byte >= prev_byte,
"encode({i})={byte} < encode({})={prev_byte} — not monotonic",
i - 1
);
prev_byte = byte;
}
}
#[test]
fn test_monotonic_decoding() {
let mut prev = 0u32;
for b in 0..=255u8 {
let val = FieldnormEncoder::decode(b);
assert!(
val >= prev,
"decode({b})={val} < decode({})={prev} — not monotonic",
b - 1
);
prev = val;
}
}
#[test]
fn test_roundtrip_lower_bound() {
for i in 0..100_000u32 {
let decoded = FieldnormEncoder::decode(FieldnormEncoder::encode(i));
assert!(
decoded <= i,
"decode(encode({i})) = {decoded} > {i} — not a lower bound"
);
}
}
#[test]
fn test_quantization_error_bounded() {
for i in 16..100_000u32 {
let error = FieldnormEncoder::quantization_error(i);
assert!(
error < 0.15, "quantization_error({i}) = {error:.4} exceeds 15%"
);
}
}
#[test]
fn test_full_byte_range_used() {
let mut seen = [false; 256];
for i in 0..10_000_000u32 {
seen[FieldnormEncoder::encode(i) as usize] = true;
}
let used = seen.iter().filter(|&&s| s).count();
assert!(used >= 150, "only {used} distinct byte values used — expected >=150");
let max_byte = FieldnormEncoder::encode(u32::MAX);
assert_eq!(max_byte, 239, "max byte should be 239 for u32::MAX");
}
#[test]
fn test_large_values() {
let large = 1_000_000u32;
let byte = FieldnormEncoder::encode(large);
let decoded = FieldnormEncoder::decode(byte);
assert!(decoded <= large);
assert!(decoded > large / 2, "decoded {decoded} too far from {large}");
let byte_max = FieldnormEncoder::encode(u32::MAX);
let decoded_max = FieldnormEncoder::decode(byte_max);
assert!(decoded_max > 0);
}
#[test]
fn test_decode_table() {
let table = FieldnormEncoder::decode_table();
for b in 0..=255u8 {
assert_eq!(
table[b as usize],
FieldnormEncoder::decode(b),
"decode_table mismatch at byte {b}"
);
}
}
#[test]
fn test_norm_table_basic() {
let avg_dl = 100.0f32;
let k1 = 1.2;
let b = 0.75;
let table = FieldnormEncoder::build_norm_table(avg_dl, k1, b);
assert!((table[0] - 0.3).abs() < 1e-6, "table[0] = {}", table[0]);
let avg_byte = FieldnormEncoder::encode(100);
let avg_decoded = FieldnormEncoder::decode(avg_byte);
let expected = k1 * (1.0 - b + b * avg_decoded as f32 / avg_dl);
assert!(
(table[avg_byte as usize] - expected).abs() < 1e-5,
"table[avg_byte] = {}, expected {expected}",
table[avg_byte as usize]
);
for (i, &v) in table.iter().enumerate() {
assert!(v >= 0.0, "norm_table[{i}] = {v} is negative");
}
for i in 1..256 {
assert!(
table[i] >= table[i - 1],
"norm_table[{i}]={} < norm_table[{}]={}",
table[i],
i - 1,
table[i - 1]
);
}
}
#[test]
fn test_norm_table_avg_dl_zero() {
let table = FieldnormEncoder::build_norm_table(0.0, 1.2, 0.75);
for (i, &v) in table.iter().enumerate() {
assert!(v.is_finite(), "norm_table[{i}] = {v} is not finite");
}
}
#[test]
fn test_score_table() {
let avg_dl = 100.0;
let k1 = 1.2;
let b = 0.75;
let idf = 5.0;
let max_tf = 10;
let table = FieldnormEncoder::build_score_table(avg_dl, k1, b, idf, max_tf);
assert_eq!(table.len(), 11);
for byte in 0..256 {
assert_eq!(table[0][byte], 0.0, "score(tf=0, byte={byte}) should be 0");
}
for byte in 0..256 {
assert!(table[1][byte] >= 0.0);
}
for byte in 0..256 {
for tf in 1..max_tf as usize {
assert!(
table[tf + 1][byte] >= table[tf][byte],
"score(tf={}, byte={byte}) < score(tf={tf}, byte={byte})",
tf + 1
);
}
}
}
#[test]
fn test_bm25_scoring_workflow() {
let doc_lengths = vec![50u32, 100, 150, 200, 300];
let fieldnorm_bytes: Vec<u8> = doc_lengths.iter().map(|&l| FieldnormEncoder::encode(l)).collect();
let avg_dl: f32 = doc_lengths.iter().sum::<u32>() as f32 / doc_lengths.len() as f32;
let k1 = 1.2f32;
let b = 0.75f32;
let norm_table = FieldnormEncoder::build_norm_table(avg_dl, k1, b);
let idf = 3.5f32;
let k1p1 = k1 + 1.0;
let scores: Vec<f32> = fieldnorm_bytes
.iter()
.map(|&byte| {
let tf = 2.0f32;
let len_norm = norm_table[byte as usize];
idf * tf * k1p1 / (tf + len_norm)
})
.collect();
assert!(scores[0] > scores[4], "shorter doc should score higher");
for (i, &s) in scores.iter().enumerate() {
assert!(s > 0.0 && s.is_finite(), "score[{i}] = {s}");
}
}
#[test]
#[cfg(not(debug_assertions))]
fn test_encode_decode_performance() {
use std::time::Instant;
let n = 10_000_000u32;
let start = Instant::now();
let mut sum = 0u64;
for i in 0..n {
sum += FieldnormEncoder::encode(i) as u64;
}
let encode_time = start.elapsed();
assert!(sum > 0);
let start = Instant::now();
let mut sum2 = 0u64;
for i in 0..n {
sum2 += FieldnormEncoder::decode((i & 0xFF) as u8) as u64;
}
let decode_time = start.elapsed();
assert!(sum2 > 0);
let table = FieldnormEncoder::decode_table();
let start = Instant::now();
let mut sum3 = 0u64;
for i in 0..n {
sum3 += table[(i & 0xFF) as usize] as u64;
}
let table_time = start.elapsed();
assert!(sum3 > 0);
eprintln!(
"FieldnormEncoder perf ({n} ops): encode={:?}, decode={:?}, table_lookup={:?}",
encode_time, decode_time, table_time
);
}
#[test]
#[cfg(not(debug_assertions))]
fn test_norm_table_vs_float_math() {
use std::time::Instant;
use crate::containers::UintVecMin0;
let n = 1_000_000usize;
let k1 = 1.2f32;
let b = 0.75f32;
let avg_dl = 150.0f32;
let doc_lengths: Vec<u32> = (0..n).map(|i| (i % 500 + 1) as u32).collect();
let fieldnorm_bytes: Vec<u8> = doc_lengths.iter().map(|&l| FieldnormEncoder::encode(l)).collect();
let max_val = *doc_lengths.iter().max().unwrap() as usize;
let mut uint_vec = UintVecMin0::new(n, max_val);
for (i, &dl) in doc_lengths.iter().enumerate() {
uint_vec.set(i, dl as usize);
}
let norm_table = FieldnormEncoder::build_norm_table(avg_dl, k1, b);
let start = Instant::now();
let mut sum1 = 0.0f64;
for i in 0..n {
let tf = 2.0f32;
let len_norm = norm_table[fieldnorm_bytes[i] as usize];
sum1 += (tf * 2.2 / (tf + len_norm)) as f64;
}
let table_time = start.elapsed();
let start = Instant::now();
let mut sum2 = 0.0f64;
for i in 0..n {
let tf = 2.0f32;
let dl = uint_vec.get(i) as f32;
let len_norm = k1 * (1.0 - b + b * dl / avg_dl);
sum2 += (tf * 2.2 / (tf + len_norm)) as f64;
}
let uint_time = start.elapsed();
eprintln!(
"BM25 scoring ({n} docs): Vec<u8>+table={:?}, UintVecMin0+float={:?}, speedup={:.1}x",
table_time, uint_time,
uint_time.as_nanos() as f64 / table_time.as_nanos().max(1) as f64
);
let fieldnorm_mem = fieldnorm_bytes.len(); let uint_mem = uint_vec.mem_size();
eprintln!(
"Memory: Vec<u8>={} bytes ({:.2} B/doc), UintVecMin0={} bytes ({:.2} B/doc)",
fieldnorm_mem, fieldnorm_mem as f64 / n as f64,
uint_mem, uint_mem as f64 / n as f64
);
assert!(sum1 > 0.0 && sum2 > 0.0);
}
}