use crate::index::morton::{morton_encode_3d, quantize};
use crate::index::bq::BqSignature;
pub const PCA_DIMS: usize = 3;
#[inline]
pub fn dot(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
pub fn normalize(v: &mut Vec<f32>) {
let norm = dot(v, v).sqrt();
if norm > 1e-9 {
v.iter_mut().for_each(|x| *x /= norm);
}
}
pub fn cosine_sim(a: &[f32], b: &[f32]) -> f32 {
dot(a, b).clamp(-1.0, 1.0)
}
fn power_iteration(data: &[&[f32]], dim: usize, iters: usize) -> Vec<f32> {
let mut v: Vec<f32> = (0..dim).map(|i| {
let hash = (i as u32).wrapping_mul(2654435761);
(hash as f32 / u32::MAX as f32) - 0.5
}).collect();
normalize(&mut v);
for _ in 0..iters {
let mut new_v = vec![0.0f32; dim];
for row in data {
let proj = dot(row, &v);
for (nv, rv) in new_v.iter_mut().zip(row.iter()) {
*nv += proj * rv;
}
}
normalize(&mut new_v);
v = new_v;
}
v
}
pub fn compute_pca_basis(data: &[&[f32]], dim: usize) -> Vec<Vec<f32>> {
let mut residual: Vec<Vec<f32>> = data.iter().map(|v| v.to_vec()).collect();
let mut basis = Vec::with_capacity(PCA_DIMS);
for _ in 0..PCA_DIMS {
let res_views: Vec<&[f32]> = residual.iter().map(|v| v.as_slice()).collect();
let pc = power_iteration(&res_views, dim, 20);
for row in residual.iter_mut() {
let proj = dot(row, &pc);
for (r, p) in row.iter_mut().zip(pc.iter()) {
*r -= proj * p;
}
}
basis.push(pc);
}
basis
}
pub fn kmeans(data: &[&[f32]], k: usize, dim: usize, iters: usize) -> Vec<Vec<f32>> {
let step = (data.len() / k).max(1);
let mut centers: Vec<Vec<f32>> = (0..k)
.map(|i| data[(i * step) % data.len()].to_vec())
.collect();
for _ in 0..iters {
let mut sums: Vec<Vec<f32>> = vec![vec![0.0; dim]; k];
let mut counts: Vec<usize> = vec![0; k];
for row in data {
let best = centers.iter()
.enumerate()
.map(|(i, c)| (i, dot(row, c))) .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
.map(|(i, _)| i)
.unwrap_or(0);
for (s, r) in sums[best].iter_mut().zip(row.iter()) {
*s += r;
}
counts[best] += 1;
}
for i in 0..k {
if counts[i] > 0 {
let inv = 1.0 / counts[i] as f32;
let mut c: Vec<f32> = sums[i].iter().map(|x| x * inv).collect();
normalize(&mut c);
centers[i] = c;
}
}
}
centers
}
pub fn compute_sequence_id(
vec: &[f32],
pca_basis: &[Vec<f32>], centers: &[Vec<f32>], bq: &BqSignature,
) -> u64 {
let (cluster_id, _) = centers.iter()
.enumerate()
.map(|(i, c)| (i, dot(vec, c)))
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
.unwrap_or((0, 0.0));
let center = ¢ers[cluster_id];
let residual: Vec<f32> = vec.iter().zip(center.iter())
.map(|(v, c)| v - c)
.collect();
let px = dot(&residual, &pca_basis[0]);
let py = dot(&residual, &pca_basis[1]);
let pz = dot(&residual, &pca_basis[2]);
let qx = quantize(px, 13);
let qy = quantize(py, 13);
let qz = quantize(pz, 13);
let morton = morton_encode_3d(qx, qy, qz);
let bq_prefix = (bq.data[0] >> 48) as u64;
((cluster_id as u64) << 56) | (morton << 16) | bq_prefix
}
#[repr(C)]
#[derive(Copy, Clone, Debug, bytemuck::Pod, bytemuck::Zeroable)]
pub struct SeqEntry {
pub seq_id: u64,
pub phys_idx: u64,
pub bq: BqSignature,
}
#[repr(C)]
#[derive(Clone, Copy, Debug, bytemuck::Pod, bytemuck::Zeroable)]
pub struct ErpcParams {
pub k_clusters: u64,
pub probe_count: u64,
pub bq_refined_count: u64,
pub wing_scale: f64,
}
impl ErpcParams {
pub fn compute_for(n: usize, _dim: usize, effort: f32) -> Self {
let n_f64 = n as f64;
let p = 1.2;
let n_p = n_f64.powf(p);
let k_half_p = 90_000_f64.powf(p);
let k_clusters = 16.max((512.0 * n_p / (k_half_p + n_p)) as usize);
let probe_count = 3.max((1.8 * (k_clusters as f64).ln()) as usize);
let refine_max = 1200.0 + effort as f64 * 3000.0;
let refine_half_p = 200_000_f64.powf(p);
let bq_refined_count = 100.max((refine_max * n_p / (refine_half_p + n_p)) as usize);
let wing_scale = 2.0 + (effort as f64 * 15.0);
Self {
k_clusters: k_clusters as u64,
probe_count: probe_count as u64,
bq_refined_count: bq_refined_count as u64,
wing_scale,
}
}
}
pub const WING_EXPONENT: f64 = 0.6;
pub const WING_MIN: usize = 50;
#[inline]
pub fn adaptive_wing(seg_size: usize, wing_scale: f64) -> usize {
let raw = (seg_size as f64).powf(WING_EXPONENT) * wing_scale;
(raw as usize).max(WING_MIN).min(seg_size) }
pub struct ErpcIndex {
pub pca_basis: Vec<Vec<f32>>, pub centers: Vec<Vec<f32>>, pub sequence: Vec<SeqEntry>, pub dim: usize,
pub params: ErpcParams,
}
impl ErpcIndex {
pub fn build<T: crate::VectorType>(flat_data: &[T], dim: usize, effort: f32) -> Self {
let n = flat_data.len() / dim;
let mut refs = Vec::with_capacity(n);
for chunk in flat_data.chunks(dim) {
refs.push(chunk);
}
let f32_refs: Vec<Vec<f32>> = refs.iter().map(|&c| c.iter().map(|x| x.to_f32()).collect()).collect();
let data: Vec<&[f32]> = f32_refs.iter().map(|r| r.as_slice()).collect();
let data = &data;
let params = ErpcParams::compute_for(n, dim, effort);
eprintln!("[ERPC] Configuraed parameters: {:?}", params);
eprintln!("[ERPC] Step 1: 训练 K-Means 聚类中心 (K={})...", params.k_clusters as usize);
let centers = kmeans(data, params.k_clusters as usize, dim, 10);
eprintln!("[ERPC] Step 2: 计算全量残差 (vec - nearest_center)...");
let residuals: Vec<Vec<f32>> = data.iter().map(|vec| {
let (best_idx, _) = centers.iter()
.enumerate()
.map(|(i, c)| (i, dot(vec, c)))
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
.unwrap_or((0, 0.0));
vec.iter().zip(centers[best_idx].iter())
.map(|(v, c)| v - c)
.collect()
}).collect();
eprintln!("[ERPC] Step 3: 在残差空间上训练 PCA 基向量...");
let res_views: Vec<&[f32]> = residuals.iter().map(|v| v.as_slice()).collect();
let pca_basis = compute_pca_basis(&res_views, dim);
eprintln!("[ERPC] Step 4: 计算全量 Sequence ID + BQ 签名...");
let mut sequence: Vec<SeqEntry> = data.iter().enumerate().map(|(i, vec)| {
let bq = BqSignature::from_vector(vec);
let seq_id = compute_sequence_id(vec, &pca_basis, ¢ers, &bq);
SeqEntry { seq_id, phys_idx: i as u64, bq }
}).collect();
sequence.sort_unstable_by_key(|e| e.seq_id);
Self { pca_basis, centers, sequence, dim, params }
}
pub fn search<T: crate::VectorType>(
&self,
query: &[f32],
flat_data: &[T],
top_k: usize,
) -> Vec<(usize, f32)> {
let bq_query = BqSignature::from_vector(query);
let seq_query = compute_sequence_id(query, &self.pca_basis, &self.centers, &bq_query);
let mut cluster_scores: Vec<(usize, f32)> = self.centers.iter()
.enumerate()
.map(|(i, c)| (i, dot(query, c)))
.collect();
cluster_scores.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
let probe_clusters: Vec<usize> = cluster_scores.iter()
.take(self.params.probe_count as usize)
.map(|(i, _)| *i)
.collect();
let mut global_bq_candidates: Vec<(usize, u32)> = Vec::new();
for &cluster_id in &probe_clusters {
let cluster_lo = (cluster_id as u64) << 56;
let cluster_hi = cluster_lo | 0x00ffffffffffffff_u64;
let seg_start = self.sequence.partition_point(|e| e.seq_id < cluster_lo);
let seg_end = self.sequence.partition_point(|e| e.seq_id <= cluster_hi);
if seg_start >= seg_end { continue; }
let morton_target = seq_query & 0x00ffffffffffffff_u64 | cluster_lo;
let morton_pos = self.sequence[seg_start..seg_end]
.partition_point(|e| e.seq_id < morton_target) + seg_start;
let seg_size = seg_end - seg_start;
let wing = adaptive_wing(seg_size, self.params.wing_scale);
let wing_start = morton_pos.saturating_sub(wing).max(seg_start);
let wing_end = (morton_pos + wing).min(seg_end);
for entry in &self.sequence[wing_start..wing_end] {
global_bq_candidates.push((entry.phys_idx as usize, bq_query.hamming_distance(&entry.bq)));
}
}
global_bq_candidates.sort_unstable_by_key(|e| e.1);
let refine_len = global_bq_candidates.len().min(self.params.bq_refined_count as usize);
let mut candidates: Vec<(usize, f32)> = Vec::with_capacity(refine_len);
let is_f32 = std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>();
if is_f32 {
for &(phys_idx, _) in &global_bq_candidates[..refine_len] {
let offset = phys_idx * self.dim;
let vec_slice = &flat_data[offset..offset + self.dim];
let f32_slice: &[f32] = unsafe {
std::slice::from_raw_parts(vec_slice.as_ptr() as *const f32, self.dim)
};
let sim = cosine_sim(query, f32_slice);
candidates.push((phys_idx, sim));
}
} else {
for &(phys_idx, _) in &global_bq_candidates[..refine_len] {
let offset = phys_idx * self.dim;
let vec_slice = &flat_data[offset..offset + self.dim];
let mut f32_vec = Vec::with_capacity(self.dim);
for x in vec_slice {
f32_vec.push(x.to_f32());
}
let sim = cosine_sim(query, &f32_vec);
candidates.push((phys_idx, sim));
}
}
candidates.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
candidates.dedup_by_key(|e| e.0);
candidates.truncate(top_k);
candidates
}
pub fn count_cosine_computations(
&self,
query: &[f32],
) -> (usize, usize, usize) { let bq_query = BqSignature::from_vector(query);
let seq_query = compute_sequence_id(query, &self.pca_basis, &self.centers, &bq_query);
let mut cluster_scores: Vec<(usize, f32)> = self.centers.iter()
.enumerate()
.map(|(i, c)| (i, dot(query, c)))
.collect();
cluster_scores.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
let mut total_in_clusters = 0usize;
let mut total_bq_candidates = 0usize;
for &(cluster_id, _) in cluster_scores.iter().take(self.params.probe_count as usize) {
let cluster_lo = (cluster_id as u64) << 56;
let cluster_hi = cluster_lo | 0x00ffffffffffffff_u64;
let seg_start = self.sequence.partition_point(|e| e.seq_id < cluster_lo);
let seg_end = self.sequence.partition_point(|e| e.seq_id <= cluster_hi);
total_in_clusters += seg_end - seg_start;
if seg_start >= seg_end { continue; }
let morton_target = seq_query & 0x00ffffffffffffff_u64 | cluster_lo;
let morton_pos = self.sequence[seg_start..seg_end]
.partition_point(|e| e.seq_id < morton_target) + seg_start;
let seg_size = seg_end - seg_start;
let wing = adaptive_wing(seg_size, self.params.wing_scale);
let wing_start = morton_pos.saturating_sub(wing).max(seg_start);
let wing_end = (morton_pos + wing).min(seg_end);
total_bq_candidates += wing_end - wing_start;
}
let total_computed = total_bq_candidates.min(self.params.bq_refined_count as usize);
(self.sequence.len(), total_in_clusters, total_computed)
}
}