1use crate::alg_types::SolverReturn;
20use crate::return_codes::AlgorithmMode;
21use pounce_common::types::{Index, Number};
22use std::collections::BTreeMap;
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum Linearity {
27 Linear,
28 NonLinear,
29}
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum IndexStyle {
36 C = 0,
37 Fortran = 1,
38}
39
40#[derive(Debug, Clone, Copy)]
42pub struct NlpInfo {
43 pub n: Index,
44 pub m: Index,
45 pub nnz_jac_g: Index,
46 pub nnz_h_lag: Index,
47 pub index_style: IndexStyle,
48}
49
50#[derive(Debug, Default, Clone)]
53pub struct MetaData {
54 pub strings: BTreeMap<String, Vec<String>>,
55 pub integers: BTreeMap<String, Vec<Index>>,
56 pub numerics: BTreeMap<String, Vec<Number>>,
57}
58
59#[derive(Debug)]
61pub struct BoundsInfo<'a> {
62 pub x_l: &'a mut [Number],
63 pub x_u: &'a mut [Number],
64 pub g_l: &'a mut [Number],
65 pub g_u: &'a mut [Number],
66}
67
68#[derive(Debug)]
71pub struct StartingPoint<'a> {
72 pub init_x: bool,
73 pub x: &'a mut [Number],
74 pub init_z: bool,
75 pub z_l: &'a mut [Number],
76 pub z_u: &'a mut [Number],
77 pub init_lambda: bool,
78 pub lambda: &'a mut [Number],
79}
80
81#[derive(Debug)]
83pub struct ScalingRequest<'a> {
84 pub obj_scaling: &'a mut Number,
85 pub use_x_scaling: &'a mut bool,
86 pub x_scaling: &'a mut [Number],
87 pub use_g_scaling: &'a mut bool,
88 pub g_scaling: &'a mut [Number],
89}
90
91#[derive(Debug)]
95pub enum SparsityRequest<'a> {
96 Structure {
100 irow: &'a mut [Index],
101 jcol: &'a mut [Index],
102 },
103 Values { values: &'a mut [Number] },
107}
108
109#[derive(Debug)]
111pub struct Solution<'a> {
112 pub status: SolverReturn,
113 pub x: &'a [Number],
114 pub z_l: &'a [Number],
115 pub z_u: &'a [Number],
116 pub g: &'a [Number],
117 pub lambda: &'a [Number],
118 pub obj_value: Number,
119}
120
121#[derive(Debug, Clone, Copy)]
123pub struct IterStats {
124 pub mode: AlgorithmMode,
125 pub iter: Index,
126 pub obj_value: Number,
127 pub inf_pr: Number,
128 pub inf_du: Number,
129 pub mu: Number,
130 pub d_norm: Number,
131 pub regularization_size: Number,
132 pub alpha_du: Number,
133 pub alpha_pr: Number,
134 pub ls_trials: Index,
135}
136
137#[derive(Debug, Default)]
141pub struct IpoptData {
142 _private: (),
143}
144
145#[derive(Debug, Default)]
148pub struct IpoptCq {
149 _private: (),
150}
151
152pub trait TNLP {
158 fn get_nlp_info(&mut self) -> Option<NlpInfo>;
160
161 fn get_bounds_info(&mut self, b: BoundsInfo<'_>) -> bool;
163
164 fn get_starting_point(&mut self, sp: StartingPoint<'_>) -> bool;
166
167 fn eval_f(&mut self, x: &[Number], new_x: bool) -> Option<Number>;
169
170 fn eval_grad_f(&mut self, x: &[Number], new_x: bool, grad_f: &mut [Number]) -> bool;
172
173 fn eval_g(&mut self, x: &[Number], new_x: bool, g: &mut [Number]) -> bool;
175
176 fn eval_jac_g(&mut self, x: Option<&[Number]>, new_x: bool, mode: SparsityRequest<'_>) -> bool;
179
180 fn eval_h(
184 &mut self,
185 _x: Option<&[Number]>,
186 _new_x: bool,
187 _obj_factor: Number,
188 _lambda: Option<&[Number]>,
189 _new_lambda: bool,
190 _mode: SparsityRequest<'_>,
191 ) -> bool {
192 false
193 }
194
195 fn finalize_solution(&mut self, sol: Solution<'_>, ip_data: &IpoptData, ip_cq: &IpoptCq);
197
198 fn get_var_con_metadata(&mut self, _var: &mut MetaData, _con: &mut MetaData) -> bool {
203 false
204 }
205
206 fn get_scaling_parameters(&mut self, _req: ScalingRequest<'_>) -> bool {
209 false
210 }
211
212 fn get_variables_linearity(&mut self, _types: &mut [Linearity]) -> bool {
214 false
215 }
216
217 fn get_constraints_linearity(&mut self, _types: &mut [Linearity]) -> bool {
220 false
221 }
222
223 fn get_number_of_nonlinear_variables(&mut self) -> Index {
226 -1
227 }
228
229 fn get_list_of_nonlinear_variables(&mut self, _pos_nonlin_vars: &mut [Index]) -> bool {
232 false
233 }
234
235 fn intermediate_callback(
238 &mut self,
239 _stats: IterStats,
240 _ip_data: &IpoptData,
241 _ip_cq: &IpoptCq,
242 ) -> bool {
243 true
244 }
245
246 fn finalize_metadata(&mut self, _var: &MetaData, _con: &MetaData) {}
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254
255 struct Mini;
259 impl TNLP for Mini {
260 fn get_nlp_info(&mut self) -> Option<NlpInfo> {
261 Some(NlpInfo {
262 n: 2,
263 m: 1,
264 nnz_jac_g: 2,
265 nnz_h_lag: 2,
266 index_style: IndexStyle::C,
267 })
268 }
269 fn get_bounds_info(&mut self, b: BoundsInfo<'_>) -> bool {
270 b.x_l.iter_mut().for_each(|v| *v = -1e19);
271 b.x_u.iter_mut().for_each(|v| *v = 1e19);
272 b.g_l[0] = 1.0;
273 b.g_u[0] = 1.0;
274 true
275 }
276 fn get_starting_point(&mut self, sp: StartingPoint<'_>) -> bool {
277 assert!(sp.init_x);
278 sp.x[0] = 0.5;
279 sp.x[1] = 0.5;
280 true
281 }
282 fn eval_f(&mut self, x: &[Number], _new_x: bool) -> Option<Number> {
283 Some(x[0] * x[0] + x[1] * x[1])
284 }
285 fn eval_grad_f(&mut self, x: &[Number], _new_x: bool, grad_f: &mut [Number]) -> bool {
286 grad_f[0] = 2.0 * x[0];
287 grad_f[1] = 2.0 * x[1];
288 true
289 }
290 fn eval_g(&mut self, x: &[Number], _new_x: bool, g: &mut [Number]) -> bool {
291 g[0] = x[0] + x[1];
292 true
293 }
294 fn eval_jac_g(
295 &mut self,
296 _x: Option<&[Number]>,
297 _new_x: bool,
298 mode: SparsityRequest<'_>,
299 ) -> bool {
300 match mode {
301 SparsityRequest::Structure { irow, jcol } => {
302 irow.copy_from_slice(&[0, 0]);
303 jcol.copy_from_slice(&[0, 1]);
304 }
305 SparsityRequest::Values { values } => {
306 values.copy_from_slice(&[1.0, 1.0]);
307 }
308 }
309 true
310 }
311 fn finalize_solution(&mut self, _sol: Solution<'_>, _d: &IpoptData, _q: &IpoptCq) {}
312 }
313
314 #[test]
315 fn tnlp_is_object_safe() {
316 let mut t: Box<dyn TNLP> = Box::new(Mini);
319 let info = t.get_nlp_info().expect("get_nlp_info");
320 assert_eq!(info.n, 2);
321 assert_eq!(info.m, 1);
322 assert_eq!(info.index_style, IndexStyle::C);
323
324 let mut x_l = [0.0; 2];
325 let mut x_u = [0.0; 2];
326 let mut g_l = [0.0; 1];
327 let mut g_u = [0.0; 1];
328 assert!(t.get_bounds_info(BoundsInfo {
329 x_l: &mut x_l,
330 x_u: &mut x_u,
331 g_l: &mut g_l,
332 g_u: &mut g_u
333 }));
334 assert_eq!(g_l[0], 1.0);
335
336 let mut grad = [0.0; 2];
337 assert!(t.eval_grad_f(&[3.0, 4.0], true, &mut grad));
338 assert_eq!(grad, [6.0, 8.0]);
339
340 let mut tmp_v = [0.0; 0];
342 assert!(!t.eval_h(
343 None,
344 false,
345 1.0,
346 None,
347 false,
348 SparsityRequest::Values { values: &mut tmp_v }
349 ));
350
351 assert_eq!(t.get_number_of_nonlinear_variables(), -1);
353 }
354
355 #[test]
356 fn sparsity_request_round_trip() {
357 let mut t = Mini;
358 let mut irow = [0; 2];
359 let mut jcol = [0; 2];
360 assert!(t.eval_jac_g(
361 None,
362 false,
363 SparsityRequest::Structure {
364 irow: &mut irow,
365 jcol: &mut jcol
366 }
367 ));
368 assert_eq!(irow, [0, 0]);
369 assert_eq!(jcol, [0, 1]);
370
371 let mut vals = [0.0; 2];
372 assert!(t.eval_jac_g(
373 Some(&[1.0, 2.0]),
374 true,
375 SparsityRequest::Values { values: &mut vals }
376 ));
377 assert_eq!(vals, [1.0, 1.0]);
378 }
379}