#include "ipm/basiclu/lu_internal.h"
#include "ipm/basiclu/lu_list.h"
#include "ipm/basiclu/lu_file.h"
#define GAP (-1)
#define FLIP(i) (-(i)-1)
static lu_int find(
const lu_int j, const lu_int *index, lu_int start, const lu_int end)
{
if (end >= 0)
{
while (start < end && index[start] != j)
start++;
return start;
}
else
{
while (index[start] != j && index[start] >= 0)
start++;
return index[start] == j ? start : end;
}
}
static lu_int bfs_path
(
const lu_int m,
const lu_int j0,
const lu_int *begin,
const lu_int *end,
const lu_int *index,
lu_int *jlist,
lu_int *marked,
lu_int *queue
)
{
lu_int j, k, pos, front, tail = 1, top = m, found = 0;
queue[0] = j0;
for (front = 0; front < tail && !found; front++)
{
j = queue[front];
for (pos = begin[j]; pos < end[j]; pos++)
{
k = index[pos];
if (k == j0)
{
found = 1;
break;
}
if (marked[k] >= 0)
{
marked[k] = FLIP(j);
queue[tail++] = k;
}
}
}
if (found)
{
while (j != j0)
{
jlist[--top] = j;
j = FLIP(marked[j]);
assert(j >= 0);
}
jlist[--top] = j0;
}
for (pos = 0; pos < tail; pos++)
marked[queue[pos]] = 0;
return top;
}
static lu_int compress_packed(const lu_int m, lu_int *begin, lu_int *index,
double *value)
{
lu_int i, p, get, put, nz = 0;
const lu_int end = begin[m];
for (i = 0; i < m; i++)
{
p = begin[i];
if (index[p] == GAP)
{
begin[i] = 0;
}
else
{
assert(index[p] > GAP);
begin[i] = index[p];
index[p] = GAP-i-1;
}
}
assert(index[0] == GAP);
i = -1;
put = 1;
for (get = 1; get < end; get++)
{
if (index[get] > GAP)
{
assert(i >= 0);
index[put] = index[get];
value[put++] = value[get];
nz++;
}
else if (index[get] < GAP)
{
assert(i == -1);
i = GAP - index[get] - 1;
index[put] = begin[i];
begin[i] = put;
value[put++] = value[get];
nz++;
}
else if (i >= 0)
{
i = -1;
index[put++] = GAP;
}
}
assert(i == -1);
begin[m] = put;
return nz;
}
static void permute(struct lu *this, const lu_int *jlist, lu_int nswap)
{
lu_int *pmap = this->pmap;
lu_int *qmap = this->qmap;
lu_int *Ubegin = this->Ubegin;
lu_int *Wbegin = this->Wbegin;
lu_int *Wend = this->Wend;
lu_int *Wflink = this->Wflink;
lu_int *Wblink = this->Wblink;
double *col_pivot = this->col_pivot;
double *row_pivot = this->row_pivot;
lu_int *Uindex = this->Uindex;
double *Uvalue = this->Uvalue;
lu_int *Windex = this->Windex;
double *Wvalue = this->Wvalue;
const lu_int j0 = jlist[0];
const lu_int jn = jlist[nswap];
const lu_int i0 = pmap[j0];
const lu_int in = pmap[jn];
lu_int begin, end, i, inext, j, jprev, n, where;
double piv;
assert(nswap >= 1);
assert(qmap[i0] == j0);
assert(qmap[in] == jn);
assert(row_pivot[i0] == 0);
assert(col_pivot[j0] == 0);
begin = Wbegin[jn];
end = Wend[jn];
piv = col_pivot[jn];
for (n = nswap; n > 0; n--)
{
j = jlist[n];
jprev = jlist[n-1];
Wbegin[j] = Wbegin[jprev];
Wend[j] = Wend[jprev];
lu_list_swap(Wflink, Wblink, j, jprev);
where = find(j, Windex, Wbegin[j], Wend[j]);
assert(where < Wend[j]);
if (n > 1)
{
assert(jprev != j0);
Windex[where] = jprev;
col_pivot[j] = Wvalue[where];
assert(col_pivot[j]);
Wvalue[where] = col_pivot[jprev];
}
else
{
assert(jprev == j0);
col_pivot[j] = Wvalue[where];
assert(col_pivot[j]);
Wend[j]--;
Windex[where] = Windex[Wend[j]];
Wvalue[where] = Wvalue[Wend[j]];
}
this->min_pivot = fmin(this->min_pivot, fabs(col_pivot[j]));
this->max_pivot = fmax(this->max_pivot, fabs(col_pivot[j]));
}
Wbegin[j0] = begin;
Wend[j0] = end;
where = find(j0, Windex, Wbegin[j0], Wend[j0]);
assert(where < Wend[j0]);
Windex[where] = jn;
col_pivot[j0] = Wvalue[where];
assert(col_pivot[j0]);
Wvalue[where] = piv;
this->min_pivot = fmin(this->min_pivot, fabs(col_pivot[j0]));
this->max_pivot = fmax(this->max_pivot, fabs(col_pivot[j0]));
begin = Ubegin[i0];
for (n = 0; n < nswap; n++)
{
i = pmap[jlist[n]];
inext = pmap[jlist[n+1]];
Ubegin[i] = Ubegin[inext];
where = find(i, Uindex, Ubegin[i], -1);
assert(where >= 0);
Uindex[where] = inext;
row_pivot[i] = Uvalue[where];
assert(row_pivot[i]);
Uvalue[where] = row_pivot[inext];
}
Ubegin[in] = begin;
where = find(in, Uindex, Ubegin[in], -1);
assert(where >= 0);
row_pivot[in] = Uvalue[where];
assert(row_pivot[in]);
for (end = where; Uindex[end] >= 0; end++) ;
Uindex[where] = Uindex[end-1];
Uvalue[where] = Uvalue[end-1];
Uindex[end-1] = -1;
for (n = nswap; n > 0; n--)
{
j = jlist[n];
i = pmap[jlist[n-1]];
pmap[j] = i;
qmap[i] = j;
}
pmap[j0] = in;
qmap[in] = j0;
#ifndef NDEBUG
for (n = 0; n <= nswap; n++)
{
j = jlist[n];
i = pmap[j];
assert(row_pivot[i] == col_pivot[j]);
}
#endif
}
#ifdef DEBUG_EXTRA
static void check_consistency(struct lu *this, lu_int *p_col, lu_int *p_row)
{
const lu_int m = this->m;
const lu_int *pmap = this->pmap;
const lu_int *qmap = this->qmap;
const lu_int *Ubegin = this->Ubegin;
const lu_int *Wbegin = this->Wbegin;
const lu_int *Wend = this->Wend;
const lu_int *Uindex = this->Uindex;
const double *Uvalue = this->Uvalue;
const lu_int *Windex = this->Windex;
const double *Wvalue = this->Wvalue;
lu_int i, ientry, j, jentry, pos, where, found;
for (i = 0; i < m; i++)
{
j = qmap[i];
for (pos = Ubegin[i]; (ientry = Uindex[pos]) >= 0; pos++)
{
jentry = qmap[ientry];
for (where = Wbegin[jentry];
where < Wend[jentry] && Windex[where] != j; where++)
;
found = where < Wend[jentry] && Wvalue[where] == Uvalue[pos];
if (!found)
{
*p_col = j;
*p_row = ientry;
return;
}
}
}
for (j = 0; j < m; j++)
{
i = pmap[j];
for (pos = Wbegin[j]; pos < Wend[j]; pos++)
{
jentry = Windex[pos];
ientry = pmap[jentry];
for (where = Ubegin[ientry];
Uindex[where] >= 0 && Uindex[where] != i; where++)
;
found = Uindex[where] == i && Uvalue[where] == Wvalue[pos];
if (!found)
{
*p_col = jentry;
*p_row = i;
return;
}
}
}
*p_col = -1;
*p_row = -1;
}
#endif
lu_int lu_update(struct lu *this, double xtbl)
{
const lu_int m = this->m;
const lu_int nforrest = this->nforrest;
lu_int Unz = this->Unz;
const lu_int pad = this->pad;
const double stretch = this->stretch;
lu_int *pmap = this->pmap;
lu_int *qmap = this->qmap;
lu_int *pivotcol = this->pivotcol;
lu_int *pivotrow = this->pivotrow;
lu_int *Ubegin = this->Ubegin;
lu_int *Rbegin = this->Rbegin;
lu_int *Wbegin = this->Wbegin;
lu_int *Wend = this->Wend;
lu_int *Wflink = this->Wflink;
lu_int *Wblink = this->Wblink;
double *col_pivot = this->col_pivot;
double *row_pivot = this->row_pivot;
lu_int *Lindex = this->Lindex;
double *Lvalue = this->Lvalue;
lu_int *Uindex = this->Uindex;
double *Uvalue = this->Uvalue;
lu_int *Windex = this->Windex;
double *Wvalue = this->Wvalue;
lu_int *marked = this->marked;
lu_int *iwork1 = this->iwork1;
lu_int *iwork2 = iwork1 + m;
double *work1 = this->work1;
lu_int jpivot = this->btran_for_update;
lu_int ipivot = pmap[jpivot];
double oldpiv = col_pivot[jpivot];
lu_int status = BASICLU_OK;
lu_int i, j, jnext, n, nz, t, put, pos, end, where, room, grow, used,need,M;
lu_int have_diag, intersect, istriangular, nz_roweta, nz_spike;
lu_int nreach, *col_reach, *row_reach;
double spike_diag, newpiv, piverr;
assert(nforrest < m);
spike_diag = 0.0;
have_diag = 0;
put = Ubegin[m];
for (pos = put; (i = Uindex[pos]) >= 0; pos++)
{
if (i != ipivot)
{
Uindex[put] = i;
Uvalue[put++] = Uvalue[pos];
}
else
{
spike_diag = Uvalue[pos];
have_diag = 1;
}
}
if (have_diag)
{
Uindex[put] = ipivot;
Uvalue[put] = spike_diag;
}
nz_spike = put - Ubegin[m];
nz_roweta = Rbegin[nforrest+1] - Rbegin[nforrest];
M = ++this->marker;
for (pos = Rbegin[nforrest]; pos < Rbegin[nforrest+1]; pos++)
{
i = Lindex[pos];
marked[i] = M;
work1[i] = Lvalue[pos];
}
newpiv = spike_diag;
intersect = 0;
for (pos = Ubegin[m]; pos < Ubegin[m] + nz_spike; pos++)
{
i = Uindex[pos];
assert(i != ipivot);
if (marked[i] == M)
{
newpiv -= Uvalue[pos] * work1[i];
intersect++;
}
}
if (newpiv == 0 || fabs(newpiv) < this->abstol)
{
status = BASICLU_ERROR_singular_update;
return status;
}
piverr = fabs(newpiv - xtbl*oldpiv);
grow = 0;
for (pos = Ubegin[m]; pos < Ubegin[m] + nz_spike; pos++)
{
i = Uindex[pos];
assert(i != ipivot);
j = qmap[i];
jnext = Wflink[j];
if (Wend[j] == Wbegin[jnext])
{
nz = Wend[j] - Wbegin[j];
grow += nz+1;
grow += stretch*(nz+1) + pad;
}
}
room = Wend[m] - Wbegin[m];
if (grow > room)
{
this->addmemW = grow-room;
status = BASICLU_REALLOCATE;
return status;
}
nz = 0;
for (pos = Ubegin[ipivot]; (i = Uindex[pos]) >= 0; pos++)
{
j = qmap[i];
end = Wend[j]--;
where = find(jpivot, Windex, Wbegin[j], end);
assert(where < end);
Windex[where] = Windex[end-1];
Wvalue[where] = Wvalue[end-1];
nz++;
}
Unz -= nz;
for (pos = Ubegin[ipivot]; Uindex[pos] >= 0; pos++)
{
Uindex[pos] = GAP;
}
Ubegin[ipivot] = Ubegin[m];
Ubegin[m] += nz_spike;
Uindex[Ubegin[m]++] = GAP;
for (pos = Ubegin[ipivot]; (i = Uindex[pos]) >= 0; pos++)
{
j = qmap[i];
jnext = Wflink[j];
if (Wend[j] == Wbegin[jnext])
{
nz = Wend[j] - Wbegin[j];
room = 1 + stretch*(nz+1) + pad;
lu_file_reappend(j, m, Wbegin, Wend, Wflink, Wblink, Windex,
Wvalue, room);
}
end = Wend[j]++;
Windex[end] = jpivot;
Wvalue[end] = Uvalue[pos];
}
Unz += nz_spike;
col_pivot[jpivot] = spike_diag;
row_pivot[ipivot] = spike_diag;
row_reach = NULL;
col_reach = NULL;
if (have_diag)
{
istriangular = intersect == 0;
if (istriangular)
{
this->min_pivot = fmin(this->min_pivot, fabs(newpiv));
this->max_pivot = fmax(this->max_pivot, fabs(newpiv));
nreach = nz_roweta + 1;
row_reach = iwork1;
col_reach = iwork2;
row_reach[0] = ipivot;
col_reach[0] = jpivot;
pos = Rbegin[nforrest];
for (n = 1; n < nreach; n++)
{
i = Lindex[pos++];
row_reach[n] = i;
col_reach[n] = qmap[i];
}
this->nsymperm_total++;
}
}
else
{
lu_int *path = iwork1, top;
lu_int *reach = iwork2, rtop;
lu_int *pstack = (void *) work1;
top = bfs_path(m, jpivot, Wbegin, Wend, Windex, path, marked, iwork2);
assert(top < m-1);
assert(path[top] == jpivot);
istriangular = 1;
rtop = m;
M = ++this->marker;
for (t = top; t < m-1 && istriangular; t++)
{
j = path[t];
jnext = path[t+1];
where = find(jnext, Windex, Wbegin[j], Wend[j]);
assert(where < Wend[j]);
Windex[where] = j;
rtop = lu_dfs(j, Wbegin, Wend, Windex, rtop, reach, pstack, marked,
M);
assert(reach[rtop] == j);
reach[rtop] = jnext;
Windex[where] = jnext;
istriangular = marked[jnext] != M;
}
if (istriangular)
{
j = path[m-1];
rtop = lu_dfs(j, Wbegin, Wend, Windex, rtop, reach, pstack, marked,
M);
assert(reach[rtop] == j);
reach[rtop] = jpivot;
marked[j]--;
for (pos = Ubegin[ipivot]; (i = Uindex[pos]) >= 0; pos++)
{
if (marked[qmap[i]] == M) istriangular = 0;
}
marked[j]++;
}
if (istriangular)
{
lu_int nswap = m-top-1;
permute(this, path+top, nswap);
Unz--;
assert(reach[rtop] == jpivot);
col_reach = reach + rtop;
row_reach = iwork1 + rtop;
nreach = m-rtop;
for (n = 0; n < nreach; n++)
{
row_reach[n] = pmap[col_reach[n]];
}
}
}
if (!istriangular)
{
for (pos = Wbegin[jpivot]; pos < Wend[jpivot]; pos++)
{
j = Windex[pos];
assert(j != jpivot);
where = -1;
for (end = Ubegin[pmap[j]]; (i = Uindex[end]) >= 0; end++)
{
if (i == ipivot) where = end;
}
assert(where >= 0);
Uindex[where] = Uindex[end-1];
Uvalue[where] = Uvalue[end-1];
Uindex[end-1] = -1;
Unz--;
}
Wend[jpivot] = Wbegin[jpivot];
col_pivot[jpivot] = newpiv;
row_pivot[ipivot] = newpiv;
this->min_pivot = fmin(this->min_pivot, fabs(newpiv));
this->max_pivot = fmax(this->max_pivot, fabs(newpiv));
nz = 0;
put = Rbegin[nforrest];
double max_eta = 0;
for (pos = put; pos < Rbegin[nforrest+1]; pos++)
{
if (Lvalue[pos])
{
max_eta = fmax(max_eta, fabs(Lvalue[pos]));
Lindex[put] = Lindex[pos];
Lvalue[put++] = Lvalue[pos];
nz++;
}
}
Rbegin[nforrest+1] = put;
this->Rnz += nz;
this->max_eta = fmax(this->max_eta, max_eta);
nreach = 1;
row_reach = &ipivot;
col_reach = &jpivot;
this->nforrest++;
this->nforrest_total++;
}
if (this->pivotlen + nreach > 2*m)
{
lu_garbage_perm(this);
}
put = this->pivotlen;
for (n = 0; n < nreach; n++) pivotrow[put++] = row_reach[n];
put = this->pivotlen;
for (n = 0; n < nreach; n++) pivotcol[put++] = col_reach[n];
this->pivotlen += nreach;
used = Ubegin[m];
if (used-Unz-m > this->compress_thres * used)
{
nz = compress_packed(m, Ubegin, Uindex, Uvalue);
assert(nz == Unz);
}
used = Wbegin[m];
need = Unz + stretch*Unz + m*pad;
if ((used-need) > this->compress_thres * used)
{
nz = lu_file_compress(m, Wbegin, Wend, Wflink, Windex, Wvalue,
stretch, pad);
assert(nz == Unz);
}
this->pivot_error = piverr / (1.0 + fabs(newpiv));
this->Unz = Unz;
this->btran_for_update = -1;
this->ftran_for_update = -1;
this->update_cost_numer += nz_roweta;
this->nupdate++;
this->nupdate_total++;
#ifdef DEBUG_EXTRA
{
lu_int col, row;
check_consistency(this, &col, &row);
assert(col < 0 && row < 0);
}
#endif
return status;
}