opensrdk_symbolic_computation/expression/
size.rs1use crate::Expression;
2use serde::{Deserialize, Serialize};
3
4#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)]
5pub enum Size {
6 One,
7 Many,
8}
9
10impl Expression {
11 pub fn sizes(&self) -> Vec<Size> {
12 match self {
13 Expression::Variable(_, sizes) => sizes.clone(),
14 Expression::Constant(v) => v.sizes().into_abstract_size(),
15 Expression::PartialVariable(v) => v.sizes().into_abstract_size(),
16 Expression::Add(l, _) => l.sizes(),
17 Expression::Sub(l, _) => l.sizes(),
18 Expression::Mul(l, _) => l.sizes(),
19 Expression::Div(l, _) => l.sizes(),
20 Expression::Neg(v) => v.sizes(),
21 Expression::Transcendental(v) => v.sizes(),
22 Expression::Tensor(v) => v.sizes(),
23 Expression::Matrix(v) => v.sizes(),
24 }
25 }
26
27 pub fn is_same_size(&self, other: &Expression) -> bool {
28 let sl = self.sizes();
29 let sr = other.sizes();
30
31 if sl.len() == 0 || sr.len() == 0 {
32 return true;
33 }
34
35 sl == sr
36 }
37
38 pub fn not_1dimension_ranks(&self) -> usize {
39 self.sizes().iter().filter(|&d| *d != Size::One).count()
40 }
41
42 pub fn mathematical_sizes(&self) -> Vec<Size> {
43 if self.sizes() == vec![Size::One; 2] {
44 vec![]
45 } else {
46 self.sizes()
47 }
48 }
49}
50
51pub trait AbstractSize {
52 fn into_abstract_size(&self) -> Vec<Size>;
53}
54
55impl AbstractSize for [usize] {
56 fn into_abstract_size(&self) -> Vec<Size> {
57 self.iter()
58 .map(|&size| if size > 1 { Size::Many } else { Size::One })
59 .collect()
60 }
61}
62
63#[cfg(test)]
64mod tests {
65 use std::{
66 collections::{HashMap, HashSet},
67 ops::Add,
68 };
69
70 use opensrdk_linear_algebra::{sparse::SparseTensor, Matrix, Tensor};
71
72 use crate::{new_variable, new_variable_tensor, AbstractSize, Expression, Size};
73
74 #[test]
75 fn it_works1() {
76 let a = 5.0f64;
77 let b = vec![a; 8];
78 let mut hash = HashMap::new();
79 hash.insert(vec![3usize; 8], 2.0);
80 hash.insert(vec![1usize; 8], 3.0);
81 hash.insert(vec![4usize; 8], 4.0);
82 hash.insert(vec![5usize; 8], 2.0);
83 let c = SparseTensor::from(vec![6usize; 8], hash).unwrap();
84
85 let ea = Expression::from(a);
86 let eb = Expression::from(b);
87 let ec = Expression::from(c);
88
89 let sa = ea.sizes();
90 let sb = eb.sizes();
91 let sc = ec.sizes();
92
93 assert_eq!(vec![Size::Many; 0], sa);
94 assert_eq!(vec![Size::Many; 1], sb);
95 assert_eq!(vec![Size::Many; 8], sc);
96 }
97
98 #[test]
99 fn it_works2() {
100 let id = "x";
101 let ea = new_variable((id).to_string());
102 let sa = ea.sizes();
103
104 assert_eq!(vec![Size::Many; 0], sa);
105 }
106
107 #[test]
108 fn it_works3() {
109 let mut hash1 = HashMap::new();
110 hash1.insert(vec![3usize; 8], 2.0);
111 hash1.insert(vec![1usize; 8], 3.0);
112 hash1.insert(vec![4usize; 8], 4.0);
113 hash1.insert(vec![5usize; 8], 2.0);
114 let c1 = Expression::from(SparseTensor::from(vec![6usize; 8], hash1).unwrap());
115
116 let mut hash2 = HashMap::new();
117 hash2.insert(vec![4usize; 8], 1.0);
118 hash2.insert(vec![1usize; 8], 4.0);
119 hash2.insert(vec![6usize; 8], 3.0);
120 let c2 = Expression::from(SparseTensor::from(vec![8usize; 8], hash2).unwrap());
121
122 let mut hash3 = HashMap::new();
123 hash3.insert(vec![4usize; 6], 1.0);
124 hash3.insert(vec![1usize; 6], 4.0);
125 hash3.insert(vec![6usize; 6], 3.0);
126 let c3 = Expression::from(SparseTensor::from(vec![8usize; 6], hash3).unwrap());
127
128 let result1 = &c1.is_same_size(&c2);
129 let result2 = &c1.is_same_size(&c3);
130
131 assert_eq!(result1, &true);
132 assert_eq!(result2, &false);
133 }
134
135 #[test]
136 fn it_works4() {
137 let mut hash = HashMap::new();
138 hash.insert(vec![3usize; 8], 2.0);
139 hash.insert(vec![1usize; 8], 3.0);
140 hash.insert(vec![4usize; 8], 4.0);
141 hash.insert(vec![5usize; 8], 2.0);
142 let c = Expression::from(SparseTensor::from(vec![6usize; 8], hash).unwrap());
143
144 let rank = c.not_1dimension_ranks();
145
146 assert_eq!(rank, 8usize);
147 }
148
149 #[test]
150 fn it_works5() {
151 let id = "x";
152 let ea = new_variable_tensor((id).to_string(), vec![Size::Many, Size::Many]);
153 let ea_det = ea.clone().det();
154 let size = ea_det.sizes();
155 let mathematical_size = ea_det.mathematical_sizes();
156
157 assert_eq!(vec![Size::One; 2], size);
158 assert_eq!(Vec::<Size>::new(), mathematical_size);
159 }
160}