chela/ndarray/constructors.rs
1use crate::dtype::NumericDataType;
2use crate::ndarray::flags::NdArrayFlags;
3use crate::ndarray::NdArray;
4use crate::util::flatten::Flatten;
5use crate::util::nested::Nested;
6use crate::util::shape::Shape;
7use crate::util::to_vec::ToVec;
8use crate::{FloatDataType, RawDataType};
9use num::NumCast;
10use std::mem::ManuallyDrop;
11use std::ptr::NonNull;
12
13/// Computes the stride of an ndarray from its given shape assuming a contiguous layout.
14///
15/// In the context of multidimensional arrays, the stride refers to the number of elements
16/// that need to be skipped in memory to move to the next element along each dimension.
17/// Strides are calculated by determining how many elements are spanned by the dimensions
18/// following a particular axis.
19///
20/// # Arguments
21///
22/// * `shape` - A slice representing the shape of the ndarray.
23///
24/// # Returns
25///
26/// A `Vec<usize>` containing the stride for each dimension of the ndarray, with the same
27/// length as the input `shape`. The result indicates how many elements need to be skipped
28/// in memory to traverse the ndarray along each dimension.
29///
30/// # Example
31///
32/// ```
33/// let shape = vec![5, 3, 2, 1];
34///
35/// // stride would be [10, 2, 1, 1]
36/// // Axis 0 (size 5): stride = 3 * 2 * 1 * 1 = 10
37/// // Axis 1 (size 3): stride = 2 * 1 * 1 = 2
38/// // Axis 2 (size 2): stride = 1 * 1
39/// // Axis 3 (size 1): stride is always 1
40/// ```
41pub(crate) fn stride_from_shape(shape: &[usize]) -> Vec<usize> {
42 let ndims = shape.len();
43 let mut stride = vec![0; ndims];
44
45 let mut p = 1;
46 for i in (0..ndims).rev() {
47 stride[i] = p;
48 p *= shape[i];
49 }
50
51 stride
52}
53
54impl<'a, T: RawDataType> NdArray<'a, T> {
55 /// Constructs a new ndarray from the given data buffer and shape assuming a contiguous layout
56 ///
57 /// # Parameters
58 /// - `shape`: A vector that defines the dimensions of the ndarray.
59 /// - `data`: The underlying buffer that holds the ndarray's elements.
60 /// - `requires_grad`: If gradients need to be computed for this ndarray.
61 ///
62 /// # Safety
63 /// - `data` must remain valid and not be used elsewhere after being passed to this function.
64 /// - `shape.iter().product()` must equal `data.len()`
65 pub(crate) unsafe fn from_contiguous_owned_buffer(shape: Vec<usize>, data: Vec<T>) -> Self {
66 let flags = NdArrayFlags::Owned | NdArrayFlags::Contiguous | NdArrayFlags::UniformStride | NdArrayFlags::Writeable;
67
68 // take control of the data so that Rust doesn't drop it once the vector goes out of scope
69 let mut data = ManuallyDrop::new(data);
70 let stride = stride_from_shape(&shape);
71
72 Self {
73 ptr: NonNull::new_unchecked(data.as_mut_ptr()),
74 len: data.len(),
75 capacity: data.capacity(),
76
77 shape,
78 stride,
79 flags,
80
81 _marker: Default::default(),
82 }
83 }
84
85 /// Constructs an n-dimensional `NdArray` from input data such as a vector or array.
86 ///
87 /// # Parameters
88 /// - `data`: a nested array or vector of valid data types (floats, integers, bools)
89 ///
90 /// # Panics
91 /// - If the input data has inhomogeneous dimensions, i.e., nested arrays do not have consistent sizes.
92 /// - If the input data is empty (cannot create a zero-length ndarray )
93 ///
94 /// # Example
95 /// ```
96 /// # use chela::*;
97 ///
98 /// let ndarray : NdArray<i32> = NdArray::from([[1, 2], [3, 4]]);
99 /// assert_eq!(ndarray.shape(), &[2, 2]);
100 ///
101 /// let ndarray = NdArray::from(vec![1f32, 2.0, 3.0, 4.0, 5.0]);
102 /// assert_eq!(ndarray.ndims(), 1);
103 /// ```
104 pub fn from<const D: usize>(data: impl Flatten<T> + Shape + Nested<{ D }>) -> Self {
105 assert!(data.check_homogenous(), "Tensor::from() failed, found inhomogeneous dimensions");
106
107 let shape = data.shape();
108 let data = data.flatten();
109
110 assert!(!data.is_empty(), "Tensor::from() failed, cannot create data buffer from empty data");
111
112 unsafe { NdArray::from_contiguous_owned_buffer(shape, data) }
113 }
114
115 /// Creates an ndarray filled with a specified value and given shape.
116 ///
117 /// # Parameters
118 ///
119 /// * `n` - The value to fill the ndarray with (can be any valid data type like float, integer, or bool).
120 /// * `shape` - An array or vector representing the shape of the ndarray (e.g. `[2, 3, 5]`).
121 ///
122 /// # Panics
123 /// This function panics if the provided shape is empty.
124 ///
125 /// # Examples
126 ///
127 /// ```
128 /// # use chela::*;
129 ///
130 /// let ndarray = NdArray::full(5i32, [2, 3]); // creates a 2x3 ndarray filled with the value 5.
131 /// let ndarray = NdArray::full(true, [2, 3, 5]); // creates a 2x3x5 ndarray filled with 'true'
132 /// ```
133 pub fn full(n: T, shape: impl ToVec<usize>) -> Self {
134 let shape = shape.to_vec();
135
136 let data = vec![n; shape.iter().product()];
137 assert!(!data.is_empty(), "Cannot create an empty tensor!");
138
139 unsafe { NdArray::from_contiguous_owned_buffer(shape, data) }
140 }
141
142 /// Creates a new ndarray filled with zeros with the given shape.
143 ///
144 /// # Parameters
145 /// - `shape`: An array or vector representing the shape of the ndarray (e.g. `[2, 3, 5]`).
146 ///
147 /// # Panics
148 /// This function panics if the provided shape is empty.
149 ///
150 /// # Examples
151 /// ```
152 /// # use chela::*;
153 ///
154 /// let ndarray = NdArray::<i32>::zeros([2, 3]);
155 /// let ndarray = NdArray::<bool>::zeros([2, 3]); // creates an ndarray filled with 'false'
156 /// ```
157 pub fn zeros(shape: impl ToVec<usize>) -> Self
158 where
159 T: From<bool>,
160 {
161 Self::full(false.into(), shape)
162 }
163
164 /// Creates a new ndarray filled with ones with the given shape.
165 ///
166 /// # Parameters
167 /// - `shape`: An array or vector representing the shape of the ndarray (e.g. `[2, 3, 5]`).
168 ///
169 /// # Panics
170 /// This function panics if the provided shape is empty.
171 ///
172 /// # Examples
173 /// ```
174 /// # use chela::*;
175 ///
176 /// let ndarray = NdArray::<i32>::ones([2, 3]);
177 /// let ndarray = NdArray::<bool>::ones([2, 3]); // creates an ndarray filled with 'true'
178 /// ```
179 pub fn ones(shape: impl ToVec<usize>) -> Self
180 where
181 T: From<bool>,
182 {
183 Self::full(true.into(), shape)
184 }
185
186 /// Creates a 0-dimensional (shapeless) ndarray containing a single value.
187 ///
188 /// # Parameters
189 /// - `n`: The value to be stored in the scalar ndarray.
190 ///
191 /// # Example
192 /// ```rust
193 /// # use chela::*;
194 ///
195 /// let scalar_array = NdArray::scalar(42);
196 /// assert_eq!(scalar_array.shape(), []);
197 /// assert_eq!(scalar_array.value(), 42);
198 /// ```
199 pub fn scalar(n: T) -> Self {
200 NdArray::full(n, [])
201 }
202
203 // Maybe we should support empty arrays one day.
204 // pub fn empty() -> Self {
205 // unsafe { NdArray::from_contiguous_owned_buffer(vec![0], vec![]) }
206 // }
207}
208
209impl<T: NumericDataType> NdArray<'_, T> {
210 /// Generates a 1D ndarray with evenly spaced values within a specified range.
211 ///
212 /// # Arguments
213 ///
214 /// * `start` - The starting value of the sequence, inclusive.
215 /// * `stop` - The ending value of the sequence, exclusive.
216 ///
217 /// # Returns
218 ///
219 /// An `NdArray` containing values starting from `start` and ending before `stop`,
220 /// with a step-size of 1.
221 ///
222 /// # Examples
223 ///
224 /// ```rust
225 /// # use chela::*;
226 /// let ndarray = NdArray::arange(0i32, 5); // [0, 1, 2, 3, 4].
227 /// ```
228 pub fn arange(start: T, stop: T) -> NdArray<'static, T> {
229 Self::arange_with_step(start, stop, T::one())
230 }
231
232 /// Generates a 1D ndarray with evenly spaced values within a specified range.
233 ///
234 /// # Arguments
235 ///
236 /// * `start` - The starting value of the sequence, inclusive.
237 /// * `stop` - The ending value of the sequence, exclusive.
238 /// * `step` - The interval between each consecutive value
239 ///
240 /// # Examples
241 ///
242 /// ```rust
243 /// # use chela::*;
244 /// let ndarray = NdArray::arange_with_step(0i32, 5, 2); // [0, 2, 4].
245 /// ```
246 pub fn arange_with_step(start: T, stop: T, step: T) -> NdArray<'static, T> {
247 let n = ((stop - start).to_float() / step.to_float()).ceil();
248 let n = NumCast::from(n).unwrap();
249
250 let mut data: Vec<T> = vec![T::default(); n];
251 for (i, item) in data.iter_mut().enumerate() {
252 *item = <T as NumCast>::from(i).unwrap() * step + start;
253 }
254
255 unsafe { NdArray::from_contiguous_owned_buffer(vec![data.len()], data) }
256 }
257}
258
259impl<T: FloatDataType> NdArray<'_, T> {
260 /// Generates a 1-dimensional ndarray with `num `evenly spaced values between `start` and `stop`
261 /// (inclusive).
262 ///
263 /// # Arguments
264 ///
265 /// * `start` - The starting value of the sequence.
266 /// * `stop` - The ending value of the sequence. The value is inclusive in the range.
267 /// * `num` - The number of evenly spaced values to generate. Must be greater than 0.
268 ///
269 /// # Panic
270 ///
271 /// Panics if `num` is 0.
272 ///
273 /// # Example
274 ///
275 /// ```
276 /// # use chela::*;
277 /// let result = NdArray::linspace(0f32, 1.0, 5); // [0.0, 0.25, 0.5, 0.75, 1.0]
278 /// assert_eq!(result, NdArray::from([0f32, 0.25, 0.5, 0.75, 1.0]));
279 /// ```
280 pub fn linspace(start: T, stop: T, num: usize) -> NdArray<'static, T> {
281 assert!(num > 0);
282
283 if num == 1 {
284 return unsafe { NdArray::from_contiguous_owned_buffer(vec![1], vec![start]) };
285 }
286
287 let step = (stop - start) / (<T as NumCast>::from(num).unwrap() - T::one());
288
289 // from start to (stop + step) to make the range inclusive
290 NdArray::arange_with_step(start, stop + step, step)
291 }
292
293 /// Generates a 1-dimensional ndarray with `num `evenly spaced values between `start` and `stop`
294 /// (exclusive).
295 ///
296 /// # Arguments
297 ///
298 /// * `start` - The starting value of the sequence.
299 /// * `stop` - The ending value of the sequence. The value is exclusive in the range.
300 /// * `num` - The number of evenly spaced values to generate. Must be greater than 0.
301 ///
302 /// # Panic
303 ///
304 /// Panics if `num` is 0.
305 ///
306 /// # Example
307 ///
308 /// ```
309 /// # use chela::*;
310 /// let result = NdArray::linspace_exclusive(0.0f32, 1.0, 5);
311 /// assert_eq!(result, NdArray::from([0f32, 0.2, 0.4, 0.6, 0.8]));
312 /// ```
313 pub fn linspace_exclusive(start: T, stop: T, num: usize) -> NdArray<'static, T> {
314 assert!(num > 0);
315
316 if num == 1 {
317 return unsafe { NdArray::from_contiguous_owned_buffer(vec![1], vec![start]) };
318 }
319
320 let step = (stop - start) / <T as NumCast>::from(num).unwrap();
321 NdArray::arange_with_step(start, stop, step)
322 }
323}
324
325impl<T: RawDataType> Drop for NdArray<'_, T> {
326 /// This method is implicitly invoked when the ndarray is deleted to clean up its memory if
327 /// the ndarray owns its data (i.e. it is not a view into another ndarray ).
328 ///
329 /// Resets `self.len` and `self.capacity` to 0.
330 fn drop(&mut self) {
331 if self.flags.contains(NdArrayFlags::Owned) {
332 // drops the data
333 unsafe { Vec::from_raw_parts(self.mut_ptr(), self.len, self.capacity) };
334 }
335
336 self.len = 0;
337 self.capacity = 0;
338 }
339}