numrs/array/
array.rs

1use crate::array::{DType, DTypeValue};
2use crate::ir::HloNode;
3use anyhow::Result;
4
5/// A lightweight n-dimensional array with generic dtype support.
6///
7/// Arrays are memory contiguous (row-major) and can hold different data types.
8/// The type parameter `T` must implement `DTypeValue` trait.
9///
10/// Arrays now support zero-copy views with strides and offset:
11/// - `strides`: Step size in each dimension (None = C-contiguous)
12/// - `offset`: Starting offset in data buffer (0 = start at beginning)
13#[derive(Clone, Debug)]
14pub struct Array<T: DTypeValue = f32> {
15    pub shape: Vec<usize>,
16    pub dtype: DType,
17    pub data: Vec<T>,
18    /// Strides for each dimension (None = C-contiguous layout)
19    pub strides: Option<Vec<isize>>,
20    /// Offset into data buffer for view (0 = no offset)
21    pub offset: usize,
22}
23
24impl<T: DTypeValue> Array<T> {
25    /// Create a new array given a shape and a flat data vector.
26    /// The dtype is automatically inferred from type T.
27    pub fn new(shape: Vec<usize>, data: Vec<T>) -> Self {
28        assert_eq!(
29            shape.iter().product::<usize>(),
30            data.len(),
31            "shape length mismatch"
32        );
33        Self {
34            shape,
35            dtype: T::dtype(),
36            data,
37            strides: None, // C-contiguous by default
38            offset: 0,     // No offset
39        }
40    }
41
42    /// Create a new array with implicit type conversion from any DTypeValue.
43    ///
44    /// This allows you to create arrays with automatic casting from input data
45    /// to the target dtype using the promotion engine.
46    ///
47    /// # Examples
48    /// ```
49    /// use numrs::Array;
50    ///
51    /// // Create f32 array from i32 data - automatic conversion
52    /// let arr = Array::<f32>::from_vec(vec![3], vec![1i32, 2, 3]);
53    /// assert_eq!(arr.data, vec![1.0f32, 2.0, 3.0]);
54    ///
55    /// // Create i32 array from f32 data - automatic conversion
56    /// let arr = Array::<i32>::from_vec(vec![2], vec![1.5f32, 2.7]);
57    /// assert_eq!(arr.data, vec![1, 2]);
58    /// ```
59    pub fn from_vec<U: DTypeValue>(shape: Vec<usize>, data: Vec<U>) -> Self {
60        assert_eq!(
61            shape.iter().product::<usize>(),
62            data.len(),
63            "shape length mismatch"
64        );
65
66        // Si los tipos son iguales, no hay conversión (zero-cost)
67        if U::dtype() == T::dtype() {
68            // Verificamos explícitamente tamaño y alineación para evitar UB.
69            assert_eq!(std::mem::size_of::<U>(), std::mem::size_of::<T>());
70            assert_eq!(std::mem::align_of::<U>(), std::mem::align_of::<T>());
71
72            let mut data = std::mem::ManuallyDrop::new(data);
73            let (ptr, len, cap) = (data.as_mut_ptr(), data.len(), data.capacity());
74
75            let converted_data = unsafe { Vec::from_raw_parts(ptr as *mut T, len, cap) };
76            return Self {
77                shape,
78                dtype: T::dtype(),
79                data: converted_data,
80                strides: None,
81                offset: 0,
82            };
83        }
84
85        // Si los tipos son diferentes, usar el motor de promoción
86        let converted_data: Vec<T> = data.iter().map(|&val| T::from_f32(val.to_f32())).collect();
87
88        Self {
89            shape,
90            dtype: T::dtype(),
91            data: converted_data,
92            strides: None,
93            offset: 0,
94        }
95    }
96
97    /// Helper to convert any array to f32 (useful for tests and interoperability)
98    pub fn to_f32(&self) -> Array<f32> {
99        let new_data: Vec<f32> = self.data.iter().map(|v| v.to_f32()).collect();
100        Array {
101            shape: self.shape.clone(),
102            dtype: DType::F32,
103            data: new_data,
104            strides: self.strides.clone(),
105            offset: 0,
106        }
107    }
108
109    /// Create an array initialized with zeros for a shape.
110    pub fn zeros(shape: Vec<usize>) -> Self {
111        let len = shape.iter().product();
112        Self {
113            shape,
114            dtype: T::dtype(),
115            data: vec![T::default(); len],
116            strides: None,
117            offset: 0,
118        }
119    }
120
121    /// Create an array initialized with ones.
122    pub fn ones(shape: Vec<usize>) -> Self {
123        let len: usize = shape.iter().product();
124        let one = T::from_f32(1.0);
125        Self {
126            shape,
127            dtype: T::dtype(),
128            data: vec![one; len],
129            strides: None,
130            offset: 0,
131        }
132    }
133
134    /// Return shape as a slice
135    pub fn shape(&self) -> &[usize] {
136        &self.shape
137    }
138
139    /// Return the dtype of this array
140    pub fn dtype(&self) -> DType {
141        self.dtype
142    }
143
144    /// Helper to get length
145    pub fn len(&self) -> usize {
146        self.data.len()
147    }
148
149    /// Check if array is C-contiguous (no custom strides)
150    pub fn is_contiguous(&self) -> bool {
151        self.strides.is_none() && self.offset == 0
152    }
153
154    /// Compute C-contiguous strides for current shape
155    pub fn compute_default_strides(&self) -> Vec<isize> {
156        let mut strides = vec![1isize; self.shape.len()];
157        for i in (0..self.shape.len().saturating_sub(1)).rev() {
158            strides[i] = strides[i + 1] * (self.shape[i + 1] as isize);
159        }
160        strides
161    }
162
163    /// Get effective strides (either custom or computed default)
164    pub fn get_strides(&self) -> Vec<isize> {
165        self.strides
166            .clone()
167            .unwrap_or_else(|| self.compute_default_strides())
168    }
169
170    /// Create a broadcast view without copying data
171    ///
172    /// This creates a zero-copy view that logically broadcasts the array
173    /// to a new shape by adjusting strides (setting stride to 0 for broadcast dims).
174    ///
175    /// # Arguments
176    /// * `target_shape` - The shape to broadcast to
177    ///
178    /// # Returns
179    /// A new Array that shares the same data but appears to have the target shape
180    pub fn broadcast_view(&self, target_shape: &[usize]) -> Result<Self> {
181        // Validar que el broadcast sea posible
182        crate::ops::shape::broadcast_to::validate_broadcast_public(&self.shape, target_shape)?;
183
184        let src_ndim = self.shape.len();
185        let target_ndim = target_shape.len();
186
187        // Calcular nuevos strides
188        let src_strides = self.get_strides();
189        let mut new_strides = vec![0isize; target_ndim];
190
191        // Mapear desde el final (right-aligned)
192        for i in 0..src_ndim {
193            let src_idx = src_ndim - 1 - i;
194            let target_idx = target_ndim - 1 - i;
195
196            if self.shape[src_idx] == 1 {
197                // Dimensión broadcast: stride = 0 (reutilizar mismo valor)
198                new_strides[target_idx] = 0;
199            } else {
200                // Dimensión normal: usar stride original
201                new_strides[target_idx] = src_strides[src_idx];
202            }
203        }
204
205        // Dimensiones nuevas (padding izquierdo) tienen stride 0
206        for i in 0..(target_ndim - src_ndim) {
207            new_strides[i] = 0;
208        }
209
210        Ok(Self {
211            shape: target_shape.to_vec(),
212            dtype: self.dtype,
213            data: self.data.clone(), // Compartir referencia (Rc en futuro?)
214            strides: Some(new_strides),
215            offset: self.offset,
216        })
217    }
218
219    /// Materialize a view into a contiguous array
220    ///
221    /// If the array is already contiguous, returns a clone.
222    /// Otherwise, creates a new contiguous array with the data arranged properly.
223    pub fn to_contiguous(&self) -> Self {
224        if self.is_contiguous() {
225            return self.clone();
226        }
227
228        let size: usize = self.shape.iter().product();
229        let mut result = Vec::with_capacity(size);
230        let strides = self.get_strides();
231
232        // Iterar sobre todos los elementos en orden C-contiguous
233        let mut indices = vec![0usize; self.shape.len()];
234
235        for _ in 0..size {
236            // Calcular offset con strides
237            let mut flat_idx = self.offset as isize;
238            for (i, &idx) in indices.iter().enumerate() {
239                flat_idx += idx as isize * strides[i];
240            }
241
242            // Bounds check: asegurar que el índice es válido
243            let flat_idx_usize = flat_idx as usize;
244            if flat_idx_usize >= self.data.len() {
245                // Fallback: usar módulo para wrap around (broadcasting)
246                let safe_idx = flat_idx_usize % self.data.len().max(1);
247                result.push(self.data[safe_idx]);
248            } else {
249                result.push(self.data[flat_idx_usize]);
250            }
251
252            // Incrementar índices (orden C)
253            for i in (0..self.shape.len()).rev() {
254                indices[i] += 1;
255                if indices[i] < self.shape[i] {
256                    break;
257                }
258                indices[i] = 0;
259            }
260        }
261
262        Self {
263            shape: self.shape.clone(),
264            dtype: self.dtype,
265            data: result,
266            strides: None,
267            offset: 0,
268        }
269    }
270
271    /// Build an HLO graph node representing a constant array
272    pub fn to_hlo_const(&self) -> HloNode {
273        HloNode::const_node(self.shape.clone())
274    }
275}