use crate::algorithm::sparse_linalg::ordering::{ColamdOptions, colamd};
use crate::error::Result;
use super::types::{QrOptions, QrOrdering, QrSymbolic};
pub fn sparse_qr_symbolic(
col_ptrs: &[i64],
row_indices: &[i64],
m: usize,
n: usize,
options: &QrOptions,
) -> Result<QrSymbolic> {
let col_perm = match options.ordering {
QrOrdering::Identity => (0..n).collect::<Vec<usize>>(),
QrOrdering::Colamd => {
let colamd_opts = ColamdOptions::default();
let (perm, _stats) = colamd(m, n, col_ptrs, row_indices, &colamd_opts)?;
perm
}
};
let (perm_col_ptrs, perm_row_indices) = permute_columns(col_ptrs, row_indices, n, &col_perm);
let etree = compute_etree_ata(&perm_col_ptrs, &perm_row_indices, m, n);
let r_col_counts = compute_r_col_counts(&perm_col_ptrs, &perm_row_indices, &etree, m, n);
let predicted_r_nnz: usize = r_col_counts.iter().sum();
Ok(QrSymbolic {
m,
n,
etree,
r_col_counts,
col_perm,
predicted_r_nnz,
})
}
fn permute_columns(
col_ptrs: &[i64],
row_indices: &[i64],
n: usize,
perm: &[usize],
) -> (Vec<i64>, Vec<i64>) {
let mut new_counts = vec![0usize; n];
for new_col in 0..n {
let old_col = perm[new_col];
let start = col_ptrs[old_col] as usize;
let end = col_ptrs[old_col + 1] as usize;
new_counts[new_col] = end - start;
}
let mut new_col_ptrs = vec![0i64; n + 1];
for j in 0..n {
new_col_ptrs[j + 1] = new_col_ptrs[j] + new_counts[j] as i64;
}
let total_nnz = new_col_ptrs[n] as usize;
let mut new_row_indices = vec![0i64; total_nnz];
for new_col in 0..n {
let old_col = perm[new_col];
let old_start = col_ptrs[old_col] as usize;
let old_end = col_ptrs[old_col + 1] as usize;
let new_start = new_col_ptrs[new_col] as usize;
for (i, &row) in row_indices[old_start..old_end].iter().enumerate() {
new_row_indices[new_start + i] = row;
}
}
(new_col_ptrs, new_row_indices)
}
fn compute_etree_ata(col_ptrs: &[i64], row_indices: &[i64], m: usize, n: usize) -> Vec<i64> {
let mut parent = vec![-1i64; n];
let mut ancestor = vec![0usize; n];
for j in 0..n {
ancestor[j] = j;
}
let mut first_col = vec![usize::MAX; m];
for j in 0..n {
ancestor[j] = j;
let start = col_ptrs[j] as usize;
let end = col_ptrs[j + 1] as usize;
for &row in &row_indices[start..end] {
let row = row as usize;
let k = first_col[row];
if k == usize::MAX {
first_col[row] = j;
} else {
let mut r = k;
while ancestor[r] != r {
r = ancestor[r];
}
let mut node = k;
while node != r {
let next = ancestor[node];
ancestor[node] = r;
node = next;
}
if r != j {
parent[r] = j as i64;
ancestor[r] = j;
}
}
}
}
parent
}
fn compute_r_col_counts(
col_ptrs: &[i64],
_row_indices: &[i64],
etree: &[i64],
m: usize,
n: usize,
) -> Vec<usize> {
let mut counts = vec![0usize; n];
for col in 0..n {
let start = col_ptrs[col] as usize;
let end = col_ptrs[col + 1] as usize;
let direct = end - start;
counts[col] = direct.min(m);
}
for j in 0..n {
let parent = etree[j];
if parent >= 0 && (parent as usize) < n {
let contribution = if counts[j] > 0 { counts[j] - 1 } else { 0 };
counts[parent as usize] = counts[parent as usize].max(contribution + 1);
}
}
for count in &mut counts {
*count = (*count).max(1);
}
counts
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_symbolic_identity_ordering() {
let col_ptrs = vec![0i64, 1, 2, 3];
let row_indices = vec![0i64, 1, 2];
let options = QrOptions::no_ordering();
let symbolic = sparse_qr_symbolic(&col_ptrs, &row_indices, 3, 3, &options).unwrap();
assert_eq!(symbolic.m, 3);
assert_eq!(symbolic.n, 3);
assert_eq!(symbolic.col_perm, vec![0, 1, 2]);
}
#[test]
fn test_symbolic_tridiagonal() {
let col_ptrs = vec![0i64, 2, 4, 6, 7];
let row_indices = vec![0i64, 1, 1, 2, 2, 3, 3];
let options = QrOptions::no_ordering();
let symbolic = sparse_qr_symbolic(&col_ptrs, &row_indices, 4, 4, &options).unwrap();
assert_eq!(symbolic.m, 4);
assert_eq!(symbolic.n, 4);
for &count in &symbolic.r_col_counts {
assert!(count >= 1);
}
}
#[test]
fn test_symbolic_with_colamd() {
let col_ptrs = vec![0i64, 3, 5, 7];
let row_indices = vec![0i64, 1, 2, 1, 3, 0, 3];
let options = QrOptions::default(); let symbolic = sparse_qr_symbolic(&col_ptrs, &row_indices, 4, 3, &options).unwrap();
assert_eq!(symbolic.m, 4);
assert_eq!(symbolic.n, 3);
assert_eq!(symbolic.col_perm.len(), 3);
let mut sorted_perm = symbolic.col_perm.clone();
sorted_perm.sort_unstable();
assert_eq!(sorted_perm, vec![0, 1, 2]);
}
#[test]
fn test_etree_chain() {
let col_ptrs = vec![0i64, 2, 4, 6];
let row_indices = vec![0i64, 1, 1, 2, 2, 3];
let etree = compute_etree_ata(&col_ptrs, &row_indices, 4, 3);
assert_eq!(etree[0], 1);
assert_eq!(etree[1], 2);
assert_eq!(etree[2], -1);
}
}