arr_rs/core/operations/
broadcast.rs

1use crate::{
2    core::prelude::*,
3    errors::prelude::*,
4    validators::prelude::*,
5};
6
7/// `ArrayTrait` - Array Broadcast functions
8pub trait ArrayBroadcast<T: ArrayElement> where Self: Sized + Clone {
9
10    /// Broadcast an array to a new shape
11    ///
12    /// # Arguments
13    ///
14    /// * `other` - other array for broadcasting
15    ///
16    /// # Examples
17    ///
18    /// ```
19    /// use arr_rs::prelude::*;
20    ///
21    /// let expected = Array::new(vec![
22    ///     (1, 4), (1, 5), (1, 6),
23    ///     (2, 4), (2, 5), (2, 6),
24    ///     (3, 4), (3, 5), (3, 6)
25    /// ].into_iter().map(Tuple2::from_tuple).collect(), vec![3, 3]).unwrap();
26    ///
27    /// let arr_1 = array!(i32, [[1], [2], [3]]).unwrap();
28    /// let arr_2 = array!(i32, [[4, 5, 6]]).unwrap();
29    ///
30    /// let broadcast = arr_1.broadcast(&arr_2).unwrap();
31    /// assert_eq!(expected, broadcast);
32    /// ```
33    ///
34    /// # Errors
35    ///
36    /// may returns `ArrayError`
37    fn broadcast(&self, other: &Array<T>) -> Result<Array<Tuple2<T, T>>, ArrayError>;
38
39    /// Broadcast an array to a new shape
40    ///
41    /// # Arguments
42    ///
43    /// * `other` - other array for broadcasting
44    ///
45    /// # Examples
46    ///
47    /// ```
48    /// use arr_rs::prelude::*;
49    ///
50    /// let expected = Array::new(vec![1, 1, 1, 2, 2, 2, 3, 3, 3], vec![3, 3]).unwrap();
51    /// let arr_1 = array!(i32, [[1], [2], [3]]).unwrap();
52    ///
53    /// let broadcast = arr_1.broadcast_to(vec![3, 3]).unwrap();
54    /// assert_eq!(expected, broadcast);
55    /// ```
56    ///
57    /// # Errors
58    ///
59    /// may returns `ArrayError`
60    fn broadcast_to(&self, shape: Vec<usize>) -> Result<Array<T>, ArrayError>;
61
62    /// Broadcast a list of arrays to a common shape
63    ///
64    /// # Arguments
65    ///
66    /// * `arrays` - list of arrays for broadcasting
67    ///
68    /// # Examples
69    ///
70    /// ```
71    /// use arr_rs::prelude::*;
72    ///
73    /// let expected = vec![
74    ///     Array::new(vec![1, 1, 1, 2, 2, 2, 3, 3, 3], vec![3, 3]).unwrap(),
75    ///     Array::new(vec![4, 5, 6, 4, 5, 6, 4, 5, 6], vec![3, 3]).unwrap(),
76    /// ];
77    /// let arr_1 = array!(i32, [[1], [2], [3]]).unwrap();
78    /// let arr_2 = array!(i32, [4, 5, 6]).unwrap();
79    ///
80    /// let broadcast = Array::broadcast_arrays(vec![arr_1 ,arr_2]).unwrap();
81    /// assert_eq!(expected, broadcast);
82    /// ```
83    ///
84    /// # Errors
85    ///
86    /// may returns `ArrayError`
87    fn broadcast_arrays(arrays: Vec<Array<T>>) -> Result<Vec<Array<T>>, ArrayError>;
88}
89
90impl <T: ArrayElement> ArrayBroadcast<T> for Array<T> {
91
92    fn broadcast(&self, other: &Self) -> Result<Array<Tuple2<T, T>>, ArrayError> {
93        self.get_shape()?.is_broadcastable(&other.get_shape()?)?;
94        if self.get_shape()? == other.get_shape()? {
95            return self.get_elements()?.into_iter()
96                .zip(other.get_elements()?)
97                .map(|(a, b)| Tuple2(a, b))
98                .collect::<Array<Tuple2<T, T>>>()
99                .reshape(&self.get_shape()?);
100        }
101
102        let final_shape = self.broadcast_shape(&other.get_shape()?)?;
103
104        let inner_arrays_self = self.extract_inner_arrays();
105        let inner_arrays_other = other.extract_inner_arrays();
106
107        let output_elements = inner_arrays_self.iter().cycle()
108            .zip(inner_arrays_other.iter().cycle())
109            .flat_map( | (inner_self, inner_other) | match (inner_self.len(), inner_other.len()) {
110                (1, _) => inner_self.iter().cycle()
111                    .zip(inner_other.iter())
112                    .take(final_shape[final_shape.len() - 1])
113                    .map(|(a, b) | Tuple2(a.clone(), b.clone()))
114                    .collect::< Vec < _ > > (),
115                (_, 1) => inner_self.iter()
116                    .zip(inner_other.iter().cycle())
117                    .take(final_shape[final_shape.len() - 1])
118                    .map(|(a, b) | Tuple2(a.clone(), b.clone()))
119                    .collect::<Vec < _ > > (),
120                _ => inner_self.iter().cycle()
121                    .zip(inner_other.iter().cycle())
122                    .take(final_shape[final_shape.len() - 1])
123                    .map(|(a, b) | Tuple2(a.clone(), b.clone()))
124                    .collect::< Vec< _ > > (),
125            })
126            .take(final_shape.iter().product())
127            .collect:: < Vec<_ > > ();
128
129        Array::new(output_elements, final_shape)
130    }
131
132    fn broadcast_to(&self, shape: Vec<usize>) -> Result<Self, ArrayError> {
133        self.get_shape()?.is_broadcastable(&shape)?;
134
135        if self.get_shape()?.iter().product::<usize>() == shape.iter().product::<usize>() {
136            self.reshape(&shape)
137        } else {
138            let output_elements: Vec<T> = self.elements
139                .chunks_exact(self.shape[self.shape.len() - 1])
140                .flat_map(|inner| inner.iter()
141                    .cycle()
142                    .take(shape[shape.len() - 1])
143                    .cloned())
144                .cycle()
145                .take(shape.iter().product())
146                .collect();
147
148            Self::new(output_elements, shape)
149        }
150    }
151
152    fn broadcast_arrays(arrays: Vec<Self>) -> Result<Vec<Self>, ArrayError> {
153        arrays.iter()
154            .map(Self::get_shape)
155            .collect::<Vec<Result<Vec<usize>, ArrayError>>>()
156            .has_error()?;
157        let shapes = arrays.iter()
158            .map(|array| array.get_shape().unwrap())
159            .collect::<Vec<_>>();
160
161        let common_shape = Self::common_broadcast_shape(&shapes);
162        if let Ok(common_shape) = common_shape {
163            let result = arrays.iter()
164                .map(|array| array.broadcast_to(common_shape.clone()))
165                .collect::<Vec<Result<Self, _>>>()
166                .has_error()?
167                .into_iter().map(Result::unwrap)
168                .collect();
169            Ok(result)
170        } else {
171            Err(common_shape.err().unwrap())
172        }
173    }
174}
175
176impl <T: ArrayElement> ArrayBroadcast<T> for Result<Array<T>, ArrayError> {
177
178    fn broadcast(&self, other: &Array<T>) -> Result<Array<Tuple2<T, T>>, ArrayError> {
179        self.clone()?.broadcast(other)
180    }
181
182    fn broadcast_to(&self, shape: Vec<usize>) -> Self {
183        self.clone()?.broadcast_to(shape)
184    }
185
186    fn broadcast_arrays(arrays: Vec<Array<T>>) -> Result<Vec<Array<T>>, ArrayError> {
187        Array::broadcast_arrays(arrays)
188    }
189}
190
191impl <T: ArrayElement> Array<T> {
192
193    fn broadcast_shape(&self, shape: &[usize]) -> Result<Vec<usize>, ArrayError> {
194        let max_dim = self.shape.len().max(shape.len());
195        let shape1_padded = self.shape.iter().rev()
196            .copied().chain(std::iter::repeat(1))
197            .take(max_dim);
198        let shape2_padded = shape.iter().rev()
199            .copied().chain(std::iter::repeat(1))
200            .take(max_dim);
201
202        let zipped = shape1_padded.zip(shape2_padded);
203        let result = zipped
204            .map(|(dim1, dim2)| {
205                if dim1 == 1 { Ok(dim2) }
206                else if dim2 == 1 || dim1 == dim2 { Ok(dim1) }
207                else { Err(ArrayError::BroadcastShapeMismatch) }
208            })
209            .collect::<Vec<Result<usize, ArrayError>>>()
210            .has_error()?.iter()
211            .map(|a| *a.as_ref().unwrap())
212            .collect();
213        Ok(result)
214    }
215
216    fn common_broadcast_shape(shapes: &[Vec<usize>]) -> Result<Vec<usize>, ArrayError> {
217        let max_dim = shapes.iter()
218            .map(Vec::len)
219            .max().unwrap_or(0);
220
221        let shapes_padded: Vec<_> = shapes
222            .iter()
223            .map(|shape| shape.iter().rev().copied()
224                .chain(std::iter::repeat(1))
225                .take(max_dim)
226                .collect::<Vec<_>>()
227            )
228            .collect();
229
230        let common_shape: Vec<usize> = (0..max_dim)
231            .map(|dim_idx| shapes_padded.iter()
232                .map(|shape| shape[dim_idx])
233                .max().unwrap_or(1)
234            )
235            .collect();
236
237        let is_compatible = shapes_padded.iter()
238            .all(|shape| common_shape.iter().enumerate()
239                .all(|(dim_idx, &common_dim)| {
240                    let dim = shape[dim_idx];
241                    dim == common_dim || dim == 1 || common_dim == 1
242                })
243            );
244
245        if is_compatible { Ok(common_shape.into_iter().rev().collect()) }
246        else { Err(ArrayError::BroadcastShapeMismatch) }
247    }
248
249    fn extract_inner_arrays(&self) -> Vec<Vec<T>> {
250        match self.shape.len() {
251            1 => vec![self.elements.clone()],
252            _ => self.elements
253                .chunks_exact(*self.shape.last().unwrap())
254                .map(Vec::from)
255                .collect(),
256        }
257    }
258
259    pub(crate) fn broadcast_h2<S: ArrayElement>(&self, other: &Array<S>) -> Result<TupleH2<T, S>, ArrayError> {
260        let tmp_other = Self::single(T::zero()).broadcast_to(other.get_shape()?)?;
261        let tmp_array = self.broadcast(&tmp_other)?;
262
263        let array = tmp_array.clone().into_iter()
264            .map(|t| t.0).collect::<Self>()
265            .reshape(&tmp_array.get_shape()?)?;
266        let other = other.broadcast_to(array.get_shape()?)?;
267
268        Ok((array, other))
269    }
270
271    pub(crate) fn broadcast_h3<S: ArrayElement, Q: ArrayElement>(&self, other_1: &Array<S>, other_2: &Array<Q>) -> Result<TupleH3<T, S, Q>, ArrayError> {
272        let tmp_other_1 = Self::single(T::zero()).broadcast_to(other_1.get_shape()?)?;
273        let tmp_other_2 = Self::single(T::zero()).broadcast_to(other_2.get_shape()?)?;
274        let broadcasted = Self::broadcast_arrays(vec![self.clone(), tmp_other_1, tmp_other_2])?;
275
276        let array = broadcasted[0].clone();
277        let other_1 = other_1.broadcast_to(array.get_shape()?)?;
278        let other_2 = other_2.broadcast_to(array.get_shape()?)?;
279
280        Ok((array, other_1, other_2))
281    }
282}