#![allow(dead_code)]
#[derive(Debug, Clone)]
pub struct AdaptiveConfig {
pub min_candidates: usize,
pub confidence_threshold: f32,
pub projection_dims: usize,
pub dimension_sampling: bool,
pub sampling_ratio: f32,
pub hubness_aware: bool,
pub hub_threshold: f32,
}
impl Default for AdaptiveConfig {
fn default() -> Self {
Self {
min_candidates: 10,
confidence_threshold: 0.9,
projection_dims: 0,
dimension_sampling: false,
sampling_ratio: 0.25,
hubness_aware: false,
hub_threshold: 0.1,
}
}
}
impl AdaptiveConfig {
pub fn conservative() -> Self {
Self {
min_candidates: 50,
confidence_threshold: 0.99,
projection_dims: 0,
dimension_sampling: false,
sampling_ratio: 1.0,
hubness_aware: false,
hub_threshold: 0.05,
}
}
pub fn aggressive() -> Self {
Self {
min_candidates: 5,
confidence_threshold: 0.7,
projection_dims: 32,
dimension_sampling: true,
sampling_ratio: 0.1,
hubness_aware: true,
hub_threshold: 0.15,
}
}
}
#[derive(Debug)]
pub struct EarlyTerminationOracle {
top_k_distances: Vec<f32>,
k: usize,
num_evaluated: usize,
distance_mean: f32,
distance_var: f32,
config: AdaptiveConfig,
}
impl EarlyTerminationOracle {
pub fn new(k: usize, config: AdaptiveConfig) -> Self {
Self {
top_k_distances: Vec::with_capacity(k),
k,
num_evaluated: 0,
distance_mean: 0.0,
distance_var: 0.0,
config,
}
}
pub fn observe(&mut self, distance: f32) {
self.num_evaluated += 1;
let delta = distance - self.distance_mean;
self.distance_mean += delta / self.num_evaluated as f32;
let delta2 = distance - self.distance_mean;
self.distance_var += delta * delta2;
if self.top_k_distances.len() < self.k {
self.top_k_distances.push(distance);
self.top_k_distances.sort_unstable_by(|a, b| a.total_cmp(b));
} else if distance < self.top_k_distances[self.k - 1] {
self.top_k_distances[self.k - 1] = distance;
self.top_k_distances.sort_unstable_by(|a, b| a.total_cmp(b));
}
}
pub fn should_terminate(&self) -> bool {
if self.num_evaluated < self.config.min_candidates {
return false;
}
if self.top_k_distances.len() < self.k {
return false;
}
let variance = self.distance_var / (self.num_evaluated as f32 - 1.0).max(1.0);
let std_dev = variance.sqrt().max(1e-9);
let threshold = self.top_k_distances[self.k - 1];
let z_score = (threshold - self.distance_mean) / std_dev;
let prob_better = 1.0 / (1.0 + (-1.7 * z_score).exp());
prob_better < (1.0 - self.config.confidence_threshold)
}
pub fn top_k(&self) -> &[f32] {
&self.top_k_distances
}
pub fn num_evaluated(&self) -> usize {
self.num_evaluated
}
}
#[derive(Debug, Clone)]
pub struct DimensionImportance {
variances: Vec<f32>,
means: Vec<f32>,
importance_order: Vec<usize>,
num_samples: usize,
}
impl DimensionImportance {
pub fn estimate(vectors: &[f32], num_vectors: usize, dimension: usize) -> Self {
assert_eq!(vectors.len(), num_vectors * dimension);
let mut means = vec![0.0f32; dimension];
let mut variances = vec![0.0f32; dimension];
for i in 0..num_vectors {
let vec = &vectors[i * dimension..(i + 1) * dimension];
for (d, &v) in vec.iter().enumerate() {
means[d] += v;
}
}
for m in &mut means {
*m /= num_vectors as f32;
}
for i in 0..num_vectors {
let vec = &vectors[i * dimension..(i + 1) * dimension];
for (d, &v) in vec.iter().enumerate() {
let diff = v - means[d];
variances[d] += diff * diff;
}
}
for v in &mut variances {
*v /= (num_vectors as f32 - 1.0).max(1.0);
}
let mut importance_order: Vec<usize> = (0..dimension).collect();
importance_order.sort_unstable_by(|&a, &b| variances[b].total_cmp(&variances[a]));
Self {
variances,
means,
importance_order,
num_samples: num_vectors,
}
}
pub fn order(&self) -> &[usize] {
&self.importance_order
}
pub fn top_dimensions(&self, k: usize) -> &[usize] {
let k = k.min(self.importance_order.len());
&self.importance_order[..k]
}
pub fn weight(&self, dim: usize) -> f32 {
self.variances.get(dim).copied().unwrap_or(0.0)
}
pub fn total_variance(&self) -> f32 {
self.variances.iter().sum()
}
pub fn cumulative_variance(&self, k: usize) -> f32 {
self.importance_order
.iter()
.take(k)
.map(|&d| self.variances[d])
.sum()
}
}
#[derive(Debug, Clone)]
pub struct HubnessTracker {
occurrence_counts: Vec<usize>,
num_queries: usize,
k: usize,
hub_threshold: f32,
}
impl HubnessTracker {
pub fn new(num_nodes: usize, k: usize, hub_threshold: f32) -> Self {
Self {
occurrence_counts: vec![0; num_nodes],
num_queries: 0,
k,
hub_threshold,
}
}
pub fn record_result(&mut self, top_k_indices: &[usize]) {
self.num_queries += 1;
for &idx in top_k_indices.iter().take(self.k) {
if idx < self.occurrence_counts.len() {
self.occurrence_counts[idx] += 1;
}
}
}
pub fn is_hub(&self, node_idx: usize) -> bool {
if self.num_queries == 0 {
return false;
}
let occurrence_rate = self.occurrence_counts.get(node_idx).copied().unwrap_or(0) as f32
/ self.num_queries as f32;
occurrence_rate > self.hub_threshold
}
pub fn hub_score(&self, node_idx: usize) -> f32 {
if self.num_queries == 0 {
return 0.0;
}
self.occurrence_counts.get(node_idx).copied().unwrap_or(0) as f32 / self.num_queries as f32
}
pub fn hubs(&self) -> Vec<usize> {
(0..self.occurrence_counts.len())
.filter(|&i| self.is_hub(i))
.collect()
}
pub fn stats(&self) -> HubnessStats {
let num_hubs = self.hubs().len();
let max_occurrences = self.occurrence_counts.iter().max().copied().unwrap_or(0);
let mean_occurrences = if self.occurrence_counts.is_empty() {
0.0
} else {
self.occurrence_counts.iter().sum::<usize>() as f32
/ self.occurrence_counts.len() as f32
};
HubnessStats {
num_hubs,
total_nodes: self.occurrence_counts.len(),
max_occurrences,
mean_occurrences,
queries_processed: self.num_queries,
}
}
}
#[derive(Debug, Clone)]
pub struct HubnessStats {
pub num_hubs: usize,
pub total_nodes: usize,
pub max_occurrences: usize,
pub mean_occurrences: f32,
pub queries_processed: usize,
}
pub fn sampled_l2_squared(
a: &[f32],
b: &[f32],
importance: &DimensionImportance,
sample_fraction: f32,
) -> (f32, f32) {
debug_assert_eq!(a.len(), b.len());
let dim = a.len();
let num_samples = ((dim as f32 * sample_fraction) as usize).max(1).min(dim);
let sampled_dims = importance.top_dimensions(num_samples);
let mut sampled_dist = 0.0f32;
for &d in sampled_dims {
let diff = a[d] - b[d];
sampled_dist += diff * diff;
}
let sampled_variance = importance.cumulative_variance(num_samples);
let total_variance = importance.total_variance();
let scale = if sampled_variance > 1e-9 {
total_variance / sampled_variance
} else {
dim as f32 / num_samples as f32
};
let estimated_full = sampled_dist * scale;
(sampled_dist, estimated_full)
}
pub fn two_phase_l2_squared(
a: &[f32],
b: &[f32],
importance: &DimensionImportance,
threshold: f32,
sample_fraction: f32,
) -> Option<f32> {
let (_, estimated) = sampled_l2_squared(a, b, importance, sample_fraction);
if estimated > threshold * 1.5 {
return None;
}
let full_dist: f32 = a
.iter()
.zip(b.iter())
.map(|(x, y)| {
let d = x - y;
d * d
})
.sum();
Some(full_dist)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_early_termination_basic() {
let config = AdaptiveConfig::default();
let mut oracle = EarlyTerminationOracle::new(3, config);
for i in 0..5 {
oracle.observe(i as f32);
assert!(!oracle.should_terminate());
}
for _ in 0..100 {
oracle.observe(10.0 + rand_f32() * 0.1);
}
let _ = oracle.should_terminate();
}
#[test]
fn test_dimension_importance() {
let vectors: Vec<f32> = vec![
0.0, 0.5, 0.5, 1.0, 0.5, 0.5, 2.0, 0.5, 0.5, 3.0, 0.5, 0.5, ];
let importance = DimensionImportance::estimate(&vectors, 4, 3);
assert_eq!(importance.order()[0], 0);
let cum1 = importance.cumulative_variance(1);
let cum2 = importance.cumulative_variance(2);
let cum3 = importance.cumulative_variance(3);
assert!(cum1 <= cum2);
assert!(cum2 <= cum3);
}
#[test]
fn test_hubness_tracker() {
let mut tracker = HubnessTracker::new(10, 3, 0.3);
for _ in 0..10 {
tracker.record_result(&[0, 1, 2]);
}
for _ in 0..3 {
tracker.record_result(&[5, 6, 7]);
}
assert!(tracker.is_hub(0)); assert!(!tracker.is_hub(5));
let stats = tracker.stats();
assert!(stats.num_hubs > 0);
}
#[test]
fn test_sampled_distance() {
let vectors: Vec<f32> = (0..100).map(|i| (i as f32 / 100.0) - 0.5).collect();
let importance = DimensionImportance::estimate(&vectors, 10, 10);
let a = vec![0.0f32; 10];
let b: Vec<f32> = (0..10).map(|i| i as f32 * 0.1).collect();
let (_sampled, estimated) = sampled_l2_squared(&a, &b, &importance, 0.5);
let full: f32 = a
.iter()
.zip(b.iter())
.map(|(x, y)| {
let d = x - y;
d * d
})
.sum();
assert!(estimated > 0.0);
assert!(estimated < full * 3.0);
}
fn rand_f32() -> f32 {
use std::cell::Cell;
thread_local! {
static SEED: Cell<u32> = const { Cell::new(12345) };
}
SEED.with(|s| {
let next = s.get().wrapping_mul(1103515245).wrapping_add(12345);
s.set(next);
(next as f32) / (u32::MAX as f32)
})
}
}