use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum LinearConstraintType {
LessEq,
GreaterEq,
Equality { tolerance: f32 },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LinearConstraint {
coefficients: Vec<f32>,
rhs: f32,
constraint_type: LinearConstraintType,
weight: f32,
}
impl LinearConstraint {
pub fn less_eq(coefficients: Vec<f32>, rhs: f32) -> Self {
Self {
coefficients,
rhs,
constraint_type: LinearConstraintType::LessEq,
weight: 1.0,
}
}
pub fn greater_eq(coefficients: Vec<f32>, rhs: f32) -> Self {
Self {
coefficients,
rhs,
constraint_type: LinearConstraintType::GreaterEq,
weight: 1.0,
}
}
pub fn equality(coefficients: Vec<f32>, rhs: f32, tolerance: f32) -> Self {
Self {
coefficients,
rhs,
constraint_type: LinearConstraintType::Equality { tolerance },
weight: 1.0,
}
}
pub fn with_weight(mut self, weight: f32) -> Self {
self.weight = weight;
self
}
fn dot(&self, x: &[f32]) -> f32 {
self.coefficients
.iter()
.zip(x.iter())
.map(|(a, xi)| a * xi)
.sum()
}
fn norm_sq(&self) -> f32 {
self.coefficients.iter().map(|a| a * a).sum()
}
pub fn check(&self, x: &[f32]) -> bool {
let ax = self.dot(x);
match &self.constraint_type {
LinearConstraintType::LessEq => ax <= self.rhs,
LinearConstraintType::GreaterEq => ax >= self.rhs,
LinearConstraintType::Equality { tolerance } => (ax - self.rhs).abs() <= *tolerance,
}
}
pub fn violation(&self, x: &[f32]) -> f32 {
let ax = self.dot(x);
match &self.constraint_type {
LinearConstraintType::LessEq => (ax - self.rhs).max(0.0),
LinearConstraintType::GreaterEq => (self.rhs - ax).max(0.0),
LinearConstraintType::Equality { tolerance } => {
let diff = (ax - self.rhs).abs();
(diff - tolerance).max(0.0)
}
}
}
pub fn project(&self, x: &[f32]) -> Vec<f32> {
let ax = self.dot(x);
let norm_sq = self.norm_sq();
if norm_sq < f32::EPSILON {
return x.to_vec();
}
let needs_projection = match &self.constraint_type {
LinearConstraintType::LessEq => ax > self.rhs,
LinearConstraintType::GreaterEq => ax < self.rhs,
LinearConstraintType::Equality { tolerance } => (ax - self.rhs).abs() > *tolerance,
};
if !needs_projection {
return x.to_vec();
}
let factor = (ax - self.rhs) / norm_sq;
x.iter()
.zip(self.coefficients.iter())
.map(|(xi, ai)| xi - factor * ai)
.collect()
}
pub fn weight(&self) -> f32 {
self.weight
}
pub fn coefficients(&self) -> &[f32] {
&self.coefficients
}
pub fn rhs(&self) -> f32 {
self.rhs
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LinearConstraintSet {
constraints: Vec<LinearConstraint>,
}
impl LinearConstraintSet {
pub fn new(constraints: Vec<LinearConstraint>) -> Self {
Self { constraints }
}
pub fn check_all(&self, x: &[f32]) -> bool {
self.constraints.iter().all(|c| c.check(x))
}
pub fn check_each(&self, x: &[f32]) -> Vec<bool> {
self.constraints.iter().map(|c| c.check(x)).collect()
}
pub fn total_violation(&self, x: &[f32]) -> f32 {
self.constraints
.iter()
.map(|c| c.violation(x) * c.weight())
.sum()
}
pub fn project(&self, x: &[f32], max_iters: usize) -> Vec<f32> {
let mut current = x.to_vec();
for _ in 0..max_iters {
let prev = current.clone();
for c in &self.constraints {
current = c.project(¤t);
}
let diff: f32 = current
.iter()
.zip(prev.iter())
.map(|(a, b)| (a - b).abs())
.sum();
if diff < 1e-6 {
break;
}
}
current
}
pub fn len(&self) -> usize {
self.constraints.len()
}
pub fn is_empty(&self) -> bool {
self.constraints.is_empty()
}
pub fn add(&mut self, constraint: LinearConstraint) {
self.constraints.push(constraint);
}
pub fn constraints(&self) -> &[LinearConstraint] {
&self.constraints
}
}
impl Default for LinearConstraintSet {
fn default() -> Self {
Self::new(Vec::new())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AffineEquality {
matrix: Vec<Vec<f32>>,
rhs: Vec<f32>,
tolerance: f32,
}
impl AffineEquality {
pub fn new(matrix: Vec<Vec<f32>>, rhs: Vec<f32>, tolerance: f32) -> Self {
Self {
matrix,
rhs,
tolerance,
}
}
fn multiply(&self, x: &[f32]) -> Vec<f32> {
self.matrix
.iter()
.map(|row| row.iter().zip(x.iter()).map(|(a, xi)| a * xi).sum())
.collect()
}
pub fn check(&self, x: &[f32]) -> bool {
let ax = self.multiply(x);
ax.iter()
.zip(self.rhs.iter())
.all(|(axi, bi)| (axi - bi).abs() <= self.tolerance)
}
pub fn residual(&self, x: &[f32]) -> f32 {
let ax = self.multiply(x);
ax.iter()
.zip(self.rhs.iter())
.map(|(axi, bi)| (axi - bi).powi(2))
.sum::<f32>()
.sqrt()
}
pub fn violations(&self, x: &[f32]) -> Vec<f32> {
let ax = self.multiply(x);
ax.iter()
.zip(self.rhs.iter())
.map(|(axi, bi)| (axi - bi).abs())
.collect()
}
pub fn matrix(&self) -> &[Vec<f32>] {
&self.matrix
}
pub fn rhs(&self) -> &[f32] {
&self.rhs
}
pub fn num_equations(&self) -> usize {
self.matrix.len()
}
pub fn num_variables(&self) -> Option<usize> {
self.matrix.first().map(|row| row.len())
}
}