pub fn hilbert_encode(coords: &[u16], bits_per_dim: u8) -> u128 {
let num_dims = coords.len();
let total_bits = (bits_per_dim as usize) * num_dims;
assert!(
total_bits <= 128,
"Hilbert code exceeds 128 bits: {} dims × {} bits = {} bits",
num_dims,
bits_per_dim,
total_bits
);
if num_dims == 0 {
return 0;
}
let mut result: u128 = 0;
let mut rotation = 0u32;
let mut reflection = 0u32;
for bit_idx in (0..bits_per_dim).rev() {
let mut chunk = 0u32;
for (dim_idx, &coord) in coords.iter().enumerate() {
let bit = ((coord >> bit_idx) & 1) as u32;
chunk |= bit << dim_idx;
}
let transformed = apply_transform(chunk, rotation, reflection, num_dims);
for dim in 0..num_dims {
let bit = (transformed >> dim) & 1;
result = (result << 1) | (bit as u128);
}
let gray = transformed ^ (transformed >> 1);
update_transform(&mut rotation, &mut reflection, gray, num_dims);
}
result
}
pub fn hilbert_decode(code: u128, num_dims: u8, bits_per_dim: u8) -> Vec<u16> {
let num_dims = num_dims as usize;
let total_bits = bits_per_dim as usize * num_dims;
assert!(
total_bits <= 128,
"Hilbert code exceeds 128 bits: {} dims × {} bits = {} bits",
num_dims,
bits_per_dim,
total_bits
);
if num_dims == 0 {
return Vec::new();
}
let mut result = vec![0u16; num_dims];
let mut rotation = 0u32;
let mut reflection = 0u32;
let mut bit_pos = total_bits;
for bit_idx in (0..bits_per_dim).rev() {
let mut transformed = 0u32;
for dim in 0..num_dims {
bit_pos -= 1;
let bit = (code >> bit_pos) & 1;
transformed |= (bit as u32) << dim;
}
let chunk = apply_inverse_transform(transformed, rotation, reflection, num_dims);
for (dim_idx, coord) in result.iter_mut().enumerate() {
let bit = (chunk >> dim_idx) & 1;
*coord |= (bit as u16) << bit_idx;
}
let gray = transformed ^ (transformed >> 1);
update_transform(&mut rotation, &mut reflection, gray, num_dims);
}
result
}
fn apply_transform(point: u32, rotation: u32, reflection: u32, num_dims: usize) -> u32 {
let mask = (1u32 << num_dims) - 1;
let rotated = if rotation == 0 {
point
} else {
let rot = rotation % (num_dims as u32);
((point << rot) | (point >> (num_dims - rot as usize))) & mask
};
rotated ^ reflection
}
fn apply_inverse_transform(point: u32, rotation: u32, reflection: u32, num_dims: usize) -> u32 {
let mask = (1u32 << num_dims) - 1;
let unreflected = point ^ reflection;
if rotation == 0 {
unreflected
} else {
let rot = rotation % (num_dims as u32);
((unreflected >> rot) | (unreflected << (num_dims - rot as usize))) & mask
}
}
fn update_transform(rotation: &mut u32, reflection: &mut u32, gray: u32, num_dims: usize) {
if num_dims == 2 {
match gray {
0 => {
*rotation = (*rotation + 1) % 4;
*reflection ^= 0b11;
}
1 => {
}
2 => {
}
3 => {
*rotation = (*rotation + 3) % 4;
*reflection ^= 0b11;
}
_ => {}
}
} else {
let trailing_zeros = gray.trailing_zeros();
*rotation = (*rotation + trailing_zeros) % (num_dims as u32);
if gray & 1 == 0 {
*reflection ^= 1 << (gray.trailing_zeros() % num_dims as u32);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hilbert_encode_2d_zeros() {
let coords = vec![0, 0];
let code = hilbert_encode(&coords, 3);
let decoded = hilbert_decode(code, 2, 3);
assert_eq!(decoded, coords);
}
#[test]
fn test_hilbert_encode_10d() {
let coords = vec![512, 256, 128, 64, 32, 16, 8, 4, 2, 1];
let code = hilbert_encode(&coords, 10);
assert!(code > 0, "Code should be non-zero");
}
#[test]
fn test_roundtrip_2d() {
let original = vec![5, 3];
let encoded = hilbert_encode(&original, 3);
let decoded = hilbert_decode(encoded, 2, 3);
assert_eq!(decoded, original);
}
#[test]
fn test_roundtrip_3d() {
let original = vec![7, 5, 3];
let encoded = hilbert_encode(&original, 4);
let decoded = hilbert_decode(encoded, 3, 4);
assert_eq!(decoded, original);
}
#[test]
fn test_roundtrip_10d() {
let original = vec![512, 256, 128, 64, 32, 16, 8, 4, 2, 1];
let encoded = hilbert_encode(&original, 10);
let decoded = hilbert_decode(encoded, 10, 10);
assert_eq!(decoded, original);
}
#[test]
fn test_roundtrip_all_zeros() {
let original = vec![0; 10];
let encoded = hilbert_encode(&original, 10);
let decoded = hilbert_decode(encoded, 10, 10);
assert_eq!(decoded, original);
}
#[test]
fn test_roundtrip_all_max() {
let original = vec![1023; 10];
let encoded = hilbert_encode(&original, 10);
let decoded = hilbert_decode(encoded, 10, 10);
assert_eq!(decoded, original);
}
#[test]
fn test_single_dimension() {
let original = vec![42];
let encoded = hilbert_encode(&original, 8);
let decoded = hilbert_decode(encoded, 1, 8);
assert_eq!(decoded, original);
}
#[test]
#[should_panic(expected = "Hilbert code exceeds 128 bits")]
fn test_encode_exceeds_128_bits() {
let coords = vec![1; 10];
hilbert_encode(&coords, 13);
}
#[test]
#[should_panic(expected = "Hilbert code exceeds 128 bits")]
fn test_decode_exceeds_128_bits() {
hilbert_decode(0, 10, 13);
}
#[test]
fn test_various_dimensions() {
for dims in 2..=10 {
let coords = vec![100; dims];
let encoded = hilbert_encode(&coords, 10);
let decoded = hilbert_decode(encoded, dims as u8, 10);
assert_eq!(decoded, coords, "Roundtrip failed for {} dimensions", dims);
}
}
#[test]
fn test_encoding_deterministic() {
let coords = vec![512, 256, 128, 64, 32];
let code1 = hilbert_encode(&coords, 10);
let code2 = hilbert_encode(&coords, 10);
assert_eq!(code1, code2, "Encoding should be deterministic");
}
#[test]
fn test_hilbert_vs_morton_locality_2d() {
use super::super::morton::morton_encode;
let center = vec![512, 512];
let neighbors = vec![
vec![513, 512], vec![512, 513], vec![513, 513], ];
let hilbert_center = hilbert_encode(¢er, 10);
let morton_center = morton_encode(¢er, 10);
let mut hilbert_distances = Vec::new();
let mut morton_distances = Vec::new();
for neighbor in &neighbors {
let h_code = hilbert_encode(neighbor, 10);
let m_code = morton_encode(neighbor, 10);
let h_dist = (hilbert_center ^ h_code).count_ones();
let m_dist = (morton_center ^ m_code).count_ones();
hilbert_distances.push(h_dist);
morton_distances.push(m_dist);
}
let avg_hilbert: f64 = hilbert_distances.iter().map(|&d| d as f64).sum::<f64>()
/ hilbert_distances.len() as f64;
let avg_morton: f64 =
morton_distances.iter().map(|&d| d as f64).sum::<f64>() / morton_distances.len() as f64;
assert!(
avg_hilbert < 50.0,
"Hilbert should preserve locality (avg distance: {})",
avg_hilbert
);
assert!(
avg_morton < 50.0,
"Morton should preserve locality (avg distance: {})",
avg_morton
);
}
#[test]
fn test_hilbert_vs_morton_locality_10d() {
use super::super::morton::morton_encode;
let center = vec![512; 10];
let neighbor = vec![513, 512, 512, 512, 512, 512, 512, 512, 512, 512];
let hilbert_center = hilbert_encode(¢er, 10);
let hilbert_neighbor = hilbert_encode(&neighbor, 10);
let morton_center = morton_encode(¢er, 10);
let morton_neighbor = morton_encode(&neighbor, 10);
let hilbert_dist = (hilbert_center ^ hilbert_neighbor).count_ones();
let morton_dist = (morton_center ^ morton_neighbor).count_ones();
assert!(
hilbert_dist < 50,
"Hilbert should preserve locality in 10D, got distance: {}",
hilbert_dist
);
assert!(
morton_dist < 50,
"Morton should preserve locality in 10D, got distance: {}",
morton_dist
);
}
}