1use crate::ipopt_cq::IpoptCqHandle;
17use crate::ipopt_data::IpoptDataHandle;
18use crate::ipopt_nlp::IpoptNlp;
19use crate::iterates_vector::{IteratesVector, IteratesVectorMut};
20use crate::kkt::pd_full_space_solver::PdFullSpaceSolver;
21use crate::kkt::search_dir_calc::SearchDirCalculator;
22use pounce_common::types::Number;
23use std::cell::{RefCell, RefMut};
24use std::rc::Rc;
25
26pub struct PdSearchDirCalc {
27 pd_solver: Rc<RefCell<PdFullSpaceSolver>>,
33 pub fast_step_computation: bool,
36 pub mehrotra_algorithm: bool,
40}
41
42impl PdSearchDirCalc {
43 pub fn new(pd_solver: PdFullSpaceSolver) -> Self {
44 Self {
45 pd_solver: Rc::new(RefCell::new(pd_solver)),
46 fast_step_computation: false,
47 mehrotra_algorithm: false,
48 }
49 }
50
51 pub fn pd_solver_rc(&self) -> Rc<RefCell<PdFullSpaceSolver>> {
55 Rc::clone(&self.pd_solver)
56 }
57
58 pub fn pd_solver_mut(&self) -> RefMut<'_, PdFullSpaceSolver> {
62 self.pd_solver.borrow_mut()
63 }
64
65 pub fn compute_search_direction(
69 &mut self,
70 data: &IpoptDataHandle,
71 cq: &IpoptCqHandle,
72 nlp: &Rc<RefCell<dyn IpoptNlp>>,
73 ) -> bool {
74 let improve_solution = data.borrow().delta.is_some();
75
76 if improve_solution && self.fast_step_computation {
77 return true;
78 }
79
80 let curr = {
81 let d = data.borrow();
82 d.curr
83 .clone()
84 .unwrap_or_else(|| panic!("PdSearchDirCalc: IpoptData::curr is unset"))
85 };
86
87 let mut rhs = curr.make_new_zeroed();
89 {
90 let cq_ref = cq.borrow();
91 rhs.x.copy(&*cq_ref.curr_grad_lag_with_damping_x());
92 rhs.s.copy(&*cq_ref.curr_grad_lag_with_damping_s());
93 rhs.y_c.copy(&*cq_ref.curr_c());
94 rhs.y_d.copy(&*cq_ref.curr_d_minus_s());
95 }
96
97 let nbounds = {
98 let n = nlp.borrow();
99 n.x_l().dim() + n.x_u().dim() + n.d_l().dim() + n.d_u().dim()
100 };
101
102 if nbounds > 0 && self.mehrotra_algorithm {
103 let delta_aff = {
104 let d = data.borrow();
105 d.delta_aff
106 .clone()
107 .unwrap_or_else(|| panic!("PdSearchDirCalc: delta_aff missing for Mehrotra"))
108 };
109 self.fill_mehrotra_z_blocks(&delta_aff, cq, nlp, &mut rhs);
110 } else {
111 let cq_ref = cq.borrow();
112 rhs.z_l.copy(&*cq_ref.curr_relaxed_compl_x_l());
113 rhs.z_u.copy(&*cq_ref.curr_relaxed_compl_x_u());
114 rhs.v_l.copy(&*cq_ref.curr_relaxed_compl_s_l());
115 rhs.v_u.copy(&*cq_ref.curr_relaxed_compl_s_u());
116 }
117
118 let frozen_rhs = rhs.freeze();
119
120 let mut delta = frozen_rhs.make_new_zeroed();
123 if improve_solution {
124 let prev = {
125 let d = data.borrow();
126 let Some(p) = d.delta.clone() else {
127 unreachable!("PdSearchDirCalc: delta cleared between is_some() and clone()")
128 };
129 p
130 };
131 delta.add_one_vector(-1.0, &prev, 0.0);
132 }
133
134 let allow_inexact = self.fast_step_computation;
135 let ok = self.pd_solver.borrow_mut().solve(
136 data,
137 cq,
138 nlp,
139 -1.0,
140 0.0,
141 &frozen_rhs,
142 &mut delta,
143 allow_inexact,
144 improve_solution,
145 );
146
147 if ok {
148 data.borrow_mut().set_delta(delta.freeze());
149 }
150 ok
151 }
152
153 pub fn compute_affine_step(
163 &mut self,
164 data: &IpoptDataHandle,
165 cq: &IpoptCqHandle,
166 nlp: &Rc<RefCell<dyn IpoptNlp>>,
167 ) -> bool {
168 let curr = {
169 let d = data.borrow();
170 d.curr
171 .clone()
172 .unwrap_or_else(|| panic!("PdSearchDirCalc: IpoptData::curr is unset"))
173 };
174
175 let mut rhs = curr.make_new_zeroed();
176 {
177 let cq_ref = cq.borrow();
178 rhs.x.copy(&*cq_ref.curr_grad_lag_x());
184 rhs.s.copy(&*cq_ref.curr_grad_lag_s());
185 rhs.y_c.copy(&*cq_ref.curr_c());
186 rhs.y_d.copy(&*cq_ref.curr_d_minus_s());
187 rhs.z_l.copy(&*cq_ref.curr_compl_x_l());
190 rhs.z_u.copy(&*cq_ref.curr_compl_x_u());
191 rhs.v_l.copy(&*cq_ref.curr_compl_s_l());
192 rhs.v_u.copy(&*cq_ref.curr_compl_s_u());
193 }
194
195 let frozen_rhs = rhs.freeze();
196 let mut delta_aff = frozen_rhs.make_new_zeroed();
197
198 let ok = self.pd_solver.borrow_mut().solve(
207 data,
208 cq,
209 nlp,
210 -1.0,
211 0.0,
212 &frozen_rhs,
213 &mut delta_aff,
214 false,
215 false,
216 );
217
218 if ok {
219 data.borrow_mut().set_delta_aff(delta_aff.freeze());
220 }
221 ok
222 }
223
224 pub fn compute_centering_step(
232 &mut self,
233 data: &IpoptDataHandle,
234 cq: &IpoptCqHandle,
235 nlp: &Rc<RefCell<dyn IpoptNlp>>,
236 ) -> bool {
237 let curr = {
238 let d = data.borrow();
239 d.curr
240 .clone()
241 .unwrap_or_else(|| panic!("PdSearchDirCalc: IpoptData::curr is unset"))
242 };
243 let avrg_compl = cq.borrow().curr_avrg_compl();
244
245 let mut rhs = curr.make_new_zeroed();
246 {
252 let cq_ref = cq.borrow();
253 rhs.x
254 .add_one_vector(-avrg_compl, &*cq_ref.grad_kappa_times_damping_x(), 0.0);
255 rhs.s
256 .add_one_vector(-avrg_compl, &*cq_ref.grad_kappa_times_damping_s(), 0.0);
257 }
258 rhs.y_c.set(0.0);
259 rhs.y_d.set(0.0);
260 rhs.z_l.set(avrg_compl);
261 rhs.z_u.set(avrg_compl);
262 rhs.v_l.set(avrg_compl);
263 rhs.v_u.set(avrg_compl);
264
265 let frozen_rhs = rhs.freeze();
266 let mut delta_cen = frozen_rhs.make_new_zeroed();
267
268 let ok = self.pd_solver.borrow_mut().solve(
273 data,
274 cq,
275 nlp,
276 1.0,
277 0.0,
278 &frozen_rhs,
279 &mut delta_cen,
280 false,
281 false,
282 );
283
284 if ok {
285 data.borrow_mut().set_delta_cen(delta_cen.freeze());
286 }
287 ok
288 }
289
290 pub fn compute_soc_step(
304 &mut self,
305 data: &IpoptDataHandle,
306 cq: &IpoptCqHandle,
307 nlp: &Rc<RefCell<dyn IpoptNlp>>,
308 c_soc: &dyn pounce_linalg::Vector,
309 dms_soc: &dyn pounce_linalg::Vector,
310 alpha_primal_soc: Number,
311 soc_method: i32,
312 ) -> Option<IteratesVector> {
313 let curr = {
314 let d = data.borrow();
315 d.curr
316 .clone()
317 .unwrap_or_else(|| panic!("PdSearchDirCalc::compute_soc_step: curr is unset"))
318 };
319 let mut rhs = curr.make_new_zeroed();
320 {
321 let cq_ref = cq.borrow();
322 rhs.x.copy(&*cq_ref.curr_grad_lag_with_damping_x());
323 rhs.s.copy(&*cq_ref.curr_grad_lag_with_damping_s());
324 if soc_method == 1 {
325 rhs.x.scal(alpha_primal_soc);
326 rhs.s.scal(alpha_primal_soc);
327 }
328 rhs.y_c.copy(c_soc);
329 rhs.y_d.copy(dms_soc);
330 rhs.z_l.copy(&*cq_ref.curr_relaxed_compl_x_l());
331 rhs.z_u.copy(&*cq_ref.curr_relaxed_compl_x_u());
332 rhs.v_l.copy(&*cq_ref.curr_relaxed_compl_s_l());
333 rhs.v_u.copy(&*cq_ref.curr_relaxed_compl_s_u());
334 }
335 let frozen_rhs = rhs.freeze();
336 let mut delta_soc = frozen_rhs.make_new_zeroed();
337 let ok = self.pd_solver.borrow_mut().solve(
338 data,
339 cq,
340 nlp,
341 -1.0,
342 0.0,
343 &frozen_rhs,
344 &mut delta_soc,
345 false,
346 false,
347 );
348 if ok {
349 Some(delta_soc.freeze())
350 } else {
351 None
352 }
353 }
354
355 fn fill_mehrotra_z_blocks(
359 &self,
360 delta_aff: &IteratesVector,
361 cq: &IpoptCqHandle,
362 nlp: &Rc<RefCell<dyn IpoptNlp>>,
363 rhs: &mut IteratesVectorMut,
364 ) {
365 let n = nlp.borrow();
366 let cq_ref = cq.borrow();
367
368 n.px_l()
370 .trans_mult_vector(1.0, &*delta_aff.x, 0.0, &mut *rhs.z_l);
371 rhs.z_l.element_wise_multiply(&*delta_aff.z_l);
372 rhs.z_l.axpy(1.0, &*cq_ref.curr_relaxed_compl_x_l());
373
374 n.px_u()
376 .trans_mult_vector(-1.0, &*delta_aff.x, 0.0, &mut *rhs.z_u);
377 rhs.z_u.element_wise_multiply(&*delta_aff.z_u);
378 rhs.z_u.axpy(1.0, &*cq_ref.curr_relaxed_compl_x_u());
379
380 n.pd_l()
382 .trans_mult_vector(1.0, &*delta_aff.s, 0.0, &mut *rhs.v_l);
383 rhs.v_l.element_wise_multiply(&*delta_aff.v_l);
384 rhs.v_l.axpy(1.0, &*cq_ref.curr_relaxed_compl_s_l());
385
386 n.pd_u()
388 .trans_mult_vector(-1.0, &*delta_aff.s, 0.0, &mut *rhs.v_u);
389 rhs.v_u.element_wise_multiply(&*delta_aff.v_u);
390 rhs.v_u.axpy(1.0, &*cq_ref.curr_relaxed_compl_s_u());
391 }
392}
393
394impl SearchDirCalculator for PdSearchDirCalc {}
395
396pub fn mehrotra_corrector_lower(
401 delta_aff_x_lo: Number,
402 delta_aff_z: Number,
403 relaxed_compl: Number,
404) -> Number {
405 delta_aff_x_lo * delta_aff_z + relaxed_compl
406}
407
408pub fn mehrotra_corrector_upper(
409 delta_aff_x_up: Number,
410 delta_aff_z: Number,
411 relaxed_compl: Number,
412) -> Number {
413 -delta_aff_x_up * delta_aff_z + relaxed_compl
414}
415
416pub fn relaxed_complementarity(x: Number, z: Number, mu: Number) -> Number {
417 x * z - mu
418}
419
420#[cfg(test)]
421mod tests {
422 use super::*;
423
424 #[test]
425 fn relaxed_compl_at_central_path_is_zero() {
426 assert_eq!(relaxed_complementarity(2.0, 0.5, 1.0), 0.0);
427 }
428
429 #[test]
430 fn mehrotra_lower_combines_linearly() {
431 assert_eq!(mehrotra_corrector_lower(1.0, 2.0, 0.5), 2.5);
432 }
433
434 #[test]
435 fn mehrotra_upper_negates_dx() {
436 assert_eq!(mehrotra_corrector_upper(1.0, 2.0, 0.5), -1.5);
437 }
438}