use serde::{Deserialize, Serialize};
type ConstraintFn = std::sync::Arc<dyn Fn(&[f32]) -> f32 + Send + Sync>;
type GradientFn = std::sync::Arc<dyn Fn(&[f32]) -> Vec<f32> + Send + Sync>;
#[derive(Clone)]
pub struct NonlinearConstraint {
name: String,
constraint_fn: ConstraintFn,
gradient_fn: Option<GradientFn>,
constraint_type: NonlinearConstraintType,
weight: f32,
}
impl std::fmt::Debug for NonlinearConstraint {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("NonlinearConstraint")
.field("name", &self.name)
.field("has_gradient", &self.gradient_fn.is_some())
.field("constraint_type", &self.constraint_type)
.field("weight", &self.weight)
.finish()
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum NonlinearConstraintType {
Inequality,
Equality { tolerance: f32 },
}
impl NonlinearConstraint {
pub fn inequality<F>(name: impl Into<String>, f: F) -> Self
where
F: Fn(&[f32]) -> f32 + Send + Sync + 'static,
{
Self {
name: name.into(),
constraint_fn: std::sync::Arc::new(f),
gradient_fn: None,
constraint_type: NonlinearConstraintType::Inequality,
weight: 1.0,
}
}
pub fn equality<F>(name: impl Into<String>, f: F, tolerance: f32) -> Self
where
F: Fn(&[f32]) -> f32 + Send + Sync + 'static,
{
Self {
name: name.into(),
constraint_fn: std::sync::Arc::new(f),
gradient_fn: None,
constraint_type: NonlinearConstraintType::Equality { tolerance },
weight: 1.0,
}
}
pub fn with_gradient<G>(mut self, grad: G) -> Self
where
G: Fn(&[f32]) -> Vec<f32> + Send + Sync + 'static,
{
self.gradient_fn = Some(std::sync::Arc::new(grad));
self
}
pub fn with_weight(mut self, weight: f32) -> Self {
self.weight = weight;
self
}
pub fn evaluate(&self, x: &[f32]) -> f32 {
(self.constraint_fn)(x)
}
pub fn gradient(&self, x: &[f32]) -> Option<Vec<f32>> {
self.gradient_fn.as_ref().map(|g| g(x))
}
pub fn check(&self, x: &[f32]) -> bool {
let val = self.evaluate(x);
match self.constraint_type {
NonlinearConstraintType::Inequality => val <= 0.0,
NonlinearConstraintType::Equality { tolerance } => val.abs() <= tolerance,
}
}
pub fn violation(&self, x: &[f32]) -> f32 {
let val = self.evaluate(x);
match self.constraint_type {
NonlinearConstraintType::Inequality => val.max(0.0),
NonlinearConstraintType::Equality { tolerance } => (val.abs() - tolerance).max(0.0),
}
}
pub fn project(&self, x: &[f32], max_iter: usize, step_size: f32) -> Vec<f32> {
if let Some(grad_fn) = &self.gradient_fn {
let mut result = x.to_vec();
for _ in 0..max_iter {
if self.check(&result) {
break;
}
let grad = grad_fn(&result);
let val = self.evaluate(&result);
for (ri, &gi) in result.iter_mut().zip(grad.iter()) {
*ri -= step_size * val.signum() * gi;
}
}
result
} else {
x.to_vec()
}
}
pub fn name(&self) -> &str {
&self.name
}
pub fn weight(&self) -> f32 {
self.weight
}
pub fn has_gradient(&self) -> bool {
self.gradient_fn.is_some()
}
}