redstone_ml/linalg/
matrix_ops.rs

1use crate::axis::AxisType;
2use crate::einsum::einsum_into_ptr;
3use crate::linalg::sum_of_products::SumOfProductsType;
4use crate::{Axis, IntegerDataType, NumericDataType, RawDataType, NdArray, StridedMemory, Constructors};
5use std::cmp::min;
6
7
8impl<'a, T: MatrixOps> NdArray<'a, T> {
9    /// Calculates the matrix product of two ndarrays.
10    ///
11    /// - If both ndarrays are 1D, then their dot product is returned.
12    /// - If both ndarrays are 2D, then their matrix product is returned.
13    /// - If the first ndarray is 2D and the second ndarray is 1D, then the matrix-vector product is returned.
14    ///
15    /// # Panics
16    /// - If the dimensions/shape of the ndarrays are incompatible
17    ///
18    /// # Example
19    /// ```
20    /// # use redstone_ml::*;
21    ///
22    /// let a = NdArray::new(vec![
23    ///     [1, 2, 3],
24    ///     [4, 5, 6],
25    /// ]);
26    ///
27    /// let b = NdArray::new(vec![
28    ///     [7, 8],
29    ///     [9, 10],
30    ///     [11, 12],
31    /// ]);
32    ///
33    /// let result = a.matmul(&b);
34    /// assert_eq!(result, NdArray::new(vec![
35    ///     [58, 64],
36    ///     [139, 154],
37    /// ]));
38    /// ```
39    pub fn matmul<'r>(&self, other: impl AsRef<NdArray<'a, T>>) -> NdArray<'r, T> {
40        let other = other.as_ref();
41
42        if self.ndims() == 1 && other.ndims() == 1 {
43            return self.dot(other);
44        }
45
46        if self.ndims() == 2 && other.ndims() == 1 {
47            assert_eq!(self.shape()[1], other.shape()[0], "mismatched shape for matrix-vector product: {:?} and {:?})", self.shape(), other.shape());
48            return unsafe { <T as MatrixOps>::matrix_vector_product(self, other) };
49        }
50
51        if self.ndims() == 2 && other.ndims() == 2 {
52            assert_eq!(self.shape()[1], other.shape()[0], "mismatched shape for matrix-matrix product: {:?} and {:?})", self.shape(), other.shape());
53
54            let output_shape = [self.shape()[0], other.shape()[1]];
55
56            let result = NdArray::zeros(output_shape);
57            unsafe { <T as MatrixOps>::matrix_matrix_product(self, other, result.stride(), result.mut_ptr()) };
58            return result;
59        }
60
61        panic!("matmul requires a ndarray with 1 or 2 dimensions");
62    }
63
64    /// Performs batch matrix multiplication on 3D ndarrays.
65    ///
66    /// The shape of the resulting array will be `[batch_size, self.shape()[1], other.shape()[2]]`,
67    /// where `batch_size` is the shared first dimension of both input arrays.
68    ///
69    /// # Panics
70    /// - If either array is not 3D
71    /// - If the arrays do not have dimensions compatible for batch matrix multiplication.
72    ///
73    /// # Example
74    /// ```
75    /// # use redstone_ml::*;
76    ///
77    /// let arr1 = NdArray::<f32>::rand([3, 2, 4]); // 3 batches of 2x4 matrices
78    /// let arr2 = NdArray::<f32>::rand([3, 4, 5]); // 3 batches of 4x5 matrices
79    /// let result = arr1.bmm(&arr2);
80    /// assert_eq!(result.shape(), [3, 2, 5]); // result is 3 batches of 2x5 matrices
81    /// ```
82    pub fn bmm<'r>(&self, other: impl AsRef<NdArray<'a, T>>) -> NdArray<'r, T> {
83        let other = other.as_ref();
84        assert_eq!(self.ndims(), 3, "batch matrix multiplication requires 3D ndarrays");
85        assert_eq!(other.ndims(), 3, "batch matrix multiplication requires 3D ndarrays");
86        assert_eq!(self.len(), other.len(), "incompatible batch sizes for batch matrix multiplication: {:?} and {:?})", self.shape(), other.shape());
87
88        let output_shape = [self.len(), self.shape()[1], other.shape()[2]];
89
90        let result = NdArray::zeros(output_shape);
91        unsafe { <T as MatrixOps>::batch_matrix_matrix_product(self, other, result.stride(), result.mut_ptr()); }
92        result
93    }
94}
95
96impl<'a, T: SumOfProductsType> NdArray<'a, T> {
97    /// Calculates the dot product of two 1D arrays.
98    ///
99    /// # Panics
100    /// - Panics if either array is not 1D
101    /// - Panics if the lengths of the two arrays are not equal
102    ///
103    /// # Examples
104    /// ```
105    /// # use redstone_ml::*;
106    /// let arr1 = NdArray::new([1, 2, 3]);
107    /// let arr2 = NdArray::new([4, 5, 6]);
108    /// let result = arr1.dot(arr2);
109    /// assert_eq!(result.value(), 32); // 1*4 + 2*5 + 3*6 = 32
110    /// ```
111    pub fn dot<'b, 'r>(&self, other: impl AsRef<NdArray<'b, T>>) -> NdArray<'r, T> {
112        let other = other.as_ref();
113        assert_eq!(self.ndims(), 1, "dot product requires an array with 1 dimension");
114        assert_eq!(other.ndims(), 1, "dot product requires an array with 1 dimension");
115        assert_eq!(self.len(), other.len(), "dot product requires array with the same length");
116
117        let result = NdArray::scalar(T::default());
118
119        unsafe {
120            <T as SumOfProductsType>::sum_of_products_in_strides_n_n_out_stride_0(&[self.mut_ptr(), other.mut_ptr(), result.mut_ptr()],
121                                                                                  &[self.stride()[0], other.stride()[0], 0],
122                                                                                  self.len())
123        };
124
125        result
126    }
127}
128
129impl<'a, T: NumericDataType> NdArray<'a, T> {
130    /// Returns the trace of the ndarray along its first 2 axes.
131    ///
132    /// # Panics
133    /// - if the ndarray has fewer than 2 dimensions.
134    ///
135    /// # Examples
136    /// ```
137    /// # use redstone_ml::*;
138    /// let arr = NdArray::new([
139    ///     [1, 2, 3],
140    ///     [4, 5, 6],
141    ///     [7, 8, 9]
142    /// ]);
143    ///
144    /// assert_eq!(arr.trace(), NdArray::scalar(1 + 5 + 9));
145    pub fn trace<'r>(&self) -> NdArray<'r, T> {
146        self.offset_trace(0)
147    }
148
149    /// Returns the sum of an offset ndarray diagonal along its first 2 axes.
150    ///
151    /// # Panics
152    /// - if the ndarray has fewer than 2 dimensions.
153    ///
154    /// # Examples
155    /// ```
156    /// # use redstone_ml::*;
157    /// let arr = NdArray::new([
158    ///     [1, 2, 3],
159    ///     [4, 5, 6],
160    ///     [7, 8, 9]
161    /// ]);
162    ///
163    /// assert_eq!(arr.offset_trace(-1), NdArray::scalar(4 + 8));
164    pub fn offset_trace<'r>(&self, offset: isize) -> NdArray<'r, T> {
165        self.offset_trace_along(offset, 0, 1)
166    }
167
168    /// Returns the trace of an ndarray along the specified axes.
169    ///
170    /// # Panics
171    /// - if the ndarray has fewer than 2 dimensions.
172    /// - if `axis1` and `axis2` are the same or are out-of-bounds
173    ///
174    /// # Examples
175    /// ```
176    /// # use redstone_ml::*;
177    /// let ndarray = NdArray::new([
178    ///     [1, 2, 3],
179    ///     [4, 5, 6],
180    ///     [7, 8, 9]
181    /// ]);
182    ///
183    /// assert_eq!(ndarray.trace_along(0, 1), NdArray::scalar(1 + 5 + 9));
184    pub fn trace_along<'r>(&self, axis1: impl AxisType, axis2: impl AxisType) -> NdArray<'r, T> {
185        self.offset_trace_along(0, axis1, axis2)
186    }
187
188    /// Returns the sum of an offset ndarray diagonal along the specified axes.
189    ///
190    /// # Panics
191    /// - if the ndarray has fewer than 2 dimensions.
192    /// - if `axis1` and `axis2` are the same or are out-of-bounds
193    ///
194    /// # Examples
195    /// ```
196    /// # use redstone_ml::*;
197    /// let ndarray = NdArray::new([
198    ///     [1, 2, 3],
199    ///     [4, 5, 6],
200    ///     [7, 8, 9]
201    /// ]);
202    ///
203    /// assert_eq!(ndarray.offset_trace_along(1, 0, 1), NdArray::scalar(2 + 6));
204    pub fn offset_trace_along<'r>(&self, offset: isize, axis1: impl AxisType, axis2: impl AxisType) -> NdArray<'r, T> {
205        let diagonal = self.offset_diagonal_along(offset, axis1, axis2);
206        diagonal.sum_along(-1)
207    }
208}
209
210impl<'a, T: RawDataType> NdArray<'a, T> {
211    /// Returns a diagonal view of the ndarray along its first 2 axes.
212    ///
213    /// # Panics
214    /// - if the ndarray has fewer than 2 dimensions.
215    ///
216    /// # Examples
217    /// ```
218    /// # use redstone_ml::*;
219    /// let ndarray = NdArray::new([
220    ///     [1, 2, 3],
221    ///     [4, 5, 6],
222    ///     [7, 8, 9]
223    /// ]);
224    ///
225    /// let diagonal = ndarray.diagonal();
226    /// assert_eq!(diagonal, NdArray::new([1, 5, 9]));
227    pub fn diagonal(&'a self) -> NdArray<'a, T> {
228        self.diagonal_along(0, 1)
229    }
230
231    /// Returns an offset diagonal view of the ndarray along its first 2 axes.
232    ///
233    /// # Panics
234    /// - if the ndarray has fewer than 2 dimensions.
235    ///
236    /// # Examples
237    /// ```
238    /// # use redstone_ml::*;
239    /// let ndarray = NdArray::new([
240    ///     [1, 2, 3],
241    ///     [4, 5, 6],
242    ///     [7, 8, 9]
243    /// ]);
244    ///
245    /// let diagonal = ndarray.offset_diagonal(1);
246    /// assert_eq!(diagonal, NdArray::new([2, 6]));
247    pub fn offset_diagonal(&'a self, offset: isize) -> NdArray<'a, T> {
248        self.offset_diagonal_along(offset, 0, 1)
249    }
250
251    /// Returns a diagonal view of the ndarray along the specified axes.
252    ///
253    /// # Panics
254    /// - if the ndarray has fewer than 2 dimensions.
255    /// - if `axis1` and `axis2` are the same or are out-of-bounds
256    ///
257    /// # Examples
258    /// ```
259    /// # use redstone_ml::*;
260    /// let ndarray = NdArray::new([
261    ///     [1, 2, 3],
262    ///     [4, 5, 6],
263    ///     [7, 8, 9]
264    /// ]);
265    ///
266    /// let diagonal = ndarray.diagonal_along(Axis(0), Axis(1));  // or .diagonal_along(0, 1)
267    /// assert_eq!(diagonal, NdArray::new([1, 5, 9]));
268    pub fn diagonal_along(&'a self, axis1: impl AxisType, axis2: impl AxisType) -> NdArray<'a, T> {
269        self.offset_diagonal_along(0, axis1, axis2)
270    }
271
272    /// Returns an offset diagonal view of the ndarray along the specified axes.
273    ///
274    /// # Panics
275    /// - if the ndarray has fewer than 2 dimensions.
276    /// - if `axis1` and `axis2` are the same or are out-of-bounds
277    ///
278    /// # Examples
279    /// ```
280    /// # use redstone_ml::*;
281    /// let arr = NdArray::new([
282    ///     [1, 2, 3],
283    ///     [4, 5, 6],
284    ///     [7, 8, 9]
285    /// ]);
286    ///
287    /// let diagonal = arr.offset_diagonal_along(-1, Axis(0), Axis(1));  // or .offset_diagonal_along(-1, 0, 1)
288    /// assert_eq!(diagonal, NdArray::new([4, 8]));
289    pub fn offset_diagonal_along(&'a self, offset: isize, axis1: impl AxisType, axis2: impl AxisType) -> NdArray<'a, T> {
290        assert!(self.ndims() >= 2, "diagonals require a ndarray with at least 2 dimensions");
291
292        let axis1 = axis1.as_absolute(self.ndims());
293        let axis2 = axis2.as_absolute(self.ndims());
294
295        assert_ne!(axis1, axis2, "axis1 and axis2 cannot be the same");
296
297
298        // get the new dimensions and strides of the two axes
299
300        let mut dim1 = self.shape()[axis1];
301        let mut dim2 = self.shape()[axis2];
302
303        let stride1 = self.stride()[axis1];
304        let stride2 = self.stride()[axis2];
305
306
307        // modify the dimensions and data pointer based on offset
308
309        let ptr_offset = if offset >= 0 {
310            let offset = offset as usize;
311            if offset >= dim2 {
312                panic!("invalid offset {} for axis with dimension {}", offset, dim2);
313            }
314
315            dim2 -= offset;
316            offset * stride2
317        } else {
318            let offset = -offset as usize;
319            if offset >= dim1 {
320                panic!("invalid offset -{} for axis with dimension {}", offset, dim1);
321            }
322
323            dim1 -= offset;
324            offset * stride1
325        };
326
327
328        // compute the resultant shape and stride
329
330        let mut result_shape = Vec::with_capacity(self.ndims() - 1);
331        let mut result_stride = Vec::with_capacity(self.ndims() - 1);
332
333        for axis in 0..self.ndims() {
334            if axis == axis1 || axis == axis2 {
335                continue;
336            }
337
338            result_shape.push(self.shape()[axis]);
339            result_stride.push(self.stride()[axis]);
340        }
341
342        result_shape.push(min(dim1, dim2));
343        result_stride.push(stride1 + stride2);
344
345        // create and return the diagonal view
346        unsafe { self.reshaped_view_with_offset(ptr_offset, result_shape, result_stride) }
347    }
348}
349
350
351pub(crate) trait MatrixOps: SumOfProductsType {
352    /// Performs an unchecked batched matrix-matrix product operation
353    /// and writes the result to the given pointer
354    ///
355    /// # Safety
356    ///
357    /// - The dimensions of `lhs` and `rhs` must be `(b, i, j)` and `(b, j, k)`.
358    /// - `result` must point to a valid data buffer with dimension `(b, i, k)`
359    /// - `result_stride` must represent a valid layout for the results buffer with 
360    ///   the last 2 dimensions being contiguous.
361    /// - `result` must not overlap with `lhs` or `rhs`.
362    unsafe fn batch_matrix_matrix_product<'a>(lhs: &NdArray<'a, Self>,
363                                              rhs: &NdArray<'a, Self>,
364                                              result_stride: &[usize],
365                                              mut result: *mut Self) {
366        let mut lhs_slice = lhs.slice_along(Axis(0), 0);
367        let mut rhs_slice = rhs.slice_along(Axis(0), 0);
368
369        for _ in 0..lhs.len() {
370            Self::matrix_matrix_product(&lhs_slice, &rhs_slice, &result_stride[1..], result);
371
372            result = result.add(result_stride[0]);
373            lhs_slice.offset_ptr(lhs.stride()[0] as isize);
374            rhs_slice.offset_ptr(rhs.stride()[0] as isize);
375        }
376    }
377
378    /// Performs an unchecked matrix-matrix product and writes the result to the given pointer
379    ///
380    /// # Safety
381    ///
382    /// - The dimensions of `lhs` and `rhs` must be `(i, j)` and `(j, k)`.
383    /// - `result` must point to a valid data buffer with dimension `(i, k)`.
384    /// - `result_stride` must represent a contiguous layout for the results buffer.
385    /// - `result` must not overlap with `lhs` or `rhs`.
386    unsafe fn matrix_matrix_product<'a>(lhs: &NdArray<'a, Self>,
387                                        rhs: &NdArray<'a, Self>,
388                                        result_stride: &[usize],
389                                        result: *mut Self)
390    {
391        einsum_into_ptr([lhs, rhs], (["ij", "jk"], "ik"), result_stride, result)
392    }
393
394    /// Performs an unchecked matrix-vector product and returns the result.
395    ///
396    /// # Safety
397    ///
398    /// - The dimensions of `lhs` and `rhs` must be `(i, j)` and `(j)`.
399    unsafe fn matrix_vector_product<'a, 'b, 'r>(matrix: &NdArray<'a, Self>,
400                                                vector: &NdArray<'b, Self>) -> NdArray<'r, Self> {
401        let rows = matrix.shape()[0];
402        let cols = matrix.shape()[1];
403        let mut result = vec![Self::default(); rows];
404
405        let strides = &[matrix.stride()[1], vector.stride()[0], 0];
406
407        let mut matrix_row = matrix.mut_ptr();
408        let mut dst = result.as_mut_ptr();
409
410        for _ in 0..rows {
411            Self::sum_of_products_in_strides_n_n_out_stride_0(&[matrix_row, vector.mut_ptr(), dst], strides, cols);
412            matrix_row = matrix_row.add(matrix.stride()[0]);
413            dst = dst.add(1);
414        }
415
416        NdArray::from_contiguous_owned_buffer(vec![rows], result)
417    }
418}
419
420impl<T: IntegerDataType> MatrixOps for T {}
421
422impl MatrixOps for f32 {
423    #[cfg(blas)]
424    unsafe fn matrix_matrix_product<'a>(lhs: &NdArray<'a, Self>,
425                                        rhs: &NdArray<'a, Self>,
426                                        result_stride: &[usize],
427                                        result: *mut Self) {
428        use crate::acceleration::cblas::{cblas_sgemm, CBLAS_NO_TRANS, CBLAS_ROW_MAJOR};
429
430        // BLAS does not support matrices that don't have contiguous rows
431        if lhs.stride()[1] != 1 || rhs.stride()[1] != 1 {
432            return einsum_into_ptr([lhs, rhs], (["ij", "jk"], "ik"), result_stride, result);
433        }
434
435        let m = lhs.shape()[0];
436        let n = rhs.shape()[1];
437
438        unsafe {
439            cblas_sgemm(CBLAS_ROW_MAJOR, CBLAS_NO_TRANS, CBLAS_NO_TRANS,
440                        m as i32, n as i32, lhs.shape()[1] as i32,
441                        1.0,
442                        lhs.mut_ptr(), lhs.stride()[0] as i32,
443                        rhs.mut_ptr(), rhs.stride()[0] as i32,
444                        0.0, result, n as i32);
445        }
446    }
447
448    #[cfg(all(blas, not(neon_simd)))]
449    unsafe fn matrix_vector_product<'a, 'b, 'r>(matrix: &NdArray<'a, Self>,
450                                                vector: &NdArray<'b, Self>) -> NdArray<'r, Self> {
451        use crate::acceleration::cblas::{cblas_sgemv, CBLAS_NO_TRANS, CBLAS_ROW_MAJOR};
452        use crate::einsum;
453
454        // BLAS does not support matrices that don't have contiguous rows
455        if matrix.stride()[1] != 1 {
456            return einsum([matrix, vector], (["ij", "j"], "i"));
457        }
458
459        let rows = matrix.shape()[0];
460        let cols = matrix.shape()[1] as i32;
461        let mut result = vec![Self::default(); rows];
462
463        unsafe {
464            cblas_sgemv(CBLAS_ROW_MAJOR, CBLAS_NO_TRANS,
465                        rows as i32, cols, 1.0, matrix.ptr(), matrix.stride()[0] as i32,
466                        vector.ptr(), vector.stride()[0] as i32,
467                        0.0, result.as_mut_ptr(), 1
468            );
469
470            NdArray::from_contiguous_owned_buffer(vec![rows], result)
471        }
472    }
473}
474
475impl MatrixOps for f64 {
476    #[cfg(blas)]
477    unsafe fn matrix_matrix_product<'a>(lhs: &NdArray<'a, Self>,
478                                        rhs: &NdArray<'a, Self>,
479                                        result_stride: &[usize],
480                                        result: *mut Self) {
481        use crate::acceleration::cblas::{cblas_dgemm, CBLAS_NO_TRANS, CBLAS_ROW_MAJOR};
482
483        // BLAS does not support matrices that don't have contiguous rows
484        if lhs.stride()[1] != 1 || rhs.stride()[1] != 1 {
485            return einsum_into_ptr([lhs, rhs], (["ij", "jk"], "ik"), result_stride, result);
486        }
487
488        let m = lhs.shape()[0];
489        let n = rhs.shape()[1];
490
491        unsafe {
492            cblas_dgemm(CBLAS_ROW_MAJOR, CBLAS_NO_TRANS, CBLAS_NO_TRANS,
493                        m as i32, n as i32, lhs.shape()[1] as i32,
494                        1.0,
495                        lhs.mut_ptr(), lhs.stride()[0] as i32,
496                        rhs.mut_ptr(), rhs.stride()[0] as i32,
497                        0.0, result, n as i32);
498        }
499    }
500
501    #[cfg(all(blas, not(neon_simd)))]
502    unsafe fn matrix_vector_product<'a, 'b, 'r>(matrix: &NdArray<'a, Self>,
503                                                vector: &NdArray<'b, Self>) -> NdArray<'r, Self> {
504        use crate::acceleration::cblas::{cblas_dgemv, CBLAS_NO_TRANS, CBLAS_ROW_MAJOR};
505        use crate::einsum;
506
507        // BLAS does not support matrices that don't have contiguous rows
508        if matrix.stride()[1] != 1 {
509            return einsum([matrix, vector], (["ij", "j"], "i"));
510        }
511
512        let rows = matrix.shape()[0];
513        let cols = matrix.shape()[1] as i32;
514        let mut result = vec![Self::default(); rows];
515
516        unsafe {
517            cblas_dgemv(CBLAS_ROW_MAJOR, CBLAS_NO_TRANS,
518                        rows as i32, cols, 1.0, matrix.ptr(), matrix.stride()[0] as i32,
519                        vector.ptr(), vector.stride()[0] as i32,
520                        0.0, result.as_mut_ptr(), 1
521            );
522
523            NdArray::from_contiguous_owned_buffer(vec![rows], result)
524        }
525    }
526}