use crate::consts::{DAMAGING_ROUNDING, INFO_DFT};
use crate::linalg::{inprod, matprod12_into, matprod21_into, planerot, r1update};
use crate::mat::Mat;
use crate::math;
use crate::powalg::{CalWs, calbeta, calvlag_into, hess_mul_into};
#[derive(Debug, Clone)]
pub(crate) struct UpdateWs {
zrow: Vec<f64>, hcol: Vec<f64>, vlag: Vec<f64>, v1: Vec<f64>, v2: Vec<f64>, pqinc: Vec<f64>, zmat_zrow: Vec<f64>, hm: Vec<f64>, pgopt: Vec<f64>, inner: Vec<f64>, pqalt: Vec<f64>, galt: Vec<f64>, pgalt: Vec<f64>, dxpt: Vec<f64>, cal: CalWs,
}
impl UpdateWs {
pub(crate) fn new(n: usize, npt: usize) -> Self {
Self {
zrow: vec![0.0; npt - n - 1],
hcol: vec![0.0; npt + n],
vlag: vec![0.0; npt + n],
v1: vec![0.0; n],
v2: vec![0.0; n],
pqinc: vec![0.0; npt],
zmat_zrow: vec![0.0; npt],
hm: vec![0.0; n],
pgopt: vec![0.0; n],
inner: vec![0.0; npt - n - 1],
pqalt: vec![0.0; npt],
galt: vec![0.0; n],
pgalt: vec![0.0; n],
dxpt: vec![0.0; npt],
cal: CalWs::new(n, npt),
}
}
}
#[expect(clippy::too_many_arguments)] #[expect(clippy::needless_range_loop)] pub(crate) fn updateh(
knew: Option<usize>,
kopt: usize,
d: &[f64],
xpt: &Mat,
bmat: &mut Mat,
zmat: &mut Mat,
ws: &mut UpdateWs,
precomputed: Option<(&[f64], f64)>, ) -> i32 {
let n = xpt.nrows();
let npt = xpt.ncols();
let Some(knew) = knew else {
return INFO_DFT;
};
let UpdateWs {
zrow,
hcol,
vlag,
v1,
v2,
cal,
..
} = ws;
for j in 0..zmat.ncols() {
zrow[j] = zmat[[knew, j]];
}
matprod21_into(zmat, zrow, &mut hcol[..npt]);
hcol[npt..npt + n].copy_from_slice(bmat.col(knew));
let beta = if let Some((vlag_src, beta_src)) = precomputed {
vlag.copy_from_slice(vlag_src);
beta_src
} else {
let beta = calbeta(kopt, bmat, d, xpt, zmat, cal);
calvlag_into(kopt, bmat, d, xpt, zmat, cal, vlag);
beta
};
let alpha = hcol[knew];
let tau = vlag[knew];
let denom = alpha * beta + tau * tau;
vlag[knew] -= 1.0;
let hcol_sum_abs: f64 = hcol.iter().map(|v| math::abs(*v)).sum();
let vlag_sum_abs: f64 = vlag.iter().map(|v| math::abs(*v)).sum();
if !((hcol_sum_abs + vlag_sum_abs + math::abs(beta)).is_finite() && denom > 0.0) {
return DAMAGING_ROUNDING;
}
for i in 0..n {
v1[i] = (alpha * vlag[npt + i] - tau * hcol[npt + i]) / denom;
v2[i] = (-beta * hcol[npt + i] - tau * vlag[npt + i]) / denom;
}
for j in 0..(npt + n) {
let vj = vlag[j];
let bj = &mut bmat.col_mut(j)[..n];
let v1 = &v1[..n];
for i in 0..n {
bj[i] += v1[i] * vj;
}
}
for j in 0..(npt + n) {
let hj = hcol[j];
let bj = &mut bmat.col_mut(j)[..n];
let v2 = &v2[..n];
for i in 0..n {
bj[i] += v2[i] * hj;
}
}
for j in 0..n {
for i in 0..j {
bmat[[i, npt + j]] = bmat[[j, npt + i]];
}
}
for j in 1..(npt - n - 1) {
let max_abs_zmat = zmat
.data()
.iter()
.map(|v| math::abs(*v))
.fold(0.0_f64, f64::max);
if math::abs(zmat[[knew, j]]) > f64::from(1.0e-20_f32) * max_abs_zmat {
let grot = planerot([zmat[[knew, 0]], zmat[[knew, j]]]);
let (z0, zj) = zmat.two_cols_mut(0, j);
for i in 0..npt {
let (a, b) = (z0[i], zj[i]);
z0[i] = a * grot[0][0] + b * grot[0][1];
zj[i] = a * grot[1][0] + b * grot[1][1];
}
}
zmat[[knew, j]] = 0.0;
}
let sqrtdn = math::sqrt(denom);
let zknew1 = zmat[[knew, 0]];
let tau_q = tau / sqrtdn;
let zk_q = zknew1 / sqrtdn;
let z0 = &mut zmat.col_mut(0)[..npt];
let vl = &vlag[..npt];
for i in 0..npt {
z0[i] = tau_q * z0[i] - zk_q * vl[i];
}
INFO_DFT
}
pub(crate) fn updatexf(
knew: Option<usize>,
ximproved: bool,
f: f64,
xnew: &[f64],
kopt: &mut usize,
fval: &mut [f64],
xpt: &mut Mat,
) {
let Some(knew) = knew else {
return;
};
xpt.col_mut(knew).copy_from_slice(xnew);
fval[knew] = f;
if ximproved {
*kopt = knew;
}
}
#[expect(clippy::too_many_arguments)] pub(crate) fn updateq(
knew: Option<usize>,
ximproved: bool,
bmat: &Mat,
d: &[f64],
moderr: f64,
xdrop: &[f64],
xosav: &[f64],
xpt: &Mat,
zmat: &Mat,
gopt: &mut [f64],
hq: &mut Mat,
pq: &mut [f64],
ws: &mut UpdateWs,
) {
let Some(knew) = knew else {
return;
};
let UpdateWs {
zrow,
zmat_zrow,
pqinc,
hm,
dxpt,
..
} = ws;
r1update(hq, pq[knew], xdrop);
pq[knew] = 0.0;
for j in 0..zmat.ncols() {
zrow[j] = zmat[[knew, j]];
}
matprod21_into(zmat, zrow, zmat_zrow);
let npt = pq.len();
for k in 0..npt {
pqinc[k] = moderr * zmat_zrow[k];
}
for k in 0..npt {
pq[k] += pqinc[k];
}
hess_mul_into(xosav, xpt, pqinc, None, dxpt, hm);
let n = gopt.len();
for i in 0..n {
gopt[i] += moderr * bmat[[i, knew]];
}
for i in 0..n {
gopt[i] += hm[i];
}
if ximproved {
hess_mul_into(d, xpt, pq, Some(hq), dxpt, hm);
for i in 0..n {
gopt[i] += hm[i];
}
}
}
#[expect(clippy::too_many_arguments)] #[expect(clippy::similar_names)] pub(crate) fn tryqalt(
bmat: &Mat,
fval: &[f64],
ratio: f64,
sl: &[f64],
su: &[f64],
xopt: &[f64],
xpt: &Mat,
zmat: &Mat,
itest: &mut i32,
gopt: &mut [f64],
hq: &mut Mat,
pq: &mut [f64],
ws: &mut UpdateWs,
) {
let n = gopt.len();
let npt = pq.len();
let UpdateWs {
pgopt,
inner,
pqalt,
galt,
pgalt,
hm,
dxpt,
..
} = ws;
pgopt.copy_from_slice(gopt);
for i in 0..n {
if xopt[i] >= su[i] {
pgopt[i] = 0.0_f64.max(gopt[i]);
}
}
for i in 0..n {
if xopt[i] <= sl[i] {
pgopt[i] = 0.0_f64.min(gopt[i]);
}
}
matprod12_into(fval, zmat, inner);
matprod21_into(zmat, inner, pqalt);
galt.fill(0.0);
for j in 0..npt {
let fvalj = fval[j];
for i in 0..n {
galt[i] += bmat[[i, j]] * fvalj;
}
}
hess_mul_into(xopt, xpt, pqalt, None, dxpt, hm);
for i in 0..n {
galt[i] += hm[i];
}
pgalt.copy_from_slice(galt);
for i in 0..n {
if xopt[i] >= su[i] {
pgalt[i] = 0.0_f64.max(galt[i]);
}
}
for i in 0..n {
if xopt[i] <= sl[i] {
pgalt[i] = 0.0_f64.min(galt[i]);
}
}
let pgopt_sq = inprod(pgopt, pgopt);
let pgalt_sq = inprod(pgalt, pgalt);
if ratio > 0.1 || pgopt_sq < 10.0 * pgalt_sq {
*itest = 0;
} else {
*itest += 1;
}
if *itest >= 3 {
gopt.copy_from_slice(galt);
pq.copy_from_slice(pqalt);
hq.fill(0.0);
*itest = 0;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::mat::Mat;
use crate::test_support::{self, DiffStats};
#[test]
#[expect(clippy::similar_names)] fn updateh_matches_prima_on_every_captured_state() {
let states = test_support::load_states("updateh");
assert!(!states.is_empty());
let mut stats = DiffStats::default();
for st in &states {
let (e, x) = (&st.entry, &st.exit);
let knew = match e.usize("knew") {
0 => None,
k => Some(k - 1),
};
let kopt = e.usize("kopt") - 1;
let d = e.vec("d");
let xpt = e.mat("xpt");
let mut bmat = e.mat("bmat");
let mut zmat = e.mat("zmat");
let mut ws = UpdateWs::new(xpt.nrows(), xpt.ncols());
let _info = updateh(knew, kopt, &d, &xpt, &mut bmat, &mut zmat, &mut ws, None);
stats.mat("bmat", &bmat, &x.mat("bmat"));
stats.mat("zmat", &zmat, &x.mat("zmat"));
}
stats.report("updateh");
}
#[test]
#[expect(clippy::similar_names)] fn updatexf_matches_prima_on_every_captured_state() {
let states = test_support::load_states("updatexf");
assert!(!states.is_empty());
let mut stats = DiffStats::default();
for st in &states {
let (e, x) = (&st.entry, &st.exit);
let knew = match e.usize("knew") {
0 => None,
k => Some(k - 1),
};
let ximproved = e.i64("ximproved") != 0;
let f = e.f64("f");
let xnew = e.vec("xnew");
let mut kopt = e.usize("kopt") - 1;
let mut fval = e.vec("fval");
let mut xpt = e.mat("xpt");
updatexf(knew, ximproved, f, &xnew, &mut kopt, &mut fval, &mut xpt);
assert_eq!(kopt + 1, x.usize("kopt"), "{}: kopt", st.problem);
stats.slice("fval", &fval, &x.vec("fval"));
stats.mat("xpt", &xpt, &x.mat("xpt"));
}
stats.report("updatexf");
}
#[test]
#[expect(clippy::similar_names)] fn updateq_matches_prima_on_every_captured_state() {
let states = test_support::load_states("updateq");
assert!(!states.is_empty());
let mut stats = DiffStats::default();
for st in &states {
let (e, x) = (&st.entry, &st.exit);
let knew = match e.usize("knew") {
0 => None,
k => Some(k - 1),
};
let ximproved = e.i64("ximproved") != 0;
let bmat = e.mat("bmat");
let d = e.vec("d");
let moderr = e.f64("moderr");
let xdrop = e.vec("xdrop");
let xosav = e.vec("xosav");
let xpt = e.mat("xpt");
let zmat = e.mat("zmat");
let mut gopt = e.vec("gopt");
let mut hq = e.mat("hq");
let mut pq = e.vec("pq");
let mut ws = UpdateWs::new(xpt.nrows(), xpt.ncols());
updateq(
knew, ximproved, &bmat, &d, moderr, &xdrop, &xosav, &xpt, &zmat, &mut gopt,
&mut hq, &mut pq, &mut ws,
);
stats.slice("gopt", &gopt, &x.vec("gopt"));
stats.mat("hq", &hq, &x.mat("hq"));
stats.slice("pq", &pq, &x.vec("pq"));
}
stats.report("updateq");
}
#[test]
#[expect(clippy::similar_names)] fn tryqalt_matches_prima_on_every_captured_state() {
let states = test_support::load_states("tryqalt");
assert!(!states.is_empty());
let mut stats = DiffStats::default();
for st in &states {
let (e, x) = (&st.entry, &st.exit);
let bmat = e.mat("bmat");
let fval = e.vec("fval");
let ratio = e.f64("ratio");
let sl = e.vec("sl");
let su = e.vec("su");
let xopt = e.vec("xopt");
let xpt = e.mat("xpt");
let zmat = e.mat("zmat");
let mut itest = i32::try_from(e.i64("itest")).unwrap();
let mut gopt = e.vec("gopt");
let mut hq = e.mat("hq");
let mut pq = e.vec("pq");
let mut ws = UpdateWs::new(xpt.nrows(), xpt.ncols());
tryqalt(
&bmat, &fval, ratio, &sl, &su, &xopt, &xpt, &zmat, &mut itest, &mut gopt, &mut hq,
&mut pq, &mut ws,
);
assert_eq!(
itest,
i32::try_from(x.i64("itest")).unwrap(),
"{}: itest",
st.problem
);
stats.slice("gopt", &gopt, &x.vec("gopt"));
stats.mat("hq", &hq, &x.mat("hq"));
stats.slice("pq", &pq, &x.vec("pq"));
}
stats.report("tryqalt");
}
#[test]
fn updateh_with_no_knew_leaves_h_untouched() {
let xpt = Mat::zeros(1, 4);
let mut bmat = Mat::from_col_major(1, 5, vec![1.0, 2.0, 3.0, 4.0, 5.0]);
let mut zmat = Mat::from_col_major(4, 2, (0..8).map(f64::from).collect());
let (b0, z0) = (bmat.data().to_vec(), zmat.data().to_vec());
let mut ws = UpdateWs::new(xpt.nrows(), xpt.ncols());
let info = updateh(None, 0, &[0.0], &xpt, &mut bmat, &mut zmat, &mut ws, None);
assert_eq!(info, crate::consts::INFO_DFT);
assert_eq!(bmat.data(), &b0[..]);
assert_eq!(zmat.data(), &z0[..]);
}
#[test]
fn updatexf_without_improvement_keeps_kopt() {
let mut xpt = Mat::zeros(1, 4);
let mut fval = vec![3.0, 1.0, 4.0, 5.0];
let mut kopt = 1;
updatexf(Some(2), false, 2.0, &[0.5], &mut kopt, &mut fval, &mut xpt);
assert_eq!(kopt, 1);
assert_eq!(fval[2], 2.0);
assert_eq!(xpt[[0, 2]], 0.5);
}
#[test]
fn updatexf_with_no_knew_is_a_silent_no_op() {
let mut xpt = Mat::from_col_major(1, 4, vec![1.0, 2.0, 3.0, 4.0]);
let mut fval = vec![3.0, 1.0, 4.0, 5.0];
let mut kopt = 1;
let (x0, f0) = (xpt.data().to_vec(), fval.clone());
updatexf(None, true, 9.0, &[7.0], &mut kopt, &mut fval, &mut xpt);
assert_eq!(kopt, 1);
assert_eq!(fval, f0);
assert_eq!(xpt.data(), &x0[..]);
}
#[test]
fn updateq_with_no_knew_is_a_silent_no_op() {
let (n, npt) = (1, 4);
let bmat = Mat::zeros(n, npt + n);
let zmat = Mat::zeros(npt, npt - n - 1);
let xpt = Mat::zeros(n, npt);
let mut gopt = vec![5.0];
let mut hq = Mat::from_col_major(1, 1, vec![7.0]);
let mut pq = vec![1.0, 2.0, 3.0, 4.0];
let (g0, h0, p0) = (gopt.clone(), hq.data().to_vec(), pq.clone());
let mut ws = UpdateWs::new(n, npt);
updateq(
None,
true,
&bmat,
&[0.0],
0.5,
&[0.0],
&[0.0],
&xpt,
&zmat,
&mut gopt,
&mut hq,
&mut pq,
&mut ws,
);
assert_eq!(gopt, g0);
assert_eq!(hq.data(), &h0[..]);
assert_eq!(pq, p0);
}
#[test]
#[expect(clippy::similar_names)] fn tryqalt_replaces_the_model_on_the_third_consecutive_failure() {
let n = 1;
let npt = 4;
let bmat = Mat::zeros(n, npt + n);
let zmat = Mat::zeros(npt, npt - n - 1);
let xpt = Mat::zeros(n, npt);
let fval = vec![0.0; npt];
let (sl, su, xopt) = (vec![-1.0], vec![1.0], vec![0.0]);
let mut itest = 2;
let mut gopt = vec![5.0];
let mut hq = Mat::from_col_major(1, 1, vec![7.0]);
let mut pq = vec![1.0; npt];
let mut ws = UpdateWs::new(n, npt);
tryqalt(
&bmat, &fval, 0.0, &sl, &su, &xopt, &xpt, &zmat, &mut itest, &mut gopt, &mut hq,
&mut pq, &mut ws,
);
assert_eq!(itest, 0); assert_eq!(gopt, vec![0.0]); assert_eq!(hq[[0, 0]], 0.0);
assert_eq!(pq, vec![0.0; npt]);
}
#[test]
#[expect(clippy::similar_names)] fn tryqalt_resets_itest_when_only_the_ratio_is_good() {
let (n, npt) = (1, 4);
let bmat = Mat::zeros(n, npt + n);
let zmat = Mat::zeros(npt, npt - n - 1);
let xpt = Mat::zeros(n, npt);
let fval = vec![0.0; npt];
let (sl, su, xopt) = (vec![-1.0], vec![1.0], vec![0.0]);
let mut itest = 1;
let mut gopt = vec![5.0];
let mut hq = Mat::from_col_major(1, 1, vec![7.0]);
let mut pq = vec![1.0; npt];
let mut ws = UpdateWs::new(n, npt);
tryqalt(
&bmat, &fval, 0.5, &sl, &su, &xopt, &xpt, &zmat, &mut itest, &mut gopt, &mut hq,
&mut pq, &mut ws,
);
assert_eq!(itest, 0);
}
}