use std::collections::BTreeMap;
use std::path::Path;
use std::time::Instant;
use rand::seq::SliceRandom;
use rustc_hash::FxHashMap;
use crate::engine::stages::collect::{collect_stage, CollectInput};
use crate::engine::stages::graph::{freeze_graph, graph_stage, GraphInput};
use crate::engine::stages::parse::{parse_stage, ParseInput};
use crate::predictive::embeddings::{node2vec_random_walks, train_skipgram, Word2VecConfig};
use crate::quantize::turbo_quant::{
quantize_scalar, uniform_codebook_4bit, QuantizedVector, TurboQuantCodebook, TurboQuantConfig,
};
fn cosine_sim(a: &[f64], b: &[f64]) -> f64 {
let dot: f64 = a.iter().zip(b).map(|(x, y)| x * y).sum();
let na: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
let nb: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
if na == 0.0 || nb == 0.0 {
return 0.0;
}
dot / (na * nb)
}
fn mse(a: &[f64], b: &[f64]) -> f64 {
let n = a.len() as f64;
a.iter().zip(b).map(|(x, y)| (x - y).powi(2)).sum::<f64>() / n
}
fn exact_knn(query: &[f64], database: &[Vec<f64>], k: usize) -> Vec<usize> {
let mut scored: Vec<(usize, f64)> = database
.iter()
.enumerate()
.map(|(i, v)| (i, cosine_sim(query, v)))
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(k);
scored.into_iter().map(|(i, _)| i).collect()
}
fn recall_at_k(exact: &[usize], approx: &[usize]) -> f64 {
let hit_count = exact.iter().filter(|i| approx.contains(i)).count();
hit_count as f64 / exact.len() as f64
}
fn quantize_naive(x: &[f64], dim: usize) -> (Vec<u8>, f64) {
let norm: f64 = x.iter().map(|v| v * v).sum::<f64>().sqrt();
let inv = if norm > 0.0 { 1.0 / norm } else { 1.0 };
let (centroids, boundaries) = uniform_codebook_4bit(dim);
let indices: Vec<u8> = x
.iter()
.map(|v| quantize_scalar(v * inv, &boundaries))
.collect();
let recon: Vec<f64> = indices
.iter()
.map(|&idx| centroids[idx as usize] * norm)
.collect();
let cos = cosine_sim(x, &recon);
(indices, cos)
}
#[test]
#[ignore]
fn turbo_quant_benchmark() {
let t_total = Instant::now();
let repo = Path::new(env!("CARGO_MANIFEST_DIR"));
let collect_out = collect_stage(&CollectInput {
repo_path: repo,
exclude_patterns: &[],
max_files: 10_000,
})
.expect("collect_stage failed");
let parse_out = parse_stage(&ParseInput {
files: collect_out.all_paths(),
workers: 8,
progress: None,
})
.expect("parse_stage failed");
let graph_out = graph_stage(&GraphInput {
parse_results: &parse_out.results,
repo_path: repo,
})
.expect("graph_stage failed");
let frozen = freeze_graph(graph_out.mutable_graph, graph_out.value_store, None);
let code_graph = &frozen.graph;
let edges: Vec<(u32, u32)> = code_graph
.all_call_edges()
.iter()
.map(|(a, b)| (a.index() as u32, b.index() as u32))
.collect();
let num_nodes = code_graph.node_count();
println!(
"Graph: {} nodes, {} edges, {} files parsed",
num_nodes,
edges.len(),
parse_out.stats.files_parsed
);
let t_embed = Instant::now();
let walks = node2vec_random_walks(crate::predictive::embeddings::Node2VecParams {
edges: &edges,
num_nodes,
walk_length: 20,
walks_per_node: 10,
p: 1.0,
q: 1.0,
seed: Some(42),
});
let w2v = train_skipgram(
&walks,
&Word2VecConfig {
embedding_dim: 128,
seed: Some(42),
..Default::default()
},
);
let embed_ms = t_embed.elapsed().as_millis();
println!(
"Embeddings: {} vectors in {}ms",
w2v.embeddings.len(),
embed_ms
);
let embeddings_f64: FxHashMap<u32, Vec<f64>> = w2v
.embeddings
.iter()
.map(|(&id, v)| (id, v.iter().map(|&x| x as f64).collect()))
.collect();
assert!(
embeddings_f64.len() >= 50,
"Need at least 50 embeddings, got {}",
embeddings_f64.len()
);
let t_quant = Instant::now();
let cb = TurboQuantCodebook::new(TurboQuantConfig::default());
let mut ids: Vec<u32> = embeddings_f64.keys().copied().collect();
ids.sort();
let quantized: Vec<QuantizedVector> = ids
.iter()
.map(|id| cb.quantize(&embeddings_f64[id]))
.collect();
let quant_ms = t_quant.elapsed().as_millis();
let raw_bytes = ids.len() * 128 * 8; let packed_bytes: usize = quantized.iter().map(|qv| qv.indices.len() + 8).sum(); let compression_ratio = raw_bytes as f64 / packed_bytes as f64;
let mut cosines = Vec::with_capacity(ids.len());
let mut mses = Vec::with_capacity(ids.len());
for (i, id) in ids.iter().enumerate() {
let orig = &embeddings_f64[id];
let recon = cb.reconstruct(&quantized[i]);
cosines.push(cosine_sim(orig, &recon));
mses.push(mse(orig, &recon));
}
let avg_cosine = cosines.iter().sum::<f64>() / cosines.len() as f64;
let min_cosine = cosines.iter().cloned().fold(f64::INFINITY, f64::min);
let avg_mse = mses.iter().sum::<f64>() / mses.len() as f64;
let mut naive_cosines = Vec::with_capacity(ids.len());
for id in &ids {
let orig = &embeddings_f64[id];
let (_, cos) = quantize_naive(orig, 128);
naive_cosines.push(cos);
}
let naive_avg_cosine = naive_cosines.iter().sum::<f64>() / naive_cosines.len() as f64;
let db_vecs: Vec<Vec<f64>> = ids.iter().map(|id| embeddings_f64[id].clone()).collect();
let db_f64_refs: &[Vec<f64>] = &db_vecs;
let mut rng = rand::rng();
let mut query_indices: Vec<usize> = (0..ids.len()).collect();
query_indices.shuffle(&mut rng);
let num_queries = 50.min(ids.len());
query_indices.truncate(num_queries);
let ks = [1, 5, 10, 50];
let mut recall_sums: BTreeMap<usize, f64> = ks.iter().map(|&k| (k, 0.0)).collect();
for &qi in &query_indices {
let query = &db_vecs[qi];
let query_f64: &[f64] = query.as_slice();
for &k in &ks {
let effective_k = k.min(ids.len());
let exact_topk = exact_knn(query_f64, db_f64_refs, effective_k);
let approx_results = cb.knn_search(query_f64, &quantized, effective_k);
let approx_topk: Vec<usize> = approx_results.iter().map(|(i, _)| *i).collect();
*recall_sums.get_mut(&k).unwrap() += recall_at_k(&exact_topk, &approx_topk);
}
}
let recall: BTreeMap<String, f64> = recall_sums
.iter()
.map(|(&k, &sum)| (format!("recall@{}", k), sum / num_queries as f64))
.collect();
let mut rerank_recall_sums: BTreeMap<usize, f64> = ks.iter().map(|&k| (k, 0.0)).collect();
for &qi in &query_indices {
let query = &db_vecs[qi];
let query_f64: &[f64] = query.as_slice();
for &k in &ks {
let effective_k = k.min(ids.len());
let exact_topk = exact_knn(query_f64, db_f64_refs, effective_k);
let rerank_results = cb.knn_search_rerank(query_f64, &quantized, effective_k, 20);
let rerank_topk: Vec<usize> = rerank_results.iter().map(|(i, _)| *i).collect();
*rerank_recall_sums.get_mut(&k).unwrap() += recall_at_k(&exact_topk, &rerank_topk);
}
}
let rerank_recall: BTreeMap<String, f64> = rerank_recall_sums
.iter()
.map(|(&k, &sum)| (format!("recall@{}", k), sum / num_queries as f64))
.collect();
let total_ms = t_total.elapsed().as_millis();
let report = serde_json::json!({
"benchmark": "turbo_quant_on_repotoire",
"num_vectors": ids.len(),
"dim": 128,
"bits": 4,
"compression_ratio": format!("{:.1}x", compression_ratio),
"cosine_similarity": {
"mean": format!("{:.6}", avg_cosine),
"min": format!("{:.6}", min_cosine),
},
"mse": format!("{:.8}", avg_mse),
"naive_baseline": {
"cosine_mean": format!("{:.6}", naive_avg_cosine),
},
"recall_adc": recall,
"recall_rerank": rerank_recall,
"timing_ms": {
"embeddings": embed_ms,
"quantization": quant_ms,
"total": total_ms,
},
"graph": {
"nodes": num_nodes,
"edges": edges.len(),
"files_parsed": parse_out.stats.files_parsed,
}
});
println!("\n{}", serde_json::to_string_pretty(&report).unwrap());
assert!(
avg_cosine >= 0.99,
"Average cosine similarity {avg_cosine:.6} < 0.99"
);
assert!(
compression_ratio >= 7.0,
"Compression ratio {compression_ratio:.1}x < 7x"
);
assert!(
avg_cosine > naive_avg_cosine,
"Lloyd-Max ({avg_cosine:.6}) should beat naive ({naive_avg_cosine:.6})"
);
}