use ailake_core::{AilakeError, AilakeResult, RowId, VectorMetric};
use ailake_vec::{
exact_distance,
rabitq::{bits_from_signs, encode_batch, RaBitQCodebook, RaBitQVec},
};
use half::f16;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RaBitQConfig {
pub seed: u64,
#[serde(default = "default_keep_raw")]
pub keep_raw: bool,
}
fn default_keep_raw() -> bool {
true
}
impl Default for RaBitQConfig {
fn default() -> Self {
Self {
seed: 0,
keep_raw: true,
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct RaBitQIndex {
pub codebook: RaBitQCodebook,
pub entries: Vec<RaBitQVec>,
pub row_ids: Vec<u64>,
pub metric: VectorMetric,
pub dim: u32,
pub raw_f16: Option<Vec<f16>>,
}
impl RaBitQIndex {
pub fn build(
row_ids: &[RowId],
vectors: &[Vec<f32>],
metric: VectorMetric,
config: RaBitQConfig,
keep_raw: bool,
) -> AilakeResult<Self> {
if vectors.is_empty() {
return Err(AilakeError::InvalidArgument(
"RaBitQIndex::build requires at least one vector".into(),
));
}
let dim = vectors[0].len();
let cb = RaBitQCodebook::new(dim, config.seed);
let entries = encode_batch(&cb, vectors);
let raw_f16 = if keep_raw {
Some(
vectors
.iter()
.flat_map(|v| v.iter().map(|&x| f16::from_f32(x)))
.collect(),
)
} else {
None
};
Ok(Self {
codebook: cb,
entries,
row_ids: row_ids.iter().map(|r| r.as_u64()).collect(),
metric,
dim: dim as u32,
raw_f16,
})
}
pub fn node_count(&self) -> u64 {
self.row_ids.len() as u64
}
pub fn search(
&self,
query: &[f32],
top_k: usize,
rerank_factor: Option<usize>,
) -> Vec<(RowId, f32)> {
if self.entries.is_empty() {
return vec![];
}
debug_assert!(
self.codebook.is_ready(),
"RaBitQCodebook not initialized — call rebuild_proj() after deserialization"
);
let (q_proj, q_scale) = self.codebook.prepare_query(query);
let b_q = bits_from_signs(&q_proj);
let n = self.entries.len();
let q_norm = query.iter().map(|x| x * x).sum::<f32>().sqrt();
let mut scored: Vec<(usize, f32)> = self
.entries
.iter()
.enumerate()
.map(|(i, entry)| {
let ip = self.codebook.estimate_ip_binary(&b_q, q_scale, entry);
let dist = match self.metric {
VectorMetric::Cosine | VectorMetric::NormalizedCosine => 1.0 - ip,
VectorMetric::DotProduct => -ip * q_norm * entry.norm,
VectorMetric::Euclidean => {
let norm_x = entry.norm;
(q_norm * q_norm + norm_x * norm_x - 2.0 * ip * q_norm * norm_x)
.max(0.0)
.sqrt()
}
};
(i, dist)
})
.collect();
let candidates = rerank_factor.unwrap_or(1).max(1) * top_k;
let cmp = |a: &(usize, f32), b: &(usize, f32)| {
a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)
};
if candidates < n {
scored.select_nth_unstable_by(candidates - 1, cmp);
scored.truncate(candidates);
}
scored.sort_unstable_by(cmp);
let rerank_slice = &scored[..candidates.min(scored.len())];
if let Some(ref raw) = self.raw_f16 {
let dim = self.dim as usize;
let mut reranked: Vec<(usize, f32)> = rerank_slice
.iter()
.map(|&(i, _)| {
let db_f16 = &raw[i * dim..(i + 1) * dim];
let db_f32: Vec<f32> = db_f16.iter().map(|x| x.to_f32()).collect();
let d = exact_distance(self.metric, query, &db_f32);
(i, d)
})
.collect();
reranked.sort_unstable_by(cmp);
reranked
.into_iter()
.take(top_k)
.map(|(i, d)| (RowId::new(self.row_ids[i]), d))
.collect()
} else {
rerank_slice
.iter()
.take(top_k)
.map(|&(i, d)| (RowId::new(self.row_ids[i]), d))
.collect()
}
}
}
pub struct RaBitQSerializer;
impl RaBitQSerializer {
pub fn to_bytes(index: &RaBitQIndex) -> AilakeResult<Vec<u8>> {
bincode::serialize(index).map_err(|e| AilakeError::Bincode(e.to_string()))
}
pub fn from_bytes(bytes: &[u8]) -> AilakeResult<RaBitQIndex> {
let mut idx: RaBitQIndex =
bincode::deserialize(bytes).map_err(|e| AilakeError::Bincode(e.to_string()))?;
idx.codebook.rebuild_proj();
Ok(idx)
}
}
#[cfg(test)]
mod tests {
use super::*;
use ailake_vec::cosine_distance;
use rand::{rngs::StdRng, Rng, SeedableRng};
fn unit_vecs(n: usize, dim: usize, seed: u64) -> Vec<Vec<f32>> {
let mut rng = StdRng::seed_from_u64(seed);
(0..n)
.map(|_| {
let v: Vec<f32> = (0..dim).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
v.iter().map(|x| x / norm).collect()
})
.collect()
}
#[test]
fn top1_is_self() {
let dim = 32;
let vecs = unit_vecs(50, dim, 1);
let row_ids: Vec<RowId> = (0..50u64).map(RowId::new).collect();
let idx = RaBitQIndex::build(
&row_ids,
&vecs,
VectorMetric::Cosine,
RaBitQConfig::default(),
false,
)
.unwrap();
let query = vecs[0].clone();
let results = idx.search(&query, 1, None);
assert_eq!(results.len(), 1);
assert_eq!(results[0].0, RowId::new(0));
}
#[test]
fn rerank_improves_recall() {
let dim = 64;
let n = 200;
let vecs = unit_vecs(n, dim, 42);
let row_ids: Vec<RowId> = (0..n as u64).map(RowId::new).collect();
let idx = RaBitQIndex::build(
&row_ids,
&vecs,
VectorMetric::Cosine,
RaBitQConfig::default(),
true,
)
.unwrap();
let query = vecs[5].clone();
let mut gt: Vec<(f32, u64)> = vecs
.iter()
.enumerate()
.map(|(i, v)| (cosine_distance(&query, v), i as u64))
.collect();
gt.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
let gt_top5: std::collections::HashSet<u64> =
gt.iter().take(5).map(|(_, id)| *id).collect();
let results_reranked = idx.search(&query, 5, Some(4));
let found: std::collections::HashSet<u64> =
results_reranked.iter().map(|(id, _)| id.as_u64()).collect();
let recall = found.intersection(>_top5).count() as f64 / 5.0;
assert!(recall >= 0.6, "recall@5 with reranking = {recall:.2}");
}
#[test]
fn serialization_roundtrip() {
let dim = 32;
let vecs = unit_vecs(20, dim, 7);
let row_ids: Vec<RowId> = (0..20u64).map(RowId::new).collect();
let idx = RaBitQIndex::build(
&row_ids,
&vecs,
VectorMetric::Cosine,
RaBitQConfig::default(),
false,
)
.unwrap();
let bytes = RaBitQSerializer::to_bytes(&idx).unwrap();
let idx2 = RaBitQSerializer::from_bytes(&bytes).unwrap();
let q = vecs[0].clone();
let r1 = idx.entries[0].code.clone();
let _r2 = idx2.search(&q, 1, None);
assert_eq!(r1, idx2.entries[0].code);
}
}