use crate::error::{EmbedVecError, Result};
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
use serde::{Deserialize, Serialize};
const INV_2PHI: f32 = 0.309_017_0;
const HALF_PHI: f32 = 0.809_016_9;
pub const H4_NUM_VERTICES: usize = 120;
pub const H4_BLOCK_SIZE: usize = 4;
#[derive(Debug, Clone)]
pub struct H4Codec {
dimension: usize,
num_blocks: usize,
use_hadamard: bool,
random_signs: Vec<f32>,
vertices: Vec<[f32; 4]>,
}
impl H4Codec {
pub fn new(dimension: usize, use_hadamard: bool, random_seed: u64) -> Self {
let num_blocks = (dimension + H4_BLOCK_SIZE - 1) / H4_BLOCK_SIZE;
let padded_dim = num_blocks * H4_BLOCK_SIZE;
let mut rng = ChaCha8Rng::seed_from_u64(random_seed);
let random_signs: Vec<f32> = (0..padded_dim)
.map(|_| if rand::Rng::gen::<bool>(&mut rng) { 1.0 } else { -1.0 })
.collect();
let vertices = generate_h4_vertices();
Self {
dimension,
num_blocks,
use_hadamard,
random_signs,
vertices,
}
}
pub fn encode(&self, vector: &[f32]) -> Result<H4EncodedVector> {
if vector.len() != self.dimension {
return Err(EmbedVecError::DimensionMismatch {
expected: self.dimension,
got: vector.len(),
});
}
let mut padded = vec![0.0f32; self.num_blocks * H4_BLOCK_SIZE];
padded[..vector.len()].copy_from_slice(vector);
for (i, v) in padded.iter_mut().enumerate() {
*v *= self.random_signs[i];
}
if self.use_hadamard {
for block in padded.chunks_mut(H4_BLOCK_SIZE) {
hadamard4_inplace(block);
}
}
let max_abs = padded.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
let scale = if max_abs > 1e-10 { max_abs } else { 1.0 };
let mut indices = Vec::with_capacity(self.num_blocks);
for block in padded.chunks(H4_BLOCK_SIZE) {
let normalized: [f32; 4] = std::array::from_fn(|i| block[i] / scale);
let idx = H4Oracle::nearest_vertex_index(&normalized, &self.vertices);
indices.push(idx);
}
Ok(H4EncodedVector { indices, scale })
}
pub fn decode(&self, encoded: &H4EncodedVector) -> Vec<f32> {
let mut result = Vec::with_capacity(self.num_blocks * H4_BLOCK_SIZE);
for &idx in &encoded.indices {
let vertex = &self.vertices[idx as usize];
for &v in vertex {
result.push(v * encoded.scale);
}
}
if self.use_hadamard {
for block in result.chunks_mut(H4_BLOCK_SIZE) {
hadamard4_inplace(block);
}
}
for (i, v) in result.iter_mut().enumerate() {
*v *= self.random_signs[i];
}
result.truncate(self.dimension);
result
}
pub fn asymmetric_distance(&self, query: &[f32], encoded: &H4EncodedVector) -> f32 {
let decoded = self.decode(encoded);
let mut dist = 0.0f32;
for (q, d) in query.iter().zip(decoded.iter()) {
let diff = q - d;
dist += diff * diff;
}
dist.sqrt()
}
pub fn bytes_per_vector(&self) -> usize {
self.num_blocks + 4 }
pub fn num_blocks(&self) -> usize {
self.num_blocks
}
pub fn bits_per_dim(&self) -> f32 {
(H4_NUM_VERTICES as f32).log2() / H4_BLOCK_SIZE as f32
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct H4EncodedVector {
pub indices: Vec<u8>,
pub scale: f32,
}
impl H4EncodedVector {
pub fn empty() -> Self {
Self {
indices: Vec::new(),
scale: 1.0,
}
}
pub fn size_bytes(&self) -> usize {
self.indices.len() + 4
}
}
pub struct H4Oracle;
impl H4Oracle {
#[inline]
pub fn nearest_vertex_index(x: &[f32; 4], vertices: &[[f32; 4]]) -> u8 {
let mut best_idx = 0usize;
let mut best_dist = f32::MAX;
for (idx, vertex) in vertices.iter().enumerate() {
let dist = Self::squared_distance_4d(x, vertex);
if dist < best_dist {
best_dist = dist;
best_idx = idx;
}
}
best_idx as u8
}
#[inline]
pub fn nearest_vertex(x: &[f32; 4], vertices: &[[f32; 4]]) -> [f32; 4] {
let idx = Self::nearest_vertex_index(x, vertices);
vertices[idx as usize]
}
#[inline(always)]
fn squared_distance_4d(a: &[f32; 4], b: &[f32; 4]) -> f32 {
let d0 = a[0] - b[0];
let d1 = a[1] - b[1];
let d2 = a[2] - b[2];
let d3 = a[3] - b[3];
d0 * d0 + d1 * d1 + d2 * d2 + d3 * d3
}
}
#[inline]
pub fn hadamard4_inplace(data: &mut [f32]) {
debug_assert_eq!(data.len(), 4);
let (a, b, c, d) = (data[0], data[1], data[2], data[3]);
let (s0, s1) = (a + b, a - b);
let (s2, s3) = (c + d, c - d);
data[0] = (s0 + s2) * 0.5;
data[1] = (s1 + s3) * 0.5;
data[2] = (s0 - s2) * 0.5;
data[3] = (s1 - s3) * 0.5;
}
pub fn generate_h4_vertices() -> Vec<[f32; 4]> {
let mut vertices = Vec::with_capacity(H4_NUM_VERTICES);
for pos in 0..4 {
for sign in [-1.0f32, 1.0] {
let mut v = [0.0f32; 4];
v[pos] = sign;
vertices.push(v);
}
}
for mask in 0u8..16 {
let v: [f32; 4] = std::array::from_fn(|i| {
if (mask >> i) & 1 == 1 { -0.5 } else { 0.5 }
});
vertices.push(v);
}
let base = [0.0f32, 0.5, INV_2PHI, HALF_PHI];
let even_perms: [[usize; 4]; 12] = [
[0, 1, 2, 3],
[0, 2, 3, 1],
[0, 3, 1, 2],
[1, 0, 3, 2],
[1, 2, 0, 3],
[1, 3, 2, 0],
[2, 0, 1, 3],
[2, 1, 3, 0],
[2, 3, 0, 1],
[3, 0, 2, 1],
[3, 1, 0, 2],
[3, 2, 1, 0],
];
for perm in &even_perms {
for sign_mask in 0u8..8 {
let s1 = if (sign_mask >> 0) & 1 == 1 { -1.0f32 } else { 1.0 };
let s2 = if (sign_mask >> 1) & 1 == 1 { -1.0f32 } else { 1.0 };
let s3 = if (sign_mask >> 2) & 1 == 1 { -1.0f32 } else { 1.0 };
let signs = [0.0f32, s1, s2, s3]; let mut v = [0.0f32; 4];
for (out_idx, &base_idx) in perm.iter().enumerate() {
v[out_idx] = base[base_idx] * signs[base_idx];
}
vertices.push(v);
}
}
vertices
}
pub fn verify_h4_unit_sphere(vertices: &[[f32; 4]]) -> bool {
vertices.iter().all(|v| {
let norm_sq: f32 = v.iter().map(|x| x * x).sum();
(norm_sq - 1.0).abs() < 1e-4
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_h4_vertex_count() {
let vertices = generate_h4_vertices();
assert_eq!(vertices.len(), H4_NUM_VERTICES, "Expected 120 vertices");
}
#[test]
fn test_h4_vertices_on_unit_sphere() {
let vertices = generate_h4_vertices();
for (i, v) in vertices.iter().enumerate() {
let norm_sq: f32 = v.iter().map(|x| x * x).sum();
assert!(
(norm_sq - 1.0).abs() < 1e-4,
"Vertex {} has norm² = {:.6}, expected 1.0",
i, norm_sq
);
}
}
#[test]
fn test_h4_oracle_axis_aligned() {
let vertices = generate_h4_vertices();
let query = [1.0f32, 0.0, 0.0, 0.0];
let idx = H4Oracle::nearest_vertex_index(&query, &vertices);
let result = vertices[idx as usize];
let dist: f32 = result.iter().zip(query.iter()).map(|(a, b)| (a - b).powi(2)).sum();
assert!(dist < 1e-5, "Oracle missed axis-aligned vertex, dist² = {}", dist);
}
#[test]
fn test_h4_oracle_half_half() {
let vertices = generate_h4_vertices();
let query = [0.5f32, 0.5, 0.5, 0.5];
let idx = H4Oracle::nearest_vertex_index(&query, &vertices);
let result = vertices[idx as usize];
let dist: f32 = result.iter().zip(query.iter()).map(|(a, b)| (a - b).powi(2)).sum();
assert!(dist < 1e-5, "Oracle missed Type B vertex, dist² = {}", dist);
}
#[test]
fn test_hadamard4_involution() {
let original = [1.0f32, 2.0, 3.0, 4.0];
let mut data = original;
hadamard4_inplace(&mut data);
hadamard4_inplace(&mut data);
for i in 0..4 {
assert!(
(data[i] - original[i]).abs() < 1e-5,
"Hadamard4 involution failed at index {}: got {}, expected {}",
i, data[i], original[i]
);
}
}
#[test]
fn test_hadamard4_orthogonality() {
let mut e0 = [1.0f32, 0.0, 0.0, 0.0];
let mut e1 = [0.0f32, 1.0, 0.0, 0.0];
hadamard4_inplace(&mut e0);
hadamard4_inplace(&mut e1);
let dot: f32 = e0.iter().zip(e1.iter()).map(|(a, b)| a * b).sum();
assert!(dot.abs() < 1e-5, "Hadamard4 outputs not orthogonal, dot = {}", dot);
}
#[test]
fn test_h4_codec_encode_decode_roundtrip() {
let codec = H4Codec::new(768, true, 0xdeadbeef);
let vector: Vec<f32> = (0..768).map(|i| (i as f32 * 0.01).cos()).collect();
let encoded = codec.encode(&vector).unwrap();
let decoded = codec.decode(&encoded);
assert_eq!(decoded.len(), 768);
assert_eq!(encoded.indices.len(), 192);
let mse: f32 = vector
.iter()
.zip(decoded.iter())
.map(|(a, b)| (a - b).powi(2))
.sum::<f32>()
/ 768.0;
assert!(mse < 0.5, "H4 codec MSE too high: {}", mse);
}
#[test]
fn test_h4_codec_dimension_mismatch() {
let codec = H4Codec::new(8, true, 42);
let result = codec.encode(&[1.0, 2.0, 3.0]);
assert!(result.is_err());
}
#[test]
fn test_h4_codec_memory_efficiency() {
let codec = H4Codec::new(768, true, 42);
let f32_bytes = 768 * 4; let h4_bytes = codec.bytes_per_vector();
let ratio = f32_bytes as f32 / h4_bytes as f32;
assert!(ratio > 3.0, "H4 compression ratio too low: {:.1}×", ratio);
}
#[test]
fn test_h4_bits_per_dim() {
let codec = H4Codec::new(4, false, 0);
let bpd = codec.bits_per_dim();
assert!((bpd - 1.727).abs() < 0.01, "bits_per_dim = {}", bpd);
}
#[test]
fn test_h4_encoded_vector_size() {
let codec = H4Codec::new(16, false, 0);
let v: Vec<f32> = vec![0.5; 16];
let encoded = codec.encode(&v).unwrap();
assert_eq!(encoded.size_bytes(), 8);
}
}