redstone_ml/linalg/matrix_ops.rs
1use crate::axis::AxisType;
2use crate::einsum::einsum_into_ptr;
3use crate::linalg::sum_of_products::SumOfProductsType;
4use crate::{Axis, IntegerDataType, NumericDataType, RawDataType, NdArray, StridedMemory, Constructors};
5use std::cmp::min;
6
7
8impl<'a, T: MatrixOps> NdArray<'a, T> {
9 /// Calculates the matrix product of two ndarrays.
10 ///
11 /// - If both ndarrays are 1D, then their dot product is returned.
12 /// - If both ndarrays are 2D, then their matrix product is returned.
13 /// - If the first ndarray is 2D and the second ndarray is 1D, then the matrix-vector product is returned.
14 ///
15 /// # Panics
16 /// - If the dimensions/shape of the ndarrays are incompatible
17 ///
18 /// # Example
19 /// ```
20 /// # use redstone_ml::*;
21 ///
22 /// let a = NdArray::new(vec![
23 /// [1, 2, 3],
24 /// [4, 5, 6],
25 /// ]);
26 ///
27 /// let b = NdArray::new(vec![
28 /// [7, 8],
29 /// [9, 10],
30 /// [11, 12],
31 /// ]);
32 ///
33 /// let result = a.matmul(&b);
34 /// assert_eq!(result, NdArray::new(vec![
35 /// [58, 64],
36 /// [139, 154],
37 /// ]));
38 /// ```
39 pub fn matmul<'r>(&self, other: impl AsRef<NdArray<'a, T>>) -> NdArray<'r, T> {
40 let other = other.as_ref();
41
42 if self.ndims() == 1 && other.ndims() == 1 {
43 return self.dot(other);
44 }
45
46 if self.ndims() == 2 && other.ndims() == 1 {
47 assert_eq!(self.shape()[1], other.shape()[0], "mismatched shape for matrix-vector product: {:?} and {:?})", self.shape(), other.shape());
48 return unsafe { <T as MatrixOps>::matrix_vector_product(self, other) };
49 }
50
51 if self.ndims() == 2 && other.ndims() == 2 {
52 assert_eq!(self.shape()[1], other.shape()[0], "mismatched shape for matrix-matrix product: {:?} and {:?})", self.shape(), other.shape());
53
54 let output_shape = [self.shape()[0], other.shape()[1]];
55
56 let result = NdArray::zeros(output_shape);
57 unsafe { <T as MatrixOps>::matrix_matrix_product(self, other, result.stride(), result.mut_ptr()) };
58 return result;
59 }
60
61 panic!("matmul requires a ndarray with 1 or 2 dimensions");
62 }
63
64 /// Performs batch matrix multiplication on 3D ndarrays.
65 ///
66 /// The shape of the resulting array will be `[batch_size, self.shape()[1], other.shape()[2]]`,
67 /// where `batch_size` is the shared first dimension of both input arrays.
68 ///
69 /// # Panics
70 /// - If either array is not 3D
71 /// - If the arrays do not have dimensions compatible for batch matrix multiplication.
72 ///
73 /// # Example
74 /// ```
75 /// # use redstone_ml::*;
76 ///
77 /// let arr1 = NdArray::<f32>::rand([3, 2, 4]); // 3 batches of 2x4 matrices
78 /// let arr2 = NdArray::<f32>::rand([3, 4, 5]); // 3 batches of 4x5 matrices
79 /// let result = arr1.bmm(&arr2);
80 /// assert_eq!(result.shape(), [3, 2, 5]); // result is 3 batches of 2x5 matrices
81 /// ```
82 pub fn bmm<'r>(&self, other: impl AsRef<NdArray<'a, T>>) -> NdArray<'r, T> {
83 let other = other.as_ref();
84 assert_eq!(self.ndims(), 3, "batch matrix multiplication requires 3D ndarrays");
85 assert_eq!(other.ndims(), 3, "batch matrix multiplication requires 3D ndarrays");
86 assert_eq!(self.len(), other.len(), "incompatible batch sizes for batch matrix multiplication: {:?} and {:?})", self.shape(), other.shape());
87
88 let output_shape = [self.len(), self.shape()[1], other.shape()[2]];
89
90 let result = NdArray::zeros(output_shape);
91 unsafe { <T as MatrixOps>::batch_matrix_matrix_product(self, other, result.stride(), result.mut_ptr()); }
92 result
93 }
94}
95
96impl<'a, T: SumOfProductsType> NdArray<'a, T> {
97 /// Calculates the dot product of two 1D arrays.
98 ///
99 /// # Panics
100 /// - Panics if either array is not 1D
101 /// - Panics if the lengths of the two arrays are not equal
102 ///
103 /// # Examples
104 /// ```
105 /// # use redstone_ml::*;
106 /// let arr1 = NdArray::new([1, 2, 3]);
107 /// let arr2 = NdArray::new([4, 5, 6]);
108 /// let result = arr1.dot(arr2);
109 /// assert_eq!(result.value(), 32); // 1*4 + 2*5 + 3*6 = 32
110 /// ```
111 pub fn dot<'b, 'r>(&self, other: impl AsRef<NdArray<'b, T>>) -> NdArray<'r, T> {
112 let other = other.as_ref();
113 assert_eq!(self.ndims(), 1, "dot product requires an array with 1 dimension");
114 assert_eq!(other.ndims(), 1, "dot product requires an array with 1 dimension");
115 assert_eq!(self.len(), other.len(), "dot product requires array with the same length");
116
117 let result = NdArray::scalar(T::default());
118
119 unsafe {
120 <T as SumOfProductsType>::sum_of_products_in_strides_n_n_out_stride_0(&[self.mut_ptr(), other.mut_ptr(), result.mut_ptr()],
121 &[self.stride()[0], other.stride()[0], 0],
122 self.len())
123 };
124
125 result
126 }
127}
128
129impl<'a, T: NumericDataType> NdArray<'a, T> {
130 /// Returns the trace of the ndarray along its first 2 axes.
131 ///
132 /// # Panics
133 /// - if the ndarray has fewer than 2 dimensions.
134 ///
135 /// # Examples
136 /// ```
137 /// # use redstone_ml::*;
138 /// let arr = NdArray::new([
139 /// [1, 2, 3],
140 /// [4, 5, 6],
141 /// [7, 8, 9]
142 /// ]);
143 ///
144 /// assert_eq!(arr.trace(), NdArray::scalar(1 + 5 + 9));
145 pub fn trace<'r>(&self) -> NdArray<'r, T> {
146 self.offset_trace(0)
147 }
148
149 /// Returns the sum of an offset ndarray diagonal along its first 2 axes.
150 ///
151 /// # Panics
152 /// - if the ndarray has fewer than 2 dimensions.
153 ///
154 /// # Examples
155 /// ```
156 /// # use redstone_ml::*;
157 /// let arr = NdArray::new([
158 /// [1, 2, 3],
159 /// [4, 5, 6],
160 /// [7, 8, 9]
161 /// ]);
162 ///
163 /// assert_eq!(arr.offset_trace(-1), NdArray::scalar(4 + 8));
164 pub fn offset_trace<'r>(&self, offset: isize) -> NdArray<'r, T> {
165 self.offset_trace_along(offset, 0, 1)
166 }
167
168 /// Returns the trace of an ndarray along the specified axes.
169 ///
170 /// # Panics
171 /// - if the ndarray has fewer than 2 dimensions.
172 /// - if `axis1` and `axis2` are the same or are out-of-bounds
173 ///
174 /// # Examples
175 /// ```
176 /// # use redstone_ml::*;
177 /// let ndarray = NdArray::new([
178 /// [1, 2, 3],
179 /// [4, 5, 6],
180 /// [7, 8, 9]
181 /// ]);
182 ///
183 /// assert_eq!(ndarray.trace_along(0, 1), NdArray::scalar(1 + 5 + 9));
184 pub fn trace_along<'r>(&self, axis1: impl AxisType, axis2: impl AxisType) -> NdArray<'r, T> {
185 self.offset_trace_along(0, axis1, axis2)
186 }
187
188 /// Returns the sum of an offset ndarray diagonal along the specified axes.
189 ///
190 /// # Panics
191 /// - if the ndarray has fewer than 2 dimensions.
192 /// - if `axis1` and `axis2` are the same or are out-of-bounds
193 ///
194 /// # Examples
195 /// ```
196 /// # use redstone_ml::*;
197 /// let ndarray = NdArray::new([
198 /// [1, 2, 3],
199 /// [4, 5, 6],
200 /// [7, 8, 9]
201 /// ]);
202 ///
203 /// assert_eq!(ndarray.offset_trace_along(1, 0, 1), NdArray::scalar(2 + 6));
204 pub fn offset_trace_along<'r>(&self, offset: isize, axis1: impl AxisType, axis2: impl AxisType) -> NdArray<'r, T> {
205 let diagonal = self.offset_diagonal_along(offset, axis1, axis2);
206 diagonal.sum_along(-1)
207 }
208}
209
210impl<'a, T: RawDataType> NdArray<'a, T> {
211 /// Returns a diagonal view of the ndarray along its first 2 axes.
212 ///
213 /// # Panics
214 /// - if the ndarray has fewer than 2 dimensions.
215 ///
216 /// # Examples
217 /// ```
218 /// # use redstone_ml::*;
219 /// let ndarray = NdArray::new([
220 /// [1, 2, 3],
221 /// [4, 5, 6],
222 /// [7, 8, 9]
223 /// ]);
224 ///
225 /// let diagonal = ndarray.diagonal();
226 /// assert_eq!(diagonal, NdArray::new([1, 5, 9]));
227 pub fn diagonal(&'a self) -> NdArray<'a, T> {
228 self.diagonal_along(0, 1)
229 }
230
231 /// Returns an offset diagonal view of the ndarray along its first 2 axes.
232 ///
233 /// # Panics
234 /// - if the ndarray has fewer than 2 dimensions.
235 ///
236 /// # Examples
237 /// ```
238 /// # use redstone_ml::*;
239 /// let ndarray = NdArray::new([
240 /// [1, 2, 3],
241 /// [4, 5, 6],
242 /// [7, 8, 9]
243 /// ]);
244 ///
245 /// let diagonal = ndarray.offset_diagonal(1);
246 /// assert_eq!(diagonal, NdArray::new([2, 6]));
247 pub fn offset_diagonal(&'a self, offset: isize) -> NdArray<'a, T> {
248 self.offset_diagonal_along(offset, 0, 1)
249 }
250
251 /// Returns a diagonal view of the ndarray along the specified axes.
252 ///
253 /// # Panics
254 /// - if the ndarray has fewer than 2 dimensions.
255 /// - if `axis1` and `axis2` are the same or are out-of-bounds
256 ///
257 /// # Examples
258 /// ```
259 /// # use redstone_ml::*;
260 /// let ndarray = NdArray::new([
261 /// [1, 2, 3],
262 /// [4, 5, 6],
263 /// [7, 8, 9]
264 /// ]);
265 ///
266 /// let diagonal = ndarray.diagonal_along(Axis(0), Axis(1)); // or .diagonal_along(0, 1)
267 /// assert_eq!(diagonal, NdArray::new([1, 5, 9]));
268 pub fn diagonal_along(&'a self, axis1: impl AxisType, axis2: impl AxisType) -> NdArray<'a, T> {
269 self.offset_diagonal_along(0, axis1, axis2)
270 }
271
272 /// Returns an offset diagonal view of the ndarray along the specified axes.
273 ///
274 /// # Panics
275 /// - if the ndarray has fewer than 2 dimensions.
276 /// - if `axis1` and `axis2` are the same or are out-of-bounds
277 ///
278 /// # Examples
279 /// ```
280 /// # use redstone_ml::*;
281 /// let arr = NdArray::new([
282 /// [1, 2, 3],
283 /// [4, 5, 6],
284 /// [7, 8, 9]
285 /// ]);
286 ///
287 /// let diagonal = arr.offset_diagonal_along(-1, Axis(0), Axis(1)); // or .offset_diagonal_along(-1, 0, 1)
288 /// assert_eq!(diagonal, NdArray::new([4, 8]));
289 pub fn offset_diagonal_along(&'a self, offset: isize, axis1: impl AxisType, axis2: impl AxisType) -> NdArray<'a, T> {
290 assert!(self.ndims() >= 2, "diagonals require a ndarray with at least 2 dimensions");
291
292 let axis1 = axis1.as_absolute(self.ndims());
293 let axis2 = axis2.as_absolute(self.ndims());
294
295 assert_ne!(axis1, axis2, "axis1 and axis2 cannot be the same");
296
297
298 // get the new dimensions and strides of the two axes
299
300 let mut dim1 = self.shape()[axis1];
301 let mut dim2 = self.shape()[axis2];
302
303 let stride1 = self.stride()[axis1];
304 let stride2 = self.stride()[axis2];
305
306
307 // modify the dimensions and data pointer based on offset
308
309 let ptr_offset = if offset >= 0 {
310 let offset = offset as usize;
311 if offset >= dim2 {
312 panic!("invalid offset {} for axis with dimension {}", offset, dim2);
313 }
314
315 dim2 -= offset;
316 offset * stride2
317 } else {
318 let offset = -offset as usize;
319 if offset >= dim1 {
320 panic!("invalid offset -{} for axis with dimension {}", offset, dim1);
321 }
322
323 dim1 -= offset;
324 offset * stride1
325 };
326
327
328 // compute the resultant shape and stride
329
330 let mut result_shape = Vec::with_capacity(self.ndims() - 1);
331 let mut result_stride = Vec::with_capacity(self.ndims() - 1);
332
333 for axis in 0..self.ndims() {
334 if axis == axis1 || axis == axis2 {
335 continue;
336 }
337
338 result_shape.push(self.shape()[axis]);
339 result_stride.push(self.stride()[axis]);
340 }
341
342 result_shape.push(min(dim1, dim2));
343 result_stride.push(stride1 + stride2);
344
345 // create and return the diagonal view
346 unsafe { self.reshaped_view_with_offset(ptr_offset, result_shape, result_stride) }
347 }
348}
349
350
351pub(crate) trait MatrixOps: SumOfProductsType {
352 /// Performs an unchecked batched matrix-matrix product operation
353 /// and writes the result to the given pointer
354 ///
355 /// # Safety
356 ///
357 /// - The dimensions of `lhs` and `rhs` must be `(b, i, j)` and `(b, j, k)`.
358 /// - `result` must point to a valid data buffer with dimension `(b, i, k)`
359 /// - `result_stride` must represent a valid layout for the results buffer with
360 /// the last 2 dimensions being contiguous.
361 /// - `result` must not overlap with `lhs` or `rhs`.
362 unsafe fn batch_matrix_matrix_product<'a>(lhs: &NdArray<'a, Self>,
363 rhs: &NdArray<'a, Self>,
364 result_stride: &[usize],
365 mut result: *mut Self) {
366 let mut lhs_slice = lhs.slice_along(Axis(0), 0);
367 let mut rhs_slice = rhs.slice_along(Axis(0), 0);
368
369 for _ in 0..lhs.len() {
370 Self::matrix_matrix_product(&lhs_slice, &rhs_slice, &result_stride[1..], result);
371
372 result = result.add(result_stride[0]);
373 lhs_slice.offset_ptr(lhs.stride()[0] as isize);
374 rhs_slice.offset_ptr(rhs.stride()[0] as isize);
375 }
376 }
377
378 /// Performs an unchecked matrix-matrix product and writes the result to the given pointer
379 ///
380 /// # Safety
381 ///
382 /// - The dimensions of `lhs` and `rhs` must be `(i, j)` and `(j, k)`.
383 /// - `result` must point to a valid data buffer with dimension `(i, k)`.
384 /// - `result_stride` must represent a contiguous layout for the results buffer.
385 /// - `result` must not overlap with `lhs` or `rhs`.
386 unsafe fn matrix_matrix_product<'a>(lhs: &NdArray<'a, Self>,
387 rhs: &NdArray<'a, Self>,
388 result_stride: &[usize],
389 result: *mut Self)
390 {
391 einsum_into_ptr([lhs, rhs], (["ij", "jk"], "ik"), result_stride, result)
392 }
393
394 /// Performs an unchecked matrix-vector product and returns the result.
395 ///
396 /// # Safety
397 ///
398 /// - The dimensions of `lhs` and `rhs` must be `(i, j)` and `(j)`.
399 unsafe fn matrix_vector_product<'a, 'b, 'r>(matrix: &NdArray<'a, Self>,
400 vector: &NdArray<'b, Self>) -> NdArray<'r, Self> {
401 let rows = matrix.shape()[0];
402 let cols = matrix.shape()[1];
403 let mut result = vec![Self::default(); rows];
404
405 let strides = &[matrix.stride()[1], vector.stride()[0], 0];
406
407 let mut matrix_row = matrix.mut_ptr();
408 let mut dst = result.as_mut_ptr();
409
410 for _ in 0..rows {
411 Self::sum_of_products_in_strides_n_n_out_stride_0(&[matrix_row, vector.mut_ptr(), dst], strides, cols);
412 matrix_row = matrix_row.add(matrix.stride()[0]);
413 dst = dst.add(1);
414 }
415
416 NdArray::from_contiguous_owned_buffer(vec![rows], result)
417 }
418}
419
420impl<T: IntegerDataType> MatrixOps for T {}
421
422impl MatrixOps for f32 {
423 #[cfg(blas)]
424 unsafe fn matrix_matrix_product<'a>(lhs: &NdArray<'a, Self>,
425 rhs: &NdArray<'a, Self>,
426 result_stride: &[usize],
427 result: *mut Self) {
428 use crate::acceleration::cblas::{cblas_sgemm, CBLAS_NO_TRANS, CBLAS_ROW_MAJOR};
429
430 // BLAS does not support matrices that don't have contiguous rows
431 if lhs.stride()[1] != 1 || rhs.stride()[1] != 1 {
432 return einsum_into_ptr([lhs, rhs], (["ij", "jk"], "ik"), result_stride, result);
433 }
434
435 let m = lhs.shape()[0];
436 let n = rhs.shape()[1];
437
438 unsafe {
439 cblas_sgemm(CBLAS_ROW_MAJOR, CBLAS_NO_TRANS, CBLAS_NO_TRANS,
440 m as i32, n as i32, lhs.shape()[1] as i32,
441 1.0,
442 lhs.mut_ptr(), lhs.stride()[0] as i32,
443 rhs.mut_ptr(), rhs.stride()[0] as i32,
444 0.0, result, n as i32);
445 }
446 }
447
448 #[cfg(all(blas, not(neon_simd)))]
449 unsafe fn matrix_vector_product<'a, 'b, 'r>(matrix: &NdArray<'a, Self>,
450 vector: &NdArray<'b, Self>) -> NdArray<'r, Self> {
451 use crate::acceleration::cblas::{cblas_sgemv, CBLAS_NO_TRANS, CBLAS_ROW_MAJOR};
452 use crate::einsum;
453
454 // BLAS does not support matrices that don't have contiguous rows
455 if matrix.stride()[1] != 1 {
456 return einsum([matrix, vector], (["ij", "j"], "i"));
457 }
458
459 let rows = matrix.shape()[0];
460 let cols = matrix.shape()[1] as i32;
461 let mut result = vec![Self::default(); rows];
462
463 unsafe {
464 cblas_sgemv(CBLAS_ROW_MAJOR, CBLAS_NO_TRANS,
465 rows as i32, cols, 1.0, matrix.ptr(), matrix.stride()[0] as i32,
466 vector.ptr(), vector.stride()[0] as i32,
467 0.0, result.as_mut_ptr(), 1
468 );
469
470 NdArray::from_contiguous_owned_buffer(vec![rows], result)
471 }
472 }
473}
474
475impl MatrixOps for f64 {
476 #[cfg(blas)]
477 unsafe fn matrix_matrix_product<'a>(lhs: &NdArray<'a, Self>,
478 rhs: &NdArray<'a, Self>,
479 result_stride: &[usize],
480 result: *mut Self) {
481 use crate::acceleration::cblas::{cblas_dgemm, CBLAS_NO_TRANS, CBLAS_ROW_MAJOR};
482
483 // BLAS does not support matrices that don't have contiguous rows
484 if lhs.stride()[1] != 1 || rhs.stride()[1] != 1 {
485 return einsum_into_ptr([lhs, rhs], (["ij", "jk"], "ik"), result_stride, result);
486 }
487
488 let m = lhs.shape()[0];
489 let n = rhs.shape()[1];
490
491 unsafe {
492 cblas_dgemm(CBLAS_ROW_MAJOR, CBLAS_NO_TRANS, CBLAS_NO_TRANS,
493 m as i32, n as i32, lhs.shape()[1] as i32,
494 1.0,
495 lhs.mut_ptr(), lhs.stride()[0] as i32,
496 rhs.mut_ptr(), rhs.stride()[0] as i32,
497 0.0, result, n as i32);
498 }
499 }
500
501 #[cfg(all(blas, not(neon_simd)))]
502 unsafe fn matrix_vector_product<'a, 'b, 'r>(matrix: &NdArray<'a, Self>,
503 vector: &NdArray<'b, Self>) -> NdArray<'r, Self> {
504 use crate::acceleration::cblas::{cblas_dgemv, CBLAS_NO_TRANS, CBLAS_ROW_MAJOR};
505 use crate::einsum;
506
507 // BLAS does not support matrices that don't have contiguous rows
508 if matrix.stride()[1] != 1 {
509 return einsum([matrix, vector], (["ij", "j"], "i"));
510 }
511
512 let rows = matrix.shape()[0];
513 let cols = matrix.shape()[1] as i32;
514 let mut result = vec![Self::default(); rows];
515
516 unsafe {
517 cblas_dgemv(CBLAS_ROW_MAJOR, CBLAS_NO_TRANS,
518 rows as i32, cols, 1.0, matrix.ptr(), matrix.stride()[0] as i32,
519 vector.ptr(), vector.stride()[0] as i32,
520 0.0, result.as_mut_ptr(), 1
521 );
522
523 NdArray::from_contiguous_owned_buffer(vec![rows], result)
524 }
525 }
526}