use serde::{Deserialize, Serialize};
use crate::backend::ExecutionBackend;
use crate::bitpack::{pack_signs, unpack_signs, BitPackedVector};
use crate::error::{Result, TurboQuantError};
use crate::turboquant_mse::TurboQuantMSE;
use crate::turboquant_prod::TurboQuantProd;
use crate::utils::inner_product;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchQuantizedMSE {
pub(crate) packed_indices: Vec<BitPackedVector>,
pub bit_width: u8,
pub dim: usize,
}
impl BatchQuantizedMSE {
pub fn validate_layout(&self) -> Result<()> {
if self.dim == 0 {
return Err(TurboQuantError::InvalidDimension(self.dim));
}
if !(1..=8).contains(&self.bit_width) {
return Err(TurboQuantError::InvalidBitWidth(self.bit_width));
}
for (row_index, packed) in self.packed_indices.iter().enumerate() {
if packed.count() != self.dim {
return Err(TurboQuantError::LengthMismatch {
context: format!("MSE batch row {row_index} packed index count"),
expected: self.dim,
got: packed.count(),
});
}
if packed.bit_width() != self.bit_width {
return Err(TurboQuantError::BitWidthMismatch {
expected: self.bit_width,
got: packed.bit_width(),
});
}
}
Ok(())
}
pub fn len(&self) -> usize {
self.packed_indices.len()
}
pub fn is_empty(&self) -> bool {
self.packed_indices.is_empty()
}
pub fn total_bytes(&self) -> usize {
self.packed_indices.iter().map(|p| p.byte_len()).sum()
}
pub fn bytes_per_vector(&self) -> f64 {
if self.packed_indices.is_empty() {
return 0.0;
}
self.total_bytes() as f64 / self.packed_indices.len() as f64
}
pub fn compression_ratio(&self) -> f64 {
let uncompressed = self.packed_indices.len() * self.dim * 4;
if self.total_bytes() == 0 {
return 0.0;
}
uncompressed as f64 / self.total_bytes() as f64
}
pub fn get_packed(&self, index: usize) -> Option<&BitPackedVector> {
self.packed_indices.get(index)
}
pub fn unpack_indices(&self, index: usize) -> Option<Vec<u8>> {
self.packed_indices.get(index).map(|p| p.unpack())
}
pub fn extend(&mut self, other: &BatchQuantizedMSE) -> Result<()> {
if self.dim != other.dim {
return Err(TurboQuantError::DimensionMismatch {
expected: self.dim,
got: other.dim,
});
}
if self.bit_width != other.bit_width {
return Err(TurboQuantError::BitWidthMismatch {
expected: self.bit_width,
got: other.bit_width,
});
}
self.packed_indices.extend_from_slice(&other.packed_indices);
Ok(())
}
pub fn drain_front(&mut self, n: usize) {
let n = n.min(self.packed_indices.len());
self.packed_indices.drain(..n);
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchQuantizedProd {
pub(crate) packed_mse_indices: Vec<BitPackedVector>,
pub(crate) packed_qjl_signs: Vec<Vec<u8>>,
pub(crate) residual_norms: Vec<f64>,
pub bit_width: u8,
pub dim: usize,
}
impl BatchQuantizedProd {
pub fn validate_layout(&self) -> Result<()> {
if self.dim == 0 {
return Err(TurboQuantError::InvalidDimension(self.dim));
}
if !(2..=8).contains(&self.bit_width) {
return Err(TurboQuantError::InvalidBitWidth(self.bit_width));
}
let row_count = self.packed_mse_indices.len();
if self.packed_qjl_signs.len() != row_count {
return Err(TurboQuantError::LengthMismatch {
context: "Prod batch QJL row count".into(),
expected: row_count,
got: self.packed_qjl_signs.len(),
});
}
if self.residual_norms.len() != row_count {
return Err(TurboQuantError::LengthMismatch {
context: "Prod batch residual norm count".into(),
expected: row_count,
got: self.residual_norms.len(),
});
}
let expected_mse_bit_width = self.bit_width - 1;
let expected_sign_bytes = self.dim.div_ceil(8);
for (row_index, packed) in self.packed_mse_indices.iter().enumerate() {
if packed.count() != self.dim {
return Err(TurboQuantError::LengthMismatch {
context: format!("Prod batch row {row_index} packed MSE index count"),
expected: self.dim,
got: packed.count(),
});
}
if packed.bit_width() != expected_mse_bit_width {
return Err(TurboQuantError::BitWidthMismatch {
expected: expected_mse_bit_width,
got: packed.bit_width(),
});
}
}
for (row_index, packed) in self.packed_qjl_signs.iter().enumerate() {
if packed.len() != expected_sign_bytes {
return Err(TurboQuantError::LengthMismatch {
context: format!("Prod batch row {row_index} packed QJL byte length"),
expected: expected_sign_bytes,
got: packed.len(),
});
}
}
for (row_index, residual_norm) in self.residual_norms.iter().enumerate() {
if !residual_norm.is_finite() || *residual_norm < 0.0 {
return Err(TurboQuantError::InvalidValue {
context: format!("Prod batch residual norm at row {row_index}"),
value: *residual_norm,
});
}
}
Ok(())
}
pub fn len(&self) -> usize {
self.packed_mse_indices.len()
}
pub fn is_empty(&self) -> bool {
self.packed_mse_indices.is_empty()
}
pub fn total_bytes(&self) -> usize {
let mse_bytes: usize = self.packed_mse_indices.iter().map(|p| p.byte_len()).sum();
let qjl_bytes: usize = self.packed_qjl_signs.iter().map(|p| p.len()).sum();
let norm_bytes = self.residual_norms.len() * 4; mse_bytes + qjl_bytes + norm_bytes
}
pub fn compression_ratio(&self) -> f64 {
let uncompressed = self.packed_mse_indices.len() * self.dim * 4;
let total = self.total_bytes();
if total == 0 {
return 0.0;
}
uncompressed as f64 / total as f64
}
pub fn unpack_mse_indices(&self, index: usize) -> Option<Vec<u8>> {
self.packed_mse_indices
.get(index)
.map(BitPackedVector::unpack)
}
pub fn unpack_qjl_signs(&self, index: usize) -> Option<Vec<bool>> {
self.packed_qjl_signs
.get(index)
.map(|packed| unpack_signs(packed, self.dim))
}
pub fn residual_norm(&self, index: usize) -> Option<f64> {
self.residual_norms.get(index).copied()
}
pub fn extend(&mut self, other: &BatchQuantizedProd) -> Result<()> {
if self.dim != other.dim {
return Err(TurboQuantError::DimensionMismatch {
expected: self.dim,
got: other.dim,
});
}
if self.bit_width != other.bit_width {
return Err(TurboQuantError::BitWidthMismatch {
expected: self.bit_width,
got: other.bit_width,
});
}
self.packed_mse_indices
.extend_from_slice(&other.packed_mse_indices);
self.packed_qjl_signs
.extend_from_slice(&other.packed_qjl_signs);
self.residual_norms.extend_from_slice(&other.residual_norms);
Ok(())
}
pub fn drain_front(&mut self, n: usize) {
let n = n.min(self.packed_mse_indices.len());
self.packed_mse_indices.drain(..n);
self.packed_qjl_signs.drain(..n);
self.residual_norms.drain(..n);
}
}
pub fn batch_quantize_mse(
quantizer: &TurboQuantMSE,
vectors: &[Vec<f64>],
) -> Result<BatchQuantizedMSE> {
batch_quantize_mse_with_backend(ExecutionBackend::default(), quantizer, vectors)
}
pub fn batch_quantize_mse_with_backend(
backend: ExecutionBackend,
quantizer: &TurboQuantMSE,
vectors: &[Vec<f64>],
) -> Result<BatchQuantizedMSE> {
#[cfg(not(feature = "gpu"))]
let _ = backend;
#[cfg(feature = "gpu")]
if matches!(backend, ExecutionBackend::Wgpu) {
return crate::gpu::batch_quantize_mse_wgpu(quantizer, vectors);
}
batch_quantize_mse_cpu(quantizer, vectors)
}
fn batch_quantize_mse_cpu(
quantizer: &TurboQuantMSE,
vectors: &[Vec<f64>],
) -> Result<BatchQuantizedMSE> {
let mut packed_indices = Vec::with_capacity(vectors.len());
for v in vectors {
let q = quantizer.quantize(v)?;
let packed = BitPackedVector::pack(&q.indices, q.bit_width)?;
packed_indices.push(packed);
}
Ok(BatchQuantizedMSE {
packed_indices,
bit_width: quantizer.bit_width,
dim: quantizer.dim,
})
}
pub fn batch_dequantize_mse(
quantizer: &TurboQuantMSE,
batch: &BatchQuantizedMSE,
) -> Result<Vec<Vec<f64>>> {
batch_dequantize_mse_with_backend(ExecutionBackend::default(), quantizer, batch)
}
pub fn batch_dequantize_mse_with_backend(
backend: ExecutionBackend,
quantizer: &TurboQuantMSE,
batch: &BatchQuantizedMSE,
) -> Result<Vec<Vec<f64>>> {
batch.validate_layout()?;
#[cfg(not(feature = "gpu"))]
let _ = backend;
#[cfg(feature = "gpu")]
if matches!(backend, ExecutionBackend::Wgpu) {
return crate::gpu::batch_dequantize_mse_wgpu(quantizer, batch);
}
batch_dequantize_mse_cpu(quantizer, batch)
}
fn batch_dequantize_mse_cpu(
quantizer: &TurboQuantMSE,
batch: &BatchQuantizedMSE,
) -> Result<Vec<Vec<f64>>> {
let mut vectors = Vec::with_capacity(batch.len());
for i in 0..batch.len() {
let indices = batch
.unpack_indices(i)
.ok_or_else(|| TurboQuantError::Internal(format!("batch index {} out of range", i)))?;
let qv = crate::turboquant_mse::QuantizedVector {
indices,
bit_width: batch.bit_width,
dim: batch.dim,
};
let v = quantizer.dequantize(&qv)?;
vectors.push(v);
}
Ok(vectors)
}
pub fn batch_quantize_prod(
quantizer: &TurboQuantProd,
vectors: &[Vec<f64>],
) -> Result<BatchQuantizedProd> {
batch_quantize_prod_with_backend(ExecutionBackend::default(), quantizer, vectors)
}
pub fn batch_quantize_prod_with_backend(
backend: ExecutionBackend,
quantizer: &TurboQuantProd,
vectors: &[Vec<f64>],
) -> Result<BatchQuantizedProd> {
#[cfg(not(feature = "gpu"))]
let _ = backend;
#[cfg(feature = "gpu")]
if matches!(backend, ExecutionBackend::Wgpu) {
return crate::gpu::batch_quantize_prod_wgpu(quantizer, vectors);
}
batch_quantize_prod_cpu(quantizer, vectors)
}
fn batch_quantize_prod_cpu(
quantizer: &TurboQuantProd,
vectors: &[Vec<f64>],
) -> Result<BatchQuantizedProd> {
let mut packed_mse_indices = Vec::with_capacity(vectors.len());
let mut packed_qjl_signs = Vec::with_capacity(vectors.len());
let mut residual_norms = Vec::with_capacity(vectors.len());
for v in vectors {
let q = quantizer.quantize(v)?;
let mse_packed = BitPackedVector::pack(&q.mse_indices, q.bit_width - 1)?;
let signs_packed = pack_signs(&q.qjl_signs);
packed_mse_indices.push(mse_packed);
packed_qjl_signs.push(signs_packed);
residual_norms.push(q.residual_norm);
}
Ok(BatchQuantizedProd {
packed_mse_indices,
packed_qjl_signs,
residual_norms,
bit_width: quantizer.bit_width,
dim: quantizer.dim,
})
}
pub fn batch_estimate_inner_products(
quantizer: &TurboQuantProd,
batch: &BatchQuantizedProd,
query: &[f64],
) -> Result<Vec<f64>> {
batch_estimate_inner_products_with_backend(ExecutionBackend::default(), quantizer, batch, query)
}
pub fn batch_estimate_inner_products_with_backend(
backend: ExecutionBackend,
quantizer: &TurboQuantProd,
batch: &BatchQuantizedProd,
query: &[f64],
) -> Result<Vec<f64>> {
batch.validate_layout()?;
#[cfg(not(feature = "gpu"))]
let _ = backend;
if query.len() != batch.dim {
return Err(TurboQuantError::DimensionMismatch {
expected: batch.dim,
got: query.len(),
});
}
#[cfg(feature = "gpu")]
if matches!(backend, ExecutionBackend::Wgpu) {
return crate::gpu::batch_estimate_inner_products_wgpu(quantizer, batch, query);
}
batch_estimate_inner_products_cpu(quantizer, batch, query)
}
fn batch_estimate_inner_products_cpu(
quantizer: &TurboQuantProd,
batch: &BatchQuantizedProd,
query: &[f64],
) -> Result<Vec<f64>> {
let mut scores = Vec::with_capacity(batch.len());
for i in 0..batch.len() {
let mse_indices = batch.packed_mse_indices[i].unpack();
let qjl_signs = unpack_signs(&batch.packed_qjl_signs[i], batch.dim);
let pq = crate::turboquant_prod::ProdQuantized {
mse_indices,
qjl_signs,
residual_norm: batch.residual_norms[i],
bit_width: batch.bit_width,
dim: batch.dim,
};
let score = quantizer.estimate_inner_product(&pq, query)?;
scores.push(score);
}
Ok(scores)
}
pub fn batch_attention_scores_mse(
quantizer: &TurboQuantMSE,
keys: &BatchQuantizedMSE,
query: &[f64],
) -> Result<Vec<f64>> {
batch_attention_scores_mse_with_backend(ExecutionBackend::default(), quantizer, keys, query)
}
pub fn batch_attention_scores_mse_with_backend(
backend: ExecutionBackend,
quantizer: &TurboQuantMSE,
keys: &BatchQuantizedMSE,
query: &[f64],
) -> Result<Vec<f64>> {
keys.validate_layout()?;
#[cfg(not(feature = "gpu"))]
let _ = backend;
if query.len() != keys.dim {
return Err(TurboQuantError::DimensionMismatch {
expected: keys.dim,
got: query.len(),
});
}
#[cfg(feature = "gpu")]
if matches!(backend, ExecutionBackend::Wgpu) {
return crate::gpu::batch_attention_scores_mse_wgpu(quantizer, keys, query);
}
let reconstructed = batch_dequantize_mse_cpu(quantizer, keys)?;
let scores: Vec<f64> = reconstructed
.iter()
.map(|k| inner_product(k, query))
.collect();
Ok(scores)
}
pub fn batch_mse(quantizer: &TurboQuantMSE, vectors: &[Vec<f64>]) -> Result<f64> {
if vectors.is_empty() {
return Ok(0.0);
}
let mut total_mse = 0.0;
for v in vectors {
total_mse += quantizer.actual_mse(v)?;
}
Ok(total_mse / vectors.len() as f64)
}
pub fn batch_ip_error(
quantizer: &TurboQuantProd,
vectors: &[Vec<f64>],
query: &[f64],
) -> Result<f64> {
if vectors.is_empty() {
return Ok(0.0);
}
let batch = batch_quantize_prod(quantizer, vectors)?;
let estimated = batch_estimate_inner_products(quantizer, &batch, query)?;
let mut total_err = 0.0;
for (i, v) in vectors.iter().enumerate() {
let true_ip = inner_product(v, query);
total_err += (true_ip - estimated[i]).abs();
}
Ok(total_err / vectors.len() as f64)
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchStats {
pub count: usize,
pub dim: usize,
pub bit_width: u8,
pub total_bytes: usize,
pub uncompressed_bytes: usize,
pub compression_ratio: f64,
pub bytes_per_vector: f64,
}
impl BatchStats {
pub fn from_mse_batch(batch: &BatchQuantizedMSE) -> Self {
let total_bytes = batch.total_bytes();
let uncompressed_bytes = batch.len() * batch.dim * 4;
Self {
count: batch.len(),
dim: batch.dim,
bit_width: batch.bit_width,
total_bytes,
uncompressed_bytes,
compression_ratio: batch.compression_ratio(),
bytes_per_vector: batch.bytes_per_vector(),
}
}
pub fn from_prod_batch(batch: &BatchQuantizedProd) -> Self {
let total_bytes = batch.total_bytes();
let uncompressed_bytes = batch.len() * batch.dim * 4;
Self {
count: batch.len(),
dim: batch.dim,
bit_width: batch.bit_width,
total_bytes,
uncompressed_bytes,
compression_ratio: batch.compression_ratio(),
bytes_per_vector: total_bytes as f64 / batch.len().max(1) as f64,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::utils::normalize;
fn random_unit_vectors(dim: usize, count: usize, seed: u64) -> Vec<Vec<f64>> {
use rand::SeedableRng;
use rand_distr::{Distribution, Normal};
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
let normal = Normal::new(0.0, 1.0).unwrap();
(0..count)
.map(|_| {
let raw: Vec<f64> = (0..dim).map(|_| normal.sample(&mut rng)).collect();
normalize(&raw).unwrap()
})
.collect()
}
#[test]
fn test_batch_quantize_mse_roundtrip() {
let dim = 64;
let vectors = random_unit_vectors(dim, 10, 42);
let tq = TurboQuantMSE::new(dim, 4, 42).unwrap();
let batch = batch_quantize_mse(&tq, &vectors).unwrap();
assert_eq!(batch.len(), 10);
assert_eq!(batch.dim, dim);
assert_eq!(batch.bit_width, 4);
let reconstructed = batch_dequantize_mse(&tq, &batch).unwrap();
assert_eq!(reconstructed.len(), 10);
for (orig, recon) in vectors.iter().zip(reconstructed.iter()) {
let mse: f64 = orig
.iter()
.zip(recon.iter())
.map(|(a, b)| (a - b) * (a - b))
.sum::<f64>()
/ dim as f64;
assert!(mse < 0.1, "MSE {} too large", mse);
}
}
#[test]
fn test_batch_quantize_prod() {
let dim = 64;
let vectors = random_unit_vectors(dim, 8, 7);
let query = &random_unit_vectors(dim, 1, 99)[0];
let tq = TurboQuantProd::new(dim, 3, 7).unwrap();
let batch = batch_quantize_prod(&tq, &vectors).unwrap();
assert_eq!(batch.len(), 8);
let scores = batch_estimate_inner_products(&tq, &batch, query).unwrap();
assert_eq!(scores.len(), 8);
for (i, &score) in scores.iter().enumerate() {
let true_ip = inner_product(&vectors[i], query);
assert!(
(true_ip - score).abs() < 0.5,
"vector {}: true={}, est={}",
i,
true_ip,
score
);
}
}
#[test]
fn test_batch_compression_ratio() {
let dim = 128;
let vectors = random_unit_vectors(dim, 20, 1);
let tq = TurboQuantMSE::new(dim, 4, 1).unwrap();
let batch = batch_quantize_mse(&tq, &vectors).unwrap();
let stats = BatchStats::from_mse_batch(&batch);
assert!(
(stats.compression_ratio - 8.0).abs() < 0.1,
"ratio={}",
stats.compression_ratio
);
assert_eq!(stats.count, 20);
assert_eq!(stats.total_bytes, 20 * 64);
}
#[test]
fn test_batch_attention_scores_mse() {
let dim = 64;
let keys = random_unit_vectors(dim, 5, 42);
let query = &random_unit_vectors(dim, 1, 99)[0];
let tq = TurboQuantMSE::new(dim, 4, 42).unwrap();
let batch = batch_quantize_mse(&tq, &keys).unwrap();
let scores = batch_attention_scores_mse(&tq, &batch, query).unwrap();
assert_eq!(scores.len(), 5);
}
#[test]
fn test_batch_mse_stat() {
let dim = 128;
let vectors = random_unit_vectors(dim, 10, 3);
let tq = TurboQuantMSE::new(dim, 4, 3).unwrap();
let avg_mse = batch_mse(&tq, &vectors).unwrap();
let bound = tq.distortion_bound();
assert!(avg_mse > 0.0, "MSE should be positive");
assert!(
avg_mse < bound * 10.0,
"MSE {} too far from bound {}",
avg_mse,
bound
);
}
#[test]
fn test_empty_batch() {
let tq = TurboQuantMSE::new(64, 4, 1).unwrap();
let batch = batch_quantize_mse(&tq, &[]).unwrap();
assert!(batch.is_empty());
assert_eq!(batch.total_bytes(), 0);
}
#[test]
fn test_batch_ip_error() {
let dim = 64;
let vectors = random_unit_vectors(dim, 10, 5);
let query = &random_unit_vectors(dim, 1, 50)[0];
let tq = TurboQuantProd::new(dim, 3, 5).unwrap();
let avg_err = batch_ip_error(&tq, &vectors, query).unwrap();
assert!(avg_err < 0.5, "avg IP error {} too large", avg_err);
}
#[test]
fn test_batch_mse_extend_dimension_mismatch() {
let tq64 = TurboQuantMSE::new(64, 4, 1).unwrap();
let vecs64 = random_unit_vectors(64, 3, 1);
let mut batch64 = batch_quantize_mse(&tq64, &vecs64).unwrap();
let tq32 = TurboQuantMSE::new(32, 4, 1).unwrap();
let vecs32 = random_unit_vectors(32, 3, 2);
let batch32 = batch_quantize_mse(&tq32, &vecs32).unwrap();
assert!(batch64.extend(&batch32).is_err());
}
#[test]
fn test_batch_prod_extend_dimension_mismatch() {
let tq64 = TurboQuantProd::new(64, 3, 1).unwrap();
let vecs64 = random_unit_vectors(64, 3, 1);
let mut batch64 = batch_quantize_prod(&tq64, &vecs64).unwrap();
let tq32 = TurboQuantProd::new(32, 3, 1).unwrap();
let vecs32 = random_unit_vectors(32, 3, 2);
let batch32 = batch_quantize_prod(&tq32, &vecs32).unwrap();
assert!(batch64.extend(&batch32).is_err());
}
#[test]
fn test_batch_drain_front() {
let dim = 64;
let vecs = random_unit_vectors(dim, 10, 1);
let tq = TurboQuantMSE::new(dim, 4, 1).unwrap();
let mut batch = batch_quantize_mse(&tq, &vecs).unwrap();
assert_eq!(batch.len(), 10);
batch.drain_front(3);
assert_eq!(batch.len(), 7);
batch.drain_front(100);
assert_eq!(batch.len(), 0);
assert!(batch.is_empty());
}
#[test]
fn test_batch_prod_drain_front() {
let dim = 64;
let vecs = random_unit_vectors(dim, 8, 1);
let tq = TurboQuantProd::new(dim, 3, 1).unwrap();
let mut batch = batch_quantize_prod(&tq, &vecs).unwrap();
assert_eq!(batch.len(), 8);
batch.drain_front(5);
assert_eq!(batch.len(), 3);
}
#[test]
fn test_batch_empty_mse_and_ip_error() {
let tq_mse = TurboQuantMSE::new(64, 4, 1).unwrap();
assert_eq!(batch_mse(&tq_mse, &[]).unwrap(), 0.0);
let tq_prod = TurboQuantProd::new(64, 3, 1).unwrap();
let query = &random_unit_vectors(64, 1, 1)[0];
assert_eq!(batch_ip_error(&tq_prod, &[], query).unwrap(), 0.0);
}
#[test]
fn test_batch_attention_scores_dim_mismatch() {
let dim = 64;
let vecs = random_unit_vectors(dim, 5, 1);
let tq = TurboQuantMSE::new(dim, 4, 1).unwrap();
let batch = batch_quantize_mse(&tq, &vecs).unwrap();
let bad_query = vec![0.0; 32];
assert!(batch_attention_scores_mse(&tq, &batch, &bad_query).is_err());
}
#[test]
fn test_batch_stats_prod() {
let dim = 64;
let vecs = random_unit_vectors(dim, 5, 1);
let tq = TurboQuantProd::new(dim, 3, 1).unwrap();
let batch = batch_quantize_prod(&tq, &vecs).unwrap();
let stats = BatchStats::from_prod_batch(&batch);
assert_eq!(stats.count, 5);
assert_eq!(stats.dim, 64);
assert!(stats.total_bytes > 0);
assert!(stats.compression_ratio > 1.0);
}
#[test]
fn test_batch_mse_extend_bit_width_mismatch() {
let tq2 = TurboQuantMSE::new(64, 2, 1).unwrap();
let tq4 = TurboQuantMSE::new(64, 4, 1).unwrap();
let vecs = random_unit_vectors(64, 2, 1);
let mut batch2 = batch_quantize_mse(&tq2, &vecs).unwrap();
let batch4 = batch_quantize_mse(&tq4, &vecs).unwrap();
assert!(matches!(
batch2.extend(&batch4),
Err(TurboQuantError::BitWidthMismatch { .. })
));
}
#[test]
fn test_batch_prod_extend_bit_width_mismatch() {
let tq3 = TurboQuantProd::new(64, 3, 1).unwrap();
let tq4 = TurboQuantProd::new(64, 4, 1).unwrap();
let vecs = random_unit_vectors(64, 2, 1);
let mut batch3 = batch_quantize_prod(&tq3, &vecs).unwrap();
let batch4 = batch_quantize_prod(&tq4, &vecs).unwrap();
assert!(matches!(
batch3.extend(&batch4),
Err(TurboQuantError::BitWidthMismatch { .. })
));
}
#[test]
fn test_batch_mse_extend_same_params() {
let dim = 64;
let vecs1 = random_unit_vectors(dim, 5, 1);
let vecs2 = random_unit_vectors(dim, 3, 2);
let tq = TurboQuantMSE::new(dim, 4, 1).unwrap();
let mut batch1 = batch_quantize_mse(&tq, &vecs1).unwrap();
let batch2 = batch_quantize_mse(&tq, &vecs2).unwrap();
batch1.extend(&batch2).unwrap();
assert_eq!(batch1.len(), 8);
let recon = batch_dequantize_mse(&tq, &batch1).unwrap();
assert_eq!(recon.len(), 8);
}
#[test]
fn test_batch_ip_dim_mismatch() {
let vecs = random_unit_vectors(64, 5, 1);
let tq = TurboQuantProd::new(64, 3, 1).unwrap();
let batch = batch_quantize_prod(&tq, &vecs).unwrap();
let bad_query = vec![0.0; 32];
assert!(batch_estimate_inner_products(&tq, &batch, &bad_query).is_err());
}
#[test]
fn test_batch_dequantize_rejects_packed_bit_width_mismatch() {
let dim = 32;
let tq = TurboQuantMSE::new(dim, 4, 1).unwrap();
let vectors = random_unit_vectors(dim, 2, 1);
let mut batch = batch_quantize_mse(&tq, &vectors).unwrap();
batch.bit_width = 3;
assert!(matches!(
batch_dequantize_mse(&tq, &batch),
Err(TurboQuantError::BitWidthMismatch { .. })
));
}
#[test]
fn test_batch_prod_rejects_missing_qjl_rows() {
let dim = 32;
let tq = TurboQuantProd::new(dim, 3, 1).unwrap();
let vectors = random_unit_vectors(dim, 2, 1);
let query = &random_unit_vectors(dim, 1, 99)[0];
let mut batch = batch_quantize_prod(&tq, &vectors).unwrap();
batch.packed_qjl_signs.pop();
assert!(matches!(
batch_estimate_inner_products(&tq, &batch, query),
Err(TurboQuantError::LengthMismatch { .. })
));
}
#[test]
fn test_batch_prod_rejects_short_sign_buffer() {
let dim = 32;
let tq = TurboQuantProd::new(dim, 3, 1).unwrap();
let vectors = random_unit_vectors(dim, 2, 1);
let query = &random_unit_vectors(dim, 1, 99)[0];
let mut batch = batch_quantize_prod(&tq, &vectors).unwrap();
batch.packed_qjl_signs[0].pop();
assert!(matches!(
batch_estimate_inner_products(&tq, &batch, query),
Err(TurboQuantError::LengthMismatch { .. })
));
}
}