1use crate::enums::shape_dim::ShapeDim;
24
25pub trait Shape {
32 fn shape(&self) -> ShapeDim;
34
35 fn shape_1d(&self) -> usize {
39 match self.shape() {
40 ShapeDim::Rank0(n) => n,
41 ShapeDim::Rank1(n) => n,
42 ShapeDim::Rank2 { rows, .. } => rows,
43 ShapeDim::Rank3 { x, .. } => x,
44 ShapeDim::Rank4 { a, .. } => a,
45 ShapeDim::RankN(dims) => *dims.get(0).unwrap_or(&1),
46 ShapeDim::Collection(items) => items.iter().map(|x| x.shape_1d()).sum(),
47 ShapeDim::Dictionary { .. } => panic!("shape_1d: incompatible Dictionary shape"),
48 ShapeDim::Unknown => panic!("shape_1d: incompatible Unknown shape"),
49 }
50 }
51
52 fn shape_2d(&self) -> (usize, usize) {
56 match self.shape() {
57 ShapeDim::Rank0(n) => (n, 1),
58 ShapeDim::Rank1(n) => (n, 1),
59 ShapeDim::Rank2 { rows, cols } => (rows, cols),
60 ShapeDim::Rank3 { x, y, .. } => (x, y),
61 ShapeDim::Rank4 { a, b, .. } => (a, b),
62 ShapeDim::RankN(dims) => (*dims.get(0).unwrap_or(&1), *dims.get(1).unwrap_or(&1)),
63 ShapeDim::Collection(items) => {
64 let mut total_rows = 0usize;
65 let mut ref_cols: Option<usize> = None;
66
67 for item in items {
68 let (rows, cols) = item.shape_2d();
69 total_rows += rows;
70
71 match ref_cols {
72 None => ref_cols = Some(cols),
73 Some(prev) if prev == cols => {}
74 Some(prev) => panic!(
75 "shape_2d: column mismatch in Collection: {} vs {}",
76 prev, cols
77 ),
78 }
79 }
80
81 (total_rows, ref_cols.unwrap_or(1))
82 }
83 ShapeDim::Dictionary { .. } => panic!("shape_2d: incompatible Dictionary shape"),
84 ShapeDim::Unknown => panic!("shape_2d: incompatible Unknown shape"),
85 }
86 }
87
88 fn shape_3d(&self) -> (usize, usize, usize) {
92 match self.shape() {
93 ShapeDim::Rank0(n) => (n, 1, 1),
94 ShapeDim::Rank1(n) => (n, 1, 1),
95 ShapeDim::Rank2 { rows, cols } => (rows, cols, 1),
96 ShapeDim::Rank3 { x, y, z } => (x, y, z),
97 ShapeDim::Rank4 { a, b, c, .. } => (a, b, c),
98 ShapeDim::RankN(dims) => (
99 *dims.get(0).unwrap_or(&1),
100 *dims.get(1).unwrap_or(&1),
101 *dims.get(2).unwrap_or(&1),
102 ),
103 ShapeDim::Collection(items) => {
104 let mut total_a = 0usize;
105 let mut ref_b: Option<usize> = None;
106 let mut ref_c: Option<usize> = None;
107
108 for item in items {
109 let (a, b, c) = item.shape_3d();
110 total_a += a;
111
112 match ref_b {
113 None => ref_b = Some(b),
114 Some(prev) if prev == b => {}
115 Some(prev) => panic!(
116 "shape_3d: 2nd dim mismatch in Collection: {} vs {}",
117 prev, b
118 ),
119 }
120
121 match ref_c {
122 None => ref_c = Some(c),
123 Some(prev) if prev == c => {}
124 Some(prev) => panic!(
125 "shape_3d: 3rd dim mismatch in Collection: {} vs {}",
126 prev, c
127 ),
128 }
129 }
130
131 (total_a, ref_b.unwrap_or(1), ref_c.unwrap_or(1))
132 }
133 ShapeDim::Dictionary { .. } => panic!("shape_3d: incompatible Dictionary shape"),
134 ShapeDim::Unknown => panic!("shape_3d: incompatible Unknown shape"),
135 }
136 }
137
138 fn shape_4d(&self) -> (usize, usize, usize, usize) {
142 match self.shape() {
143 ShapeDim::Rank0(n) => (n, 1, 1, 1),
144 ShapeDim::Rank1(n) => (n, 1, 1, 1),
145 ShapeDim::Rank2 { rows, cols } => (rows, cols, 1, 1),
146 ShapeDim::Rank3 { x, y, z } => (x, y, z, 1),
147 ShapeDim::Rank4 { a, b, c, d } => (a, b, c, d),
148 ShapeDim::RankN(dims) => (
149 *dims.get(0).unwrap_or(&1),
150 *dims.get(1).unwrap_or(&1),
151 *dims.get(2).unwrap_or(&1),
152 *dims.get(3).unwrap_or(&1),
153 ),
154 ShapeDim::Collection(items) => {
155 let mut total_a = 0usize;
156 let mut ref_b: Option<usize> = None;
157 let mut ref_c: Option<usize> = None;
158 let mut ref_d: Option<usize> = None;
159
160 for item in items {
161 let (a, b, c, d) = item.shape_4d();
162 total_a += a;
163
164 match ref_b {
165 None => ref_b = Some(b),
166 Some(prev) if prev == b => {}
167 Some(prev) => panic!("shape_4d: 2nd dim mismatch: {} vs {}", prev, b),
168 }
169
170 match ref_c {
171 None => ref_c = Some(c),
172 Some(prev) if prev == c => {}
173 Some(prev) => panic!("shape_4d: 3rd dim mismatch: {} vs {}", prev, c),
174 }
175
176 match ref_d {
177 None => ref_d = Some(d),
178 Some(prev) if prev == d => {}
179 Some(prev) => panic!("shape_4d: 4th dim mismatch: {} vs {}", prev, d),
180 }
181 }
182
183 (
184 total_a,
185 ref_b.unwrap_or(1),
186 ref_c.unwrap_or(1),
187 ref_d.unwrap_or(1),
188 )
189 }
190 ShapeDim::Dictionary { .. } => panic!("shape_4d: incompatible Dictionary shape"),
191 ShapeDim::Unknown => panic!("shape_4d: incompatible Unknown shape"),
192 }
193 }
194}