numrs/array/array.rs
1use crate::array::{DType, DTypeValue};
2use crate::ir::HloNode;
3use anyhow::Result;
4
5/// A lightweight n-dimensional array with generic dtype support.
6///
7/// Arrays are memory contiguous (row-major) and can hold different data types.
8/// The type parameter `T` must implement `DTypeValue` trait.
9///
10/// Arrays now support zero-copy views with strides and offset:
11/// - `strides`: Step size in each dimension (None = C-contiguous)
12/// - `offset`: Starting offset in data buffer (0 = start at beginning)
13#[derive(Clone, Debug)]
14pub struct Array<T: DTypeValue = f32> {
15 pub shape: Vec<usize>,
16 pub dtype: DType,
17 pub data: Vec<T>,
18 /// Strides for each dimension (None = C-contiguous layout)
19 pub strides: Option<Vec<isize>>,
20 /// Offset into data buffer for view (0 = no offset)
21 pub offset: usize,
22}
23
24impl<T: DTypeValue> Array<T> {
25 /// Create a new array given a shape and a flat data vector.
26 /// The dtype is automatically inferred from type T.
27 pub fn new(shape: Vec<usize>, data: Vec<T>) -> Self {
28 assert_eq!(
29 shape.iter().product::<usize>(),
30 data.len(),
31 "shape length mismatch"
32 );
33 Self {
34 shape,
35 dtype: T::dtype(),
36 data,
37 strides: None, // C-contiguous by default
38 offset: 0, // No offset
39 }
40 }
41
42 /// Create a new array with implicit type conversion from any DTypeValue.
43 ///
44 /// This allows you to create arrays with automatic casting from input data
45 /// to the target dtype using the promotion engine.
46 ///
47 /// # Examples
48 /// ```
49 /// use numrs::Array;
50 ///
51 /// // Create f32 array from i32 data - automatic conversion
52 /// let arr = Array::<f32>::from_vec(vec![3], vec![1i32, 2, 3]);
53 /// assert_eq!(arr.data, vec![1.0f32, 2.0, 3.0]);
54 ///
55 /// // Create i32 array from f32 data - automatic conversion
56 /// let arr = Array::<i32>::from_vec(vec![2], vec![1.5f32, 2.7]);
57 /// assert_eq!(arr.data, vec![1, 2]);
58 /// ```
59 pub fn from_vec<U: DTypeValue>(shape: Vec<usize>, data: Vec<U>) -> Self {
60 assert_eq!(
61 shape.iter().product::<usize>(),
62 data.len(),
63 "shape length mismatch"
64 );
65
66 // Si los tipos son iguales, no hay conversión (zero-cost)
67 if U::dtype() == T::dtype() {
68 // Verificamos explícitamente tamaño y alineación para evitar UB.
69 assert_eq!(std::mem::size_of::<U>(), std::mem::size_of::<T>());
70 assert_eq!(std::mem::align_of::<U>(), std::mem::align_of::<T>());
71
72 let mut data = std::mem::ManuallyDrop::new(data);
73 let (ptr, len, cap) = (data.as_mut_ptr(), data.len(), data.capacity());
74
75 let converted_data = unsafe { Vec::from_raw_parts(ptr as *mut T, len, cap) };
76 return Self {
77 shape,
78 dtype: T::dtype(),
79 data: converted_data,
80 strides: None,
81 offset: 0,
82 };
83 }
84
85 // Si los tipos son diferentes, usar el motor de promoción
86 let converted_data: Vec<T> = data.iter().map(|&val| T::from_f32(val.to_f32())).collect();
87
88 Self {
89 shape,
90 dtype: T::dtype(),
91 data: converted_data,
92 strides: None,
93 offset: 0,
94 }
95 }
96
97 /// Helper to convert any array to f32 (useful for tests and interoperability)
98 pub fn to_f32(&self) -> Array<f32> {
99 let new_data: Vec<f32> = self.data.iter().map(|v| v.to_f32()).collect();
100 Array {
101 shape: self.shape.clone(),
102 dtype: DType::F32,
103 data: new_data,
104 strides: self.strides.clone(),
105 offset: 0,
106 }
107 }
108
109 /// Create an array initialized with zeros for a shape.
110 pub fn zeros(shape: Vec<usize>) -> Self {
111 let len = shape.iter().product();
112 Self {
113 shape,
114 dtype: T::dtype(),
115 data: vec![T::default(); len],
116 strides: None,
117 offset: 0,
118 }
119 }
120
121 /// Create an array initialized with ones.
122 pub fn ones(shape: Vec<usize>) -> Self {
123 let len: usize = shape.iter().product();
124 let one = T::from_f32(1.0);
125 Self {
126 shape,
127 dtype: T::dtype(),
128 data: vec![one; len],
129 strides: None,
130 offset: 0,
131 }
132 }
133
134 /// Return shape as a slice
135 pub fn shape(&self) -> &[usize] {
136 &self.shape
137 }
138
139 /// Return the dtype of this array
140 pub fn dtype(&self) -> DType {
141 self.dtype
142 }
143
144 /// Helper to get length
145 pub fn len(&self) -> usize {
146 self.data.len()
147 }
148
149 /// Check if array is C-contiguous (no custom strides)
150 pub fn is_contiguous(&self) -> bool {
151 self.strides.is_none() && self.offset == 0
152 }
153
154 /// Compute C-contiguous strides for current shape
155 pub fn compute_default_strides(&self) -> Vec<isize> {
156 let mut strides = vec![1isize; self.shape.len()];
157 for i in (0..self.shape.len().saturating_sub(1)).rev() {
158 strides[i] = strides[i + 1] * (self.shape[i + 1] as isize);
159 }
160 strides
161 }
162
163 /// Get effective strides (either custom or computed default)
164 pub fn get_strides(&self) -> Vec<isize> {
165 self.strides
166 .clone()
167 .unwrap_or_else(|| self.compute_default_strides())
168 }
169
170 /// Create a broadcast view without copying data
171 ///
172 /// This creates a zero-copy view that logically broadcasts the array
173 /// to a new shape by adjusting strides (setting stride to 0 for broadcast dims).
174 ///
175 /// # Arguments
176 /// * `target_shape` - The shape to broadcast to
177 ///
178 /// # Returns
179 /// A new Array that shares the same data but appears to have the target shape
180 pub fn broadcast_view(&self, target_shape: &[usize]) -> Result<Self> {
181 // Validar que el broadcast sea posible
182 crate::ops::shape::broadcast_to::validate_broadcast_public(&self.shape, target_shape)?;
183
184 let src_ndim = self.shape.len();
185 let target_ndim = target_shape.len();
186
187 // Calcular nuevos strides
188 let src_strides = self.get_strides();
189 let mut new_strides = vec![0isize; target_ndim];
190
191 // Mapear desde el final (right-aligned)
192 for i in 0..src_ndim {
193 let src_idx = src_ndim - 1 - i;
194 let target_idx = target_ndim - 1 - i;
195
196 if self.shape[src_idx] == 1 {
197 // Dimensión broadcast: stride = 0 (reutilizar mismo valor)
198 new_strides[target_idx] = 0;
199 } else {
200 // Dimensión normal: usar stride original
201 new_strides[target_idx] = src_strides[src_idx];
202 }
203 }
204
205 // Dimensiones nuevas (padding izquierdo) tienen stride 0
206 for i in 0..(target_ndim - src_ndim) {
207 new_strides[i] = 0;
208 }
209
210 Ok(Self {
211 shape: target_shape.to_vec(),
212 dtype: self.dtype,
213 data: self.data.clone(), // Compartir referencia (Rc en futuro?)
214 strides: Some(new_strides),
215 offset: self.offset,
216 })
217 }
218
219 /// Materialize a view into a contiguous array
220 ///
221 /// If the array is already contiguous, returns a clone.
222 /// Otherwise, creates a new contiguous array with the data arranged properly.
223 pub fn to_contiguous(&self) -> Self {
224 if self.is_contiguous() {
225 return self.clone();
226 }
227
228 let size: usize = self.shape.iter().product();
229 let mut result = Vec::with_capacity(size);
230 let strides = self.get_strides();
231
232 // Iterar sobre todos los elementos en orden C-contiguous
233 let mut indices = vec![0usize; self.shape.len()];
234
235 for _ in 0..size {
236 // Calcular offset con strides
237 let mut flat_idx = self.offset as isize;
238 for (i, &idx) in indices.iter().enumerate() {
239 flat_idx += idx as isize * strides[i];
240 }
241
242 // Bounds check: asegurar que el índice es válido
243 let flat_idx_usize = flat_idx as usize;
244 if flat_idx_usize >= self.data.len() {
245 // Fallback: usar módulo para wrap around (broadcasting)
246 let safe_idx = flat_idx_usize % self.data.len().max(1);
247 result.push(self.data[safe_idx]);
248 } else {
249 result.push(self.data[flat_idx_usize]);
250 }
251
252 // Incrementar índices (orden C)
253 for i in (0..self.shape.len()).rev() {
254 indices[i] += 1;
255 if indices[i] < self.shape[i] {
256 break;
257 }
258 indices[i] = 0;
259 }
260 }
261
262 Self {
263 shape: self.shape.clone(),
264 dtype: self.dtype,
265 data: result,
266 strides: None,
267 offset: 0,
268 }
269 }
270
271 /// Build an HLO graph node representing a constant array
272 pub fn to_hlo_const(&self) -> HloNode {
273 HloNode::const_node(self.shape.clone())
274 }
275}