use crate::RetrieveError;
use std::collections::{BinaryHeap, HashSet};
#[derive(Clone, Debug)]
pub struct DEGConfig {
pub base_edges: usize,
pub max_edges: usize,
pub min_edges: usize,
pub density_k: usize,
pub alpha: f32,
pub ef_search: usize,
}
impl Default for DEGConfig {
fn default() -> Self {
Self {
base_edges: 16,
max_edges: 32,
min_edges: 8,
density_k: 10,
alpha: 1.2,
ef_search: 100,
}
}
}
#[derive(Clone, Debug)]
pub struct DensityInfo {
pub density: f32,
pub edge_budget: usize,
pub avg_neighbor_dist: f32,
}
pub struct DEGIndex {
config: DEGConfig,
dim: usize,
vectors: Vec<Vec<f32>>,
edges: Vec<Vec<u32>>,
density: Vec<DensityInfo>,
entry_point: Option<u32>,
}
impl DEGIndex {
pub fn new(dim: usize, config: DEGConfig) -> Self {
Self {
config,
dim,
vectors: Vec::new(),
edges: Vec::new(),
density: Vec::new(),
entry_point: None,
}
}
pub fn add(&mut self, vector: Vec<f32>) -> Result<u32, RetrieveError> {
if vector.len() != self.dim {
return Err(RetrieveError::DimensionMismatch {
query_dim: vector.len(),
doc_dim: self.dim,
});
}
let id = self.vectors.len() as u32;
self.vectors.push(vector);
self.edges.push(Vec::new());
self.density.push(DensityInfo {
density: 0.0,
edge_budget: self.config.base_edges,
avg_neighbor_dist: 0.0,
});
if self.entry_point.is_none() {
self.entry_point = Some(id);
}
Ok(id)
}
pub fn build(&mut self) -> Result<(), RetrieveError> {
if self.vectors.is_empty() {
return Ok(());
}
let n = self.vectors.len();
self.estimate_densities()?;
self.assign_edge_budgets();
for i in 0..n {
self.connect_node(i as u32)?;
}
self.select_entry_point();
Ok(())
}
fn estimate_densities(&mut self) -> Result<(), RetrieveError> {
let n = self.vectors.len();
let k = self.config.density_k.min(n - 1);
for i in 0..n {
let mut distances: Vec<(u32, f32)> = (0..n)
.filter(|&j| j != i)
.map(|j| (j as u32, self.distance(i as u32, j as u32)))
.collect();
distances.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
let k_neighbors: Vec<_> = distances.iter().take(k).collect();
let avg_dist = if k_neighbors.is_empty() {
1.0
} else {
k_neighbors.iter().map(|(_, d)| d).sum::<f32>() / k_neighbors.len() as f32
};
let density = 1.0 / (avg_dist + 0.1);
self.density[i] = DensityInfo {
density,
edge_budget: self.config.base_edges,
avg_neighbor_dist: avg_dist,
};
}
Ok(())
}
fn assign_edge_budgets(&mut self) {
let min_density = self
.density
.iter()
.map(|d| d.density)
.fold(f32::INFINITY, f32::min);
let max_density = self
.density
.iter()
.map(|d| d.density)
.fold(f32::NEG_INFINITY, f32::max);
let density_range = (max_density - min_density).max(0.1);
for info in &mut self.density {
let normalized = (info.density - min_density) / density_range;
let edge_range = (self.config.max_edges - self.config.min_edges) as f32;
let budget = self.config.max_edges - (normalized * edge_range) as usize;
info.edge_budget = budget.clamp(self.config.min_edges, self.config.max_edges);
}
}
fn connect_node(&mut self, node_id: u32) -> Result<(), RetrieveError> {
let budget = self.density[node_id as usize].edge_budget;
let mut candidates: Vec<(u32, f32)> = (0..self.vectors.len() as u32)
.filter(|&j| j != node_id)
.map(|j| (j, self.distance(node_id, j)))
.collect();
candidates.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
let mut neighbors = Vec::new();
for (candidate, dist) in candidates {
if neighbors.len() >= budget {
break;
}
let is_diverse = neighbors.iter().all(|&n| {
let neighbor_dist = self.distance(candidate, n);
neighbor_dist > dist * self.config.alpha || neighbor_dist > dist
});
if is_diverse {
neighbors.push(candidate);
}
}
for &neighbor in &neighbors {
let neighbor_edges = &mut self.edges[neighbor as usize];
if !neighbor_edges.contains(&node_id) {
let neighbor_budget = self.density[neighbor as usize].edge_budget;
if neighbor_edges.len() < neighbor_budget {
neighbor_edges.push(node_id);
}
}
}
self.edges[node_id as usize] = neighbors;
Ok(())
}
fn select_entry_point(&mut self) {
if self.vectors.is_empty() {
return;
}
let best = self
.density
.iter()
.enumerate()
.max_by(|a, b| a.1.density.total_cmp(&b.1.density))
.map(|(i, _)| i as u32);
if let Some(entry) = best {
self.entry_point = Some(entry);
}
}
pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(u32, f32)>, RetrieveError> {
if query.len() != self.dim {
return Err(RetrieveError::DimensionMismatch {
query_dim: query.len(),
doc_dim: self.dim,
});
}
if self.vectors.is_empty() {
return Ok(Vec::new());
}
let entry = self.entry_point.unwrap_or(0);
let mut visited: HashSet<u32> = HashSet::new();
let mut candidates: BinaryHeap<Candidate> = BinaryHeap::new();
let mut results: BinaryHeap<Candidate> = BinaryHeap::new();
let entry_dist = self.query_distance(entry, query);
candidates.push(Candidate {
id: entry,
distance: -entry_dist,
}); results.push(Candidate {
id: entry,
distance: entry_dist,
});
visited.insert(entry);
while let Some(Candidate {
id: current,
distance: neg_dist,
}) = candidates.pop()
{
let current_dist = -neg_dist;
let worst_result = results.peek().map(|c| c.distance).unwrap_or(f32::INFINITY);
if current_dist > worst_result && results.len() >= k {
break;
}
let local_density = self.density[current as usize].density;
let expansion = if local_density < 0.5 {
2 } else {
1 };
for &neighbor in &self.edges[current as usize] {
if visited.insert(neighbor) {
let dist = self.query_distance(neighbor, query);
if results.len() < k || dist < worst_result {
results.push(Candidate {
id: neighbor,
distance: dist,
});
while results.len() > k {
results.pop();
}
}
for _ in 0..expansion {
if candidates.len() < self.config.ef_search {
candidates.push(Candidate {
id: neighbor,
distance: -dist,
});
}
}
}
}
if visited.len() >= self.config.ef_search {
break;
}
}
let mut result_vec: Vec<(u32, f32)> =
results.into_iter().map(|c| (c.id, c.distance)).collect();
result_vec.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
result_vec.truncate(k);
Ok(result_vec)
}
fn distance(&self, a: u32, b: u32) -> f32 {
euclidean_distance(&self.vectors[a as usize], &self.vectors[b as usize])
}
fn query_distance(&self, id: u32, query: &[f32]) -> f32 {
euclidean_distance(&self.vectors[id as usize], query)
}
pub fn len(&self) -> usize {
self.vectors.len()
}
pub fn is_empty(&self) -> bool {
self.vectors.is_empty()
}
pub fn get_density(&self, id: u32) -> Option<&DensityInfo> {
self.density.get(id as usize)
}
pub fn edge_count(&self, id: u32) -> usize {
self.edges.get(id as usize).map(|e| e.len()).unwrap_or(0)
}
}
#[derive(Clone, Copy)]
struct Candidate {
id: u32,
distance: f32,
}
impl PartialEq for Candidate {
fn eq(&self, other: &Self) -> bool {
self.distance == other.distance
}
}
impl Eq for Candidate {}
impl PartialOrd for Candidate {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Candidate {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.distance.total_cmp(&other.distance)
}
}
use crate::distance::l2_distance as euclidean_distance;
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
fn create_clustered_data(
num_clusters: usize,
points_per_cluster: usize,
dim: usize,
) -> Vec<Vec<f32>> {
let mut data = Vec::new();
for c in 0..num_clusters {
let center_offset = c as f32 * 10.0;
for p in 0..points_per_cluster {
let mut point = vec![0.0; dim];
for (d, val) in point.iter_mut().enumerate() {
*val = center_offset + ((p * d) % 10) as f32 * 0.1;
}
data.push(point);
}
}
data
}
#[test]
fn test_deg_basic() {
let mut index = DEGIndex::new(4, DEGConfig::default());
for i in 0..10 {
let v = vec![i as f32 * 0.1; 4];
index.add(v).unwrap();
}
index.build().unwrap();
assert_eq!(index.len(), 10);
assert!(index.entry_point.is_some());
}
#[test]
fn test_deg_search() {
let mut index = DEGIndex::new(
4,
DEGConfig {
density_k: 3,
base_edges: 4,
..Default::default()
},
);
let data = create_clustered_data(3, 10, 4);
for v in data {
index.add(v).unwrap();
}
index.build().unwrap();
let query = vec![0.0; 4]; let results = index.search(&query, 5).unwrap();
assert!(!results.is_empty());
assert!(results.len() <= 5);
for i in 1..results.len() {
assert!(results[i - 1].1 <= results[i].1);
}
}
#[test]
fn test_density_estimation() {
let mut index = DEGIndex::new(2, DEGConfig::default());
for i in 0..10 {
index.add(vec![i as f32 * 0.1, i as f32 * 0.1]).unwrap();
}
index.add(vec![100.0, 100.0]).unwrap();
index.build().unwrap();
let isolated_density = index.get_density(10).unwrap().density;
let cluster_density = index.get_density(5).unwrap().density;
assert!(isolated_density < cluster_density);
}
#[test]
fn test_adaptive_edge_budget() {
let mut index = DEGIndex::new(
2,
DEGConfig {
min_edges: 2,
max_edges: 8,
base_edges: 4,
..Default::default()
},
);
for i in 0..20 {
index.add(vec![i as f32 * 0.1, i as f32 * 0.05]).unwrap();
}
index.add(vec![50.0, 50.0]).unwrap();
index.add(vec![60.0, 60.0]).unwrap();
index.build().unwrap();
let isolated_budget = index.get_density(20).unwrap().edge_budget;
let cluster_budget = index.get_density(10).unwrap().edge_budget;
assert!(isolated_budget >= cluster_budget);
}
#[test]
fn test_config_defaults() {
let config = DEGConfig::default();
assert_eq!(config.base_edges, 16);
assert_eq!(config.max_edges, 32);
assert_eq!(config.min_edges, 8);
assert_eq!(config.density_k, 10);
}
}