use ndarray::ArrayView2;
pub fn encode(
vectors: &[f32],
n: usize,
dim: usize,
rotation: &[f32], boundaries: &[f32],
centroids: &[f32],
bit_width: usize,
) -> (Vec<u8>, Vec<f32>) {
let mut norms = vec![0.0f32; n];
let mut unit_flat = vec![0.0f32; n * dim];
for i in 0..n {
let row = &vectors[i * dim..(i + 1) * dim];
let norm: f32 = row.iter().map(|x| x * x).sum::<f32>().sqrt();
norms[i] = norm;
let inv_norm = if norm > 1e-10 { 1.0 / norm } else { 0.0 };
for j in 0..dim {
unit_flat[i * dim + j] = row[j] * inv_norm;
}
}
let unit_mat = ArrayView2::from_shape((n, dim), &unit_flat).unwrap();
let rot_mat = ArrayView2::from_shape((dim, dim), rotation).unwrap();
let rotated_mat = unit_mat.dot(&rot_mat.t());
let rotated = rotated_mat.as_slice().unwrap();
let mut codes = vec![0u8; n * dim];
for b in boundaries {
for idx in 0..n * dim {
if rotated[idx] > *b {
codes[idx] += 1;
}
}
}
let mut scales = vec![0.0f32; n];
for i in 0..n {
let row_start = i * dim;
let mut inner = 0.0f64;
for j in 0..dim {
let r = rotated[row_start + j] as f64;
let c = centroids[codes[row_start + j] as usize] as f64;
inner += r * c;
}
let inner = inner.max(1e-10) as f32;
scales[i] = norms[i] / inner;
}
let packed = pack_codes(&codes, n, dim, bit_width);
(packed, scales)
}
fn pack_codes(codes: &[u8], n: usize, dim: usize, bits: usize) -> Vec<u8> {
let bytes_per_plane = dim / 8;
let bytes_per_row = bits * bytes_per_plane;
let mut packed = vec![0u8; n * bytes_per_row];
for i in 0..n {
for j in 0..dim {
let code = codes[i * dim + j];
let byte_pos = j / 8;
let bit_pos = 7 - (j % 8);
for p in 0..bits {
if code & (1 << p) != 0 {
packed[i * bytes_per_row + p * bytes_per_plane + byte_pos] |= 1 << bit_pos;
}
}
}
}
packed
}