use std::sync::Arc;
use nalgebra::DMatrix;
use nalgebra_sparse::CscMatrix;
use super::expression::{Array, ConstantData, Expr, ExprId};
use super::shape::Shape;
pub fn constant(value: f64) -> Expr {
Expr::Constant(ConstantData {
id: ExprId::new(),
value: Array::Scalar(value),
})
}
pub fn constant_vec(values: Vec<f64>) -> Expr {
Expr::Constant(ConstantData {
id: ExprId::new(),
value: Array::from_vec(values),
})
}
pub fn constant_matrix(values: Vec<f64>, rows: usize, cols: usize) -> Expr {
let matrix = DMatrix::from_vec(rows, cols, values);
Expr::Constant(ConstantData {
id: ExprId::new(),
value: Array::Dense(matrix),
})
}
pub fn constant_dmatrix(matrix: DMatrix<f64>) -> Expr {
Expr::Constant(ConstantData {
id: ExprId::new(),
value: Array::Dense(matrix),
})
}
pub fn constant_sparse(matrix: CscMatrix<f64>) -> Expr {
Expr::Constant(ConstantData {
id: ExprId::new(),
value: Array::Sparse(matrix),
})
}
pub fn zeros(shape: impl Into<Shape>) -> Expr {
let shape = shape.into();
let value = if shape.is_scalar() {
Array::Scalar(0.0)
} else {
Array::Dense(DMatrix::zeros(shape.rows(), shape.cols()))
};
Expr::Constant(ConstantData {
id: ExprId::new(),
value,
})
}
pub fn ones(shape: impl Into<Shape>) -> Expr {
let shape = shape.into();
let value = if shape.is_scalar() {
Array::Scalar(1.0)
} else {
Array::Dense(DMatrix::from_element(shape.rows(), shape.cols(), 1.0))
};
Expr::Constant(ConstantData {
id: ExprId::new(),
value,
})
}
pub fn eye(n: usize) -> Expr {
Expr::Constant(ConstantData {
id: ExprId::new(),
value: Array::Dense(DMatrix::identity(n, n)),
})
}
pub trait IntoConstant {
fn into_constant(self) -> Expr;
}
impl IntoConstant for f64 {
fn into_constant(self) -> Expr {
constant(self)
}
}
impl IntoConstant for i32 {
fn into_constant(self) -> Expr {
constant(self as f64)
}
}
impl IntoConstant for Vec<f64> {
fn into_constant(self) -> Expr {
constant_vec(self)
}
}
impl IntoConstant for &[f64] {
fn into_constant(self) -> Expr {
constant_vec(self.to_vec())
}
}
impl IntoConstant for DMatrix<f64> {
fn into_constant(self) -> Expr {
constant_dmatrix(self)
}
}
impl IntoConstant for CscMatrix<f64> {
fn into_constant(self) -> Expr {
constant_sparse(self)
}
}
pub fn const_arc(value: f64) -> Arc<Expr> {
Arc::new(constant(value))
}
pub fn const_vec_arc(values: Vec<f64>) -> Arc<Expr> {
Arc::new(constant_vec(values))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_constant_scalar() {
let c = constant(5.0);
if let Expr::Constant(data) = &c {
assert_eq!(data.value.as_scalar(), Some(5.0));
} else {
panic!("Expected Constant");
}
}
#[test]
fn test_constant_vec() {
let c = constant_vec(vec![1.0, 2.0, 3.0]);
assert_eq!(c.shape(), Shape::matrix(3, 1));
}
#[test]
fn test_zeros() {
let z = zeros((3, 4));
assert_eq!(z.shape(), Shape::matrix(3, 4));
if let Expr::Constant(data) = &z {
assert!(data.value.is_nonneg());
assert!(data.value.is_nonpos());
} else {
panic!("Expected Constant");
}
}
#[test]
fn test_ones() {
let o = ones(5);
assert_eq!(o.shape(), Shape::matrix(5, 1));
}
#[test]
fn test_eye() {
let e = eye(3);
assert_eq!(e.shape(), Shape::matrix(3, 3));
}
#[test]
fn test_into_constant() {
let _: Expr = 5.0.into_constant();
let _: Expr = vec![1.0, 2.0].into_constant();
}
}