use std::sync::Arc;
use crate::expr::Expr;
pub fn norm1(x: &Expr) -> Expr {
Expr::Norm1(Arc::new(x.clone()))
}
pub fn norm2(x: &Expr) -> Expr {
Expr::Norm2(Arc::new(x.clone()))
}
pub fn norm_inf(x: &Expr) -> Expr {
Expr::NormInf(Arc::new(x.clone()))
}
pub fn norm(x: &Expr, p: f64) -> Expr {
try_norm(x, p).expect("unsupported norm p-value")
}
pub fn try_norm(x: &Expr, p: f64) -> crate::Result<Expr> {
if p == 1.0 {
Ok(norm1(x))
} else if p == 2.0 {
Ok(norm2(x))
} else if p.is_infinite() {
Ok(norm_inf(x))
} else {
Err(crate::CvxError::InvalidProblem(format!(
"norm p={} is not supported; use p=1, 2, or inf",
p
)))
}
}
pub fn abs(x: &Expr) -> Expr {
Expr::Abs(Arc::new(x.clone()))
}
pub fn pos(x: &Expr) -> Expr {
Expr::Pos(Arc::new(x.clone()))
}
pub fn neg_part(x: &Expr) -> Expr {
Expr::NegPart(Arc::new(x.clone()))
}
pub fn maximum(exprs: Vec<Expr>) -> Expr {
if exprs.len() == 1 {
return exprs.into_iter().next().unwrap();
}
Expr::Maximum(exprs.into_iter().map(Arc::new).collect())
}
pub fn max2(a: &Expr, b: &Expr) -> Expr {
maximum(vec![a.clone(), b.clone()])
}
pub fn minimum(exprs: Vec<Expr>) -> Expr {
if exprs.len() == 1 {
return exprs.into_iter().next().unwrap();
}
Expr::Minimum(exprs.into_iter().map(Arc::new).collect())
}
pub fn min2(a: &Expr, b: &Expr) -> Expr {
minimum(vec![a.clone(), b.clone()])
}
pub fn quad_form(x: &Expr, p: &Expr) -> Expr {
Expr::QuadForm(Arc::new(x.clone()), Arc::new(p.clone()))
}
pub fn sum_squares(x: &Expr) -> Expr {
Expr::SumSquares(Arc::new(x.clone()))
}
pub fn quad_over_lin(x: &Expr, y: &Expr) -> Expr {
Expr::QuadOverLin(Arc::new(x.clone()), Arc::new(y.clone()))
}
pub fn exp(x: &Expr) -> Expr {
Expr::Exp(Arc::new(x.clone()))
}
pub fn log(x: &Expr) -> Expr {
Expr::Log(Arc::new(x.clone()))
}
pub fn entropy(x: &Expr) -> Expr {
Expr::Entropy(Arc::new(x.clone()))
}
pub fn power(x: &Expr, p: f64) -> Expr {
Expr::Power(Arc::new(x.clone()), p)
}
pub fn sqrt(x: &Expr) -> Expr {
power(x, 0.5)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dcp::Curvature;
use crate::expr::variable;
#[test]
fn test_norm2_convex() {
let x = variable(5);
let n = norm2(&x);
assert_eq!(n.curvature(), Curvature::Convex);
assert!(n.is_nonneg());
}
#[test]
fn test_norm1_convex() {
let x = variable(5);
let n = norm1(&x);
assert_eq!(n.curvature(), Curvature::Convex);
}
#[test]
fn test_abs_convex() {
let x = variable(5);
let a = abs(&x);
assert_eq!(a.curvature(), Curvature::Convex);
}
#[test]
fn test_pos_convex() {
let x = variable(5);
let p = pos(&x);
assert_eq!(p.curvature(), Curvature::Convex);
}
#[test]
fn test_maximum_convex() {
let x = variable(5);
let y = variable(5);
let m = maximum(vec![x, y]);
assert_eq!(m.curvature(), Curvature::Convex);
}
#[test]
fn test_minimum_concave() {
let x = variable(5);
let y = variable(5);
let m = minimum(vec![x, y]);
assert_eq!(m.curvature(), Curvature::Concave);
}
#[test]
fn test_sum_squares_convex() {
let x = variable(5);
let s = sum_squares(&x);
assert_eq!(s.curvature(), Curvature::Convex);
assert!(s.is_nonneg());
}
#[test]
fn test_quad_form_psd() {
use nalgebra::DMatrix;
let x = variable(2);
let p = crate::expr::constant_dmatrix(DMatrix::identity(2, 2));
let q = quad_form(&x, &p);
assert_eq!(q.curvature(), Curvature::Convex);
}
#[test]
fn test_norm_of_affine_is_convex() {
let x = variable(5);
let y = variable(5);
let z = &x + &y;
let n = norm2(&z);
assert_eq!(n.curvature(), Curvature::Convex);
}
#[test]
fn test_norm_of_convex_is_unknown() {
let x = variable(5);
let n1 = norm2(&x);
let n2 = norm2(&n1);
assert_eq!(n2.curvature(), Curvature::Unknown);
}
}