use rand::prelude::*;
use rand::rngs::StdRng;
#[derive(Debug, Clone)]
pub struct ProjectionCache {
pub directions: Vec<Vec<f32>>,
pub num_projections: usize,
pub dim: usize,
}
impl ProjectionCache {
pub fn new(dim: usize, num_projections: usize, seed: u64) -> Self {
let mut rng = StdRng::seed_from_u64(seed);
let directions: Vec<Vec<f32>> = (0..num_projections)
.map(|_| {
let mut dir: Vec<f32> = (0..dim)
.map(|_| rng.sample::<f32, _>(rand::distributions::Standard) * 2.0 - 1.0)
.collect();
let norm: f32 = dir.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-8 {
for x in &mut dir {
*x /= norm;
}
}
dir
})
.collect();
Self {
directions,
num_projections,
dim,
}
}
#[inline]
pub fn project(&self, vector: &[f32]) -> Vec<f32> {
self.directions
.iter()
.map(|dir| Self::dot_product_simd(vector, dir))
.collect()
}
#[inline]
pub fn project_into(&self, vector: &[f32], out: &mut [f32]) {
for (i, dir) in self.directions.iter().enumerate() {
out[i] = Self::dot_product_simd(vector, dir);
}
}
#[inline(always)]
fn dot_product_simd(a: &[f32], b: &[f32]) -> f32 {
let len = a.len();
let chunks = len / 4;
let remainder = len % 4;
let mut sum0 = 0.0f32;
let mut sum1 = 0.0f32;
let mut sum2 = 0.0f32;
let mut sum3 = 0.0f32;
for i in 0..chunks {
let base = i * 4;
sum0 += a[base] * b[base];
sum1 += a[base + 1] * b[base + 1];
sum2 += a[base + 2] * b[base + 2];
sum3 += a[base + 3] * b[base + 3];
}
let base = chunks * 4;
for i in 0..remainder {
sum0 += a[base + i] * b[base + i];
}
sum0 + sum1 + sum2 + sum3
}
}
#[derive(Debug, Clone)]
pub struct WindowCache {
pub key_projections: Vec<Vec<f32>>,
pub sorted_indices: Vec<Vec<usize>>,
pub sorted_values: Vec<Vec<f32>>,
pub histograms: Option<Vec<Vec<f32>>>,
pub cdfs: Option<Vec<Vec<f32>>>,
pub num_keys: usize,
}
impl WindowCache {
pub fn build(keys: &[&[f32]], proj_cache: &ProjectionCache) -> Self {
let num_keys = keys.len();
let num_proj = proj_cache.num_projections;
let key_projections: Vec<Vec<f32>> = keys.iter().map(|k| proj_cache.project(k)).collect();
let mut sorted_indices = vec![Vec::with_capacity(num_keys); num_proj];
let mut sorted_values = vec![Vec::with_capacity(num_keys); num_proj];
for p in 0..num_proj {
let mut indexed: Vec<(usize, f32)> = key_projections
.iter()
.enumerate()
.map(|(i, projs)| (i, projs[p]))
.collect();
indexed.sort_unstable_by(|a, b| {
a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)
});
sorted_indices[p] = indexed.iter().map(|(i, _)| *i).collect();
sorted_values[p] = indexed.iter().map(|(_, v)| *v).collect();
}
Self {
key_projections,
sorted_indices,
sorted_values,
histograms: None,
cdfs: None,
num_keys,
}
}
pub fn build_histograms(&mut self, num_bins: usize) {
let num_proj = self.sorted_values.len();
let mut histograms = vec![vec![0.0f32; num_bins]; num_proj];
let mut cdfs = vec![vec![0.0f32; num_bins]; num_proj];
for p in 0..num_proj {
let vals = &self.sorted_values[p];
if vals.is_empty() {
continue;
}
let min_val = vals[0];
let max_val = vals[vals.len() - 1];
let range = (max_val - min_val).max(1e-8);
for &v in vals {
let bin = ((v - min_val) / range * (num_bins - 1) as f32)
.clamp(0.0, (num_bins - 1) as f32) as usize;
histograms[p][bin] += 1.0 / self.num_keys as f32;
}
let mut cumsum = 0.0f32;
for bin in 0..num_bins {
cumsum += histograms[p][bin];
cdfs[p][bin] = cumsum;
}
}
self.histograms = Some(histograms);
self.cdfs = Some(cdfs);
}
#[inline]
pub fn get_sorted(&self, projection_idx: usize) -> &[f32] {
&self.sorted_values[projection_idx]
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_projection_cache() {
let cache = ProjectionCache::new(64, 8, 42);
assert_eq!(cache.num_projections, 8);
assert_eq!(cache.dim, 64);
for dir in &cache.directions {
let norm: f32 = dir.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-5);
}
}
#[test]
fn test_window_cache() {
let proj_cache = ProjectionCache::new(32, 4, 42);
let keys: Vec<Vec<f32>> = (0..10).map(|i| vec![i as f32 * 0.1; 32]).collect();
let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
let window_cache = WindowCache::build(&keys_refs, &proj_cache);
assert_eq!(window_cache.num_keys, 10);
assert_eq!(window_cache.sorted_indices.len(), 4);
}
#[test]
fn test_histograms() {
let proj_cache = ProjectionCache::new(16, 2, 42);
let keys: Vec<Vec<f32>> = (0..20).map(|i| vec![i as f32 * 0.05; 16]).collect();
let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
let mut window_cache = WindowCache::build(&keys_refs, &proj_cache);
window_cache.build_histograms(10);
assert!(window_cache.cdfs.is_some());
let cdfs = window_cache.cdfs.as_ref().unwrap();
for cdf in cdfs {
assert!((cdf[cdf.len() - 1] - 1.0).abs() < 1e-5);
}
}
}