use std::collections::HashMap;
use nalgebra::DMatrix;
use nalgebra_sparse::CscMatrix;
use crate::expr::{ExprId, Shape};
use crate::sparse::{csc_add, csc_neg, csc_scale};
#[derive(Debug, Clone)]
pub struct LinExpr {
pub coeffs: HashMap<ExprId, CscMatrix<f64>>,
pub constant: DMatrix<f64>,
pub shape: Shape,
}
impl LinExpr {
pub fn zeros(shape: Shape) -> Self {
let rows = shape.rows();
let cols = shape.cols();
LinExpr {
coeffs: HashMap::new(),
constant: DMatrix::zeros(rows, cols),
shape,
}
}
pub fn variable(var_id: ExprId, shape: Shape) -> Self {
let size = shape.size();
let identity = CscMatrix::identity(size);
let mut coeffs = HashMap::new();
coeffs.insert(var_id, identity);
LinExpr {
coeffs,
constant: DMatrix::zeros(shape.rows(), shape.cols()),
shape,
}
}
pub fn constant(value: DMatrix<f64>) -> Self {
let shape = Shape::matrix(value.nrows(), value.ncols());
LinExpr {
coeffs: HashMap::new(),
constant: value,
shape,
}
}
pub fn scalar(value: f64) -> Self {
LinExpr {
coeffs: HashMap::new(),
constant: DMatrix::from_element(1, 1, value),
shape: Shape::scalar(),
}
}
pub fn is_constant(&self) -> bool {
self.coeffs.is_empty()
}
pub fn size(&self) -> usize {
self.shape.size()
}
pub fn add(&self, other: &LinExpr) -> LinExpr {
let coeffs = if self.coeffs.is_empty() {
other.coeffs.clone()
} else if other.coeffs.is_empty() {
self.coeffs.clone()
} else {
let mut coeffs = self.coeffs.clone();
coeffs.reserve(other.coeffs.len());
for (var_id, coeff) in &other.coeffs {
coeffs
.entry(*var_id)
.and_modify(|c| *c = csc_add(c, coeff))
.or_insert_with(|| coeff.clone());
}
coeffs
};
let new_constant = if self.constant.nrows() == other.constant.nrows()
&& self.constant.ncols() == other.constant.ncols()
{
&self.constant + &other.constant
} else if other.constant.nrows() == 1 && other.constant.ncols() == 1 {
let scalar = other.constant[(0, 0)];
self.constant.map(|v| v + scalar)
} else if self.constant.nrows() == 1 && self.constant.ncols() == 1 {
let scalar = self.constant[(0, 0)];
other.constant.map(|v| v + scalar)
} else {
self.constant.clone()
};
let new_shape = if self.shape.size() >= other.shape.size() {
self.shape.clone()
} else {
other.shape.clone()
};
LinExpr {
coeffs,
constant: new_constant,
shape: new_shape,
}
}
pub fn neg(&self) -> LinExpr {
let coeffs = self.coeffs.iter().map(|(k, v)| (*k, csc_neg(v))).collect();
LinExpr {
coeffs,
constant: -&self.constant,
shape: self.shape.clone(),
}
}
pub fn scale(&self, scalar: f64) -> LinExpr {
let coeffs = self
.coeffs
.iter()
.map(|(k, v)| (*k, csc_scale(v, scalar)))
.collect();
LinExpr {
coeffs,
constant: &self.constant * scalar,
shape: self.shape.clone(),
}
}
pub fn variables(&self) -> Vec<ExprId> {
let mut vars: Vec<_> = self.coeffs.keys().copied().collect();
vars.sort_by_key(|id| id.raw());
vars
}
}
#[derive(Debug, Clone)]
pub struct QuadExpr {
pub quad_coeffs: HashMap<(ExprId, ExprId), CscMatrix<f64>>,
pub linear: LinExpr,
pub constant: f64,
}
impl QuadExpr {
pub fn from_linear(linear: LinExpr) -> Self {
let constant = if linear.constant.nrows() == 1 && linear.constant.ncols() == 1 {
linear.constant[(0, 0)]
} else {
0.0
};
QuadExpr {
quad_coeffs: HashMap::new(),
linear: LinExpr {
coeffs: linear.coeffs,
constant: DMatrix::zeros(1, 1),
shape: Shape::scalar(),
},
constant,
}
}
pub fn quadratic(var_id: ExprId, p: CscMatrix<f64>) -> Self {
let mut quad_coeffs = HashMap::new();
quad_coeffs.insert((var_id, var_id), p);
QuadExpr {
quad_coeffs,
linear: LinExpr::zeros(Shape::scalar()),
constant: 0.0,
}
}
pub fn is_linear(&self) -> bool {
self.quad_coeffs.is_empty()
}
pub fn add(&self, other: &QuadExpr) -> QuadExpr {
let mut quad_coeffs = self.quad_coeffs.clone();
for (key, coeff) in &other.quad_coeffs {
quad_coeffs
.entry(*key)
.and_modify(|c| *c = csc_add(c, coeff))
.or_insert_with(|| coeff.clone());
}
QuadExpr {
quad_coeffs,
linear: self.linear.add(&other.linear),
constant: self.constant + other.constant,
}
}
pub fn scale(&self, scalar: f64) -> QuadExpr {
let quad_coeffs = self
.quad_coeffs
.iter()
.map(|(k, v)| (*k, csc_scale(v, scalar)))
.collect();
QuadExpr {
quad_coeffs,
linear: self.linear.scale(scalar),
constant: self.constant * scalar,
}
}
pub fn variables(&self) -> Vec<ExprId> {
let mut vars: Vec<_> = self.linear.variables();
for (v1, v2) in self.quad_coeffs.keys() {
vars.push(*v1);
vars.push(*v2);
}
vars.sort_by_key(|id| id.raw());
vars.dedup();
vars
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_lin_expr_zeros() {
let e = LinExpr::zeros(Shape::vector(5));
assert!(e.is_constant());
assert_eq!(e.size(), 5);
}
#[test]
fn test_lin_expr_variable() {
let var_id = ExprId::new();
let e = LinExpr::variable(var_id, Shape::vector(3));
assert!(!e.is_constant());
assert_eq!(e.variables(), vec![var_id]);
}
#[test]
fn test_lin_expr_add() {
let var1 = ExprId::new();
let var2 = ExprId::new();
let e1 = LinExpr::variable(var1, Shape::vector(3));
let e2 = LinExpr::variable(var2, Shape::vector(3));
let sum = e1.add(&e2);
assert_eq!(sum.variables().len(), 2);
}
#[test]
fn test_quad_expr_from_linear() {
let var_id = ExprId::new();
let lin = LinExpr::variable(var_id, Shape::scalar());
let quad = QuadExpr::from_linear(lin);
assert!(quad.is_linear());
}
}