use burn::backend::{wgpu::WgpuDevice, Wgpu};
use burn::tensor::{Int, Tensor, TensorData};
use crate::batch::{BatchQuantizedMSE, BatchQuantizedProd};
use crate::bitpack::{pack_signs, BitPackedVector};
use crate::error::{Result, TurboQuantError};
use crate::qjl::QJLQuantized;
use crate::turboquant_mse::TurboQuantMSE;
use crate::turboquant_prod::TurboQuantProd;
use crate::utils::validate_unit_vector;
type GpuBackend = Wgpu<f32, i32>;
type FloatTensor<const D: usize> = Tensor<GpuBackend, D>;
type IntTensor<const D: usize> = Tensor<GpuBackend, D, Int>;
type QuantizedIndexRows = Vec<Vec<u8>>;
type ReconstructedRows = Vec<Vec<f64>>;
type QuantizedWithReconstruction = (QuantizedIndexRows, ReconstructedRows);
pub struct WgpuMseBatchRunner {
device: WgpuDevice,
rotation: FloatTensor<2>,
rotation_t: FloatTensor<2>,
boundaries: FloatTensor<1>,
centroids: FloatTensor<1>,
dim: usize,
boundary_count: usize,
level_count: usize,
}
impl WgpuMseBatchRunner {
pub fn new(quantizer: &TurboQuantMSE) -> Result<Self> {
Self::with_device(quantizer, WgpuDevice::default())
}
pub fn with_device(quantizer: &TurboQuantMSE, device: WgpuDevice) -> Result<Self> {
let matrix = quantizer.rotation().matrix();
let dim = quantizer.dim;
let rotation = tensor_2d(&device, flatten_matrix_rows(matrix), dim, dim);
let rotation_t = tensor_2d(&device, flatten_matrix_transpose_rows(matrix), dim, dim);
let boundaries = tensor_1d(
&device,
quantizer
.codebook()
.boundaries
.iter()
.map(|value| *value as f32)
.collect(),
);
let centroids = tensor_1d(
&device,
quantizer
.codebook()
.centroids
.iter()
.map(|value| *value as f32)
.collect(),
);
Ok(Self {
device,
rotation,
rotation_t,
boundaries,
centroids,
dim,
boundary_count: quantizer.codebook().boundaries.len(),
level_count: quantizer.codebook().centroids.len(),
})
}
pub fn quantize_batch(
&self,
quantizer: &TurboQuantMSE,
vectors: &[Vec<f64>],
) -> Result<Vec<Vec<u8>>> {
self.quantize_with_reconstruction(quantizer, vectors)
.map(|(indices, _)| indices)
}
pub fn quantize_with_reconstruction(
&self,
quantizer: &TurboQuantMSE,
vectors: &[Vec<f64>],
) -> Result<QuantizedWithReconstruction> {
validate_vectors(quantizer, vectors)?;
if vectors.is_empty() {
return Ok((Vec::new(), Vec::new()));
}
let batch_len = vectors.len();
let input = tensor_2d(&self.device, flatten_vectors(vectors)?, batch_len, self.dim);
let rotated = input.matmul(self.rotation_t.clone());
let indices = self.quantize_rotated(rotated.clone(), batch_len)?;
let rotated_hat = self.lookup_centroids(indices.clone(), batch_len)?;
let reconstructed = rotated_hat.matmul(self.rotation.clone());
Ok((
tensor_to_u8_rows(indices, self.dim)?,
tensor_to_f64_rows(reconstructed, self.dim)?,
))
}
pub fn dequantize_batch(
&self,
quantizer: &TurboQuantMSE,
batch: &BatchQuantizedMSE,
) -> Result<Vec<Vec<f64>>> {
if batch.dim != quantizer.dim {
return Err(TurboQuantError::DimensionMismatch {
expected: quantizer.dim,
got: batch.dim,
});
}
if batch.bit_width != quantizer.bit_width {
return Err(TurboQuantError::BitWidthMismatch {
expected: quantizer.bit_width,
got: batch.bit_width,
});
}
let indices = batch
.packed_indices
.iter()
.map(BitPackedVector::unpack)
.collect::<Vec<_>>();
self.dequantize_from_indices(&indices)
}
pub fn dequantize_from_indices(&self, indices: &[Vec<u8>]) -> Result<Vec<Vec<f64>>> {
if indices.is_empty() {
return Ok(Vec::new());
}
let batch_len = indices.len();
let packed = flatten_index_rows(indices, self.dim)?;
let indices = int_tensor_2d(&self.device, packed, batch_len, self.dim);
let rotated_hat = self.lookup_centroids(indices, batch_len)?;
let reconstructed = rotated_hat.matmul(self.rotation.clone());
tensor_to_f64_rows(reconstructed, self.dim)
}
pub fn attention_scores(
&self,
quantizer: &TurboQuantMSE,
keys: &BatchQuantizedMSE,
query: &[f64],
) -> Result<Vec<f64>> {
keys.validate_layout()?;
if keys.dim != quantizer.dim {
return Err(TurboQuantError::DimensionMismatch {
expected: quantizer.dim,
got: keys.dim,
});
}
if keys.bit_width != quantizer.bit_width {
return Err(TurboQuantError::BitWidthMismatch {
expected: quantizer.bit_width,
got: keys.bit_width,
});
}
if query.len() != quantizer.dim {
return Err(TurboQuantError::DimensionMismatch {
expected: quantizer.dim,
got: query.len(),
});
}
let indices = keys
.packed_indices
.iter()
.map(BitPackedVector::unpack)
.collect::<Vec<_>>();
self.mse_stage_scores_from_indices(&indices, query)
}
pub fn mse_stage_scores_from_indices(
&self,
indices: &[Vec<u8>],
query: &[f64],
) -> Result<Vec<f64>> {
if indices.is_empty() {
return Ok(Vec::new());
}
if query.len() != self.dim {
return Err(TurboQuantError::DimensionMismatch {
expected: self.dim,
got: query.len(),
});
}
let batch_len = indices.len();
let packed = flatten_index_rows(indices, self.dim)?;
let indices = int_tensor_2d(&self.device, packed, batch_len, self.dim);
let rotated_hat = self.lookup_centroids(indices, batch_len)?;
let rotated_query = tensor_2d(
&self.device,
query.iter().map(|value| *value as f32).collect(),
1,
self.dim,
)
.matmul(self.rotation_t.clone())
.reshape([self.dim, 1]);
let scores = rotated_hat.matmul(rotated_query).squeeze::<1>();
tensor_to_f64_vec(scores)
}
fn quantize_rotated(&self, rotated: FloatTensor<2>, batch_len: usize) -> Result<IntTensor<2>> {
if self.boundary_count == 0 {
return Ok(IntTensor::zeros([batch_len, self.dim], &self.device));
}
let rotated = rotated.reshape([batch_len, self.dim, 1]);
let boundaries = self
.boundaries
.clone()
.reshape([1, 1, self.boundary_count])
.repeat_dim(0, batch_len)
.repeat_dim(1, self.dim);
let counts = rotated.greater_equal(boundaries).int().sum_dim(2);
Ok(counts.squeeze::<2>())
}
fn lookup_centroids(&self, indices: IntTensor<2>, batch_len: usize) -> Result<FloatTensor<2>> {
let flat_indices = indices.reshape([batch_len * self.dim]).float();
let one_hot: FloatTensor<2> = flat_indices.one_hot(self.level_count);
let centroids = self.centroids.clone().reshape([self.level_count, 1]);
let rotated = one_hot.matmul(centroids).reshape([batch_len, self.dim]);
Ok(rotated)
}
}
pub fn batch_quantize_mse_wgpu(
quantizer: &TurboQuantMSE,
vectors: &[Vec<f64>],
) -> Result<BatchQuantizedMSE> {
let runner = WgpuMseBatchRunner::new(quantizer)?;
let indices = runner.quantize_batch(quantizer, vectors)?;
let packed_indices = indices
.iter()
.map(|row| BitPackedVector::pack(row, quantizer.bit_width))
.collect::<Result<Vec<_>>>()?;
Ok(BatchQuantizedMSE {
packed_indices,
bit_width: quantizer.bit_width,
dim: quantizer.dim,
})
}
pub fn batch_dequantize_mse_wgpu(
quantizer: &TurboQuantMSE,
batch: &BatchQuantizedMSE,
) -> Result<Vec<Vec<f64>>> {
let runner = WgpuMseBatchRunner::new(quantizer)?;
runner.dequantize_batch(quantizer, batch)
}
pub fn batch_quantize_prod_wgpu(
quantizer: &TurboQuantProd,
vectors: &[Vec<f64>],
) -> Result<BatchQuantizedProd> {
let runner = WgpuMseBatchRunner::new(quantizer.mse_stage())?;
let (mse_indices, stage1_recon) =
runner.quantize_with_reconstruction(quantizer.mse_stage(), vectors)?;
let mut packed_mse_indices = Vec::with_capacity(vectors.len());
let mut packed_qjl_signs = Vec::with_capacity(vectors.len());
let mut residual_norms = Vec::with_capacity(vectors.len());
for (vector, stage1) in vectors.iter().zip(stage1_recon.iter()) {
let residual: Vec<f64> = vector
.iter()
.zip(stage1.iter())
.map(|(left, right)| left - right)
.collect();
let q_qjl: QJLQuantized = quantizer.qjl_stage().quantize(&residual)?;
let row = &mse_indices[packed_mse_indices.len()];
packed_mse_indices.push(BitPackedVector::pack(row, quantizer.bit_width - 1)?);
packed_qjl_signs.push(pack_signs(&q_qjl.signs));
residual_norms.push(q_qjl.residual_norm);
}
Ok(BatchQuantizedProd {
packed_mse_indices,
packed_qjl_signs,
residual_norms,
bit_width: quantizer.bit_width,
dim: quantizer.dim,
})
}
pub fn batch_estimate_inner_products_wgpu(
quantizer: &TurboQuantProd,
batch: &BatchQuantizedProd,
query: &[f64],
) -> Result<Vec<f64>> {
batch.validate_layout()?;
if batch.dim != quantizer.dim {
return Err(TurboQuantError::DimensionMismatch {
expected: quantizer.dim,
got: batch.dim,
});
}
if batch.bit_width != quantizer.bit_width {
return Err(TurboQuantError::BitWidthMismatch {
expected: quantizer.bit_width,
got: batch.bit_width,
});
}
if query.len() != quantizer.dim {
return Err(TurboQuantError::DimensionMismatch {
expected: quantizer.dim,
got: query.len(),
});
}
let runner = WgpuMseBatchRunner::new(quantizer.mse_stage())?;
let indices = batch
.packed_mse_indices
.iter()
.map(BitPackedVector::unpack)
.collect::<Vec<_>>();
let mut scores = runner.mse_stage_scores_from_indices(&indices, query)?;
for (index, score) in scores.iter_mut().enumerate() {
let qjl_signs = crate::bitpack::unpack_signs(&batch.packed_qjl_signs[index], batch.dim);
let correction = quantizer.qjl_stage().estimate_inner_product_parts(
&qjl_signs,
batch.residual_norms[index],
batch.dim,
query,
)?;
*score += correction;
}
Ok(scores)
}
pub fn batch_attention_scores_mse_wgpu(
quantizer: &TurboQuantMSE,
keys: &BatchQuantizedMSE,
query: &[f64],
) -> Result<Vec<f64>> {
let runner = WgpuMseBatchRunner::new(quantizer)?;
runner.attention_scores(quantizer, keys, query)
}
fn validate_vectors(quantizer: &TurboQuantMSE, vectors: &[Vec<f64>]) -> Result<()> {
for vector in vectors {
if vector.len() != quantizer.dim {
return Err(TurboQuantError::DimensionMismatch {
expected: quantizer.dim,
got: vector.len(),
});
}
validate_unit_vector(vector, "WGPU batch quantization input")?;
}
Ok(())
}
fn flatten_vectors(vectors: &[Vec<f64>]) -> Result<Vec<f32>> {
let dim = vectors.first().map_or(0, Vec::len);
let mut flat = Vec::with_capacity(vectors.len() * dim);
for vector in vectors {
if vector.len() != dim {
return Err(TurboQuantError::LengthMismatch {
context: "WGPU batch vector length".into(),
expected: dim,
got: vector.len(),
});
}
flat.extend(vector.iter().map(|value| *value as f32));
}
Ok(flat)
}
fn flatten_index_rows(indices: &[Vec<u8>], dim: usize) -> Result<Vec<i32>> {
let mut flat = Vec::with_capacity(indices.len() * dim);
for row in indices {
if row.len() != dim {
return Err(TurboQuantError::LengthMismatch {
context: "WGPU index row length".into(),
expected: dim,
got: row.len(),
});
}
flat.extend(row.iter().map(|index| *index as i32));
}
Ok(flat)
}
fn tensor_1d(device: &WgpuDevice, data: Vec<f32>) -> FloatTensor<1> {
let len = data.len();
Tensor::<GpuBackend, 1>::from_data(TensorData::new(data, [len]), device)
}
fn tensor_2d(device: &WgpuDevice, data: Vec<f32>, rows: usize, cols: usize) -> FloatTensor<2> {
Tensor::<GpuBackend, 2>::from_data(TensorData::new(data, [rows, cols]), device)
}
fn int_tensor_2d(device: &WgpuDevice, data: Vec<i32>, rows: usize, cols: usize) -> IntTensor<2> {
Tensor::<GpuBackend, 2, Int>::from_data(TensorData::new(data, [rows, cols]), device)
}
fn tensor_to_u8_rows(tensor: IntTensor<2>, dim: usize) -> Result<Vec<Vec<u8>>> {
let data = tensor
.into_data()
.into_vec::<i32>()
.map_err(|error| TurboQuantError::Internal(format!("{error:?}")))?;
Ok(data
.chunks(dim)
.map(|row| row.iter().map(|index| *index as u8).collect())
.collect())
}
fn tensor_to_f64_rows(tensor: FloatTensor<2>, dim: usize) -> Result<Vec<Vec<f64>>> {
let data = tensor
.into_data()
.into_vec::<f32>()
.map_err(|error| TurboQuantError::Internal(format!("{error:?}")))?;
Ok(data
.chunks(dim)
.map(|row| row.iter().map(|value| *value as f64).collect())
.collect())
}
fn tensor_to_f64_vec(tensor: FloatTensor<1>) -> Result<Vec<f64>> {
Ok(tensor
.into_data()
.into_vec::<f32>()
.map_err(|error| TurboQuantError::Internal(format!("{error:?}")))?
.into_iter()
.map(|value| value as f64)
.collect())
}
fn flatten_matrix_rows(matrix: &nalgebra::DMatrix<f64>) -> Vec<f32> {
let mut flat = Vec::with_capacity(matrix.nrows() * matrix.ncols());
for row in 0..matrix.nrows() {
for col in 0..matrix.ncols() {
flat.push(matrix[(row, col)] as f32);
}
}
flat
}
fn flatten_matrix_transpose_rows(matrix: &nalgebra::DMatrix<f64>) -> Vec<f32> {
let mut flat = Vec::with_capacity(matrix.nrows() * matrix.ncols());
for row in 0..matrix.nrows() {
for col in 0..matrix.ncols() {
flat.push(matrix[(col, row)] as f32);
}
}
flat
}
#[cfg(test)]
mod tests {
use approx::assert_abs_diff_eq;
use super::{
batch_attention_scores_mse_wgpu, batch_dequantize_mse_wgpu,
batch_estimate_inner_products_wgpu, batch_quantize_mse_wgpu, batch_quantize_prod_wgpu,
};
use crate::batch::{
batch_attention_scores_mse, batch_dequantize_mse, batch_estimate_inner_products,
batch_quantize_mse, batch_quantize_prod,
};
use crate::error::TurboQuantError;
use crate::turboquant_mse::TurboQuantMSE;
use crate::turboquant_prod::TurboQuantProd;
use crate::utils::normalize;
fn random_unit_vectors(dim: usize, count: usize, seed: u64) -> Vec<Vec<f64>> {
use rand::SeedableRng;
use rand_distr::{Distribution, Normal};
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
let normal = Normal::new(0.0, 1.0).unwrap();
(0..count)
.map(|_| {
let raw: Vec<f64> = (0..dim).map(|_| normal.sample(&mut rng)).collect();
normalize(&raw).unwrap()
})
.collect()
}
#[test]
fn wgpu_mse_roundtrip_matches_cpu() {
let dim = 32;
let quantizer = TurboQuantMSE::new(dim, 4, 42).unwrap();
let vectors = random_unit_vectors(dim, 6, 7);
let query = &random_unit_vectors(dim, 1, 99)[0];
let cpu_batch = batch_quantize_mse(&quantizer, &vectors).unwrap();
let gpu_batch = batch_quantize_mse_wgpu(&quantizer, &vectors).unwrap();
let cpu_recon = batch_dequantize_mse(&quantizer, &cpu_batch).unwrap();
let gpu_recon = batch_dequantize_mse_wgpu(&quantizer, &gpu_batch).unwrap();
let cpu_scores = batch_attention_scores_mse(&quantizer, &cpu_batch, query).unwrap();
let gpu_scores = batch_attention_scores_mse_wgpu(&quantizer, &gpu_batch, query).unwrap();
assert_eq!(gpu_batch.len(), cpu_batch.len());
assert_eq!(gpu_recon.len(), cpu_recon.len());
assert_eq!(gpu_scores.len(), cpu_scores.len());
for (cpu_row, gpu_row) in cpu_recon.iter().zip(gpu_recon.iter()) {
for (cpu_value, gpu_value) in cpu_row.iter().zip(gpu_row.iter()) {
assert_abs_diff_eq!(cpu_value, gpu_value, epsilon = 1e-4);
}
}
for (cpu_score, gpu_score) in cpu_scores.iter().zip(gpu_scores.iter()) {
assert_abs_diff_eq!(cpu_score, gpu_score, epsilon = 1e-4);
}
}
#[test]
fn wgpu_prod_scores_match_cpu() {
let dim = 32;
let quantizer = TurboQuantProd::new(dim, 4, 7).unwrap();
let vectors = random_unit_vectors(dim, 6, 3);
let query = &random_unit_vectors(dim, 1, 21)[0];
let cpu_batch = batch_quantize_prod(&quantizer, &vectors).unwrap();
let gpu_batch = batch_quantize_prod_wgpu(&quantizer, &vectors).unwrap();
let cpu_scores = batch_estimate_inner_products(&quantizer, &cpu_batch, query).unwrap();
let gpu_scores = batch_estimate_inner_products_wgpu(&quantizer, &gpu_batch, query).unwrap();
assert_eq!(gpu_batch.len(), cpu_batch.len());
assert_eq!(gpu_scores.len(), cpu_scores.len());
for (cpu_score, gpu_score) in cpu_scores.iter().zip(gpu_scores.iter()) {
assert_abs_diff_eq!(cpu_score, gpu_score, epsilon = 1e-4);
}
}
#[test]
fn wgpu_prod_rejects_invalid_batch_layout() {
let dim = 32;
let quantizer = TurboQuantProd::new(dim, 4, 7).unwrap();
let vectors = random_unit_vectors(dim, 2, 3);
let query = &random_unit_vectors(dim, 1, 21)[0];
let mut batch = batch_quantize_prod_wgpu(&quantizer, &vectors).unwrap();
batch.packed_qjl_signs[0].pop();
assert!(matches!(
batch_estimate_inner_products_wgpu(&quantizer, &batch, query),
Err(TurboQuantError::LengthMismatch { .. })
));
}
}