use core::f32;
use std::path::Path;
use faer::{Col, Mat, Row};
use log::debug;
use serde::{Deserialize, Serialize};
use crate::consts::{DEFAULT_X_DOT_PRODUCT, EPSILON, SCALAR, THETA_LOG_DIM};
use crate::metrics::METRICS;
use crate::rerank::new_re_ranker;
use crate::utils::{
asymmetric_binary_dot_product, gen_random_bias, gen_random_qr_orthogonal,
kmeans_nearest_cluster, l2_squared_distance, matrix_from_fvecs, min_max_residual, project,
read_u64_vecs, read_vecs, scalar_quantize, vector_binarize_one, vector_binarize_query,
vector_binarize_u64, write_matrix, write_vecs,
};
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
#[repr(C)]
pub struct Factor {
pub factor_ip: f32,
pub factor_ppc: f32,
pub error_bound: f32,
pub center_distance_square: f32,
}
impl Factor {
fn into_vec(self) -> Vec<f32> {
vec![
self.factor_ip,
self.factor_ppc,
self.error_bound,
self.center_distance_square,
]
}
}
impl From<Vec<f32>> for Factor {
fn from(f32s: Vec<f32>) -> Self {
assert_eq!(f32s.len(), 4);
Self {
factor_ip: f32s[0],
factor_ppc: f32s[1],
error_bound: f32s[2],
center_distance_square: f32s[3],
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct RaBitQ {
dim: u32,
base: Mat<f32>,
orthogonal: Mat<f32>,
centroids: Mat<f32>,
rand_bias: Vec<f32>,
offsets: Vec<u32>,
map_ids: Vec<u32>,
x_binary_vec: Vec<u64>,
factors: Vec<Factor>,
}
impl RaBitQ {
pub fn load_from_json(path: &impl AsRef<Path>) -> Self {
serde_json::from_slice(&std::fs::read(path).expect("open json error"))
.expect("deserialize error")
}
pub fn dump_to_json(&self, path: &impl AsRef<Path>) {
std::fs::write(path, serde_json::to_string(&self).expect("serialize error"))
.expect("write json error");
}
pub fn load_from_dir(path: &Path) -> Self {
let orthogonal = matrix_from_fvecs(&path.join("orthogonal.fvecs"));
let centroids = matrix_from_fvecs(&path.join("centroids.fvecs"));
let offsets_ids =
read_vecs::<u32>(&path.join("offsets_ids.ivecs")).expect("open offsets_ids error");
let offsets = offsets_ids.first().expect("offsets is empty").clone();
let map_ids = offsets_ids.last().expect("map_ids is empty").clone();
let factors = read_vecs::<f32>(&path.join("factors.fvecs"))
.expect("open factors error")
.into_iter()
.flatten()
.collect::<Vec<f32>>()
.chunks_exact(4)
.map(|f| f.to_vec().into())
.collect();
let x_binary_vec = read_u64_vecs(&path.join("x_binary_vec.u64vecs"))
.expect("open x_binary_vec error")
.into_iter()
.flatten()
.collect();
let dim = orthogonal.nrows();
assert!(dim % 64 == 0);
let base = matrix_from_fvecs(&path.join("base.fvecs"))
.transpose()
.to_owned();
Self {
dim: dim as u32,
base,
orthogonal,
centroids,
rand_bias: gen_random_bias(dim),
offsets,
map_ids,
x_binary_vec,
factors,
}
}
pub fn dump_to_dir(&self, path: &Path) {
std::fs::create_dir_all(path).expect("create dir error");
write_matrix(&path.join("base.fvecs"), &self.base.transpose()).expect("write base error");
write_matrix(&path.join("orthogonal.fvecs"), &self.orthogonal.as_ref())
.expect("write orthogonal error");
write_matrix(&path.join("centroids.fvecs"), &self.centroids.as_ref())
.expect("write centroids error");
write_vecs(
&path.join("offsets_ids.ivecs"),
&[&self.offsets, &self.map_ids],
)
.expect("write offsets_ids error");
write_vecs(
&path.join("factors.fvecs"),
&[&self
.factors
.iter()
.flat_map(|f| f.into_vec())
.collect::<Vec<_>>()],
)
.expect("write factors error");
write_vecs(
&path.join("x_binary_vec.u64vecs"),
&[&self.x_binary_vec],
)
.expect("write x_binary_vec error");
}
pub fn from_path(base_path: &Path, centroid_path: &Path) -> Self {
let mut base = matrix_from_fvecs(base_path);
let n = base.nrows();
let mut dim = base.ncols();
let mut centroids = matrix_from_fvecs(centroid_path);
let k = centroids.nrows();
assert!(dim == centroids.ncols());
if dim % 64 != 0 {
let dim_pad = dim.div_ceil(64) * 64;
base = Mat::from_fn(n, dim_pad, |i, j| match j < dim {
true => base.read(i, j),
false => 0.0,
});
centroids = Mat::from_fn(k, dim_pad, |i, j| match j < dim {
true => centroids.read(i, j),
false => 0.0,
});
dim = dim_pad;
}
debug!("n: {}, dim: {}, k: {}", n, dim, k);
let orthogonal = gen_random_qr_orthogonal(dim);
let rand_bias = gen_random_bias(dim);
debug!("projection x & c...");
let x_projected = (&base * &orthogonal).transpose().to_owned();
let centroids = (centroids * &orthogonal).transpose().to_owned();
let dim_sqrt = (dim as f32).sqrt();
let mut labels = vec![Vec::new(); k];
let mut factors = vec![Factor::default(); n];
let mut x_c_distance = vec![0.0; n];
let mut x_binary_vec: Vec<Vec<u64>> = Vec::with_capacity(n);
let mut x_signed_vec: Vec<Col<f32>> = Vec::with_capacity(n);
let mut x_dot_product = vec![0.0; n];
for (i, xp) in x_projected.col_iter().enumerate() {
if i % 5000 == 0 {
debug!("\t> preprocessing {}...", i);
}
let (min_label, min_dist) = kmeans_nearest_cluster(¢roids.as_ref(), &xp);
labels[min_label].push((i as u32, min_dist));
let x_c_quantized = xp - centroids.col(min_label);
x_c_distance[i] = x_c_quantized.norm_l2();
factors[i].center_distance_square = x_c_distance[i].powi(2);
x_binary_vec.push(vector_binarize_u64(&x_c_quantized.as_ref()));
x_signed_vec.push(vector_binarize_one(&x_c_quantized.as_ref()));
let norm = x_c_distance[i] * dim_sqrt;
x_dot_product[i] = if norm.is_normal() {
x_c_quantized.as_ref().adjoint() * &x_signed_vec[i] / norm
} else {
DEFAULT_X_DOT_PRODUCT
};
}
debug!("computing factors...");
let error_base = 2.0 * EPSILON / (dim as f32 - 1.0).sqrt();
let one_vec: Row<f32> = Row::ones(dim);
for i in 0..n {
let x_c_over_ip = x_c_distance[i] / x_dot_product[i];
let factor = &mut factors[i];
factor.error_bound =
error_base * (x_c_over_ip * x_c_over_ip - factor.center_distance_square).sqrt();
factor.factor_ip = -2.0 / dim_sqrt * x_c_over_ip;
factor.factor_ppc = factor.factor_ip * (&one_vec * &x_signed_vec[i]);
}
let labels = labels
.into_iter()
.map(|mut v| {
v.sort_by(|a, b| a.1.partial_cmp(&b.1).expect("failed to compare labels"));
v.into_iter().map(|(i, _)| i).collect::<Vec<_>>()
})
.collect::<Vec<_>>();
debug!("sort by labels...");
let mut offsets = vec![0; k + 1];
for i in 0..k {
offsets[i + 1] = offsets[i] + labels[i].len() as u32;
}
let flat_labels: Vec<u32> = labels.into_iter().flatten().collect();
let base = Mat::from_fn(n, dim, |i, j| base.read(flat_labels[i] as usize, j))
.transpose()
.to_owned();
let x_binary_vec = flat_labels
.iter()
.flat_map(|i| x_binary_vec[*i as usize].clone())
.collect();
let factors = flat_labels.iter().map(|&i| factors[i as usize]).collect();
Self {
dim: dim as u32,
base,
orthogonal,
rand_bias,
offsets,
map_ids: flat_labels,
centroids,
x_binary_vec,
factors,
}
}
pub fn query(
&self,
query: &[f32],
probe: usize,
topk: usize,
heuristic_rank: bool,
) -> Vec<(f32, u32)> {
assert_eq!(self.dim as usize, query.len().div_ceil(64) * 64);
let mut query_vec = query.to_vec();
if query.len() < self.dim as usize {
query_vec.extend_from_slice(&vec![0.0; self.dim as usize - query.len()]);
}
let y_projected = project(&query_vec, &self.orthogonal.as_ref());
let k = self.centroids.shape().1;
let mut lists = Vec::with_capacity(k);
for (i, centroid) in self.centroids.col_iter().enumerate() {
let dist = l2_squared_distance(
centroid
.try_as_slice()
.expect("failed to get centroid slice"),
y_projected.as_slice(),
);
lists.push((dist, i));
}
let length = probe.min(k);
lists.select_nth_unstable_by(length - 1, |a, b| a.0.total_cmp(&b.0));
lists.truncate(length);
lists.sort_by(|a, b| a.0.total_cmp(&b.0));
let mut re_ranker = new_re_ranker(&query_vec, topk, heuristic_rank);
let mut residual = vec![0f32; self.dim as usize];
let mut quantized = vec![0u8; (self.dim as usize).div_ceil(64) * 64];
let mut rough_distances = Vec::new();
let mut binary_vec = vec![0u64; (self.dim).div_ceil(64) as usize * THETA_LOG_DIM as usize];
for &(dist, i) in lists[..length].iter() {
let (lower_bound, upper_bound) =
min_max_residual(&mut residual, &y_projected.as_ref(), &self.centroids.col(i));
let delta = (upper_bound - lower_bound) * SCALAR;
let one_over_delta = delta.recip();
let scalar_sum = scalar_quantize(
&mut quantized,
&residual,
&self.rand_bias,
lower_bound,
one_over_delta,
);
binary_vec.iter_mut().for_each(|element| *element = 0);
vector_binarize_query(&quantized, &mut binary_vec);
self.calculate_rough_distance(
dist,
&binary_vec,
lower_bound,
scalar_sum as f32,
delta,
i,
&mut rough_distances,
);
re_ranker.rank_batch(&rough_distances, &self.base.as_ref(), &self.map_ids);
rough_distances.clear();
}
METRICS.add_query_count(1);
re_ranker.get_result()
}
#[allow(clippy::too_many_arguments)]
fn calculate_rough_distance(
&self,
y_c_distance_square: f32,
y_binary_vec: &[u64],
lower_bound: f32,
scalar_sum: f32,
delta: f32,
cluster_id: usize,
rough_distances: &mut Vec<(f32, u32)>,
) {
let dist_sqrt = y_c_distance_square.sqrt();
let binary_offset = y_binary_vec.len() / THETA_LOG_DIM as usize;
for j in self.offsets[cluster_id]..self.offsets[cluster_id + 1] {
let ju = j as usize;
let factor = &self.factors[ju];
rough_distances.push((
(factor.center_distance_square
+ y_c_distance_square
+ lower_bound * factor.factor_ppc
+ (2.0
* asymmetric_binary_dot_product(
&self.x_binary_vec[ju * binary_offset..(ju + 1) * binary_offset],
y_binary_vec,
) as f32
- scalar_sum)
* factor.factor_ip
* delta
- factor.error_bound * dist_sqrt),
j,
));
}
}
}