opensrdk_symbolic_computation/expression/
size.rs

1use crate::Expression;
2use serde::{Deserialize, Serialize};
3
4#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)]
5pub enum Size {
6    One,
7    Many,
8}
9
10impl Expression {
11    pub fn sizes(&self) -> Vec<Size> {
12        match self {
13            Expression::Variable(_, sizes) => sizes.clone(),
14            Expression::Constant(v) => v.sizes().into_abstract_size(),
15            Expression::PartialVariable(v) => v.sizes().into_abstract_size(),
16            Expression::Add(l, _) => l.sizes(),
17            Expression::Sub(l, _) => l.sizes(),
18            Expression::Mul(l, _) => l.sizes(),
19            Expression::Div(l, _) => l.sizes(),
20            Expression::Neg(v) => v.sizes(),
21            Expression::Transcendental(v) => v.sizes(),
22            Expression::Tensor(v) => v.sizes(),
23            Expression::Matrix(v) => v.sizes(),
24        }
25    }
26
27    pub fn is_same_size(&self, other: &Expression) -> bool {
28        let sl = self.sizes();
29        let sr = other.sizes();
30
31        if sl.len() == 0 || sr.len() == 0 {
32            return true;
33        }
34
35        sl == sr
36    }
37
38    pub fn not_1dimension_ranks(&self) -> usize {
39        self.sizes().iter().filter(|&d| *d != Size::One).count()
40    }
41
42    pub fn mathematical_sizes(&self) -> Vec<Size> {
43        if self.sizes() == vec![Size::One; 2] {
44            vec![]
45        } else {
46            self.sizes()
47        }
48    }
49}
50
51pub trait AbstractSize {
52    fn into_abstract_size(&self) -> Vec<Size>;
53}
54
55impl AbstractSize for [usize] {
56    fn into_abstract_size(&self) -> Vec<Size> {
57        self.iter()
58            .map(|&size| if size > 1 { Size::Many } else { Size::One })
59            .collect()
60    }
61}
62
63#[cfg(test)]
64mod tests {
65    use std::{
66        collections::{HashMap, HashSet},
67        ops::Add,
68    };
69
70    use opensrdk_linear_algebra::{sparse::SparseTensor, Matrix, Tensor};
71
72    use crate::{new_variable, new_variable_tensor, AbstractSize, Expression, Size};
73
74    #[test]
75    fn it_works1() {
76        let a = 5.0f64;
77        let b = vec![a; 8];
78        let mut hash = HashMap::new();
79        hash.insert(vec![3usize; 8], 2.0);
80        hash.insert(vec![1usize; 8], 3.0);
81        hash.insert(vec![4usize; 8], 4.0);
82        hash.insert(vec![5usize; 8], 2.0);
83        let c = SparseTensor::from(vec![6usize; 8], hash).unwrap();
84
85        let ea = Expression::from(a);
86        let eb = Expression::from(b);
87        let ec = Expression::from(c);
88
89        let sa = ea.sizes();
90        let sb = eb.sizes();
91        let sc = ec.sizes();
92
93        assert_eq!(vec![Size::Many; 0], sa);
94        assert_eq!(vec![Size::Many; 1], sb);
95        assert_eq!(vec![Size::Many; 8], sc);
96    }
97
98    #[test]
99    fn it_works2() {
100        let id = "x";
101        let ea = new_variable((id).to_string());
102        let sa = ea.sizes();
103
104        assert_eq!(vec![Size::Many; 0], sa);
105    }
106
107    #[test]
108    fn it_works3() {
109        let mut hash1 = HashMap::new();
110        hash1.insert(vec![3usize; 8], 2.0);
111        hash1.insert(vec![1usize; 8], 3.0);
112        hash1.insert(vec![4usize; 8], 4.0);
113        hash1.insert(vec![5usize; 8], 2.0);
114        let c1 = Expression::from(SparseTensor::from(vec![6usize; 8], hash1).unwrap());
115
116        let mut hash2 = HashMap::new();
117        hash2.insert(vec![4usize; 8], 1.0);
118        hash2.insert(vec![1usize; 8], 4.0);
119        hash2.insert(vec![6usize; 8], 3.0);
120        let c2 = Expression::from(SparseTensor::from(vec![8usize; 8], hash2).unwrap());
121
122        let mut hash3 = HashMap::new();
123        hash3.insert(vec![4usize; 6], 1.0);
124        hash3.insert(vec![1usize; 6], 4.0);
125        hash3.insert(vec![6usize; 6], 3.0);
126        let c3 = Expression::from(SparseTensor::from(vec![8usize; 6], hash3).unwrap());
127
128        let result1 = &c1.is_same_size(&c2);
129        let result2 = &c1.is_same_size(&c3);
130
131        assert_eq!(result1, &true);
132        assert_eq!(result2, &false);
133    }
134
135    #[test]
136    fn it_works4() {
137        let mut hash = HashMap::new();
138        hash.insert(vec![3usize; 8], 2.0);
139        hash.insert(vec![1usize; 8], 3.0);
140        hash.insert(vec![4usize; 8], 4.0);
141        hash.insert(vec![5usize; 8], 2.0);
142        let c = Expression::from(SparseTensor::from(vec![6usize; 8], hash).unwrap());
143
144        let rank = c.not_1dimension_ranks();
145
146        assert_eq!(rank, 8usize);
147    }
148
149    #[test]
150    fn it_works5() {
151        let id = "x";
152        let ea = new_variable_tensor((id).to_string(), vec![Size::Many, Size::Many]);
153        let ea_det = ea.clone().det();
154        let size = ea_det.sizes();
155        let mathematical_size = ea_det.mathematical_sizes();
156
157        assert_eq!(vec![Size::One; 2], size);
158        assert_eq!(Vec::<Size>::new(), mathematical_size);
159    }
160}