redstone_ml/
ndarray.rs

1//! # N-dimensional Array for Linear Algebra & Tensor Computations
2//!
3//! An `NdArray` is a fixed-size multidimensional array container defined by its `shape`
4//! and datatype. 1D (vectors) and 2D (matrices) arrays are often of special interest
5//! and can be used in various linear algebra computations
6//! including dot products, matrix products, batch matrix multiplications, einsums, and more.
7//!
8//! `NdArrays` can be iterated over (along configurable dimensions), reshaped, sliced, and indexed,
9//! reduced, and more.
10//!
11//! This struct is heavily modeled after NumPy's `ndarray` and supports many of the same methods.
12//!
13//! Example:
14//!
15//! ```rust
16//! use redstone_ml::*;
17//!
18//! let matrix_a = NdArray::new([[1, 3, 2], [-1, 0, -1]]); // shape [2, 3]
19//! let matrix_b = NdArray::randint([3, 7], -5, 3);
20//!
21//! let matrix_view = matrix_b.slice_along(Axis(1), 0..2); // shape [3, 2]
22//! let matrix_c = matrix_a.matmul(matrix_view);
23//!
24//! let result = matrix_c.sum();
25//! ```
26//!
27//! ## NdArray Views & Lifetimes
28//!
29//! There are 2 ways we can create NdArray views: by borrowing or by consuming:
30//! ```rust
31//! # use redstone_ml::*;
32//! let data = NdArray::<f64>::rand([9]);
33//! let matrix = (&data).reshape([3, 3]); // by borrowing (data remains alive after)
34//!
35//! let data = NdArray::<f64>::rand([9]);
36//! let matrix = data.reshape([3, 3]); // by consuming data
37//! ```
38//!
39//! The consuming syntax allows us to chain operations without worrying about lifetimes
40//! ```rust
41//! # use redstone_ml::*;
42//! // a reshaped and transposed random matrix
43//! let matrix = NdArray::<f64>::rand([9]).reshape([3, 3]).T();
44//! ```
45//!
46//! Operations like `reshape`, `view`, `diagonal`, `squeeze`, `unsqueeze`, `T`, `transpose`, and
47//! `ravel` do not create new NdArrays by duplicating memory (which would be slow).
48//! They always return `NdArray` views which share memory with the source `NdArray`.
49//! `NdArray::clone()` or `NdArray::flatten()` can be used to duplicate the underlying `NdArray`.
50//!
51//! This means that all `NdArray` views have a lifetime at-most as long as the source `NdArray`.
52//!
53//! ## Linear Algebra, Broadcasting, and Reductions
54//!
55//! We currently support the core linear algebra operations including dot products,
56//! matrix-vector and matrix-matrix multiplications, batched matrix multiplications, and trace.
57//!
58//! ```rust
59//! # use redstone_ml::*;
60//! # let matrix = NdArray::<f64>::randn([3, 3]);
61//! # let matrix1 = NdArray::<f64>::randn([3, 3]);
62//! # let matrix2 = NdArray::<f64>::randn([3, 3]);
63//! # let batch_matrices1 = NdArray::<f64>::randn([2, 3, 3]);
64//! # let batch_matrices2 = NdArray::<f64>::randn([2, 3, 3]);
65//! # let vector = NdArray::<f64>::randn([3]);
66//! # let vector1 = NdArray::<f64>::randn([3]);
67//! # let vector2 = NdArray::<f64>::randn([3]);
68//! vector1.dot(vector2);
69//!
70//! matrix.trace(); // also trace_along/offset_trace
71//! matrix.diagonal(); // also diagonal_along/offset_diagonal
72//! matrix.matmul(vector);
73//! matrix1.matmul(matrix2);
74//!
75//! batch_matrices1.bmm(batch_matrices2);
76//! ```
77//!
78//! We can also perform various reductions including `sum`, `product`, `min`, `max`,
79//! `min_magnitude`, and `max_magnitude`. Each of these is accelerated with various libraries
80//! including vDSP, Arm64 NEON SIMD, and BLAS.
81//!
82//! ```rust
83//! # use redstone_ml::*;
84//! # let ndarray = NdArray::<f64>::zeros([5, 5, 5]);
85//! let sum = ndarray.sum();
86//! let sum_along = ndarray.sum_along([0, -1]); // sum along first and last axes
87//! ```
88//!
89//! `NdArrays` can be used in arithmetic operations using the usual binary operators including
90//! addition (`+`), subtraction (`-`), multiplication (`*`), division (`/`), remainder (`%`),
91//! and bitwise operations (`&`, `|`, `<<`, `>>`).
92//!
93//! ```rust
94//! # use redstone_ml::*;
95//! # let arr1 = NdArray::<f64>::zeros([2, 2, 2]);
96//! # let arr2 = NdArray::<f64>::zeros([2, 2, 2]);
97//! let result = &arr1 + &arr2; // non-consuming
98//! let result = &arr1 + arr2;  // consumes RHS
99//! # let arr2 = NdArray::<f64>::zeros([2, 2, 2]);
100//! let result = arr1 + arr2;   // consumes both
101//! ```
102//!
103//! `NdArrays` are automatically broadcast using the exact same rules as NumPy
104//! to perform efficient computations with different-dimensional (yet compatible) data.
105//!
106//! ## Slicing, Indexing, and Iterating
107//!
108//! Slicing and indexing an `NdArray` always return a view. This is how we can access various
109//! elements of vectors, columns/rows of matrices, and more.
110//!
111//! ```rust
112//! # use redstone_ml::*;
113//! let arr = NdArray::<f32>::rand([2, 4, 3, 5]); // 4D NdArray
114//! let slice1 = arr.slice(s![.., 0, ..=2]);      // use s! to specify a slice
115//! let slice2 = arr.slice_along(Axis(-2), 0);    // 0th element along second-to-last axis
116//! let el = arr[[0, 3, 2, 4]];
117//! ```
118//!
119//! One can also iterate over an `NdArray` in various ways:
120//! ```rust
121//! # use redstone_ml::*;
122//! # let arr = NdArray::<f32>::rand([2, 4, 3, 5]); // 4D NdArray
123//! for subarray in arr.iter() { /* 4x3x5 subarrays */ }
124//! for subarray in arr.iter_along(Axis(2)) { /* 2x4x5 subarrays */ }
125//! for el in arr.flatiter() { /* element-wise iteration */ }
126//! ```
127//!
128//! ## Other Constructors
129//!
130//! ```rust
131//! # use redstone_ml::*;
132//! let ndarray = NdArray::arange(0i32, 5); // [0, 1, 2, 3, 4]
133//! let ndarray = NdArray::linspace(0f32, 1.0, 5); // [0.0, 0.25, 0.5, 0.75, 1.0]
134//! ```
135//!
136//! ```rust
137//! # use redstone_ml::*;
138//! let ndarray = NdArray::full(5.0, [5, 4, 2]);
139//! let falses = NdArray::<bool>::zeros([5, 4, 2]);
140//! ```
141//!
142//! A scalar `NdArray` is dimensionless and contains a single value.
143//! It is often the return value for reduction methods like `sum`, `product`, `min`, and `max`.
144//! ```rust
145//! # use redstone_ml::*;
146//! let ten = NdArray::scalar(10u8);
147//! ```
148//!
149//! In many cases, one desires randomized multidimensional arrays with a specified shape.
150//! ```rust
151//! # use redstone_ml::*;
152//! let rand = NdArray::<f32>::randn([3, 4]);
153//! let rand = NdArray::<f32>::rand([3, 4]);
154//! let rand = NdArray::randint([3, 4], -5, 3);
155//! ```
156
157
158use std::marker::PhantomData;
159use std::ptr::NonNull;
160
161pub mod methods;
162
163pub mod iterator;
164pub use iterator::*;
165
166pub mod reshape;
167
168pub(crate) mod flags;
169use flags::NdArrayFlags;
170
171pub mod reduce;
172
173pub mod constructors;
174pub mod index_impl;
175pub mod slice;
176pub mod fill;
177pub mod clone;
178pub mod equals;
179pub mod broadcast;
180pub mod binary_ops;
181pub mod astype;
182
183mod print;
184mod unary_ops;
185mod assign_ops;
186
187pub(crate) const MAX_DIMS: usize = 32;
188pub(crate) const MAX_ARGS: usize = 16;
189
190use crate::dtype::RawDataType;
191
192pub struct NdArray<'a, T: RawDataType> {
193    pub(crate) ptr: NonNull<T>,
194    len: usize,
195    capacity: usize,
196
197    shape: Vec<usize>,
198    stride: Vec<usize>,
199    pub(crate) flags: NdArrayFlags,
200
201    _marker: PhantomData<&'a T>,
202}