use serde::{Deserialize, Serialize};
use super::{
Constraint, LinearConstraint, NonlinearConstraint, QuadraticConstraint, SetMembershipConstraint,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum ConstraintMode {
#[default]
Hard,
Soft,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default)]
pub enum PenaltyFunction {
L1,
#[default]
L2,
Huber { delta: f32 },
LogBarrier { slack: f32 },
Exact { threshold: f32 },
}
impl PenaltyFunction {
pub fn compute(&self, violation: f32, weight: f32) -> f32 {
if violation <= 0.0 {
return 0.0;
}
match self {
Self::L1 => weight * violation,
Self::L2 => weight * violation * violation,
Self::Huber { delta } => {
if violation <= *delta {
weight * 0.5 * violation * violation
} else {
weight * (*delta * violation - 0.5 * delta * delta)
}
}
Self::LogBarrier { slack } => {
let s = *slack - violation;
if s > 0.0 {
-weight * s.ln()
} else {
f32::MAX
}
}
Self::Exact { threshold } => {
if violation > *threshold {
f32::MAX
} else {
0.0
}
}
}
}
pub fn gradient(&self, violation: f32, weight: f32) -> f32 {
if violation <= 0.0 {
return 0.0;
}
match self {
Self::L1 => weight,
Self::L2 => 2.0 * weight * violation,
Self::Huber { delta } => {
if violation <= *delta {
weight * violation
} else {
weight * *delta
}
}
Self::LogBarrier { slack } => {
let s = *slack - violation;
if s > 0.0 {
weight / s
} else {
f32::MAX
}
}
Self::Exact { .. } => 0.0, }
}
}
#[derive(Debug, Clone)]
pub struct SoftHardConstraint<C> {
constraint: C,
mode: ConstraintMode,
penalty: PenaltyFunction,
weight: f32,
priority: u32,
}
impl<C> SoftHardConstraint<C> {
pub fn hard(constraint: C) -> Self {
Self {
constraint,
mode: ConstraintMode::Hard,
penalty: PenaltyFunction::default(),
weight: 1.0,
priority: 0,
}
}
pub fn soft(constraint: C) -> Self {
Self {
constraint,
mode: ConstraintMode::Soft,
penalty: PenaltyFunction::default(),
weight: 1.0,
priority: 0,
}
}
pub fn with_penalty(mut self, penalty: PenaltyFunction) -> Self {
self.penalty = penalty;
self
}
pub fn with_weight(mut self, weight: f32) -> Self {
self.weight = weight;
self
}
pub fn with_priority(mut self, priority: u32) -> Self {
self.priority = priority;
self
}
pub fn mode(&self) -> ConstraintMode {
self.mode
}
pub fn is_hard(&self) -> bool {
self.mode == ConstraintMode::Hard
}
pub fn is_soft(&self) -> bool {
self.mode == ConstraintMode::Soft
}
pub fn inner(&self) -> &C {
&self.constraint
}
pub fn weight(&self) -> f32 {
self.weight
}
pub fn priority(&self) -> u32 {
self.priority
}
pub fn penalty(&self) -> PenaltyFunction {
self.penalty
}
}
pub trait ViolationComputable {
fn violation(&self, x: &[f32]) -> f32;
fn check(&self, x: &[f32]) -> bool;
}
impl ViolationComputable for Constraint {
fn violation(&self, x: &[f32]) -> f32 {
if let Some(dim) = self.dimension() {
if dim < x.len() {
Constraint::violation(self, x[dim])
} else {
0.0
}
} else {
x.iter().map(|&v| Constraint::violation(self, v)).sum()
}
}
fn check(&self, x: &[f32]) -> bool {
if let Some(dim) = self.dimension() {
if dim < x.len() {
Constraint::check(self, x[dim])
} else {
true
}
} else {
x.iter().all(|&v| Constraint::check(self, v))
}
}
}
impl ViolationComputable for LinearConstraint {
fn violation(&self, x: &[f32]) -> f32 {
LinearConstraint::violation(self, x)
}
fn check(&self, x: &[f32]) -> bool {
LinearConstraint::check(self, x)
}
}
impl ViolationComputable for QuadraticConstraint {
fn violation(&self, x: &[f32]) -> f32 {
QuadraticConstraint::violation(self, x)
}
fn check(&self, x: &[f32]) -> bool {
QuadraticConstraint::check(self, x)
}
}
impl ViolationComputable for NonlinearConstraint {
fn violation(&self, x: &[f32]) -> f32 {
NonlinearConstraint::violation(self, x)
}
fn check(&self, x: &[f32]) -> bool {
NonlinearConstraint::check(self, x)
}
}
impl ViolationComputable for SetMembershipConstraint {
fn violation(&self, x: &[f32]) -> f32 {
SetMembershipConstraint::violation(self, x)
}
fn check(&self, x: &[f32]) -> bool {
SetMembershipConstraint::check(self, x)
}
}
impl<C: ViolationComputable> SoftHardConstraint<C> {
pub fn check(&self, x: &[f32]) -> bool {
self.constraint.check(x)
}
pub fn violation(&self, x: &[f32]) -> f32 {
self.constraint.violation(x)
}
pub fn loss(&self, x: &[f32]) -> f32 {
let viol = self.constraint.violation(x);
match self.mode {
ConstraintMode::Hard => {
if viol > 0.0 {
f32::MAX
} else {
0.0
}
}
ConstraintMode::Soft => self.penalty.compute(viol, self.weight),
}
}
pub fn loss_gradient(&self, x: &[f32]) -> f32 {
let viol = self.constraint.violation(x);
match self.mode {
ConstraintMode::Hard => 0.0, ConstraintMode::Soft => self.penalty.gradient(viol, self.weight),
}
}
}
#[derive(Debug, Default)]
pub struct ConstraintSet<C> {
constraints: Vec<SoftHardConstraint<C>>,
}
impl<C: ViolationComputable> ConstraintSet<C> {
pub fn new() -> Self {
Self {
constraints: Vec::new(),
}
}
pub fn add_hard(&mut self, constraint: C) {
self.constraints.push(SoftHardConstraint::hard(constraint));
}
pub fn add_soft(&mut self, constraint: C, penalty: PenaltyFunction, weight: f32) {
self.constraints.push(
SoftHardConstraint::soft(constraint)
.with_penalty(penalty)
.with_weight(weight),
);
}
pub fn add(&mut self, constraint: SoftHardConstraint<C>) {
self.constraints.push(constraint);
}
pub fn all_hard_satisfied(&self, x: &[f32]) -> bool {
self.constraints
.iter()
.filter(|c| c.is_hard())
.all(|c| c.check(x))
}
pub fn all_satisfied(&self, x: &[f32]) -> bool {
self.constraints.iter().all(|c| c.check(x))
}
pub fn soft_loss(&self, x: &[f32]) -> f32 {
self.constraints
.iter()
.filter(|c| c.is_soft())
.map(|c| c.loss(x))
.sum()
}
pub fn total_loss(&self, x: &[f32]) -> f32 {
let mut loss = 0.0;
for c in &self.constraints {
let l = c.loss(x);
if l == f32::MAX {
return f32::MAX;
}
loss += l;
}
loss
}
pub fn hard_constraints(&self) -> impl Iterator<Item = &SoftHardConstraint<C>> {
self.constraints.iter().filter(|c| c.is_hard())
}
pub fn soft_constraints(&self) -> impl Iterator<Item = &SoftHardConstraint<C>> {
self.constraints.iter().filter(|c| c.is_soft())
}
pub fn by_priority(&self) -> Vec<&SoftHardConstraint<C>> {
let mut sorted: Vec<_> = self.constraints.iter().collect();
sorted.sort_by_key(|c| std::cmp::Reverse(c.priority()));
sorted
}
pub fn len(&self) -> usize {
self.constraints.len()
}
pub fn is_empty(&self) -> bool {
self.constraints.is_empty()
}
}