1use std::{collections::BTreeMap, fmt::Display};
5
6use crate::shape::{Axis, Dimension};
7
8pub(super) type Stride = usize;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, bitcode::Encode, bitcode::Decode)]
11pub(super) struct StridedDim {
12 pub(super) axis: Axis,
13 pub(super) dim: Dimension,
14 pub(super) stride: Stride,
15}
16
17#[derive(Debug, PartialEq, Eq, Clone, PartialOrd, Ord, bitcode::Encode, bitcode::Decode)]
18pub(super) enum View {
19 None,
20 Strided(Vec<StridedDim>),
22 Padded(Vec<StridedDim>, Vec<(Vec<Axis>, (isize, isize))>),
25 }
28
29impl View {
30 pub(super) fn new(shape: &[usize]) -> Self {
31 let mut stride = 1;
32 let mut view: Vec<StridedDim> = shape
33 .iter()
34 .enumerate()
35 .rev()
36 .map(|(axis, dim)| {
37 let temp = stride;
38 stride *= dim;
39 StridedDim {
40 axis,
41 stride: temp,
42 dim: *dim,
43 }
44 })
45 .collect();
46 view.reverse();
47 return View::Strided(view);
48 }
49
50 pub(super) fn binded(shape: &[usize], axes: &[usize]) -> Self {
52 assert_eq!(shape.len(), axes.len());
53 let mut stride = 1;
54 let mut view: Vec<StridedDim> = shape
55 .iter()
56 .zip(axes)
57 .rev()
58 .map(|(&dim, &axis)| {
59 let temp = stride;
60 stride *= dim;
61 StridedDim {
62 axis,
63 stride: temp,
64 dim,
65 }
66 })
67 .collect();
68 view.reverse();
69 return View::Strided(view);
70 }
71
72 pub(super) fn shape(&self) -> Vec<usize> {
73 match self {
74 View::None => vec![1],
75 View::Strided(dims) => dims.iter().map(|dim| dim.dim).collect(),
76 View::Padded(dims, _) => dims.iter().map(|dim| dim.dim).collect(),
77 }
78 }
79
80 pub(super) fn rank(&self) -> usize {
81 match self {
82 View::None => 1,
83 View::Strided(dims) => dims.len(),
84 View::Padded(dims, _) => dims.len(),
85 }
86 }
87
88 pub(super) fn used_axes(&self) -> Vec<Axis> {
90 match self {
91 View::None => Vec::new(),
92 View::Strided(dims) | View::Padded(dims, _) => {
93 let mut res: Vec<Axis> = dims
94 .iter()
95 .flat_map(|x| if x.stride != 0 { Some(x.axis) } else { None })
96 .collect();
97 res.sort();
98 res
99 }
100 }
101 }
102
103 pub(super) fn requires_conditional_padding(&self) -> bool {
104 if let View::Padded(_, padding) = self {
106 return padding.iter().any(|(_, (lp, rp))| *lp > 0 || *rp > 0);
107 }
108 false
109 }
110
111 pub(super) fn original_numel(&self) -> usize {
112 match self {
114 View::None => 1,
115 View::Strided(dims) => dims
116 .iter()
117 .map(|dim| if dim.stride != 0 { dim.dim } else { 1 })
118 .product(),
119 View::Padded(dims, axes) => axes
120 .iter()
121 .map(|(axes, (lp, rp))| {
122 let numel: usize = dims
123 .iter()
124 .filter_map(|StridedDim { axis, dim, .. }| {
125 if axes.contains(axis) {
126 Some(*dim)
127 } else {
128 None
129 }
130 })
131 .product();
132 (numel as isize - lp - rp) as usize
134 })
135 .product(),
136 }
137 }
138
139 pub(super) fn permute(&mut self, axes: &[usize]) {
140 assert_eq!(self.rank(), axes.len());
142 match self {
143 View::None => {}
144 View::Strided(dims) => {
145 *dims = axes.iter().map(|axis| dims[*axis]).collect();
146 for (a, dim) in dims.iter_mut().enumerate() {
147 dim.axis = a;
148 }
149 }
150 View::Padded(dims, padding) => {
151 *dims = axes.iter().map(|axis| dims[*axis]).collect();
152 for (a, dim) in dims.iter_mut().enumerate() {
153 dim.axis = a;
154 }
155 let axes_map: BTreeMap<usize, usize> =
157 (0..axes.len()).zip(axes.iter().copied()).collect();
158 for (axes, _) in padding {
159 for d in axes {
160 *d = axes_map[d];
161 }
162 }
163 }
164 }
165 }
166
167 pub(super) fn pad_axis(&mut self, axis: Axis, left_pad: isize, right_pad: isize) {
170 let paxis = axis;
172 match self {
173 View::None => {}
174 View::Strided(dims) => {
175 if dims.iter().any(|&StridedDim { axis, .. }| axis == paxis) {
176 *self = View::Padded(
177 dims.iter()
178 .map(|&StridedDim { axis, dim, stride }| {
179 if axis == paxis {
180 StridedDim {
181 axis,
182 dim: (dim as isize + left_pad + right_pad) as usize,
183 stride,
184 }
185 } else {
186 StridedDim { axis, dim, stride }
187 }
188 })
189 .collect(),
190 vec![(vec![axis], (left_pad, right_pad))],
191 );
192 }
193 }
194 View::Padded(dims, padding) => {
195 if let Some(StridedDim { dim, .. }) = dims
196 .iter_mut()
197 .find(|StridedDim { axis, .. }| *axis == paxis)
198 {
199 *dim = (*dim as isize + left_pad + right_pad) as usize;
201 if let Some((_, (lp, rp))) =
202 padding.iter_mut().find(|(axes, _)| axes.contains(&axis))
203 {
204 *lp += left_pad;
205 *rp += right_pad;
206 } else {
207 padding.push((vec![axis], (left_pad, right_pad)));
208 padding.sort();
209 }
210 }
211 }
212 }
213 }
215
216 pub(super) fn expand(&mut self, axis: Axis, dimension: Dimension) {
217 match self {
232 View::None => {}
233 View::Strided(dims) => {
234 for StridedDim {
235 axis: paxis,
236 dim,
237 stride,
238 ..
239 } in dims.iter_mut()
240 {
241 if axis == *paxis {
242 assert_eq!(*dim, 1);
243 *stride = 0;
244 *dim = dimension;
245 }
246 }
247 }
248 View::Padded(dims, padding) => {
249 for StridedDim {
250 axis: paxis,
251 dim,
252 stride,
253 ..
254 } in dims.iter_mut()
255 {
256 if axis == *paxis {
257 assert_eq!(*dim, 1);
258 *stride = 0;
259 *dim = dimension;
260 for (axes, _) in padding.iter_mut() {
262 if let Some(id) = axes.iter().position(|&a| a == axis) {
263 axes.remove(id);
264 }
265 }
266 }
267 }
268 }
269 }
270 }
271
272 pub(super) fn numel(&self) -> usize {
273 match self {
274 View::None => 0,
275 View::Strided(dims) => dims.iter().map(|dim| dim.dim).product(),
276 View::Padded(dims, _) => dims.iter().map(|dim| dim.dim).product(),
277 }
278 }
279
280 pub(super) fn is_contiguous(&self) -> bool {
281 &View::new(&self.shape()) == self
282 }
283
284 pub(super) fn split_axis(&mut self, axis: Axis, dimensions: &[usize]) {
285 match self {
287 View::None => {}
288 View::Strided(dims) => {
289 for st_dim in dims.iter_mut() {
291 if axis < st_dim.axis {
292 st_dim.axis += dimensions.len() - 1;
293 }
294 }
295 if let Some((id, st_dim)) = dims
296 .iter_mut()
297 .enumerate()
298 .find(|(_, dim)| dim.axis == axis)
299 {
300 let mut stride = st_dim.stride;
301 dims.remove(id);
302 let mut temp_axis = axis + dimensions.len();
303 for dim in dimensions.iter().copied().rev() {
304 temp_axis -= 1;
305 dims.insert(
306 id,
307 StridedDim {
308 axis: temp_axis,
309 dim,
310 stride,
311 },
312 );
313 stride *= dim;
314 }
315 }
316 }
317 View::Padded(dims, padding) => {
318 let dim_len = dimensions.len();
319 for st_dim in dims.iter_mut() {
320 if axis < st_dim.axis {
321 st_dim.axis += dim_len - 1;
322 }
323 }
324 if let Some((id, st_dim)) = dims
325 .iter_mut()
326 .enumerate()
327 .find(|(_, dim)| dim.axis == axis)
328 {
329 let mut stride = st_dim.stride;
330 dims.remove(id);
331 let mut temp_axis = axis + dimensions.len();
332 for dim in dimensions.iter().copied().rev() {
333 temp_axis -= 1;
334 dims.insert(
335 id,
336 StridedDim {
337 axis: temp_axis,
338 dim,
339 stride,
340 },
341 );
342 stride *= dim;
343 }
344 }
345 for (axes, _) in padding.iter_mut() {
347 for a in axes {
348 if *a > axis {
349 *a += dim_len - 1;
350 }
351 }
352 }
353 if let Some((axes, _)) = padding.iter_mut().find(|(k, _)| k.contains(&axis)) {
355 for a in axis + 1..axis + dim_len {
357 if dims.iter().find(|dim| dim.axis == a).unwrap().dim != 1 {
358 axes.push(a);
359 }
360 }
361 axes.sort();
363 }
364 }
365 }
366 }
367}
368
369impl Display for View {
370 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
371 match self {
372 View::None => f.write_str("View::None"),
373 View::Strided(dims) => f.write_fmt(format_args!(
374 "V:S ax{:?} sh{:?} st{:?}",
375 dims.iter().map(|d| d.axis).collect::<Vec<Dimension>>(),
376 dims.iter().map(|d| d.dim).collect::<Vec<Dimension>>(),
377 dims.iter().map(|d| d.stride).collect::<Vec<Stride>>()
378 )),
379 View::Padded(dims, padding) => f.write_fmt(format_args!(
380 "V:P ax{:?} sh{:?} st{:?} pd{:?}",
381 dims.iter().map(|d| d.axis).collect::<Vec<Dimension>>(),
382 dims.iter().map(|d| d.dim).collect::<Vec<Dimension>>(),
383 dims.iter().map(|d| d.stride).collect::<Vec<Stride>>(),
384 padding,
385 )),
386 }
387 }
388}