arr_rs/linalg/operations/
products.rs

1use crate::{
2    core::prelude::*,
3    errors::prelude::*,
4    extensions::prelude::*,
5    linalg::prelude::*,
6    math::prelude::*,
7    numeric::prelude::*,
8    validators::prelude::*,
9};
10
11/// `ArrayTrait` - Array Linalg Products functions
12pub trait ArrayLinalgProducts<N: NumericOps> where Self: Sized + Clone {
13
14    /// Dot product of two arrays
15    ///
16    /// # Arguments
17    ///
18    /// * `other` - other array to perform operations with
19    ///
20    /// # Examples
21    ///
22    /// ```
23    /// use arr_rs::prelude::*;
24    ///
25    /// assert_eq!(Array::single(12), Array::single(3).dot(&Array::single(4).unwrap()));
26    /// assert_eq!(Array::single(20), Array::flat(vec![1, 2, 3]).dot(&Array::flat(vec![2, 3, 4]).unwrap()));
27    /// assert_eq!(Array::new(vec![4, 1, 2, 2], vec![2, 2]), Array::new(vec![1, 0, 0, 1], vec![2, 2]).dot(&Array::new(vec![4, 1, 2, 2], vec![2, 2]).unwrap()));
28    /// ```
29    ///
30    /// # Errors
31    ///
32    /// may returns `ArrayError`
33    fn dot(&self, other: &Array<N>) -> Result<Array<N>, ArrayError>;
34
35    /// Dot product of two vectors. If input is an array, it will be raveled
36    ///
37    /// # Arguments
38    ///
39    /// * `other` - other array to perform operations with
40    ///
41    /// # Examples
42    ///
43    /// ```
44    /// use arr_rs::prelude::*;
45    ///
46    /// assert_eq!(Array::single(20), Array::flat(vec![1, 2, 3]).vdot(&Array::flat(vec![2, 3, 4]).unwrap()));
47    /// assert_eq!(Array::single(30), Array::new(vec![1, 4, 5, 6], vec![2, 2]).vdot(&Array::new(vec![4, 1, 2, 2], vec![2, 2]).unwrap()));
48    /// ```
49    ///
50    /// # Errors
51    ///
52    /// may returns `ArrayError`
53    fn vdot(&self, other: &Array<N>) -> Result<Array<N>, ArrayError>;
54
55    /// Inner product of two arrays
56    ///
57    /// # Arguments
58    ///
59    /// * `other` - other array to perform operations with
60    ///
61    /// # Examples
62    ///
63    /// ```
64    /// use arr_rs::prelude::*;
65    ///
66    /// assert_eq!(Array::single(20), Array::flat(vec![1, 2, 3, 4]).inner(&Array::flat(vec![4, 3, 2, 1]).unwrap()));
67    /// assert_eq!(Array::new(vec![10, 4, 24, 10], vec![2, 2]), Array::new(vec![1, 2, 3, 4], vec![2, 2]).inner(&Array::new(vec![4, 3, 2, 1], vec![2, 2]).unwrap()));
68    /// ```
69    ///
70    /// # Errors
71    ///
72    /// may returns `ArrayError`
73    fn inner(&self, other: &Array<N>) -> Result<Array<N>, ArrayError>;
74
75    /// Outer product of two arrays
76    ///
77    /// # Arguments
78    ///
79    /// * `other` - other array to perform operations with
80    ///
81    /// # Examples
82    ///
83    /// ```
84    /// use arr_rs::prelude::*;
85    ///
86    /// assert_eq!(Array::new(vec![4, 3, 8, 6], vec![2, 2]), Array::flat(vec![1, 2]).outer(&Array::flat(vec![4, 3]).unwrap()));
87    /// assert_eq!(Array::new(vec![4, 3, 2, 1, 8, 6, 4, 2, 12, 9, 6, 3, 16, 12, 8, 4], vec![4, 4]), Array::new(vec![1, 2, 3, 4], vec![2, 2]).outer(&Array::new(vec![4, 3, 2, 1], vec![2, 2]).unwrap()));
88    /// ```
89    ///
90    /// # Errors
91    ///
92    /// may returns `ArrayError`
93    fn outer(&self, other: &Array<N>) -> Result<Array<N>, ArrayError>;
94
95    /// Matrix product of two arrays
96    ///
97    /// # Arguments
98    ///
99    /// * `other` - other array to perform operations with
100    ///
101    /// # Examples
102    ///
103    /// ```
104    /// use arr_rs::prelude::*;
105    ///
106    /// assert_eq!(Array::single(5), Array::flat(vec![1, 2]).matmul(&Array::flat(vec![1, 2]).unwrap()));
107    /// assert_eq!(Array::new(vec![5, 8, 8, 13], vec![2, 2]), Array::new(vec![1, 2, 2, 3], vec![2, 2]).matmul(&Array::new(vec![1, 2, 2, 3], vec![2, 2]).unwrap()));
108    /// ```
109    ///
110    /// # Errors
111    ///
112    /// may returns `ArrayError`
113    fn matmul(&self, other: &Array<N>) -> Result<Array<N>, ArrayError>;
114}
115
116impl <N: NumericOps> ArrayLinalgProducts<N> for Array<N> {
117
118    fn dot(&self, other: &Self) -> Result<Self, ArrayError> {
119        if self.len()? == 1 || other.len()? == 1 {
120            self.multiply(other)
121        } else if self.ndim()? == 1 && other.ndim()? == 1 {
122            self.vdot(other)
123        } else if self.ndim()? == 2 && other.ndim()? == 2 {
124            self.matmul(other)
125        } else if self.ndim()? == 1 || other.ndim()? == 1 {
126            Self::dot_1d(self, other)
127        } else {
128            Self::dot_nd(self, other)
129        }
130    }
131
132    fn vdot(&self, other: &Self) -> Result<Self, ArrayError> {
133        self.len()?.is_equal(&other.len()?)?;
134        let result = self.ravel()?.zip(&other.ravel()?)?
135            .map(|tuple| tuple.0.to_f64() * tuple.1.to_f64())?
136            .fold(0., |a, b| a + b)?;
137        Self::single(N::from(result))
138    }
139
140    fn inner(&self, other: &Self) -> Result<Self, ArrayError> {
141        if self.ndim()? == 1 && other.ndim()? == 1 {
142            self.shapes_align(0, &other.get_shape()?, 0)?;
143            self.zip(other)?
144                .map(|i| i.0.to_f64() * i.1.to_f64())
145                .sum(None)?
146                .to_array_num()
147        } else {
148            self.shapes_align(self.ndim()? - 1, &other.get_shape()?, other.ndim()? - 1)?;
149            Self::inner_nd(self, other)
150        }
151    }
152
153    fn outer(&self, other: &Self) -> Result<Self, ArrayError> {
154        self.into_iter().flat_map(|a| other.into_iter()
155            .map(|b| N::from(a.to_f64() * b.to_f64()))
156            .collect::<Self>())
157            .collect::<Self>()
158            .reshape(&[self.len()?, other.len()?])
159    }
160
161    fn matmul(&self, other: &Self) -> Result<Self, ArrayError> {
162        if self.ndim()? == 1 && other.ndim()? == 1 {
163            self.vdot(other)
164        } else if self.ndim()? == 1 || other.ndim()? == 1 {
165            if self.ndim()? == 1 { self.shapes_align(0, &other.get_shape()?, other.ndim()? - 1)?; }
166            else { self.shapes_align(self.ndim()? - 1, &other.get_shape()?, 0)?; }
167            Self::matmul_1d_nd(self, other)
168        } else if self.ndim()? == 2 && other.ndim()? == 2 {
169            self.shapes_align(0, &other.get_shape()?, 1)?;
170            Self::matmul_iterate(self, other)
171        } else {
172            Self::matmul_nd(self, other)
173        }
174    }
175}
176
177impl <N: NumericOps> ArrayLinalgProducts<N> for Result<Array<N>, ArrayError> {
178
179    fn dot(&self, other: &Array<N>) -> Self {
180        self.clone()?.dot(other)
181    }
182
183    fn vdot(&self, other: &Array<N>) -> Self {
184        self.clone()?.vdot(other)
185    }
186
187    fn inner(&self, other: &Array<N>) -> Self {
188        self.clone()?.inner(other)
189    }
190
191    fn outer(&self, other: &Array<N>) -> Self {
192        self.clone()?.outer(other)
193    }
194
195    fn matmul(&self, other: &Array<N>) -> Self {
196        self.clone()?.matmul(other)
197    }
198}
199
200trait ProductsHelper<N: NumericOps> {
201
202    fn dot_split_array(arr: &Array<N>, axis: usize) -> Result<Vec<Array<N>>, ArrayError> {
203        arr.split_axis(axis)?
204            .into_iter().flatten()
205            .collect::<Array<N>>()
206            .split(arr.get_shape()?.remove_at(axis).iter().product(), None)
207    }
208
209    fn dot_iterate(v_arr_1: &[Array<N>], v_arr_2: &[Array<N>]) -> Result<Array<N>, ArrayError> {
210        v_arr_1.iter().flat_map(|a| {
211            v_arr_2.iter().map(move |b| a.vdot(b))
212        })
213            .collect::<Vec<Result<Array<N>, _>>>()
214            .has_error()?.into_iter()
215            .flat_map(Result::unwrap)
216            .collect::<Array<N>>()
217            .ravel()
218    }
219
220    fn dot_1d(arr_1: &Array<N>, arr_2: &Array<N>) -> Result<Array<N>, ArrayError> {
221        let arr_1 = if arr_1.ndim()? > 1 { arr_1.get_rows()? } else { vec![arr_1.clone()] };
222        let arr_2 = if arr_2.ndim()? > 1 { arr_2.get_columns()? } else { vec![arr_2.clone()] };
223        Self::dot_iterate(&arr_1, &arr_2)
224    }
225
226    fn dot_nd(arr_1: &Array<N>, arr_2: &Array<N>) -> Result<Array<N>, ArrayError> {
227        arr_1.shapes_align(arr_1.ndim()? - 1, &arr_2.get_shape()?, arr_2.ndim()? - 2)?;
228        let mut new_shape = arr_1.get_shape()?.remove_at(arr_1.ndim()? - 2);
229        new_shape.extend_from_slice(&arr_2.get_shape()?.remove_at(arr_2.ndim()? - 1));
230        let v_arr_1 = Self::dot_split_array(arr_1, arr_1.ndim()? - 2)?;
231        let v_arr_2 = Self::dot_split_array(arr_2, arr_2.ndim()? - 1)?;
232
233        let rev = arr_2.len()? > arr_1.len()?;
234        let pairs = (0..new_shape.len().to_isize())
235            .collect::<Vec<isize>>()
236            .reverse_if(rev)
237            .into_iter()
238            .step_by(2)
239            .map(|item|
240                if rev { if item <= 1 { vec![item] } else { vec![item, item - 1] } }
241                else if new_shape.len().to_isize() > item + 1 { vec![item + 1, item] }
242                else { vec![item] })
243            .collect::<Vec<Vec<isize>>>()
244            .reverse_if(rev)
245            .into_iter()
246            .flatten()
247            .collect::<Vec<isize>>();
248        Self::dot_iterate(&v_arr_1, &v_arr_2)
249            .reshape(&new_shape)
250            .transpose(Some(pairs))
251    }
252
253    fn inner_nd(arr_1: &Array<N>, arr_2: &Array<N>) -> Result<Array<N>, ArrayError> {
254        fn inner_split<N: NumericOps>(arr: &Array<N>) -> Result<Vec<Array<N>>, ArrayError> {
255            let r_arr = arr.ravel()?;
256            r_arr.split(arr.get_shape()?.remove_at(arr.ndim()? - 1).iter().product(), None)
257        }
258
259        let mut new_shape = vec![];
260        new_shape.extend_from_slice(&arr_1.get_shape()?.remove_at(arr_1.ndim()? - 1));
261        new_shape.extend_from_slice(&arr_2.get_shape()?.remove_at(arr_2.ndim()? - 1));
262
263        let v_arr_1 = inner_split(arr_1)?;
264        let v_arr_2 = inner_split(arr_2)?;
265
266        v_arr_1.iter()
267            .flat_map(|v_a1| v_arr_2.iter()
268                .map(|v_a2| v_a1.inner(v_a2))
269                .collect::<Vec<Result<Array<N>, ArrayError>>>())
270            .collect::<Vec<Result<Array<N>, ArrayError>>>()
271            .has_error()?.into_iter()
272            .flat_map(Result::unwrap)
273            .collect::<Array<N>>()
274            .reshape(&new_shape)
275    }
276
277    fn matmul_iterate(arr_1: &Array<N>, arr_2: &Array<N>) -> Result<Array<N>, ArrayError> {
278        let (shape_1, shape_2) = (&arr_1.get_shape()?, &arr_2.get_shape()?);
279        (0..shape_1[0])
280            .flat_map(|i| (0..shape_2[1])
281                .map(move |j| (0..shape_1[1])
282                    .fold(0., |acc, k| arr_1[i * shape_1[1] + k].to_f64().mul_add(arr_2[k * shape_2[1] + j].to_f64(), acc))))
283            .map(N::from_f64)
284            .collect::<Array<N>>()
285            .reshape(&[shape_1[0], shape_2[1]])
286    }
287
288    fn matmul_1d_nd(arr_1: &Array<N>, arr_2: &Array<N>) -> Result<Array<N>, ArrayError> {
289        if arr_1.ndim()? == 1 {
290            if arr_2.ndim()? > 2 {
291                let new_shape = arr_2.get_shape()?.remove_at(0);
292                arr_2.split_axis(0)?.into_iter()
293                    .map(|arr| Self::matmul_1d_nd(arr_1, &arr.reshape(&new_shape).unwrap()))
294                    .collect::<Vec<Result<Array<N>, _>>>()
295                    .has_error()?
296                    .into_iter()
297                    .flat_map(Result::unwrap)
298                    .collect::<Array<N>>()
299                    .reshape(&new_shape)
300            } else {
301                let result = arr_1
302                    .get_elements()?
303                    .into_iter()
304                    .zip(&arr_2.split_axis(0)?)
305                    .map(|(a, b)| b.into_iter()
306                        .map(|item| a.to_f64() * item.to_f64())
307                        .sum::<f64>())
308                    .map(N::from)
309                    .collect::<Array<N>>();
310                Ok(result)
311            }
312        } else if arr_1.ndim()? > 2 {
313            let new_shape = arr_1.get_shape()?.remove_at(0);
314            arr_1.split_axis(0)?
315                .into_iter()
316                .map(|arr| Self::matmul_1d_nd(&arr.reshape(&new_shape).unwrap(), arr_2))
317                .collect::<Vec<Result<Array<N>, _>>>()
318                .has_error()?
319                .into_iter()
320                .flat_map(Result::unwrap)
321                .collect::<Array<N>>()
322                .reshape(&new_shape)
323        } else {
324            let result = arr_1
325                .split_axis(0)?
326                .iter()
327                .map(|arr| (0..arr.shape[arr.shape.len() - 1])
328                    .map(|idx| arr[idx].to_f64() * arr_2[idx].to_f64())
329                    .sum::<f64>())
330                .map(N::from)
331                .collect::<Array<N>>();
332            Ok(result)
333        }
334    }
335
336    fn matmul_nd(arr_1: &Array<N>, arr_2: &Array<N>) -> Result<Array<N>, ArrayError> {
337        fn matmul_split<N: NumericOps>(arr: &Array<N>, len: usize, chunk_len: usize) -> Result<Vec<Array<N>>, ArrayError> {
338            let shape_last = arr.get_shape()?
339                .into_iter()
340                .skip(arr.ndim()? - 2)
341                .take(2)
342                .collect::<Vec<usize>>();
343            let result = arr.split(arr.len()? / chunk_len, Some(0))?
344                .into_iter().cycle().take(len)
345                .map(|arr| arr.reshape(&shape_last).unwrap())
346                .collect::<Vec<Array<N>>>();
347            Ok(result)
348        }
349
350        let mut new_shape =
351            if arr_1.ndim()? >= arr_2.ndim()? { arr_1.get_shape()? }
352            else { arr_2.get_shape()? };
353        let shape_len = new_shape.len();
354        new_shape[shape_len - 2] = arr_1.get_shape()?[arr_1.ndim()? - 2];
355        new_shape[shape_len - 1] = arr_2.get_shape()?[arr_2.ndim()? - 1];
356        let chunk_len = arr_1.get_shape()?[arr_1.ndim()? - 2 ..].iter().product::<usize>();
357        let len = std::cmp::max(arr_1.len()?, arr_2.len()?) / chunk_len;
358        matmul_split(arr_1, len, chunk_len)?
359            .into_iter()
360            .zip(&matmul_split(arr_2, len, chunk_len)?)
361            .map(|(a, b)| a.matmul(b))
362            .collect::<Vec<Result<Array<N>, _>>>()
363            .has_error()?
364            .into_iter()
365            .flat_map(Result::unwrap)
366            .collect::<Array<N>>()
367            .reshape(&new_shape)
368    }
369}
370
371impl <N: NumericOps> ProductsHelper<N> for Array<N> {}