use crate::lu::dfs::dfs;
use crate::lu::file::{file_compress, file_reappend};
use crate::lu::garbage_perm::garbage_perm;
use crate::lu::list::list_swap;
use crate::lu::lu::*;
use crate::LUInt;
use crate::Status;
use std::time::Instant;
const GAP: LUInt = -1;
macro_rules! flip {
($i:expr) => {
-($i) - 1
};
}
fn find(j: usize, index: &[LUInt], mut start: LUInt, end: LUInt) -> LUInt {
if end >= 0 {
while start < end && index[start as usize] != j as LUInt {
start += 1;
}
start
} else {
while index[start as usize] != j as LUInt && index[start as usize] >= 0 {
start += 1;
}
if index[start as usize] == j as LUInt {
start
} else {
end
}
}
}
fn bfs_path(
m: usize, j0: usize,
begin: &[LUInt],
end: &[LUInt],
index: &[LUInt],
jlist: &mut [LUInt],
marked: &mut [LUInt],
queue: &mut [LUInt], ) -> usize {
let mut j: LUInt = -1;
let mut tail: LUInt = 1;
let mut top = m;
let mut found: LUInt = 0;
queue[0] = j0 as LUInt;
for front in 0..tail {
if found != 0 {
break;
}
j = queue[front as usize];
for pos in begin[j as usize]..end[j as usize] {
let k = index[pos as usize];
if k == j0 as LUInt {
found = 1;
break;
}
if marked[k as usize] >= 0 {
marked[k as usize] = flip!(j); queue[tail as usize] = k; tail += 1;
}
}
}
if found != 0 {
while j != j0 as LUInt {
top -= 1;
jlist[top as usize] = j;
j = flip!(marked[j as usize]); assert!(j >= 0);
}
top -= 1;
jlist[top as usize] = j0 as LUInt;
}
for pos in 0..tail {
marked[queue[pos as usize] as usize] = 0; }
top
}
fn compress_packed(m: usize, begin: &mut [LUInt], index: &mut [LUInt], value: &mut [f64]) -> usize {
let mut nz = 0;
let end = begin[m as usize];
for i in 0..m {
let p = begin[i];
if index[p as usize] == GAP {
begin[i] = 0;
} else {
assert!(index[p as usize] > GAP);
begin[i] = index[p as usize]; index[p as usize] = GAP - i as LUInt - 1; }
}
assert_eq!(index[0], GAP);
let mut i = -1;
let mut put = 1;
for get in 1..end {
if index[get as usize] > GAP {
assert!(i >= 0);
index[put as usize] = index[get as usize];
value[put as usize] = value[get as usize];
put += 1;
nz += 1;
} else if index[get as usize] < GAP {
assert_eq!(i, -1);
i = GAP - index[get as usize] - 1;
index[put as usize] = begin[i as usize]; begin[i as usize] = put;
value[put as usize] = value[get as usize];
put += 1;
nz += 1;
} else if i >= 0 {
i = -1;
index[put as usize] = GAP;
put += 1;
}
}
assert_eq!(i, -1);
begin[m as usize] = put;
nz
}
fn permute(lu: &mut LU, jlist: &[LUInt], nswap: usize) {
let pmap = &mut pmap!(lu);
let qmap = &mut qmap!(lu);
let u_begin = &mut lu.u_begin;
let w_begin = &mut lu.w_begin;
let w_end = &mut lu.w_end;
let w_flink = &mut lu.w_flink;
let w_blink = &mut lu.w_blink;
let col_pivot = &mut lu.col_pivot;
let row_pivot = &mut lu.row_pivot;
let u_index = &mut lu.u_index;
let u_value = &mut lu.u_value;
let w_index = &mut lu.w_index;
let w_value = &mut lu.w_value;
let j0 = jlist[0] as usize;
let jn = jlist[nswap as usize] as usize;
let i0 = pmap[j0] as usize;
let in_ = pmap[jn as usize] as usize;
assert!(nswap >= 1);
assert_eq!(qmap[i0] as usize, j0);
assert_eq!(qmap[in_] as usize, jn);
assert_eq!(row_pivot[i0], 0.0);
assert_eq!(col_pivot[j0], 0.0);
let begin = w_begin[jn]; let end = w_end[jn];
let piv = col_pivot[jn];
for n in (1..=nswap).rev() {
let j = jlist[n] as usize;
let jprev = jlist[n - 1] as usize;
w_begin[j] = w_begin[jprev];
w_end[j] = w_end[jprev];
list_swap(w_flink, w_blink, j, jprev);
let where_ = find(j, w_index, w_begin[j], w_end[j]);
assert!(where_ < w_end[j]);
if n > 1 {
assert_ne!(jprev, j0);
w_index[where_ as usize] = jprev as LUInt;
col_pivot[j] = w_value[where_ as usize];
assert_ne!(col_pivot[j], 0.0);
w_value[where_ as usize] = col_pivot[jprev as usize];
} else {
assert_eq!(jprev, j0);
col_pivot[j] = w_value[where_ as usize];
assert_ne!(col_pivot[j], 0.0);
w_end[j] -= 1;
w_index[where_ as usize] = w_index[w_end[j] as usize];
w_value[where_ as usize] = w_value[w_end[j] as usize];
}
lu.min_pivot = f64::min(lu.min_pivot, col_pivot[j].abs());
lu.max_pivot = f64::max(lu.max_pivot, col_pivot[j].abs());
}
w_begin[j0] = begin;
w_end[j0] = end as LUInt;
let where_ = find(j0, w_index, w_begin[j0], w_end[j0]);
assert!(where_ < w_end[j0]);
w_index[where_ as usize] = jn as LUInt;
col_pivot[j0] = w_value[where_ as usize];
assert_ne!(col_pivot[j0], 0.0);
w_value[where_ as usize] = piv;
lu.min_pivot = f64::min(lu.min_pivot, col_pivot[j0].abs());
lu.max_pivot = f64::max(lu.max_pivot, col_pivot[j0].abs());
let begin = u_begin[i0];
for n in 0..nswap {
let i = pmap[jlist[n] as usize] as usize;
let inext = pmap[jlist[n + 1] as usize];
u_begin[i] = u_begin[inext as usize];
let where_ = find(i, u_index, u_begin[i], -1);
assert!(where_ >= 0);
u_index[where_ as usize] = inext;
row_pivot[i] = u_value[where_ as usize];
assert_ne!(row_pivot[i], 0.0);
u_value[where_ as usize] = row_pivot[inext as usize];
}
u_begin[in_] = begin;
let where_ = find(in_, u_index, u_begin[in_], -1);
assert!(where_ >= 0);
row_pivot[in_] = u_value[where_ as usize];
assert_ne!(row_pivot[in_], 0.0);
let mut end = where_ as usize;
while u_index[end] >= 0 {
end += 1;
}
u_index[where_ as usize] = u_index[end - 1];
u_value[where_ as usize] = u_value[end - 1];
u_index[end - 1] = -1;
for n in (1..=nswap).rev() {
let j = jlist[n];
let i = pmap[jlist[n - 1] as usize];
pmap[j as usize] = i;
qmap[i as usize] = j;
}
pmap[j0] = in_ as LUInt;
qmap[in_] = j0 as LUInt;
if cfg!(feature = "debug") {
for n in 0..=nswap {
let j = jlist[n];
let i = pmap[j as usize];
assert_eq!(row_pivot[i as usize], col_pivot[j as usize]);
}
}
}
#[cfg(feature = "debug_extra")]
fn check_consistency(lu: &LU, p_col: &mut LUInt, p_row: &mut LUInt) {
let m = lu.m;
let pmap = &lu.pmap;
let qmap = &lu.qmap;
let u_begin = &lu.u_begin;
let w_begin = &lu.w_begin;
let w_end = &lu.w_end;
let u_index = lu.u_index.as_ref().unwrap();
let u_value = lu.u_value.as_ref().unwrap();
let w_index = lu.w_index.as_ref().unwrap();
let w_value = lu.w_value.as_ref().unwrap();
for i in 0..m {
let j = qmap[i];
let mut pos = u_begin[i];
while u_index[pos] >= 0 {
let ientry = u_index[pos];
jentry = qmap[ientry];
where_ = w_begin[jentry];
while where_ < w_end[jentry] && w_index[where_] != j {
where_ += 1;
}
let found = where_ < w_end[jentry] && w_value[where_] == u_value[pos];
if !found {
*p_col = j;
*p_row = ientry;
return;
}
pos += 1;
}
}
for j in 0..m {
let i = pmap[j];
for pos in w_begin[j]..w_end[j] {
let jentry = w_index[pos];
let ientry = pmap[jentry];
let where_ = u_begin[ientry];
while u_index[where_] >= 0 && u_index[where_] != i {
where_ += 1;
}
let found = u_index[where_] == i && u_value[where_] == w_value[pos];
if !found {
*p_col = jentry;
*p_row = i;
return;
}
}
}
*p_col = -1;
*p_row = -1;
}
pub(crate) fn update(lu: &mut LU, xtbl: f64) -> Status {
let m = lu.m;
let nforrest = lu.nforrest;
let mut u_nz = lu.u_nz;
let pad = lu.pad;
let stretch = lu.stretch;
let jpivot = lu.btran_for_update.unwrap();
let ipivot = pmap![lu][jpivot] as usize;
let oldpiv = lu.col_pivot[jpivot];
let mut status = Status::OK;
let ipivot_vec = vec![0; ipivot]; let jpivot_vec = vec![0; jpivot];
let tic = Instant::now();
assert!(nforrest < m);
let mut spike_diag = 0.0;
let mut have_diag = 0;
let mut put = lu.u_begin[m] as usize;
let mut pos = put;
while lu.u_index[pos] >= 0 {
let i = lu.u_index[pos];
if i != ipivot as LUInt {
lu.u_index[put] = i;
lu.u_value[put] = lu.u_value[pos];
put += 1;
} else {
spike_diag = lu.u_value[pos];
have_diag = 1;
}
pos += 1;
}
if have_diag != 0 {
lu.u_index[put] = ipivot as LUInt;
lu.u_value[put] = spike_diag;
}
let nz_spike = put - lu.u_begin[m] as usize;
let nz_roweta = (r_begin![lu][nforrest + 1] - r_begin![lu][nforrest]) as usize;
lu.marker += 1;
let marker = lu.marker;
for pos in r_begin![lu][nforrest]..r_begin![lu][nforrest + 1] {
let i = lu.l_index[pos as usize] as usize;
marked![lu][i] = marker;
lu.work1[i] = lu.l_value[pos as usize];
}
let mut newpiv = spike_diag;
let mut intersect = 0;
for pos in lu.u_begin[m] as usize..lu.u_begin[m] as usize + nz_spike {
let i = lu.u_index[pos] as usize;
assert_ne!(i, ipivot);
if marked![lu][i] == marker {
newpiv -= lu.u_value[pos] * lu.work1[i];
intersect += 1;
}
}
if newpiv == 0.0 || newpiv.abs() < lu.abstol {
status = Status::ErrorSingularUpdate;
return status;
}
let piverr = (newpiv - xtbl * oldpiv).abs();
let mut grow = 0;
for pos in lu.u_begin[m] as usize..lu.u_begin[m] as usize + nz_spike {
let i = lu.u_index[pos] as usize;
assert_ne!(i, ipivot);
let j = qmap![lu][i] as usize;
let jnext = lu.w_flink[j] as usize;
if lu.w_end[j] == lu.w_begin[jnext] {
let nz = (lu.w_end[j] - lu.w_begin[j]) as usize;
grow += nz + 1; grow += (stretch * (nz + 1) as f64) as usize + pad; }
}
let room = (lu.w_end[m] - lu.w_begin[m]) as usize;
if grow > room {
lu.addmem_w = grow - room;
status = Status::Reallocate;
return status;
}
let mut nz = 0;
let mut pos = lu.u_begin[ipivot] as usize;
while lu.u_index[pos] >= 0 {
let i = lu.u_index[pos] as usize;
let j = qmap![lu][i] as usize;
let end = lu.w_end[j] as usize;
lu.w_end[j] -= 1;
let where_ = find(jpivot, &lu.w_index, lu.w_begin[j], end as LUInt);
assert!(where_ < end as LUInt);
lu.w_index[where_ as usize] = lu.w_index[end - 1];
lu.w_value[where_ as usize] = lu.w_value[end - 1];
nz += 1;
pos += 1;
}
u_nz -= nz;
let mut pos = lu.u_begin[ipivot] as usize;
while lu.u_index[pos] >= 0 {
lu.u_index[pos] = GAP;
pos += 1;
}
lu.u_begin[ipivot] = lu.u_begin[m as usize];
lu.u_begin[m as usize] += nz_spike as LUInt;
lu.u_index[lu.u_begin[m as usize] as usize] = GAP;
lu.u_begin[m as usize] += 1;
let mut pos = lu.u_begin[ipivot as usize] as usize;
while lu.u_index[pos] >= 0 {
let i = lu.u_index[pos] as usize;
let j = qmap![lu][i] as usize;
let jnext = lu.w_flink[j] as usize;
if lu.w_end[j] == lu.w_begin[jnext] {
nz = (lu.w_end[j] - lu.w_begin[j]) as usize;
let room = 1 + (stretch * (nz + 1) as f64) as usize + pad;
file_reappend(
j,
m,
&mut lu.w_begin,
&mut lu.w_end,
&mut lu.w_flink,
&mut lu.w_blink,
&mut lu.w_index,
&mut lu.w_value,
room,
);
}
let end = lu.w_end[j] as usize;
lu.w_end[j] += 1;
lu.w_index[end] = jpivot as LUInt;
lu.w_value[end] = lu.u_value[pos];
pos += 1;
}
u_nz += nz_spike;
lu.col_pivot[jpivot] = spike_diag;
lu.row_pivot[ipivot as usize] = spike_diag;
let (istriangular, nreach_opt, row_reach_opt, col_reach_opt) = if have_diag != 0 {
let istriangular = intersect == 0;
if istriangular {
lu.min_pivot = f64::min(lu.min_pivot, newpiv.abs());
lu.max_pivot = f64::max(lu.max_pivot, newpiv.abs());
let nreach = nz_roweta + 1;
let mut row_reach: Vec<LUInt> = vec![0; nreach as usize - 1]; let mut col_reach: Vec<LUInt> = vec![0; nreach as usize - 1];
row_reach[0] = ipivot as LUInt;
col_reach[0] = jpivot as LUInt;
let mut pos = r_begin![lu][nforrest] as usize;
for n in 1..nreach {
let i = lu.l_index[pos];
pos += 1;
row_reach[n as usize] = i;
col_reach[n as usize] = qmap![lu][i as usize];
}
lu.nsymperm_total += 1;
(istriangular, Some(nreach), Some(row_reach), Some(col_reach))
} else {
(istriangular, None, None, None)
}
} else {
let top = {
let (iwork1, iwork2) = iwork1!(lu).split_at_mut(m as usize);
let path = iwork1;
let top = bfs_path(
m,
jpivot,
&lu.w_begin,
&lu.w_end,
&lu.w_index,
path,
&mut marked!(lu),
iwork2,
);
assert!(top < m - 1);
assert_eq!(path[top], jpivot as LUInt);
top
};
let (mut istriangular, mut rtop) = {
let (iwork1, iwork2) = iwork1!(lu).split_at_mut(m as usize);
let path = iwork1;
let reach = iwork2;
let pstack = &mut lu.work1;
let mut istriangular = true;
let mut rtop = m;
lu.marker += 1;
let marker = lu.marker;
for t in top..m - 1 {
if !istriangular {
break;
}
let j = path[t] as usize;
let jnext = path[t + 1] as usize;
let where_ = find(jnext, &lu.w_index, lu.w_begin[j], lu.w_end[j]);
assert!(where_ < lu.w_end[j]);
lu.w_index[where_ as usize] = j as LUInt; rtop = dfs(
j,
&lu.w_begin,
Some(&lu.w_end),
&lu.w_index,
rtop,
reach,
pstack,
&mut marked!(lu),
marker,
);
assert_eq!(reach[rtop as usize] as usize, j);
reach[rtop as usize] = jnext as LUInt;
lu.w_index[where_ as usize] = jnext as LUInt;
istriangular = marked![lu][jnext] != marker;
}
(istriangular, rtop)
};
if istriangular {
let (iwork1, iwork2) = iwork1!(lu).split_at_mut(m as usize);
let path = iwork1;
let reach = iwork2;
let pstack = &mut lu.work1;
let j = path[m - 1] as usize;
rtop = dfs(
j,
&lu.w_begin,
Some(&lu.w_end),
&lu.w_index,
rtop,
reach,
pstack,
&mut marked!(lu),
marker,
);
assert_eq!(reach[rtop as usize], j as LUInt);
reach[rtop as usize] = jpivot as LUInt;
marked![lu][j] -= 1;
let mut pos = lu.u_begin[ipivot] as usize;
while lu.u_index[pos] >= 0 {
let i = lu.u_index[pos] as usize;
if marked![lu][qmap![lu][i] as usize] == marker {
istriangular = false;
}
pos += 1;
}
marked![lu][j] += 1;
}
if istriangular {
let nswap = m - top - 1;
permute(lu, &iwork1![lu][top..top + nswap].to_vec(), nswap); u_nz -= 1;
let (iwork1, iwork2) = iwork1!(lu).split_at_mut(m as usize);
let reach = iwork2;
assert_eq!(reach[rtop], jpivot as LUInt);
let nreach = m - rtop;
let col_reach = reach[rtop..rtop + nreach].to_vec();
let mut row_reach = iwork1[rtop..rtop + nreach].to_vec();
for n in 0..nreach {
row_reach[n] = pmap![lu][col_reach[n] as usize];
}
(istriangular, Some(nreach), Some(row_reach), Some(col_reach))
} else {
(istriangular, None, None, None)
}
};
let (nreach, row_reach, col_reach) = if !istriangular {
let mut pos = lu.w_begin[jpivot];
while pos < lu.w_end[jpivot] {
let j = lu.w_index[pos as usize] as usize;
assert_ne!(j, jpivot);
let mut where_ = None;
let mut end = lu.u_begin[pmap![lu][j] as usize] as usize;
while lu.u_index[end] >= 0 {
let i = lu.u_index[end] as usize;
if i == ipivot {
where_ = Some(end);
}
end += 1;
}
assert!(where_.is_some());
lu.u_index[where_.unwrap()] = lu.u_index[end - 1];
lu.u_value[where_.unwrap()] = lu.u_value[end - 1];
lu.u_index[end - 1] = -1;
u_nz -= 1;
pos += 1;
}
lu.w_end[jpivot] = lu.w_begin[jpivot];
lu.col_pivot[jpivot] = newpiv;
lu.row_pivot[ipivot] = newpiv;
lu.min_pivot = f64::min(lu.min_pivot, newpiv.abs());
lu.max_pivot = f64::max(lu.max_pivot, newpiv.abs());
nz = 0;
put = r_begin![lu][nforrest] as usize;
let mut max_eta = 0.0;
for pos in put..r_begin![lu][nforrest + 1] as usize {
if lu.l_value[pos] != 0.0 {
max_eta = f64::max(max_eta, lu.l_value[pos].abs());
lu.l_index[put] = lu.l_index[pos];
lu.l_value[put] = lu.l_value[pos];
put += 1;
nz += 1;
}
}
r_begin![lu][nforrest + 1] = put as LUInt;
lu.r_nz += nz;
lu.max_eta = f64::max(lu.max_eta, max_eta);
let nreach: usize = 1;
let row_reach: Vec<LUInt> = ipivot_vec.to_vec();
let col_reach: Vec<LUInt> = jpivot_vec.to_vec();
lu.nforrest += 1;
lu.nforrest_total += 1;
(nreach, row_reach, col_reach)
} else {
(
nreach_opt.unwrap(),
row_reach_opt.unwrap(),
col_reach_opt.unwrap(),
)
};
if lu.pivotlen + nreach > 2 * m {
garbage_perm(lu);
}
let mut put = lu.pivotlen;
for n in 0..nreach {
pivotrow![lu][put] = row_reach[n];
put += 1;
}
let mut put = lu.pivotlen;
for n in 0..nreach {
pivotcol![lu][put] = col_reach[n];
put += 1;
}
lu.pivotlen += nreach;
let used = lu.u_begin[m] as usize;
if used - u_nz - m > (lu.compress_thres * used as f64) as usize {
nz = compress_packed(m, &mut lu.u_begin, &mut lu.u_index, &mut lu.u_value);
assert_eq!(nz, u_nz);
}
let used = lu.w_begin[m as usize] as usize;
let need = u_nz + (stretch * u_nz as f64) as usize + m * pad;
if (used - need) > (lu.compress_thres * used as f64) as usize {
nz = file_compress(
m,
&mut lu.w_begin,
&mut lu.w_end,
&mut lu.w_flink,
&mut lu.w_index,
&mut lu.w_value,
stretch,
pad,
);
assert_eq!(nz, u_nz);
}
let elapsed = tic.elapsed().as_secs_f64();
lu.time_update += elapsed;
lu.time_update_total += elapsed;
lu.pivot_error = piverr / (1.0 + newpiv.abs());
lu.u_nz = u_nz;
lu.btran_for_update = None;
lu.ftran_for_update = None;
lu.update_cost_numer += nz_roweta as f64;
lu.nupdate = Some(lu.nupdate.unwrap() + 1);
lu.nupdate_total += 1;
#[cfg(feature = "debug_extra")]
{
let mut col = -1;
let mut row = -1;
check_consistency(&lu, &mut col, &mut row);
assert!(col < 0 && row < 0);
}
status
}