Skip to main content

gam_solve/
objective_base.rs

1use ndarray::Array2;
2use std::sync::Arc;
3
4use crate::rho_optimizer::RhoBlockAdditiveOuterHessian;
5use gam_problem::OuterStrategyError;
6
7pub use gam_problem::HessianResult;
8
9pub fn add_rho_block_dense_to_hessian(
10    hessian: &mut HessianResult,
11    rho_block: &Array2<f64>,
12) -> Result<(), String> {
13    if rho_block.nrows() != rho_block.ncols() {
14        return Err(OuterStrategyError::RhoBlockShape {
15            reason: format!(
16                "rho-block Hessian update must be square, got {}x{}",
17                rho_block.nrows(),
18                rho_block.ncols()
19            ),
20        }
21        .into());
22    }
23    match hessian {
24        HessianResult::Analytic(h) => {
25            if rho_block.nrows() > h.nrows() || rho_block.ncols() > h.ncols() {
26                return Err(OuterStrategyError::RhoBlockShape {
27                    reason: format!(
28                        "rho-block Hessian update shape mismatch: got {}x{}, outer Hessian is {}x{}",
29                        rho_block.nrows(),
30                        rho_block.ncols(),
31                        h.nrows(),
32                        h.ncols()
33                    ),
34                }
35                .into());
36            }
37            let k = rho_block.nrows();
38            let mut sl = h.slice_mut(ndarray::s![..k, ..k]);
39            sl += rho_block;
40            Ok(())
41        }
42        HessianResult::Operator(op) => {
43            let base = Arc::clone(op);
44            let dim = base.dim();
45            if rho_block.nrows() > dim {
46                return Err(OuterStrategyError::RhoBlockShape {
47                    reason: format!(
48                        "rho-block Hessian update dimension mismatch: got {}x{}, operator dim is {}",
49                        rho_block.nrows(),
50                        rho_block.ncols(),
51                        dim
52                    ),
53                }
54                .into());
55            }
56            *hessian = HessianResult::Operator(Arc::new(RhoBlockAdditiveOuterHessian {
57                base,
58                rho_block: rho_block.clone(),
59                dim,
60            }));
61            Ok(())
62        }
63        HessianResult::Unavailable => Ok(()),
64    }
65}
66
67#[inline]
68pub(crate) fn failed_inner_residual_barrier_cost(
69    cost: f64,
70    inner_failed_max_iterations: bool,
71    relative_gradient_norm: f64,
72) -> f64 {
73    if !cost.is_finite() || !inner_failed_max_iterations {
74        return cost;
75    }
76    if relative_gradient_norm.is_finite() {
77        cost + 0.5 * relative_gradient_norm * relative_gradient_norm
78    } else {
79        f64::INFINITY
80    }
81}