arr_rs/core/operations/
split.rs

1use crate::{
2    core::prelude::*,
3    errors::prelude::*,
4    extensions::prelude::*,
5    validators::prelude::*,
6};
7use crate::prelude::Numeric;
8
9/// `ArrayTrait` - Array Split functions
10pub trait ArraySplit<T: ArrayElement> where Self: Sized + Clone {
11
12    /// Split an array into multiple sub-arrays
13    ///
14    /// # Arguments
15    ///
16    /// * `parts` - indices defining how to split the array
17    /// * `axis` - the axis along which to split. optional, defaults to 0
18    ///
19    /// # Examples
20    /// ```
21    /// use arr_rs::prelude::*;
22    ///
23    /// let arr = array_arange!(i32, 0, 7);
24    /// let split = arr.array_split(3, None).unwrap();
25    /// assert_eq!(vec![array_flat!(i32, 0, 1, 2).unwrap(), array_flat!(i32, 3, 4, 5).unwrap(), array_flat!(i32, 6, 7).unwrap()], split);
26    ///
27    /// let arr = array_arange!(i32, 0, 8);
28    /// let split = arr.array_split(4, None).unwrap();
29    /// assert_eq!(vec![array_flat!(i32, 0, 1, 2).unwrap(), array_flat!(i32, 3, 4).unwrap(), array_flat!(i32, 5, 6).unwrap(), array_flat!(i32, 7, 8).unwrap()], split);
30    /// ```
31    ///
32    /// # Errors
33    ///
34    /// may returns `ArrayError`
35    fn array_split(&self, parts: usize, axis: Option<usize>) -> Result<Vec<Array<T>>, ArrayError>;
36
37    /// Split an array into multiple sub-arrays of equal size
38    ///
39    /// # Arguments
40    ///
41    /// * `parts` - indices defining how to split the array
42    /// * `axis` - the axis along which to split. optional, defaults to 0
43    ///
44    /// # Examples
45    /// ```
46    /// use arr_rs::prelude::*;
47    ///
48    /// let arr = array_arange!(i32, 0, 8);
49    /// let split = arr.split(3, None).unwrap();
50    /// assert_eq!(vec![array_flat!(i32, 0, 1, 2).unwrap(), array_flat!(i32, 3, 4, 5).unwrap(), array_flat!(i32, 6, 7, 8).unwrap()], split);
51    ///
52    /// let arr = array_arange!(i32, 0, 7);
53    /// let split = arr.split(4, None).unwrap();
54    /// assert_eq!(vec![array_flat!(i32, 0, 1).unwrap(), array_flat!(i32, 2, 3).unwrap(), array_flat!(i32, 4, 5).unwrap(), array_flat!(i32, 6, 7).unwrap()], split);
55    /// ```
56    ///
57    /// # Errors
58    ///
59    /// may returns `ArrayError`
60    fn split(&self, parts: usize, axis: Option<usize>) -> Result<Vec<Array<T>>, ArrayError>;
61
62    /// Split an array into multiple sub-arrays of equal size by axis
63    ///
64    /// # Arguments
65    ///
66    /// * `axis` - the axis along which to split
67    ///
68    /// # Examples
69    /// ```
70    /// use arr_rs::prelude::*;
71    ///
72    /// let arr = array_arange!(i32, 0, 3).reshape(&[2, 2]);
73    /// let split = arr.split_axis(0).unwrap();
74    /// assert_eq!(vec![array!(i32, [[0, 1]]).unwrap(), array!(i32, [[2, 3]]).unwrap()], split);
75    ///
76    /// let arr = array_arange!(i32, 0, 7).reshape(&[2, 2, 2]);
77    /// let split = arr.split_axis(1).unwrap();
78    /// assert_eq!(vec![array!(i32, [[[0, 1]], [[4, 5]]]).unwrap(), array!(i32, [[[2, 3]], [[6, 7]]]).unwrap()], split);
79    /// ```
80    ///
81    /// # Errors
82    ///
83    /// may returns `ArrayError`
84    fn split_axis(&self, axis: usize) -> Result<Vec<Array<T>>, ArrayError>;
85
86    /// Split an array into multiple sub-arrays horizontally (column-wise)
87    ///
88    /// # Arguments
89    ///
90    /// * `parts` - indices defining how to split the array
91    ///
92    /// # Examples
93    /// ```
94    /// use arr_rs::prelude::*;
95    ///
96    /// let arr = array_arange!(i32, 0, 7).reshape(&[2, 2, 2]).unwrap();
97    /// let split = arr.hsplit(2).unwrap();
98    /// assert_eq!(vec![array!(i32, [[[0, 1]], [[4, 5]]]).unwrap(), array!(i32, [[[2, 3]], [[6, 7]]]).unwrap()], split);
99    /// ```
100    ///
101    /// # Errors
102    ///
103    /// may returns `ArrayError`
104    fn hsplit(&self, parts: usize) -> Result<Vec<Array<T>>, ArrayError>;
105
106    /// Split an array into multiple sub-arrays vertically (row-wise)
107    ///
108    /// # Arguments
109    ///
110    /// * `parts` - indices defining how to split the array
111    ///
112    /// # Examples
113    /// ```
114    /// use arr_rs::prelude::*;
115    ///
116    /// let arr = array_arange!(i32, 0, 7).reshape(&[2, 2, 2]).unwrap();
117    /// let split = arr.vsplit(2).unwrap();
118    /// assert_eq!(vec![array!(i32, [[[0, 1], [2, 3]]]).unwrap(), array!(i32, [[[4, 5], [6, 7]]]).unwrap()], split);
119    /// ```
120    ///
121    /// # Errors
122    ///
123    /// may returns `ArrayError`
124    fn vsplit(&self, parts: usize) -> Result<Vec<Array<T>>, ArrayError>;
125
126    /// Split an array into multiple sub-arrays along the 3rd axis (depth)
127    ///
128    /// # Arguments
129    ///
130    /// * `parts` - indices defining how to split the array
131    ///
132    /// # Examples
133    /// ```
134    /// use arr_rs::prelude::*;
135    ///
136    /// let arr = array_arange!(i32, 0, 7).reshape(&[2, 2, 2]).unwrap();
137    /// let split = arr.dsplit(2).unwrap();
138    /// assert_eq!(vec![array!(i32, [[[0], [2]], [[4], [6]]]).unwrap(), array!(i32, [[[1], [3]], [[5], [7]]]).unwrap()], split);
139    /// ```
140    ///
141    /// # Errors
142    ///
143    /// may returns `ArrayError`
144    fn dsplit(&self, parts: usize) -> Result<Vec<Array<T>>, ArrayError>;
145}
146
147impl <T: ArrayElement> ArraySplit<T> for Array<T> {
148
149    fn array_split(&self, parts: usize, axis: Option<usize>) -> Result<Vec<Self>, ArrayError> {
150        if parts == 0 { return Err(ArrayError::ParameterError { param: "parts", message: "number of sections must be larger than 0", }) }
151        self.axis_opt_in_bounds(axis)?;
152        if self.is_empty()? { return Ok(vec![self.clone()]) }
153
154        let axis = axis.unwrap_or(0);
155        let n_total = self.len()?;
156
157        let (sections, extras) = (n_total / parts, n_total % parts);
158        let section_sizes = std::iter::repeat(sections + 1)
159            .take(extras)
160            .chain(std::iter::repeat(sections).take(parts - extras))
161            .collect::<Vec<usize>>()
162            .insert_at(0, 0);
163        let mut div_points = vec![0; section_sizes.len()];
164        (0..div_points.len()).for_each(|i| {
165            div_points[i] = section_sizes.clone()[0 ..= i]
166                .iter_mut()
167                .fold(0, |acc, x| { *x += acc; *x });
168        });
169
170        let arr = self.rollaxis(axis.to_isize(), None);
171        arr.clone().map_or_else(|_| Err(arr.err().unwrap()), |arr| {
172            let result = div_points
173                .windows(2)
174                .map(|w| arr.clone().into_iter()
175                    .skip(w[0]).take(w[1] - w[0])
176                    .collect::<Self>())
177                .map(|m| {
178                    if self.ndim()? == 1 { Ok(m) }
179                    else {
180                        let mut new_shape = self.get_shape()?;
181                        new_shape[axis] /= parts;
182                        m.reshape(&new_shape)
183                    }
184                })
185                .collect::<Vec<Result<Self, _>>>();
186            if let Err(err) = result.has_error() { Err(err) }
187            else { Ok(result.into_iter().map(Result::unwrap).collect()) }
188        })
189    }
190
191    fn split(&self, parts: usize, axis: Option<usize>) -> Result<Vec<Self>, ArrayError> {
192        self.axis_opt_in_bounds(axis)?;
193        if parts == 0 {
194            Err(ArrayError::ParameterError { param: "parts", message: "number of sections must be larger than 0", })
195        } else {
196            if self.is_empty()? { return Ok(vec![self.clone()]) }
197            let n_total = self.shape[axis.unwrap_or(0)];
198
199            if n_total % parts == 0 {
200                self.array_split(parts, axis)
201            } else {
202                Err(ArrayError::ParameterError { param: "parts", message: "array split does not result in an equal division", })
203            }
204        }
205    }
206
207    fn split_axis(&self, axis: usize) -> Result<Vec<Self>, ArrayError> {
208        self.axis_in_bounds(axis)?;
209        if self.is_empty()? || self.ndim()? == 1 { Ok(vec![self.clone()]) }
210        else { self.array_split(self.shape[axis], Some(axis)) }
211    }
212
213    fn hsplit(&self, parts: usize) -> Result<Vec<Self>, ArrayError> {
214        self.is_dim_unsupported(&[0])?;
215        if parts == 0 {
216            Err(ArrayError::ParameterError { param: "parts", message: "number of sections must be larger than 0", })
217        } else {
218            match self.ndim()? {
219                1 => self.split(parts, Some(0)),
220                _ => self.split(parts, Some(1)),
221            }
222        }
223    }
224
225    fn vsplit(&self, parts: usize) -> Result<Vec<Self>, ArrayError> {
226        self.is_dim_unsupported(&[0, 1])?;
227        if parts == 0 {
228            Err(ArrayError::ParameterError { param: "parts", message: "number of sections must be larger than 0", })
229        } else {
230            self.split(parts, Some(0))
231        }
232    }
233
234    fn dsplit(&self, parts: usize) -> Result<Vec<Self>, ArrayError> {
235        self.is_dim_unsupported(&[0, 1, 2])?;
236        if parts == 0 {
237            Err(ArrayError::ParameterError { param: "parts", message: "number of sections must be larger than 0", })
238        } else {
239            self.split(parts, Some(2))
240        }
241    }
242}
243
244impl <T: ArrayElement> ArraySplit<T> for Result<Array<T>, ArrayError> {
245
246    fn array_split(&self, parts: usize, axis: Option<usize>) -> Result<Vec<Array<T>>, ArrayError> {
247        self.clone()?.array_split(parts, axis)
248    }
249
250    fn split(&self, parts: usize, axis: Option<usize>) -> Result<Vec<Array<T>>, ArrayError> {
251        self.clone()?.split(parts, axis)
252    }
253
254    fn split_axis(&self, axis: usize) -> Result<Vec<Array<T>>, ArrayError> {
255        self.clone()?.split_axis(axis)
256    }
257
258    fn hsplit(&self, parts: usize) -> Result<Vec<Array<T>>, ArrayError> {
259        self.clone()?.hsplit(parts)
260    }
261
262    fn vsplit(&self, parts: usize) -> Result<Vec<Array<T>>, ArrayError> {
263        self.clone()?.vsplit(parts)
264    }
265
266    fn dsplit(&self, parts: usize) -> Result<Vec<Array<T>>, ArrayError> {
267        self.clone()?.dsplit(parts)
268    }
269}