use crate::ipopt_cq::IpoptCqHandle;
use crate::ipopt_data::IpoptDataHandle;
use crate::ipopt_nlp::IpoptNlp;
use crate::iterates_vector::IteratesVector;
use crate::kkt::pd_search_dir_calc::PdSearchDirCalc;
use crate::mu::oracle::r#trait::MuOracle;
use pounce_common::types::Number;
use std::cell::RefCell;
use std::rc::Rc;
pub struct ProbingMuOracle {
pub sigma_max: Number,
pub mu_min: Number,
pub mu_max: Number,
pub mu_curr: Number,
pub mu_aff: Number,
}
impl Default for ProbingMuOracle {
fn default() -> Self {
Self {
sigma_max: 100.0,
mu_min: 1e-11,
mu_max: 1e5,
mu_curr: 1.0,
mu_aff: 1.0,
}
}
}
impl ProbingMuOracle {
pub fn new() -> Self {
Self::default()
}
pub fn probing_mu(mu_curr: Number, mu_aff: Number, sigma_max: Number) -> Number {
let sigma = (mu_aff / mu_curr).powi(3).min(sigma_max);
sigma * mu_curr
}
pub fn calculate_mu_with_affine_step(
&mut self,
data: &IpoptDataHandle,
cq: &IpoptCqHandle,
nlp: &Rc<RefCell<dyn IpoptNlp>>,
pd_search_dir: &mut PdSearchDirCalc,
tau: Number,
) -> Option<Number> {
if !pd_search_dir.compute_affine_step(data, cq, nlp) {
return None;
}
let delta_aff: IteratesVector = data.borrow().delta_aff.clone()?;
let cq_ref = cq.borrow();
let mu_curr = cq_ref.curr_avrg_compl();
let alpha_pri = cq_ref.aff_step_alpha_primal_max(&delta_aff, tau);
let alpha_du = cq_ref.aff_step_alpha_dual_max(&delta_aff, tau);
let mu_aff = cq_ref.aff_step_compl_avrg(&delta_aff, alpha_pri, alpha_du);
self.mu_curr = mu_curr;
self.mu_aff = mu_aff;
let raw = Self::probing_mu(mu_curr, mu_aff, self.sigma_max);
Some(raw.clamp(self.mu_min, self.mu_max))
}
}
impl MuOracle for ProbingMuOracle {
fn calculate_mu(&mut self) -> Option<Number> {
let raw = Self::probing_mu(self.mu_curr, self.mu_aff, self.sigma_max);
Some(raw.clamp(self.mu_min, self.mu_max))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn probing_aff_equals_curr_keeps_mu() {
assert_eq!(ProbingMuOracle::probing_mu(1.0, 1.0, 100.0), 1.0);
}
#[test]
fn probing_aff_half_curr() {
let m = ProbingMuOracle::probing_mu(1.0, 0.5, 100.0);
assert!((m - 0.125).abs() < 1e-15);
}
#[test]
fn probing_caps_at_sigma_max() {
let m = ProbingMuOracle::probing_mu(1.0, 10.0, 100.0);
assert!((m - 100.0).abs() < 1e-13);
}
#[test]
fn calculate_mu_via_trait_clamped() {
let mut o = ProbingMuOracle {
sigma_max: 100.0,
mu_min: 0.5,
mu_max: 10.0,
mu_curr: 1.0,
mu_aff: 0.001, };
assert_eq!(o.calculate_mu(), Some(0.5));
}
}