1use crate::RetrieveError;
31use std::collections::{BinaryHeap, HashSet};
32
33#[derive(Clone, Debug)]
35pub struct DEGConfig {
36 pub base_edges: usize,
38 pub max_edges: usize,
40 pub min_edges: usize,
42 pub density_k: usize,
44 pub alpha: f32,
46 pub ef_search: usize,
48}
49
50impl Default for DEGConfig {
51 fn default() -> Self {
52 Self {
53 base_edges: 16,
54 max_edges: 32,
55 min_edges: 8,
56 density_k: 10,
57 alpha: 1.2,
58 ef_search: 100,
59 }
60 }
61}
62
63#[derive(Clone, Debug)]
65pub struct DensityInfo {
66 pub density: f32,
68 pub edge_budget: usize,
70 pub avg_neighbor_dist: f32,
72}
73
74pub struct DEGIndex {
76 config: DEGConfig,
77 dim: usize,
78 vectors: Vec<Vec<f32>>,
80 edges: Vec<Vec<u32>>,
82 density: Vec<DensityInfo>,
84 entry_point: Option<u32>,
86}
87
88impl DEGIndex {
89 pub fn new(dim: usize, config: DEGConfig) -> Self {
91 Self {
92 config,
93 dim,
94 vectors: Vec::new(),
95 edges: Vec::new(),
96 density: Vec::new(),
97 entry_point: None,
98 }
99 }
100
101 pub fn add(&mut self, vector: Vec<f32>) -> Result<u32, RetrieveError> {
103 if vector.len() != self.dim {
104 return Err(RetrieveError::DimensionMismatch {
105 query_dim: vector.len(),
106 doc_dim: self.dim,
107 });
108 }
109
110 let id = self.vectors.len() as u32;
111 self.vectors.push(vector);
112 self.edges.push(Vec::new());
113 self.density.push(DensityInfo {
114 density: 0.0,
115 edge_budget: self.config.base_edges,
116 avg_neighbor_dist: 0.0,
117 });
118
119 if self.entry_point.is_none() {
120 self.entry_point = Some(id);
121 }
122
123 Ok(id)
124 }
125
126 pub fn build(&mut self) -> Result<(), RetrieveError> {
128 if self.vectors.is_empty() {
129 return Ok(());
130 }
131
132 let n = self.vectors.len();
133
134 const DEG_SCALE_LIMIT: usize = 10_000;
137 if n > DEG_SCALE_LIMIT {
138 return Err(RetrieveError::InvalidParameter(format!(
139 "DEG construction is O(n^2); n={} exceeds practical limit of {}. \
140 Use HNSW for larger datasets.",
141 n, DEG_SCALE_LIMIT
142 )));
143 }
144
145 self.estimate_densities()?;
147
148 self.assign_edge_budgets();
150
151 for i in 0..n {
153 self.connect_node(i as u32)?;
154 }
155
156 self.select_entry_point();
158
159 Ok(())
160 }
161
162 fn estimate_densities(&mut self) -> Result<(), RetrieveError> {
164 let n = self.vectors.len();
165 let k = self.config.density_k.min(n - 1);
166
167 for i in 0..n {
168 let mut distances: Vec<(u32, f32)> = (0..n)
170 .filter(|&j| j != i)
171 .map(|j| (j as u32, self.distance(i as u32, j as u32)))
172 .collect();
173
174 distances.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
175
176 let k_neighbors: Vec<_> = distances.iter().take(k).collect();
178 let avg_dist = if k_neighbors.is_empty() {
179 1.0
180 } else {
181 k_neighbors.iter().map(|(_, d)| d).sum::<f32>() / k_neighbors.len() as f32
182 };
183
184 let density = 1.0 / (avg_dist + 0.1);
186
187 self.density[i] = DensityInfo {
188 density,
189 edge_budget: self.config.base_edges,
190 avg_neighbor_dist: avg_dist,
191 };
192 }
193
194 Ok(())
195 }
196
197 fn assign_edge_budgets(&mut self) {
199 let min_density = self
201 .density
202 .iter()
203 .map(|d| d.density)
204 .fold(f32::INFINITY, f32::min);
205 let max_density = self
206 .density
207 .iter()
208 .map(|d| d.density)
209 .fold(f32::NEG_INFINITY, f32::max);
210
211 let density_range = (max_density - min_density).max(0.1);
212
213 for info in &mut self.density {
214 let normalized = (info.density - min_density) / density_range;
216
217 let edge_range = (self.config.max_edges - self.config.min_edges) as f32;
220 let budget = self.config.max_edges - (normalized * edge_range) as usize;
221
222 info.edge_budget = budget.clamp(self.config.min_edges, self.config.max_edges);
223 }
224 }
225
226 fn connect_node(&mut self, node_id: u32) -> Result<(), RetrieveError> {
228 let budget = self.density[node_id as usize].edge_budget;
229
230 let mut candidates: Vec<(u32, f32)> = (0..self.vectors.len() as u32)
232 .filter(|&j| j != node_id)
233 .map(|j| (j, self.distance(node_id, j)))
234 .collect();
235
236 candidates.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
237
238 let mut neighbors = Vec::new();
240
241 for (candidate, dist) in candidates {
242 if neighbors.len() >= budget {
243 break;
244 }
245
246 let is_diverse = neighbors.iter().all(|&n| {
250 let neighbor_dist = self.distance(candidate, n);
251 self.config.alpha * neighbor_dist > dist
252 });
253
254 if is_diverse {
255 neighbors.push(candidate);
256 }
257 }
258
259 for &neighbor in &neighbors {
261 let neighbor_edges = &mut self.edges[neighbor as usize];
262 if !neighbor_edges.contains(&node_id) {
263 let neighbor_budget = self.density[neighbor as usize].edge_budget;
264 if neighbor_edges.len() < neighbor_budget {
265 neighbor_edges.push(node_id);
266 }
267 }
268 }
269
270 self.edges[node_id as usize] = neighbors;
271
272 Ok(())
273 }
274
275 fn select_entry_point(&mut self) {
277 if self.vectors.is_empty() {
278 return;
279 }
280
281 let best = self
283 .density
284 .iter()
285 .enumerate()
286 .max_by(|a, b| a.1.density.total_cmp(&b.1.density))
287 .map(|(i, _)| i as u32);
288
289 if let Some(entry) = best {
290 self.entry_point = Some(entry);
291 }
292 }
293
294 pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(u32, f32)>, RetrieveError> {
296 if query.len() != self.dim {
297 return Err(RetrieveError::DimensionMismatch {
298 query_dim: query.len(),
299 doc_dim: self.dim,
300 });
301 }
302
303 if self.vectors.is_empty() {
304 return Ok(Vec::new());
305 }
306
307 let entry = self.entry_point.unwrap_or(0);
308
309 let mut visited: HashSet<u32> = HashSet::new();
311 let mut candidates: BinaryHeap<Candidate> = BinaryHeap::new();
312 let mut results: BinaryHeap<Candidate> = BinaryHeap::new();
313
314 let entry_dist = self.query_distance(entry, query);
316 candidates.push(Candidate {
317 id: entry,
318 distance: -entry_dist,
319 }); results.push(Candidate {
321 id: entry,
322 distance: entry_dist,
323 });
324 visited.insert(entry);
325
326 while let Some(Candidate {
327 id: current,
328 distance: neg_dist,
329 }) = candidates.pop()
330 {
331 let current_dist = -neg_dist;
332
333 let worst_result = results.peek().map(|c| c.distance).unwrap_or(f32::INFINITY);
335
336 if current_dist > worst_result && results.len() >= k {
337 break;
338 }
339
340 let local_density = self.density[current as usize].density;
342 let expansion = if local_density < 0.5 {
343 2 } else {
345 1 };
347
348 for &neighbor in &self.edges[current as usize] {
350 if visited.insert(neighbor) {
351 let dist = self.query_distance(neighbor, query);
352
353 if results.len() < k || dist < worst_result {
355 results.push(Candidate {
356 id: neighbor,
357 distance: dist,
358 });
359 while results.len() > k {
360 results.pop();
361 }
362 }
363
364 for _ in 0..expansion {
366 if candidates.len() < self.config.ef_search {
367 candidates.push(Candidate {
368 id: neighbor,
369 distance: -dist,
370 });
371 }
372 }
373 }
374 }
375
376 if visited.len() >= self.config.ef_search {
377 break;
378 }
379 }
380
381 let mut result_vec: Vec<(u32, f32)> =
383 results.into_iter().map(|c| (c.id, c.distance)).collect();
384 result_vec.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
385 result_vec.truncate(k);
386
387 Ok(result_vec)
388 }
389
390 fn distance(&self, a: u32, b: u32) -> f32 {
392 euclidean_distance(&self.vectors[a as usize], &self.vectors[b as usize])
393 }
394
395 fn query_distance(&self, id: u32, query: &[f32]) -> f32 {
397 euclidean_distance(&self.vectors[id as usize], query)
398 }
399
400 pub fn len(&self) -> usize {
402 self.vectors.len()
403 }
404
405 pub fn is_empty(&self) -> bool {
407 self.vectors.is_empty()
408 }
409
410 pub fn get_density(&self, id: u32) -> Option<&DensityInfo> {
412 self.density.get(id as usize)
413 }
414
415 pub fn edge_count(&self, id: u32) -> usize {
417 self.edges.get(id as usize).map(|e| e.len()).unwrap_or(0)
418 }
419}
420
421#[derive(Clone, Copy)]
423struct Candidate {
424 id: u32,
425 distance: f32,
426}
427
428impl PartialEq for Candidate {
429 fn eq(&self, other: &Self) -> bool {
430 self.distance == other.distance
431 }
432}
433
434impl Eq for Candidate {}
435
436impl PartialOrd for Candidate {
437 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
438 Some(self.cmp(other))
439 }
440}
441
442impl Ord for Candidate {
443 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
444 self.distance.total_cmp(&other.distance)
447 }
448}
449
450use crate::distance::l2_distance as euclidean_distance;
451
452#[cfg(test)]
453#[allow(clippy::unwrap_used, clippy::expect_used)]
454mod tests {
455 use super::*;
456
457 fn create_clustered_data(
458 num_clusters: usize,
459 points_per_cluster: usize,
460 dim: usize,
461 ) -> Vec<Vec<f32>> {
462 let mut data = Vec::new();
463
464 for c in 0..num_clusters {
465 let center_offset = c as f32 * 10.0;
466
467 for p in 0..points_per_cluster {
468 let mut point = vec![0.0; dim];
469 for (d, val) in point.iter_mut().enumerate() {
470 *val = center_offset + ((p * d) % 10) as f32 * 0.1;
471 }
472 data.push(point);
473 }
474 }
475
476 data
477 }
478
479 #[test]
480 fn test_deg_basic() {
481 let mut index = DEGIndex::new(4, DEGConfig::default());
482
483 for i in 0..10 {
485 let v = vec![i as f32 * 0.1; 4];
486 index.add(v).unwrap();
487 }
488
489 index.build().unwrap();
491
492 assert_eq!(index.len(), 10);
493 assert!(index.entry_point.is_some());
494 }
495
496 #[test]
497 fn test_deg_search() {
498 let mut index = DEGIndex::new(
499 4,
500 DEGConfig {
501 density_k: 3,
502 base_edges: 4,
503 ..Default::default()
504 },
505 );
506
507 let data = create_clustered_data(3, 10, 4);
509 for v in data {
510 index.add(v).unwrap();
511 }
512
513 index.build().unwrap();
514
515 let query = vec![0.0; 4]; let results = index.search(&query, 5).unwrap();
518
519 assert!(!results.is_empty());
520 assert!(results.len() <= 5);
521
522 for i in 1..results.len() {
524 assert!(results[i - 1].1 <= results[i].1);
525 }
526 }
527
528 #[test]
529 fn test_density_estimation() {
530 let mut index = DEGIndex::new(2, DEGConfig::default());
531
532 for i in 0..10 {
535 index.add(vec![i as f32 * 0.1, i as f32 * 0.1]).unwrap();
536 }
537
538 index.add(vec![100.0, 100.0]).unwrap();
540
541 index.build().unwrap();
542
543 let isolated_density = index.get_density(10).unwrap().density;
545 let cluster_density = index.get_density(5).unwrap().density;
546
547 assert!(isolated_density < cluster_density);
548 }
549
550 #[test]
551 fn test_adaptive_edge_budget() {
552 let mut index = DEGIndex::new(
553 2,
554 DEGConfig {
555 min_edges: 2,
556 max_edges: 8,
557 base_edges: 4,
558 ..Default::default()
559 },
560 );
561
562 for i in 0..20 {
564 index.add(vec![i as f32 * 0.1, i as f32 * 0.05]).unwrap();
565 }
566
567 index.add(vec![50.0, 50.0]).unwrap();
569 index.add(vec![60.0, 60.0]).unwrap();
570
571 index.build().unwrap();
572
573 let isolated_budget = index.get_density(20).unwrap().edge_budget;
575 let cluster_budget = index.get_density(10).unwrap().edge_budget;
576
577 assert!(isolated_budget >= cluster_budget);
578 }
579
580 #[test]
581 fn test_config_defaults() {
582 let config = DEGConfig::default();
583
584 assert_eq!(config.base_edges, 16);
585 assert_eq!(config.max_edges, 32);
586 assert_eq!(config.min_edges, 8);
587 assert_eq!(config.density_k, 10);
588 }
589
590 #[test]
597 fn test_alpha_pruning_is_respected() {
598 fn build_line_index(alpha: f32) -> DEGIndex {
605 let mut index = DEGIndex::new(
606 1,
607 DEGConfig {
608 base_edges: 4,
609 max_edges: 4,
610 min_edges: 1,
611 alpha,
612 density_k: 2,
613 ..Default::default()
614 },
615 );
616 for i in 0..8 {
617 index.add(vec![i as f32]).unwrap();
618 }
619 index.build().unwrap();
620 index
621 }
622
623 let tight = build_line_index(1.0);
624 let loose = build_line_index(3.0);
625
626 let tight_results = tight.search(&[0.0], 8).unwrap();
633 let loose_results = loose.search(&[0.0], 8).unwrap();
634
635 assert!(
638 tight_results.len() <= loose_results.len() + 2,
639 "tight alpha should not produce dramatically more results than loose: tight={}, loose={}",
640 tight_results.len(),
641 loose_results.len()
642 );
643
644 assert!(
647 !tight_results.is_empty(),
648 "tight alpha search should return at least one result"
649 );
650 }
651
652 #[test]
655 fn test_deg_recall_regression() {
656 let mut index = DEGIndex::new(
657 4,
658 DEGConfig {
659 alpha: 1.2,
660 base_edges: 8,
661 max_edges: 16,
662 density_k: 5,
663 ..Default::default()
664 },
665 );
666 let data = create_clustered_data(3, 30, 4);
667 let queries: Vec<_> = data.iter().take(10).cloned().collect();
668 for v in &data {
669 index.add(v.clone()).unwrap();
670 }
671 index.build().unwrap();
672
673 let mut hits = 0;
674 for q in &queries {
675 let results = index.search(q, 1).unwrap();
676 if let Some((_, dist)) = results.first() {
677 if *dist < 0.05 {
678 hits += 1;
679 }
680 }
681 }
682 assert!(
683 hits >= 7,
684 "recall too low ({}/10): alpha pruning may be broken",
685 hits
686 );
687 }
688}