opensrdk_symbolic_computation/
constant_value.rs

1use opensrdk_linear_algebra::{sparse::SparseTensor, Matrix, Tensor};
2use serde::{Deserialize, Serialize};
3
4use crate::Expression;
5
6#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
7pub enum ConstantValue {
8    Scalar(f64),
9    Tensor(SparseTensor),
10    Matrix(Matrix),
11}
12
13impl ConstantValue {
14    pub fn sizes(&self) -> Vec<usize> {
15        match self {
16            ConstantValue::Scalar(_) => vec![],
17            ConstantValue::Tensor(v) => {
18                (0..v.rank()).into_iter().map(|rank| v.size(rank)).collect()
19            }
20            ConstantValue::Matrix(v) => vec![v.rows(), v.cols()],
21        }
22    }
23
24    pub fn elems(&self) -> Vec<f64> {
25        match self {
26            ConstantValue::Scalar(v) => vec![*v],
27            ConstantValue::Tensor(v) => v.elems().into_iter().map(|(_, v)| *v).collect(),
28            ConstantValue::Matrix(v) => v.elems().to_vec(),
29        }
30    }
31
32    pub fn elems_mut(&mut self) -> Vec<&mut f64> {
33        match self {
34            ConstantValue::Scalar(v) => vec![v],
35            ConstantValue::Tensor(v) => v.elems_mut().into_iter().map(|(_, v)| v).collect(),
36            ConstantValue::Matrix(v) => v.elems_mut().iter_mut().collect(),
37        }
38    }
39
40    pub fn into_scalar(&self) -> f64 {
41        if let ConstantValue::Scalar(v) = self {
42            *v
43        } else {
44            panic!()
45        }
46    }
47
48    pub fn into_tensor(self) -> SparseTensor {
49        if let ConstantValue::Tensor(v) = self {
50            v
51        } else {
52            panic!()
53        }
54    }
55
56    pub fn into_tensor_ref(&self) -> &SparseTensor {
57        if let ConstantValue::Tensor(v) = self {
58            v
59        } else {
60            panic!()
61        }
62    }
63
64    pub fn into_matrix(self) -> Matrix {
65        if let ConstantValue::Matrix(v) = self {
66            v
67        } else {
68            panic!()
69        }
70    }
71
72    pub fn into_matrix_ref(&self) -> &Matrix {
73        if let ConstantValue::Matrix(v) = self {
74            v
75        } else {
76            panic!()
77        }
78    }
79}
80
81// impl From<Expression> for ConstantValue {
82//     fn from(v: Expression) -> Self {
83//         match v {
84//             Expression::Constant(a) => a,
85//             Expression::Variable(_, _) => todo!(),
86//             Expression::PartialVariable(_) => todo!(),
87//             Expression::Add(_, _) => todo!(),
88//             Expression::Sub(_, _) => todo!(),
89//             Expression::Mul(_, _) => todo!(),
90//             Expression::Div(_, _) => todo!(),
91//             Expression::Neg(_) => todo!(),
92//             Expression::Transcendental(_) => todo!(),
93//             Expression::Tensor(_) => todo!(),
94//             Expression::Matrix(_) => todo!(),
95//         }
96//     }
97// }
98
99impl ConstantValue {
100    pub fn add(&self, rhs: ConstantValue) -> ConstantValue {
101        match (self, rhs) {
102            (ConstantValue::Scalar(lhs), ConstantValue::Scalar(rhs)) => {
103                ConstantValue::Scalar(lhs + rhs)
104            }
105            (ConstantValue::Scalar(lhs), ConstantValue::Tensor(rhs)) => {
106                ConstantValue::Tensor(lhs + rhs)
107            }
108            (ConstantValue::Scalar(lhs), ConstantValue::Matrix(rhs)) => {
109                ConstantValue::Matrix(lhs + rhs)
110            }
111            (ConstantValue::Tensor(lhs), ConstantValue::Tensor(rhs)) => {
112                ConstantValue::Tensor(lhs + rhs)
113            }
114            (ConstantValue::Tensor(lhs), ConstantValue::Scalar(rhs)) => {
115                ConstantValue::Tensor(lhs.clone() + rhs)
116            }
117            (ConstantValue::Matrix(lhs), ConstantValue::Matrix(rhs)) => {
118                ConstantValue::Matrix(lhs + rhs)
119            }
120            (ConstantValue::Matrix(lhs), ConstantValue::Scalar(rhs)) => {
121                ConstantValue::Matrix(lhs.clone() + rhs)
122            }
123            _ => panic!(),
124        }
125    }
126
127    pub fn sub(&self, rhs: ConstantValue) -> ConstantValue {
128        match (self, rhs) {
129            (ConstantValue::Scalar(lhs), ConstantValue::Scalar(rhs)) => {
130                ConstantValue::Scalar(lhs - rhs)
131            }
132            (ConstantValue::Scalar(lhs), ConstantValue::Tensor(rhs)) => {
133                ConstantValue::Tensor(lhs - rhs)
134            }
135            (ConstantValue::Scalar(lhs), ConstantValue::Matrix(rhs)) => {
136                ConstantValue::Matrix(lhs - rhs)
137            }
138            (ConstantValue::Tensor(lhs), ConstantValue::Tensor(rhs)) => {
139                ConstantValue::Tensor(lhs - rhs)
140            }
141            (ConstantValue::Tensor(lhs), ConstantValue::Scalar(rhs)) => {
142                ConstantValue::Tensor(lhs.clone() - rhs)
143            }
144            (ConstantValue::Matrix(lhs), ConstantValue::Matrix(rhs)) => {
145                ConstantValue::Matrix(lhs - rhs)
146            }
147            (ConstantValue::Matrix(lhs), ConstantValue::Scalar(rhs)) => {
148                ConstantValue::Matrix(lhs.clone() - rhs)
149            }
150            _ => panic!(),
151        }
152    }
153
154    pub fn mul(&self, rhs: ConstantValue) -> ConstantValue {
155        match (self, rhs) {
156            (ConstantValue::Scalar(lhs), ConstantValue::Scalar(rhs)) => {
157                ConstantValue::Scalar(lhs * rhs)
158            }
159            (ConstantValue::Scalar(lhs), ConstantValue::Tensor(rhs)) => {
160                ConstantValue::Tensor(lhs * rhs)
161            }
162            (ConstantValue::Scalar(lhs), ConstantValue::Matrix(rhs)) => {
163                ConstantValue::Matrix(lhs * rhs)
164            }
165            (ConstantValue::Tensor(lhs), ConstantValue::Tensor(rhs)) => {
166                ConstantValue::Tensor(lhs * rhs)
167            }
168            (ConstantValue::Tensor(lhs), ConstantValue::Scalar(rhs)) => {
169                ConstantValue::Tensor(lhs.clone() * rhs)
170            }
171            (ConstantValue::Matrix(lhs), ConstantValue::Matrix(rhs)) => {
172                ConstantValue::Matrix(lhs * rhs)
173            }
174            (ConstantValue::Matrix(lhs), ConstantValue::Scalar(rhs)) => {
175                ConstantValue::Matrix(lhs.clone() * rhs)
176            }
177            _ => panic!(),
178        }
179    }
180
181    pub fn div(self, rhs: &ConstantValue) -> ConstantValue {
182        match (self, rhs) {
183            (ConstantValue::Scalar(lhs), ConstantValue::Scalar(rhs)) => {
184                ConstantValue::Scalar(lhs / rhs)
185            }
186            (ConstantValue::Scalar(lhs), ConstantValue::Tensor(rhs)) => {
187                ConstantValue::Tensor(lhs / rhs.clone())
188            }
189            (ConstantValue::Scalar(lhs), ConstantValue::Matrix(rhs)) => {
190                ConstantValue::Matrix(lhs / rhs.clone())
191            }
192            (ConstantValue::Tensor(lhs), ConstantValue::Tensor(rhs)) => {
193                ConstantValue::Tensor(lhs / rhs)
194            }
195            (ConstantValue::Tensor(lhs), ConstantValue::Scalar(rhs)) => {
196                ConstantValue::Tensor(lhs / rhs)
197            }
198            (ConstantValue::Matrix(lhs), ConstantValue::Matrix(rhs)) => {
199                ConstantValue::Matrix(lhs / rhs)
200            }
201            (ConstantValue::Matrix(lhs), ConstantValue::Scalar(rhs)) => {
202                ConstantValue::Matrix(lhs / rhs)
203            }
204            _ => panic!(),
205        }
206    }
207}