use std::sync::Arc;
use super::expression::{Expr, ExprId, VariableData};
use super::shape::Shape;
#[derive(Default)]
pub struct VariableBuilder {
shape: Shape,
name: Option<String>,
nonneg: bool,
nonpos: bool,
}
impl VariableBuilder {
pub fn new(shape: impl Into<Shape>) -> Self {
Self {
shape: shape.into(),
..Default::default()
}
}
pub fn scalar() -> Self {
Self::new(Shape::scalar())
}
pub fn vector(n: usize) -> Self {
Self::new(Shape::vector(n))
}
pub fn matrix(m: usize, n: usize) -> Self {
Self::new(Shape::matrix(m, n))
}
pub fn name(mut self, name: impl Into<String>) -> Self {
self.name = Some(name.into());
self
}
pub fn nonneg(mut self) -> Self {
self.nonneg = true;
self.nonpos = false; self
}
pub fn nonpos(mut self) -> Self {
self.nonpos = true;
self.nonneg = false; self
}
pub fn build(self) -> Expr {
Expr::Variable(VariableData {
id: ExprId::new(),
shape: self.shape,
name: self.name,
nonneg: self.nonneg,
nonpos: self.nonpos,
})
}
}
pub fn variable(shape: impl Into<Shape>) -> Expr {
VariableBuilder::new(shape).build()
}
pub trait VariableExt {
fn nonneg(self) -> Expr;
fn nonpos(self) -> Expr;
fn named(self, name: impl Into<String>) -> Expr;
}
impl VariableExt for Expr {
fn nonneg(self) -> Expr {
match self {
Expr::Variable(mut v) => {
v.nonneg = true;
v.nonpos = false;
Expr::Variable(v)
}
other => other,
}
}
fn nonpos(self) -> Expr {
match self {
Expr::Variable(mut v) => {
v.nonpos = true;
v.nonneg = false;
Expr::Variable(v)
}
other => other,
}
}
fn named(self, name: impl Into<String>) -> Expr {
match self {
Expr::Variable(mut v) => {
v.name = Some(name.into());
Expr::Variable(v)
}
other => other,
}
}
}
pub fn named_variable(name: impl Into<String>, shape: impl Into<Shape>) -> Expr {
VariableBuilder::new(shape).name(name).build()
}
pub fn nonneg_variable(shape: impl Into<Shape>) -> Expr {
VariableBuilder::new(shape).nonneg().build()
}
pub fn nonpos_variable(shape: impl Into<Shape>) -> Expr {
VariableBuilder::new(shape).nonpos().build()
}
pub fn scalar_var() -> Expr {
VariableBuilder::scalar().build()
}
pub fn vector_var(n: usize) -> Expr {
VariableBuilder::vector(n).build()
}
pub fn matrix_var(m: usize, n: usize) -> Expr {
VariableBuilder::matrix(m, n).build()
}
pub fn var(shape: impl Into<Shape>) -> Arc<Expr> {
Arc::new(variable(shape))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_variable_builder() {
let x = VariableBuilder::vector(5).name("x").nonneg().build();
if let Expr::Variable(v) = &x {
assert_eq!(v.shape, Shape::vector(5));
assert_eq!(v.name, Some("x".to_string()));
assert!(v.nonneg);
assert!(!v.nonpos);
} else {
panic!("Expected Variable");
}
}
#[test]
fn test_variable_function() {
let x = variable((3, 4));
assert_eq!(x.shape(), Shape::matrix(3, 4));
}
#[test]
fn test_variable_ext() {
let x = variable(5).nonneg().named("x");
if let Expr::Variable(v) = &x {
assert!(v.nonneg);
assert_eq!(v.name, Some("x".to_string()));
} else {
panic!("Expected Variable");
}
}
#[test]
fn test_convenience_functions() {
assert_eq!(scalar_var().shape(), Shape::scalar());
assert_eq!(vector_var(5).shape(), Shape::vector(5));
assert_eq!(matrix_var(3, 4).shape(), Shape::matrix(3, 4));
}
}