1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
pub mod assign;
pub mod differential;
pub mod matrix_expression;
pub mod operators;
pub mod partial_variable;
pub mod size;
pub mod tensor_expression;
pub mod tex_code;
pub mod transcendental_expression;
pub mod variable;

pub use assign::*;
pub use differential::*;
pub use matrix_expression::*;
use opensrdk_linear_algebra::{sparse::SparseTensor, Matrix};
pub use partial_variable::*;
pub use size::*;
pub use tensor_expression::*;
pub use tex_code::*;
pub use transcendental_expression::*;
pub use variable::*;

use crate::{ConstantValue, ExpressionArray};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;

#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub enum Expression {
    Variable(String, Vec<Size>),
    Constant(ConstantValue),
    PartialVariable(ExpressionArray),
    Add(Box<Expression>, Box<Expression>),
    Sub(Box<Expression>, Box<Expression>),
    Mul(Box<Expression>, Box<Expression>),
    Div(Box<Expression>, Box<Expression>),
    Neg(Box<Expression>),
    Transcendental(Box<TranscendentalExpression>),
    Tensor(Box<TensorExpression>),
    Matrix(Box<MatrixExpression>),
}

impl From<f64> for Expression {
    fn from(v: f64) -> Self {
        Expression::Constant(ConstantValue::Scalar(v))
    }
}

impl From<Vec<f64>> for Expression {
    fn from(v: Vec<f64>) -> Self {
        Expression::Constant(ConstantValue::Tensor(v.into()))
    }
}

impl From<SparseTensor> for Expression {
    fn from(v: SparseTensor) -> Self {
        Expression::Constant(ConstantValue::Tensor(v))
    }
}

impl From<Matrix> for Expression {
    fn from(v: Matrix) -> Self {
        Expression::Constant(ConstantValue::Matrix(v))
    }
}

impl From<ConstantValue> for Expression {
    fn from(v: ConstantValue) -> Self {
        match v {
            ConstantValue::Scalar(v) => v.into(),
            ConstantValue::Matrix(v) => v.into(),
            ConstantValue::Tensor(v) => v.into(),
        }
    }
}

impl From<Expression> for ConstantValue {
    fn from(v: Expression) -> Self {
        match v {
            Expression::Constant(a) => a,
            Expression::Variable(_, _) => todo!(),
            Expression::PartialVariable(_) => todo!(),
            Expression::Add(_, _) => todo!(),
            Expression::Sub(_, _) => todo!(),
            Expression::Mul(_, _) => todo!(),
            Expression::Div(_, _) => todo!(),
            Expression::Neg(_) => todo!(),
            Expression::Transcendental(_) => todo!(),
            Expression::Tensor(_) => todo!(),
            Expression::Matrix(_) => todo!(),
        }
    }
}

#[cfg(test)]
mod tests {
    use std::collections::HashMap;

    use opensrdk_linear_algebra::sparse::SparseTensor;

    use crate::{ConstantValue, Expression};

    #[test]
    fn it_works() {
        let a = 5.0f64;
        let b = vec![a; 8];
        let mut hash = HashMap::new();
        hash.insert(vec![3usize; 8], 2.0);
        hash.insert(vec![1usize; 8], 3.0);
        hash.insert(vec![4usize; 8], 4.0);
        hash.insert(vec![5usize; 8], 2.0);
        let c = SparseTensor::from(vec![6usize; 8], hash).unwrap();

        let ea = Expression::from(a);
        let eb = Expression::from(b);
        let ec = Expression::from(c);
        println!("a {:#?}", ea);
        println!("b {:#?}", eb);
        println!("c {:#?}", ec);

        let a_rev: ConstantValue = ea.into();
        let b_rev: ConstantValue = eb.into();
        let c_rev: ConstantValue = ec.into();

        println!("a {:#?}", a_rev);
        println!("b {:#?}", b_rev);
        println!("c {:#?}", c_rev);
    }
}