use super::adjacency::AdjacencyGraph;
use crate::error::{SparseError, SparseResult};
#[derive(Debug, Clone)]
pub struct AmdResult {
pub perm: Vec<usize>,
pub inv_perm: Vec<usize>,
pub estimated_nnz: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum NodeStatus {
Active,
Eliminated(usize),
Absorbed(usize),
}
struct QuotientGraph {
n: usize,
adj: Vec<Vec<usize>>,
status: Vec<NodeStatus>,
degree: Vec<usize>,
weight: Vec<usize>,
}
impl QuotientGraph {
fn new(graph: &AdjacencyGraph) -> Self {
let n = graph.num_nodes();
let adj: Vec<Vec<usize>> = (0..n).map(|u| graph.neighbors(u).to_vec()).collect();
let degree: Vec<usize> = (0..n).map(|u| graph.degree(u)).collect();
Self {
n,
adj,
status: vec![NodeStatus::Active; n],
degree,
weight: vec![1; n],
}
}
fn find_representative(&self, mut node: usize) -> usize {
let mut steps = 0;
while let NodeStatus::Absorbed(parent) = self.status[node] {
node = parent;
steps += 1;
if steps > self.n {
break; }
}
node
}
fn approximate_external_degree(&self, u: usize) -> usize {
let mut reachable = Vec::new();
let mut seen = vec![false; self.n];
seen[u] = true;
for &v in &self.adj[u] {
let rep = self.find_representative(v);
match self.status[rep] {
NodeStatus::Active => {
if !seen[rep] {
seen[rep] = true;
reachable.push(rep);
}
}
NodeStatus::Eliminated(_) => {
for &w in &self.adj[rep] {
let wr = self.find_representative(w);
if matches!(self.status[wr], NodeStatus::Active) && !seen[wr] {
seen[wr] = true;
reachable.push(wr);
}
}
}
NodeStatus::Absorbed(_) => {
}
}
}
reachable.iter().map(|&v| self.weight[v]).sum()
}
fn eliminate(&mut self, pivot: usize, step: usize) -> Vec<usize> {
self.status[pivot] = NodeStatus::Eliminated(step);
let mut reach = Vec::new();
let mut seen = vec![false; self.n];
seen[pivot] = true;
for &v in &self.adj[pivot].clone() {
let rep = self.find_representative(v);
match self.status[rep] {
NodeStatus::Active => {
if !seen[rep] {
seen[rep] = true;
reach.push(rep);
}
}
NodeStatus::Eliminated(_) => {
for &w in &self.adj[rep].clone() {
let wr = self.find_representative(w);
if matches!(self.status[wr], NodeStatus::Active) && !seen[wr] {
seen[wr] = true;
reach.push(wr);
}
}
self.status[rep] = NodeStatus::Absorbed(pivot);
}
NodeStatus::Absorbed(_) => {}
}
}
self.adj[pivot] = reach.clone();
for &r in &reach {
let mut new_adj = Vec::new();
let mut has_pivot_element = false;
for &v in &self.adj[r] {
let rep = self.find_representative(v);
if rep == pivot {
if !has_pivot_element {
new_adj.push(pivot);
has_pivot_element = true;
}
} else if matches!(self.status[rep], NodeStatus::Active) {
if !new_adj.contains(&rep) {
new_adj.push(rep);
}
} else if matches!(self.status[rep], NodeStatus::Eliminated(_))
&& !new_adj.contains(&rep)
{
new_adj.push(rep);
}
}
if !has_pivot_element {
new_adj.push(pivot);
}
self.adj[r] = new_adj;
}
for &r in &reach {
self.degree[r] = self.approximate_external_degree(r);
}
self.mass_eliminate(&reach);
reach
}
fn mass_eliminate(&mut self, reach: &[usize]) {
if reach.len() < 2 {
return;
}
let mut hashes: Vec<(usize, u64)> = Vec::new();
for &u in reach {
if !matches!(self.status[u], NodeStatus::Active) {
continue;
}
let mut hash = 0u64;
let mut sorted_adj: Vec<usize> = self.adj[u]
.iter()
.map(|&v| self.find_representative(v))
.filter(|&v| v != u)
.collect();
sorted_adj.sort_unstable();
sorted_adj.dedup();
for &v in &sorted_adj {
hash = hash
.wrapping_mul(6364136223846793005)
.wrapping_add(v as u64);
}
hashes.push((u, hash));
}
hashes.sort_unstable_by_key(|&(_, h)| h);
let mut i = 0;
while i < hashes.len() {
let mut j = i + 1;
while j < hashes.len() && hashes[j].1 == hashes[i].1 {
j += 1;
}
if j - i > 1 {
let group: Vec<usize> = hashes[i..j].iter().map(|&(u, _)| u).collect();
self.try_merge_group(&group);
}
i = j;
}
}
fn try_merge_group(&mut self, group: &[usize]) {
if group.len() < 2 {
return;
}
let get_sorted_adj = |u: usize, adj: &[Vec<usize>], status: &[NodeStatus]| -> Vec<usize> {
let mut sorted: Vec<usize> = adj[u]
.iter()
.filter_map(|&v| {
let rep = {
let mut node = v;
let mut steps = 0;
while let NodeStatus::Absorbed(parent) = status[node] {
node = parent;
steps += 1;
if steps > adj.len() {
break;
}
}
node
};
if rep != u {
Some(rep)
} else {
None
}
})
.collect();
sorted.sort_unstable();
sorted.dedup();
sorted
};
let representative = group[0];
let rep_adj = get_sorted_adj(representative, &self.adj, &self.status);
for &other in &group[1..] {
if !matches!(self.status[other], NodeStatus::Active) {
continue;
}
let other_adj = get_sorted_adj(other, &self.adj, &self.status);
if rep_adj == other_adj {
self.status[other] = NodeStatus::Absorbed(representative);
self.weight[representative] += self.weight[other];
}
}
}
}
pub fn amd(graph: &AdjacencyGraph) -> SparseResult<AmdResult> {
let n = graph.num_nodes();
if n == 0 {
return Ok(AmdResult {
perm: Vec::new(),
inv_perm: Vec::new(),
estimated_nnz: 0,
});
}
let mut qg = QuotientGraph::new(graph);
let mut perm = Vec::with_capacity(n);
let mut estimated_nnz = 0usize;
for step in 0..n {
let pivot = (0..n)
.filter(|&u| matches!(qg.status[u], NodeStatus::Active))
.min_by_key(|&u| qg.degree[u]);
let pivot = match pivot {
Some(p) => p,
None => break, };
perm.push(pivot);
let reach = qg.eliminate(pivot, step);
estimated_nnz += reach.len();
}
let mut full_perm = Vec::with_capacity(n);
let mut emitted = vec![false; n];
for &p in &perm {
if !emitted[p] {
full_perm.push(p);
emitted[p] = true;
}
for u in 0..n {
if !emitted[u] {
if let NodeStatus::Absorbed(rep) = qg.status[u] {
let final_rep = qg.find_representative(rep);
if final_rep == p {
full_perm.push(u);
emitted[u] = true;
}
}
}
}
}
for u in 0..n {
if !emitted[u] {
full_perm.push(u);
emitted[u] = true;
}
}
let mut inv_perm = vec![0usize; n];
for (new_i, &old_i) in full_perm.iter().enumerate() {
inv_perm[old_i] = new_i;
}
estimated_nnz += n;
Ok(AmdResult {
perm: full_perm,
inv_perm,
estimated_nnz,
})
}
pub fn amd_simple(graph: &AdjacencyGraph) -> SparseResult<AmdResult> {
let n = graph.num_nodes();
if n == 0 {
return Ok(AmdResult {
perm: Vec::new(),
inv_perm: Vec::new(),
estimated_nnz: 0,
});
}
let mut adj: Vec<Vec<usize>> = (0..n).map(|u| graph.neighbors(u).to_vec()).collect();
let mut perm = Vec::with_capacity(n);
let mut eliminated = vec![false; n];
let mut estimated_nnz = n;
for _step in 0..n {
let pivot = (0..n)
.filter(|&u| !eliminated[u])
.min_by_key(|&u| adj[u].iter().filter(|&&v| !eliminated[v]).count())
.unwrap_or(0);
perm.push(pivot);
eliminated[pivot] = true;
let neighbors: Vec<usize> = adj[pivot]
.iter()
.copied()
.filter(|&v| !eliminated[v])
.collect();
estimated_nnz += neighbors.len();
for i in 0..neighbors.len() {
for j in (i + 1)..neighbors.len() {
let u = neighbors[i];
let v = neighbors[j];
if !adj[u].contains(&v) {
adj[u].push(v);
adj[v].push(u);
}
}
}
}
let mut inv_perm = vec![0usize; n];
for (new_i, &old_i) in perm.iter().enumerate() {
inv_perm[old_i] = new_i;
}
Ok(AmdResult {
perm,
inv_perm,
estimated_nnz,
})
}
#[cfg(test)]
mod tests {
use super::*;
fn path_graph(n: usize) -> AdjacencyGraph {
let mut adj = vec![Vec::new(); n];
for i in 0..n.saturating_sub(1) {
adj[i].push(i + 1);
adj[i + 1].push(i);
}
AdjacencyGraph::from_adjacency_list(adj)
}
fn grid_graph(rows: usize, cols: usize) -> AdjacencyGraph {
let n = rows * cols;
let mut adj = vec![Vec::new(); n];
for r in 0..rows {
for c in 0..cols {
let u = r * cols + c;
if c + 1 < cols {
let v = r * cols + c + 1;
adj[u].push(v);
adj[v].push(u);
}
if r + 1 < rows {
let v = (r + 1) * cols + c;
adj[u].push(v);
adj[v].push(u);
}
}
}
AdjacencyGraph::from_adjacency_list(adj)
}
#[test]
fn test_amd_valid_permutation() {
let graph = path_graph(8);
let result = amd(&graph).expect("AMD");
assert_eq!(result.perm.len(), 8);
let mut sorted = result.perm.clone();
sorted.sort_unstable();
assert_eq!(sorted, (0..8).collect::<Vec<_>>());
}
#[test]
fn test_amd_simple_valid_permutation() {
let graph = path_graph(8);
let result = amd_simple(&graph).expect("AMD simple");
assert_eq!(result.perm.len(), 8);
let mut sorted = result.perm.clone();
sorted.sort_unstable();
assert_eq!(sorted, (0..8).collect::<Vec<_>>());
}
#[test]
fn test_amd_fill_estimate_reasonable() {
let graph = grid_graph(4, 4);
let result = amd(&graph).expect("AMD grid");
assert_eq!(result.perm.len(), 16);
assert!(result.estimated_nnz >= 16);
assert!(result.estimated_nnz <= 256);
}
#[test]
fn test_amd_empty_graph() {
let graph = AdjacencyGraph::from_adjacency_list(Vec::new());
let result = amd(&graph).expect("AMD empty");
assert!(result.perm.is_empty());
assert_eq!(result.estimated_nnz, 0);
}
#[test]
fn test_amd_single_node() {
let graph = AdjacencyGraph::from_adjacency_list(vec![Vec::new()]);
let result = amd(&graph).expect("AMD single");
assert_eq!(result.perm, vec![0]);
assert_eq!(result.estimated_nnz, 1); }
#[test]
fn test_amd_inverse_perm_consistency() {
let graph = grid_graph(3, 3);
let result = amd(&graph).expect("AMD");
for (new_i, &old_i) in result.perm.iter().enumerate() {
assert_eq!(
result.inv_perm[old_i], new_i,
"inv_perm inconsistency at old_i={}",
old_i
);
}
}
#[test]
fn test_amd_star_graph() {
let n = 6;
let mut adj = vec![Vec::new(); n];
for i in 1..n {
adj[0].push(i);
adj[i].push(0);
}
let graph = AdjacencyGraph::from_adjacency_list(adj);
let result = amd(&graph).expect("AMD star");
assert_eq!(result.perm.len(), n);
let center_pos = result.inv_perm[0];
assert!(
center_pos >= n / 2,
"center node should be eliminated late, got position {}",
center_pos
);
}
#[test]
fn test_amd_simple_vs_quotient_same_size() {
let graph = path_graph(6);
let r1 = amd(&graph).expect("AMD quotient");
let r2 = amd_simple(&graph).expect("AMD simple");
assert_eq!(r1.perm.len(), r2.perm.len());
let mut s1 = r1.perm.clone();
s1.sort_unstable();
assert_eq!(s1, (0..6).collect::<Vec<_>>());
let mut s2 = r2.perm.clone();
s2.sort_unstable();
assert_eq!(s2, (0..6).collect::<Vec<_>>());
}
}