use std::sync::Arc;
use crate::expr::{Array, Expr};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Curvature {
Constant,
Affine,
Convex,
Concave,
Unknown,
}
impl Curvature {
pub fn is_convex(self) -> bool {
matches!(
self,
Curvature::Constant | Curvature::Affine | Curvature::Convex
)
}
pub fn is_concave(self) -> bool {
matches!(
self,
Curvature::Constant | Curvature::Affine | Curvature::Concave
)
}
pub fn is_affine(self) -> bool {
matches!(self, Curvature::Constant | Curvature::Affine)
}
pub fn is_constant(self) -> bool {
matches!(self, Curvature::Constant)
}
pub fn negate(self) -> Self {
match self {
Curvature::Convex => Curvature::Concave,
Curvature::Concave => Curvature::Convex,
other => other,
}
}
}
pub fn add_curvature(a: Curvature, b: Curvature) -> Curvature {
use Curvature::*;
match (a, b) {
(Constant, x) | (x, Constant) => x,
(Affine, Affine) => Affine,
(Affine, x) | (x, Affine) => x,
(Convex, Convex) => Convex,
(Concave, Concave) => Concave,
(Convex, Concave) | (Concave, Convex) => Unknown,
(Unknown, _) | (_, Unknown) => Unknown,
}
}
pub fn scalar_mul_curvature(scalar: f64, expr_curv: Curvature) -> Curvature {
if scalar == 0.0 {
Curvature::Constant
} else if scalar > 0.0 {
expr_curv
} else {
expr_curv.negate()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PsdStatus {
Psd, Nsd, Neither, }
impl PsdStatus {
pub fn of_array(arr: &Array) -> Self {
match arr.is_psd() {
Some(true) => PsdStatus::Psd,
Some(false) => {
match arr {
Array::Scalar(v) => {
if *v <= 0.0 {
PsdStatus::Nsd
} else {
PsdStatus::Neither
}
}
Array::Dense(m) => {
let neg = -m.clone();
if neg.cholesky().is_some() {
PsdStatus::Nsd
} else {
PsdStatus::Neither
}
}
_ => PsdStatus::Neither,
}
}
None => PsdStatus::Neither,
}
}
}
impl Expr {
pub fn curvature(&self) -> Curvature {
match self {
Expr::Variable(_) => Curvature::Affine,
Expr::Constant(_) => Curvature::Constant,
Expr::Add(a, b) => add_curvature(a.curvature(), b.curvature()),
Expr::Neg(a) => a.curvature().negate(),
Expr::Mul(a, b) => mul_curvature(a, b),
Expr::MatMul(a, b) => matmul_curvature(a, b),
Expr::Sum(a, _) => a.curvature(),
Expr::Reshape(a, _) => a.curvature(),
Expr::Index(a, _) => a.curvature(),
Expr::VStack(exprs) => combine_all_curvatures(exprs),
Expr::HStack(exprs) => combine_all_curvatures(exprs),
Expr::Transpose(a) => a.curvature(),
Expr::Trace(a) => a.curvature(),
Expr::Norm1(x) | Expr::Norm2(x) | Expr::NormInf(x) => {
if x.curvature().is_affine() {
Curvature::Convex
} else {
Curvature::Unknown
}
}
Expr::Abs(x) => {
if x.curvature().is_affine() {
Curvature::Convex
} else {
Curvature::Unknown
}
}
Expr::Pos(x) => {
if x.curvature().is_convex() {
Curvature::Convex
} else {
Curvature::Unknown
}
}
Expr::NegPart(x) => {
if x.curvature().is_concave() {
Curvature::Convex
} else {
Curvature::Unknown
}
}
Expr::Maximum(exprs) => {
if exprs.iter().all(|e| e.curvature().is_convex()) {
Curvature::Convex
} else {
Curvature::Unknown
}
}
Expr::Minimum(exprs) => {
if exprs.iter().all(|e| e.curvature().is_concave()) {
Curvature::Concave
} else {
Curvature::Unknown
}
}
Expr::QuadForm(x, p) => {
if !x.curvature().is_affine() {
return Curvature::Unknown;
}
if let Some(p_val) = p.constant_value() {
match PsdStatus::of_array(p_val) {
PsdStatus::Psd => Curvature::Convex,
PsdStatus::Nsd => Curvature::Concave,
PsdStatus::Neither => Curvature::Unknown,
}
} else {
Curvature::Unknown
}
}
Expr::SumSquares(x) => {
if x.curvature().is_affine() {
Curvature::Convex
} else {
Curvature::Unknown
}
}
Expr::QuadOverLin(x, y) => {
if x.curvature().is_affine() && y.curvature().is_concave() {
Curvature::Convex
} else {
Curvature::Unknown
}
}
Expr::Exp(x) => {
if x.curvature().is_affine() {
Curvature::Convex
} else {
Curvature::Unknown
}
}
Expr::Log(x) => {
if x.curvature().is_concave() {
Curvature::Concave
} else {
Curvature::Unknown
}
}
Expr::Entropy(x) => {
if x.curvature().is_affine() {
Curvature::Concave
} else {
Curvature::Unknown
}
}
Expr::Power(x, p) => {
if *p == 0.0 {
Curvature::Constant
} else if *p == 1.0 {
x.curvature()
} else if *p == 2.0 {
if x.curvature().is_affine() {
Curvature::Convex
} else {
Curvature::Unknown
}
} else if *p > 1.0 || *p < 0.0 {
if x.curvature().is_affine() {
Curvature::Convex
} else {
Curvature::Unknown
}
} else if *p > 0.0 && *p < 1.0 {
if x.curvature().is_affine() {
Curvature::Concave
} else {
Curvature::Unknown
}
} else {
Curvature::Unknown
}
}
Expr::Cumsum(x, _) => x.curvature(), Expr::Diag(x) => x.curvature(), }
}
pub fn is_convex(&self) -> bool {
self.curvature().is_convex()
}
pub fn is_concave(&self) -> bool {
self.curvature().is_concave()
}
pub fn is_affine(&self) -> bool {
self.curvature().is_affine()
}
}
fn mul_curvature(a: &Expr, b: &Expr) -> Curvature {
let ac = a.curvature();
let bc = b.curvature();
if ac.is_constant() && bc.is_constant() {
return Curvature::Constant;
}
if ac.is_constant() {
if let Some(arr) = a.constant_value() {
if let Some(scalar) = arr.as_scalar() {
return scalar_mul_curvature(scalar, bc);
}
}
if bc.is_affine() {
return Curvature::Affine;
}
return Curvature::Unknown;
}
if bc.is_constant() {
if let Some(arr) = b.constant_value() {
if let Some(scalar) = arr.as_scalar() {
return scalar_mul_curvature(scalar, ac);
}
}
if ac.is_affine() {
return Curvature::Affine;
}
return Curvature::Unknown;
}
Curvature::Unknown
}
fn matmul_curvature(a: &Expr, b: &Expr) -> Curvature {
let ac = a.curvature();
let bc = b.curvature();
if ac.is_constant() {
return bc;
}
if bc.is_constant() {
return ac;
}
Curvature::Unknown
}
fn combine_all_curvatures(exprs: &[Arc<Expr>]) -> Curvature {
if exprs.is_empty() {
return Curvature::Constant;
}
let mut result = Curvature::Constant;
for e in exprs {
result = add_curvature(result, e.curvature());
}
result
}
#[cfg(test)]
mod tests {
use super::*;
use crate::expr::{constant, variable};
#[test]
fn test_curvature_basics() {
assert!(Curvature::Constant.is_convex());
assert!(Curvature::Constant.is_concave());
assert!(Curvature::Constant.is_affine());
assert!(Curvature::Affine.is_convex());
assert!(Curvature::Affine.is_concave());
assert!(Curvature::Affine.is_affine());
assert!(Curvature::Convex.is_convex());
assert!(!Curvature::Convex.is_concave());
assert!(!Curvature::Convex.is_affine());
assert!(!Curvature::Concave.is_convex());
assert!(Curvature::Concave.is_concave());
assert!(!Curvature::Concave.is_affine());
}
#[test]
fn test_negate_curvature() {
assert_eq!(Curvature::Convex.negate(), Curvature::Concave);
assert_eq!(Curvature::Concave.negate(), Curvature::Convex);
assert_eq!(Curvature::Affine.negate(), Curvature::Affine);
assert_eq!(Curvature::Constant.negate(), Curvature::Constant);
}
#[test]
fn test_add_curvature() {
use Curvature::*;
assert_eq!(add_curvature(Convex, Convex), Convex);
assert_eq!(add_curvature(Concave, Concave), Concave);
assert_eq!(add_curvature(Affine, Affine), Affine);
assert_eq!(add_curvature(Convex, Affine), Convex);
assert_eq!(add_curvature(Concave, Affine), Concave);
assert_eq!(add_curvature(Convex, Concave), Unknown);
}
#[test]
fn test_variable_is_affine() {
let x = variable(5);
assert!(x.is_affine());
assert!(x.is_convex());
assert!(x.is_concave());
}
#[test]
fn test_constant_is_constant() {
let c = constant(5.0);
assert_eq!(c.curvature(), Curvature::Constant);
}
#[test]
fn test_norm_is_convex() {
let x = variable(5);
let n = Expr::Norm2(Arc::new(x));
assert_eq!(n.curvature(), Curvature::Convex);
assert!(n.is_convex());
assert!(!n.is_concave());
}
#[test]
fn test_neg_flips_curvature() {
let x = variable(5);
let n = Expr::Norm2(Arc::new(x));
let neg_n = Expr::Neg(Arc::new(n));
assert_eq!(neg_n.curvature(), Curvature::Concave);
}
}