opensrdk_symbolic_computation/expression/
mod.rs

1pub mod assign;
2pub mod differential;
3pub mod matrix_expression;
4pub mod operators;
5pub mod partial_variable;
6pub mod size;
7pub mod tensor_expression;
8pub mod tex_code;
9pub mod transcendental_expression;
10pub mod variable;
11
12pub use assign::*;
13pub use differential::*;
14pub use matrix_expression::*;
15use opensrdk_linear_algebra::{sparse::SparseTensor, Matrix};
16pub use partial_variable::*;
17pub use size::*;
18pub use tensor_expression::*;
19pub use tex_code::*;
20pub use transcendental_expression::*;
21pub use variable::*;
22
23use crate::{ConstantValue, ExpressionArray};
24use serde::{Deserialize, Serialize};
25use std::collections::HashMap;
26
27#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
28pub enum Expression {
29    Variable(String, Vec<Size>),
30    Constant(ConstantValue),
31    PartialVariable(ExpressionArray),
32    Add(Box<Expression>, Box<Expression>),
33    Sub(Box<Expression>, Box<Expression>),
34    Mul(Box<Expression>, Box<Expression>),
35    Div(Box<Expression>, Box<Expression>),
36    Neg(Box<Expression>),
37    Transcendental(Box<TranscendentalExpression>),
38    Tensor(Box<TensorExpression>),
39    Matrix(Box<MatrixExpression>),
40}
41
42impl From<f64> for Expression {
43    fn from(v: f64) -> Self {
44        Expression::Constant(ConstantValue::Scalar(v))
45    }
46}
47
48impl From<Vec<f64>> for Expression {
49    fn from(v: Vec<f64>) -> Self {
50        Expression::Constant(ConstantValue::Tensor(v.into()))
51    }
52}
53
54impl From<SparseTensor> for Expression {
55    fn from(v: SparseTensor) -> Self {
56        Expression::Constant(ConstantValue::Tensor(v))
57    }
58}
59
60impl From<Matrix> for Expression {
61    fn from(v: Matrix) -> Self {
62        Expression::Constant(ConstantValue::Matrix(v))
63    }
64}
65
66impl From<ConstantValue> for Expression {
67    fn from(v: ConstantValue) -> Self {
68        match v {
69            ConstantValue::Scalar(v) => v.into(),
70            ConstantValue::Matrix(v) => v.into(),
71            ConstantValue::Tensor(v) => v.into(),
72        }
73    }
74}
75
76impl From<Expression> for ConstantValue {
77    fn from(v: Expression) -> Self {
78        match v {
79            Expression::Constant(a) => a,
80            Expression::Variable(_, _) => todo!(),
81            Expression::PartialVariable(_) => todo!(),
82            Expression::Add(_, _) => todo!(),
83            Expression::Sub(_, _) => todo!(),
84            Expression::Mul(_, _) => todo!(),
85            Expression::Div(_, _) => todo!(),
86            Expression::Neg(_) => todo!(),
87            Expression::Transcendental(_) => todo!(),
88            Expression::Tensor(_) => todo!(),
89            Expression::Matrix(_) => todo!(),
90        }
91    }
92}
93
94#[cfg(test)]
95mod tests {
96    use std::collections::HashMap;
97
98    use opensrdk_linear_algebra::sparse::SparseTensor;
99
100    use crate::{ConstantValue, Expression};
101
102    #[test]
103    fn it_works() {
104        let a = 5.0f64;
105        let b = vec![a; 8];
106        let mut hash = HashMap::new();
107        hash.insert(vec![3usize; 8], 2.0);
108        hash.insert(vec![1usize; 8], 3.0);
109        hash.insert(vec![4usize; 8], 4.0);
110        hash.insert(vec![5usize; 8], 2.0);
111        let c = SparseTensor::from(vec![6usize; 8], hash).unwrap();
112
113        let ea = Expression::from(a);
114        let eb = Expression::from(b);
115        let ec = Expression::from(c);
116        println!("a {:#?}", ea);
117        println!("b {:#?}", eb);
118        println!("c {:#?}", ec);
119
120        let a_rev: ConstantValue = ea.into();
121        let b_rev: ConstantValue = eb.into();
122        let c_rev: ConstantValue = ec.into();
123
124        println!("a {:#?}", a_rev);
125        println!("b {:#?}", b_rev);
126        println!("c {:#?}", c_rev);
127    }
128}