use crate::error::{EmbedVecError, Result};
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
use serde::{Deserialize, Serialize};
pub const E8_NUM_ROOTS: usize = 240;
pub const E8_BLOCK_SIZE: usize = 8;
#[derive(Debug, Clone)]
pub struct E8Codec {
dimension: usize,
num_blocks: usize,
bits_per_block: u8,
use_hadamard: bool,
hadamard: Option<HadamardTransform>,
random_signs: Vec<f32>,
scale: f32,
}
impl E8Codec {
pub fn new(dimension: usize, bits_per_block: u8, use_hadamard: bool, random_seed: u64) -> Self {
let num_blocks = (dimension + E8_BLOCK_SIZE - 1) / E8_BLOCK_SIZE;
let padded_dim = num_blocks * E8_BLOCK_SIZE;
let mut rng = ChaCha8Rng::seed_from_u64(random_seed);
let random_signs: Vec<f32> = (0..padded_dim)
.map(|_| if rand::Rng::gen::<bool>(&mut rng) { 1.0 } else { -1.0 })
.collect();
let hadamard = if use_hadamard {
Some(HadamardTransform::new(E8_BLOCK_SIZE))
} else {
None
};
Self {
dimension,
num_blocks,
bits_per_block,
use_hadamard,
hadamard,
random_signs,
scale: 1.0,
}
}
pub fn encode(&self, vector: &[f32]) -> Result<E8EncodedVector> {
if vector.len() != self.dimension {
return Err(EmbedVecError::DimensionMismatch {
expected: self.dimension,
got: vector.len(),
});
}
let mut padded = vec![0.0f32; self.num_blocks * E8_BLOCK_SIZE];
padded[..vector.len()].copy_from_slice(vector);
for (i, v) in padded.iter_mut().enumerate() {
*v *= self.random_signs[i];
}
if let Some(ref hadamard) = self.hadamard {
for block in padded.chunks_mut(E8_BLOCK_SIZE) {
hadamard.transform_inplace(block);
}
}
let max_abs = padded.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
let scale = if max_abs > 1e-10 {
max_abs / self.codebook_range()
} else {
1.0
};
let mut points = Vec::with_capacity(self.num_blocks);
for block in padded.chunks(E8_BLOCK_SIZE) {
let scaled_block: Vec<f32> = block.iter().map(|x| x / scale).collect();
let (is_half, point) = E8Oracle::nearest_point(&scaled_block);
points.push(E8Point {
coords: std::array::from_fn(|i| {
if is_half {
(point[i] - 0.5).round() as i8
} else {
point[i].round() as i8
}
}),
is_half,
});
}
Ok(E8EncodedVector { points, scale })
}
pub fn decode(&self, encoded: &E8EncodedVector) -> Vec<f32> {
let mut result = Vec::with_capacity(self.num_blocks * E8_BLOCK_SIZE);
for point in &encoded.points {
let coords = point.to_f32();
for &v in &coords {
result.push(v * encoded.scale);
}
}
if let Some(ref hadamard) = self.hadamard {
for block in result.chunks_mut(E8_BLOCK_SIZE) {
hadamard.inverse_transform_inplace(block);
}
}
for (i, v) in result.iter_mut().enumerate() {
*v *= self.random_signs[i];
}
result.truncate(self.dimension);
result
}
pub fn asymmetric_distance(&self, query: &[f32], encoded: &E8EncodedVector) -> f32 {
let decoded = self.decode(encoded);
let mut dist = 0.0f32;
for (q, d) in query.iter().zip(decoded.iter()) {
let diff = q - d;
dist += diff * diff;
}
dist.sqrt()
}
fn codebook_range(&self) -> f32 {
match self.bits_per_block {
8 => 4.0,
10 => 5.0,
12 => 6.0,
_ => 5.0,
}
}
pub fn bytes_per_vector(&self) -> usize {
let code_bits = self.num_blocks * self.bits_per_block as usize;
let code_bytes = (code_bits + 7) / 8;
code_bytes + 4 }
pub fn num_blocks(&self) -> usize {
self.num_blocks
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct E8EncodedVector {
pub points: Vec<E8Point>,
pub scale: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct E8Point {
pub coords: [i8; 8],
pub is_half: bool,
}
impl E8Point {
pub fn to_f32(&self) -> [f32; 8] {
let offset = if self.is_half { 0.5 } else { 0.0 };
std::array::from_fn(|i| self.coords[i] as f32 + offset)
}
pub fn from_d8(coords: &[f32; 8]) -> Self {
Self {
coords: std::array::from_fn(|i| coords[i] as i8),
is_half: false,
}
}
pub fn from_d8_half(coords: &[f32; 8]) -> Self {
Self {
coords: std::array::from_fn(|i| (coords[i] - 0.5) as i8),
is_half: true,
}
}
}
impl E8EncodedVector {
pub fn empty() -> Self {
Self {
points: Vec::new(),
scale: 1.0,
}
}
pub fn size_bytes(&self) -> usize {
self.points.len() * 9 + 4
}
}
#[derive(Debug, Clone)]
pub struct HadamardTransform {
size: usize,
norm_factor: f32,
}
impl HadamardTransform {
pub fn new(size: usize) -> Self {
assert!(size.is_power_of_two(), "Hadamard size must be power of 2");
Self {
size,
norm_factor: 1.0 / (size as f32).sqrt(),
}
}
pub fn transform_inplace(&self, data: &mut [f32]) {
assert_eq!(data.len(), self.size);
self.fwht_inplace(data);
for v in data.iter_mut() {
*v *= self.norm_factor;
}
}
pub fn inverse_transform_inplace(&self, data: &mut [f32]) {
assert_eq!(data.len(), self.size);
self.fwht_inplace(data);
for v in data.iter_mut() {
*v *= self.norm_factor;
}
}
fn fwht_inplace(&self, data: &mut [f32]) {
let n = data.len();
let mut h = 1;
while h < n {
for i in (0..n).step_by(h * 2) {
for j in i..(i + h) {
let x = data[j];
let y = data[j + h];
data[j] = x + y;
data[j + h] = x - y;
}
}
h *= 2;
}
}
}
pub struct E8Oracle;
impl E8Oracle {
pub fn nearest_point(x: &[f32]) -> (bool, [f32; 8]) {
assert_eq!(x.len(), 8);
let (d8_point, d8_dist) = Self::nearest_d8(x);
let (d8_half_point, d8_half_dist) = Self::nearest_d8_half(x);
if d8_dist <= d8_half_dist {
(false, d8_point)
} else {
(true, d8_half_point)
}
}
#[inline]
fn nearest_d8(x: &[f32]) -> ([f32; 8], f32) {
let r0 = x[0].round();
let r1 = x[1].round();
let r2 = x[2].round();
let r3 = x[3].round();
let r4 = x[4].round();
let r5 = x[5].round();
let r6 = x[6].round();
let r7 = x[7].round();
let mut rounded = [r0, r1, r2, r3, r4, r5, r6, r7];
let sum = (r0 as i32) + (r1 as i32) + (r2 as i32) + (r3 as i32)
+ (r4 as i32) + (r5 as i32) + (r6 as i32) + (r7 as i32);
if sum & 1 != 0 {
let res = [
(x[0] - r0).abs(),
(x[1] - r1).abs(),
(x[2] - r2).abs(),
(x[3] - r3).abs(),
(x[4] - r4).abs(),
(x[5] - r5).abs(),
(x[6] - r6).abs(),
(x[7] - r7).abs(),
];
let mut max_idx = 0;
let mut min_val = res[0];
for i in 1..8 {
if res[i] < min_val {
min_val = res[i];
max_idx = i;
}
}
let residual = x[max_idx] - rounded[max_idx];
if residual > 0.0 {
rounded[max_idx] += 1.0;
} else {
rounded[max_idx] -= 1.0;
}
}
let d0 = x[0] - rounded[0];
let d1 = x[1] - rounded[1];
let d2 = x[2] - rounded[2];
let d3 = x[3] - rounded[3];
let d4 = x[4] - rounded[4];
let d5 = x[5] - rounded[5];
let d6 = x[6] - rounded[6];
let d7 = x[7] - rounded[7];
let dist = d0*d0 + d1*d1 + d2*d2 + d3*d3 + d4*d4 + d5*d5 + d6*d6 + d7*d7;
(rounded, dist)
}
#[inline]
fn nearest_d8_half(x: &[f32]) -> ([f32; 8], f32) {
let shifted: [f32; 8] = std::array::from_fn(|i| x[i] - 0.5);
let (d8_point, _) = Self::nearest_d8(&shifted);
let result: [f32; 8] = std::array::from_fn(|i| d8_point[i] + 0.5);
let dist = Self::squared_distance(x, &result);
(result, dist)
}
fn squared_distance(a: &[f32], b: &[f32; 8]) -> f32 {
let mut sum = 0.0f32;
for i in 0..8 {
let diff = a[i] - b[i];
sum += diff * diff;
}
sum
}
pub fn decode_point(code: u16) -> [f32; 8] {
let is_d8_half = (code & 0x8000) != 0;
let base_code = code & 0x7FFF;
let mut point: [f32; 8] = [0.0; 8];
for i in 0..8 {
let shift = (i * 2).min(14);
let vi = ((base_code >> shift) & 0x3) as i32;
point[i] = (vi - 1) as f32;
}
let sum: i32 = point.iter().map(|&v| v as i32).sum();
if sum % 2 != 0 {
point[0] += 1.0;
}
if is_d8_half {
for v in point.iter_mut() {
*v += 0.5;
}
}
point
}
}
pub fn generate_e8_roots() -> Vec<[f32; 8]> {
let mut roots = Vec::with_capacity(E8_NUM_ROOTS);
for i in 0..8 {
for j in (i + 1)..8 {
for si in [-1.0f32, 1.0] {
for sj in [-1.0f32, 1.0] {
let mut v = [0.0f32; 8];
v[i] = si;
v[j] = sj;
roots.push(v);
}
}
}
}
for mask in 0u8..=255 {
if mask.count_ones() % 2 == 0 {
let v: [f32; 8] = std::array::from_fn(|i| {
if (mask >> i) & 1 == 1 {
-0.5
} else {
0.5
}
});
roots.push(v);
}
}
assert_eq!(roots.len(), E8_NUM_ROOTS);
roots
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hadamard_transform() {
let hadamard = HadamardTransform::new(8);
let mut data = [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
let original = data;
hadamard.transform_inplace(&mut data);
hadamard.inverse_transform_inplace(&mut data);
for i in 0..8 {
assert!((data[i] - original[i]).abs() < 1e-5);
}
}
#[test]
fn test_e8_oracle_zero() {
let x = [0.0f32; 8];
let (_is_half, point) = E8Oracle::nearest_point(&x);
let dist: f32 = point.iter().map(|v| v * v).sum();
assert!(dist < 1e-5);
}
#[test]
fn test_e8_oracle_unit() {
let x = [1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
let (_is_half, point) = E8Oracle::nearest_point(&x);
assert!((point[0] - 1.0).abs() < 1e-5);
assert!((point[1] - 1.0).abs() < 1e-5);
for i in 2..8 {
assert!(point[i].abs() < 1e-5);
}
}
#[test]
fn test_e8_codec_roundtrip() {
let codec = E8Codec::new(768, 10, true, 42);
let vector: Vec<f32> = (0..768).map(|i| (i as f32 * 0.01).sin()).collect();
let encoded = codec.encode(&vector).unwrap();
let decoded = codec.decode(&encoded);
assert_eq!(decoded.len(), 768);
let mse: f32 = vector
.iter()
.zip(decoded.iter())
.map(|(a, b)| (a - b).powi(2))
.sum::<f32>()
/ 768.0;
assert!(mse < 0.5, "MSE too high: {}", mse);
}
#[test]
fn test_e8_roots_count() {
let roots = generate_e8_roots();
assert_eq!(roots.len(), 240);
for root in &roots {
let norm_sq: f32 = root.iter().map(|x| x * x).sum();
assert!((norm_sq - 2.0).abs() < 1e-5, "Root norm^2 = {}", norm_sq);
}
}
#[test]
fn test_codec_memory_savings() {
let codec = E8Codec::new(768, 10, true, 42);
let f32_bytes = 768 * 4; let e8_bytes = codec.bytes_per_vector();
let ratio = f32_bytes as f32 / e8_bytes as f32;
assert!(ratio > 3.0, "Compression ratio too low: {}", ratio);
}
}