zyx/runtime/
view.rs

1//! View handles movement operations on nodes.
2//! It is midlayer between graph and IR representation of movement ops.
3
4use std::{collections::BTreeMap, fmt::Display};
5
6use crate::shape::{Axis, Dimension};
7
8pub(super) type Stride = usize;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, bitcode::Encode, bitcode::Decode)]
11pub(super) struct StridedDim {
12    pub(super) axis: Axis,
13    pub(super) dim: Dimension,
14    pub(super) stride: Stride,
15}
16
17#[derive(Debug, PartialEq, Eq, Clone, PartialOrd, Ord, bitcode::Encode, bitcode::Decode)]
18pub(super) enum View {
19    None,
20    //Contiguous(Vec<StridedDim>), // TODO perhaps later, mainly for cpu and perhaps wide loads on gpu
21    Strided(Vec<StridedDim>),
22    // First is typical strided, second is group of axes and their padding. Very ugly, but works.
23    // If you can make it nicer, please do.
24    Padded(Vec<StridedDim>, Vec<(Vec<Axis>, (isize, isize))>),
25    //Reshaped(), // TODO perhaps for some weird optimizations, but it may actually reduce performace
26    // since then loads are very unpredictable
27}
28
29impl View {
30    pub(super) fn new(shape: &[usize]) -> Self {
31        let mut stride = 1;
32        let mut view: Vec<StridedDim> = shape
33            .iter()
34            .enumerate()
35            .rev()
36            .map(|(axis, dim)| {
37                let temp = stride;
38                stride *= dim;
39                StridedDim {
40                    axis,
41                    stride: temp,
42                    dim: *dim,
43                }
44            })
45            .collect();
46        view.reverse();
47        return View::Strided(view);
48    }
49
50    /// Creates view binded to specific axes
51    pub(super) fn binded(shape: &[usize], axes: &[usize]) -> Self {
52        assert_eq!(shape.len(), axes.len());
53        let mut stride = 1;
54        let mut view: Vec<StridedDim> = shape
55            .iter()
56            .zip(axes)
57            .rev()
58            .map(|(&dim, &axis)| {
59                let temp = stride;
60                stride *= dim;
61                StridedDim {
62                    axis,
63                    stride: temp,
64                    dim,
65                }
66            })
67            .collect();
68        view.reverse();
69        return View::Strided(view);
70    }
71
72    pub(super) fn shape(&self) -> Vec<usize> {
73        match self {
74            View::None => vec![1],
75            View::Strided(dims) => dims.iter().map(|dim| dim.dim).collect(),
76            View::Padded(dims, _) => dims.iter().map(|dim| dim.dim).collect(),
77        }
78    }
79
80    pub(super) fn rank(&self) -> usize {
81        match self {
82            View::None => 1,
83            View::Strided(dims) => dims.len(),
84            View::Padded(dims, _) => dims.len(),
85        }
86    }
87
88    /// Returns sorted used axes
89    pub(super) fn used_axes(&self) -> Vec<Axis> {
90        match self {
91            View::None => Vec::new(),
92            View::Strided(dims) | View::Padded(dims, _) => {
93                let mut res: Vec<Axis> = dims
94                    .iter()
95                    .flat_map(|x| if x.stride != 0 { Some(x.axis) } else { None })
96                    .collect();
97                res.sort();
98                res
99            }
100        }
101    }
102
103    pub(super) fn requires_conditional_padding(&self) -> bool {
104        // View requires conditional padding if any padding is more than zero
105        if let View::Padded(_, padding) = self {
106            return padding.iter().any(|(_, (lp, rp))| *lp > 0 || *rp > 0);
107        }
108        false
109    }
110
111    pub(super) fn original_numel(&self) -> usize {
112        //println!("Original numel {self}");
113        match self {
114            View::None => 1,
115            View::Strided(dims) => dims
116                .iter()
117                .map(|dim| if dim.stride != 0 { dim.dim } else { 1 })
118                .product(),
119            View::Padded(dims, axes) => axes
120                .iter()
121                .map(|(axes, (lp, rp))| {
122                    let numel: usize = dims
123                        .iter()
124                        .filter_map(|StridedDim { axis, dim, .. }| {
125                            if axes.contains(axis) {
126                                Some(*dim)
127                            } else {
128                                None
129                            }
130                        })
131                        .product();
132                    //println!("{numel}, {lp}, {rp}");
133                    (numel as isize - lp - rp) as usize
134                })
135                .product(),
136        }
137    }
138
139    pub(super) fn permute(&mut self, axes: &[usize]) {
140        //println!("Permuting {self} by {axes:?}");
141        assert_eq!(self.rank(), axes.len());
142        match self {
143            View::None => {}
144            View::Strided(dims) => {
145                *dims = axes.iter().map(|axis| dims[*axis]).collect();
146                for (a, dim) in dims.iter_mut().enumerate() {
147                    dim.axis = a;
148                }
149            }
150            View::Padded(dims, padding) => {
151                *dims = axes.iter().map(|axis| dims[*axis]).collect();
152                for (a, dim) in dims.iter_mut().enumerate() {
153                    dim.axis = a;
154                }
155                // TODO is this correct?
156                let axes_map: BTreeMap<usize, usize> =
157                    (0..axes.len()).zip(axes.iter().copied()).collect();
158                for (axes, _) in padding {
159                    for d in axes {
160                        *d = axes_map[d];
161                    }
162                }
163            }
164        }
165    }
166
167    //pub(super) fn arbitrary_permute(&mut self, axes: &[usize]) { todo!() }
168
169    pub(super) fn pad_axis(&mut self, axis: Axis, left_pad: isize, right_pad: isize) {
170        //println!("Padding {axis} with {left_pad}, {right_pad}");
171        let paxis = axis;
172        match self {
173            View::None => {}
174            View::Strided(dims) => {
175                if dims.iter().any(|&StridedDim { axis, .. }| axis == paxis) {
176                    *self = View::Padded(
177                        dims.iter()
178                            .map(|&StridedDim { axis, dim, stride }| {
179                                if axis == paxis {
180                                    StridedDim {
181                                        axis,
182                                        dim: (dim as isize + left_pad + right_pad) as usize,
183                                        stride,
184                                    }
185                                } else {
186                                    StridedDim { axis, dim, stride }
187                                }
188                            })
189                            .collect(),
190                        vec![(vec![axis], (left_pad, right_pad))],
191                    );
192                }
193            }
194            View::Padded(dims, padding) => {
195                if let Some(StridedDim { dim, .. }) = dims
196                    .iter_mut()
197                    .find(|StridedDim { axis, .. }| *axis == paxis)
198                {
199                    //println!("Padding axis {axis}, dim {dim} with {left_pad}, {right_pad}");
200                    *dim = (*dim as isize + left_pad + right_pad) as usize;
201                    if let Some((_, (lp, rp))) =
202                        padding.iter_mut().find(|(axes, _)| axes.contains(&axis))
203                    {
204                        *lp += left_pad;
205                        *rp += right_pad;
206                    } else {
207                        padding.push((vec![axis], (left_pad, right_pad)));
208                        padding.sort();
209                    }
210                }
211            }
212        }
213        //println!("Result {self}");
214    }
215
216    pub(super) fn expand(&mut self, axis: Axis, dimension: Dimension) {
217        // TODO probably instead of changing stride to 0, we can simply
218        // remove the dimension alltogether
219        /*let _ = dimension;
220        match self {
221            View::None => {}
222            View::Strided(dims) => {
223                dims.retain(|x| x.axis != axis);
224            }
225            View::Padded(dims, pa) => {
226                dims.retain(|x| x.axis != axis);
227                pa.axes.iter_mut().for_each(|(v, _)| v.retain(|a| *a != axis));
228                pa.axes.retain(|(axes, _)| !axes.is_empty());
229            }
230        }*/
231        match self {
232            View::None => {}
233            View::Strided(dims) => {
234                for StridedDim {
235                    axis: paxis,
236                    dim,
237                    stride,
238                    ..
239                } in dims.iter_mut()
240                {
241                    if axis == *paxis {
242                        assert_eq!(*dim, 1);
243                        *stride = 0;
244                        *dim = dimension;
245                    }
246                }
247            }
248            View::Padded(dims, padding) => {
249                for StridedDim {
250                    axis: paxis,
251                    dim,
252                    stride,
253                    ..
254                } in dims.iter_mut()
255                {
256                    if axis == *paxis {
257                        assert_eq!(*dim, 1);
258                        *stride = 0;
259                        *dim = dimension;
260                        // Remove expanded axes from padding
261                        for (axes, _) in padding.iter_mut() {
262                            if let Some(id) = axes.iter().position(|&a| a == axis) {
263                                axes.remove(id);
264                            }
265                        }
266                    }
267                }
268            }
269        }
270    }
271
272    pub(super) fn numel(&self) -> usize {
273        match self {
274            View::None => 0,
275            View::Strided(dims) => dims.iter().map(|dim| dim.dim).product(),
276            View::Padded(dims, _) => dims.iter().map(|dim| dim.dim).product(),
277        }
278    }
279
280    pub(super) fn is_contiguous(&self) -> bool {
281        &View::new(&self.shape()) == self
282    }
283
284    pub(super) fn split_axis(&mut self, axis: Axis, dimensions: &[usize]) {
285        //println!("{axis}, {dimensions:?}");
286        match self {
287            View::None => {}
288            View::Strided(dims) => {
289                // Rename all following axes
290                for st_dim in dims.iter_mut() {
291                    if axis < st_dim.axis {
292                        st_dim.axis += dimensions.len() - 1;
293                    }
294                }
295                if let Some((id, st_dim)) = dims
296                    .iter_mut()
297                    .enumerate()
298                    .find(|(_, dim)| dim.axis == axis)
299                {
300                    let mut stride = st_dim.stride;
301                    dims.remove(id);
302                    let mut temp_axis = axis + dimensions.len();
303                    for dim in dimensions.iter().copied().rev() {
304                        temp_axis -= 1;
305                        dims.insert(
306                            id,
307                            StridedDim {
308                                axis: temp_axis,
309                                dim,
310                                stride,
311                            },
312                        );
313                        stride *= dim;
314                    }
315                }
316            }
317            View::Padded(dims, padding) => {
318                let dim_len = dimensions.len();
319                for st_dim in dims.iter_mut() {
320                    if axis < st_dim.axis {
321                        st_dim.axis += dim_len - 1;
322                    }
323                }
324                if let Some((id, st_dim)) = dims
325                    .iter_mut()
326                    .enumerate()
327                    .find(|(_, dim)| dim.axis == axis)
328                {
329                    let mut stride = st_dim.stride;
330                    dims.remove(id);
331                    let mut temp_axis = axis + dimensions.len();
332                    for dim in dimensions.iter().copied().rev() {
333                        temp_axis -= 1;
334                        dims.insert(
335                            id,
336                            StridedDim {
337                                axis: temp_axis,
338                                dim,
339                                stride,
340                            },
341                        );
342                        stride *= dim;
343                    }
344                }
345                // If key in padding axes is greater than axis, then add dim_len - 1 to it
346                for (axes, _) in padding.iter_mut() {
347                    for a in axes {
348                        if *a > axis {
349                            *a += dim_len - 1;
350                        }
351                    }
352                }
353                // Split padding
354                if let Some((axes, _)) = padding.iter_mut().find(|(k, _)| k.contains(&axis)) {
355                    //std::println!("Original: {axes:?} splitting into: {axis}..{}", axis+dim_len);
356                    for a in axis + 1..axis + dim_len {
357                        if dims.iter().find(|dim| dim.axis == a).unwrap().dim != 1 {
358                            axes.push(a);
359                        }
360                    }
361                    // Would not be needed on btreeset
362                    axes.sort();
363                }
364            }
365        }
366    }
367}
368
369impl Display for View {
370    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
371        match self {
372            View::None => f.write_str("View::None"),
373            View::Strided(dims) => f.write_fmt(format_args!(
374                "V:S ax{:?} sh{:?} st{:?}",
375                dims.iter().map(|d| d.axis).collect::<Vec<Dimension>>(),
376                dims.iter().map(|d| d.dim).collect::<Vec<Dimension>>(),
377                dims.iter().map(|d| d.stride).collect::<Vec<Stride>>()
378            )),
379            View::Padded(dims, padding) => f.write_fmt(format_args!(
380                "V:P ax{:?} sh{:?} st{:?} pd{:?}",
381                dims.iter().map(|d| d.axis).collect::<Vec<Dimension>>(),
382                dims.iter().map(|d| d.dim).collect::<Vec<Dimension>>(),
383                dims.iter().map(|d| d.stride).collect::<Vec<Stride>>(),
384                padding,
385            )),
386        }
387    }
388}