use innr::binary::{binary_dot, binary_hamming, binary_jaccard, encode_binary};
use innr::{cosine, dot};
use std::time::Instant;
fn main() {
println!("Binary (1-Bit) Quantization");
println!("===========================\n");
demo_encoding();
demo_operations();
demo_memory();
demo_recall();
println!("Done!");
}
fn demo_encoding() {
println!("1. Encoding: f32 -> Binary");
println!(" -----------------------\n");
let embedding = [0.5f32, -0.3, 0.9, 0.0, -0.7, 0.1, 0.0, 0.8];
let packed = encode_binary(&embedding, 0.0);
println!(" f32 values: {:?}", embedding);
print!(" Binary bits: [");
for i in 0..embedding.len() {
if i > 0 {
print!(", ");
}
print!("{}", if packed.get(i) { 1 } else { 0 });
}
println!("]");
println!(" Rule: 1 if value > threshold (0.0), else 0");
println!();
assert!(packed.get(0)); assert!(!packed.get(1)); assert!(packed.get(2)); assert!(!packed.get(3)); assert!(!packed.get(4)); assert!(packed.get(5)); assert!(!packed.get(6)); assert!(packed.get(7)); }
fn demo_operations() {
println!("2. Binary Similarity Operations");
println!(" ----------------------------\n");
let a_f32 = [1.0f32, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0];
let b_f32 = [1.0f32, 1.0, -1.0, -1.0, 1.0, 1.0, -1.0, -1.0];
let a = encode_binary(&a_f32, 0.0);
let b = encode_binary(&b_f32, 0.0);
println!(" a bits: 1 0 1 0 1 0 1 0");
println!(" b bits: 1 1 0 0 1 1 0 0\n");
let hamming = binary_hamming(&a, &b);
let dot_val = binary_dot(&a, &b);
let jaccard = binary_jaccard(&a, &b);
println!(" Hamming distance: {} (bits that differ)", hamming);
assert_eq!(hamming, 4);
println!(" Binary dot: {} (bits both 1)", dot_val);
assert_eq!(dot_val, 2);
println!(" Jaccard: {:.4} (|A & B| / |A | B|)", jaccard);
assert!((jaccard - 2.0 / 6.0).abs() < 1e-6);
println!();
}
fn demo_memory() {
println!("3. Memory Reduction: 32x Compression");
println!(" ----------------------------------\n");
let configs: &[(&str, usize, usize)] = &[
("384d, 1M docs", 384, 1_000_000),
("768d, 1M docs", 768, 1_000_000),
("768d, 10M docs", 768, 10_000_000),
("1536d, 10M docs", 1536, 10_000_000),
];
println!(
" {:25} {:>10} {:>10} {:>7}",
"Config", "f32", "Binary", "Ratio"
);
println!(" {}", "-".repeat(55));
for &(name, dim, n) in configs {
let f32_bytes = n as u64 * dim as u64 * 4;
let binary_bytes = n as u64 * (dim as u64).div_ceil(64) * 8;
let ratio = f32_bytes as f64 / binary_bytes as f64;
println!(
" {:25} {:>10} {:>10} {:>6.1}x",
name,
format_bytes(f32_bytes),
format_bytes(binary_bytes),
ratio
);
}
println!();
}
fn demo_recall() {
println!("4. Recall Trade-off: Binary vs Exact");
println!(" ---------------------------------\n");
let dim = 768;
let n_docs = 5000;
let n_queries = 50;
let k = 10;
let docs_f32: Vec<Vec<f32>> = (0..n_docs)
.map(|i| generate_normalized(dim, i as u64))
.collect();
let docs_binary: Vec<_> = docs_f32.iter().map(|v| encode_binary(v, 0.0)).collect();
let queries: Vec<Vec<f32>> = (0..n_queries)
.map(|i| generate_normalized(dim, (i + 100_000) as u64))
.collect();
let mut total_recall = 0.0f64;
for query in &queries {
let query_binary = encode_binary(query, 0.0);
let mut exact_scores: Vec<(usize, f32)> = docs_f32
.iter()
.enumerate()
.map(|(i, d)| (i, cosine(query, d)))
.collect();
exact_scores.sort_by(|a, b| b.1.total_cmp(&a.1));
let exact_topk: Vec<usize> = exact_scores.iter().take(k).map(|(i, _)| *i).collect();
let mut binary_scores: Vec<(usize, u32)> = docs_binary
.iter()
.enumerate()
.map(|(i, d)| (i, binary_hamming(&query_binary, d)))
.collect();
binary_scores.sort_by_key(|(_, h)| *h);
let binary_topk: Vec<usize> = binary_scores.iter().take(k).map(|(i, _)| *i).collect();
let overlap = exact_topk
.iter()
.filter(|i| binary_topk.contains(i))
.count();
total_recall += overlap as f64 / k as f64;
}
let start = Instant::now();
for query in &queries {
let query_binary = encode_binary(query, 0.0);
let mut binary_scores: Vec<(usize, u32)> = docs_binary
.iter()
.enumerate()
.map(|(i, d)| (i, binary_hamming(&query_binary, d)))
.collect();
binary_scores.sort_by_key(|(_, h)| *h);
std::hint::black_box(&binary_scores);
}
let binary_time = start.elapsed();
let start = Instant::now();
for query in &queries {
let mut scores: Vec<(usize, f32)> = docs_f32
.iter()
.enumerate()
.map(|(i, d)| (i, dot(query, d)))
.collect();
scores.sort_by(|a, b| b.1.total_cmp(&a.1));
std::hint::black_box(&scores);
}
let exact_time = start.elapsed();
let avg_recall = total_recall / n_queries as f64;
println!(
" Corpus: {} docs x {}d, {} queries, k={}\n",
n_docs, dim, n_queries, k
);
println!(" Recall@{}: {:.1}%", k, avg_recall * 100.0);
println!(" (overlap between binary top-k and exact top-k)\n");
println!(
" Binary scoring: {:?} ({:.1} us/query)",
binary_time,
binary_time.as_micros() as f64 / n_queries as f64
);
println!(
" Exact scoring: {:?} ({:.1} us/query)",
exact_time,
exact_time.as_micros() as f64 / n_queries as f64
);
println!();
println!(" Binary quantization trades recall for speed and 32x less memory.");
println!(" Typical usage: binary retrieves top-1000 candidates, then rerank");
println!(" with full-precision vectors for the final top-k.");
println!();
}
fn generate_normalized(dim: usize, seed: u64) -> Vec<f32> {
let mut v: Vec<f32> = (0..dim)
.map(|i| {
let x = seed
.wrapping_mul(6364136223846793005)
.wrapping_add(i as u64 * 1442695040888963407);
((x >> 33) as f32 / (1u64 << 31) as f32) * 2.0 - 1.0
})
.collect();
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > f32::EPSILON {
for x in &mut v {
*x /= norm;
}
}
v
}
fn format_bytes(bytes: u64) -> String {
if bytes >= 1_000_000_000 {
format!("{:.1} GB", bytes as f64 / 1e9)
} else if bytes >= 1_000_000 {
format!("{:.1} MB", bytes as f64 / 1e6)
} else if bytes >= 1_000 {
format!("{:.1} KB", bytes as f64 / 1e3)
} else {
format!("{} B", bytes)
}
}