use crate::lu::*;
use crate::lufact::{lucomp, lucopy, ludfs, maxmatch};
pub fn dgstrf(gp: &GP, mrows: i32, ncols: i32, a_nz: &[f64], desc_a: &mut CSC) -> Result<LU, i32> {
let nrow: i32 = mrows;
let ncol: i32 = ncols;
let mut zpivot: i32 = 0;
let mut nz_count_limit: i32;
let mut flops: f64 = 0.0;
let mut pivt_row: i32;
let mut orig_row: i32;
let mut this_col: i32;
let mut othr_col: i32;
let pivot_policy: i32 = gp.pivot_policy;
let pivot_threshold: f64 = gp.pivot_threshold;
let drop_threshold: f64 = gp.drop_threshold;
let col_fill_ratio: f64 = gp.col_fill_ratio;
let fill_ratio: f64 = gp.fill_ratio;
let expand_ratio: f64 = gp.expand_ratio;
let user_col_perm: &Option<Vec<i32>> = &gp.col_perm;
let user_col_perm_base: i32 = gp.col_perm_base;
log::info!(
"piv_pol={} piv_thr={} drop_thr={} col_fill_rt={}",
pivot_policy,
pivot_threshold,
drop_threshold,
col_fill_ratio
);
if let Some(col_perm) = user_col_perm {
if col_perm.len() != ncol as usize {
return Err(-1);
}
}
let _a_m = desc_a.m;
let a_n = desc_a.n;
let a_nnz = desc_a.nnz;
let a_base = desc_a.base;
let a_colptr = &mut desc_a.colptr;
let a_rowind = &mut desc_a.rowind;
if a_base == 0 {
for jcol in 0..a_n + 1 {
a_colptr[jcol as usize] += 1;
}
for jcol in 0..a_nnz {
a_rowind[jcol as usize] += 1;
}
desc_a.base = 1;
}
let mut rwork = vec![0.0; nrow as usize];
let mut twork = vec![0.0; nrow as usize];
let mut found = vec![0; nrow as usize];
let mut child = vec![0; nrow as usize];
let mut parent = vec![0; nrow as usize];
let mut pattern = vec![0; nrow as usize];
let mut lu = LU::new(nrow, ncol, (a_nnz as f64 * fill_ratio) as i32);
let mut cmatch = vec![0; ncol as usize];
let mut rmatch = vec![0; nrow as usize];
{
lu.l_colptr.fill(0);
lu.u_colptr.fill(0);
lu.col_perm.fill(0);
lu.lu_rowind.fill(0);
rmatch.fill(0);
cmatch.fill(0);
lu.row_perm.fill(0);
}
maxmatch(
nrow, ncol, a_colptr, a_rowind, &mut lu.l_colptr, &mut lu.u_colptr, &mut lu.row_perm, &mut lu.col_perm, &mut lu.lu_rowind, &mut rmatch, &mut cmatch, );
for jcol in 0..ncol {
if cmatch[jcol as usize] == 0 {
log::warn!("perfect matching not found");
break;
}
}
let mut lastlu = 0;
let mut local_pivot_policy = pivot_policy;
let _lasta = a_colptr[ncol as usize] - 1;
lu.u_colptr[0] = 1;
pattern.fill(0);
found.fill(0);
rwork.fill(0.0);
lu.row_perm.fill(0);
if let Some(user_col_perm) = user_col_perm {
log::info!("user_col_perm_base={}", user_col_perm_base);
for jcol in 0..ncol {
lu.col_perm[jcol as usize] = user_col_perm[jcol as usize] + (1 - user_col_perm_base);
}
} else {
for jcol in 0..ncol {
lu.col_perm[jcol as usize] = jcol + 1;
}
}
for jcol in 1..=ncol {
if lastlu + nrow >= lu.lu_size {
let new_size: i32 = (lu.lu_size as f64 * expand_ratio) as i32;
log::info!("expanding to {} nonzeros...", new_size);
let mut lu_nz = vec![0.0; new_size as usize];
lu_nz[..lu.lu_size as usize].copy_from_slice(&lu.lu_nz[..]);
lu.lu_nz = lu_nz;
let mut lu_rowind = vec![0; new_size as usize];
lu_rowind[..lu.lu_size as usize].copy_from_slice(&lu.lu_rowind[..]);
lu.lu_rowind = lu_rowind;
lu.lu_size = new_size;
}
{
let jjj = lu.col_perm[(jcol - 1) as usize];
for i in a_colptr[(jjj - 1) as usize]..a_colptr[jjj as usize] {
pattern[(a_rowind[(i - 1) as usize] - 1) as usize] = 1;
}
this_col = lu.col_perm[(jcol - 1) as usize];
orig_row = cmatch[(this_col - 1) as usize];
pattern[(orig_row - 1) as usize] = 2;
if lu.row_perm[(orig_row - 1) as usize] != 0 {
log::error!("pivot row from max-matching already used");
return Err(1);
}
};
let info = ludfs(
jcol,
a_nz,
a_rowind,
a_colptr,
&mut lastlu,
&mut lu.lu_rowind,
&mut lu.l_colptr,
&mut lu.u_colptr,
&mut lu.row_perm,
&mut lu.col_perm,
&mut rwork,
&mut found,
&mut parent,
&mut child,
);
if info != 0 {
return Err(-100);
}
flops = lucomp(
jcol,
&mut lastlu,
&mut lu.lu_nz,
&mut lu.lu_rowind,
&mut lu.l_colptr,
&mut lu.u_colptr,
&lu.row_perm,
&lu.col_perm,
&mut rwork,
&mut found,
&mut pattern,
);
if rwork[(orig_row - 1) as usize] == 0.0 {
let mut buf = String::from("matching to as zero: ");
for i in a_colptr[(jcol - 1) as usize]..a_colptr[jcol as usize] {
buf.push_str(&format!(
"({}, {:?}) ",
a_rowind[(i - 1) as usize],
a_nz[(i - 1) as usize]
));
}
log::warn!("{}. orig_row={}", buf, orig_row);
}
nz_count_limit = (col_fill_ratio
* ((a_colptr[this_col as usize] - a_colptr[(this_col - 1) as usize] + 1) as f64))
as i32;
lucopy(
local_pivot_policy,
pivot_threshold,
drop_threshold,
nz_count_limit,
jcol,
ncol,
&mut lastlu,
&mut lu.lu_nz,
&mut lu.lu_rowind,
&mut lu.l_colptr,
&mut lu.u_colptr,
&mut lu.row_perm,
&mut lu.col_perm,
&mut rwork,
&mut pattern,
&mut twork,
&mut flops,
&mut zpivot,
);
if zpivot == -1 {
return Err(jcol);
}
{
let jjj = lu.col_perm[(jcol - 1) as usize];
for i in a_colptr[(jjj - 1) as usize]..a_colptr[jjj as usize] {
pattern[(a_rowind[(i - 1) as usize] - 1) as usize] = 0;
}
pattern[(orig_row - 1) as usize] = 0;
pivt_row = zpivot;
othr_col = rmatch[(pivt_row - 1) as usize];
cmatch[(this_col - 1) as usize] = pivt_row;
cmatch[(othr_col - 1) as usize] = orig_row;
rmatch[(orig_row - 1) as usize] = othr_col;
rmatch[(pivt_row - 1) as usize] = this_col;
}
if jcol == nrow {
local_pivot_policy = -1;
}
}
let mut jcol = ncol + 1;
for i in 0..nrow {
if lu.row_perm[i as usize] == 0 {
lu.row_perm[i as usize] = jcol;
jcol = jcol + 1;
}
}
for i in 0..lastlu {
lu.lu_rowind[i as usize] = lu.row_perm[(lu.lu_rowind[i as usize] - 1) as usize];
}
let mut minujj = f64::INFINITY;
for jcol in 1..=ncol {
let ujj = f64::abs(lu.lu_nz[(lu.l_colptr[(jcol - 1) as usize] - 2) as usize]);
if ujj < minujj {
minujj = ujj;
}
}
log::info!("flops: {}", flops);
log::info!("nonzeros: {}", lastlu);
Ok(lu)
}