use crate::error::{Error, Result};
#[derive(Debug, Clone)]
pub struct ColamdOptions {
pub dense_row_threshold: f64,
pub dense_col_threshold: f64,
pub compute_stats: bool,
}
impl Default for ColamdOptions {
fn default() -> Self {
Self {
dense_row_threshold: 0.5,
dense_col_threshold: 0.5,
compute_stats: false,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct ColamdStats {
pub n_dense_rows: usize,
pub n_dense_cols: usize,
pub n_empty_cols: usize,
pub n_cols_ordered: usize,
}
const DEAD: i32 = -1;
#[derive(Debug, Clone)]
struct Column {
degree: i32,
head: i32,
length: i32,
score: i32,
prev: i32,
next: i32,
_parent: i32,
order: i32,
}
impl Column {
fn new() -> Self {
Self {
degree: 0,
head: -1,
length: 0,
score: 0,
prev: -1,
next: -1,
_parent: -1,
order: -1,
}
}
fn is_dead(&self) -> bool {
self.head == DEAD
}
fn is_alive(&self) -> bool {
self.head != DEAD
}
}
#[derive(Debug, Clone)]
struct Row {
degree: i32,
_head: i32,
_length: i32,
_mark: i32,
_first_col: i32,
}
impl Row {
fn new() -> Self {
Self {
degree: 0,
_head: -1,
_length: 0,
_mark: 0,
_first_col: -1,
}
}
fn is_alive(&self) -> bool {
self.degree != DEAD
}
}
pub fn colamd(
n_rows: usize,
n_cols: usize,
col_ptrs: &[i64],
row_indices: &[i64],
options: &ColamdOptions,
) -> Result<(Vec<usize>, ColamdStats)> {
if col_ptrs.len() != n_cols + 1 {
return Err(Error::InvalidArgument {
arg: "col_ptrs",
reason: format!(
"length {} does not match n_cols + 1 = {}",
col_ptrs.len(),
n_cols + 1
),
});
}
let nnz = col_ptrs[n_cols] as usize;
if row_indices.len() < nnz {
return Err(Error::InvalidArgument {
arg: "row_indices",
reason: format!("length {} is less than nnz = {}", row_indices.len(), nnz),
});
}
if n_cols == 0 {
return Ok((vec![], ColamdStats::default()));
}
let mut stats = ColamdStats::default();
let mut cols: Vec<Column> = vec![Column::new(); n_cols];
let mut rows: Vec<Row> = vec![Row::new(); n_rows];
let dense_row_count = (options.dense_row_threshold * n_cols as f64).ceil() as i32;
let dense_col_count = (options.dense_col_threshold * n_rows as f64).ceil() as i32;
let mut dense_cols: Vec<usize> = Vec::new();
let mut empty_cols: Vec<usize> = Vec::new();
for j in 0..n_cols {
let start = col_ptrs[j] as usize;
let end = col_ptrs[j + 1] as usize;
let degree = (end - start) as i32;
if degree == 0 {
empty_cols.push(j);
cols[j].head = DEAD;
stats.n_empty_cols += 1;
} else if degree > dense_col_count {
dense_cols.push(j);
cols[j].head = DEAD;
stats.n_dense_cols += 1;
} else {
cols[j].degree = degree;
cols[j].length = degree;
cols[j].head = start as i32;
}
}
for j in 0..n_cols {
if cols[j].is_dead() {
continue;
}
let start = col_ptrs[j] as usize;
let end = col_ptrs[j + 1] as usize;
for idx in start..end {
let i = row_indices[idx] as usize;
if i < n_rows {
rows[i].degree += 1;
}
}
}
for i in 0..n_rows {
if rows[i].degree > dense_row_count {
rows[i].degree = DEAD;
stats.n_dense_rows += 1;
}
}
for j in 0..n_cols {
if cols[j].is_dead() {
continue;
}
let start = col_ptrs[j] as usize;
let end = col_ptrs[j + 1] as usize;
let mut new_degree = 0i32;
for idx in start..end {
let i = row_indices[idx] as usize;
if i < n_rows && rows[i].is_alive() {
new_degree += 1;
}
}
cols[j].degree = new_degree;
cols[j].score = new_degree;
if new_degree == 0 {
empty_cols.push(j);
cols[j].head = DEAD;
}
}
let max_degree = n_rows.min(n_cols);
let mut degree_head: Vec<i32> = vec![-1; max_degree + 1];
for j in 0..n_cols {
if cols[j].is_alive() {
let d = cols[j].score as usize;
if d <= max_degree {
let head = degree_head[d];
cols[j].next = head;
cols[j].prev = -1;
if head >= 0 {
cols[head as usize].prev = j as i32;
}
degree_head[d] = j as i32;
}
}
}
let mut row_col_adj: Vec<Vec<usize>> = vec![Vec::new(); n_rows];
for j in 0..n_cols {
if cols[j].is_dead() {
continue;
}
let start = col_ptrs[j] as usize;
let end = col_ptrs[j + 1] as usize;
for idx in start..end {
let i = row_indices[idx] as usize;
if i < n_rows && rows[i].is_alive() {
row_col_adj[i].push(j);
}
}
}
let mut perm: Vec<usize> = Vec::with_capacity(n_cols);
let mut min_degree = 0usize;
let mut n_ordered = 0usize;
let _n_to_order = n_cols - empty_cols.len() - dense_cols.len();
for &j in &empty_cols {
cols[j].order = perm.len() as i32;
perm.push(j);
n_ordered += 1;
}
while n_ordered < n_cols - dense_cols.len() {
while min_degree <= max_degree && degree_head[min_degree] < 0 {
min_degree += 1;
}
if min_degree > max_degree {
break;
}
let pivot_col = degree_head[min_degree] as usize;
let next = cols[pivot_col].next;
degree_head[min_degree] = next;
if next >= 0 {
cols[next as usize].prev = -1;
}
cols[pivot_col].order = perm.len() as i32;
cols[pivot_col].head = DEAD;
perm.push(pivot_col);
n_ordered += 1;
let start = col_ptrs[pivot_col] as usize;
let end = col_ptrs[pivot_col + 1] as usize;
let mut affected_rows: Vec<usize> = Vec::new();
for idx in start..end {
let i = row_indices[idx] as usize;
if i < n_rows && rows[i].is_alive() {
affected_rows.push(i);
}
}
let mut affected_cols: Vec<usize> = Vec::new();
let mut col_seen: Vec<bool> = vec![false; n_cols];
col_seen[pivot_col] = true;
for &i in &affected_rows {
for &j in &row_col_adj[i] {
if !col_seen[j] && cols[j].is_alive() {
col_seen[j] = true;
affected_cols.push(j);
}
}
}
for &i in &affected_rows {
row_col_adj[i].retain(|&j| j != pivot_col);
rows[i].degree -= 1;
if rows[i].degree <= 0 {
rows[i].degree = DEAD;
}
}
for &j in &affected_cols {
let old_score = cols[j].score as usize;
if old_score <= max_degree {
let prev = cols[j].prev;
let next = cols[j].next;
if prev >= 0 {
cols[prev as usize].next = next;
} else if degree_head[old_score] == j as i32 {
degree_head[old_score] = next;
}
if next >= 0 {
cols[next as usize].prev = prev;
}
}
let start = col_ptrs[j] as usize;
let end = col_ptrs[j + 1] as usize;
let mut new_degree = 0i32;
for idx in start..end {
let i = row_indices[idx] as usize;
if i < n_rows && rows[i].is_alive() {
new_degree += 1;
}
}
cols[j].degree = new_degree;
cols[j].score = new_degree;
if new_degree == 0 {
cols[j].order = perm.len() as i32;
cols[j].head = DEAD;
perm.push(j);
n_ordered += 1;
} else {
let d = new_degree as usize;
if d <= max_degree {
let head = degree_head[d];
cols[j].next = head;
cols[j].prev = -1;
if head >= 0 {
cols[head as usize].prev = j as i32;
}
degree_head[d] = j as i32;
if d < min_degree {
min_degree = d;
}
}
}
}
}
for &j in &dense_cols {
cols[j].order = perm.len() as i32;
perm.push(j);
}
stats.n_cols_ordered = perm.len();
Ok((perm, stats))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_colamd_empty_matrix() {
let col_ptrs = vec![0i64];
let row_indices: Vec<i64> = vec![];
let (perm, stats) =
colamd(0, 0, &col_ptrs, &row_indices, &ColamdOptions::default()).unwrap();
assert!(perm.is_empty());
assert_eq!(stats.n_cols_ordered, 0);
}
#[test]
fn test_colamd_single_column() {
let col_ptrs = vec![0i64, 2];
let row_indices = vec![0i64, 1];
let (perm, stats) =
colamd(2, 1, &col_ptrs, &row_indices, &ColamdOptions::default()).unwrap();
assert_eq!(perm, vec![0]);
assert_eq!(stats.n_cols_ordered, 1);
}
#[test]
fn test_colamd_diagonal_matrix() {
let col_ptrs = vec![0i64, 1, 2, 3];
let row_indices = vec![0i64, 1, 2];
let (perm, stats) =
colamd(3, 3, &col_ptrs, &row_indices, &ColamdOptions::default()).unwrap();
assert_eq!(perm.len(), 3);
assert_eq!(stats.n_cols_ordered, 3);
let mut sorted_perm = perm.clone();
sorted_perm.sort();
assert_eq!(sorted_perm, vec![0, 1, 2]);
}
#[test]
fn test_colamd_tridiagonal_matrix() {
let col_ptrs = vec![0i64, 2, 4, 6, 7];
let row_indices = vec![0i64, 1, 1, 2, 2, 3, 3];
let (perm, stats) =
colamd(4, 4, &col_ptrs, &row_indices, &ColamdOptions::default()).unwrap();
assert_eq!(perm.len(), 4);
assert_eq!(stats.n_cols_ordered, 4);
let mut sorted_perm = perm.clone();
sorted_perm.sort();
assert_eq!(sorted_perm, vec![0, 1, 2, 3]);
}
#[test]
fn test_colamd_with_empty_column() {
let col_ptrs = vec![0i64, 2, 2, 4]; let row_indices = vec![0i64, 1, 0, 1];
let (perm, stats) =
colamd(2, 3, &col_ptrs, &row_indices, &ColamdOptions::default()).unwrap();
assert_eq!(perm.len(), 3);
assert_eq!(stats.n_empty_cols, 1);
assert_eq!(perm[0], 1);
}
#[test]
fn test_colamd_arrow_matrix() {
let col_ptrs = vec![0i64, 4, 6, 8, 10];
let row_indices = vec![0i64, 1, 2, 3, 0, 1, 0, 2, 0, 3];
let options = ColamdOptions {
dense_col_threshold: 0.7,
..Default::default()
};
let (perm, stats) = colamd(4, 4, &col_ptrs, &row_indices, &options).unwrap();
assert_eq!(perm.len(), 4);
assert_eq!(stats.n_dense_cols, 1);
assert_eq!(perm[3], 0);
}
#[test]
fn test_colamd_permutation_validity() {
let col_ptrs = vec![0i64, 3, 5, 8, 10, 12];
let row_indices = vec![0i64, 2, 4, 1, 3, 0, 2, 4, 1, 3, 0, 4];
let (perm, _) = colamd(5, 5, &col_ptrs, &row_indices, &ColamdOptions::default()).unwrap();
assert_eq!(perm.len(), 5);
let mut sorted = perm.clone();
sorted.sort();
assert_eq!(sorted, vec![0, 1, 2, 3, 4]);
}
}