arr_rs/core/operations/
broadcast.rs1use crate::{
2 core::prelude::*,
3 errors::prelude::*,
4 validators::prelude::*,
5};
6
7pub trait ArrayBroadcast<T: ArrayElement> where Self: Sized + Clone {
9
10 fn broadcast(&self, other: &Array<T>) -> Result<Array<Tuple2<T, T>>, ArrayError>;
38
39 fn broadcast_to(&self, shape: Vec<usize>) -> Result<Array<T>, ArrayError>;
61
62 fn broadcast_arrays(arrays: Vec<Array<T>>) -> Result<Vec<Array<T>>, ArrayError>;
88}
89
90impl <T: ArrayElement> ArrayBroadcast<T> for Array<T> {
91
92 fn broadcast(&self, other: &Self) -> Result<Array<Tuple2<T, T>>, ArrayError> {
93 self.get_shape()?.is_broadcastable(&other.get_shape()?)?;
94 if self.get_shape()? == other.get_shape()? {
95 return self.get_elements()?.into_iter()
96 .zip(other.get_elements()?)
97 .map(|(a, b)| Tuple2(a, b))
98 .collect::<Array<Tuple2<T, T>>>()
99 .reshape(&self.get_shape()?);
100 }
101
102 let final_shape = self.broadcast_shape(&other.get_shape()?)?;
103
104 let inner_arrays_self = self.extract_inner_arrays();
105 let inner_arrays_other = other.extract_inner_arrays();
106
107 let output_elements = inner_arrays_self.iter().cycle()
108 .zip(inner_arrays_other.iter().cycle())
109 .flat_map( | (inner_self, inner_other) | match (inner_self.len(), inner_other.len()) {
110 (1, _) => inner_self.iter().cycle()
111 .zip(inner_other.iter())
112 .take(final_shape[final_shape.len() - 1])
113 .map(|(a, b) | Tuple2(a.clone(), b.clone()))
114 .collect::< Vec < _ > > (),
115 (_, 1) => inner_self.iter()
116 .zip(inner_other.iter().cycle())
117 .take(final_shape[final_shape.len() - 1])
118 .map(|(a, b) | Tuple2(a.clone(), b.clone()))
119 .collect::<Vec < _ > > (),
120 _ => inner_self.iter().cycle()
121 .zip(inner_other.iter().cycle())
122 .take(final_shape[final_shape.len() - 1])
123 .map(|(a, b) | Tuple2(a.clone(), b.clone()))
124 .collect::< Vec< _ > > (),
125 })
126 .take(final_shape.iter().product())
127 .collect:: < Vec<_ > > ();
128
129 Array::new(output_elements, final_shape)
130 }
131
132 fn broadcast_to(&self, shape: Vec<usize>) -> Result<Self, ArrayError> {
133 self.get_shape()?.is_broadcastable(&shape)?;
134
135 if self.get_shape()?.iter().product::<usize>() == shape.iter().product::<usize>() {
136 self.reshape(&shape)
137 } else {
138 let output_elements: Vec<T> = self.elements
139 .chunks_exact(self.shape[self.shape.len() - 1])
140 .flat_map(|inner| inner.iter()
141 .cycle()
142 .take(shape[shape.len() - 1])
143 .cloned())
144 .cycle()
145 .take(shape.iter().product())
146 .collect();
147
148 Self::new(output_elements, shape)
149 }
150 }
151
152 fn broadcast_arrays(arrays: Vec<Self>) -> Result<Vec<Self>, ArrayError> {
153 arrays.iter()
154 .map(Self::get_shape)
155 .collect::<Vec<Result<Vec<usize>, ArrayError>>>()
156 .has_error()?;
157 let shapes = arrays.iter()
158 .map(|array| array.get_shape().unwrap())
159 .collect::<Vec<_>>();
160
161 let common_shape = Self::common_broadcast_shape(&shapes);
162 if let Ok(common_shape) = common_shape {
163 let result = arrays.iter()
164 .map(|array| array.broadcast_to(common_shape.clone()))
165 .collect::<Vec<Result<Self, _>>>()
166 .has_error()?
167 .into_iter().map(Result::unwrap)
168 .collect();
169 Ok(result)
170 } else {
171 Err(common_shape.err().unwrap())
172 }
173 }
174}
175
176impl <T: ArrayElement> ArrayBroadcast<T> for Result<Array<T>, ArrayError> {
177
178 fn broadcast(&self, other: &Array<T>) -> Result<Array<Tuple2<T, T>>, ArrayError> {
179 self.clone()?.broadcast(other)
180 }
181
182 fn broadcast_to(&self, shape: Vec<usize>) -> Self {
183 self.clone()?.broadcast_to(shape)
184 }
185
186 fn broadcast_arrays(arrays: Vec<Array<T>>) -> Result<Vec<Array<T>>, ArrayError> {
187 Array::broadcast_arrays(arrays)
188 }
189}
190
191impl <T: ArrayElement> Array<T> {
192
193 fn broadcast_shape(&self, shape: &[usize]) -> Result<Vec<usize>, ArrayError> {
194 let max_dim = self.shape.len().max(shape.len());
195 let shape1_padded = self.shape.iter().rev()
196 .copied().chain(std::iter::repeat(1))
197 .take(max_dim);
198 let shape2_padded = shape.iter().rev()
199 .copied().chain(std::iter::repeat(1))
200 .take(max_dim);
201
202 let zipped = shape1_padded.zip(shape2_padded);
203 let result = zipped
204 .map(|(dim1, dim2)| {
205 if dim1 == 1 { Ok(dim2) }
206 else if dim2 == 1 || dim1 == dim2 { Ok(dim1) }
207 else { Err(ArrayError::BroadcastShapeMismatch) }
208 })
209 .collect::<Vec<Result<usize, ArrayError>>>()
210 .has_error()?.iter()
211 .map(|a| *a.as_ref().unwrap())
212 .collect();
213 Ok(result)
214 }
215
216 fn common_broadcast_shape(shapes: &[Vec<usize>]) -> Result<Vec<usize>, ArrayError> {
217 let max_dim = shapes.iter()
218 .map(Vec::len)
219 .max().unwrap_or(0);
220
221 let shapes_padded: Vec<_> = shapes
222 .iter()
223 .map(|shape| shape.iter().rev().copied()
224 .chain(std::iter::repeat(1))
225 .take(max_dim)
226 .collect::<Vec<_>>()
227 )
228 .collect();
229
230 let common_shape: Vec<usize> = (0..max_dim)
231 .map(|dim_idx| shapes_padded.iter()
232 .map(|shape| shape[dim_idx])
233 .max().unwrap_or(1)
234 )
235 .collect();
236
237 let is_compatible = shapes_padded.iter()
238 .all(|shape| common_shape.iter().enumerate()
239 .all(|(dim_idx, &common_dim)| {
240 let dim = shape[dim_idx];
241 dim == common_dim || dim == 1 || common_dim == 1
242 })
243 );
244
245 if is_compatible { Ok(common_shape.into_iter().rev().collect()) }
246 else { Err(ArrayError::BroadcastShapeMismatch) }
247 }
248
249 fn extract_inner_arrays(&self) -> Vec<Vec<T>> {
250 match self.shape.len() {
251 1 => vec![self.elements.clone()],
252 _ => self.elements
253 .chunks_exact(*self.shape.last().unwrap())
254 .map(Vec::from)
255 .collect(),
256 }
257 }
258
259 pub(crate) fn broadcast_h2<S: ArrayElement>(&self, other: &Array<S>) -> Result<TupleH2<T, S>, ArrayError> {
260 let tmp_other = Self::single(T::zero()).broadcast_to(other.get_shape()?)?;
261 let tmp_array = self.broadcast(&tmp_other)?;
262
263 let array = tmp_array.clone().into_iter()
264 .map(|t| t.0).collect::<Self>()
265 .reshape(&tmp_array.get_shape()?)?;
266 let other = other.broadcast_to(array.get_shape()?)?;
267
268 Ok((array, other))
269 }
270
271 pub(crate) fn broadcast_h3<S: ArrayElement, Q: ArrayElement>(&self, other_1: &Array<S>, other_2: &Array<Q>) -> Result<TupleH3<T, S, Q>, ArrayError> {
272 let tmp_other_1 = Self::single(T::zero()).broadcast_to(other_1.get_shape()?)?;
273 let tmp_other_2 = Self::single(T::zero()).broadcast_to(other_2.get_shape()?)?;
274 let broadcasted = Self::broadcast_arrays(vec![self.clone(), tmp_other_1, tmp_other_2])?;
275
276 let array = broadcasted[0].clone();
277 let other_1 = other_1.broadcast_to(array.get_shape()?)?;
278 let other_2 = other_2.broadcast_to(array.get_shape()?)?;
279
280 Ok((array, other_1, other_2))
281 }
282}