arr_rs/core/operations/split.rs
1use crate::{
2 core::prelude::*,
3 errors::prelude::*,
4 extensions::prelude::*,
5 validators::prelude::*,
6};
7use crate::prelude::Numeric;
8
9/// `ArrayTrait` - Array Split functions
10pub trait ArraySplit<T: ArrayElement> where Self: Sized + Clone {
11
12 /// Split an array into multiple sub-arrays
13 ///
14 /// # Arguments
15 ///
16 /// * `parts` - indices defining how to split the array
17 /// * `axis` - the axis along which to split. optional, defaults to 0
18 ///
19 /// # Examples
20 /// ```
21 /// use arr_rs::prelude::*;
22 ///
23 /// let arr = array_arange!(i32, 0, 7);
24 /// let split = arr.array_split(3, None).unwrap();
25 /// assert_eq!(vec![array_flat!(i32, 0, 1, 2).unwrap(), array_flat!(i32, 3, 4, 5).unwrap(), array_flat!(i32, 6, 7).unwrap()], split);
26 ///
27 /// let arr = array_arange!(i32, 0, 8);
28 /// let split = arr.array_split(4, None).unwrap();
29 /// assert_eq!(vec![array_flat!(i32, 0, 1, 2).unwrap(), array_flat!(i32, 3, 4).unwrap(), array_flat!(i32, 5, 6).unwrap(), array_flat!(i32, 7, 8).unwrap()], split);
30 /// ```
31 ///
32 /// # Errors
33 ///
34 /// may returns `ArrayError`
35 fn array_split(&self, parts: usize, axis: Option<usize>) -> Result<Vec<Array<T>>, ArrayError>;
36
37 /// Split an array into multiple sub-arrays of equal size
38 ///
39 /// # Arguments
40 ///
41 /// * `parts` - indices defining how to split the array
42 /// * `axis` - the axis along which to split. optional, defaults to 0
43 ///
44 /// # Examples
45 /// ```
46 /// use arr_rs::prelude::*;
47 ///
48 /// let arr = array_arange!(i32, 0, 8);
49 /// let split = arr.split(3, None).unwrap();
50 /// assert_eq!(vec![array_flat!(i32, 0, 1, 2).unwrap(), array_flat!(i32, 3, 4, 5).unwrap(), array_flat!(i32, 6, 7, 8).unwrap()], split);
51 ///
52 /// let arr = array_arange!(i32, 0, 7);
53 /// let split = arr.split(4, None).unwrap();
54 /// assert_eq!(vec![array_flat!(i32, 0, 1).unwrap(), array_flat!(i32, 2, 3).unwrap(), array_flat!(i32, 4, 5).unwrap(), array_flat!(i32, 6, 7).unwrap()], split);
55 /// ```
56 ///
57 /// # Errors
58 ///
59 /// may returns `ArrayError`
60 fn split(&self, parts: usize, axis: Option<usize>) -> Result<Vec<Array<T>>, ArrayError>;
61
62 /// Split an array into multiple sub-arrays of equal size by axis
63 ///
64 /// # Arguments
65 ///
66 /// * `axis` - the axis along which to split
67 ///
68 /// # Examples
69 /// ```
70 /// use arr_rs::prelude::*;
71 ///
72 /// let arr = array_arange!(i32, 0, 3).reshape(&[2, 2]);
73 /// let split = arr.split_axis(0).unwrap();
74 /// assert_eq!(vec![array!(i32, [[0, 1]]).unwrap(), array!(i32, [[2, 3]]).unwrap()], split);
75 ///
76 /// let arr = array_arange!(i32, 0, 7).reshape(&[2, 2, 2]);
77 /// let split = arr.split_axis(1).unwrap();
78 /// assert_eq!(vec![array!(i32, [[[0, 1]], [[4, 5]]]).unwrap(), array!(i32, [[[2, 3]], [[6, 7]]]).unwrap()], split);
79 /// ```
80 ///
81 /// # Errors
82 ///
83 /// may returns `ArrayError`
84 fn split_axis(&self, axis: usize) -> Result<Vec<Array<T>>, ArrayError>;
85
86 /// Split an array into multiple sub-arrays horizontally (column-wise)
87 ///
88 /// # Arguments
89 ///
90 /// * `parts` - indices defining how to split the array
91 ///
92 /// # Examples
93 /// ```
94 /// use arr_rs::prelude::*;
95 ///
96 /// let arr = array_arange!(i32, 0, 7).reshape(&[2, 2, 2]).unwrap();
97 /// let split = arr.hsplit(2).unwrap();
98 /// assert_eq!(vec![array!(i32, [[[0, 1]], [[4, 5]]]).unwrap(), array!(i32, [[[2, 3]], [[6, 7]]]).unwrap()], split);
99 /// ```
100 ///
101 /// # Errors
102 ///
103 /// may returns `ArrayError`
104 fn hsplit(&self, parts: usize) -> Result<Vec<Array<T>>, ArrayError>;
105
106 /// Split an array into multiple sub-arrays vertically (row-wise)
107 ///
108 /// # Arguments
109 ///
110 /// * `parts` - indices defining how to split the array
111 ///
112 /// # Examples
113 /// ```
114 /// use arr_rs::prelude::*;
115 ///
116 /// let arr = array_arange!(i32, 0, 7).reshape(&[2, 2, 2]).unwrap();
117 /// let split = arr.vsplit(2).unwrap();
118 /// assert_eq!(vec![array!(i32, [[[0, 1], [2, 3]]]).unwrap(), array!(i32, [[[4, 5], [6, 7]]]).unwrap()], split);
119 /// ```
120 ///
121 /// # Errors
122 ///
123 /// may returns `ArrayError`
124 fn vsplit(&self, parts: usize) -> Result<Vec<Array<T>>, ArrayError>;
125
126 /// Split an array into multiple sub-arrays along the 3rd axis (depth)
127 ///
128 /// # Arguments
129 ///
130 /// * `parts` - indices defining how to split the array
131 ///
132 /// # Examples
133 /// ```
134 /// use arr_rs::prelude::*;
135 ///
136 /// let arr = array_arange!(i32, 0, 7).reshape(&[2, 2, 2]).unwrap();
137 /// let split = arr.dsplit(2).unwrap();
138 /// assert_eq!(vec![array!(i32, [[[0], [2]], [[4], [6]]]).unwrap(), array!(i32, [[[1], [3]], [[5], [7]]]).unwrap()], split);
139 /// ```
140 ///
141 /// # Errors
142 ///
143 /// may returns `ArrayError`
144 fn dsplit(&self, parts: usize) -> Result<Vec<Array<T>>, ArrayError>;
145}
146
147impl <T: ArrayElement> ArraySplit<T> for Array<T> {
148
149 fn array_split(&self, parts: usize, axis: Option<usize>) -> Result<Vec<Self>, ArrayError> {
150 if parts == 0 { return Err(ArrayError::ParameterError { param: "parts", message: "number of sections must be larger than 0", }) }
151 self.axis_opt_in_bounds(axis)?;
152 if self.is_empty()? { return Ok(vec![self.clone()]) }
153
154 let axis = axis.unwrap_or(0);
155 let n_total = self.len()?;
156
157 let (sections, extras) = (n_total / parts, n_total % parts);
158 let section_sizes = std::iter::repeat(sections + 1)
159 .take(extras)
160 .chain(std::iter::repeat(sections).take(parts - extras))
161 .collect::<Vec<usize>>()
162 .insert_at(0, 0);
163 let mut div_points = vec![0; section_sizes.len()];
164 (0..div_points.len()).for_each(|i| {
165 div_points[i] = section_sizes.clone()[0 ..= i]
166 .iter_mut()
167 .fold(0, |acc, x| { *x += acc; *x });
168 });
169
170 let arr = self.rollaxis(axis.to_isize(), None);
171 arr.clone().map_or_else(|_| Err(arr.err().unwrap()), |arr| {
172 let result = div_points
173 .windows(2)
174 .map(|w| arr.clone().into_iter()
175 .skip(w[0]).take(w[1] - w[0])
176 .collect::<Self>())
177 .map(|m| {
178 if self.ndim()? == 1 { Ok(m) }
179 else {
180 let mut new_shape = self.get_shape()?;
181 new_shape[axis] /= parts;
182 m.reshape(&new_shape)
183 }
184 })
185 .collect::<Vec<Result<Self, _>>>();
186 if let Err(err) = result.has_error() { Err(err) }
187 else { Ok(result.into_iter().map(Result::unwrap).collect()) }
188 })
189 }
190
191 fn split(&self, parts: usize, axis: Option<usize>) -> Result<Vec<Self>, ArrayError> {
192 self.axis_opt_in_bounds(axis)?;
193 if parts == 0 {
194 Err(ArrayError::ParameterError { param: "parts", message: "number of sections must be larger than 0", })
195 } else {
196 if self.is_empty()? { return Ok(vec![self.clone()]) }
197 let n_total = self.shape[axis.unwrap_or(0)];
198
199 if n_total % parts == 0 {
200 self.array_split(parts, axis)
201 } else {
202 Err(ArrayError::ParameterError { param: "parts", message: "array split does not result in an equal division", })
203 }
204 }
205 }
206
207 fn split_axis(&self, axis: usize) -> Result<Vec<Self>, ArrayError> {
208 self.axis_in_bounds(axis)?;
209 if self.is_empty()? || self.ndim()? == 1 { Ok(vec![self.clone()]) }
210 else { self.array_split(self.shape[axis], Some(axis)) }
211 }
212
213 fn hsplit(&self, parts: usize) -> Result<Vec<Self>, ArrayError> {
214 self.is_dim_unsupported(&[0])?;
215 if parts == 0 {
216 Err(ArrayError::ParameterError { param: "parts", message: "number of sections must be larger than 0", })
217 } else {
218 match self.ndim()? {
219 1 => self.split(parts, Some(0)),
220 _ => self.split(parts, Some(1)),
221 }
222 }
223 }
224
225 fn vsplit(&self, parts: usize) -> Result<Vec<Self>, ArrayError> {
226 self.is_dim_unsupported(&[0, 1])?;
227 if parts == 0 {
228 Err(ArrayError::ParameterError { param: "parts", message: "number of sections must be larger than 0", })
229 } else {
230 self.split(parts, Some(0))
231 }
232 }
233
234 fn dsplit(&self, parts: usize) -> Result<Vec<Self>, ArrayError> {
235 self.is_dim_unsupported(&[0, 1, 2])?;
236 if parts == 0 {
237 Err(ArrayError::ParameterError { param: "parts", message: "number of sections must be larger than 0", })
238 } else {
239 self.split(parts, Some(2))
240 }
241 }
242}
243
244impl <T: ArrayElement> ArraySplit<T> for Result<Array<T>, ArrayError> {
245
246 fn array_split(&self, parts: usize, axis: Option<usize>) -> Result<Vec<Array<T>>, ArrayError> {
247 self.clone()?.array_split(parts, axis)
248 }
249
250 fn split(&self, parts: usize, axis: Option<usize>) -> Result<Vec<Array<T>>, ArrayError> {
251 self.clone()?.split(parts, axis)
252 }
253
254 fn split_axis(&self, axis: usize) -> Result<Vec<Array<T>>, ArrayError> {
255 self.clone()?.split_axis(axis)
256 }
257
258 fn hsplit(&self, parts: usize) -> Result<Vec<Array<T>>, ArrayError> {
259 self.clone()?.hsplit(parts)
260 }
261
262 fn vsplit(&self, parts: usize) -> Result<Vec<Array<T>>, ArrayError> {
263 self.clone()?.vsplit(parts)
264 }
265
266 fn dsplit(&self, parts: usize) -> Result<Vec<Array<T>>, ArrayError> {
267 self.clone()?.dsplit(parts)
268 }
269}