use crate::local_solver::builders::{COBYLABuilder, LocalSolverConfig};
use crate::problem::Problem;
use ndarray::Array1;
use std::fmt;
use std::ops::Index;
use thiserror::Error;
#[cfg(feature = "checkpointing")]
use std::path::PathBuf;
#[derive(Debug, Clone)]
#[cfg_attr(feature = "checkpointing", derive(serde::Serialize, serde::Deserialize))]
pub struct OQNLPParams {
pub iterations: usize,
pub population_size: usize,
pub wait_cycle: usize,
pub threshold_factor: f64,
pub distance_factor: f64,
pub local_solver_type: LocalSolverType,
pub local_solver_config: LocalSolverConfig,
pub seed: u64,
}
impl Default for OQNLPParams {
fn default() -> Self {
Self {
iterations: 300,
population_size: 1000,
wait_cycle: 15,
threshold_factor: 0.2,
distance_factor: 0.75,
local_solver_type: LocalSolverType::COBYLA,
local_solver_config: COBYLABuilder::default().build(),
seed: 0,
}
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "checkpointing", derive(serde::Serialize, serde::Deserialize))]
pub struct FilterParams {
pub distance_factor: f64,
pub wait_cycle: usize,
pub threshold_factor: f64,
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "checkpointing", derive(serde::Serialize, serde::Deserialize))]
pub struct LocalSolution {
pub point: Array1<f64>,
pub objective: f64,
}
impl LocalSolution {
pub fn fun(&self) -> f64 {
self.objective
}
pub fn x(&self) -> Array1<f64> {
self.point.clone()
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "checkpointing", derive(serde::Serialize, serde::Deserialize))]
pub struct SolutionSet {
pub solutions: Array1<LocalSolution>,
}
impl SolutionSet {
pub fn len(&self) -> usize {
self.solutions.len()
}
pub fn is_empty(&self) -> bool {
self.solutions.is_empty()
}
pub fn best_solution(&self) -> Option<&LocalSolution> {
self.solutions.iter().min_by(|a, b| a.objective.total_cmp(&b.objective))
}
pub fn solutions(&self) -> impl Iterator<Item = &LocalSolution> {
self.solutions.iter()
}
pub fn display_with_constraints<P: Problem>(
&self,
problem: &P,
constraint_descriptions: Option<&[&str]>,
) -> String {
let mut result = String::new();
let constraints = problem.constraints();
result.push_str("━━━━━━━━━━━ Solution Set ━━━━━━━━━━━\n");
result.push_str(&format!("Total solutions: {}\n", self.solutions.len()));
if !self.solutions.is_empty() {
if let Some(best) = self.best_solution() {
result.push_str(&format!("Best objective value: {:.8e}\n", best.objective));
}
}
result.push_str("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n");
for (i, solution) in self.solutions.iter().enumerate() {
result.push_str(&format!("Solution #{}\n", i + 1));
result.push_str(&format!(" Objective: {:.8e}\n", solution.objective));
result.push_str(" Parameters:\n");
result.push_str(&format!(" {:.8e}\n", solution.point));
if !constraints.is_empty() {
result.push_str(" Constraint violations:\n");
for (j, constraint_fn) in constraints.iter().enumerate() {
let x_slice: Vec<f64> = solution.point.to_vec();
let constraint_value = constraint_fn(&x_slice, &mut ());
let status = if constraint_value >= 0.0 { "✓" } else { "✗" };
let violation = if constraint_value < 0.0 {
format!(" (violated by {:.6})", -constraint_value)
} else {
" (satisfied)".to_string()
};
let description = if let Some(descriptions) = constraint_descriptions {
if j < descriptions.len() {
format!(" [{}]", descriptions[j])
} else {
String::new()
}
} else {
String::new()
};
result.push_str(&format!(
" Constraint {}{}: {} {:.6e}{}\n",
j + 1,
description,
status,
constraint_value,
violation
));
}
}
if i < self.solutions.len() - 1 {
result.push_str("――――――――――――――――――――――――――――――――――――\n");
}
}
result
}
pub fn display_with_problem<P: Problem>(&self, problem: &P) -> String {
let constraints = problem.constraints();
if constraints.is_empty() {
format!("{}", self)
} else {
self.display_with_constraints(problem, None)
}
}
}
impl Index<usize> for SolutionSet {
type Output = LocalSolution;
fn index(&self, index: usize) -> &Self::Output {
&self.solutions[index]
}
}
impl fmt::Display for SolutionSet {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let len: usize = self.solutions.len();
writeln!(f, "━━━━━━━━━━━ Solution Set ━━━━━━━━━━━")?;
writeln!(f, "Total solutions: {}", self.solutions.len())?;
if len > 0 {
if let Some(best) = self.best_solution() {
writeln!(f, "Best objective value: {:.8e}", best.objective)?;
}
}
writeln!(f, "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")?;
for (i, solution) in self.solutions.iter().enumerate() {
writeln!(f, "Solution #{}", i + 1)?;
writeln!(f, " Objective: {:.8e}", solution.objective)?;
writeln!(f, " Parameters:")?;
writeln!(f, " {:.8e}", solution.point)?;
if i < self.solutions.len() - 1 {
writeln!(f, "――――――――――――――――――――――――――――――――――――")?;
}
}
Ok(())
}
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "checkpointing", derive(serde::Serialize, serde::Deserialize))]
pub enum LocalSolverType {
#[cfg(feature = "argmin")]
LBFGS,
#[cfg(feature = "argmin")]
NelderMead,
#[cfg(feature = "argmin")]
SteepestDescent,
#[cfg(feature = "argmin")]
TrustRegion,
#[cfg(feature = "argmin")]
NewtonCG,
COBYLA,
}
impl LocalSolverType {
pub fn from_string(s: &str) -> Result<Self, &'static str> {
match s.to_lowercase().as_str() {
#[cfg(feature = "argmin")]
"lbfgs" => Ok(Self::LBFGS),
#[cfg(feature = "argmin")]
"nelder-mead" => Ok(Self::NelderMead),
#[cfg(feature = "argmin")]
"neldermead" => Ok(Self::NelderMead),
#[cfg(feature = "argmin")]
"steepestdescent" => Ok(Self::SteepestDescent),
#[cfg(feature = "argmin")]
"trustregion" => Ok(Self::TrustRegion),
#[cfg(feature = "argmin")]
"newton-cg" => Ok(Self::NewtonCG),
#[cfg(feature = "argmin")]
"newtoncg" => Ok(Self::NewtonCG),
"cobyla" => Ok(Self::COBYLA),
_ => Err("Invalid solver type."),
}
}
}
#[derive(Debug, Error)]
pub enum EvaluationError {
#[error("Invalid input: {reason}")]
InvalidInput { reason: String },
#[error("Division by zero encountered")]
DivisionByZero,
#[error("Negative square root: attempted sqrt({value})")]
NegativeSqrt { value: f64 },
#[error("Objective function not implemented and needed for local solver")]
ObjectiveFunctionNotImplemented,
#[error("Gradient not implemented and needed for local solver")]
GradientNotImplemented,
#[error("Hessian not implemented and needed for local solver")]
HessianNotImplemented,
#[error("Objective function evaluation failed: {reason}")]
ObjectiveFunctionEvaluationFailed { reason: String },
#[error("Gradient evaluation failed: {reason}")]
GradientEvaluationFailed { reason: String },
#[error("Hessian evaluation failed: {reason}")]
HessianEvaluationFailed { reason: String },
#[error(
"Constraints not implemented and needed for constrained solver (only COBYLA supports constraints)"
)]
ConstraintNotImplemented,
#[error("Invalid constraint index {index}, valid range is 0..{max_index}")]
InvalidConstraintIndex { index: usize, max_index: usize },
#[error("Constraint {index} evaluation failed: {reason}")]
ConstraintEvaluationFailed { index: usize, reason: String },
}
#[cfg(feature = "checkpointing")]
#[derive(Debug, Clone)]
#[cfg_attr(feature = "checkpointing", derive(serde::Serialize, serde::Deserialize))]
pub struct CheckpointConfig {
pub checkpoint_dir: PathBuf,
pub checkpoint_name: String,
pub save_frequency: usize,
pub keep_all: bool,
pub auto_resume: bool,
}
#[cfg(feature = "checkpointing")]
impl Default for CheckpointConfig {
fn default() -> Self {
Self {
checkpoint_dir: PathBuf::from("./checkpoints"),
checkpoint_name: "oqnlp_checkpoint".to_string(),
save_frequency: 25,
keep_all: false,
auto_resume: true,
}
}
}
#[cfg(feature = "checkpointing")]
#[derive(Debug, Clone)]
#[cfg_attr(feature = "checkpointing", derive(serde::Serialize, serde::Deserialize))]
pub struct OQNLPCheckpoint {
pub params: OQNLPParams,
pub current_iteration: usize,
pub merit_threshold: f64,
pub solution_set: Option<SolutionSet>,
#[cfg_attr(
feature = "checkpointing",
serde(
serialize_with = "serialize_vec_array1",
deserialize_with = "deserialize_vec_array1"
)
)]
pub reference_set: Vec<Array1<f64>>,
pub unchanged_cycles: usize,
pub elapsed_time: f64,
pub distance_filter_solutions: Vec<LocalSolution>,
pub current_seed: u64,
pub target_objective: Option<f64>,
pub exclude_out_of_bounds: bool,
#[cfg(feature = "rayon")]
pub batch_iterations: Option<usize>,
#[cfg(feature = "rayon")]
pub enable_parallel: bool,
pub abs_tol: f64,
pub rel_tol: f64,
pub timestamp: String,
}
#[cfg(feature = "checkpointing")]
fn serialize_vec_array1<S>(vec: &Vec<Array1<f64>>, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::SerializeSeq;
let mut seq = serializer.serialize_seq(Some(vec.len()))?;
for array in vec {
let vec_data: Vec<f64> = array.to_vec();
seq.serialize_element(&vec_data)?;
}
seq.end()
}
#[cfg(feature = "checkpointing")]
fn deserialize_vec_array1<'de, D>(deserializer: D) -> Result<Vec<Array1<f64>>, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::Deserialize;
let vec_of_vecs: Vec<Vec<f64>> = Vec::deserialize(deserializer)?;
Ok(vec_of_vecs.into_iter().map(Array1::from).collect())
}
#[cfg(feature = "checkpointing")]
impl fmt::Display for OQNLPCheckpoint {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
writeln!(f, "━━━━━━━━━━━ OQNLP Checkpoint ━━━━━━━━━━━")?;
writeln!(f, "Timestamp: {}", self.timestamp)?;
writeln!(f, "Current iteration: {}", self.current_iteration)?;
writeln!(f, "Elapsed time: {:.2}s", self.elapsed_time)?;
writeln!(f, "Unchanged cycles: {}", self.unchanged_cycles)?;
writeln!(f, "Merit threshold: {:.8e}", self.merit_threshold)?;
writeln!(f, "Reference set size: {}", self.reference_set.len())?;
writeln!(f, "Distance filter solutions: {}", self.distance_filter_solutions.len())?;
writeln!(f, "Current seed: {}", self.current_seed)?;
if let Some(ref solution_set) = self.solution_set {
writeln!(f, "Solution set: {} solutions", solution_set.len())?;
if let Some(best) = solution_set.best_solution() {
writeln!(f, "Best objective: {:.8e}", best.objective)?;
}
} else {
writeln!(f, "Solution set: None")?;
}
writeln!(f, "Parameters:")?;
writeln!(f, " Population size: {}", self.params.population_size)?;
writeln!(f, " Iterations: {}", self.params.iterations)?;
writeln!(f, " Wait cycle: {}", self.params.wait_cycle)?;
writeln!(f, " Threshold factor: {}", self.params.threshold_factor)?;
writeln!(f, " Distance factor: {}", self.params.distance_factor)?;
writeln!(f, " Local solver: {:?}", self.params.local_solver_type)?;
writeln!(f, " Seed: {}", self.params.seed)?;
if let Some(target) = self.target_objective {
writeln!(f, " Target objective: {:.8e}", target)?;
} else {
writeln!(f, " Target objective: None")?;
}
writeln!(f, " Exclude out of bounds: {}", self.exclude_out_of_bounds)?;
writeln!(f, " Absolute tolerance: {:.2e}", self.abs_tol)?;
writeln!(f, " Relative tolerance: {:.2e}", self.rel_tol)?;
writeln!(f, "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")?;
Ok(())
}
}
#[cfg(test)]
mod tests_types {
use super::*;
use ndarray::array;
#[test]
fn test_oqnlp_params_default() {
let params = OQNLPParams::default();
assert_eq!(params.iterations, 300);
assert_eq!(params.population_size, 1000);
assert_eq!(params.wait_cycle, 15);
assert_eq!(params.threshold_factor, 0.2);
assert_eq!(params.distance_factor, 0.75);
assert_eq!(params.seed, 0);
}
#[test]
fn test_solution_set_len() {
let solutions = Array1::from_vec(vec![
LocalSolution { point: array![1.0, 2.0], objective: -1.0 },
LocalSolution { point: array![3.0, 4.0], objective: -2.0 },
]);
let solution_set: SolutionSet = SolutionSet { solutions };
assert_eq!(solution_set.len(), 2);
}
#[test]
fn test_solution_set_is_empty() {
let solutions: Array1<LocalSolution> = Array1::from_vec(vec![]);
let solution_set: SolutionSet = SolutionSet { solutions };
assert!(solution_set.is_empty());
let solutions: Array1<LocalSolution> =
Array1::from_vec(vec![LocalSolution { point: array![1.0], objective: -1.0 }]);
let solution_set: SolutionSet = SolutionSet { solutions };
assert!(!solution_set.is_empty());
}
#[test]
fn test_solution_set_index() {
let solutions: Array1<LocalSolution> = Array1::from_vec(vec![
LocalSolution { point: array![1.0, 2.0], objective: -1.0 },
LocalSolution { point: array![3.0, 4.0], objective: -2.0 },
]);
let solution_set: SolutionSet = SolutionSet { solutions };
assert_eq!(solution_set[0].objective, -1.0);
assert_eq!(solution_set[1].objective, -2.0);
}
#[test]
fn test_solution_set_display() {
let solutions: Array1<LocalSolution> =
Array1::from_vec(vec![LocalSolution { point: array![1.0], objective: -1.0 }]);
let solution_set: SolutionSet = SolutionSet { solutions };
println!("{}", solution_set);
let display_output: String = format!("{}", solution_set);
assert!(display_output.contains("Solution Set"));
assert!(display_output.contains("Total solutions: 1"));
assert!(display_output.contains("Best objective value"));
assert!(display_output.contains("Solution #1"));
}
#[test]
fn test_empty_solution_set_display() {
let solutions: Array1<LocalSolution> = Array1::from_vec(vec![]);
let solution_set: SolutionSet = SolutionSet { solutions };
let display_output: String = format!("{}", solution_set);
assert!(display_output.contains("Solution Set"));
assert!(display_output.contains("Total solutions: 0"));
}
#[test]
#[should_panic]
fn test_solution_set_index_out_of_bounds() {
let solutions: Array1<LocalSolution> = Array1::from_vec(vec![]);
let solution_set: SolutionSet = SolutionSet { solutions };
let _should_panic: LocalSolution = solution_set[0].clone();
}
#[cfg(feature = "argmin")]
#[test]
fn test_local_solver_type_from_string() {
assert_eq!(LocalSolverType::from_string("LBFGS"), Ok(LocalSolverType::LBFGS));
assert_eq!(LocalSolverType::from_string("Nelder-Mead"), Ok(LocalSolverType::NelderMead));
assert_eq!(
LocalSolverType::from_string("SteepestDescent"),
Ok(LocalSolverType::SteepestDescent)
);
assert_eq!(LocalSolverType::from_string("TrustRegion"), Ok(LocalSolverType::TrustRegion));
assert_eq!(LocalSolverType::from_string("NewtonCG"), Ok(LocalSolverType::NewtonCG));
assert_eq!(LocalSolverType::from_string("Invalid"), Err("Invalid solver type."));
}
#[test]
fn test_local_solution_f_x() {
let local_solution = LocalSolution { point: array![1.0], objective: -1.0 };
assert_eq!(local_solution.fun(), -1.0);
assert_eq!(local_solution.x(), array![1.0]);
}
#[test]
fn test_solution_set_best_solution() {
let solutions: Array1<LocalSolution> = Array1::from_vec(vec![
LocalSolution { point: array![1.0], objective: -1.0 },
LocalSolution { point: array![2.0], objective: -1.0 },
LocalSolution { point: array![3.0], objective: -1.0 },
]);
let solution_set: SolutionSet = SolutionSet { solutions };
let best_solution = solution_set.best_solution().expect("No best solution found");
assert_eq!(best_solution.objective, -1.0);
}
#[cfg(feature = "checkpointing")]
#[test]
fn test_oqnlp_checkpoint_display() {
let solution_set = SolutionSet {
solutions: Array1::from_vec(vec![
LocalSolution { point: array![1.0, 2.0], objective: -1.5 },
LocalSolution { point: array![3.0, 4.0], objective: -2.0 },
]),
};
let checkpoint = OQNLPCheckpoint {
params: OQNLPParams {
iterations: 100,
population_size: 500,
wait_cycle: 20,
threshold_factor: 0.3,
distance_factor: 0.8,
seed: 42,
local_solver_type: LocalSolverType::COBYLA,
local_solver_config: crate::local_solver::builders::COBYLABuilder::default()
.build(),
},
current_iteration: 50,
merit_threshold: 1.25,
solution_set: Some(solution_set),
reference_set: vec![array![1.0, 2.0], array![3.0, 4.0]],
unchanged_cycles: 5,
elapsed_time: 123.45,
distance_filter_solutions: vec![],
current_seed: 42,
target_objective: Some(-1.5),
exclude_out_of_bounds: true,
#[cfg(feature = "rayon")]
batch_iterations: Some(4),
#[cfg(feature = "rayon")]
enable_parallel: true,
abs_tol: 1e-8,
rel_tol: 1e-6,
timestamp: "2025-07-27T12:00:00Z".to_string(),
};
let display_output = format!("{}", checkpoint);
assert!(display_output.contains("OQNLP Checkpoint"));
assert!(display_output.contains("2025-07-27T12:00:00Z"));
assert!(display_output.contains("Current iteration: 50"));
assert!(display_output.contains("Merit threshold: 1.25"));
assert!(display_output.contains("Solution set: 2 solutions"));
assert!(display_output.contains("Best objective: -2"));
assert!(display_output.contains("Reference set size: 2"));
assert!(display_output.contains("Unchanged cycles: 5"));
assert!(display_output.contains("Elapsed time: 123.45s"));
assert!(display_output.contains("Population size: 500"));
assert!(display_output.contains("Wait cycle: 20"));
assert!(display_output.contains("Local solver: COBYLA"));
}
#[test]
fn test_solution_set_display_with_constraints() {
use crate::problem::Problem;
use crate::types::EvaluationError;
#[derive(Debug, Clone)]
struct TestProblemWithConstraints;
impl Problem for TestProblemWithConstraints {
fn objective(&self, x: &Array1<f64>) -> Result<f64, EvaluationError> {
Ok(x[0].powi(2) + x[1].powi(2))
}
fn variable_bounds(&self) -> ndarray::Array2<f64> {
ndarray::array![[-2.0, 2.0], [-2.0, 2.0]]
}
fn constraints(&self) -> Vec<fn(&[f64], &mut ()) -> f64> {
vec![
|x: &[f64], _: &mut ()| 1.0 - x[0] - x[1], ]
}
}
let solutions =
Array1::from_vec(vec![LocalSolution { point: array![0.3, 0.3], objective: 0.18 }]);
let solution_set = SolutionSet { solutions };
let problem = TestProblemWithConstraints;
let constraint_descriptions = ["x[0] + x[1] <= 1.0"];
let display_output =
solution_set.display_with_constraints(&problem, Some(&constraint_descriptions));
assert!(display_output.contains("Solution Set"));
assert!(display_output.contains("Total solutions: 1"));
assert!(display_output.contains("Constraint violations:"));
assert!(display_output.contains("Constraint 1 [x[0] + x[1] <= 1.0]"));
assert!(display_output.contains("✓")); assert!(display_output.contains("(satisfied)"));
}
#[test]
fn test_solution_set_display_with_problem_no_constraints() {
use crate::problem::Problem;
use crate::types::EvaluationError;
#[derive(Debug, Clone)]
struct TestProblemNoConstraints;
impl Problem for TestProblemNoConstraints {
fn objective(&self, x: &Array1<f64>) -> Result<f64, EvaluationError> {
Ok(x[0].powi(2) + x[1].powi(2))
}
fn variable_bounds(&self) -> ndarray::Array2<f64> {
ndarray::array![[-2.0, 2.0], [-2.0, 2.0]]
}
}
let solutions =
Array1::from_vec(vec![LocalSolution { point: array![1.0, 1.0], objective: 2.0 }]);
let solution_set = SolutionSet { solutions };
let problem = TestProblemNoConstraints;
let display_output = solution_set.display_with_problem(&problem);
assert!(display_output.contains("Solution Set"));
assert!(display_output.contains("Total solutions: 1"));
assert!(!display_output.contains("Constraint violations:")); }
#[cfg(feature = "checkpointing")]
#[test]
fn test_oqnlp_checkpoint_display_no_solutions() {
let checkpoint = OQNLPCheckpoint {
params: OQNLPParams::default(),
current_iteration: 10,
merit_threshold: f64::INFINITY,
solution_set: None,
reference_set: vec![],
unchanged_cycles: 0,
elapsed_time: 15.5,
distance_filter_solutions: vec![],
current_seed: 0,
target_objective: None,
exclude_out_of_bounds: false,
#[cfg(feature = "rayon")]
batch_iterations: None,
#[cfg(feature = "rayon")]
enable_parallel: false,
abs_tol: 1e-8,
rel_tol: 1e-6,
timestamp: "2025-07-27T10:00:00Z".to_string(),
};
let display_output = format!("{}", checkpoint);
assert!(display_output.contains("OQNLP Checkpoint"));
assert!(display_output.contains("2025-07-27T10:00:00Z"));
assert!(display_output.contains("Current iteration: 10"));
assert!(display_output.contains("Solution set: None"));
assert!(display_output.contains("Reference set size: 0"));
}
}