use pounce_common::types::{Index, Number};
use crate::convenience::SensResult;
pub const DEFAULT_ACTIVE_TOL: Number = 1e-6;
#[derive(Debug, Clone)]
pub struct DiffHandoff {
pub x: Vec<Number>,
pub obj_val: Number,
pub lambda: Vec<Number>,
pub mult_x_lower: Vec<Number>,
pub mult_x_upper: Vec<Number>,
pub active_constraints: Vec<bool>,
pub pinned_vars: Vec<bool>,
pub active_tol: Number,
}
impl DiffHandoff {
pub fn from_solution(
x: Vec<Number>,
obj_val: Number,
lambda: Vec<Number>,
mult_x_lower: Vec<Number>,
mult_x_upper: Vec<Number>,
equality_mask: &[bool],
active_tol: Number,
) -> Self {
debug_assert_eq!(mult_x_lower.len(), x.len(), "z_L length must match x");
debug_assert_eq!(mult_x_upper.len(), x.len(), "z_U length must match x");
let (pinned_vars, active_constraints) = Self::masks(
&mult_x_lower,
&mult_x_upper,
&lambda,
equality_mask,
active_tol,
);
Self {
x,
obj_val,
lambda,
mult_x_lower,
mult_x_upper,
active_constraints,
pinned_vars,
active_tol,
}
}
pub fn masks(
mult_x_lower: &[Number],
mult_x_upper: &[Number],
lambda: &[Number],
equality_mask: &[bool],
active_tol: Number,
) -> (Vec<bool>, Vec<bool>) {
debug_assert_eq!(
mult_x_lower.len(),
mult_x_upper.len(),
"z_L and z_U lengths must match"
);
debug_assert!(
equality_mask.is_empty() || equality_mask.len() == lambda.len(),
"equality_mask must be empty or length n_g"
);
let pinned_vars = mult_x_lower
.iter()
.zip(mult_x_upper.iter())
.map(|(&l, &u)| l > active_tol || u > active_tol)
.collect();
let active_constraints = lambda
.iter()
.enumerate()
.map(|(i, &lam)| {
equality_mask.get(i).copied().unwrap_or(false) || lam.abs() > active_tol
})
.collect();
(pinned_vars, active_constraints)
}
pub fn from_sens_result(res: &SensResult, equality_mask: &[bool]) -> Option<Self> {
let x = res.x.clone()?;
let obj_val = res.obj_val?;
let lambda = res.mult_g.clone()?;
let mult_x_lower = res.mult_x_l.clone()?;
let mult_x_upper = res.mult_x_u.clone()?;
Some(Self::from_solution(
x,
obj_val,
lambda,
mult_x_lower,
mult_x_upper,
equality_mask,
DEFAULT_ACTIVE_TOL,
))
}
pub fn pin(&mut self, indices: &[Index]) {
for &i in indices {
if i < 0 {
continue;
}
if let Some(slot) = self.pinned_vars.get_mut(i as usize) {
*slot = true;
}
}
}
pub fn n_x(&self) -> usize {
self.x.len()
}
pub fn n_g(&self) -> usize {
self.lambda.len()
}
pub fn n_pinned(&self) -> usize {
self.pinned_vars.iter().filter(|&&b| b).count()
}
pub fn n_active_constraints(&self) -> usize {
self.active_constraints.iter().filter(|&&b| b).count()
}
}
#[cfg(test)]
mod tests {
use super::*;
use pounce_nlp::return_codes::ApplicationReturnStatus;
#[test]
fn from_sens_result_degenerate_equality_needs_the_mask() {
let res = SensResult {
status: ApplicationReturnStatus::SolveSucceeded,
error: None,
x: Some(vec![1.0]),
obj_val: Some(0.0),
dx: None,
dx_full: None,
reduced_hessian: None,
reduced_hessian_scaled: None,
obj_scaling_factor: None,
pin_g_scaling: None,
kkt_perturbations: None,
reduced_hessian_eigenvalues: None,
reduced_hessian_eigenvectors: None,
mult_g: Some(vec![0.0]), mult_x_l: Some(vec![0.0]),
mult_x_u: Some(vec![0.0]),
g: Some(vec![0.0]),
};
let dropped = DiffHandoff::from_sens_result(&res, &[]).unwrap();
assert_eq!(dropped.active_constraints, vec![false]);
let kept = DiffHandoff::from_sens_result(&res, &[true]).unwrap();
assert_eq!(kept.active_constraints, vec![true]);
}
#[test]
fn from_sens_result_returns_none_without_duals() {
let res = SensResult {
status: ApplicationReturnStatus::SolveSucceeded,
error: None,
x: Some(vec![1.0]),
obj_val: Some(0.0),
dx: None,
dx_full: None,
reduced_hessian: None,
reduced_hessian_scaled: None,
obj_scaling_factor: None,
pin_g_scaling: None,
kkt_perturbations: None,
reduced_hessian_eigenvalues: None,
reduced_hessian_eigenvectors: None,
mult_g: None, mult_x_l: None,
mult_x_u: None,
g: None,
};
assert!(DiffHandoff::from_sens_result(&res, &[]).is_none());
}
#[test]
fn pins_active_bounds_and_marks_active_constraints() {
let x = vec![0.0, 1.0, 2.0];
let z_l = vec![5.0, 0.0, 0.0];
let z_u = vec![0.0, 0.0, 3.0];
let lambda = vec![0.0, 1e-9, 4.0];
let eq = vec![true, false, false];
let h = DiffHandoff::from_solution(x, 42.0, lambda, z_l, z_u, &eq, DEFAULT_ACTIVE_TOL);
assert_eq!(h.pinned_vars, vec![true, false, true]);
assert_eq!(h.active_constraints, vec![true, false, true]);
assert_eq!(h.n_pinned(), 2);
assert_eq!(h.n_active_constraints(), 2);
assert_eq!(h.obj_val, 42.0);
}
#[test]
fn empty_equality_mask_treats_only_nonzero_rows_as_active() {
let h = DiffHandoff::from_solution(
vec![0.0],
0.0,
vec![0.0, 5.0],
vec![0.0],
vec![0.0],
&[],
DEFAULT_ACTIVE_TOL,
);
assert_eq!(h.active_constraints, vec![false, true]);
}
#[test]
fn pin_adds_integer_variables() {
let mut h = DiffHandoff::from_solution(
vec![0.0, 0.0, 0.0],
0.0,
vec![],
vec![0.0, 0.0, 0.0],
vec![0.0, 0.0, 0.0],
&[],
DEFAULT_ACTIVE_TOL,
);
assert_eq!(h.n_pinned(), 0);
h.pin(&[1, 99]); assert_eq!(h.pinned_vars, vec![false, true, false]);
assert_eq!(h.n_pinned(), 1);
}
}