1use std::fmt;
4
5use crate::sym_expr::SymExpr;
6
7#[derive(Clone, Eq, Hash, PartialEq)]
9pub enum Constant {
10 Scalar(i32),
11 Vector(Vec<i32>),
12}
13
14impl Constant {
15 pub fn ndim(&self) -> usize {
16 match self {
17 Self::Scalar(_) => 0,
18 Self::Vector(_) => 1,
19 }
20 }
21
22 pub fn values(&self) -> &[i32] {
23 match self {
24 Self::Scalar(elem) => std::slice::from_ref(elem),
25 Self::Vector(vec) => vec.as_slice(),
26 }
27 }
28
29 pub fn into_vec(self) -> Vec<i32> {
30 match self {
31 Self::Scalar(x) => vec![x],
32 Self::Vector(vec) => vec,
33 }
34 }
35}
36
37impl fmt::Debug for Constant {
38 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39 match self {
40 Self::Scalar(val) => write!(f, "{}", val),
41 Self::Vector(vec) => write!(f, "{:?}", vec),
42 }
43 }
44}
45
46#[derive(Clone, Debug, PartialEq)]
47enum SymTensorKind {
48 Scalar(SymExpr),
49 Vector(Vec<SymExpr>),
50 Shape(Vec<SymExpr>),
51 Unknown {
52 note: &'static str,
54 },
55}
56
57#[derive(Clone, Debug, PartialEq)]
91pub struct SymTensor(SymTensorKind);
92
93impl SymTensor {
94 pub fn unknown(note: &'static str) -> Self {
99 Self(SymTensorKind::Unknown { note })
100 }
101
102 pub fn from_shape(shape: Vec<SymExpr>) -> Self {
104 Self(SymTensorKind::Shape(shape))
105 }
106
107 pub fn from_fixed_shape(shape: &[usize]) -> Self {
109 Self(SymTensorKind::Shape(
110 shape
111 .iter()
112 .copied()
113 .map(|size| SymExpr::Value(size as i32))
114 .collect(),
115 ))
116 }
117
118 pub fn from_vec(vec: Vec<SymExpr>) -> Self {
120 Self(SymTensorKind::Vector(vec))
121 }
122
123 pub fn from_scalar(item: SymExpr) -> Self {
125 Self(SymTensorKind::Scalar(item))
126 }
127
128 pub fn as_scalar(&self) -> Option<&SymExpr> {
130 match &self.0 {
131 SymTensorKind::Scalar(item) => Some(item),
132 _ => None,
133 }
134 }
135
136 pub fn as_vector(&self) -> Option<&[SymExpr]> {
138 match &self.0 {
139 SymTensorKind::Vector(vec) => Some(vec),
140 _ => None,
141 }
142 }
143
144 pub fn to_constant(&self) -> Option<Constant> {
147 match &self.0 {
148 SymTensorKind::Scalar(val) => match val {
149 SymExpr::Value(v) => Some(Constant::Scalar(*v)),
150 _ => None,
151 },
152 SymTensorKind::Vector(vec) => {
153 let values = vec
154 .iter()
155 .map(|v| match v {
156 SymExpr::Value(v) => Some(*v),
157 _ => None,
158 })
159 .collect::<Option<Vec<i32>>>()?;
160 Some(Constant::Vector(values))
161 }
162 SymTensorKind::Shape(_) | SymTensorKind::Unknown { .. } => None,
163 }
164 }
165
166 pub fn ndim(&self) -> Option<usize> {
168 match &self.0 {
169 SymTensorKind::Scalar(_) => Some(0),
170 SymTensorKind::Vector(_) => Some(1),
171 SymTensorKind::Shape(val) => Some(val.len()),
172 SymTensorKind::Unknown { .. } => None,
173 }
174 }
175
176 pub fn size(&self, index: usize) -> Option<SymExpr> {
181 match &self.0 {
182 SymTensorKind::Scalar(_) => None,
183 SymTensorKind::Vector(val) => {
184 if index == 0 {
185 Some(SymExpr::Value(val.len() as i32))
186 } else {
187 None
188 }
189 }
190 SymTensorKind::Shape(val) => val.get(index).cloned(),
191 SymTensorKind::Unknown { .. } => None,
192 }
193 }
194
195 pub fn shape(&self) -> Option<impl ExactSizeIterator<Item = SymExpr> + Clone> {
197 let ndim = self.ndim()?;
198 let dims = (0..ndim).map(|d| self.size(d).unwrap());
199 Some(dims)
200 }
201
202 pub fn values(&self) -> Option<&[SymExpr]> {
204 match &self.0 {
205 SymTensorKind::Scalar(item) => Some(std::slice::from_ref(item)),
206 SymTensorKind::Vector(val) => Some(val),
207 SymTensorKind::Shape(_) | SymTensorKind::Unknown { .. } => None,
208 }
209 }
210
211 pub fn simplify(self) -> Self {
215 match self.0 {
216 SymTensorKind::Scalar(item) => Self::from_scalar(item.simplify()),
217 SymTensorKind::Vector(vec) => {
218 Self::from_vec(vec.into_iter().map(|x| x.simplify()).collect())
219 }
220 SymTensorKind::Shape(shape) => {
221 Self::from_shape(shape.into_iter().map(|d| d.simplify()).collect())
222 }
223 _ => self,
224 }
225 }
226}
227
228#[cfg(test)]
229pub(crate) use tests::{sym_elems, sym_shape, sym_vec};
230
231#[cfg(test)]
232mod tests {
233 use super::{SymExpr, SymTensor};
234
235 macro_rules! sym_elems {
237 ($($x:expr),* $(,)?) => {
238 vec![$(SymExpr::from($x)),*]
239 };
240 }
241
242 macro_rules! sym_vec {
244 ($($x:expr),* $(,)?) => {
245 SymTensor::from_vec(vec![$(SymExpr::from($x)),*])
246 };
247 }
248
249 macro_rules! sym_shape {
251 ($($x:expr),* $(,)?) => {
252 SymTensor::from_shape(vec![$(SymExpr::from($x)),*])
253 };
254 }
255
256 pub(crate) use {sym_elems, sym_shape, sym_vec};
257
258 #[test]
259 fn test_scalar() {
260 let x = SymTensor::from_scalar("x".into());
261 assert_eq!(x.ndim(), Some(0));
262 assert_eq!(x.size(0), None);
263 assert_eq!(x.values(), Some(["x".into()].as_slice()));
264 }
265
266 #[test]
267 fn test_vector() {
268 let x = SymTensor::from_vec(vec!["x".into(), 2.into()]);
269 assert_eq!(x.ndim(), Some(1));
270 assert_eq!(x.size(0), Some(2.into()));
271 assert_eq!(x.size(1), None);
272 assert_eq!(x.values(), Some(["x".into(), 2.into()].as_slice()));
273 }
274
275 #[test]
276 fn test_tensor_with_shape() {
277 let x = SymTensor::from_shape(vec!["x".into(), 2.into()]);
278 assert_eq!(x.ndim(), Some(2));
279 assert_eq!(x.size(0), Some("x".into()));
280 assert_eq!(x.size(1), Some(2.into()));
281 assert_eq!(x.size(2), None);
282 assert_eq!(x.values(), None);
283 assert_eq!(
284 x.shape().unwrap().collect::<Vec<_>>(),
285 vec!["x".into(), 2.into()]
286 );
287 }
288 #[test]
289 fn test_simplify() {
290 let matrix = SymTensor::from_shape(vec![
292 SymExpr::pos_var("rows") + SymExpr::from(0),
293 SymExpr::pos_var("cols") * SymExpr::from(1),
294 ])
295 .simplify();
296 assert_eq!(
297 matrix.shape().unwrap().collect::<Vec<_>>(),
298 vec!["rows".into(), "cols".into(),]
299 );
300
301 let x = SymExpr::var("x");
303 let add_expr = x.clone() + SymExpr::from(0);
304 let scalar = SymTensor::from_scalar(add_expr.clone()).simplify();
305 assert_eq!(scalar.as_scalar().unwrap(), &x);
306
307 let vec = SymTensor::from_vec(vec![add_expr.clone(), add_expr.clone()]).simplify();
309 assert_eq!(vec.as_vector().unwrap(), [x.clone(), x.clone()]);
310 }
311
312 #[test]
313 fn test_unknown_shape() {
314 let x = SymTensor::unknown("missing input shape");
315 assert!(x.shape().is_none());
316 assert_eq!(x.values(), None);
317 }
318}