use std::sync::Arc;
use crate::expr::Expr;
#[derive(Debug, Clone)]
pub enum Constraint {
Zero(Arc<Expr>),
NonNeg(Arc<Expr>),
SOC {
t: Arc<Expr>,
x: Arc<Expr>,
},
}
impl Constraint {
fn broadcast_scalar(scalar: &Expr, target_shape: &crate::expr::Shape) -> Expr {
use crate::expr::{constant, ones};
if let Expr::Constant(data) = scalar {
if let Some(val) = data.value.as_scalar() {
if target_shape.is_scalar() {
return scalar.clone();
}
return constant(val) * ones(target_shape.clone());
}
}
scalar.clone()
}
pub fn eq(lhs: Expr, rhs: Expr) -> Self {
let (lhs, rhs) = Self::broadcast_if_needed(lhs, rhs);
Constraint::Zero(Arc::new(Expr::Add(
Arc::new(lhs),
Arc::new(Expr::Neg(Arc::new(rhs))),
)))
}
pub fn leq(lhs: Expr, rhs: Expr) -> Self {
let (lhs, rhs) = Self::broadcast_if_needed(lhs, rhs);
Constraint::NonNeg(Arc::new(Expr::Add(
Arc::new(rhs),
Arc::new(Expr::Neg(Arc::new(lhs))),
)))
}
pub fn geq(lhs: Expr, rhs: Expr) -> Self {
let (lhs, rhs) = Self::broadcast_if_needed(lhs, rhs);
Constraint::NonNeg(Arc::new(Expr::Add(
Arc::new(lhs),
Arc::new(Expr::Neg(Arc::new(rhs))),
)))
}
fn broadcast_if_needed(lhs: Expr, rhs: Expr) -> (Expr, Expr) {
let lhs_shape = lhs.shape();
let rhs_shape = rhs.shape();
if lhs_shape == rhs_shape {
return (lhs, rhs);
}
if lhs_shape.is_scalar() && !rhs_shape.is_scalar() {
(Self::broadcast_scalar(&lhs, &rhs_shape), rhs)
} else if rhs_shape.is_scalar() && !lhs_shape.is_scalar() {
(lhs, Self::broadcast_scalar(&rhs, &lhs_shape))
} else {
(lhs, rhs)
}
}
pub fn soc(t: Expr, x: Expr) -> Self {
Constraint::SOC {
t: Arc::new(t),
x: Arc::new(x),
}
}
pub fn is_dcp(&self) -> bool {
match self {
Constraint::Zero(expr) => expr.is_affine(),
Constraint::NonNeg(expr) => expr.is_concave(),
Constraint::SOC { t, x } => t.is_affine() && x.is_affine(),
}
}
pub fn expressions(&self) -> Vec<&Expr> {
match self {
Constraint::Zero(e) => vec![e.as_ref()],
Constraint::NonNeg(e) => vec![e.as_ref()],
Constraint::SOC { t, x } => vec![t.as_ref(), x.as_ref()],
}
}
pub fn variables(&self) -> Vec<crate::expr::ExprId> {
let mut vars = Vec::new();
for expr in self.expressions() {
vars.extend(expr.variables());
}
vars.sort_by_key(|id| id.raw());
vars.dedup();
vars
}
}
pub trait ConstraintExt {
fn eq<E: Into<Expr>>(&self, rhs: E) -> Constraint;
fn le<E: Into<Expr>>(&self, rhs: E) -> Constraint;
fn ge<E: Into<Expr>>(&self, rhs: E) -> Constraint;
}
impl ConstraintExt for Expr {
fn eq<E: Into<Expr>>(&self, rhs: E) -> Constraint {
Constraint::eq(self.clone(), rhs.into())
}
fn le<E: Into<Expr>>(&self, rhs: E) -> Constraint {
Constraint::leq(self.clone(), rhs.into())
}
fn ge<E: Into<Expr>>(&self, rhs: E) -> Constraint {
Constraint::geq(self.clone(), rhs.into())
}
}
#[macro_export]
macro_rules! constraint {
($lhs:tt >= $rhs:tt) => {
$crate::constraints::ConstraintExt::ge(&$lhs, $rhs)
};
($lhs:tt <= $rhs:tt) => {
$crate::constraints::ConstraintExt::le(&$lhs, $rhs)
};
($lhs:tt == $rhs:tt) => {
$crate::constraints::ConstraintExt::eq(&$lhs, $rhs)
};
(($($lhs:tt)+) >= $rhs:tt) => {
$crate::constraints::ConstraintExt::ge(&($($lhs)+), $rhs)
};
(($($lhs:tt)+) <= $rhs:tt) => {
$crate::constraints::ConstraintExt::le(&($($lhs)+), $rhs)
};
(($($lhs:tt)+) == $rhs:tt) => {
$crate::constraints::ConstraintExt::eq(&($($lhs)+), $rhs)
};
($lhs:tt >= ($($rhs:tt)+)) => {
$crate::constraints::ConstraintExt::ge(&$lhs, ($($rhs)+))
};
($lhs:tt <= ($($rhs:tt)+)) => {
$crate::constraints::ConstraintExt::le(&$lhs, ($($rhs)+))
};
($lhs:tt == ($($rhs:tt)+)) => {
$crate::constraints::ConstraintExt::eq(&$lhs, ($($rhs)+))
};
(($($lhs:tt)+) >= ($($rhs:tt)+)) => {
$crate::constraints::ConstraintExt::ge(&($($lhs)+), ($($rhs)+))
};
(($($lhs:tt)+) <= ($($rhs:tt)+)) => {
$crate::constraints::ConstraintExt::le(&($($lhs)+), ($($rhs)+))
};
(($($lhs:tt)+) == ($($rhs:tt)+)) => {
$crate::constraints::ConstraintExt::eq(&($($lhs)+), ($($rhs)+))
};
}
#[cfg(test)]
mod tests {
use super::*;
use crate::expr::{constant, variable};
#[test]
fn test_equality_constraint() {
let x = variable(5);
let c = constant(1.0);
let constr = Constraint::eq(x, c);
assert!(constr.is_dcp());
if let Constraint::Zero(_) = constr {
} else {
panic!("Expected Zero constraint");
}
}
#[test]
fn test_inequality_constraint() {
let x = variable(5);
let c = constant(0.0);
let constr = Constraint::geq(x, c);
assert!(constr.is_dcp());
if let Constraint::NonNeg(_) = constr {
} else {
panic!("Expected NonNeg constraint");
}
}
#[test]
fn test_soc_constraint() {
let t = variable(());
let x = variable(5);
let constr = Constraint::soc(t, x);
assert!(constr.is_dcp());
}
#[test]
fn test_non_dcp_constraint() {
let x = variable(5);
let norm_x = Expr::Norm2(Arc::new(x));
let c = constant(1.0);
let constr = Constraint::geq(norm_x, c);
assert!(!constr.is_dcp());
}
#[test]
fn test_constraint_ext() {
let x = variable(5);
let eq_constr = x.eq(1.0);
assert!(eq_constr.is_dcp());
let le_constr = x.le(2.0);
assert!(le_constr.is_dcp());
let ge_constr = x.ge(0.0);
assert!(ge_constr.is_dcp());
}
#[test]
fn test_broadcasting_scalar_to_vector() {
let x = variable(5);
let constr = x.ge(0.0);
assert!(constr.is_dcp());
if let Constraint::NonNeg(_) = constr {
} else {
panic!("Expected NonNeg constraint");
}
}
#[test]
fn test_broadcasting_with_macro() {
let x = variable(3);
let c1 = constraint!(x >= 0.0);
assert!(c1.is_dcp());
let c2 = constraint!(x <= 10.0);
assert!(c2.is_dcp());
let c3 = constraint!(x == 5.0);
assert!(c3.is_dcp());
}
#[test]
fn test_no_broadcasting_when_shapes_match() {
use crate::expr::zeros;
let x = variable(4);
let z = zeros(4);
let constr = x.ge(z);
assert!(constr.is_dcp());
}
#[test]
fn test_new_short_methods() {
let x = variable(5);
let c1 = x.ge(1.0);
assert!(c1.is_dcp());
let x2 = variable(5);
let c2 = x2.le(10.0);
assert!(c2.is_dcp());
let x3 = variable(5);
let c3 = x3.eq(5.0);
assert!(c3.is_dcp());
}
}