1use crate::alg_types::SolverReturn;
20use crate::return_codes::AlgorithmMode;
21use pounce_common::types::{Index, Number};
22use std::collections::BTreeMap;
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum Linearity {
27 Linear,
28 NonLinear,
29}
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum IndexStyle {
36 C = 0,
37 Fortran = 1,
38}
39
40#[derive(Debug, Clone, Copy)]
42pub struct NlpInfo {
43 pub n: Index,
44 pub m: Index,
45 pub nnz_jac_g: Index,
46 pub nnz_h_lag: Index,
47 pub index_style: IndexStyle,
48}
49
50#[derive(Debug, Default, Clone)]
53pub struct MetaData {
54 pub strings: BTreeMap<String, Vec<String>>,
55 pub integers: BTreeMap<String, Vec<Index>>,
56 pub numerics: BTreeMap<String, Vec<Number>>,
57}
58
59pub const IDX_NAMES: &str = "idx_names";
68
69#[derive(Debug)]
71pub struct BoundsInfo<'a> {
72 pub x_l: &'a mut [Number],
73 pub x_u: &'a mut [Number],
74 pub g_l: &'a mut [Number],
75 pub g_u: &'a mut [Number],
76}
77
78#[derive(Debug)]
81pub struct StartingPoint<'a> {
82 pub init_x: bool,
83 pub x: &'a mut [Number],
84 pub init_z: bool,
85 pub z_l: &'a mut [Number],
86 pub z_u: &'a mut [Number],
87 pub init_lambda: bool,
88 pub lambda: &'a mut [Number],
89}
90
91#[derive(Debug)]
93pub struct ScalingRequest<'a> {
94 pub obj_scaling: &'a mut Number,
95 pub use_x_scaling: &'a mut bool,
96 pub x_scaling: &'a mut [Number],
97 pub use_g_scaling: &'a mut bool,
98 pub g_scaling: &'a mut [Number],
99}
100
101#[derive(Debug)]
105pub enum SparsityRequest<'a> {
106 Structure {
110 irow: &'a mut [Index],
111 jcol: &'a mut [Index],
112 },
113 Values { values: &'a mut [Number] },
117}
118
119#[derive(Debug)]
121pub struct Solution<'a> {
122 pub status: SolverReturn,
123 pub x: &'a [Number],
124 pub z_l: &'a [Number],
125 pub z_u: &'a [Number],
126 pub g: &'a [Number],
127 pub lambda: &'a [Number],
128 pub obj_value: Number,
129}
130
131#[derive(Debug, Clone, Copy)]
133pub struct IterStats {
134 pub mode: AlgorithmMode,
135 pub iter: Index,
136 pub obj_value: Number,
137 pub inf_pr: Number,
138 pub inf_du: Number,
139 pub mu: Number,
140 pub d_norm: Number,
141 pub regularization_size: Number,
142 pub alpha_du: Number,
143 pub alpha_pr: Number,
144 pub ls_trials: Index,
145}
146
147#[derive(Debug, Default)]
151pub struct IpoptData {
152 _private: (),
153}
154
155#[derive(Debug, Default)]
158pub struct IpoptCq {
159 _private: (),
160}
161
162pub trait TNLP {
168 fn get_nlp_info(&mut self) -> Option<NlpInfo>;
170
171 fn get_bounds_info(&mut self, b: BoundsInfo<'_>) -> bool;
173
174 fn get_starting_point(&mut self, sp: StartingPoint<'_>) -> bool;
176
177 fn eval_f(&mut self, x: &[Number], new_x: bool) -> Option<Number>;
179
180 fn eval_grad_f(&mut self, x: &[Number], new_x: bool, grad_f: &mut [Number]) -> bool;
182
183 fn eval_g(&mut self, x: &[Number], new_x: bool, g: &mut [Number]) -> bool;
185
186 fn eval_jac_g(&mut self, x: Option<&[Number]>, new_x: bool, mode: SparsityRequest<'_>) -> bool;
189
190 fn eval_h(
194 &mut self,
195 _x: Option<&[Number]>,
196 _new_x: bool,
197 _obj_factor: Number,
198 _lambda: Option<&[Number]>,
199 _new_lambda: bool,
200 _mode: SparsityRequest<'_>,
201 ) -> bool {
202 false
203 }
204
205 fn finalize_solution(&mut self, sol: Solution<'_>, ip_data: &IpoptData, ip_cq: &IpoptCq);
207
208 fn get_var_con_metadata(&mut self, _var: &mut MetaData, _con: &mut MetaData) -> bool {
213 false
214 }
215
216 fn get_scaling_parameters(&mut self, _req: ScalingRequest<'_>) -> bool {
219 false
220 }
221
222 fn get_variables_linearity(&mut self, _types: &mut [Linearity]) -> bool {
224 false
225 }
226
227 fn get_constraints_linearity(&mut self, _types: &mut [Linearity]) -> bool {
230 false
231 }
232
233 fn get_number_of_nonlinear_variables(&mut self) -> Index {
236 -1
237 }
238
239 fn get_list_of_nonlinear_variables(&mut self, _pos_nonlin_vars: &mut [Index]) -> bool {
242 false
243 }
244
245 fn intermediate_callback(
248 &mut self,
249 _stats: IterStats,
250 _ip_data: &IpoptData,
251 _ip_cq: &IpoptCq,
252 ) -> bool {
253 true
254 }
255
256 fn finalize_metadata(&mut self, _var: &MetaData, _con: &MetaData) {}
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264
265 struct Mini;
269 impl TNLP for Mini {
270 fn get_nlp_info(&mut self) -> Option<NlpInfo> {
271 Some(NlpInfo {
272 n: 2,
273 m: 1,
274 nnz_jac_g: 2,
275 nnz_h_lag: 2,
276 index_style: IndexStyle::C,
277 })
278 }
279 fn get_bounds_info(&mut self, b: BoundsInfo<'_>) -> bool {
280 b.x_l.iter_mut().for_each(|v| *v = -1e19);
281 b.x_u.iter_mut().for_each(|v| *v = 1e19);
282 b.g_l[0] = 1.0;
283 b.g_u[0] = 1.0;
284 true
285 }
286 fn get_starting_point(&mut self, sp: StartingPoint<'_>) -> bool {
287 assert!(sp.init_x);
288 sp.x[0] = 0.5;
289 sp.x[1] = 0.5;
290 true
291 }
292 fn eval_f(&mut self, x: &[Number], _new_x: bool) -> Option<Number> {
293 Some(x[0] * x[0] + x[1] * x[1])
294 }
295 fn eval_grad_f(&mut self, x: &[Number], _new_x: bool, grad_f: &mut [Number]) -> bool {
296 grad_f[0] = 2.0 * x[0];
297 grad_f[1] = 2.0 * x[1];
298 true
299 }
300 fn eval_g(&mut self, x: &[Number], _new_x: bool, g: &mut [Number]) -> bool {
301 g[0] = x[0] + x[1];
302 true
303 }
304 fn eval_jac_g(
305 &mut self,
306 _x: Option<&[Number]>,
307 _new_x: bool,
308 mode: SparsityRequest<'_>,
309 ) -> bool {
310 match mode {
311 SparsityRequest::Structure { irow, jcol } => {
312 irow.copy_from_slice(&[0, 0]);
313 jcol.copy_from_slice(&[0, 1]);
314 }
315 SparsityRequest::Values { values } => {
316 values.copy_from_slice(&[1.0, 1.0]);
317 }
318 }
319 true
320 }
321 fn finalize_solution(&mut self, _sol: Solution<'_>, _d: &IpoptData, _q: &IpoptCq) {}
322 }
323
324 #[test]
325 fn tnlp_is_object_safe() {
326 let mut t: Box<dyn TNLP> = Box::new(Mini);
329 let info = t.get_nlp_info().expect("get_nlp_info");
330 assert_eq!(info.n, 2);
331 assert_eq!(info.m, 1);
332 assert_eq!(info.index_style, IndexStyle::C);
333
334 let mut x_l = [0.0; 2];
335 let mut x_u = [0.0; 2];
336 let mut g_l = [0.0; 1];
337 let mut g_u = [0.0; 1];
338 assert!(t.get_bounds_info(BoundsInfo {
339 x_l: &mut x_l,
340 x_u: &mut x_u,
341 g_l: &mut g_l,
342 g_u: &mut g_u
343 }));
344 assert_eq!(g_l[0], 1.0);
345
346 let mut grad = [0.0; 2];
347 assert!(t.eval_grad_f(&[3.0, 4.0], true, &mut grad));
348 assert_eq!(grad, [6.0, 8.0]);
349
350 let mut tmp_v = [0.0; 0];
352 assert!(!t.eval_h(
353 None,
354 false,
355 1.0,
356 None,
357 false,
358 SparsityRequest::Values { values: &mut tmp_v }
359 ));
360
361 assert_eq!(t.get_number_of_nonlinear_variables(), -1);
363 }
364
365 #[test]
366 fn sparsity_request_round_trip() {
367 let mut t = Mini;
368 let mut irow = [0; 2];
369 let mut jcol = [0; 2];
370 assert!(t.eval_jac_g(
371 None,
372 false,
373 SparsityRequest::Structure {
374 irow: &mut irow,
375 jcol: &mut jcol
376 }
377 ));
378 assert_eq!(irow, [0, 0]);
379 assert_eq!(jcol, [0, 1]);
380
381 let mut vals = [0.0; 2];
382 assert!(t.eval_jac_g(
383 Some(&[1.0, 2.0]),
384 true,
385 SparsityRequest::Values { values: &mut vals }
386 ));
387 assert_eq!(vals, [1.0, 1.0]);
388 }
389}