hpt_iterator/
shape_manipulate.rs1use std::sync::Arc;
2
3use hpt_common::{
4 axis::axis::{process_axes, Axis},
5 error::shape::ShapeError,
6 shape::shape::Shape,
7 shape::shape_utils::{get_broadcast_axes_from, mt_intervals, try_pad_shape},
8 strides::strides::Strides,
9};
10
11use crate::iterator_traits::{ParStridedHelper, StridedHelper};
12
13#[track_caller]
14pub(crate) fn par_reshape<S: Into<Shape>, T: ParStridedHelper>(mut iterator: T, shape: S) -> T {
15 let tmp = shape.into();
16 let res_shape = tmp;
17 if iterator._layout().shape() == &res_shape {
18 return iterator;
19 }
20 let size = res_shape.size() as usize;
21 let inner_loop_size = res_shape[res_shape.len() - 1] as usize;
22 let outer_loop_size = size / inner_loop_size;
23 let num_threads;
24 if outer_loop_size < rayon::current_num_threads() {
25 num_threads = outer_loop_size;
26 } else {
27 num_threads = rayon::current_num_threads();
28 }
29 let intervals = mt_intervals(outer_loop_size, num_threads);
30 let len = intervals.len();
31 iterator._set_intervals(Arc::new(intervals));
32 iterator._set_end_index(len);
33 let self_size = iterator._layout().size();
34
35 if size > (self_size as usize) {
36 let self_shape = try_pad_shape(iterator._layout().shape(), res_shape.len());
37
38 let axes_to_broadcast =
39 get_broadcast_axes_from(&self_shape, &res_shape).expect("Cannot broadcast shapes");
40
41 let mut new_strides = vec![0; res_shape.len()];
42 new_strides
43 .iter_mut()
44 .rev()
45 .zip(iterator._layout().strides().iter().rev())
46 .for_each(|(a, b)| {
47 *a = *b;
48 });
49 for &axis in axes_to_broadcast.iter() {
50 assert_eq!(self_shape[axis], 1);
51 new_strides[axis] = 0;
52 }
53 iterator._set_last_strides(new_strides[new_strides.len() - 1]);
54 iterator._set_strides(new_strides.into());
55 } else {
56 ShapeError::check_size_match(iterator._layout().shape().size(), res_shape.size())
57 .expect("Cannot reshape iterator");
58 if let Some(new_strides) = iterator._layout().is_reshape_possible(&res_shape) {
59 iterator._set_strides(new_strides);
60 iterator._set_last_strides(
61 iterator._layout().strides()[iterator._layout().strides().len() - 1],
62 );
63 } else {
64 panic!("Cannot reshape iterator");
65 }
66 }
67
68 iterator._set_shape(res_shape.clone());
69 iterator
70}
71
72#[track_caller]
73pub(crate) fn par_transpose<AXIS: Into<Axis>, T: ParStridedHelper>(
74 mut iterator: T,
75 axes: AXIS,
76) -> T {
77 let axes = process_axes(axes, iterator._layout().shape().len()).unwrap();
78
79 let mut new_shape = iterator._layout().shape().to_vec();
80 for i in axes.iter() {
81 new_shape[*i] = iterator._layout().shape()[axes[*i]];
82 }
83 let mut new_strides = iterator._layout().strides().to_vec();
84 for i in axes.iter() {
85 new_strides[*i] = iterator._layout().strides()[axes[*i]];
86 }
87 let new_strides: Strides = new_strides.into();
88 let new_shape = Arc::new(new_shape);
89 let outer_loop_size =
90 (new_shape.iter().product::<i64>() as usize) / (new_shape[new_shape.len() - 1] as usize);
91 let num_threads;
92 if outer_loop_size < rayon::current_num_threads() {
93 num_threads = outer_loop_size;
94 } else {
95 num_threads = rayon::current_num_threads();
96 }
97 let intervals = Arc::new(mt_intervals(outer_loop_size, num_threads));
98 let len = intervals.len();
99 iterator._set_intervals(intervals.clone());
100 iterator._set_end_index(len);
101
102 iterator._set_last_strides(new_strides[new_strides.len() - 1]);
103 iterator._set_strides(new_strides);
104 iterator._set_shape(Shape::from(new_shape));
105 iterator
106}
107
108#[track_caller]
109pub(crate) fn par_expand<S: Into<Shape>, T: ParStridedHelper>(mut iterator: T, shape: S) -> T {
110 let res_shape = shape.into();
111
112 let new_strides = iterator
113 ._layout()
114 .expand_strides(&res_shape)
115 .expect("Cannot expand iterator");
116
117 let outer_loop_size =
118 (res_shape.iter().product::<i64>() as usize) / (res_shape[res_shape.len() - 1] as usize);
119 let num_threads;
120 if outer_loop_size < rayon::current_num_threads() {
121 num_threads = outer_loop_size;
122 } else {
123 num_threads = rayon::current_num_threads();
124 }
125 let intervals = Arc::new(mt_intervals(outer_loop_size, num_threads));
126 let len = intervals.len();
127 iterator._set_intervals(intervals.clone());
128 iterator._set_end_index(len);
129 iterator._set_shape(res_shape.clone());
130 iterator._set_strides(new_strides);
131 iterator
132}
133
134#[track_caller]
135pub(crate) fn reshape<S: Into<Shape>, T: StridedHelper>(mut iterator: T, shape: S) -> T {
136 let tmp = shape.into();
137 let res_shape = tmp;
138 if iterator._layout().shape() == &res_shape {
139 return iterator;
140 }
141 let size = res_shape.size() as usize;
142 let self_size = iterator._layout().size();
143
144 if size > (self_size as usize) {
145 let self_shape = try_pad_shape(iterator._layout().shape(), res_shape.len());
146
147 let axes_to_broadcast =
148 get_broadcast_axes_from(&self_shape, &res_shape).expect("Cannot broadcast shapes");
149
150 let mut new_strides = vec![0; res_shape.len()];
151 new_strides
152 .iter_mut()
153 .rev()
154 .zip(iterator._layout().strides().iter().rev())
155 .for_each(|(a, b)| {
156 *a = *b;
157 });
158 for &axis in axes_to_broadcast.iter() {
159 assert_eq!(self_shape[axis], 1);
160 new_strides[axis] = 0;
161 }
162 iterator._set_last_strides(new_strides[new_strides.len() - 1]);
163 iterator._set_strides(new_strides.into());
164 } else {
165 ShapeError::check_size_match(
166 iterator._layout().shape().inner().iter().product(),
167 res_shape.size(),
168 )
169 .expect("Cannot reshape iterator");
170 if let Some(new_strides) = iterator._layout().is_reshape_possible(&res_shape) {
171 iterator._set_strides(new_strides);
172 iterator._set_last_strides(
173 iterator._layout().strides()[iterator._layout().strides().len() - 1],
174 );
175 } else {
176 panic!("Cannot reshape iterator");
177 }
178 }
179
180 iterator._set_shape(res_shape.clone());
181 iterator
182}
183
184#[track_caller]
185pub(crate) fn expand<T: StridedHelper, S: Into<Shape>>(mut iterator: T, shape: S) -> T {
186 let res_shape: Shape = shape.into();
187 let new_strides = iterator
188 ._layout()
189 .expand_strides(&res_shape)
190 .expect("Cannot expand iterator");
191 iterator._set_shape(res_shape.clone());
192 iterator._set_strides(new_strides);
193 iterator
194}
195
196#[track_caller]
197pub(crate) fn transpose<T: StridedHelper, AXIS: Into<Axis>>(mut iterator: T, axes: AXIS) -> T {
198 let axes = process_axes(axes, iterator._layout().shape().len()).unwrap();
200
201 let mut new_shape = iterator._layout().shape().to_vec();
202 for i in axes.iter() {
203 new_shape[*i] = iterator._layout().shape()[axes[*i]];
204 }
205 let mut new_strides = iterator._layout().strides().to_vec();
206 for i in axes.iter() {
207 new_strides[*i] = iterator._layout().strides()[axes[*i]];
208 }
209 let new_strides: Strides = new_strides.into();
210 let new_shape = Arc::new(new_shape);
211
212 iterator._set_last_strides(new_strides[new_strides.len() - 1]);
213 iterator._set_strides(new_strides);
214 iterator._set_shape(Shape::from(new_shape));
215 iterator
216}