use crate::consts::{EPS, REALMAX, REALMIN};
use crate::linalg::inprod;
use crate::mat::Mat;
use crate::math;
use crate::powalg::hess_mul_into;
use crate::util::interval_max;
const GRID_SIZE_MAX: usize = 42;
#[derive(Debug, Clone)]
pub(crate) struct TrsboxWs {
gopt: Vec<f64>, pq: Vec<f64>, hq: Mat, xbdi: Vec<i32>, gnew: Vec<f64>, s: Vec<f64>, xnew: Vec<f64>, xtest: Vec<f64>, sbound: Vec<f64>, dold: Vec<f64>, hdred: Vec<f64>, dred: Vec<f64>, ssq: Vec<f64>, tanbd: Vec<f64>, sqdscr: Vec<f64>, hs: Vec<f64>, dxpt: Vec<f64>, xgrid: Vec<f64>, fgrid: Vec<f64>, }
impl TrsboxWs {
pub(crate) fn new(n: usize, npt: usize) -> Self {
Self {
gopt: vec![0.0; n],
pq: vec![0.0; npt],
hq: Mat::zeros(n, n),
xbdi: vec![0; n],
gnew: vec![0.0; n],
s: vec![0.0; n],
xnew: vec![0.0; n],
xtest: vec![0.0; n],
sbound: vec![0.0; n],
dold: vec![0.0; n],
hdred: vec![0.0; n],
dred: vec![0.0; n],
ssq: vec![0.0; n],
tanbd: vec![0.0; n],
sqdscr: vec![0.0; n],
hs: vec![0.0; n],
dxpt: vec![0.0; npt],
xgrid: vec![0.0; GRID_SIZE_MAX],
fgrid: vec![0.0; GRID_SIZE_MAX],
}
}
}
#[expect(clippy::too_many_arguments)] #[expect(clippy::too_many_lines)] #[expect(clippy::similar_names)] #[expect(clippy::neg_cmp_op_on_partial_ord)] #[expect(clippy::cast_possible_truncation)] #[expect(clippy::cast_sign_loss)] #[expect(clippy::cast_possible_wrap)] #[expect(clippy::needless_range_loop)] pub(crate) fn trsbox(
delta: f64,
gopt_in: &[f64],
hq_in: &Mat,
pq_in: &[f64],
sl: &[f64],
su: &[f64],
tol: f64,
xopt: &[f64],
xpt: &Mat,
d: &mut [f64],
ws: &mut TrsboxWs,
) -> f64 {
let n = gopt_in.len();
let TrsboxWs {
gopt,
pq,
hq,
xbdi,
gnew,
s,
xnew,
xtest,
sbound,
dold,
hdred,
dred,
ssq,
tanbd,
sqdscr,
hs,
dxpt,
xgrid,
fgrid,
} = ws;
let max_abs_gopt = gopt_in.iter().fold(0.0_f64, |m, &v| m.max(math::abs(v)));
let (scaled, modscal): (bool, f64);
if max_abs_gopt > 1.0e12 {
let ms = (2.0 * REALMIN).max(1.0 / max_abs_gopt);
for i in 0..n {
gopt[i] = gopt_in[i] * ms;
}
for (pk, &pk_in) in pq.iter_mut().zip(pq_in) {
*pk = pk_in * ms;
}
hq.copy_from(hq_in);
for j in 0..n {
for v in hq.col_mut(j) {
*v *= ms;
}
}
scaled = true;
modscal = ms;
} else {
gopt.copy_from_slice(gopt_in);
pq.copy_from_slice(pq_in);
hq.copy_from(hq_in);
scaled = false;
modscal = 1.0;
}
let mut iact: Option<usize>;
let mut dredsq = 0.0;
let mut ggsav;
xbdi.fill(0);
for i in 0..n {
if xopt[i] >= su[i] && gopt[i] <= 0.0 {
xbdi[i] = 1;
}
}
for i in 0..n {
if xopt[i] <= sl[i] && gopt[i] >= 0.0 {
xbdi[i] = -1;
}
}
let mut nact = xbdi.iter().filter(|&&v| v != 0).count();
d.fill(0.0);
let mut crvmin = -REALMAX;
gnew.copy_from_slice(gopt);
let mut gredsq = masked_sumsq(gnew, xbdi);
let mut delsq = delta * delta;
let mut qred = 0.0;
let mut beta = 0.0;
let mut itercg = 0_usize;
let mut twod_search = false;
debug_assert!(nact <= n, "nact counts at-bound variables of n total");
let mut maxiter = (10_000).min((n - nact) * (n - nact));
s.fill(0.0);
for _iter in 0..maxiter {
let resid = delsq - masked_sumsq(d, xbdi);
if resid <= 0.0 {
twod_search = true;
break;
}
if itercg == 0 {
for i in 0..n {
s[i] = -gnew[i];
}
} else {
for i in 0..n {
s[i] = beta * s[i] - gnew[i];
}
}
for i in 0..n {
if xbdi[i] != 0 {
s[i] = 0.0;
}
}
let stepsq = inprod(s, s);
let ds = masked_inprod(d, s, xbdi);
if !(stepsq > EPS * delsq && gredsq * delsq > (tol * qred) * (tol * qred) && !ds.is_nan()) {
break;
}
let sqrtd = math::sqrt(stepsq * resid + ds * ds)
.max(math::sqrt(stepsq * resid))
.max(math::abs(ds));
let bstep = if ds >= 0.0 {
resid / (sqrtd + ds)
} else {
(sqrtd - ds) / stepsq
};
if bstep <= 0.0 || !bstep.is_finite() {
break;
}
hess_mul_into(s, xpt, pq, Some(&*hq), dxpt, hs);
let shs = masked_inprod(s, hs, xbdi);
let mut stplen = bstep;
if shs > 0.0 {
stplen = bstep.min(gredsq / shs);
}
for i in 0..n {
xnew[i] = xopt[i] + d[i];
}
for i in 0..n {
xtest[i] = xnew[i] + stplen * s[i];
}
sbound.fill(stplen);
for i in 0..n {
if s[i] > 0.0 && xtest[i] > su[i] {
sbound[i] = (su[i] - xnew[i]) / s[i];
}
}
for i in 0..n {
if s[i] < 0.0 && xtest[i] < sl[i] {
sbound[i] = (sl[i] - xnew[i]) / s[i];
}
}
for i in 0..n {
if sbound[i].is_nan() {
sbound[i] = stplen;
}
}
iact = None;
if sbound.iter().any(|&v| v < stplen) {
let mut imin = 0;
for i in 1..n {
if sbound[i] < sbound[imin] {
imin = i;
}
}
iact = Some(imin);
stplen = sbound[imin];
}
let mut sdec = 0.0;
if stplen > 0.0 {
itercg += 1;
let rayleighq = shs / stepsq;
if iact.is_none() && rayleighq > 0.0 {
if crvmin <= -REALMAX {
crvmin = rayleighq;
} else {
crvmin = crvmin.min(rayleighq);
}
}
ggsav = gredsq;
for i in 0..n {
gnew[i] += stplen * hs[i];
}
gredsq = masked_sumsq(gnew, xbdi);
dold.copy_from_slice(d);
for i in 0..n {
d[i] += stplen * s[i];
}
let abs_sum: f64 = d.iter().map(|&v| math::abs(v)).sum();
if !abs_sum.is_finite() {
d.copy_from_slice(dold);
break;
}
sdec = (stplen * (ggsav - 0.5 * stplen * shs)).max(0.0);
qred += sdec;
} else {
ggsav = gredsq;
}
if let Some(ia) = iact {
nact += 1;
debug_assert!(math::abs(s[ia]) > 0.0, "S(IACT) /= 0");
xbdi[ia] = 1.0_f64.copysign(s[ia]) as i32;
if nact >= n {
break;
}
delsq -= d[ia] * d[ia];
if delsq <= 0.0 {
twod_search = true;
break;
}
beta = 0.0;
itercg = 0;
gredsq = masked_sumsq(gnew, xbdi);
} else if stplen < bstep {
debug_assert!(
nact < n,
"nact==n breaks before reaching the CG continuation test"
);
if itercg >= n - nact || sdec <= tol * qred || sdec.is_nan() || qred.is_nan() {
break;
}
beta = gredsq / ggsav;
} else {
twod_search = true;
break;
}
}
if twod_search {
crvmin = 0.0;
debug_assert!(
nact < n,
"twod_search is only set while a free variable remains"
);
maxiter = 10 * (n - nact);
} else {
maxiter = 0;
}
let mut nactsav: isize = nact as isize - 1;
hdred.fill(0.0);
for iter1 in 1..=maxiter {
for i in 0..n {
xnew[i] = xopt[i] + d[i];
}
for i in 0..n {
if xbdi[i] == 0 && xnew[i] >= su[i] {
xbdi[i] = 1;
}
}
for i in 0..n {
if xbdi[i] == 0 && xnew[i] <= sl[i] {
xbdi[i] = -1;
}
}
nact = xbdi.iter().filter(|&&v| v != 0).count();
if nact >= n - 1 {
break;
}
gredsq = masked_sumsq(gnew, xbdi);
let dredg = masked_inprod(d, gnew, xbdi);
if iter1 == 1 || nact as isize > nactsav {
dredsq = masked_sumsq(d, xbdi);
dred.copy_from_slice(d);
for i in 0..n {
if xbdi[i] != 0 {
dred[i] = 0.0;
}
}
hess_mul_into(dred, xpt, pq, Some(&*hq), dxpt, hdred);
nactsav = nact as isize;
}
let mut temp = gredsq * dredsq - dredg * dredg;
if !(temp > tol * tol * (gredsq * dredsq).max(qred * qred)) {
break;
}
temp = math::sqrt(temp);
for i in 0..n {
s[i] = (dredg * d[i] - dredsq * gnew[i]) / temp;
}
for i in 0..n {
if xbdi[i] != 0 {
s[i] = 0.0;
}
}
let sredg = -temp;
ssq.fill(0.0);
for i in 0..n {
ssq[i] = d[i] * d[i] + s[i] * s[i];
}
tanbd.fill(1.0);
sqdscr.fill(-REALMAX);
for i in 0..n {
if xbdi[i] == 0 && xopt[i] - sl[i] < math::sqrt(ssq[i]) {
sqdscr[i] = math::sqrt(0.0_f64.max(ssq[i] - (xopt[i] - sl[i]) * (xopt[i] - sl[i])));
}
}
for i in 0..n {
if sqdscr[i] - s[i] > 0.0 {
tanbd[i] = tanbd[i].min((xnew[i] - sl[i]) / (sqdscr[i] - s[i]));
}
}
for v in sqdscr.iter_mut() {
*v = -REALMAX;
}
for i in 0..n {
if xbdi[i] == 0 && su[i] - xopt[i] < math::sqrt(ssq[i]) {
sqdscr[i] = math::sqrt(0.0_f64.max(ssq[i] - (su[i] - xopt[i]) * (su[i] - xopt[i])));
}
}
for i in 0..n {
if sqdscr[i] + s[i] > 0.0 {
tanbd[i] = tanbd[i].min((su[i] - xnew[i]) / (sqdscr[i] + s[i]));
}
}
for v in tanbd.iter_mut() {
if v.is_nan() {
*v = 0.0;
}
}
iact = None;
let mut hangt_bd = 1.0;
if tanbd.iter().any(|&v| v < 1.0) {
let mut imin = 0;
for i in 1..n {
if tanbd[i] < tanbd[imin] {
imin = i;
}
}
iact = Some(imin);
hangt_bd = tanbd[imin];
}
if hangt_bd <= 0.0 {
break;
}
hess_mul_into(s, xpt, pq, Some(&*hq), dxpt, hs);
let shs = masked_inprod(s, hs, xbdi);
let dhs = masked_inprod(d, hs, xbdi);
let dhd = masked_inprod(d, hdred, xbdi);
let args = [shs, dhd, dhs, dredg, sredg];
if args.iter().any(|v| v.is_nan()) {
break;
}
let grid_size = 2 * ((17.0 * hangt_bd + 4.1).round() as usize);
debug_assert!(
grid_size <= GRID_SIZE_MAX,
"interval_max grid over capacity"
);
let hangt = interval_max(
interval_fun_trsbox,
0.0,
hangt_bd,
&args,
grid_size,
&mut xgrid[..grid_size],
&mut fgrid[..grid_size],
);
let sdec = interval_fun_trsbox(hangt, &args);
if !(sdec > 0.0) {
break;
}
let cth = ((1.0 - hangt * hangt) / (1.0 + hangt * hangt)).min(1.0 - hangt * hangt);
let sth = ((hangt + hangt) / (1.0 + hangt * hangt)).min(hangt + hangt);
for i in 0..n {
gnew[i] = gnew[i] + (cth - 1.0) * hdred[i] + sth * hs[i];
}
dold.copy_from_slice(d);
for i in 0..n {
if xbdi[i] == 0 {
d[i] = cth * d[i] + sth * s[i];
}
}
let abs_sum: f64 = d.iter().map(|&v| math::abs(v)).sum();
if !abs_sum.is_finite() {
d.copy_from_slice(dold);
break;
}
for i in 0..n {
hdred[i] = cth * hdred[i] + sth * hs[i];
}
qred += sdec;
if let Some(ia) = iact {
if hangt >= hangt_bd {
xbdi[ia] = 1.0_f64.copysign(xopt[ia] + d[ia] - 0.5 * (sl[ia] + su[ia])) as i32;
} else if !(sdec > tol * qred) {
break;
}
} else if !(sdec > tol * qred) {
break;
}
}
for i in 0..n {
xnew[i] = (xopt[i] + d[i]).min(su[i]).max(sl[i]);
}
for i in 0..n {
if xbdi[i] == -1 {
xnew[i] = sl[i];
}
}
for i in 0..n {
if xbdi[i] == 1 {
xnew[i] = su[i];
}
}
for i in 0..n {
d[i] = xnew[i] - xopt[i];
}
if crvmin <= -REALMAX || crvmin.is_nan() {
crvmin = 0.0;
}
if scaled && crvmin > 0.0 {
crvmin /= modscal;
}
crvmin
}
fn interval_fun_trsbox(hangt: f64, args: &[f64]) -> f64 {
let mut f = 0.0;
if math::abs(hangt) > 0.0 {
let sth = (hangt + hangt) / (1.0 + hangt * hangt);
f = args[0] + hangt * (hangt * args[1] - args[2] - args[2]);
f = sth * (hangt * args[3] - args[4] - 0.5 * sth * f);
}
f
}
pub(crate) fn trrad(
delta_in: f64,
dnorm: f64,
eta1: f64,
eta2: f64,
gamma1: f64,
gamma2: f64,
ratio: f64,
) -> f64 {
if ratio <= eta1 {
(gamma1 * delta_in).min(dnorm)
} else if ratio <= eta2 {
(gamma1 * delta_in).max(dnorm)
} else {
(gamma1 * delta_in).max(gamma2 * dnorm)
}
}
fn masked_sumsq(x: &[f64], xbdi: &[i32]) -> f64 {
let mut acc = 0.0;
for i in 0..x.len() {
if xbdi[i] == 0 {
acc += x[i] * x[i];
}
}
acc
}
fn masked_inprod(x: &[f64], y: &[f64], xbdi: &[i32]) -> f64 {
let mut acc = 0.0;
for i in 0..x.len() {
if xbdi[i] == 0 {
acc += x[i] * y[i];
}
}
acc
}
#[cfg(test)]
mod tests {
#![expect(clippy::similar_names)]
use super::*;
use crate::mat::Mat;
use crate::test_support::{self, DiffStats};
#[test]
fn trrad_matches_prima_on_every_captured_state() {
let states = test_support::load_states("trrad");
assert!(!states.is_empty());
let mut stats = DiffStats::default();
for st in &states {
let (e, x) = (&st.entry, &st.exit);
let delta = trrad(
e.f64("delta_in"),
e.f64("dnorm"),
e.f64("eta1"),
e.f64("eta2"),
e.f64("gamma1"),
e.f64("gamma2"),
e.f64("ratio"),
);
stats.f64("delta", delta, x.f64("delta"));
}
stats.report("trrad");
}
#[test]
fn trsbox_matches_prima_on_every_captured_state() {
let states = test_support::load_states("trsbox");
assert!(!states.is_empty());
let mut stats = DiffStats::default();
for st in &states {
let (e, x) = (&st.entry, &st.exit);
let (gopt_in, pq_in) = (e.vec("gopt_in"), e.vec("pq_in"));
let (hq_in, xpt) = (e.mat("hq_in"), e.mat("xpt"));
let (sl, su, xopt) = (e.vec("sl"), e.vec("su"), e.vec("xopt"));
let mut d = vec![0.0; gopt_in.len()];
let mut ws = TrsboxWs::new(gopt_in.len(), pq_in.len());
let crvmin = trsbox(
e.f64("delta"),
&gopt_in,
&hq_in,
&pq_in,
&sl,
&su,
e.f64("tol"),
&xopt,
&xpt,
&mut d,
&mut ws,
);
stats.f64("crvmin", crvmin, x.f64("crvmin"));
stats.slice("d", &d, &x.vec("d"));
}
stats.report("trsbox");
}
#[test]
fn interval_fun_trsbox_is_zero_at_the_origin_and_matches_the_formula() {
assert_eq!(interval_fun_trsbox(0.0, &[1.0, 2.0, 3.0, 4.0, 5.0]), 0.0);
let f = interval_fun_trsbox(1.0, &[1.0, 2.0, 3.0, 4.0, 5.0]);
assert!(math::abs(f - 0.5) < 1e-15);
}
#[test]
fn trrad_takes_the_shrink_keep_and_expand_branches() {
assert_eq!(trrad(1.0, 0.4, 0.1, 0.7, 0.5, 2.0, 0.05), 0.4); assert_eq!(trrad(1.0, 0.8, 0.1, 0.7, 0.5, 2.0, 0.5), 0.8); assert_eq!(trrad(1.0, 0.8, 0.1, 0.7, 0.5, 2.0, 0.9), 1.6); }
#[test]
fn trsbox_takes_the_unconstrained_newton_step_on_a_separable_quadratic() {
let n = 2;
let npt = 4;
let xpt = Mat::zeros(n, npt);
let pq = vec![0.0; npt];
let mut hq = Mat::zeros(n, n);
hq[[0, 0]] = 1.0;
hq[[1, 1]] = 1.0;
let (sl, su) = (vec![-10.0; n], vec![10.0; n]);
let mut d = vec![0.0; n];
let mut ws = TrsboxWs::new(n, npt);
let crvmin = trsbox(
5.0,
&[1.0, 0.5],
&hq,
&pq,
&sl,
&su,
1e-2,
&[0.0; 2],
&xpt,
&mut d,
&mut ws,
);
assert!(math::abs(d[0] + 1.0) < 1e-12 && math::abs(d[1] + 0.5) < 1e-12);
assert!(math::abs(crvmin - 1.0) < 1e-12);
}
#[test]
fn trsbox_fixes_a_variable_pinned_at_its_bound() {
let n = 2;
let npt = 4;
let xpt = Mat::zeros(n, npt);
let pq = vec![0.0; npt];
let mut hq = Mat::zeros(n, n);
hq[[0, 0]] = 1.0;
hq[[1, 1]] = 1.0;
let (sl, su) = (vec![-10.0; n], vec![0.0, 10.0]); let mut d = vec![0.0; n];
let mut ws = TrsboxWs::new(n, npt);
let _ = trsbox(
5.0,
&[-1.0, 1.0],
&hq,
&pq,
&sl,
&su,
1e-2,
&[0.0; 2],
&xpt,
&mut d,
&mut ws,
);
assert_eq!(d[0], 0.0);
assert!((d[1] + 1.0).abs() < 1e-12, "d[1] = {}", d[1]);
}
}