redstone_ml/common/constructors.rs
1use crate::util::flatten::Flatten;
2use crate::util::nested::Nested;
3use crate::util::shape::Shape;
4use crate::util::to_vec::ToVec;
5use crate::{FloatDataType, NumericDataType, RawDataType, StridedMemory};
6use num::NumCast;
7
8pub trait Constructors<T: RawDataType>: StridedMemory {
9 /// Constructs a new ndarray from the given data buffer and shape assuming a contiguous layout
10 ///
11 /// # Parameters
12 /// - `shape`: A vector that defines the dimensions of the ndarray.
13 /// - `data`: The underlying buffer that holds the ndarray's elements.
14 /// - `requires_grad`: If gradients need to be computed for this ndarray.
15 ///
16 /// # Safety
17 /// - `data` must remain valid and not be used elsewhere after being passed to this function.
18 /// - `shape.iter().product()` must equal `data.len()`
19 unsafe fn from_contiguous_owned_buffer(shape: Vec<usize>, data: Vec<T>) -> Self;
20
21 /// Constructs an n-dimensional `NdArray` from input data such as a vector or array.
22 ///
23 /// # Parameters
24 /// - `data`: a nested array or vector of valid data types (floats, integers, bools)
25 ///
26 /// # Panics
27 /// - If the input data has inhomogeneous dimensions, i.e., nested arrays do not have consistent sizes.
28 /// - If the input data is empty (cannot create a zero-length ndarray)
29 ///
30 /// # Example
31 /// ```
32 /// # use redstone_ml::*;
33 ///
34 /// let ndarray : NdArray<i32> = NdArray::new([[1, 2], [3, 4]]);
35 /// assert_eq!(ndarray.shape(), &[2, 2]);
36 ///
37 /// let ndarray = NdArray::new(vec![1f32, 2.0, 3.0, 4.0, 5.0]);
38 /// assert_eq!(ndarray.ndims(), 1);
39 /// ```
40 fn new<const D: usize>(data: impl Flatten<T> + Shape + Nested<{ D }>) -> Self {
41 assert!(data.check_homogenous(), "from() failed, found inhomogeneous dimensions");
42
43 let shape = data.shape();
44 let data = data.flatten();
45
46 assert!(!data.is_empty(), "from() failed, cannot create data buffer from empty data");
47
48 unsafe { Self::from_contiguous_owned_buffer(shape, data) }
49 }
50
51 /// Creates an ndarray filled with a specified value and given shape.
52 ///
53 /// # Parameters
54 ///
55 /// * `n` - The value to fill the ndarray with (can be any valid data type like float, integer, or bool).
56 /// * `shape` - An array or vector representing the shape of the ndarray (e.g. `[2, 3, 5]`).
57 ///
58 /// # Panics
59 /// This function panics if the provided shape is empty.
60 ///
61 /// # Examples
62 ///
63 /// ```
64 /// # use redstone_ml::*;
65 ///
66 /// let ndarray = NdArray::full(5i32, [2, 3]); // creates a 2x3 ndarray filled with the value 5.
67 /// let ndarray = NdArray::full(true, [2, 3, 5]); // creates a 2x3x5 ndarray filled with 'true'
68 /// ```
69 fn full(n: T, shape: impl ToVec<usize>) -> Self {
70 let shape = shape.to_vec();
71
72 let data = vec![n; shape.iter().product()];
73 assert!(!data.is_empty(), "cannot create an empty ndarray!");
74
75 unsafe { Self::from_contiguous_owned_buffer(shape, data) }
76 }
77
78 /// Creates a new ndarray filled with zeros with the given shape.
79 ///
80 /// # Parameters
81 /// - `shape`: An array or vector representing the shape of the ndarray (e.g. `[2, 3, 5]`).
82 ///
83 /// # Panics
84 /// This function panics if the provided shape is empty.
85 ///
86 /// # Examples
87 /// ```
88 /// # use redstone_ml::*;
89 ///
90 /// let ndarray = NdArray::<i32>::zeros([2, 3]);
91 /// let ndarray = NdArray::<bool>::zeros([2, 3]); // creates an ndarray filled with 'false'
92 /// ```
93 fn zeros(shape: impl ToVec<usize>) -> Self
94 where
95 T: From<bool>
96 {
97 Self::full(false.into(), shape)
98 }
99
100 /// Creates a new ndarray filled with ones with the given shape.
101 ///
102 /// # Parameters
103 /// - `shape`: An array or vector representing the shape of the ndarray (e.g. `[2, 3, 5]`).
104 ///
105 /// # Panics
106 /// This function panics if the provided shape is empty.
107 ///
108 /// # Examples
109 /// ```
110 /// # use redstone_ml::*;
111 ///
112 /// let ndarray = NdArray::<i32>::ones([2, 3]);
113 /// let ndarray = NdArray::<bool>::ones([2, 3]); // creates an ndarray filled with 'true'
114 /// ```
115 fn ones(shape: impl ToVec<usize>) -> Self
116 where
117 T: From<bool>
118 {
119 Self::full(true.into(), shape)
120 }
121
122 /// Creates a 0-dimensional (shapeless) ndarray containing a single value.
123 ///
124 /// # Parameters
125 /// - `n`: The value to be stored in the scalar ndarray.
126 ///
127 /// # Example
128 /// ```
129 /// # use redstone_ml::*;
130 ///
131 /// let scalar_array = NdArray::scalar(42);
132 /// assert_eq!(scalar_array.shape(), []);
133 /// assert_eq!(scalar_array.value(), 42);
134 /// ```
135 fn scalar(n: T) -> Self {
136 Self::full(n, [])
137 }
138
139 /// Generates a 1D ndarray with evenly spaced values within a specified range.
140 ///
141 /// # Arguments
142 ///
143 /// * `start` - The starting value of the sequence, inclusive.
144 /// * `stop` - The ending value of the sequence, exclusive.
145 ///
146 /// # Returns
147 ///
148 /// An `NdArray` containing values starting from `start` and ending before `stop`,
149 /// with a step-size of 1.
150 ///
151 /// # Examples
152 ///
153 /// ```
154 /// # use redstone_ml::*;
155 /// let ndarray = NdArray::arange(0i32, 5); // [0, 1, 2, 3, 4].
156 /// ```
157 fn arange(start: T, stop: T) -> Self
158 where
159 T: NumericDataType
160 {
161 Self::arange_with_step(start, stop, T::one())
162 }
163
164 /// Generates a 1D ndarray with evenly spaced values within a specified range.
165 ///
166 /// # Arguments
167 ///
168 /// * `start` - The starting value of the sequence, inclusive.
169 /// * `stop` - The ending value of the sequence, exclusive.
170 /// * `step` - The interval between each consecutive value
171 ///
172 /// # Examples
173 ///
174 /// ```
175 /// # use redstone_ml::*;
176 /// let ndarray = NdArray::arange_with_step(0i32, 5, 2); // [0, 2, 4].
177 /// ```
178 fn arange_with_step(start: T, stop: T, step: T) -> Self
179 where
180 T: NumericDataType
181 {
182 let n = ((stop - start).to_float() / step.to_float()).ceil();
183 let n = NumCast::from(n).unwrap();
184
185 let mut data: Vec<T> = vec![T::default(); n];
186 for (i, item) in data.iter_mut().enumerate() {
187 *item = <T as NumCast>::from(i).unwrap() * step + start;
188 }
189
190 unsafe { Self::from_contiguous_owned_buffer(vec![data.len()], data) }
191 }
192
193 /// Generates a 1-dimensional ndarray with `num `evenly spaced values between `start` and `stop`
194 /// (inclusive).
195 ///
196 /// # Arguments
197 ///
198 /// * `start` - The starting value of the sequence.
199 /// * `stop` - The ending value of the sequence. The value is inclusive in the range.
200 /// * `num` - The number of evenly spaced values to generate. Must be greater than 0.
201 ///
202 /// # Panic
203 ///
204 /// Panics if `num` is 0.
205 ///
206 /// # Example
207 ///
208 /// ```
209 /// # use redstone_ml::*;
210 /// let result = NdArray::linspace(0f32, 1.0, 5); // [0.0, 0.25, 0.5, 0.75, 1.0]
211 /// assert_eq!(result, NdArray::new([0f32, 0.25, 0.5, 0.75, 1.0]));
212 /// ```
213 fn linspace(start: T, stop: T, num: usize) -> Self
214 where
215 T: FloatDataType
216 {
217 assert!(num > 0);
218
219 if num == 1 {
220 return unsafe { Self::from_contiguous_owned_buffer(vec![1], vec![start]) };
221 }
222
223 let step = (stop - start) / (<T as NumCast>::from(num).unwrap() - T::one());
224
225 // from start to (stop + step) to make the range inclusive
226 Self::arange_with_step(start, stop + step, step)
227 }
228
229 /// Generates a 1-dimensional ndarray with `num `evenly spaced values between `start` and `stop`
230 /// (exclusive).
231 ///
232 /// # Arguments
233 ///
234 /// * `start` - The starting value of the sequence.
235 /// * `stop` - The ending value of the sequence. The value is exclusive in the range.
236 /// * `num` - The number of evenly spaced values to generate. Must be greater than 0.
237 ///
238 /// # Panic
239 ///
240 /// Panics if `num` is 0.
241 ///
242 /// # Example
243 ///
244 /// ```
245 /// # use redstone_ml::*;
246 /// let result = NdArray::linspace_exclusive(0.0f32, 1.0, 5);
247 /// assert_eq!(result, NdArray::new([0f32, 0.2, 0.4, 0.6, 0.8]));
248 /// ```
249 fn linspace_exclusive(start: T, stop: T, num: usize) -> Self
250 where
251 T: FloatDataType
252 {
253 assert!(num > 0);
254
255 if num == 1 {
256 return unsafe { Self::from_contiguous_owned_buffer(vec![1], vec![start]) };
257 }
258
259 let step = (stop - start) / <T as NumCast>::from(num).unwrap();
260 Self::arange_with_step(start, stop, step)
261 }
262}