use std::io::{Read, Write};
use crate::error::{FastTextError, Result};
use crate::model::MinstdRng;
use crate::utils;
use crate::vector::Vector;
pub const NBITS: i32 = 8;
pub const KSUB: i32 = 1 << NBITS;
pub const MAX_POINTS: i32 = 256 * KSUB;
pub const NITER: i32 = 25;
const SEED: u64 = 1234;
const EPS: f32 = 1e-7;
#[inline]
fn dist_l2(x: &[f32], y: &[f32]) -> f32 {
debug_assert_eq!(x.len(), y.len());
x.iter()
.zip(y.iter())
.map(|(&a, &b)| {
let d = a - b;
d * d
})
.sum()
}
fn assign_centroid(x: &[f32], centroids: &[f32], code: &mut u8, d: usize) -> f32 {
let first = ¢roids[..d];
let mut best_dist = dist_l2(x, first);
*code = 0;
for j in 1..KSUB as usize {
let c = ¢roids[j * d..(j + 1) * d];
let dist = dist_l2(x, c);
if dist < best_dist {
*code = j as u8;
best_dist = dist;
}
}
best_dist
}
fn estep(x: &[f32], centroids: &[f32], codes: &mut [u8], d: usize, n: usize) {
for i in 0..n {
let xi = &x[i * d..(i + 1) * d];
assign_centroid(xi, centroids, &mut codes[i], d);
}
}
fn mstep(rng: &mut MinstdRng, x: &[f32], centroids: &mut [f32], codes: &[u8], d: usize, n: usize) {
let ksub = KSUB as usize;
let mut nelts = vec![0i32; ksub];
centroids[..d * ksub].fill(0.0);
for i in 0..n {
let k = codes[i] as usize;
let xi = &x[i * d..(i + 1) * d];
let c = &mut centroids[k * d..(k + 1) * d];
for (cj, &xj) in c.iter_mut().zip(xi.iter()) {
*cj += xj;
}
nelts[k] += 1;
}
for k in 0..ksub {
let z = nelts[k] as f32;
if z != 0.0 {
let c = &mut centroids[k * d..(k + 1) * d];
for cj in c.iter_mut() {
*cj /= z;
}
}
}
for k in 0..ksub {
if nelts[k] == 0 {
let mut m = 0usize;
while rng.uniform_real() * (n as f64 - ksub as f64) >= (nelts[m] - 1) as f64 {
m = (m + 1) % ksub;
}
let (k_start, m_start) = (k * d, m * d);
for j in 0..d {
centroids[k_start + j] = centroids[m_start + j];
}
for j in 0..d {
let sign = if j % 2 == 1 { 1.0f32 } else { -1.0f32 };
centroids[k_start + j] += sign * EPS;
centroids[m_start + j] -= sign * EPS;
}
nelts[k] = nelts[m] / 2;
nelts[m] -= nelts[k];
}
}
}
fn kmeans(rng: &mut MinstdRng, x: &[f32], centroids: &mut [f32], n: usize, d: usize) {
let ksub = KSUB as usize;
let mut perm: Vec<i32> = (0..n as i32).collect();
rng.shuffle(&mut perm);
for i in 0..ksub {
let row = perm[i] as usize;
centroids[i * d..(i + 1) * d].copy_from_slice(&x[row * d..row * d + d]);
}
let mut codes = vec![0u8; n];
for _ in 0..NITER {
estep(x, centroids, &mut codes, d, n);
mstep(rng, x, centroids, &codes, d, n);
}
}
#[derive(Debug, Clone)]
pub struct ProductQuantizer {
pub dim: i32,
pub nsubq: i32,
pub dsub: i32,
pub lastdsub: i32,
pub centroids: Vec<f32>,
rng: MinstdRng,
}
impl ProductQuantizer {
pub fn new(dim: i32, dsub: i32) -> Self {
let nsubq_base = dim / dsub;
let rem = dim % dsub;
let (nsubq, lastdsub) = if rem == 0 {
(nsubq_base, dsub)
} else {
(nsubq_base + 1, rem)
};
ProductQuantizer {
dim,
nsubq,
dsub,
lastdsub,
centroids: vec![0.0f32; (dim * KSUB) as usize],
rng: MinstdRng::new(SEED),
}
}
#[inline]
fn centroid_range(&self, m: usize, i: u8) -> (usize, usize) {
let ksub = KSUB as usize;
let dsub = self.dsub as usize;
let nsubq = self.nsubq as usize;
let d = if m == nsubq - 1 {
self.lastdsub as usize
} else {
dsub
};
let offset = m * ksub * dsub + i as usize * d;
(offset, d)
}
#[inline]
pub fn get_centroids(&self, m: usize, i: u8) -> &[f32] {
let (off, d) = self.centroid_range(m, i);
&self.centroids[off..off + d]
}
#[inline]
pub fn get_centroids_mut(&mut self, m: usize, i: u8) -> &mut [f32] {
let (off, d) = self.centroid_range(m, i);
&mut self.centroids[off..off + d]
}
pub fn train(&mut self, n: i32, x: &[f32]) {
if n <= 0 {
return;
}
let n_usize = n as usize;
let ksub = KSUB as usize;
let dim = self.dim as usize;
let dsub = self.dsub as usize;
let nsubq = self.nsubq as usize;
if n_usize < ksub {
for m in 0..nsubq {
let d = if m == nsubq - 1 {
self.lastdsub as usize
} else {
dsub
};
let cstart = m * ksub * dsub;
for j in 0..ksub {
let row = j % n_usize;
let src = &x[row * dim + m * dsub..row * dim + m * dsub + d];
self.centroids[cstart + j * d..cstart + (j + 1) * d].copy_from_slice(src);
}
}
return;
}
let np = n_usize.min(MAX_POINTS as usize);
let mut perm: Vec<i32> = (0..n_usize as i32).collect();
let max_d = dsub.max(self.lastdsub as usize);
let mut xslice = vec![0.0f32; np * max_d];
for m in 0..nsubq {
let d = if m == nsubq - 1 {
self.lastdsub as usize
} else {
dsub
};
if np != n_usize {
self.rng.shuffle(&mut perm);
}
for j in 0..np {
let row = perm[j] as usize;
let src = &x[row * dim + m * dsub..row * dim + m * dsub + d];
xslice[j * d..(j + 1) * d].copy_from_slice(src);
}
let cstart = m * ksub * dsub;
let clen = ksub * d;
let rng = &mut self.rng;
kmeans(
rng,
&xslice[..np * d],
&mut self.centroids[cstart..cstart + clen],
np,
d,
);
}
}
pub fn compute_code(&self, x: &[f32], code: &mut [u8]) {
let nsubq = self.nsubq as usize;
let dsub = self.dsub as usize;
let ksub = KSUB as usize;
for m in 0..nsubq {
let d = if m == nsubq - 1 {
self.lastdsub as usize
} else {
dsub
};
let xi = &x[m * dsub..m * dsub + d];
let cstart = m * ksub * dsub;
assign_centroid(
xi,
&self.centroids[cstart..cstart + ksub * d],
&mut code[m],
d,
);
}
}
pub fn compute_codes(&self, x: &[f32], codes: &mut [u8], n: i32) {
let n = n as usize;
let dim = self.dim as usize;
let nsubq = self.nsubq as usize;
for i in 0..n {
self.compute_code(
&x[i * dim..(i + 1) * dim],
&mut codes[i * nsubq..(i + 1) * nsubq],
);
}
}
pub fn mulcode(&self, x: &Vector, codes: &[u8], t: i32, alpha: f32) -> f32 {
let nsubq = self.nsubq as usize;
let dsub = self.dsub as usize;
let code = &codes[nsubq * t as usize..nsubq * (t as usize + 1)];
let mut res = 0.0f32;
for m in 0..nsubq {
let d = if m == nsubq - 1 {
self.lastdsub as usize
} else {
dsub
};
let c = self.get_centroids(m, code[m]);
for n in 0..d {
res += x[m * dsub + n] * c[n];
}
}
res * alpha
}
pub fn addcode(&self, x: &mut Vector, codes: &[u8], t: i32, alpha: f32) {
let nsubq = self.nsubq as usize;
let dsub = self.dsub as usize;
let code = &codes[nsubq * t as usize..nsubq * (t as usize + 1)];
for m in 0..nsubq {
let d = if m == nsubq - 1 {
self.lastdsub as usize
} else {
dsub
};
let c = self.get_centroids(m, code[m]);
for n in 0..d {
x[m * dsub + n] += alpha * c[n];
}
}
}
pub fn save<W: Write>(&self, writer: &mut W) -> Result<()> {
utils::write_i32(writer, self.dim)?;
utils::write_i32(writer, self.nsubq)?;
utils::write_i32(writer, self.dsub)?;
utils::write_i32(writer, self.lastdsub)?;
for &v in &self.centroids {
utils::write_f32(writer, v)?;
}
Ok(())
}
pub fn load<R: Read>(reader: &mut R) -> Result<Self> {
let dim = utils::read_i32(reader)?;
let nsubq = utils::read_i32(reader)?;
let dsub = utils::read_i32(reader)?;
let lastdsub = utils::read_i32(reader)?;
if dim < 0 || nsubq < 0 || dsub < 0 || lastdsub < 0 {
return Err(FastTextError::InvalidModel(
"ProductQuantizer: negative dimension field".to_string(),
));
}
let centroids_len = (dim * KSUB) as usize;
let mut centroids = vec![0.0f32; centroids_len];
for v in centroids.iter_mut() {
*v = utils::read_f32(reader)?;
}
Ok(ProductQuantizer {
dim,
nsubq,
dsub,
lastdsub,
centroids,
rng: MinstdRng::new(SEED),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
#[test]
fn test_pq_constants() {
assert_eq!(NBITS, 8);
assert_eq!(KSUB, 256);
assert_eq!(MAX_POINTS, 65536);
assert_eq!(NITER, 25);
}
#[test]
fn test_pq_dimension_decomposition_even() {
let pq = ProductQuantizer::new(100, 10);
assert_eq!(pq.dim, 100);
assert_eq!(pq.dsub, 10);
assert_eq!(pq.nsubq, 10);
assert_eq!(pq.lastdsub, 10); assert_eq!(pq.centroids.len(), (100 * 256) as usize);
}
#[test]
fn test_pq_dimension_decomposition_odd() {
let pq = ProductQuantizer::new(11, 5);
assert_eq!(pq.dim, 11);
assert_eq!(pq.dsub, 5);
assert_eq!(pq.nsubq, 3);
assert_eq!(pq.lastdsub, 1);
assert_eq!(pq.centroids.len(), (11 * 256) as usize);
}
#[test]
fn test_pq_dimension_decomposition_dim_equals_dsub() {
let pq = ProductQuantizer::new(8, 8);
assert_eq!(pq.nsubq, 1);
assert_eq!(pq.lastdsub, 8);
assert_eq!(pq.centroids.len(), (8 * 256) as usize);
}
#[test]
fn test_pq_dimension_decomposition_dsub_one() {
let pq = ProductQuantizer::new(4, 1);
assert_eq!(pq.nsubq, 4);
assert_eq!(pq.lastdsub, 1);
}
fn make_known_pq() -> ProductQuantizer {
let mut pq = ProductQuantizer::new(4, 2);
let ksub = KSUB as usize;
let dsub = 2usize;
for v in pq.centroids.iter_mut() {
*v = 100.0;
}
let cstart0 = 0;
pq.centroids[cstart0] = 1.0; pq.centroids[cstart0 + 1] = 0.0; pq.centroids[cstart0 + dsub] = 0.0; pq.centroids[cstart0 + dsub + 1] = 1.0;
let cstart1 = ksub * dsub;
pq.centroids[cstart1] = 2.0; pq.centroids[cstart1 + 1] = 0.0; pq.centroids[cstart1 + dsub] = 0.0; pq.centroids[cstart1 + dsub + 1] = 2.0;
pq
}
#[test]
fn test_pq_compute_code_assigns_nearest_centroid() {
let pq = make_known_pq();
let x = vec![0.9f32, 0.1, 0.1, 1.9];
let mut code = vec![0u8; 2];
pq.compute_code(&x, &mut code);
assert_eq!(code[0], 0, "sub-vec 0 should map to centroid 0");
assert_eq!(code[1], 1, "sub-vec 1 should map to centroid 1");
}
#[test]
fn test_pq_compute_code_second_centroid() {
let pq = make_known_pq();
let x = vec![0.1f32, 0.9, 1.9, 0.1];
let mut code = vec![0u8; 2];
pq.compute_code(&x, &mut code);
assert_eq!(code[0], 1);
assert_eq!(code[1], 0);
}
#[test]
fn test_pq_compute_codes_all_rows() {
let pq = make_known_pq();
let data: Vec<f32> = vec![
0.9, 0.1, 0.1, 1.9, 0.1, 0.9, 1.9, 0.1, ];
let n = 2i32;
let mut codes = vec![0u8; 2 * pq.nsubq as usize];
pq.compute_codes(&data, &mut codes, n);
assert_eq!(codes[0], 0);
assert_eq!(codes[1], 1);
assert_eq!(codes[2], 1);
assert_eq!(codes[3], 0);
}
#[test]
fn test_pq_mulcode_matches_naive_dot() {
let pq = make_known_pq();
let x_data = [0.9f32, 0.1, 0.1, 1.9];
let mut x = Vector::new(4);
for (i, &v) in x_data.iter().enumerate() {
x[i] = v;
}
let codes: Vec<u8> = vec![0, 1];
let alpha = 1.0f32;
let result = pq.mulcode(&x, &codes, 0, alpha);
let expected = 0.9f32 * 1.0 + 0.1f32 * 0.0 + 0.1f32 * 0.0 + 1.9f32 * 2.0;
assert!(
(result - expected).abs() < 1e-6,
"mulcode={} expected={}",
result,
expected
);
}
#[test]
fn test_pq_mulcode_with_alpha() {
let pq = make_known_pq();
let x_data = [0.9f32, 0.1, 0.1, 1.9];
let mut x = Vector::new(4);
for (i, &v) in x_data.iter().enumerate() {
x[i] = v;
}
let codes: Vec<u8> = vec![0, 1];
let alpha = 2.5f32;
let result = pq.mulcode(&x, &codes, 0, alpha);
let base = 0.9f32 * 1.0 + 0.1f32 * 0.0 + 0.1f32 * 0.0 + 1.9f32 * 2.0; let expected = base * alpha;
assert!(
(result - expected).abs() < 1e-5,
"mulcode with alpha={}: got {} expected {}",
alpha,
result,
expected
);
}
#[test]
fn test_pq_mulcode_multiple_rows() {
let pq = make_known_pq();
let x_data = [0.1f32, 0.9, 1.9, 0.1];
let mut x = Vector::new(4);
for (i, &v) in x_data.iter().enumerate() {
x[i] = v;
}
let codes: Vec<u8> = vec![1, 0, 0, 1];
let result = pq.mulcode(&x, &codes, 0, 1.0);
let expected = 0.9f32 + 3.8f32;
assert!((result - expected).abs() < 1e-6);
}
#[test]
fn test_pq_addcode_result() {
let pq = make_known_pq();
let mut x = Vector::new(4);
let codes: Vec<u8> = vec![0, 1]; pq.addcode(&mut x, &codes, 0, 1.0);
assert!((x[0] - 1.0).abs() < 1e-7);
assert!((x[1] - 0.0).abs() < 1e-7);
assert!((x[2] - 0.0).abs() < 1e-7);
assert!((x[3] - 2.0).abs() < 1e-7);
}
#[test]
fn test_pq_addcode_with_alpha() {
let pq = make_known_pq();
let mut x = Vector::new(4);
let codes: Vec<u8> = vec![0, 1];
pq.addcode(&mut x, &codes, 0, 3.0);
assert!((x[0] - 3.0).abs() < 1e-6); assert!((x[1] - 0.0).abs() < 1e-6); assert!((x[2] - 0.0).abs() < 1e-6); assert!((x[3] - 6.0).abs() < 1e-6); }
#[test]
fn test_pq_addcode_accumulates() {
let pq = make_known_pq();
let mut x = Vector::new(4);
x[0] = 5.0;
x[3] = 1.0;
let codes: Vec<u8> = vec![0, 1]; pq.addcode(&mut x, &codes, 0, 1.0);
assert!((x[0] - 6.0).abs() < 1e-7); assert!((x[1] - 0.0).abs() < 1e-7); assert!((x[2] - 0.0).abs() < 1e-7); assert!((x[3] - 3.0).abs() < 1e-7); }
#[test]
fn test_pq_mulcode_addcode_consistency() {
let pq = make_known_pq();
let x_data = [0.3f32, 0.7, 1.1, 0.5];
let mut x = Vector::new(4);
for (i, &v) in x_data.iter().enumerate() {
x[i] = v;
}
let codes: Vec<u8> = vec![0, 1];
let mul_result = pq.mulcode(&x, &codes, 0, 1.0);
let mut recon = Vector::new(4);
pq.addcode(&mut recon, &codes, 0, 1.0);
let dot: f32 = x_data
.iter()
.zip(recon.data().iter())
.map(|(&a, &b)| a * b)
.sum();
assert!(
(mul_result - dot).abs() < 1e-6,
"mulcode={} dot={}",
mul_result,
dot
);
}
#[test]
fn test_pq_save_load_roundtrip() {
let pq = make_known_pq();
let mut buf = Vec::new();
pq.save(&mut buf).expect("save should succeed");
let mut cursor = Cursor::new(&buf);
let pq2 = ProductQuantizer::load(&mut cursor).expect("load should succeed");
assert_eq!(pq2.dim, pq.dim);
assert_eq!(pq2.nsubq, pq.nsubq);
assert_eq!(pq2.dsub, pq.dsub);
assert_eq!(pq2.lastdsub, pq.lastdsub);
assert_eq!(pq2.centroids.len(), pq.centroids.len());
for (a, b) in pq.centroids.iter().zip(pq2.centroids.iter()) {
assert_eq!(a.to_bits(), b.to_bits(), "centroid mismatch");
}
}
#[test]
fn test_pq_save_load_odd_dim() {
let mut pq = ProductQuantizer::new(11, 5);
for (i, v) in pq.centroids.iter_mut().enumerate() {
*v = i as f32 * 0.001;
}
let mut buf = Vec::new();
pq.save(&mut buf).unwrap();
let mut cursor = Cursor::new(&buf);
let pq2 = ProductQuantizer::load(&mut cursor).unwrap();
assert_eq!(pq2.dim, 11);
assert_eq!(pq2.nsubq, 3);
assert_eq!(pq2.dsub, 5);
assert_eq!(pq2.lastdsub, 1);
for (a, b) in pq.centroids.iter().zip(pq2.centroids.iter()) {
assert_eq!(a.to_bits(), b.to_bits());
}
}
#[test]
fn test_pq_save_byte_count() {
let pq = ProductQuantizer::new(4, 2);
let mut buf = Vec::new();
pq.save(&mut buf).unwrap();
let expected = 4 * 4 + 4 * 256 * 4; assert_eq!(buf.len(), expected);
}
#[test]
fn test_pq_save_load_preserves_compute_code() {
let pq = make_known_pq();
let mut buf = Vec::new();
pq.save(&mut buf).unwrap();
let mut cursor = Cursor::new(&buf);
let pq2 = ProductQuantizer::load(&mut cursor).unwrap();
let x = vec![0.9f32, 0.1, 0.1, 1.9];
let mut code1 = vec![0u8; 2];
let mut code2 = vec![0u8; 2];
pq.compute_code(&x, &mut code1);
pq2.compute_code(&x, &mut code2);
assert_eq!(code1, code2);
}
#[test]
fn test_pq_train_smoke() {
let dim = 10i32;
let dsub = 5i32;
let mut pq = ProductQuantizer::new(dim, dsub);
let n = 300i32;
let mut data = vec![0.0f32; (n as usize) * (dim as usize)];
for (i, v) in data.iter_mut().enumerate() {
*v = (i as f32).sin();
}
pq.train(n, &data);
let non_zero = pq.centroids.iter().any(|&v| v != 0.0);
assert!(non_zero, "centroids should be non-zero after training");
}
#[test]
fn test_pq_train_too_few_rows() {
let mut pq = ProductQuantizer::new(4, 2);
let data = vec![0.0f32; 4 * 4]; pq.train(4, &data); assert!(pq.centroids.iter().all(|&v| v == 0.0));
}
#[test]
fn test_pq_train_small_n() {
let dim = 4i32;
let dsub = 2i32;
let n = 10i32;
let mut pq = ProductQuantizer::new(dim, dsub);
let mut data = vec![0.0f32; n as usize * dim as usize];
for i in 0..n as usize {
for j in 0..dim as usize {
data[i * dim as usize + j] = (i as f32 + 1.0) * (j as f32 + 1.0);
}
}
pq.train(n, &data);
assert!(
pq.centroids.iter().any(|&v| v != 0.0),
"centroids should be non-zero after small-n training"
);
let nsubq = pq.nsubq as usize;
let mut all_codes: Vec<Vec<u8>> = Vec::new();
for i in 0..n as usize {
let xi = &data[i * dim as usize..(i + 1) * dim as usize];
let mut code = vec![0u8; nsubq];
pq.compute_code(xi, &mut code); all_codes.push(code);
}
assert_eq!(all_codes.len(), n as usize);
assert_ne!(
all_codes[0], all_codes[1],
"distinct training points should get distinct codes"
);
for (i, codes_i) in all_codes[..n as usize].iter().enumerate() {
for (m, &code_val) in codes_i.iter().enumerate() {
assert_eq!(
code_val as usize, i,
"training point {} sub-quantizer {} should map to code {}",
i, m, i
);
}
}
for (i, codes_i) in all_codes[..n as usize].iter().enumerate() {
let mut reconstructed = Vector::new(dim as usize);
pq.addcode(&mut reconstructed, codes_i, 0, 1.0);
let xi = &data[i * dim as usize..(i + 1) * dim as usize];
for j in 0..dim as usize {
assert!(
(reconstructed[j] - xi[j]).abs() < 1e-6,
"reconstruction mismatch at point {} dim {}: got {} expected {}",
i,
j,
reconstructed[j],
xi[j]
);
}
}
}
#[test]
fn test_pq_get_centroids_matches_direct_slice() {
let pq = make_known_pq();
let c = pq.get_centroids(0, 0);
assert_eq!(c.len(), 2);
assert!((c[0] - 1.0).abs() < 1e-9);
assert!((c[1] - 0.0).abs() < 1e-9);
let c = pq.get_centroids(0, 1);
assert!((c[0] - 0.0).abs() < 1e-9);
assert!((c[1] - 1.0).abs() < 1e-9);
let c = pq.get_centroids(1, 1);
assert!((c[0] - 0.0).abs() < 1e-9);
assert!((c[1] - 2.0).abs() < 1e-9);
}
}