use crate::conv_check::{RestoFilterConvCheck, RestoPenaltyConvCheck};
use crate::init::RestoIterateInitializer;
use crate::min_c_1nrm::MinC1NormRestoration;
use crate::output::{InfPrTag, PrintInfoString, RestoIterationOutput};
use crate::r#trait::{RestorationOutcome, RestorationPhase};
use crate::resto_nlp::RestoIpoptNlp;
use pounce_algorithm::ipopt_cq::IpoptCqHandle;
use pounce_algorithm::ipopt_data::IpoptDataHandle;
use pounce_algorithm::ipopt_nlp::IpoptNlp;
use pounce_algorithm::kkt::aug_system_solver::AugSystemSolver;
use std::cell::RefCell;
use std::rc::Rc;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum OuterLineSearch {
Filter,
Penalty,
}
pub enum RestoConvCheckSlot {
Filter(RestoFilterConvCheck),
Penalty(RestoPenaltyConvCheck),
}
pub struct RestoAlgorithmBundle {
pub nlp: RestoIpoptNlp,
pub init: RestoIterateInitializer,
pub conv_check: RestoConvCheckSlot,
pub iter_output: RestoIterationOutput,
pub driver: MinC1NormRestoration,
}
#[derive(Debug, Clone)]
pub struct RestoAlgorithmBuilder {
pub rho: f64,
pub eta_factor: f64,
pub evaluate_orig_obj_at_resto_trial: bool,
pub bound_mult_reset_threshold: f64,
pub constr_mult_reset_threshold: f64,
pub expect_infeasible_problem: bool,
pub start_with_resto: bool,
pub outer_line_search: OuterLineSearch,
pub inf_pr_output: InfPrTag,
pub print_info_string: PrintInfoString,
pub obj_max_inc: f64,
}
impl Default for RestoAlgorithmBuilder {
fn default() -> Self {
Self {
rho: 1e3,
eta_factor: 1.0,
evaluate_orig_obj_at_resto_trial: true,
bound_mult_reset_threshold: 1e3,
constr_mult_reset_threshold: 0.0,
expect_infeasible_problem: false,
start_with_resto: false,
outer_line_search: OuterLineSearch::Filter,
inf_pr_output: InfPrTag::Original,
print_info_string: PrintInfoString::No,
obj_max_inc: 5.0,
}
}
}
impl RestoAlgorithmBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn build(
&self,
n_orig: pounce_common::types::Index,
m_eq: pounce_common::types::Index,
m_ineq: pounce_common::types::Index,
x_ref_vals: &[f64],
) -> RestoAlgorithmBundle {
let mut nlp =
RestoIpoptNlp::new(n_orig, m_eq, m_ineq, x_ref_vals, self.rho, self.eta_factor);
nlp.evaluate_orig_obj_at_resto_trial = self.evaluate_orig_obj_at_resto_trial;
let init = RestoIterateInitializer::with_dims(n_orig, m_eq, m_ineq, x_ref_vals.to_vec())
.with_rho(self.rho);
let conv_check = match self.outer_line_search {
OuterLineSearch::Filter => {
let mut cc = RestoFilterConvCheck::new();
cc.obj_max_inc = self.obj_max_inc;
RestoConvCheckSlot::Filter(cc)
}
OuterLineSearch::Penalty => RestoConvCheckSlot::Penalty(RestoPenaltyConvCheck::new()),
};
let iter_output = RestoIterationOutput {
print_info_string: self.print_info_string,
inf_pr_output: self.inf_pr_output,
..RestoIterationOutput::default()
};
let driver = MinC1NormRestoration {
bound_mult_reset_threshold: self.bound_mult_reset_threshold,
constr_mult_reset_threshold: self.constr_mult_reset_threshold,
expect_infeasible_problem: self.expect_infeasible_problem,
start_with_resto: self.start_with_resto,
..MinC1NormRestoration::default()
};
RestoAlgorithmBundle {
nlp,
init,
conv_check,
iter_output,
driver,
}
}
}
impl RestorationPhase for RestoAlgorithmBundle {
fn perform_restoration(
&mut self,
data: &IpoptDataHandle,
cq: &IpoptCqHandle,
nlp: &Rc<RefCell<dyn IpoptNlp>>,
aug_solver: &mut dyn AugSystemSolver,
) -> RestorationOutcome {
self.driver.perform_restoration(data, cq, nlp, aug_solver)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_options_match_upstream() {
let b = RestoAlgorithmBuilder::new();
assert_eq!(b.rho, 1e3);
assert_eq!(b.eta_factor, 1.0);
assert!(b.evaluate_orig_obj_at_resto_trial);
assert_eq!(b.bound_mult_reset_threshold, 1e3);
assert_eq!(b.constr_mult_reset_threshold, 0.0);
assert!(!b.expect_infeasible_problem);
assert!(!b.start_with_resto);
assert_eq!(b.outer_line_search, OuterLineSearch::Filter);
assert_eq!(b.inf_pr_output, InfPrTag::Original);
assert_eq!(b.print_info_string, PrintInfoString::No);
assert_eq!(b.obj_max_inc, 5.0);
}
#[test]
fn build_propagates_nlp_dims_and_rho_eta() {
let b = RestoAlgorithmBuilder {
rho: 2.5,
eta_factor: 7.0,
..RestoAlgorithmBuilder::default()
};
let bundle = b.build(3, 2, 1, &[0.1, 0.2, 0.3]);
assert_eq!(bundle.nlp.n_orig, 3);
assert_eq!(bundle.nlp.m_eq, 2);
assert_eq!(bundle.nlp.m_ineq, 1);
assert_eq!(bundle.nlp.rho, 2.5);
assert_eq!(bundle.nlp.eta_factor, 7.0);
}
#[test]
fn build_propagates_evaluate_orig_obj_flag() {
let b = RestoAlgorithmBuilder {
evaluate_orig_obj_at_resto_trial: false,
..RestoAlgorithmBuilder::default()
};
let bundle = b.build(1, 0, 0, &[0.0]);
assert!(!bundle.nlp.evaluate_orig_obj_at_resto_trial);
}
#[test]
fn outer_filter_selects_filter_conv_check() {
let b = RestoAlgorithmBuilder {
outer_line_search: OuterLineSearch::Filter,
obj_max_inc: 12.0,
..RestoAlgorithmBuilder::default()
};
let bundle = b.build(1, 0, 0, &[0.0]);
match bundle.conv_check {
RestoConvCheckSlot::Filter(cc) => {
assert_eq!(cc.obj_max_inc, 12.0);
}
RestoConvCheckSlot::Penalty(_) => panic!("expected filter conv check"),
}
}
#[test]
fn outer_penalty_selects_penalty_conv_check() {
let b = RestoAlgorithmBuilder {
outer_line_search: OuterLineSearch::Penalty,
..RestoAlgorithmBuilder::default()
};
let bundle = b.build(1, 0, 0, &[0.0]);
assert!(matches!(bundle.conv_check, RestoConvCheckSlot::Penalty(_)));
}
#[test]
fn iter_output_carries_inf_pr_and_print_info_options() {
let b = RestoAlgorithmBuilder {
inf_pr_output: InfPrTag::Internal,
print_info_string: PrintInfoString::Yes,
..RestoAlgorithmBuilder::default()
};
let bundle = b.build(1, 0, 0, &[0.0]);
assert_eq!(bundle.iter_output.inf_pr_output, InfPrTag::Internal);
assert_eq!(bundle.iter_output.print_info_string, PrintInfoString::Yes);
}
#[test]
fn driver_picks_up_reset_thresholds_and_flags() {
let b = RestoAlgorithmBuilder {
bound_mult_reset_threshold: 2.5e2,
constr_mult_reset_threshold: 7.0,
expect_infeasible_problem: true,
start_with_resto: true,
..RestoAlgorithmBuilder::default()
};
let bundle = b.build(1, 0, 0, &[0.0]);
assert_eq!(bundle.driver.bound_mult_reset_threshold, 2.5e2);
assert_eq!(bundle.driver.constr_mult_reset_threshold, 7.0);
assert!(bundle.driver.expect_infeasible_problem);
assert!(bundle.driver.start_with_resto);
}
#[test]
fn bundle_driver_default_state_is_propagated_from_builder() {
let bundle = RestoAlgorithmBuilder::new().build(1, 0, 0, &[0.0]);
assert_eq!(bundle.driver.bound_mult_reset_threshold, 1e3);
}
}