use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use nalgebra::DMatrix;
use nalgebra_sparse::CscMatrix;
use super::shape::Shape;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ExprId(u64);
impl ExprId {
pub fn new() -> Self {
static NEXT_ID: AtomicU64 = AtomicU64::new(0);
ExprId(NEXT_ID.fetch_add(1, Ordering::SeqCst))
}
pub fn raw(&self) -> u64 {
self.0
}
}
impl Default for ExprId {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub enum Array {
Dense(DMatrix<f64>),
Sparse(CscMatrix<f64>),
Scalar(f64),
}
impl Array {
pub fn shape(&self) -> Shape {
match self {
Array::Dense(m) => Shape::matrix(m.nrows(), m.ncols()),
Array::Sparse(m) => Shape::matrix(m.nrows(), m.ncols()),
Array::Scalar(_) => Shape::scalar(),
}
}
pub fn size(&self) -> usize {
match self {
Array::Dense(m) => m.nrows() * m.ncols(),
Array::Sparse(m) => m.nrows() * m.ncols(),
Array::Scalar(_) => 1,
}
}
pub fn as_scalar(&self) -> Option<f64> {
match self {
Array::Scalar(v) => Some(*v),
Array::Dense(m) if m.nrows() == 1 && m.ncols() == 1 => Some(m[(0, 0)]),
_ => None,
}
}
pub fn is_nonneg(&self) -> bool {
match self {
Array::Scalar(v) => *v >= 0.0,
Array::Dense(m) => m.iter().all(|&v| v >= 0.0),
Array::Sparse(m) => m.values().iter().all(|&v| v >= 0.0),
}
}
pub fn is_nonpos(&self) -> bool {
match self {
Array::Scalar(v) => *v <= 0.0,
Array::Dense(m) => m.iter().all(|&v| v <= 0.0),
Array::Sparse(m) => {
m.values().iter().all(|&v| v <= 0.0)
}
}
}
pub fn is_psd(&self) -> Option<bool> {
match self {
Array::Scalar(v) => Some(*v >= 0.0),
Array::Dense(m) => {
if m.nrows() != m.ncols() {
return None;
}
let n = m.nrows();
for i in 0..n {
for j in (i + 1)..n {
if (m[(i, j)] - m[(j, i)]).abs() > 1e-10 {
return None;
}
}
}
Some(m.clone().cholesky().is_some())
}
Array::Sparse(_) => {
None
}
}
}
pub fn from_scalar(v: f64) -> Self {
Array::Scalar(v)
}
pub fn from_vec(v: Vec<f64>) -> Self {
let n = v.len();
Array::Dense(DMatrix::from_vec(n, 1, v))
}
pub fn from_matrix(m: DMatrix<f64>) -> Self {
Array::Dense(m)
}
}
impl From<f64> for Array {
fn from(v: f64) -> Self {
Array::Scalar(v)
}
}
impl From<Vec<f64>> for Array {
fn from(v: Vec<f64>) -> Self {
Array::from_vec(v)
}
}
impl From<DMatrix<f64>> for Array {
fn from(m: DMatrix<f64>) -> Self {
Array::Dense(m)
}
}
#[derive(Debug, Clone)]
pub struct VariableData {
pub id: ExprId,
pub shape: Shape,
pub name: Option<String>,
pub nonneg: bool,
pub nonpos: bool,
}
#[derive(Debug, Clone)]
pub struct ConstantData {
pub id: ExprId,
pub value: Array,
}
impl ConstantData {
pub fn shape(&self) -> Shape {
self.value.shape()
}
}
#[derive(Debug, Clone)]
pub struct IndexSpec {
pub ranges: Vec<Option<(usize, usize, usize)>>,
}
impl IndexSpec {
pub fn element(indices: Vec<usize>) -> Self {
IndexSpec {
ranges: indices.into_iter().map(|i| Some((i, i + 1, 1))).collect(),
}
}
pub fn range(start: usize, stop: usize) -> Self {
IndexSpec {
ranges: vec![Some((start, stop, 1))],
}
}
pub fn all() -> Self {
IndexSpec { ranges: vec![None] }
}
}
#[derive(Debug, Clone)]
pub enum Expr {
Variable(VariableData),
Constant(ConstantData),
Add(Arc<Expr>, Arc<Expr>),
Neg(Arc<Expr>),
Mul(Arc<Expr>, Arc<Expr>),
Sum(Arc<Expr>, Option<usize>),
Reshape(Arc<Expr>, Shape),
Index(Arc<Expr>, IndexSpec),
VStack(Vec<Arc<Expr>>),
HStack(Vec<Arc<Expr>>),
Transpose(Arc<Expr>),
Trace(Arc<Expr>),
MatMul(Arc<Expr>, Arc<Expr>),
Norm1(Arc<Expr>),
Norm2(Arc<Expr>),
NormInf(Arc<Expr>),
Abs(Arc<Expr>),
Pos(Arc<Expr>),
NegPart(Arc<Expr>),
Maximum(Vec<Arc<Expr>>),
Minimum(Vec<Arc<Expr>>),
QuadForm(Arc<Expr>, Arc<Expr>),
SumSquares(Arc<Expr>),
QuadOverLin(Arc<Expr>, Arc<Expr>),
Exp(Arc<Expr>),
Log(Arc<Expr>),
Entropy(Arc<Expr>),
Power(Arc<Expr>, f64),
Cumsum(Arc<Expr>, Option<usize>),
Diag(Arc<Expr>),
}
impl Expr {
pub fn shape(&self) -> Shape {
match self {
Expr::Variable(v) => v.shape.clone(),
Expr::Constant(c) => c.shape(),
Expr::Add(a, b) => a
.shape()
.broadcast(&b.shape())
.unwrap_or_else(Shape::scalar),
Expr::Neg(a) => a.shape(),
Expr::Mul(a, b) => a
.shape()
.broadcast(&b.shape())
.unwrap_or_else(Shape::scalar),
Expr::Sum(a, axis) => {
if axis.is_some() {
let dims = a.shape();
if dims.ndim() <= 1 {
Shape::scalar()
} else {
Shape::vector(dims.cols())
}
} else {
Shape::scalar()
}
}
Expr::Reshape(_, shape) => shape.clone(),
Expr::Index(a, spec) => {
let base = a.shape();
let mut new_dims = Vec::new();
for (i, r) in spec.ranges.iter().enumerate() {
match r {
Some((start, stop, step)) => {
let size = (stop - start + step - 1) / step;
if size > 1 {
new_dims.push(size);
}
}
None => {
if i < base.ndim() {
new_dims.push(base.dims()[i]);
}
}
}
}
if new_dims.is_empty() {
Shape::scalar()
} else {
Shape::from_dims(new_dims)
}
}
Expr::VStack(exprs) => {
if exprs.is_empty() {
return Shape::scalar();
}
let first = exprs[0].shape();
let total_rows: usize = exprs.iter().map(|e| e.shape().rows()).sum();
Shape::matrix(total_rows, first.cols())
}
Expr::HStack(exprs) => {
if exprs.is_empty() {
return Shape::scalar();
}
let first = exprs[0].shape();
let total_cols: usize = exprs.iter().map(|e| e.shape().cols()).sum();
Shape::matrix(first.rows(), total_cols)
}
Expr::Transpose(a) => a.shape().transpose(),
Expr::Trace(_) => Shape::scalar(),
Expr::MatMul(a, b) => a.shape().matmul(&b.shape()).unwrap_or_else(Shape::scalar),
Expr::Norm1(_) | Expr::Norm2(_) | Expr::NormInf(_) => Shape::scalar(),
Expr::Abs(a) | Expr::Pos(a) | Expr::NegPart(a) => a.shape(),
Expr::Maximum(exprs) | Expr::Minimum(exprs) => {
if exprs.is_empty() {
Shape::scalar()
} else {
exprs[0].shape()
}
}
Expr::QuadForm(_, _) | Expr::SumSquares(_) | Expr::QuadOverLin(_, _) => Shape::scalar(),
Expr::Exp(a) | Expr::Log(a) | Expr::Entropy(a) | Expr::Power(a, _) => a.shape(),
Expr::Cumsum(a, _) => a.shape(),
Expr::Diag(a) => {
let s = a.shape();
if s.is_vector() {
let n = s.size();
Shape::matrix(n, n)
} else {
let n = s.rows().min(s.cols());
Shape::vector(n)
}
}
}
}
pub fn variable_id(&self) -> Option<ExprId> {
match self {
Expr::Variable(v) => Some(v.id),
_ => None,
}
}
pub fn is_constant(&self) -> bool {
matches!(self, Expr::Constant(_))
}
pub fn is_variable(&self) -> bool {
matches!(self, Expr::Variable(_))
}
pub fn constant_value(&self) -> Option<&Array> {
match self {
Expr::Constant(c) => Some(&c.value),
_ => None,
}
}
pub fn variables(&self) -> Vec<ExprId> {
let mut vars = Vec::new();
self.collect_variables(&mut vars);
vars.sort_by_key(|id| id.0);
vars.dedup();
vars
}
fn collect_variables(&self, vars: &mut Vec<ExprId>) {
match self {
Expr::Variable(v) => vars.push(v.id),
Expr::Constant(_) => {}
Expr::Add(a, b) | Expr::Mul(a, b) | Expr::MatMul(a, b) => {
a.collect_variables(vars);
b.collect_variables(vars);
}
Expr::Neg(a)
| Expr::Sum(a, _)
| Expr::Reshape(a, _)
| Expr::Index(a, _)
| Expr::Transpose(a)
| Expr::Trace(a) => {
a.collect_variables(vars);
}
Expr::VStack(exprs) | Expr::HStack(exprs) => {
for e in exprs {
e.collect_variables(vars);
}
}
Expr::Norm1(a)
| Expr::Norm2(a)
| Expr::NormInf(a)
| Expr::Abs(a)
| Expr::Pos(a)
| Expr::NegPart(a)
| Expr::SumSquares(a) => {
a.collect_variables(vars);
}
Expr::Maximum(exprs) | Expr::Minimum(exprs) => {
for e in exprs {
e.collect_variables(vars);
}
}
Expr::QuadForm(x, p) | Expr::QuadOverLin(x, p) => {
x.collect_variables(vars);
p.collect_variables(vars);
}
Expr::Exp(a) | Expr::Log(a) | Expr::Entropy(a) | Expr::Power(a, _) => {
a.collect_variables(vars);
}
Expr::Cumsum(a, _) | Expr::Diag(a) => {
a.collect_variables(vars);
}
}
}
}
impl From<f64> for Expr {
fn from(value: f64) -> Self {
crate::expr::constant(value)
}
}
impl From<i32> for Expr {
fn from(value: i32) -> Self {
crate::expr::constant(value as f64)
}
}
impl From<&Expr> for Expr {
fn from(expr: &Expr) -> Self {
expr.clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_expr_id() {
let id1 = ExprId::new();
let id2 = ExprId::new();
assert_ne!(id1, id2);
}
#[test]
fn test_array_scalar() {
let arr = Array::Scalar(5.0);
assert_eq!(arr.as_scalar(), Some(5.0));
assert!(arr.is_nonneg());
assert!(!arr.is_nonpos());
}
#[test]
fn test_array_from_vec() {
let arr = Array::from_vec(vec![1.0, 2.0, 3.0]);
assert_eq!(arr.shape(), Shape::matrix(3, 1));
assert!(arr.is_nonneg());
}
#[test]
fn test_variable_shape() {
let var = Expr::Variable(VariableData {
id: ExprId::new(),
shape: Shape::vector(5),
name: Some("x".to_string()),
nonneg: false,
nonpos: false,
});
assert_eq!(var.shape(), Shape::vector(5));
assert!(var.is_variable());
}
#[test]
fn test_constant_shape() {
let c = Expr::Constant(ConstantData {
id: ExprId::new(),
value: Array::from_vec(vec![1.0, 2.0, 3.0]),
});
assert_eq!(c.shape(), Shape::matrix(3, 1));
assert!(c.is_constant());
}
}