use crate::linalg::{inprod, matprod12_into, matprod21_into, norm};
use crate::mat::Mat;
use crate::math;
use crate::powalg::{CalWs, calden_into, hess_mul_into};
#[derive(Debug, Clone)]
pub(crate) struct GeostepWs {
distsq: Vec<f64>, weight: Vec<f64>, den: Vec<f64>, score: Vec<f64>, zrow_knew: Vec<f64>, pqlag: Vec<f64>, xopt: Vec<f64>, hm: Vec<f64>, glag: Vec<f64>, dderiv: Vec<f64>, stplen: Mat, isbd: Vec<[i64; 3]>, xdiff: Vec<f64>, lfrac: Vec<f64>, ufrac: Vec<f64>, slbd_test: Vec<f64>, subd_test: Vec<f64>, vlag: Mat, betabd: Mat, predsq: Mat, xline: Vec<f64>, den_line: Vec<f64>, xcauchy: Vec<f64>, s: Vec<f64>, mask_free: Vec<bool>, xtemp: Vec<f64>, mask_fixl: Vec<bool>, mask_fixu: Vec<bool>, new_mask_free: Vec<bool>, x: Vec<f64>, sxpt: Vec<f64>, s_cauchy: Vec<f64>, den_cauchy: Vec<f64>, dxpt: Vec<f64>, cal: CalWs,
}
impl GeostepWs {
pub(crate) fn new(n: usize, npt: usize) -> Self {
Self {
distsq: vec![0.0; npt],
weight: vec![0.0; npt],
den: vec![0.0; npt],
score: vec![0.0; npt],
zrow_knew: vec![0.0; npt - n - 1],
pqlag: vec![0.0; npt],
xopt: vec![0.0; n],
hm: vec![0.0; n],
glag: vec![0.0; n],
dderiv: vec![0.0; npt],
stplen: Mat::zeros(3, npt),
isbd: vec![[0; 3]; npt],
xdiff: vec![0.0; n],
lfrac: vec![0.0; n],
ufrac: vec![0.0; n],
slbd_test: vec![0.0; n],
subd_test: vec![0.0; n],
vlag: Mat::zeros(3, npt),
betabd: Mat::zeros(3, npt),
predsq: Mat::zeros(3, npt),
xline: vec![0.0; n],
den_line: vec![0.0; npt],
xcauchy: vec![0.0; n],
s: vec![0.0; n],
mask_free: vec![false; n],
xtemp: vec![0.0; n],
mask_fixl: vec![false; n],
mask_fixu: vec![false; n],
new_mask_free: vec![false; n],
x: vec![0.0; n],
sxpt: vec![0.0; npt],
s_cauchy: vec![0.0; n],
den_cauchy: vec![0.0; npt],
dxpt: vec![0.0; npt],
cal: CalWs::new(n, npt),
}
}
}
#[expect(clippy::needless_range_loop)] #[expect(clippy::too_many_arguments)] pub(crate) fn setdrop_tr(
kopt: usize,
ximproved: bool,
bmat: &Mat,
d: &[f64],
_delta: f64,
rho: f64,
xpt: &Mat,
zmat: &Mat,
ws: &mut GeostepWs,
den_precomputed: Option<&[f64]>,
) -> Option<usize> {
let n = xpt.nrows();
let npt = xpt.ncols();
let GeostepWs {
distsq,
weight,
den,
score,
cal,
..
} = ws;
if ximproved {
for k in 0..npt {
let mut sq = 0.0;
for i in 0..n {
let diff = xpt[[i, k]] - (xpt[[i, kopt]] + d[i]);
sq += diff * diff;
}
distsq[k] = sq;
}
} else {
for k in 0..npt {
let mut sq = 0.0;
for i in 0..n {
let diff = xpt[[i, k]] - xpt[[i, kopt]];
sq += diff * diff;
}
distsq[k] = sq;
}
}
let rho2 = rho * rho;
for k in 0..npt {
let m = (distsq[k] / rho2).max(1.0);
let m2 = m * m;
weight[k] = m2 * m2;
}
match den_precomputed {
Some(src) => den.copy_from_slice(src),
None => calden_into(kopt, bmat, d, xpt, zmat, cal, den),
}
for k in 0..npt {
score[k] = weight[k] * den[k];
}
if !ximproved {
score[kopt] = -1.0;
}
for k in 0..npt {
if score[k].is_nan() {
score[k] = -1.0;
}
}
let mut knew: Option<usize> = None;
if score.iter().any(|&v| v > 1.0) || (ximproved && score.iter().any(|&v| v > 0.0)) {
let mut best = score[0];
let mut best_k = 0;
for k in 1..npt {
if score[k] > best {
best = score[k];
best_k = k;
}
}
knew = Some(best_k);
}
if ximproved && knew.is_none() {
let mut best = distsq[0];
let mut best_k = 0;
for k in 1..npt {
if distsq[k] > best {
best = distsq[k];
best_k = k;
}
}
knew = Some(best_k);
}
knew
}
#[expect(clippy::too_many_arguments)] #[expect(clippy::too_many_lines)] #[expect(clippy::needless_range_loop)] #[expect(clippy::similar_names)] #[expect(clippy::cast_possible_wrap)] #[expect(clippy::cast_possible_truncation)] #[expect(clippy::cast_sign_loss)] pub(crate) fn geostep(
knew: usize,
kopt: usize,
bmat: &Mat,
delbar: f64,
sl: &[f64],
su: &[f64],
xpt: &Mat,
zmat: &Mat,
d: &mut [f64],
ws: &mut GeostepWs,
) {
let n = xpt.nrows();
let npt = xpt.ncols();
let GeostepWs {
distsq,
zrow_knew,
pqlag,
xopt,
hm,
glag,
dderiv,
stplen,
isbd,
xdiff,
lfrac,
ufrac,
slbd_test,
subd_test,
vlag,
betabd,
predsq,
xline,
den_line,
xcauchy,
s,
mask_free,
xtemp,
mask_fixl,
mask_fixu,
new_mask_free,
x,
sxpt,
s_cauchy,
den_cauchy,
dxpt,
cal,
..
} = ws;
for j in 0..zmat.ncols() {
zrow_knew[j] = zmat[[knew, j]];
}
matprod21_into(zmat, zrow_knew, pqlag);
let alpha = pqlag[knew];
xopt.copy_from_slice(xpt.col(kopt));
hess_mul_into(xopt, xpt, pqlag, None, dxpt, hm);
for i in 0..n {
glag[i] = bmat[[i, knew]] + hm[i];
}
let glag_abs_sum: f64 = glag.iter().map(|&v| math::abs(v)).sum();
if !glag_abs_sum.is_finite() {
for i in 0..n {
d[i] = xpt[[i, knew]] - xopt[i];
}
let scale = (0.5_f64).min(delbar / norm(d));
for i in 0..n {
d[i] *= scale;
}
return;
}
matprod12_into(glag, xpt, dderiv);
let glag_dot_xopt = inprod(glag, xopt);
for k in 0..npt {
dderiv[k] -= glag_dot_xopt;
}
distsq.fill(0.0);
for k in 0..npt {
let xk = &xpt.col(k)[..n];
let mut sq = 0.0;
for i in 0..n {
let diff = xk[i] - xopt[i];
sq += diff * diff;
}
distsq[k] = sq;
}
stplen.fill(0.0);
for e in isbd.iter_mut() {
*e = [0i64; 3];
}
for k in 0..npt {
if k == kopt || dderiv[k].is_nan() {
dderiv[k] = 0.0;
continue;
}
let subd_init = delbar / math::sqrt(distsq[k]);
let mut subd = subd_init;
let mut slbd = -subd_init;
let mut ilbd: i64 = 0;
let mut iubd: i64 = 0;
let sumin = (1.0_f64).min(subd_init);
xdiff.fill(0.0);
lfrac.fill(0.0);
ufrac.fill(0.0);
let xk = &xpt.col(k)[..n];
for i in 0..n {
xdiff[i] = xk[i] - xopt[i];
lfrac[i] = subd_init.copysign(-xdiff[i]);
if sl[i] - xopt[i] > -math::abs(xdiff[i]) * subd_init {
lfrac[i] = (sl[i] - xopt[i]) / xdiff[i];
}
ufrac[i] = subd_init.copysign(xdiff[i]);
if su[i] - xopt[i] < math::abs(xdiff[i]) * subd_init {
ufrac[i] = (su[i] - xopt[i]) / xdiff[i];
}
}
slbd_test.fill(slbd);
for i in 0..n {
if xdiff[i] > 0.0 {
slbd_test[i] = lfrac[i];
} else if xdiff[i] < 0.0 {
slbd_test[i] = ufrac[i];
}
}
let any_slbd_better = slbd_test.iter().any(|&v| v > slbd);
if any_slbd_better {
let mut best_val = f64::NEG_INFINITY;
let mut best_i: Option<usize> = None;
for i in 0..n {
if !slbd_test[i].is_nan() && slbd_test[i] > best_val {
best_val = slbd_test[i];
best_i = Some(i);
}
}
let ilbd_pos = best_i.unwrap(); slbd = slbd_test[ilbd_pos];
ilbd = -((ilbd_pos + 1) as i64) * (1.0_f64.copysign(xdiff[ilbd_pos]) as i64);
}
subd_test.fill(subd_init);
for i in 0..n {
if xdiff[i] > 0.0 {
subd_test[i] = ufrac[i];
} else if xdiff[i] < 0.0 {
subd_test[i] = lfrac[i];
}
}
let any_subd_better = subd_test.iter().any(|&v| v < subd_init);
if any_subd_better {
let mut best_val = f64::INFINITY;
let mut best_i: Option<usize> = None;
for i in 0..n {
if !subd_test[i].is_nan() && subd_test[i] < best_val {
best_val = subd_test[i];
best_i = Some(i);
}
}
let iubd_pos = best_i.unwrap(); subd = sumin.max(subd_test[iubd_pos]);
iubd = ((iubd_pos + 1) as i64) * (1.0_f64.copysign(xdiff[iubd_pos]) as i64);
}
let mut stpm = 0.5_f64; if k == knew {
stpm = slbd;
if math::abs(1.0 - dderiv[k]) > 0.0 {
stpm = -0.5 * dderiv[k] / (1.0 - dderiv[k]);
}
}
stpm = slbd.max(subd.min(stpm));
stplen[[0, k]] = slbd;
stplen[[1, k]] = subd;
stplen[[2, k]] = stpm;
isbd[k] = [ilbd, iubd, 0i64];
}
vlag.fill(0.0);
for k in 0..npt {
for i in 0..3 {
vlag[[i, k]] = stplen[[i, k]] * (1.0 - stplen[[i, k]]) * dderiv[k];
}
}
for i in 0..3 {
let t = stplen[[i, knew]];
vlag[[i, knew]] = t * (t * (1.0 - dderiv[knew]) + dderiv[knew]);
}
for k in 0..npt {
for i in 0..3 {
if vlag[[i, k]].is_nan() {
vlag[[i, k]] = 0.0;
}
}
}
betabd.fill(0.0);
for k in 0..npt {
for i in 0..3 {
let t = stplen[[i, k]] * (1.0 - stplen[[i, k]]) * distsq[k];
betabd[[i, k]] = 0.5 * t * t;
}
}
predsq.fill(0.0);
for k in 0..npt {
for i in 0..3 {
let v = vlag[[i, k]];
let v2 = v * v;
predsq[[i, k]] = v2 * (v2 + alpha * betabd[[i, k]]);
if predsq[[i, k]].is_nan() {
predsq[[i, k]] = 0.0;
}
}
}
let mut ksqs = [0usize; 3];
for i in 0..3 {
let mut best = predsq[[i, 0]];
let mut best_k = 0;
for k in 1..npt {
if predsq[[i, k]] > best {
best = predsq[[i, k]];
best_k = k;
}
}
ksqs[i] = best_k;
}
let picked = [
predsq[[0, ksqs[0]]],
predsq[[1, ksqs[1]]],
predsq[[2, ksqs[2]]],
];
let mut isq = 0;
let mut best_predsq = picked[0];
for i in 1..3 {
if picked[i] > best_predsq {
best_predsq = picked[i];
isq = i;
}
}
let ksq = ksqs[isq];
let stpsiz = stplen[[isq, ksq]];
let ibd = isbd[ksq][isq];
for i in 0..n {
xline[i] = (xopt[i] + stpsiz * (xpt[[i, ksq]] - xopt[i]))
.min(su[i])
.max(sl[i]);
}
if ibd < 0 {
let idx = (-ibd) as usize - 1; xline[idx] = sl[idx];
}
if ibd > 0 {
let idx = ibd as usize - 1; xline[idx] = su[idx];
}
for i in 0..n {
d[i] = xline[i] - xopt[i];
}
calden_into(kopt, bmat, d, xpt, zmat, cal, den_line);
if delbar > f64::from(1.0e-2_f32) {
return;
}
let bigstp = delbar + delbar;
xcauchy.copy_from_slice(xopt);
let mut vlagsq_cauchy = 0.0_f64;
for uphill in 0usize..=1 {
if uphill == 1 {
for i in 0..n {
glag[i] = -glag[i];
}
}
s.fill(0.0);
mask_free.fill(false);
for i in 0..n {
mask_free[i] =
(xopt[i] - sl[i]).min(glag[i]) > 0.0 || (xopt[i] - su[i]).max(glag[i]) < 0.0;
}
for i in 0..n {
if mask_free[i] {
s[i] = bigstp;
}
}
let mut ggfree = 0.0_f64;
for i in 0..n {
if mask_free[i] {
ggfree += glag[i] * glag[i];
}
}
if ggfree <= 0.0 || ggfree.is_nan() {
continue;
}
let mut sfixsq = 0.0_f64;
let mut grdstp = 0.0_f64;
xtemp.fill(0.0);
for _k in 0..n {
let resis = delbar * delbar - sfixsq;
if resis <= 0.0 {
break;
}
let ssqsav = sfixsq;
grdstp = math::sqrt(resis / ggfree);
for i in 0..n {
xtemp[i] = xopt[i] - grdstp * glag[i];
}
mask_fixl.fill(false);
mask_fixu.fill(false);
for i in 0..n {
mask_fixl[i] = s[i] >= bigstp && xtemp[i] <= sl[i];
mask_fixu[i] = s[i] >= bigstp && xtemp[i] >= su[i];
}
new_mask_free.fill(false);
for i in 0..n {
new_mask_free[i] = s[i] >= bigstp && !(mask_fixl[i] || mask_fixu[i]);
}
for i in 0..n {
if mask_fixl[i] {
s[i] = sl[i] - xopt[i];
}
if mask_fixu[i] {
s[i] = su[i] - xopt[i];
}
}
for i in 0..n {
if mask_fixl[i] || mask_fixu[i] {
sfixsq += s[i] * s[i];
}
}
ggfree = 0.0;
for i in 0..n {
if new_mask_free[i] {
ggfree += glag[i] * glag[i];
}
}
mask_free.copy_from_slice(new_mask_free);
if !(sfixsq > ssqsav && ggfree > 0.0) {
break;
}
}
x.fill(0.0);
for i in 0..n {
if glag[i] > 0.0 {
x[i] = sl[i];
}
}
for i in 0..n {
if glag[i] <= 0.0 {
x[i] = su[i];
}
}
for i in 0..n {
if math::abs(s[i]) <= 0.0 {
x[i] = xopt[i];
}
}
for i in 0..n {
xtemp[i] = (xopt[i] - grdstp * glag[i]).min(su[i]).max(sl[i]);
}
for i in 0..n {
if s[i] >= bigstp {
x[i] = xtemp[i];
}
}
for i in 0..n {
if s[i] >= bigstp {
s[i] = -grdstp * glag[i];
}
}
let gs = inprod(glag, s);
matprod12_into(s, xpt, sxpt);
let mut curv = 0.0_f64;
for k in 0..npt {
let t = pqlag[k] * sxpt[k];
curv += sxpt[k] * t;
}
if uphill == 1 {
curv = -curv;
}
let vlagsq: f64;
if curv > -gs && curv < -(1.0 + math::sqrt(2.0)) * gs {
let scaling = -gs / curv;
for i in 0..n {
x[i] = (xopt[i] + scaling * s[i]).min(su[i]).max(sl[i]);
}
let half_gs_scaling = 0.5 * gs * scaling;
vlagsq = half_gs_scaling * half_gs_scaling;
} else {
let t = gs + 0.5 * curv;
vlagsq = t * t;
}
if vlagsq > vlagsq_cauchy {
xcauchy.copy_from_slice(x);
vlagsq_cauchy = vlagsq;
}
}
for i in 0..n {
s_cauchy[i] = xcauchy[i] - xopt[i];
}
calden_into(kopt, bmat, s_cauchy, xpt, zmat, cal, den_cauchy);
if den_cauchy[knew] > den_line[knew].max(0.0) || den_line[knew].is_nan() {
d.copy_from_slice(s_cauchy);
}
let d_abs_sum: f64 = d.iter().map(|&v| math::abs(v)).sum();
if d_abs_sum <= 0.0 || !d_abs_sum.is_finite() {
for i in 0..n {
d[i] = xpt[[i, knew]] - xopt[i];
}
let scale = (0.5_f64).min(delbar / norm(d));
for i in 0..n {
d[i] *= scale;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::mat::Mat;
use crate::test_support::{self, DiffStats};
#[test]
fn setdrop_tr_matches_prima_on_every_captured_state() {
let corpus = test_support::load_states("setdrop_tr");
assert!(!corpus.is_empty());
for st in &corpus {
let (e, x) = (&st.entry, &st.exit);
let xpt = e.mat("xpt");
let mut ws = GeostepWs::new(xpt.nrows(), xpt.ncols());
let knew = setdrop_tr(
e.usize("kopt") - 1,
e.i64("ximproved") != 0,
&e.mat("bmat"),
&e.vec("d"),
e.f64("delta"),
e.f64("rho"),
&xpt,
&e.mat("zmat"),
&mut ws,
None,
);
assert_eq!(
knew.map_or(0, |k| k + 1),
x.usize("knew"),
"{}: knew",
st.problem
);
}
}
#[test]
fn geostep_matches_prima_on_every_captured_state() {
let corpus = test_support::load_states("geostep");
assert!(!corpus.is_empty());
let mut stats = DiffStats::default();
for st in &corpus {
let (e, x) = (&st.entry, &st.exit);
let xpt = e.mat("xpt");
let mut d = vec![0.0; xpt.nrows()];
let mut ws = GeostepWs::new(xpt.nrows(), xpt.ncols());
geostep(
e.usize("knew") - 1,
e.usize("kopt") - 1,
&e.mat("bmat"),
e.f64("delbar"),
&e.vec("sl"),
&e.vec("su"),
&xpt,
&e.mat("zmat"),
&mut d,
&mut ws,
);
stats.slice("d", &d, &x.vec("d"));
}
stats.report("geostep");
}
#[test]
fn setdrop_tr_without_improvement_never_drops_kopt() {
let n = 2usize;
let npt = 5usize;
let xpt = Mat::from_col_major(
n,
npt,
vec![
0.0, 0.0, 0.5, 0.0, 0.0, 0.5, -0.5, 0.0, 0.0, -0.5, ],
);
let bmat = Mat::zeros(n, npt + n);
let zmat = Mat::from_col_major(
npt,
npt - n - 1,
vec![
-5.656_854_249_492_38,
2.828_427_124_746_19,
0.0,
2.828_427_124_746_19,
0.0,
0.0,
0.0,
2.828_427_124_746_19,
-5.656_854_249_492_38,
2.828_427_124_746_19,
],
);
let d = vec![0.1, 0.1];
let delta = 1.0;
let rho = 0.5;
for kopt in 0..npt {
let mut ws = GeostepWs::new(n, npt);
let knew = setdrop_tr(
kopt, false, &bmat, &d, delta, rho, &xpt, &zmat, &mut ws, None,
);
assert!(
knew != Some(kopt),
"setdrop_tr with ximproved=false returned kopt={kopt}"
);
}
}
#[test]
fn setdrop_tr_excludes_kopt_even_when_it_would_otherwise_win() {
let (n, npt, kopt) = (1, 3, 0);
let mut bmat = Mat::zeros(n, npt + n);
bmat[[0, 0]] = 3.0;
let zmat = Mat::zeros(npt, npt - n - 1);
let xpt = Mat::zeros(n, npt);
let mut ws = GeostepWs::new(n, npt);
let knew = setdrop_tr(
kopt,
false,
&bmat,
&[1.0],
1.0,
1.0,
&xpt,
&zmat,
&mut ws,
None,
);
assert_eq!(knew, None);
}
#[test]
#[expect(clippy::similar_names)] fn geostep_returns_a_nonzero_step_inside_the_bounds() {
let corpus = test_support::load_states("geostep");
let st = &corpus[0];
let e = &st.entry;
let xopt_col = e.usize("kopt") - 1;
let xpt = e.mat("xpt");
let sl = e.vec("sl");
let su = e.vec("su");
let xopt: Vec<f64> = xpt.col(xopt_col).to_vec();
let n = xpt.nrows();
let mut d = vec![0.0; n];
let mut ws = GeostepWs::new(n, xpt.ncols());
geostep(
e.usize("knew") - 1,
xopt_col,
&e.mat("bmat"),
e.f64("delbar"),
&sl,
&su,
&xpt,
&e.mat("zmat"),
&mut d,
&mut ws,
);
let step_norm = norm(&d);
assert!(
step_norm > 0.0,
"geostep returned zero step, ||d|| = {step_norm}"
);
for i in 0..n {
let xi = xopt[i] + d[i];
assert!(
xi >= sl[i] - 1e-12,
"d violates lower bound at i={i}: xopt[i]+d[i]={xi} < sl[i]={}",
sl[i]
);
assert!(
xi <= su[i] + 1e-12,
"d violates upper bound at i={i}: xopt[i]+d[i]={xi} > su[i]={}",
su[i]
);
}
}
}