use crate::alg_types::SolverReturn;
use crate::return_codes::AlgorithmMode;
use pounce_common::types::{Index, Number};
use std::collections::BTreeMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Linearity {
Linear,
NonLinear,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum IndexStyle {
C = 0,
Fortran = 1,
}
#[derive(Debug, Clone, Copy)]
pub struct NlpInfo {
pub n: Index,
pub m: Index,
pub nnz_jac_g: Index,
pub nnz_h_lag: Index,
pub index_style: IndexStyle,
}
#[derive(Debug, Default, Clone)]
pub struct MetaData {
pub strings: BTreeMap<String, Vec<String>>,
pub integers: BTreeMap<String, Vec<Index>>,
pub numerics: BTreeMap<String, Vec<Number>>,
}
#[derive(Debug)]
pub struct BoundsInfo<'a> {
pub x_l: &'a mut [Number],
pub x_u: &'a mut [Number],
pub g_l: &'a mut [Number],
pub g_u: &'a mut [Number],
}
#[derive(Debug)]
pub struct StartingPoint<'a> {
pub init_x: bool,
pub x: &'a mut [Number],
pub init_z: bool,
pub z_l: &'a mut [Number],
pub z_u: &'a mut [Number],
pub init_lambda: bool,
pub lambda: &'a mut [Number],
}
#[derive(Debug)]
pub struct ScalingRequest<'a> {
pub obj_scaling: &'a mut Number,
pub use_x_scaling: &'a mut bool,
pub x_scaling: &'a mut [Number],
pub use_g_scaling: &'a mut bool,
pub g_scaling: &'a mut [Number],
}
#[derive(Debug)]
pub enum SparsityRequest<'a> {
Structure {
irow: &'a mut [Index],
jcol: &'a mut [Index],
},
Values { values: &'a mut [Number] },
}
#[derive(Debug)]
pub struct Solution<'a> {
pub status: SolverReturn,
pub x: &'a [Number],
pub z_l: &'a [Number],
pub z_u: &'a [Number],
pub g: &'a [Number],
pub lambda: &'a [Number],
pub obj_value: Number,
}
#[derive(Debug, Clone, Copy)]
pub struct IterStats {
pub mode: AlgorithmMode,
pub iter: Index,
pub obj_value: Number,
pub inf_pr: Number,
pub inf_du: Number,
pub mu: Number,
pub d_norm: Number,
pub regularization_size: Number,
pub alpha_du: Number,
pub alpha_pr: Number,
pub ls_trials: Index,
}
#[derive(Debug, Default)]
pub struct IpoptData {
_private: (),
}
#[derive(Debug, Default)]
pub struct IpoptCq {
_private: (),
}
pub trait TNLP {
fn get_nlp_info(&mut self) -> Option<NlpInfo>;
fn get_bounds_info(&mut self, b: BoundsInfo<'_>) -> bool;
fn get_starting_point(&mut self, sp: StartingPoint<'_>) -> bool;
fn eval_f(&mut self, x: &[Number], new_x: bool) -> Option<Number>;
fn eval_grad_f(&mut self, x: &[Number], new_x: bool, grad_f: &mut [Number]) -> bool;
fn eval_g(&mut self, x: &[Number], new_x: bool, g: &mut [Number]) -> bool;
fn eval_jac_g(&mut self, x: Option<&[Number]>, new_x: bool, mode: SparsityRequest<'_>) -> bool;
fn eval_h(
&mut self,
_x: Option<&[Number]>,
_new_x: bool,
_obj_factor: Number,
_lambda: Option<&[Number]>,
_new_lambda: bool,
_mode: SparsityRequest<'_>,
) -> bool {
false
}
fn finalize_solution(&mut self, sol: Solution<'_>, ip_data: &IpoptData, ip_cq: &IpoptCq);
fn get_var_con_metadata(&mut self, _var: &mut MetaData, _con: &mut MetaData) -> bool {
false
}
fn get_scaling_parameters(&mut self, _req: ScalingRequest<'_>) -> bool {
false
}
fn get_variables_linearity(&mut self, _types: &mut [Linearity]) -> bool {
false
}
fn get_constraints_linearity(&mut self, _types: &mut [Linearity]) -> bool {
false
}
fn get_number_of_nonlinear_variables(&mut self) -> Index {
-1
}
fn get_list_of_nonlinear_variables(&mut self, _pos_nonlin_vars: &mut [Index]) -> bool {
false
}
fn intermediate_callback(
&mut self,
_stats: IterStats,
_ip_data: &IpoptData,
_ip_cq: &IpoptCq,
) -> bool {
true
}
fn finalize_metadata(&mut self, _var: &MetaData, _con: &MetaData) {}
}
#[cfg(test)]
mod tests {
use super::*;
struct Mini;
impl TNLP for Mini {
fn get_nlp_info(&mut self) -> Option<NlpInfo> {
Some(NlpInfo {
n: 2,
m: 1,
nnz_jac_g: 2,
nnz_h_lag: 2,
index_style: IndexStyle::C,
})
}
fn get_bounds_info(&mut self, b: BoundsInfo<'_>) -> bool {
b.x_l.iter_mut().for_each(|v| *v = -1e19);
b.x_u.iter_mut().for_each(|v| *v = 1e19);
b.g_l[0] = 1.0;
b.g_u[0] = 1.0;
true
}
fn get_starting_point(&mut self, sp: StartingPoint<'_>) -> bool {
assert!(sp.init_x);
sp.x[0] = 0.5;
sp.x[1] = 0.5;
true
}
fn eval_f(&mut self, x: &[Number], _new_x: bool) -> Option<Number> {
Some(x[0] * x[0] + x[1] * x[1])
}
fn eval_grad_f(&mut self, x: &[Number], _new_x: bool, grad_f: &mut [Number]) -> bool {
grad_f[0] = 2.0 * x[0];
grad_f[1] = 2.0 * x[1];
true
}
fn eval_g(&mut self, x: &[Number], _new_x: bool, g: &mut [Number]) -> bool {
g[0] = x[0] + x[1];
true
}
fn eval_jac_g(
&mut self,
_x: Option<&[Number]>,
_new_x: bool,
mode: SparsityRequest<'_>,
) -> bool {
match mode {
SparsityRequest::Structure { irow, jcol } => {
irow.copy_from_slice(&[0, 0]);
jcol.copy_from_slice(&[0, 1]);
}
SparsityRequest::Values { values } => {
values.copy_from_slice(&[1.0, 1.0]);
}
}
true
}
fn finalize_solution(&mut self, _sol: Solution<'_>, _d: &IpoptData, _q: &IpoptCq) {}
}
#[test]
fn tnlp_is_object_safe() {
let mut t: Box<dyn TNLP> = Box::new(Mini);
let info = t.get_nlp_info().expect("get_nlp_info");
assert_eq!(info.n, 2);
assert_eq!(info.m, 1);
assert_eq!(info.index_style, IndexStyle::C);
let mut x_l = [0.0; 2];
let mut x_u = [0.0; 2];
let mut g_l = [0.0; 1];
let mut g_u = [0.0; 1];
assert!(t.get_bounds_info(BoundsInfo {
x_l: &mut x_l,
x_u: &mut x_u,
g_l: &mut g_l,
g_u: &mut g_u
}));
assert_eq!(g_l[0], 1.0);
let mut grad = [0.0; 2];
assert!(t.eval_grad_f(&[3.0, 4.0], true, &mut grad));
assert_eq!(grad, [6.0, 8.0]);
let mut tmp_v = [0.0; 0];
assert!(!t.eval_h(
None,
false,
1.0,
None,
false,
SparsityRequest::Values { values: &mut tmp_v }
));
assert_eq!(t.get_number_of_nonlinear_variables(), -1);
}
#[test]
fn sparsity_request_round_trip() {
let mut t = Mini;
let mut irow = [0; 2];
let mut jcol = [0; 2];
assert!(t.eval_jac_g(
None,
false,
SparsityRequest::Structure {
irow: &mut irow,
jcol: &mut jcol
}
));
assert_eq!(irow, [0, 0]);
assert_eq!(jcol, [0, 1]);
let mut vals = [0.0; 2];
assert!(t.eval_jac_g(
Some(&[1.0, 2.0]),
true,
SparsityRequest::Values { values: &mut vals }
));
assert_eq!(vals, [1.0, 1.0]);
}
}