redstone_ml/common/
constructors.rs

1use crate::util::flatten::Flatten;
2use crate::util::nested::Nested;
3use crate::util::shape::Shape;
4use crate::util::to_vec::ToVec;
5use crate::{FloatDataType, NumericDataType, RawDataType, StridedMemory};
6use num::NumCast;
7
8pub trait Constructors<T: RawDataType>: StridedMemory {
9    /// Constructs a new ndarray from the given data buffer and shape assuming a contiguous layout
10    ///
11    /// # Parameters
12    /// - `shape`: A vector that defines the dimensions of the ndarray.
13    /// - `data`: The underlying buffer that holds the ndarray's elements.
14    /// - `requires_grad`: If gradients need to be computed for this ndarray.
15    ///
16    /// # Safety
17    /// - `data` must remain valid and not be used elsewhere after being passed to this function.
18    /// - `shape.iter().product()` must equal `data.len()`
19    unsafe fn from_contiguous_owned_buffer(shape: Vec<usize>, data: Vec<T>) -> Self;
20
21    /// Constructs an n-dimensional `NdArray` from input data such as a vector or array.
22    ///
23    /// # Parameters
24    /// - `data`: a nested array or vector of valid data types (floats, integers, bools)
25    ///
26    /// # Panics
27    ///   - If the input data has inhomogeneous dimensions, i.e., nested arrays do not have consistent sizes.
28    ///   - If the input data is empty (cannot create a zero-length ndarray)
29    ///
30    /// # Example
31    /// ```
32    /// # use redstone_ml::*;
33    ///
34    /// let ndarray : NdArray<i32> = NdArray::new([[1, 2], [3, 4]]);
35    /// assert_eq!(ndarray.shape(), &[2, 2]);
36    ///
37    /// let ndarray = NdArray::new(vec![1f32, 2.0, 3.0, 4.0, 5.0]);
38    /// assert_eq!(ndarray.ndims(), 1);
39    /// ```
40    fn new<const D: usize>(data: impl Flatten<T> + Shape + Nested<{ D }>) -> Self {
41        assert!(data.check_homogenous(), "from() failed, found inhomogeneous dimensions");
42
43        let shape = data.shape();
44        let data = data.flatten();
45
46        assert!(!data.is_empty(), "from() failed, cannot create data buffer from empty data");
47
48        unsafe { Self::from_contiguous_owned_buffer(shape, data) }
49    }
50
51    /// Creates an ndarray filled with a specified value and given shape.
52    ///
53    /// # Parameters
54    ///
55    /// * `n` - The value to fill the ndarray with (can be any valid data type like float, integer, or bool).
56    /// * `shape` - An array or vector representing the shape of the ndarray (e.g. `[2, 3, 5]`).
57    ///
58    /// # Panics
59    /// This function panics if the provided shape is empty.
60    ///
61    /// # Examples
62    ///
63    /// ```
64    /// # use redstone_ml::*;
65    ///
66    /// let ndarray = NdArray::full(5i32, [2, 3]); // creates a 2x3 ndarray filled with the value 5.
67    /// let ndarray = NdArray::full(true, [2, 3, 5]); // creates a 2x3x5 ndarray filled with 'true'
68    /// ```
69    fn full(n: T, shape: impl ToVec<usize>) -> Self {
70        let shape = shape.to_vec();
71
72        let data = vec![n; shape.iter().product()];
73        assert!(!data.is_empty(), "cannot create an empty ndarray!");
74
75        unsafe { Self::from_contiguous_owned_buffer(shape, data) }
76    }
77
78    /// Creates a new ndarray filled with zeros with the given shape.
79    ///
80    /// # Parameters
81    /// - `shape`: An array or vector representing the shape of the ndarray (e.g. `[2, 3, 5]`).
82    ///
83    /// # Panics
84    /// This function panics if the provided shape is empty.
85    ///
86    /// # Examples
87    /// ```
88    /// # use redstone_ml::*;
89    ///
90    /// let ndarray = NdArray::<i32>::zeros([2, 3]);
91    /// let ndarray = NdArray::<bool>::zeros([2, 3]);  // creates an ndarray filled with 'false'
92    /// ```
93    fn zeros(shape: impl ToVec<usize>) -> Self
94    where
95        T: From<bool>
96    {
97        Self::full(false.into(), shape)
98    }
99
100    /// Creates a new ndarray filled with ones with the given shape.
101    ///
102    /// # Parameters
103    /// - `shape`: An array or vector representing the shape of the ndarray (e.g. `[2, 3, 5]`).
104    ///
105    /// # Panics
106    /// This function panics if the provided shape is empty.
107    ///
108    /// # Examples
109    /// ```
110    /// # use redstone_ml::*;
111    ///
112    /// let ndarray = NdArray::<i32>::ones([2, 3]);
113    /// let ndarray = NdArray::<bool>::ones([2, 3]);  // creates an ndarray filled with 'true'
114    /// ```
115    fn ones(shape: impl ToVec<usize>) -> Self
116    where
117        T: From<bool>
118    {
119        Self::full(true.into(), shape)
120    }
121
122    /// Creates a 0-dimensional (shapeless) ndarray containing a single value.
123    ///
124    /// # Parameters
125    /// - `n`: The value to be stored in the scalar ndarray.
126    ///
127    /// # Example
128    /// ```
129    /// # use redstone_ml::*;
130    ///
131    /// let scalar_array = NdArray::scalar(42);
132    /// assert_eq!(scalar_array.shape(), []);
133    /// assert_eq!(scalar_array.value(), 42);
134    /// ```
135    fn scalar(n: T) -> Self {
136        Self::full(n, [])
137    }
138
139    /// Generates a 1D ndarray with evenly spaced values within a specified range.
140    ///
141    /// # Arguments
142    ///
143    /// * `start` - The starting value of the sequence, inclusive.
144    /// * `stop` - The ending value of the sequence, exclusive.
145    ///
146    /// # Returns
147    ///
148    /// An `NdArray` containing values starting from `start` and ending before `stop`,
149    /// with a step-size of 1.
150    ///
151    /// # Examples
152    ///
153    /// ```
154    /// # use redstone_ml::*;
155    /// let ndarray = NdArray::arange(0i32, 5); // [0, 1, 2, 3, 4].
156    /// ```
157    fn arange(start: T, stop: T) -> Self
158    where
159        T: NumericDataType
160    {
161        Self::arange_with_step(start, stop, T::one())
162    }
163
164    /// Generates a 1D ndarray with evenly spaced values within a specified range.
165    ///
166    /// # Arguments
167    ///
168    /// * `start` - The starting value of the sequence, inclusive.
169    /// * `stop` - The ending value of the sequence, exclusive.
170    /// * `step` - The interval between each consecutive value
171    ///
172    /// # Examples
173    ///
174    /// ```
175    /// # use redstone_ml::*;
176    /// let ndarray = NdArray::arange_with_step(0i32, 5, 2); // [0, 2, 4].
177    /// ```
178    fn arange_with_step(start: T, stop: T, step: T) -> Self
179    where
180        T: NumericDataType
181    {
182        let n = ((stop - start).to_float() / step.to_float()).ceil();
183        let n = NumCast::from(n).unwrap();
184
185        let mut data: Vec<T> = vec![T::default(); n];
186        for (i, item) in data.iter_mut().enumerate() {
187            *item = <T as NumCast>::from(i).unwrap() * step + start;
188        }
189
190        unsafe { Self::from_contiguous_owned_buffer(vec![data.len()], data) }
191    }
192
193    /// Generates a 1-dimensional ndarray with `num `evenly spaced values between `start` and `stop`
194    /// (inclusive).
195    ///
196    /// # Arguments
197    ///
198    /// * `start` - The starting value of the sequence.
199    /// * `stop` - The ending value of the sequence. The value is inclusive in the range.
200    /// * `num` - The number of evenly spaced values to generate. Must be greater than 0.
201    ///
202    /// # Panic
203    ///
204    /// Panics if `num` is 0.
205    ///
206    /// # Example
207    ///
208    /// ```
209    /// # use redstone_ml::*;
210    /// let result = NdArray::linspace(0f32, 1.0, 5);  // [0.0, 0.25, 0.5, 0.75, 1.0]
211    /// assert_eq!(result, NdArray::new([0f32, 0.25, 0.5, 0.75, 1.0]));
212    /// ```
213    fn linspace(start: T, stop: T, num: usize) -> Self
214    where
215        T: FloatDataType
216    {
217        assert!(num > 0);
218
219        if num == 1 {
220            return unsafe { Self::from_contiguous_owned_buffer(vec![1], vec![start]) };
221        }
222
223        let step = (stop - start) / (<T as NumCast>::from(num).unwrap() - T::one());
224
225        // from start to (stop + step) to make the range inclusive
226        Self::arange_with_step(start, stop + step, step)
227    }
228
229    /// Generates a 1-dimensional ndarray with `num `evenly spaced values between `start` and `stop`
230    /// (exclusive).
231    ///
232    /// # Arguments
233    ///
234    /// * `start` - The starting value of the sequence.
235    /// * `stop` - The ending value of the sequence. The value is exclusive in the range.
236    /// * `num` - The number of evenly spaced values to generate. Must be greater than 0.
237    ///
238    /// # Panic
239    ///
240    /// Panics if `num` is 0.
241    ///
242    /// # Example
243    ///
244    /// ```
245    /// # use redstone_ml::*;
246    /// let result = NdArray::linspace_exclusive(0.0f32, 1.0, 5);
247    /// assert_eq!(result, NdArray::new([0f32, 0.2, 0.4, 0.6, 0.8]));
248    /// ```
249    fn linspace_exclusive(start: T, stop: T, num: usize) -> Self
250    where
251        T: FloatDataType
252    {
253        assert!(num > 0);
254
255        if num == 1 {
256            return unsafe { Self::from_contiguous_owned_buffer(vec![1], vec![start]) };
257        }
258
259        let step = (stop - start) / <T as NumCast>::from(num).unwrap();
260        Self::arange_with_step(start, stop, step)
261    }
262}