opensrdk_symbolic_computation/
constant_value.rs1use opensrdk_linear_algebra::{sparse::SparseTensor, Matrix, Tensor};
2use serde::{Deserialize, Serialize};
3
4use crate::Expression;
5
6#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
7pub enum ConstantValue {
8 Scalar(f64),
9 Tensor(SparseTensor),
10 Matrix(Matrix),
11}
12
13impl ConstantValue {
14 pub fn sizes(&self) -> Vec<usize> {
15 match self {
16 ConstantValue::Scalar(_) => vec![],
17 ConstantValue::Tensor(v) => {
18 (0..v.rank()).into_iter().map(|rank| v.size(rank)).collect()
19 }
20 ConstantValue::Matrix(v) => vec![v.rows(), v.cols()],
21 }
22 }
23
24 pub fn elems(&self) -> Vec<f64> {
25 match self {
26 ConstantValue::Scalar(v) => vec![*v],
27 ConstantValue::Tensor(v) => v.elems().into_iter().map(|(_, v)| *v).collect(),
28 ConstantValue::Matrix(v) => v.elems().to_vec(),
29 }
30 }
31
32 pub fn elems_mut(&mut self) -> Vec<&mut f64> {
33 match self {
34 ConstantValue::Scalar(v) => vec![v],
35 ConstantValue::Tensor(v) => v.elems_mut().into_iter().map(|(_, v)| v).collect(),
36 ConstantValue::Matrix(v) => v.elems_mut().iter_mut().collect(),
37 }
38 }
39
40 pub fn into_scalar(&self) -> f64 {
41 if let ConstantValue::Scalar(v) = self {
42 *v
43 } else {
44 panic!()
45 }
46 }
47
48 pub fn into_tensor(self) -> SparseTensor {
49 if let ConstantValue::Tensor(v) = self {
50 v
51 } else {
52 panic!()
53 }
54 }
55
56 pub fn into_tensor_ref(&self) -> &SparseTensor {
57 if let ConstantValue::Tensor(v) = self {
58 v
59 } else {
60 panic!()
61 }
62 }
63
64 pub fn into_matrix(self) -> Matrix {
65 if let ConstantValue::Matrix(v) = self {
66 v
67 } else {
68 panic!()
69 }
70 }
71
72 pub fn into_matrix_ref(&self) -> &Matrix {
73 if let ConstantValue::Matrix(v) = self {
74 v
75 } else {
76 panic!()
77 }
78 }
79}
80
81impl ConstantValue {
100 pub fn add(&self, rhs: ConstantValue) -> ConstantValue {
101 match (self, rhs) {
102 (ConstantValue::Scalar(lhs), ConstantValue::Scalar(rhs)) => {
103 ConstantValue::Scalar(lhs + rhs)
104 }
105 (ConstantValue::Scalar(lhs), ConstantValue::Tensor(rhs)) => {
106 ConstantValue::Tensor(lhs + rhs)
107 }
108 (ConstantValue::Scalar(lhs), ConstantValue::Matrix(rhs)) => {
109 ConstantValue::Matrix(lhs + rhs)
110 }
111 (ConstantValue::Tensor(lhs), ConstantValue::Tensor(rhs)) => {
112 ConstantValue::Tensor(lhs + rhs)
113 }
114 (ConstantValue::Tensor(lhs), ConstantValue::Scalar(rhs)) => {
115 ConstantValue::Tensor(lhs.clone() + rhs)
116 }
117 (ConstantValue::Matrix(lhs), ConstantValue::Matrix(rhs)) => {
118 ConstantValue::Matrix(lhs + rhs)
119 }
120 (ConstantValue::Matrix(lhs), ConstantValue::Scalar(rhs)) => {
121 ConstantValue::Matrix(lhs.clone() + rhs)
122 }
123 _ => panic!(),
124 }
125 }
126
127 pub fn sub(&self, rhs: ConstantValue) -> ConstantValue {
128 match (self, rhs) {
129 (ConstantValue::Scalar(lhs), ConstantValue::Scalar(rhs)) => {
130 ConstantValue::Scalar(lhs - rhs)
131 }
132 (ConstantValue::Scalar(lhs), ConstantValue::Tensor(rhs)) => {
133 ConstantValue::Tensor(lhs - rhs)
134 }
135 (ConstantValue::Scalar(lhs), ConstantValue::Matrix(rhs)) => {
136 ConstantValue::Matrix(lhs - rhs)
137 }
138 (ConstantValue::Tensor(lhs), ConstantValue::Tensor(rhs)) => {
139 ConstantValue::Tensor(lhs - rhs)
140 }
141 (ConstantValue::Tensor(lhs), ConstantValue::Scalar(rhs)) => {
142 ConstantValue::Tensor(lhs.clone() - rhs)
143 }
144 (ConstantValue::Matrix(lhs), ConstantValue::Matrix(rhs)) => {
145 ConstantValue::Matrix(lhs - rhs)
146 }
147 (ConstantValue::Matrix(lhs), ConstantValue::Scalar(rhs)) => {
148 ConstantValue::Matrix(lhs.clone() - rhs)
149 }
150 _ => panic!(),
151 }
152 }
153
154 pub fn mul(&self, rhs: ConstantValue) -> ConstantValue {
155 match (self, rhs) {
156 (ConstantValue::Scalar(lhs), ConstantValue::Scalar(rhs)) => {
157 ConstantValue::Scalar(lhs * rhs)
158 }
159 (ConstantValue::Scalar(lhs), ConstantValue::Tensor(rhs)) => {
160 ConstantValue::Tensor(lhs * rhs)
161 }
162 (ConstantValue::Scalar(lhs), ConstantValue::Matrix(rhs)) => {
163 ConstantValue::Matrix(lhs * rhs)
164 }
165 (ConstantValue::Tensor(lhs), ConstantValue::Tensor(rhs)) => {
166 ConstantValue::Tensor(lhs * rhs)
167 }
168 (ConstantValue::Tensor(lhs), ConstantValue::Scalar(rhs)) => {
169 ConstantValue::Tensor(lhs.clone() * rhs)
170 }
171 (ConstantValue::Matrix(lhs), ConstantValue::Matrix(rhs)) => {
172 ConstantValue::Matrix(lhs * rhs)
173 }
174 (ConstantValue::Matrix(lhs), ConstantValue::Scalar(rhs)) => {
175 ConstantValue::Matrix(lhs.clone() * rhs)
176 }
177 _ => panic!(),
178 }
179 }
180
181 pub fn div(self, rhs: &ConstantValue) -> ConstantValue {
182 match (self, rhs) {
183 (ConstantValue::Scalar(lhs), ConstantValue::Scalar(rhs)) => {
184 ConstantValue::Scalar(lhs / rhs)
185 }
186 (ConstantValue::Scalar(lhs), ConstantValue::Tensor(rhs)) => {
187 ConstantValue::Tensor(lhs / rhs.clone())
188 }
189 (ConstantValue::Scalar(lhs), ConstantValue::Matrix(rhs)) => {
190 ConstantValue::Matrix(lhs / rhs.clone())
191 }
192 (ConstantValue::Tensor(lhs), ConstantValue::Tensor(rhs)) => {
193 ConstantValue::Tensor(lhs / rhs)
194 }
195 (ConstantValue::Tensor(lhs), ConstantValue::Scalar(rhs)) => {
196 ConstantValue::Tensor(lhs / rhs)
197 }
198 (ConstantValue::Matrix(lhs), ConstantValue::Matrix(rhs)) => {
199 ConstantValue::Matrix(lhs / rhs)
200 }
201 (ConstantValue::Matrix(lhs), ConstantValue::Scalar(rhs)) => {
202 ConstantValue::Matrix(lhs / rhs)
203 }
204 _ => panic!(),
205 }
206 }
207}