use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum QuadraticConstraintType {
LessEq,
GreaterEq,
Equality { tolerance: f32 },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QuadraticConstraint {
q_matrix: Vec<Vec<f32>>,
linear: Vec<f32>,
rhs: f32,
constraint_type: QuadraticConstraintType,
weight: f32,
}
impl QuadraticConstraint {
pub fn less_eq(q_matrix: Vec<Vec<f32>>, linear: Vec<f32>, rhs: f32) -> Self {
Self {
q_matrix,
linear,
rhs,
constraint_type: QuadraticConstraintType::LessEq,
weight: 1.0,
}
}
pub fn greater_eq(q_matrix: Vec<Vec<f32>>, linear: Vec<f32>, rhs: f32) -> Self {
Self {
q_matrix,
linear,
rhs,
constraint_type: QuadraticConstraintType::GreaterEq,
weight: 1.0,
}
}
pub fn equality(q_matrix: Vec<Vec<f32>>, linear: Vec<f32>, rhs: f32, tolerance: f32) -> Self {
Self {
q_matrix,
linear,
rhs,
constraint_type: QuadraticConstraintType::Equality { tolerance },
weight: 1.0,
}
}
pub fn ball(center: Vec<f32>, radius: f32) -> Self {
let n = center.len();
let q_matrix: Vec<Vec<f32>> = (0..n)
.map(|i| {
let mut row = vec![0.0; n];
row[i] = 1.0;
row
})
.collect();
let linear: Vec<f32> = center.iter().map(|&ci| -2.0 * ci).collect();
let center_norm_sq: f32 = center.iter().map(|c| c * c).sum();
let rhs = radius * radius - center_norm_sq;
Self::less_eq(q_matrix, linear, rhs)
}
pub fn ellipsoid(a_matrix: Vec<Vec<f32>>, center: Vec<f32>) -> Self {
let n = center.len();
let q_matrix = a_matrix.clone();
let linear: Vec<f32> = (0..n)
.map(|i| {
let ac_i: f32 = a_matrix[i]
.iter()
.zip(center.iter())
.map(|(a, c)| a * c)
.sum();
-2.0 * ac_i
})
.collect();
let a_center: Vec<f32> = (0..n)
.map(|i| {
a_matrix[i]
.iter()
.zip(center.iter())
.map(|(a, c)| a * c)
.sum()
})
.collect();
let center_a_center: f32 = center
.iter()
.zip(a_center.iter())
.map(|(c, ac)| c * ac)
.sum();
let rhs = 1.0 - center_a_center;
Self::less_eq(q_matrix, linear, rhs)
}
pub fn with_weight(mut self, weight: f32) -> Self {
self.weight = weight;
self
}
fn quadratic_form(&self, x: &[f32]) -> f32 {
let n = x.len();
let mut result = 0.0;
for i in 0..n {
for j in 0..n {
if i < self.q_matrix.len() && j < self.q_matrix[i].len() {
result += x[i] * self.q_matrix[i][j] * x[j];
}
}
}
result
}
fn linear_term(&self, x: &[f32]) -> f32 {
self.linear.iter().zip(x.iter()).map(|(c, xi)| c * xi).sum()
}
pub fn evaluate(&self, x: &[f32]) -> f32 {
self.quadratic_form(x) + self.linear_term(x)
}
pub fn check(&self, x: &[f32]) -> bool {
let val = self.evaluate(x);
match &self.constraint_type {
QuadraticConstraintType::LessEq => val <= self.rhs,
QuadraticConstraintType::GreaterEq => val >= self.rhs,
QuadraticConstraintType::Equality { tolerance } => (val - self.rhs).abs() <= *tolerance,
}
}
pub fn violation(&self, x: &[f32]) -> f32 {
let val = self.evaluate(x);
match &self.constraint_type {
QuadraticConstraintType::LessEq => (val - self.rhs).max(0.0),
QuadraticConstraintType::GreaterEq => (self.rhs - val).max(0.0),
QuadraticConstraintType::Equality { tolerance } => {
let diff = (val - self.rhs).abs();
(diff - tolerance).max(0.0)
}
}
}
pub fn gradient(&self, x: &[f32]) -> Vec<f32> {
let n = x.len();
let m = self.q_matrix.len(); let mut grad = self.linear.clone();
grad.resize(n, 0.0);
for (i, grad_i) in grad.iter_mut().enumerate() {
if i < m {
let cols = self.q_matrix[i].len().min(n);
for (j, &xj) in x.iter().enumerate().take(cols) {
*grad_i += self.q_matrix[i][j] * xj;
}
}
for (j, &xj) in x.iter().enumerate().take(m) {
if i < self.q_matrix[j].len() {
*grad_i += self.q_matrix[j][i] * xj;
}
}
}
grad
}
pub fn project(&self, x: &[f32], max_iters: usize, step_size: f32) -> Vec<f32> {
if self.check(x) {
return x.to_vec();
}
let mut current = x.to_vec();
let n = current.len();
for _ in 0..max_iters {
let val = self.evaluate(¤t);
let grad = self.gradient(¤t);
let step = match &self.constraint_type {
QuadraticConstraintType::LessEq => {
if val <= self.rhs {
break;
}
val - self.rhs
}
QuadraticConstraintType::GreaterEq => {
if val >= self.rhs {
break;
}
self.rhs - val
}
QuadraticConstraintType::Equality { tolerance } => {
if (val - self.rhs).abs() <= *tolerance {
break;
}
val - self.rhs
}
};
let grad_norm: f32 = grad.iter().map(|g| g * g).sum::<f32>().sqrt();
if grad_norm < f32::EPSILON {
break;
}
let factor = step_size * step / grad_norm;
for i in 0..n {
current[i] -= factor * grad[i];
}
if step.abs() < 1e-6 {
break;
}
}
current
}
pub fn weight(&self) -> f32 {
self.weight
}
pub fn q_matrix(&self) -> &[Vec<f32>] {
&self.q_matrix
}
pub fn linear(&self) -> &[f32] {
&self.linear
}
pub fn rhs(&self) -> f32 {
self.rhs
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QuadraticConstraintSet {
constraints: Vec<QuadraticConstraint>,
}
impl QuadraticConstraintSet {
pub fn new(constraints: Vec<QuadraticConstraint>) -> Self {
Self { constraints }
}
pub fn check_all(&self, x: &[f32]) -> bool {
self.constraints.iter().all(|c| c.check(x))
}
pub fn total_violation(&self, x: &[f32]) -> f32 {
self.constraints
.iter()
.map(|c| c.violation(x) * c.weight())
.sum()
}
pub fn project(&self, x: &[f32], max_outer_iters: usize, max_inner_iters: usize) -> Vec<f32> {
let mut current = x.to_vec();
for _ in 0..max_outer_iters {
let prev = current.clone();
for c in &self.constraints {
current = c.project(¤t, max_inner_iters, 0.1);
}
let diff: f32 = current
.iter()
.zip(prev.iter())
.map(|(a, b)| (a - b).abs())
.sum();
if diff < 1e-6 {
break;
}
}
current
}
pub fn len(&self) -> usize {
self.constraints.len()
}
pub fn is_empty(&self) -> bool {
self.constraints.is_empty()
}
pub fn add(&mut self, constraint: QuadraticConstraint) {
self.constraints.push(constraint);
}
pub fn constraints(&self) -> &[QuadraticConstraint] {
&self.constraints
}
}
impl Default for QuadraticConstraintSet {
fn default() -> Self {
Self::new(Vec::new())
}
}