use std::time::Instant;
use faer::dyn_stack::{MemBuffer, MemStack};
use faer::sparse::linalg::amd;
use faer::sparse::SymbolicSparseColMatRef;
pub fn amd_with_deadline(
n: usize,
col_ptr: &[usize],
row_ind: &[usize],
deadline: Option<Instant>,
) -> Vec<usize> {
if n == 0 {
return vec![];
}
if let Some(dl) = deadline {
if Instant::now() >= dl {
return (0..n).collect();
}
}
let nnz = col_ptr[n];
let mut perm = vec![0usize; n];
let mut perm_inv = vec![0usize; n];
let a =
unsafe { SymbolicSparseColMatRef::<usize>::new_unchecked(n, n, col_ptr, None, row_ind) };
let req = amd::order_scratch::<usize>(n, nnz);
let mut mem = MemBuffer::new(req);
let stack = MemStack::new(&mut mem);
if amd::order(&mut perm, &mut perm_inv, a, amd::Control::default(), stack).is_err() {
return (0..n).collect(); }
perm
}
fn inverse_perm(perm: &[usize]) -> Vec<usize> {
let n = perm.len();
let mut inv = vec![0usize; n];
for (k, &i) in perm.iter().enumerate() {
inv[i] = k;
}
inv
}
pub fn permute_sym_upper(
n: usize,
col_ptr: &[usize],
row_ind: &[usize],
values: &[f64],
perm: &[usize],
) -> (Vec<usize>, Vec<usize>, Vec<f64>) {
let inv_perm = inverse_perm(perm);
let mut entries: Vec<(usize, usize, f64)> = Vec::new();
for j in 0..n {
let new_j = inv_perm[j];
for idx in col_ptr[j]..col_ptr[j + 1] {
let i = row_ind[idx];
let v = values[idx];
let new_i = inv_perm[i];
let (r, c) = if new_i <= new_j {
(new_i, new_j)
} else {
(new_j, new_i)
};
entries.push((r, c, v));
}
}
entries.sort_unstable_by(|a, b| a.1.cmp(&b.1).then(a.0.cmp(&b.0)));
let nnz = entries.len();
let mut new_col_ptr = vec![0usize; n + 1];
for &(_, c, _) in &entries {
new_col_ptr[c + 1] += 1;
}
for j in 0..n {
new_col_ptr[j + 1] += new_col_ptr[j];
}
let mut new_row_ind = vec![0usize; nnz];
let mut new_values = vec![0.0f64; nnz];
for (idx, &(r, _, v)) in entries.iter().enumerate() {
new_row_ind[idx] = r;
new_values[idx] = v;
}
(new_col_ptr, new_row_ind, new_values)
}
pub fn permute_vec(v: &[f64], perm: &[usize]) -> Vec<f64> {
perm.iter().map(|&i| v[i]).collect()
}
pub fn inv_permute_vec(v: &[f64], perm: &[usize]) -> Vec<f64> {
let n = v.len();
let mut out = vec![0.0f64; n];
for (k, &i) in perm.iter().enumerate() {
out[i] = v[k];
}
out
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_amd_n1() {
let col_ptr = vec![0usize, 0];
let row_ind: Vec<usize> = vec![];
let perm = amd_with_deadline(1, &col_ptr, &row_ind, None);
assert_eq!(perm, vec![0]);
}
#[test]
fn test_amd_star_graph() {
let n = 5;
let col_ptr = vec![0, 0, 1, 2, 3, 4];
let row_ind = vec![0, 0, 0, 0];
let perm = amd_with_deadline(n, &col_ptr, &row_ind, None);
assert_eq!(perm.len(), n);
assert_ne!(perm[0], 0, "Central node 0 should not be eliminated first");
assert!(
perm[0] >= 1 && perm[0] <= 4,
"First eliminated node should be a leaf"
);
let mut check = perm.clone();
check.sort_unstable();
assert_eq!(check, (0..n).collect::<Vec<_>>());
}
#[test]
fn test_amd_returns_valid_permutation() {
let n = 4;
let col_ptr = vec![0, 0, 1, 2, 3];
let row_ind = vec![0, 1, 2];
let perm = amd_with_deadline(n, &col_ptr, &row_ind, None);
assert_eq!(perm.len(), n);
let mut check = perm.clone();
check.sort_unstable();
assert_eq!(check, vec![0, 1, 2, 3]);
}
#[test]
fn test_amd_empty() {
let perm = amd_with_deadline(0, &[0], &[], None);
assert!(perm.is_empty());
}
#[test]
fn test_inverse_perm() {
let perm = vec![2, 0, 3, 1];
let inv = inverse_perm(&perm);
assert_eq!(inv, vec![1, 3, 0, 2]);
}
#[test]
fn test_permute_sym_upper_swap() {
let n = 2;
let col_ptr = vec![0, 1, 3];
let row_ind = vec![0, 0, 1]; let values = vec![4.0, 1.0, 3.0];
let perm = vec![1, 0];
let (new_col_ptr, new_row_ind, new_values) =
permute_sym_upper(n, &col_ptr, &row_ind, &values, &perm);
assert_eq!(new_col_ptr, vec![0, 1, 3]);
assert_eq!(new_row_ind, vec![0, 0, 1]);
let eps = 1e-14;
assert!(
(new_values[0] - 3.0).abs() < eps,
"A_p[0,0]={}",
new_values[0]
);
assert!(
(new_values[1] - 1.0).abs() < eps,
"A_p[0,1]={}",
new_values[1]
);
assert!(
(new_values[2] - 4.0).abs() < eps,
"A_p[1,1]={}",
new_values[2]
);
}
#[test]
fn test_permute_vec() {
let v = vec![10.0, 20.0, 30.0, 40.0];
let perm = vec![2, 0, 3, 1];
let result = permute_vec(&v, &perm);
assert_eq!(result, vec![30.0, 10.0, 40.0, 20.0]);
}
#[test]
fn test_inv_permute_vec() {
let v = vec![30.0, 10.0, 40.0, 20.0];
let perm = vec![2, 0, 3, 1];
let result = inv_permute_vec(&v, &perm);
assert_eq!(result, vec![10.0, 20.0, 30.0, 40.0]);
}
#[test]
fn test_permute_inv_permute_roundtrip() {
let v = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let perm = vec![3, 0, 4, 2, 1];
let pv = permute_vec(&v, &perm);
let recovered = inv_permute_vec(&pv, &perm);
for (a, b) in v.iter().zip(recovered.iter()) {
assert!((a - b).abs() < 1e-14, "roundtrip failed: {} != {}", a, b);
}
}
}