use std::sync::Arc;
use crate::expr::Expr;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Sign {
Nonnegative,
Nonpositive,
Zero,
Unknown,
}
impl Sign {
pub fn is_nonneg(self) -> bool {
matches!(self, Sign::Nonnegative | Sign::Zero)
}
pub fn is_nonpos(self) -> bool {
matches!(self, Sign::Nonpositive | Sign::Zero)
}
pub fn is_zero(self) -> bool {
matches!(self, Sign::Zero)
}
pub fn negate(self) -> Self {
match self {
Sign::Nonnegative => Sign::Nonpositive,
Sign::Nonpositive => Sign::Nonnegative,
Sign::Zero => Sign::Zero,
Sign::Unknown => Sign::Unknown,
}
}
}
pub fn add_sign(a: Sign, b: Sign) -> Sign {
use Sign::*;
match (a, b) {
(Zero, x) | (x, Zero) => x,
(Nonnegative, Nonnegative) => Nonnegative,
(Nonpositive, Nonpositive) => Nonpositive,
(Nonnegative, Nonpositive) | (Nonpositive, Nonnegative) => Unknown,
(Unknown, _) | (_, Unknown) => Unknown,
}
}
pub fn mul_sign(a: Sign, b: Sign) -> Sign {
use Sign::*;
match (a, b) {
(Zero, _) | (_, Zero) => Zero,
(Nonnegative, Nonnegative) | (Nonpositive, Nonpositive) => Nonnegative,
(Nonnegative, Nonpositive) | (Nonpositive, Nonnegative) => Nonpositive,
(Unknown, _) | (_, Unknown) => Unknown,
}
}
impl Expr {
pub fn sign(&self) -> Sign {
match self {
Expr::Variable(v) => {
if v.nonneg {
Sign::Nonnegative
} else if v.nonpos {
Sign::Nonpositive
} else {
Sign::Unknown
}
}
Expr::Constant(c) => {
if c.value.is_nonneg() && c.value.is_nonpos() {
Sign::Zero
} else if c.value.is_nonneg() {
Sign::Nonnegative
} else if c.value.is_nonpos() {
Sign::Nonpositive
} else {
Sign::Unknown
}
}
Expr::Add(a, b) => add_sign(a.sign(), b.sign()),
Expr::Neg(a) => a.sign().negate(),
Expr::Mul(a, b) => mul_sign(a.sign(), b.sign()),
Expr::MatMul(a, b) => {
let as_ = a.sign();
let bs = b.sign();
if as_.is_zero() || bs.is_zero() {
Sign::Zero
} else if (as_.is_nonneg() && bs.is_nonneg()) || (as_.is_nonpos() && bs.is_nonpos())
{
Sign::Nonnegative
} else {
Sign::Unknown
}
}
Expr::Sum(a, _) => a.sign(),
Expr::Reshape(a, _) => a.sign(),
Expr::Index(a, _) => a.sign(),
Expr::VStack(exprs) => combine_signs(exprs),
Expr::HStack(exprs) => combine_signs(exprs),
Expr::Transpose(a) => a.sign(),
Expr::Trace(a) => a.sign(),
Expr::Norm1(_) | Expr::Norm2(_) | Expr::NormInf(_) => Sign::Nonnegative,
Expr::Abs(_) => Sign::Nonnegative,
Expr::Pos(_) => Sign::Nonnegative,
Expr::NegPart(_) => Sign::Nonnegative,
Expr::Maximum(exprs) => {
if exprs.iter().any(|e| e.sign().is_nonneg()) {
Sign::Nonnegative
} else if exprs.iter().all(|e| e.sign().is_nonpos()) {
Sign::Nonpositive
} else {
Sign::Unknown
}
}
Expr::Minimum(exprs) => {
if exprs.iter().any(|e| e.sign().is_nonpos()) {
Sign::Nonpositive
} else if exprs.iter().all(|e| e.sign().is_nonneg()) {
Sign::Nonnegative
} else {
Sign::Unknown
}
}
Expr::QuadForm(_, p) => {
if let Some(p_val) = p.constant_value() {
use super::curvature::PsdStatus;
match PsdStatus::of_array(p_val) {
PsdStatus::Psd => Sign::Nonnegative,
PsdStatus::Nsd => Sign::Nonpositive,
PsdStatus::Neither => Sign::Unknown,
}
} else {
Sign::Unknown
}
}
Expr::SumSquares(_) => Sign::Nonnegative,
Expr::QuadOverLin(_, _) => Sign::Nonnegative,
Expr::Exp(_) => Sign::Nonnegative, Expr::Log(_x) => {
Sign::Unknown
}
Expr::Entropy(_) => {
Sign::Unknown
}
Expr::Power(x, p) => {
if *p > 0.0 {
if x.sign().is_nonneg() {
Sign::Nonnegative
} else {
Sign::Unknown
}
} else if *p < 0.0 {
if x.sign().is_nonneg() {
Sign::Nonnegative
} else {
Sign::Unknown
}
} else {
Sign::Nonnegative
}
}
Expr::Cumsum(x, _) => x.sign(), Expr::Diag(x) => x.sign(), }
}
pub fn is_nonneg(&self) -> bool {
self.sign().is_nonneg()
}
pub fn is_nonpos(&self) -> bool {
self.sign().is_nonpos()
}
}
fn combine_signs(exprs: &[Arc<Expr>]) -> Sign {
if exprs.is_empty() {
return Sign::Zero;
}
let mut all_nonneg = true;
let mut all_nonpos = true;
let mut all_zero = true;
for e in exprs {
let s = e.sign();
if !s.is_nonneg() {
all_nonneg = false;
}
if !s.is_nonpos() {
all_nonpos = false;
}
if !s.is_zero() {
all_zero = false;
}
}
if all_zero {
Sign::Zero
} else if all_nonneg {
Sign::Nonnegative
} else if all_nonpos {
Sign::Nonpositive
} else {
Sign::Unknown
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::expr::{constant, nonneg_variable, variable};
#[test]
fn test_sign_basics() {
assert!(Sign::Nonnegative.is_nonneg());
assert!(!Sign::Nonnegative.is_nonpos());
assert!(!Sign::Nonpositive.is_nonneg());
assert!(Sign::Nonpositive.is_nonpos());
assert!(Sign::Zero.is_nonneg());
assert!(Sign::Zero.is_nonpos());
assert!(Sign::Zero.is_zero());
}
#[test]
fn test_negate_sign() {
assert_eq!(Sign::Nonnegative.negate(), Sign::Nonpositive);
assert_eq!(Sign::Nonpositive.negate(), Sign::Nonnegative);
assert_eq!(Sign::Zero.negate(), Sign::Zero);
}
#[test]
fn test_add_sign() {
use Sign::*;
assert_eq!(add_sign(Nonnegative, Nonnegative), Nonnegative);
assert_eq!(add_sign(Nonpositive, Nonpositive), Nonpositive);
assert_eq!(add_sign(Nonnegative, Nonpositive), Unknown);
assert_eq!(add_sign(Zero, Nonnegative), Nonnegative);
}
#[test]
fn test_mul_sign() {
use Sign::*;
assert_eq!(mul_sign(Nonnegative, Nonnegative), Nonnegative);
assert_eq!(mul_sign(Nonpositive, Nonpositive), Nonnegative);
assert_eq!(mul_sign(Nonnegative, Nonpositive), Nonpositive);
assert_eq!(mul_sign(Zero, Unknown), Zero);
}
#[test]
fn test_variable_sign() {
let x = variable(5);
assert_eq!(x.sign(), Sign::Unknown);
let y = nonneg_variable(5);
assert_eq!(y.sign(), Sign::Nonnegative);
}
#[test]
fn test_constant_sign() {
let c = constant(5.0);
assert_eq!(c.sign(), Sign::Nonnegative);
let c = constant(-5.0);
assert_eq!(c.sign(), Sign::Nonpositive);
let c = constant(0.0);
assert_eq!(c.sign(), Sign::Zero);
}
#[test]
fn test_norm_sign() {
let x = variable(5);
let n = Expr::Norm2(Arc::new(x));
assert_eq!(n.sign(), Sign::Nonnegative);
}
}