use crate::RetrieveError;
use std::collections::{BinaryHeap, HashMap, HashSet};
#[derive(Clone, Debug)]
pub struct RepairConfig {
pub max_candidates: usize,
pub max_neighbors: usize,
pub bidirectional: bool,
pub alpha: f32,
}
impl Default for RepairConfig {
fn default() -> Self {
Self {
max_candidates: 64,
max_neighbors: 16,
bidirectional: true,
alpha: 1.2,
}
}
}
#[derive(Clone, Debug, Default)]
pub struct RepairStats {
pub nodes_processed: usize,
pub edges_removed: usize,
pub edges_added: usize,
pub bidirectional_edges: usize,
}
pub struct GraphRepairer<'a> {
config: RepairConfig,
deleted: HashSet<u32>,
get_neighbors: Box<dyn Fn(u32) -> Vec<u32> + 'a>,
compute_distance: Box<dyn Fn(u32, u32) -> f32 + 'a>,
set_neighbors: Box<dyn FnMut(u32, Vec<u32>) + 'a>,
}
impl<'a> GraphRepairer<'a> {
pub fn new<G, D, S>(
config: RepairConfig,
get_neighbors: G,
compute_distance: D,
set_neighbors: S,
) -> Self
where
G: Fn(u32) -> Vec<u32> + 'a,
D: Fn(u32, u32) -> f32 + 'a,
S: FnMut(u32, Vec<u32>) + 'a,
{
Self {
config,
deleted: HashSet::new(),
get_neighbors: Box::new(get_neighbors),
compute_distance: Box::new(compute_distance),
set_neighbors: Box::new(set_neighbors),
}
}
pub fn mark_deleted(&mut self, node_id: u32) -> Result<RepairStats, RetrieveError> {
self.deleted.insert(node_id);
let mut stats = RepairStats::default();
let neighbors_to_repair = (self.get_neighbors)(node_id);
for &neighbor in &neighbors_to_repair {
if self.deleted.contains(&neighbor) {
continue;
}
let repair_result = self.repair_single_node(neighbor, node_id)?;
stats.nodes_processed += 1;
stats.edges_removed += repair_result.edges_removed;
stats.edges_added += repair_result.edges_added;
stats.bidirectional_edges += repair_result.bidirectional_edges;
}
Ok(stats)
}
pub fn mark_deleted_batch(&mut self, node_ids: &[u32]) -> Result<RepairStats, RetrieveError> {
for &id in node_ids {
self.deleted.insert(id);
}
let mut neighbors_to_repair: HashSet<u32> = HashSet::new();
for &id in node_ids {
for neighbor in (self.get_neighbors)(id) {
if !self.deleted.contains(&neighbor) {
neighbors_to_repair.insert(neighbor);
}
}
}
let mut stats = RepairStats::default();
for neighbor in neighbors_to_repair {
let repair_result = self.repair_node_full(neighbor)?;
stats.nodes_processed += 1;
stats.edges_removed += repair_result.edges_removed;
stats.edges_added += repair_result.edges_added;
stats.bidirectional_edges += repair_result.bidirectional_edges;
}
Ok(stats)
}
fn repair_single_node(
&mut self,
node_id: u32,
deleted_neighbor: u32,
) -> Result<RepairStats, RetrieveError> {
let mut stats = RepairStats::default();
let mut neighbors: Vec<u32> = (self.get_neighbors)(node_id)
.into_iter()
.filter(|&n| n != deleted_neighbor && !self.deleted.contains(&n))
.collect();
stats.edges_removed = 1;
if neighbors.len() < self.config.max_neighbors {
let needed = self.config.max_neighbors - neighbors.len();
let candidates = self.find_replacement_candidates(node_id, &neighbors, needed)?;
for candidate in candidates {
if neighbors.len() >= self.config.max_neighbors {
break;
}
neighbors.push(candidate);
stats.edges_added += 1;
if self.config.bidirectional {
self.add_bidirectional_edge(candidate, node_id)?;
stats.bidirectional_edges += 1;
}
}
}
(self.set_neighbors)(node_id, neighbors);
Ok(stats)
}
fn repair_node_full(&mut self, node_id: u32) -> Result<RepairStats, RetrieveError> {
let mut stats = RepairStats::default();
let original = (self.get_neighbors)(node_id);
let mut neighbors: Vec<u32> = original
.into_iter()
.filter(|n| !self.deleted.contains(n))
.collect();
let removed = (self.get_neighbors)(node_id).len() - neighbors.len();
stats.edges_removed = removed;
if neighbors.len() < self.config.max_neighbors {
let needed = self.config.max_neighbors - neighbors.len();
let candidates = self.find_replacement_candidates(node_id, &neighbors, needed)?;
for candidate in candidates {
if neighbors.len() >= self.config.max_neighbors {
break;
}
neighbors.push(candidate);
stats.edges_added += 1;
if self.config.bidirectional {
self.add_bidirectional_edge(candidate, node_id)?;
stats.bidirectional_edges += 1;
}
}
}
(self.set_neighbors)(node_id, neighbors);
Ok(stats)
}
fn find_replacement_candidates(
&self,
from_node: u32,
existing_neighbors: &[u32],
needed: usize,
) -> Result<Vec<u32>, RetrieveError> {
let mut visited: HashSet<u32> = HashSet::new();
visited.insert(from_node);
visited.extend(existing_neighbors.iter().cloned());
visited.extend(self.deleted.iter().cloned());
#[derive(Clone, Copy)]
struct Candidate {
id: u32,
dist: f32,
}
impl PartialEq for Candidate {
fn eq(&self, other: &Self) -> bool {
self.dist == other.dist
}
}
impl Eq for Candidate {}
impl PartialOrd for Candidate {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Candidate {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.dist.total_cmp(&other.dist).reverse()
}
}
let mut candidates: BinaryHeap<Candidate> = BinaryHeap::new();
for &neighbor in existing_neighbors {
for two_hop in (self.get_neighbors)(neighbor) {
if visited.insert(two_hop) {
let dist = (self.compute_distance)(from_node, two_hop);
candidates.push(Candidate { id: two_hop, dist });
if candidates.len() > self.config.max_candidates {
let mut temp: Vec<_> = candidates.drain().collect();
temp.sort_unstable_by(|a, b| a.dist.total_cmp(&b.dist));
temp.truncate(self.config.max_candidates / 2);
for c in temp {
candidates.push(c);
}
}
}
}
}
let mut selected = Vec::new();
let sorted: Vec<_> = candidates.into_sorted_vec();
'outer: for candidate in sorted.iter().rev() {
for &existing in existing_neighbors.iter().chain(selected.iter()) {
let dist_to_existing = (self.compute_distance)(candidate.id, existing);
if dist_to_existing < candidate.dist * self.config.alpha {
continue 'outer; }
}
selected.push(candidate.id);
if selected.len() >= needed {
break;
}
}
Ok(selected)
}
fn add_bidirectional_edge(
&mut self,
to_node: u32,
from_node: u32,
) -> Result<(), RetrieveError> {
let mut neighbors = (self.get_neighbors)(to_node);
if neighbors.contains(&from_node) || neighbors.len() >= self.config.max_neighbors {
return Ok(());
}
neighbors.push(from_node);
(self.set_neighbors)(to_node, neighbors);
Ok(())
}
pub fn is_deleted(&self, node_id: u32) -> bool {
self.deleted.contains(&node_id)
}
pub fn deleted_count(&self) -> usize {
self.deleted.len()
}
}
pub fn compute_repair_operations(
deleted_node: u32,
neighbors_of_deleted: &[u32],
get_neighbors: impl Fn(u32) -> Vec<u32>,
compute_distance: impl Fn(u32, u32) -> f32,
config: &RepairConfig,
deleted_set: &HashSet<u32>,
) -> HashMap<u32, Vec<u32>> {
let mut operations: HashMap<u32, Vec<u32>> = HashMap::new();
for &neighbor in neighbors_of_deleted {
if deleted_set.contains(&neighbor) {
continue;
}
let current: Vec<u32> = get_neighbors(neighbor)
.into_iter()
.filter(|&n| n != deleted_node && !deleted_set.contains(&n))
.collect();
if current.len() >= config.max_neighbors {
operations.insert(neighbor, current);
continue;
}
let mut candidates: Vec<(u32, f32)> = Vec::new();
let mut visited: HashSet<u32> = current.iter().cloned().collect();
visited.insert(neighbor);
visited.extend(deleted_set.iter().cloned());
for &n in ¤t {
for two_hop in get_neighbors(n) {
if visited.insert(two_hop) {
let dist = compute_distance(neighbor, two_hop);
candidates.push((two_hop, dist));
}
}
}
candidates.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
let mut new_neighbors = current;
for (candidate, _) in candidates {
if new_neighbors.len() >= config.max_neighbors {
break;
}
new_neighbors.push(candidate);
}
operations.insert(neighbor, new_neighbors);
}
operations
}
pub fn validate_connectivity(
entry_point: u32,
total_nodes: usize,
get_neighbors: impl Fn(u32) -> Vec<u32>,
is_deleted: impl Fn(u32) -> bool,
) -> (usize, usize) {
let mut visited: HashSet<u32> = HashSet::new();
let mut queue: Vec<u32> = vec![entry_point];
while let Some(node) = queue.pop() {
if is_deleted(node) || !visited.insert(node) {
continue;
}
for neighbor in get_neighbors(node) {
if !visited.contains(&neighbor) && !is_deleted(neighbor) {
queue.push(neighbor);
}
}
}
let reachable = visited.len();
let expected_valid = total_nodes - (0..total_nodes as u32).filter(|&n| is_deleted(n)).count();
let orphans = expected_valid.saturating_sub(reachable);
(reachable, orphans)
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
use std::cell::RefCell;
#[test]
fn test_repair_config_default() {
let config = RepairConfig::default();
assert_eq!(config.max_candidates, 64);
assert_eq!(config.max_neighbors, 16);
assert!(config.bidirectional);
}
#[test]
fn test_compute_repair_operations() {
let adjacency: Vec<Vec<u32>> = vec![
vec![1], vec![0, 2], vec![1, 3], vec![2, 4], vec![3], ];
let config = RepairConfig {
max_candidates: 10,
max_neighbors: 4,
bidirectional: true,
alpha: 1.0,
};
let deleted_set: HashSet<u32> = [2].into_iter().collect();
let neighbors_of_deleted = &adjacency[2];
let ops = compute_repair_operations(
2,
neighbors_of_deleted,
|id| adjacency[id as usize].clone(),
|a, b| (a as f32 - b as f32).abs(), &config,
&deleted_set,
);
assert!(ops.contains_key(&1));
let node1_new = &ops[&1];
assert!(!node1_new.contains(&2));
}
#[test]
fn test_validate_connectivity() {
let adjacency: Vec<Vec<u32>> = vec![
vec![1, 2], vec![0, 2, 3], vec![0, 1, 3], vec![1, 2], ];
let (reachable, orphans) = validate_connectivity(
0,
4,
|id| adjacency[id as usize].clone(),
|_| false, );
assert_eq!(reachable, 4);
assert_eq!(orphans, 0);
}
#[test]
fn test_validate_connectivity_with_deletion() {
let adjacency: Vec<Vec<u32>> = vec![
vec![1], vec![0, 2], vec![1, 3], vec![2], ];
let (reachable, orphans) =
validate_connectivity(0, 4, |id| adjacency[id as usize].clone(), |id| id == 1);
assert_eq!(reachable, 1);
assert_eq!(orphans, 2);
}
#[test]
fn test_graph_repairer() {
let adjacency = RefCell::new(vec![
vec![1, 2], vec![0, 2, 3], vec![0, 1, 3], vec![1, 2], ]);
let config = RepairConfig {
max_candidates: 10,
max_neighbors: 4,
bidirectional: true,
alpha: 1.0,
};
let mut repairer = GraphRepairer::new(
config,
|id| adjacency.borrow()[id as usize].clone(),
|a, b| (a as f32 - b as f32).abs(),
|id, neighbors| {
adjacency.borrow_mut()[id as usize] = neighbors;
},
);
let stats = repairer.mark_deleted(1).unwrap();
assert!(stats.nodes_processed > 0);
assert!(stats.edges_removed > 0);
assert!(repairer.is_deleted(1));
}
#[test]
fn test_batch_deletion() {
let adjacency = RefCell::new(vec![
vec![1, 2], vec![0, 2, 3, 4], vec![0, 1, 3], vec![1, 2, 4], vec![1, 3], ]);
let config = RepairConfig::default();
let mut repairer = GraphRepairer::new(
config,
|id| adjacency.borrow()[id as usize].clone(),
|a, b| (a as f32 - b as f32).abs(),
|id, neighbors| {
adjacency.borrow_mut()[id as usize] = neighbors;
},
);
let stats = repairer.mark_deleted_batch(&[1, 3]).unwrap();
assert!(repairer.is_deleted(1));
assert!(repairer.is_deleted(3));
assert_eq!(repairer.deleted_count(), 2);
assert!(stats.nodes_processed > 0);
}
}