mod codebooks;
mod rotation;
use candle_core::{Device, Result, Tensor};
use super::config::{CacheConfig, QUANT_BLOCK_SIZE};
pub struct GpuPrecomputed {
pub rotation_fwd: Tensor,
pub rotation_inv: Tensor,
pub centroids: Tensor,
pub boundaries: Tensor,
pub outlier_centroids: Tensor,
pub outlier_boundaries: Tensor,
pub outlier_outer_centroid: f64,
pub scale_sign_tensor: Tensor,
pub qjl_rademacher: Option<Tensor>,
}
impl GpuPrecomputed {
pub fn new(config: &CacheConfig, device: &Device) -> Result<Self> {
if config.bits < 3 || config.bits > 4 {
return Err(super::cache_err(format!(
"unsupported bits={}, expected 3 or 4",
config.bits
)));
}
#[cfg(feature = "cuda")]
{
const CUDA_MAX_THREADS_PER_BLOCK: usize = 1024;
if device.is_cuda() && config.head_dim > CUDA_MAX_THREADS_PER_BLOCK {
return Err(super::cache_err(format!(
"head_dim {} exceeds the CUDA kernel thread-block limit \
({CUDA_MAX_THREADS_PER_BLOCK}); launching a kernel with \
head_dim threads per block would exceed the device maximum.",
config.head_dim
)));
}
}
let block_dim = QUANT_BLOCK_SIZE;
let polar_bits = config.bits - 1;
let head_dim = config.head_dim;
let outlier_blocks = config.outlier_blocks;
let norm_mode = config.norm_mode;
let qjl_enabled = config.qjl_enabled();
let (rotation_fwd, rotation_inv) = rotation::build_rotation_matrices(block_dim, device)?;
let (centroids, boundaries, outlier_centroids, outlier_boundaries) =
codebooks::build_codebooks(polar_bits, block_dim, norm_mode, device)?;
let outlier_outer_centroid = outlier_centroids.max(0)?.to_scalar::<f32>()? as f64;
let scale_sign_tensor = build_scale_sign_tensor(head_dim, outlier_blocks, device)?;
let qjl_rademacher = if qjl_enabled {
Some(build_rademacher_matrix(head_dim, device)?)
} else {
None
};
Ok(Self {
rotation_fwd,
rotation_inv,
centroids,
boundaries,
outlier_centroids,
outlier_boundaries,
outlier_outer_centroid,
scale_sign_tensor,
qjl_rademacher,
})
}
}
fn build_scale_sign_tensor(
head_dim: usize,
outlier_blocks: usize,
device: &Device,
) -> Result<Tensor> {
let num_blocks = head_dim / QUANT_BLOCK_SIZE;
let effective_outlier = outlier_blocks.min(num_blocks);
let mut signs = vec![1.0_f32; num_blocks];
for sign in signs.iter_mut().take(effective_outlier) {
*sign = -1.0;
}
Tensor::from_vec(signs, (1, num_blocks), device)
}
fn build_rademacher_matrix(head_dim: usize, device: &Device) -> Result<Tensor> {
use super::config::DEFAULT_QJL_SEED;
let mut rdata = Vec::with_capacity(head_dim * head_dim);
for row in 0..head_dim {
let row_vec = crate::qjl::generate_rademacher_row(head_dim, DEFAULT_QJL_SEED, row);
rdata.extend_from_slice(&row_vec);
}
Tensor::from_vec(rdata, (head_dim, head_dim), device)
}