arr_rs/linalg/operations/
products.rs1use crate::{
2 core::prelude::*,
3 errors::prelude::*,
4 extensions::prelude::*,
5 linalg::prelude::*,
6 math::prelude::*,
7 numeric::prelude::*,
8 validators::prelude::*,
9};
10
11pub trait ArrayLinalgProducts<N: NumericOps> where Self: Sized + Clone {
13
14 fn dot(&self, other: &Array<N>) -> Result<Array<N>, ArrayError>;
34
35 fn vdot(&self, other: &Array<N>) -> Result<Array<N>, ArrayError>;
54
55 fn inner(&self, other: &Array<N>) -> Result<Array<N>, ArrayError>;
74
75 fn outer(&self, other: &Array<N>) -> Result<Array<N>, ArrayError>;
94
95 fn matmul(&self, other: &Array<N>) -> Result<Array<N>, ArrayError>;
114}
115
116impl <N: NumericOps> ArrayLinalgProducts<N> for Array<N> {
117
118 fn dot(&self, other: &Self) -> Result<Self, ArrayError> {
119 if self.len()? == 1 || other.len()? == 1 {
120 self.multiply(other)
121 } else if self.ndim()? == 1 && other.ndim()? == 1 {
122 self.vdot(other)
123 } else if self.ndim()? == 2 && other.ndim()? == 2 {
124 self.matmul(other)
125 } else if self.ndim()? == 1 || other.ndim()? == 1 {
126 Self::dot_1d(self, other)
127 } else {
128 Self::dot_nd(self, other)
129 }
130 }
131
132 fn vdot(&self, other: &Self) -> Result<Self, ArrayError> {
133 self.len()?.is_equal(&other.len()?)?;
134 let result = self.ravel()?.zip(&other.ravel()?)?
135 .map(|tuple| tuple.0.to_f64() * tuple.1.to_f64())?
136 .fold(0., |a, b| a + b)?;
137 Self::single(N::from(result))
138 }
139
140 fn inner(&self, other: &Self) -> Result<Self, ArrayError> {
141 if self.ndim()? == 1 && other.ndim()? == 1 {
142 self.shapes_align(0, &other.get_shape()?, 0)?;
143 self.zip(other)?
144 .map(|i| i.0.to_f64() * i.1.to_f64())
145 .sum(None)?
146 .to_array_num()
147 } else {
148 self.shapes_align(self.ndim()? - 1, &other.get_shape()?, other.ndim()? - 1)?;
149 Self::inner_nd(self, other)
150 }
151 }
152
153 fn outer(&self, other: &Self) -> Result<Self, ArrayError> {
154 self.into_iter().flat_map(|a| other.into_iter()
155 .map(|b| N::from(a.to_f64() * b.to_f64()))
156 .collect::<Self>())
157 .collect::<Self>()
158 .reshape(&[self.len()?, other.len()?])
159 }
160
161 fn matmul(&self, other: &Self) -> Result<Self, ArrayError> {
162 if self.ndim()? == 1 && other.ndim()? == 1 {
163 self.vdot(other)
164 } else if self.ndim()? == 1 || other.ndim()? == 1 {
165 if self.ndim()? == 1 { self.shapes_align(0, &other.get_shape()?, other.ndim()? - 1)?; }
166 else { self.shapes_align(self.ndim()? - 1, &other.get_shape()?, 0)?; }
167 Self::matmul_1d_nd(self, other)
168 } else if self.ndim()? == 2 && other.ndim()? == 2 {
169 self.shapes_align(0, &other.get_shape()?, 1)?;
170 Self::matmul_iterate(self, other)
171 } else {
172 Self::matmul_nd(self, other)
173 }
174 }
175}
176
177impl <N: NumericOps> ArrayLinalgProducts<N> for Result<Array<N>, ArrayError> {
178
179 fn dot(&self, other: &Array<N>) -> Self {
180 self.clone()?.dot(other)
181 }
182
183 fn vdot(&self, other: &Array<N>) -> Self {
184 self.clone()?.vdot(other)
185 }
186
187 fn inner(&self, other: &Array<N>) -> Self {
188 self.clone()?.inner(other)
189 }
190
191 fn outer(&self, other: &Array<N>) -> Self {
192 self.clone()?.outer(other)
193 }
194
195 fn matmul(&self, other: &Array<N>) -> Self {
196 self.clone()?.matmul(other)
197 }
198}
199
200trait ProductsHelper<N: NumericOps> {
201
202 fn dot_split_array(arr: &Array<N>, axis: usize) -> Result<Vec<Array<N>>, ArrayError> {
203 arr.split_axis(axis)?
204 .into_iter().flatten()
205 .collect::<Array<N>>()
206 .split(arr.get_shape()?.remove_at(axis).iter().product(), None)
207 }
208
209 fn dot_iterate(v_arr_1: &[Array<N>], v_arr_2: &[Array<N>]) -> Result<Array<N>, ArrayError> {
210 v_arr_1.iter().flat_map(|a| {
211 v_arr_2.iter().map(move |b| a.vdot(b))
212 })
213 .collect::<Vec<Result<Array<N>, _>>>()
214 .has_error()?.into_iter()
215 .flat_map(Result::unwrap)
216 .collect::<Array<N>>()
217 .ravel()
218 }
219
220 fn dot_1d(arr_1: &Array<N>, arr_2: &Array<N>) -> Result<Array<N>, ArrayError> {
221 let arr_1 = if arr_1.ndim()? > 1 { arr_1.get_rows()? } else { vec![arr_1.clone()] };
222 let arr_2 = if arr_2.ndim()? > 1 { arr_2.get_columns()? } else { vec![arr_2.clone()] };
223 Self::dot_iterate(&arr_1, &arr_2)
224 }
225
226 fn dot_nd(arr_1: &Array<N>, arr_2: &Array<N>) -> Result<Array<N>, ArrayError> {
227 arr_1.shapes_align(arr_1.ndim()? - 1, &arr_2.get_shape()?, arr_2.ndim()? - 2)?;
228 let mut new_shape = arr_1.get_shape()?.remove_at(arr_1.ndim()? - 2);
229 new_shape.extend_from_slice(&arr_2.get_shape()?.remove_at(arr_2.ndim()? - 1));
230 let v_arr_1 = Self::dot_split_array(arr_1, arr_1.ndim()? - 2)?;
231 let v_arr_2 = Self::dot_split_array(arr_2, arr_2.ndim()? - 1)?;
232
233 let rev = arr_2.len()? > arr_1.len()?;
234 let pairs = (0..new_shape.len().to_isize())
235 .collect::<Vec<isize>>()
236 .reverse_if(rev)
237 .into_iter()
238 .step_by(2)
239 .map(|item|
240 if rev { if item <= 1 { vec![item] } else { vec![item, item - 1] } }
241 else if new_shape.len().to_isize() > item + 1 { vec![item + 1, item] }
242 else { vec![item] })
243 .collect::<Vec<Vec<isize>>>()
244 .reverse_if(rev)
245 .into_iter()
246 .flatten()
247 .collect::<Vec<isize>>();
248 Self::dot_iterate(&v_arr_1, &v_arr_2)
249 .reshape(&new_shape)
250 .transpose(Some(pairs))
251 }
252
253 fn inner_nd(arr_1: &Array<N>, arr_2: &Array<N>) -> Result<Array<N>, ArrayError> {
254 fn inner_split<N: NumericOps>(arr: &Array<N>) -> Result<Vec<Array<N>>, ArrayError> {
255 let r_arr = arr.ravel()?;
256 r_arr.split(arr.get_shape()?.remove_at(arr.ndim()? - 1).iter().product(), None)
257 }
258
259 let mut new_shape = vec![];
260 new_shape.extend_from_slice(&arr_1.get_shape()?.remove_at(arr_1.ndim()? - 1));
261 new_shape.extend_from_slice(&arr_2.get_shape()?.remove_at(arr_2.ndim()? - 1));
262
263 let v_arr_1 = inner_split(arr_1)?;
264 let v_arr_2 = inner_split(arr_2)?;
265
266 v_arr_1.iter()
267 .flat_map(|v_a1| v_arr_2.iter()
268 .map(|v_a2| v_a1.inner(v_a2))
269 .collect::<Vec<Result<Array<N>, ArrayError>>>())
270 .collect::<Vec<Result<Array<N>, ArrayError>>>()
271 .has_error()?.into_iter()
272 .flat_map(Result::unwrap)
273 .collect::<Array<N>>()
274 .reshape(&new_shape)
275 }
276
277 fn matmul_iterate(arr_1: &Array<N>, arr_2: &Array<N>) -> Result<Array<N>, ArrayError> {
278 let (shape_1, shape_2) = (&arr_1.get_shape()?, &arr_2.get_shape()?);
279 (0..shape_1[0])
280 .flat_map(|i| (0..shape_2[1])
281 .map(move |j| (0..shape_1[1])
282 .fold(0., |acc, k| arr_1[i * shape_1[1] + k].to_f64().mul_add(arr_2[k * shape_2[1] + j].to_f64(), acc))))
283 .map(N::from_f64)
284 .collect::<Array<N>>()
285 .reshape(&[shape_1[0], shape_2[1]])
286 }
287
288 fn matmul_1d_nd(arr_1: &Array<N>, arr_2: &Array<N>) -> Result<Array<N>, ArrayError> {
289 if arr_1.ndim()? == 1 {
290 if arr_2.ndim()? > 2 {
291 let new_shape = arr_2.get_shape()?.remove_at(0);
292 arr_2.split_axis(0)?.into_iter()
293 .map(|arr| Self::matmul_1d_nd(arr_1, &arr.reshape(&new_shape).unwrap()))
294 .collect::<Vec<Result<Array<N>, _>>>()
295 .has_error()?
296 .into_iter()
297 .flat_map(Result::unwrap)
298 .collect::<Array<N>>()
299 .reshape(&new_shape)
300 } else {
301 let result = arr_1
302 .get_elements()?
303 .into_iter()
304 .zip(&arr_2.split_axis(0)?)
305 .map(|(a, b)| b.into_iter()
306 .map(|item| a.to_f64() * item.to_f64())
307 .sum::<f64>())
308 .map(N::from)
309 .collect::<Array<N>>();
310 Ok(result)
311 }
312 } else if arr_1.ndim()? > 2 {
313 let new_shape = arr_1.get_shape()?.remove_at(0);
314 arr_1.split_axis(0)?
315 .into_iter()
316 .map(|arr| Self::matmul_1d_nd(&arr.reshape(&new_shape).unwrap(), arr_2))
317 .collect::<Vec<Result<Array<N>, _>>>()
318 .has_error()?
319 .into_iter()
320 .flat_map(Result::unwrap)
321 .collect::<Array<N>>()
322 .reshape(&new_shape)
323 } else {
324 let result = arr_1
325 .split_axis(0)?
326 .iter()
327 .map(|arr| (0..arr.shape[arr.shape.len() - 1])
328 .map(|idx| arr[idx].to_f64() * arr_2[idx].to_f64())
329 .sum::<f64>())
330 .map(N::from)
331 .collect::<Array<N>>();
332 Ok(result)
333 }
334 }
335
336 fn matmul_nd(arr_1: &Array<N>, arr_2: &Array<N>) -> Result<Array<N>, ArrayError> {
337 fn matmul_split<N: NumericOps>(arr: &Array<N>, len: usize, chunk_len: usize) -> Result<Vec<Array<N>>, ArrayError> {
338 let shape_last = arr.get_shape()?
339 .into_iter()
340 .skip(arr.ndim()? - 2)
341 .take(2)
342 .collect::<Vec<usize>>();
343 let result = arr.split(arr.len()? / chunk_len, Some(0))?
344 .into_iter().cycle().take(len)
345 .map(|arr| arr.reshape(&shape_last).unwrap())
346 .collect::<Vec<Array<N>>>();
347 Ok(result)
348 }
349
350 let mut new_shape =
351 if arr_1.ndim()? >= arr_2.ndim()? { arr_1.get_shape()? }
352 else { arr_2.get_shape()? };
353 let shape_len = new_shape.len();
354 new_shape[shape_len - 2] = arr_1.get_shape()?[arr_1.ndim()? - 2];
355 new_shape[shape_len - 1] = arr_2.get_shape()?[arr_2.ndim()? - 1];
356 let chunk_len = arr_1.get_shape()?[arr_1.ndim()? - 2 ..].iter().product::<usize>();
357 let len = std::cmp::max(arr_1.len()?, arr_2.len()?) / chunk_len;
358 matmul_split(arr_1, len, chunk_len)?
359 .into_iter()
360 .zip(&matmul_split(arr_2, len, chunk_len)?)
361 .map(|(a, b)| a.matmul(b))
362 .collect::<Vec<Result<Array<N>, _>>>()
363 .has_error()?
364 .into_iter()
365 .flat_map(Result::unwrap)
366 .collect::<Array<N>>()
367 .reshape(&new_shape)
368 }
369}
370
371impl <N: NumericOps> ProductsHelper<N> for Array<N> {}