pounce_algorithm/line_search/
penalty_acceptor.rs1use crate::ipopt_cq::IpoptCqHandle;
35use crate::ipopt_data::IpoptDataHandle;
36use crate::iterates_vector::IteratesVector;
37use crate::line_search::filter_acceptor::AcceptDecision;
38use crate::line_search::ls_acceptor::BacktrackingLsAcceptor;
39use pounce_common::types::Number;
40use pounce_common::utils::compare_le;
41use pounce_linalg::Vector;
42use std::rc::Rc;
43
44pub struct PenaltyLsAcceptor {
45 pub rho: Number,
48 pub nu_inc: Number,
50 pub nu_init: Number,
52 pub nu_max: Number,
54 pub eta_penalty: Number,
56 nu: Number,
57 last_nu: Number,
58 cache: Option<RefCache>,
61}
62
63struct RefCache {
65 theta_ref: Number,
66 barr_ref: Number,
67 grad_barr_t_delta: Number,
68 dwd: Number,
69 c_ref: Rc<dyn Vector>,
71 d_minus_s_ref: Rc<dyn Vector>,
73 jac_c_delta: Rc<dyn Vector>,
75 jac_d_delta_minus_ds: Rc<dyn Vector>,
77}
78
79impl Default for PenaltyLsAcceptor {
80 fn default() -> Self {
81 Self {
82 rho: 0.1,
83 nu_inc: 1e-4,
84 nu_init: 1e-6,
85 nu_max: 1e40,
86 eta_penalty: 1e-8,
87 nu: 1e-6,
88 last_nu: 1e-6,
89 cache: None,
90 }
91 }
92}
93
94impl PenaltyLsAcceptor {
95 pub fn new() -> Self {
96 Self::default()
97 }
98
99 pub fn nu(&self) -> Number {
100 self.nu
101 }
102
103 pub fn last_nu(&self) -> Number {
104 self.last_nu
105 }
106
107 pub fn reset(&mut self) {
110 self.nu = self.nu_init;
111 self.last_nu = self.nu_init;
112 self.cache = None;
113 }
114
115 pub fn update_nu(
124 &mut self,
125 grad_barr_t_delta: Number,
126 delta_w_delta: Number,
127 reference_theta: Number,
128 ) {
129 self.last_nu = self.nu;
130 if reference_theta > 0.0 {
131 let nu_plus =
132 (grad_barr_t_delta + 0.5 * delta_w_delta) / ((1.0 - self.rho) * reference_theta);
133 if self.nu < nu_plus {
134 self.nu = nu_plus + self.nu_inc;
135 }
136 }
137 }
138
139 fn calc_pred(&self, alpha: Number) -> Number {
143 let cache = self
144 .cache
145 .as_ref()
146 .expect("calc_pred called before init_this_line_search");
147 let mut tmp_c = cache.c_ref.make_new();
149 tmp_c.set(0.0);
150 tmp_c.add_two_vectors(1.0, &*cache.c_ref, alpha, &*cache.jac_c_delta, 0.0);
151 let mut tmp_d = cache.d_minus_s_ref.make_new();
152 tmp_d.set(0.0);
153 tmp_d.add_two_vectors(
154 1.0,
155 &*cache.d_minus_s_ref,
156 alpha,
157 &*cache.jac_d_delta_minus_ds,
158 0.0,
159 );
160 let theta_2 = tmp_c.asum() + tmp_d.asum();
161
162 let pred = -alpha * cache.grad_barr_t_delta - 0.5 * alpha * alpha * cache.dwd
163 + self.nu * (cache.theta_ref - theta_2);
164 if pred < 0.0 {
165 0.0
166 } else {
167 pred
168 }
169 }
170}
171
172impl BacktrackingLsAcceptor for PenaltyLsAcceptor {
173 fn reset(&mut self) {
174 PenaltyLsAcceptor::reset(self);
175 }
176
177 fn init_this_line_search(
181 &mut self,
182 _data: &IpoptDataHandle,
183 cq: &IpoptCqHandle,
184 delta: &IteratesVector,
185 ) {
186 let cqr = cq.borrow();
187 let theta_ref = cqr.curr_constraint_violation();
188 let barr_ref = cqr.curr_barrier_obj();
189 let grad_barr_t_delta = cqr.curr_grad_barr_t_delta(&*delta.x, &*delta.s);
190 let dwd = cqr.curr_dwd(&*delta.x, &*delta.s);
191
192 let c_ref = cqr.curr_c();
194 let d_minus_s_ref = cqr.curr_d_minus_s();
195 let jac_c_delta = cqr.curr_jac_c_times_vec(&*delta.x);
196 let jac_d_delta = cqr.curr_jac_d_times_vec(&*delta.x);
198 let mut tmp = jac_d_delta.make_new();
199 tmp.set(0.0);
200 tmp.add_two_vectors(1.0, &*jac_d_delta, -1.0, &*delta.s, 0.0);
201 let jac_d_delta_minus_ds: Rc<dyn Vector> = Rc::from(tmp);
202 drop(cqr);
203
204 self.cache = Some(RefCache {
205 theta_ref,
206 barr_ref,
207 grad_barr_t_delta,
208 dwd,
209 c_ref,
210 d_minus_s_ref,
211 jac_c_delta,
212 jac_d_delta_minus_ds,
213 });
214
215 self.update_nu(grad_barr_t_delta, dwd, theta_ref);
217 }
218
219 fn check_trial_point(
233 &mut self,
234 alpha_primal: Number,
235 _theta: Number,
236 _phi: Number,
237 _d_phi: Number,
238 theta_trial: Number,
239 phi_trial: Number,
240 ) -> AcceptDecision {
241 let cache = match &self.cache {
245 Some(c) => c,
246 None => return AcceptDecision::Accept,
247 };
248
249 let pred = self.calc_pred(alpha_primal);
250 let ref_merit = cache.barr_ref + self.nu * cache.theta_ref;
251 let ared = ref_merit - (phi_trial + self.nu * theta_trial);
252
253 if compare_le(self.eta_penalty * pred, ared, ref_merit.abs()) {
254 AcceptDecision::Accept
255 } else {
256 AcceptDecision::Reject
257 }
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264
265 #[test]
266 fn no_bump_when_theta_zero() {
267 let mut a = PenaltyLsAcceptor::new();
268 let nu0 = a.nu();
269 a.update_nu(10.0, 5.0, 0.0);
270 assert_eq!(a.nu(), nu0);
271 assert_eq!(a.last_nu(), nu0);
272 }
273
274 #[test]
275 fn bump_when_nu_plus_exceeds_current() {
276 let mut a = PenaltyLsAcceptor {
277 rho: 0.1,
278 nu_inc: 1e-4,
279 nu: 0.0,
280 last_nu: 0.0,
281 ..Default::default()
282 };
283 a.update_nu(1.0, 0.0, 1.0);
286 assert!(a.last_nu() == 0.0);
287 let expected = 1.0 / 0.9 + 1e-4;
288 assert!((a.nu() - expected).abs() < 1e-12);
289 }
290
291 #[test]
292 fn no_bump_when_already_above_nu_plus() {
293 let mut a = PenaltyLsAcceptor {
294 rho: 0.1,
295 nu_inc: 1e-4,
296 nu: 1e6,
297 last_nu: 1e6,
298 ..Default::default()
299 };
300 a.update_nu(1.0, 0.0, 1.0);
301 assert_eq!(a.nu(), 1e6);
302 }
303
304 #[test]
305 fn reset_restores_init() {
306 let mut a = PenaltyLsAcceptor::new();
307 a.update_nu(10.0, 0.0, 1.0); let bumped = a.nu();
309 assert!(bumped > a.nu_init);
310 PenaltyLsAcceptor::reset(&mut a);
311 assert_eq!(a.nu(), a.nu_init);
312 }
313
314 #[test]
315 fn check_trial_point_without_cache_accepts() {
316 let mut a = PenaltyLsAcceptor::new();
318 assert_eq!(
319 a.check_trial_point(1.0, 1.0, 10.0, -1.0, 0.5, 8.0),
320 AcceptDecision::Accept
321 );
322 }
323
324 fn cache_for_test(
327 theta_ref: Number,
328 barr_ref: Number,
329 grad_barr_t_delta: Number,
330 dwd: Number,
331 c_ref: Vec<Number>,
332 d_minus_s_ref: Vec<Number>,
333 jac_c_delta: Vec<Number>,
334 jac_d_delta_minus_ds: Vec<Number>,
335 ) -> RefCache {
336 use pounce_linalg::dense_vector::DenseVectorSpace;
337 use pounce_linalg::Vector;
338 let mkr = |v: Vec<Number>| -> Rc<dyn Vector> {
339 let mut x = DenseVectorSpace::new(v.len() as i32).make_new_dense();
340 x.values_mut().copy_from_slice(&v);
341 Rc::new(x)
342 };
343 RefCache {
344 theta_ref,
345 barr_ref,
346 grad_barr_t_delta,
347 dwd,
348 c_ref: mkr(c_ref),
349 d_minus_s_ref: mkr(d_minus_s_ref),
350 jac_c_delta: mkr(jac_c_delta),
351 jac_d_delta_minus_ds: mkr(jac_d_delta_minus_ds),
352 }
353 }
354
355 #[test]
356 fn calc_pred_matches_closed_form() {
357 let mut a = PenaltyLsAcceptor::new();
363 a.nu = 0.5;
364 a.cache = Some(cache_for_test(
365 3.0,
366 0.0,
367 2.0,
368 4.0,
369 vec![1.0, 2.0],
370 vec![4.0],
371 vec![-1.0, -1.0],
372 vec![-2.0],
373 ));
374 assert!((a.calc_pred(0.5) - 0.0).abs() < 1e-12);
375 }
376
377 #[test]
378 fn calc_pred_positive_when_directions_align() {
379 let mut a = PenaltyLsAcceptor::new();
383 a.nu = 1.0;
384 a.cache = Some(cache_for_test(
385 3.0,
386 0.0,
387 -2.0,
388 0.0,
389 vec![1.0, 2.0],
390 vec![0.0],
391 vec![-1.0, -2.0],
392 vec![0.0],
393 ));
394 assert!((a.calc_pred(1.0) - 5.0).abs() < 1e-12);
395 }
396
397 #[test]
398 fn check_trial_point_accepts_when_ared_meets_pred() {
399 let mut a = PenaltyLsAcceptor::new();
403 a.nu = 1.0;
404 a.eta_penalty = 0.5;
405 a.cache = Some(cache_for_test(
406 3.0,
407 0.0,
408 -2.0,
409 0.0,
410 vec![1.0, 2.0],
411 vec![0.0],
412 vec![-1.0, -2.0],
413 vec![0.0],
414 ));
415 assert_eq!(
416 a.check_trial_point(1.0, 3.0, 0.0, -2.0, 0.0, -3.0),
417 AcceptDecision::Accept
418 );
419 }
420
421 #[test]
422 fn check_trial_point_rejects_insufficient_decrease() {
423 let mut a = PenaltyLsAcceptor::new();
426 a.nu = 1.0;
427 a.eta_penalty = 0.5;
428 a.cache = Some(cache_for_test(
429 3.0,
430 0.0,
431 -2.0,
432 0.0,
433 vec![1.0, 2.0],
434 vec![0.0],
435 vec![-1.0, -2.0],
436 vec![0.0],
437 ));
438 assert_eq!(
439 a.check_trial_point(1.0, 3.0, 0.0, -2.0, 2.999, 0.0),
440 AcceptDecision::Reject
441 );
442 }
443}