hpt_iterator/
shape_manipulate.rs

1use std::sync::Arc;
2
3use hpt_common::{
4    axis::axis::{process_axes, Axis},
5    error::shape::ShapeError,
6    shape::shape::Shape,
7    shape::shape_utils::{get_broadcast_axes_from, mt_intervals, try_pad_shape},
8    strides::strides::Strides,
9};
10
11use crate::iterator_traits::{ParStridedHelper, StridedHelper};
12
13#[track_caller]
14pub(crate) fn par_reshape<S: Into<Shape>, T: ParStridedHelper>(mut iterator: T, shape: S) -> T {
15    let tmp = shape.into();
16    let res_shape = tmp;
17    if iterator._layout().shape() == &res_shape {
18        return iterator;
19    }
20    let size = res_shape.size() as usize;
21    let inner_loop_size = res_shape[res_shape.len() - 1] as usize;
22    let outer_loop_size = size / inner_loop_size;
23    let num_threads;
24    if outer_loop_size < rayon::current_num_threads() {
25        num_threads = outer_loop_size;
26    } else {
27        num_threads = rayon::current_num_threads();
28    }
29    let intervals = mt_intervals(outer_loop_size, num_threads);
30    let len = intervals.len();
31    iterator._set_intervals(Arc::new(intervals));
32    iterator._set_end_index(len);
33    let self_size = iterator._layout().size();
34
35    if size > (self_size as usize) {
36        let self_shape = try_pad_shape(iterator._layout().shape(), res_shape.len());
37
38        let axes_to_broadcast =
39            get_broadcast_axes_from(&self_shape, &res_shape).expect("Cannot broadcast shapes");
40
41        let mut new_strides = vec![0; res_shape.len()];
42        new_strides
43            .iter_mut()
44            .rev()
45            .zip(iterator._layout().strides().iter().rev())
46            .for_each(|(a, b)| {
47                *a = *b;
48            });
49        for &axis in axes_to_broadcast.iter() {
50            assert_eq!(self_shape[axis], 1);
51            new_strides[axis] = 0;
52        }
53        iterator._set_last_strides(new_strides[new_strides.len() - 1]);
54        iterator._set_strides(new_strides.into());
55    } else {
56        ShapeError::check_size_match(iterator._layout().shape().size(), res_shape.size())
57            .expect("Cannot reshape iterator");
58        if let Some(new_strides) = iterator._layout().is_reshape_possible(&res_shape) {
59            iterator._set_strides(new_strides);
60            iterator._set_last_strides(
61                iterator._layout().strides()[iterator._layout().strides().len() - 1],
62            );
63        } else {
64            panic!("Cannot reshape iterator");
65        }
66    }
67
68    iterator._set_shape(res_shape.clone());
69    iterator
70}
71
72#[track_caller]
73pub(crate) fn par_transpose<AXIS: Into<Axis>, T: ParStridedHelper>(
74    mut iterator: T,
75    axes: AXIS,
76) -> T {
77    let axes = process_axes(axes, iterator._layout().shape().len()).unwrap();
78
79    let mut new_shape = iterator._layout().shape().to_vec();
80    for i in axes.iter() {
81        new_shape[*i] = iterator._layout().shape()[axes[*i]];
82    }
83    let mut new_strides = iterator._layout().strides().to_vec();
84    for i in axes.iter() {
85        new_strides[*i] = iterator._layout().strides()[axes[*i]];
86    }
87    let new_strides: Strides = new_strides.into();
88    let new_shape = Arc::new(new_shape);
89    let outer_loop_size =
90        (new_shape.iter().product::<i64>() as usize) / (new_shape[new_shape.len() - 1] as usize);
91    let num_threads;
92    if outer_loop_size < rayon::current_num_threads() {
93        num_threads = outer_loop_size;
94    } else {
95        num_threads = rayon::current_num_threads();
96    }
97    let intervals = Arc::new(mt_intervals(outer_loop_size, num_threads));
98    let len = intervals.len();
99    iterator._set_intervals(intervals.clone());
100    iterator._set_end_index(len);
101
102    iterator._set_last_strides(new_strides[new_strides.len() - 1]);
103    iterator._set_strides(new_strides);
104    iterator._set_shape(Shape::from(new_shape));
105    iterator
106}
107
108#[track_caller]
109pub(crate) fn par_expand<S: Into<Shape>, T: ParStridedHelper>(mut iterator: T, shape: S) -> T {
110    let res_shape = shape.into();
111
112    let new_strides = iterator
113        ._layout()
114        .expand_strides(&res_shape)
115        .expect("Cannot expand iterator");
116
117    let outer_loop_size =
118        (res_shape.iter().product::<i64>() as usize) / (res_shape[res_shape.len() - 1] as usize);
119    let num_threads;
120    if outer_loop_size < rayon::current_num_threads() {
121        num_threads = outer_loop_size;
122    } else {
123        num_threads = rayon::current_num_threads();
124    }
125    let intervals = Arc::new(mt_intervals(outer_loop_size, num_threads));
126    let len = intervals.len();
127    iterator._set_intervals(intervals.clone());
128    iterator._set_end_index(len);
129    iterator._set_shape(res_shape.clone());
130    iterator._set_strides(new_strides);
131    iterator
132}
133
134#[track_caller]
135pub(crate) fn reshape<S: Into<Shape>, T: StridedHelper>(mut iterator: T, shape: S) -> T {
136    let tmp = shape.into();
137    let res_shape = tmp;
138    if iterator._layout().shape() == &res_shape {
139        return iterator;
140    }
141    let size = res_shape.size() as usize;
142    let self_size = iterator._layout().size();
143
144    if size > (self_size as usize) {
145        let self_shape = try_pad_shape(iterator._layout().shape(), res_shape.len());
146
147        let axes_to_broadcast =
148            get_broadcast_axes_from(&self_shape, &res_shape).expect("Cannot broadcast shapes");
149
150        let mut new_strides = vec![0; res_shape.len()];
151        new_strides
152            .iter_mut()
153            .rev()
154            .zip(iterator._layout().strides().iter().rev())
155            .for_each(|(a, b)| {
156                *a = *b;
157            });
158        for &axis in axes_to_broadcast.iter() {
159            assert_eq!(self_shape[axis], 1);
160            new_strides[axis] = 0;
161        }
162        iterator._set_last_strides(new_strides[new_strides.len() - 1]);
163        iterator._set_strides(new_strides.into());
164    } else {
165        ShapeError::check_size_match(
166            iterator._layout().shape().inner().iter().product(),
167            res_shape.size(),
168        )
169        .expect("Cannot reshape iterator");
170        if let Some(new_strides) = iterator._layout().is_reshape_possible(&res_shape) {
171            iterator._set_strides(new_strides);
172            iterator._set_last_strides(
173                iterator._layout().strides()[iterator._layout().strides().len() - 1],
174            );
175        } else {
176            panic!("Cannot reshape iterator");
177        }
178    }
179
180    iterator._set_shape(res_shape.clone());
181    iterator
182}
183
184#[track_caller]
185pub(crate) fn expand<T: StridedHelper, S: Into<Shape>>(mut iterator: T, shape: S) -> T {
186    let res_shape: Shape = shape.into();
187    let new_strides = iterator
188        ._layout()
189        .expand_strides(&res_shape)
190        .expect("Cannot expand iterator");
191    iterator._set_shape(res_shape.clone());
192    iterator._set_strides(new_strides);
193    iterator
194}
195
196#[track_caller]
197pub(crate) fn transpose<T: StridedHelper, AXIS: Into<Axis>>(mut iterator: T, axes: AXIS) -> T {
198    // ErrHandler::check_axes_in_range(self.shape().len(), axes).unwrap();
199    let axes = process_axes(axes, iterator._layout().shape().len()).unwrap();
200
201    let mut new_shape = iterator._layout().shape().to_vec();
202    for i in axes.iter() {
203        new_shape[*i] = iterator._layout().shape()[axes[*i]];
204    }
205    let mut new_strides = iterator._layout().strides().to_vec();
206    for i in axes.iter() {
207        new_strides[*i] = iterator._layout().strides()[axes[*i]];
208    }
209    let new_strides: Strides = new_strides.into();
210    let new_shape = Arc::new(new_shape);
211
212    iterator._set_last_strides(new_strides[new_strides.len() - 1]);
213    iterator._set_strides(new_strides);
214    iterator._set_shape(Shape::from(new_shape));
215    iterator
216}