use crate::error::FeralError;
use crate::sparse::csc::{CscMatrix, CscPattern};
pub fn compute_schur_aware_perm(
matrix: &CscMatrix,
schur_indices: &[usize],
) -> Result<Vec<usize>, FeralError> {
let n = matrix.n;
let n_schur = schur_indices.len();
if n_schur == 0 {
let pattern = matrix.symmetric_pattern();
return run_amd(&pattern);
}
if n_schur == n {
return Err(FeralError::InvalidInput(
"schur_indices.len() == n is not allowed; elimination set would be empty".to_string(),
));
}
let mut is_schur = vec![false; n];
for &s in schur_indices {
if s >= n {
return Err(FeralError::InvalidInput(format!(
"schur_indices entry {} out of range for n={}",
s, n
)));
}
if is_schur[s] {
return Err(FeralError::InvalidInput(format!(
"schur_indices contains duplicate entry {}",
s
)));
}
is_schur[s] = true;
}
let n_f = n - n_schur;
let mut non_schur_indices = Vec::with_capacity(n_f);
let mut sub_of = vec![usize::MAX; n];
for orig in 0..n {
if !is_schur[orig] {
sub_of[orig] = non_schur_indices.len();
non_schur_indices.push(orig);
}
}
let full_pattern = matrix.symmetric_pattern();
let sub_pattern = restrict_pattern_to_subgraph(&full_pattern, &sub_of, n_f);
let sub_perm = run_amd(&sub_pattern)?;
let mut perm = Vec::with_capacity(n);
for &sub_idx in &sub_perm {
perm.push(non_schur_indices[sub_idx]);
}
for &s in schur_indices {
perm.push(s);
}
debug_assert_eq!(perm.len(), n);
Ok(perm)
}
fn restrict_pattern_to_subgraph(full: &CscPattern, sub_of: &[usize], n_f: usize) -> CscPattern {
let mut col_counts = vec![0usize; n_f];
for j_orig in 0..full.n {
let j_loc = sub_of[j_orig];
if j_loc == usize::MAX {
continue;
}
for k in full.col_ptr[j_orig]..full.col_ptr[j_orig + 1] {
let i_orig = full.row_idx[k];
if sub_of[i_orig] != usize::MAX {
col_counts[j_loc] += 1;
}
}
}
let mut col_ptr = vec![0usize; n_f + 1];
for j in 0..n_f {
col_ptr[j + 1] = col_ptr[j] + col_counts[j];
}
let nnz = col_ptr[n_f];
let mut row_idx = vec![0usize; nnz];
let mut offsets = col_ptr[..n_f].to_vec();
for j_orig in 0..full.n {
let j_loc = sub_of[j_orig];
if j_loc == usize::MAX {
continue;
}
for k in full.col_ptr[j_orig]..full.col_ptr[j_orig + 1] {
let i_orig = full.row_idx[k];
let i_loc = sub_of[i_orig];
if i_loc != usize::MAX {
row_idx[offsets[j_loc]] = i_loc;
offsets[j_loc] += 1;
}
}
}
for j in 0..n_f {
let start = col_ptr[j];
let end = col_ptr[j + 1];
row_idx[start..end].sort_unstable();
}
CscPattern {
n: n_f,
col_ptr,
row_idx,
}
}
fn run_amd(pattern: &CscPattern) -> Result<Vec<usize>, FeralError> {
if pattern.n == 0 {
return Ok(Vec::new());
}
let col_buf: Result<Vec<i32>, _> = pattern.col_ptr.iter().map(|&x| i32::try_from(x)).collect();
let col_buf = col_buf.map_err(|_| {
FeralError::InvalidInput("matrix too large for i32-indexed AMD".to_string())
})?;
let row_buf: Result<Vec<i32>, _> = pattern.row_idx.iter().map(|&x| i32::try_from(x)).collect();
let row_buf = row_buf.map_err(|_| {
FeralError::InvalidInput("matrix too large for i32-indexed AMD".to_string())
})?;
let pat = feral_ordering_core::CscPattern::new(pattern.n, &col_buf, &row_buf)
.ok_or_else(|| FeralError::InvalidInput("malformed CSC pattern".to_string()))?;
let perm_i32 = feral_amd::amd_order(&pat)
.map_err(|e| FeralError::InvalidInput(format!("AMD failed: {}", e)))?;
let mut out: Vec<usize> = Vec::with_capacity(perm_i32.len());
for x in perm_i32 {
let u = usize::try_from(x)
.map_err(|_| FeralError::InvalidInput("AMD returned negative index".to_string()))?;
if u >= pattern.n {
return Err(FeralError::InvalidInput(
"AMD returned out-of-range index".to_string(),
));
}
out.push(u);
}
Ok(out)
}
#[cfg(test)]
mod tests {
use super::*;
fn small_kkt() -> CscMatrix {
let rows = vec![0, 1, 2, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5];
let cols = vec![0, 1, 2, 3, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 5];
let vals = vec![
1.0, 1.0, 1.0, 1.0, 0.5, 0.5, 0.5, 0.5, 2.0, 0.3, 0.3, 0.3, 0.3, 0.7, 3.0, ];
CscMatrix::from_triplets(6, &rows, &cols, &vals).unwrap()
}
#[test]
fn perm_places_schur_tail_at_end_in_user_order() {
let m = small_kkt();
let schur = vec![5, 4];
let perm = compute_schur_aware_perm(&m, &schur).unwrap();
assert_eq!(perm.len(), 6);
assert_eq!(perm[4], 5);
assert_eq!(perm[5], 4);
let mut prefix = perm[..4].to_vec();
prefix.sort();
assert_eq!(prefix, vec![0, 1, 2, 3]);
}
#[test]
fn perm_is_a_valid_permutation() {
let m = small_kkt();
let schur = vec![4, 5];
let perm = compute_schur_aware_perm(&m, &schur).unwrap();
let mut sorted = perm.clone();
sorted.sort();
assert_eq!(sorted, vec![0, 1, 2, 3, 4, 5]);
}
#[test]
fn empty_schur_falls_back_to_full_amd() {
let m = small_kkt();
let perm = compute_schur_aware_perm(&m, &[]).unwrap();
assert_eq!(perm.len(), 6);
let mut sorted = perm.clone();
sorted.sort();
assert_eq!(sorted, vec![0, 1, 2, 3, 4, 5]);
}
#[test]
fn duplicate_schur_indices_rejected() {
let m = small_kkt();
let r = compute_schur_aware_perm(&m, &[4, 4]);
assert!(matches!(r, Err(FeralError::InvalidInput(_))));
}
#[test]
fn out_of_range_schur_index_rejected() {
let m = small_kkt();
let r = compute_schur_aware_perm(&m, &[6]);
assert!(matches!(r, Err(FeralError::InvalidInput(_))));
}
#[test]
fn full_schur_rejected() {
let m = small_kkt();
let r = compute_schur_aware_perm(&m, &[0, 1, 2, 3, 4, 5]);
assert!(matches!(r, Err(FeralError::InvalidInput(_))));
}
#[test]
fn n_zero_with_empty_schur_returns_empty() {
let m = CscMatrix::from_triplets(0, &[], &[], &[]).unwrap();
let perm = compute_schur_aware_perm(&m, &[]).unwrap();
assert!(perm.is_empty());
}
#[test]
fn schur_size_one_works() {
let m = small_kkt();
let perm = compute_schur_aware_perm(&m, &[3]).unwrap();
assert_eq!(perm.len(), 6);
assert_eq!(perm[5], 3);
let mut prefix = perm[..5].to_vec();
prefix.sort();
assert_eq!(prefix, vec![0, 1, 2, 4, 5]);
}
#[test]
fn restrict_pattern_drops_schur_edges() {
let m = small_kkt();
let full = m.symmetric_pattern();
let mut sub_of = vec![usize::MAX; 6];
sub_of[0] = 0;
sub_of[1] = 1;
sub_of[2] = 2;
sub_of[3] = 3;
let sub = restrict_pattern_to_subgraph(&full, &sub_of, 4);
assert_eq!(sub.n, 4);
for j in 0..4 {
let nnz_j = sub.col_ptr[j + 1] - sub.col_ptr[j];
assert_eq!(nnz_j, 1, "column {} expected 1 entry, got {}", j, nnz_j);
assert_eq!(sub.row_idx[sub.col_ptr[j]], j);
}
}
}