use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashSet};
#[derive(Debug, Clone)]
pub struct ProbabilisticRoutingConfig {
pub probability_threshold: f32,
pub delta: f32,
pub epsilon: f32,
pub adaptive: bool,
pub min_edges_to_test: usize,
pub max_skip_ratio: f32,
}
impl Default for ProbabilisticRoutingConfig {
fn default() -> Self {
Self {
probability_threshold: 0.3,
delta: 1.0,
epsilon: 0.1,
adaptive: true,
min_edges_to_test: 2,
max_skip_ratio: 0.5,
}
}
}
impl ProbabilisticRoutingConfig {
pub fn fast() -> Self {
Self {
probability_threshold: 0.5,
delta: 1.2,
epsilon: 0.15,
adaptive: true,
min_edges_to_test: 1,
max_skip_ratio: 0.7,
}
}
pub fn accurate() -> Self {
Self {
probability_threshold: 0.1,
delta: 1.0,
epsilon: 0.05,
adaptive: true,
min_edges_to_test: 4,
max_skip_ratio: 0.3,
}
}
}
#[derive(Debug)]
pub struct EdgeProbabilityEstimator {
density_estimate: f32,
sample_count: u32,
config: ProbabilisticRoutingConfig,
}
impl EdgeProbabilityEstimator {
pub fn new(config: ProbabilisticRoutingConfig) -> Self {
Self {
density_estimate: 1.0,
sample_count: 0,
config,
}
}
pub fn estimate_edge_probability(
&self,
query: &[f32],
current_best_dist: f32,
current_pos: &[f32],
neighbor_pos: &[f32],
) -> f32 {
let to_query: Vec<f32> = query
.iter()
.zip(current_pos.iter())
.map(|(q, c)| q - c)
.collect();
let to_neighbor: Vec<f32> = neighbor_pos
.iter()
.zip(current_pos.iter())
.map(|(n, c)| n - c)
.collect();
let dot: f32 = crate::simd::dot(&to_query, &to_neighbor);
let norm_q: f32 = crate::simd::norm(&to_query);
let norm_n: f32 = crate::simd::norm(&to_neighbor);
let cos_angle = if norm_q > 1e-10 && norm_n > 1e-10 {
(dot / (norm_q * norm_n)).clamp(-1.0, 1.0)
} else {
0.0
};
let neighbor_dist_from_current = norm_n;
let dist_ratio = if current_best_dist > 1e-10 {
(neighbor_dist_from_current / current_best_dist).min(2.0)
} else {
1.0
};
let angle_factor = (cos_angle + 1.0) / 2.0; let dist_factor = 1.0 / (1.0 + dist_ratio);
let base_prob = angle_factor * dist_factor;
let adjusted_prob = base_prob * self.density_estimate;
adjusted_prob.clamp(0.0, 1.0)
}
pub fn update_density(&mut self, neighbors_improved: usize, neighbors_tested: usize) {
if neighbors_tested > 0 {
let improvement_rate = neighbors_improved as f32 / neighbors_tested as f32;
let alpha = 0.1;
self.density_estimate =
(1.0 - alpha) * self.density_estimate + alpha * improvement_rate;
self.sample_count += 1;
}
}
pub fn get_threshold(&self) -> f32 {
if self.config.adaptive && self.sample_count > 10 {
(self.config.probability_threshold * self.density_estimate).clamp(0.05, 0.8)
} else {
self.config.probability_threshold
}
}
}
#[derive(Debug, Clone)]
struct ProbabilisticCandidate {
id: u32,
distance: f32,
#[allow(dead_code)]
probability: f32,
}
impl PartialEq for ProbabilisticCandidate {
fn eq(&self, other: &Self) -> bool {
self.distance == other.distance
}
}
impl Eq for ProbabilisticCandidate {}
impl PartialOrd for ProbabilisticCandidate {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for ProbabilisticCandidate {
fn cmp(&self, other: &Self) -> Ordering {
self.distance.total_cmp(&other.distance).reverse()
}
}
#[derive(Debug)]
pub struct ProbabilisticRouter {
config: ProbabilisticRoutingConfig,
estimator: EdgeProbabilityEstimator,
stats: ProbabilisticStats,
}
#[derive(Debug, Default, Clone)]
pub struct ProbabilisticStats {
pub edges_considered: u64,
pub edges_tested: u64,
pub edges_skipped: u64,
pub beneficial_skips: u64,
pub total_searches: u64,
}
impl ProbabilisticStats {
pub fn skip_ratio(&self) -> f32 {
if self.edges_considered > 0 {
self.edges_skipped as f32 / self.edges_considered as f32
} else {
0.0
}
}
pub fn estimated_qps_factor(&self) -> f32 {
if self.edges_tested > 0 {
self.edges_considered as f32 / self.edges_tested as f32
} else {
1.0
}
}
}
impl ProbabilisticRouter {
pub fn new(config: ProbabilisticRoutingConfig) -> Self {
let estimator = EdgeProbabilityEstimator::new(config.clone());
Self {
config,
estimator,
stats: ProbabilisticStats::default(),
}
}
pub fn filter_neighbors<'a>(
&mut self,
query: &[f32],
current_pos: &[f32],
current_best_dist: f32,
neighbors: &'a [(u32, &[f32])], ) -> Vec<(u32, &'a [f32], f32)> {
let threshold = self.estimator.get_threshold();
let mut candidates: Vec<(u32, &'a [f32], f32)> = Vec::with_capacity(neighbors.len());
for &(id, pos) in neighbors {
self.stats.edges_considered += 1;
let prob = self.estimator.estimate_edge_probability(
query,
current_best_dist,
current_pos,
pos,
);
candidates.push((id, pos, prob));
}
candidates.sort_unstable_by(|a, b| b.2.total_cmp(&a.2));
let total = candidates.len();
let min_to_test = self.config.min_edges_to_test.min(total);
let max_to_skip = ((total as f32) * self.config.max_skip_ratio) as usize;
let mut to_test = 0;
for (i, (_, _, prob)) in candidates.iter().enumerate() {
if *prob >= threshold || to_test < min_to_test {
to_test = i + 1;
} else if i >= total - max_to_skip {
break;
}
}
let to_test = to_test.max(min_to_test).min(total);
self.stats.edges_tested += to_test as u64;
self.stats.edges_skipped += (total - to_test) as u64;
candidates.truncate(to_test);
candidates
}
pub fn search(
&mut self,
query: &[f32],
entry_point: u32,
entry_pos: &[f32],
get_neighbors: impl Fn(u32) -> Vec<(u32, Vec<f32>)>,
ef: usize,
) -> Vec<(u32, f32)> {
self.stats.total_searches += 1;
let mut visited: HashSet<u32> = HashSet::new();
let mut candidates: BinaryHeap<ProbabilisticCandidate> = BinaryHeap::new();
let mut results: BinaryHeap<ProbabilisticCandidate> = BinaryHeap::new();
let entry_dist = euclidean_distance(query, entry_pos);
candidates.push(ProbabilisticCandidate {
id: entry_point,
distance: entry_dist,
probability: 1.0,
});
results.push(ProbabilisticCandidate {
id: entry_point,
distance: -entry_dist, probability: 1.0,
});
visited.insert(entry_point);
let mut current_best_dist = entry_dist;
let mut neighbors_improved = 0;
let mut neighbors_tested = 0;
while let Some(current) = candidates.pop() {
if results.len() >= ef {
if let Some(worst) = results.peek() {
if current.distance > -worst.distance {
break;
}
}
}
let raw_neighbors = get_neighbors(current.id);
let current_pos: Vec<f32> = raw_neighbors
.first()
.map(|(_, v)| v.clone())
.unwrap_or_else(|| vec![0.0; query.len()]);
let neighbors_with_pos: Vec<_> = raw_neighbors
.iter()
.filter(|(id, _)| !visited.contains(id))
.map(|(id, pos)| (*id, pos.as_slice()))
.collect();
let filtered =
self.filter_neighbors(query, ¤t_pos, current_best_dist, &neighbors_with_pos);
for (neighbor_id, neighbor_pos, _prob) in filtered {
if visited.contains(&neighbor_id) {
continue;
}
visited.insert(neighbor_id);
neighbors_tested += 1;
let dist = euclidean_distance(query, neighbor_pos);
if dist < current_best_dist {
current_best_dist = dist;
neighbors_improved += 1;
}
let should_add = results.len() < ef
|| dist < -results.peek().map(|r| r.distance).unwrap_or(f32::INFINITY);
if should_add {
candidates.push(ProbabilisticCandidate {
id: neighbor_id,
distance: dist,
probability: 1.0,
});
results.push(ProbabilisticCandidate {
id: neighbor_id,
distance: -dist,
probability: 1.0,
});
if results.len() > ef {
results.pop();
}
}
}
}
self.estimator
.update_density(neighbors_improved, neighbors_tested);
let mut final_results: Vec<(u32, f32)> =
results.into_iter().map(|c| (c.id, -c.distance)).collect();
final_results.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
final_results
}
pub fn stats(&self) -> &ProbabilisticStats {
&self.stats
}
pub fn reset_stats(&mut self) {
self.stats = ProbabilisticStats::default();
}
}
#[derive(Debug)]
pub struct ProbabilisticEdgeSelector {
#[allow(dead_code)]
config: ProbabilisticRoutingConfig,
}
impl ProbabilisticEdgeSelector {
pub fn new(config: ProbabilisticRoutingConfig) -> Self {
Self { config }
}
pub fn order_edges(
&self,
query: &[f32],
current_pos: &[f32],
current_dist: f32,
edges: &[(u32, Vec<f32>)],
) -> Vec<(u32, f32)> {
let mut scored: Vec<(u32, f32, f32)> = edges
.iter()
.map(|(id, pos)| {
let prob = estimate_improvement_probability(query, current_pos, current_dist, pos);
let dist = euclidean_distance(query, pos);
(*id, dist, prob)
})
.collect();
scored.sort_unstable_by(|a, b| {
b.2.partial_cmp(&a.2)
.unwrap_or(Ordering::Equal)
.then_with(|| a.1.total_cmp(&b.1))
});
scored.into_iter().map(|(id, dist, _)| (id, dist)).collect()
}
}
fn estimate_improvement_probability(
query: &[f32],
_current_pos: &[f32],
current_dist: f32,
neighbor_pos: &[f32],
) -> f32 {
let neighbor_dist = euclidean_distance(query, neighbor_pos);
if neighbor_dist < current_dist {
1.0
} else {
let ratio = current_dist / neighbor_dist.max(1e-10);
ratio.powi(2).clamp(0.0, 1.0)
}
}
use crate::distance::l2_distance as euclidean_distance;
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
fn make_vector(dim: usize, seed: u32) -> Vec<f32> {
(0..dim)
.map(|i| ((seed as f32 * 0.1 + i as f32) * 0.01).sin())
.collect()
}
#[test]
fn test_probability_estimation() {
let config = ProbabilisticRoutingConfig::default();
let estimator = EdgeProbabilityEstimator::new(config);
let query = vec![1.0, 0.0, 0.0];
let current = vec![0.5, 0.0, 0.0];
let current_dist = euclidean_distance(&query, ¤t);
let neighbor_good = vec![0.8, 0.0, 0.0];
let prob_good =
estimator.estimate_edge_probability(&query, current_dist, ¤t, &neighbor_good);
let neighbor_bad = vec![0.2, 0.0, 0.0];
let prob_bad =
estimator.estimate_edge_probability(&query, current_dist, ¤t, &neighbor_bad);
assert!(prob_good > prob_bad);
}
#[test]
fn test_filter_neighbors() {
let config = ProbabilisticRoutingConfig::default();
let mut router = ProbabilisticRouter::new(config);
let query = make_vector(64, 100);
let current = make_vector(64, 50);
let current_dist = euclidean_distance(&query, ¤t);
let neighbors: Vec<(u32, Vec<f32>)> =
(0..10).map(|i| (i, make_vector(64, i * 10))).collect();
let neighbors_ref: Vec<_> = neighbors
.iter()
.map(|(id, pos)| (*id, pos.as_slice()))
.collect();
let filtered = router.filter_neighbors(&query, ¤t, current_dist, &neighbors_ref);
assert!(filtered.len() <= neighbors.len());
assert!(filtered.len() >= router.config.min_edges_to_test.min(neighbors.len()));
}
#[test]
fn test_stats_tracking() {
let config = ProbabilisticRoutingConfig::default();
let mut router = ProbabilisticRouter::new(config);
let query = make_vector(64, 100);
let current = make_vector(64, 50);
let current_dist = euclidean_distance(&query, ¤t);
let neighbors: Vec<(u32, Vec<f32>)> =
(0..20).map(|i| (i, make_vector(64, i * 5))).collect();
let neighbors_ref: Vec<_> = neighbors
.iter()
.map(|(id, pos)| (*id, pos.as_slice()))
.collect();
router.filter_neighbors(&query, ¤t, current_dist, &neighbors_ref);
assert!(router.stats.edges_considered > 0);
assert!(router.stats.edges_tested > 0);
let skip_ratio = router.stats.skip_ratio();
assert!((0.0..=1.0).contains(&skip_ratio));
}
#[test]
fn test_edge_selector() {
let config = ProbabilisticRoutingConfig::default();
let selector = ProbabilisticEdgeSelector::new(config);
let query = vec![1.0, 0.0];
let current = vec![0.5, 0.0];
let current_dist = euclidean_distance(&query, ¤t);
let edges = vec![
(0, vec![0.2, 0.0]), (1, vec![0.8, 0.0]), (2, vec![0.5, 0.5]), ];
let ordered = selector.order_edges(&query, ¤t, current_dist, &edges);
assert_eq!(ordered[0].0, 1);
}
#[test]
fn test_config_presets() {
let fast = ProbabilisticRoutingConfig::fast();
let accurate = ProbabilisticRoutingConfig::accurate();
assert!(fast.probability_threshold > accurate.probability_threshold);
assert!(accurate.min_edges_to_test > fast.min_edges_to_test);
}
}