jax_rs/
array.rs

1//! Core Array type for n-dimensional numeric arrays.
2
3use crate::{buffer::Buffer, DType, Device, Shape};
4use std::fmt;
5use std::sync::atomic::{AtomicUsize, Ordering};
6
7/// Global counter for generating unique array IDs
8static ARRAY_ID_COUNTER: AtomicUsize = AtomicUsize::new(0);
9
10/// Generate a unique ID for an array
11fn next_array_id() -> usize {
12    ARRAY_ID_COUNTER.fetch_add(1, Ordering::Relaxed)
13}
14
15/// A multidimensional numeric array.
16///
17/// This is the core data type of jax-rs, equivalent to `jax.Array` in JAX
18/// or `torch.Tensor` in PyTorch. Unlike jax-js which uses manual reference
19/// counting (`.ref` and `.dispose()`), Rust's ownership system provides
20/// automatic memory management.
21///
22/// # Memory Model
23///
24/// Arrays own their data through an `Arc<Buffer>`, allowing cheap cloning
25/// and zero-copy views. When the last reference to a buffer is dropped,
26/// the memory is automatically freed.
27///
28/// # Examples
29///
30/// ```
31/// # use jax_rs::{Array, DType, Shape};
32/// // Create a 2x3 array of zeros
33/// let a = Array::zeros(Shape::new(vec![2, 3]), DType::Float32);
34/// assert_eq!(a.shape().as_slice(), &[2, 3]);
35/// ```
36#[derive(Debug, Clone)]
37pub struct Array {
38    /// Underlying data buffer
39    buffer: Buffer,
40    /// Shape of the array
41    shape: Shape,
42    /// Strides for indexing (in elements, not bytes)
43    strides: Vec<usize>,
44    /// Offset into the buffer (in elements)
45    offset: usize,
46    /// Unique ID for tracing (pointer address)
47    id: usize,
48}
49
50impl Array {
51    /// Create a new array filled with zeros.
52    ///
53    /// # Examples
54    ///
55    /// ```
56    /// # use jax_rs::{Array, DType, Shape, Device, default_device};
57    /// let a = Array::zeros(Shape::new(vec![2, 3]), DType::Float32);
58    /// assert_eq!(a.shape().as_slice(), &[2, 3]);
59    /// assert_eq!(a.dtype(), DType::Float32);
60    /// ```
61    pub fn zeros(shape: Shape, dtype: DType) -> Self {
62        let device = crate::default_device();
63        let size = shape.size();
64        let buffer = Buffer::zeros(size, dtype, device);
65        let strides = shape.default_strides();
66        Self { buffer, shape, strides, offset: 0, id: next_array_id() }
67    }
68
69    /// Create a new array filled with ones.
70    pub fn ones(shape: Shape, dtype: DType) -> Self {
71        let device = crate::default_device();
72        let size = shape.size();
73        let buffer = Buffer::filled(1.0, size, dtype, device);
74        let strides = shape.default_strides();
75        Self { buffer, shape, strides, offset: 0, id: next_array_id() }
76    }
77
78    /// Create a new array filled with a specific value.
79    pub fn full(value: f32, shape: Shape, dtype: DType) -> Self {
80        let device = crate::default_device();
81        let size = shape.size();
82        let buffer = Buffer::filled(value, size, dtype, device);
83        let strides = shape.default_strides();
84        Self { buffer, shape, strides, offset: 0, id: next_array_id() }
85    }
86
87    /// Create an array from a flat Vec<f32> and shape.
88    ///
89    /// # Panics
90    ///
91    /// Panics if the shape size doesn't match the data length.
92    ///
93    /// # Examples
94    ///
95    /// ```
96    /// # use jax_rs::{Array, Shape};
97    /// let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
98    /// let a = Array::from_vec(data, Shape::new(vec![2, 3]));
99    /// assert_eq!(a.shape().as_slice(), &[2, 3]);
100    /// ```
101    pub fn from_vec(data: Vec<f32>, shape: Shape) -> Self {
102        assert_eq!(
103            data.len(),
104            shape.size(),
105            "Data length must match shape size"
106        );
107        let device = crate::default_device();
108        let buffer = Buffer::from_f32(data, device);
109        let strides = shape.default_strides();
110        Self { buffer, shape, strides, offset: 0, id: next_array_id() }
111    }
112
113    /// Create an array from a Vec<i32>.
114    pub fn from_vec_i32(data: Vec<i32>, shape: Shape) -> Self {
115        assert_eq!(data.len(), shape.size(), "Data length must match shape size");
116        let device = crate::default_device();
117        let buffer = Buffer::from_i32(data, device);
118        let strides = shape.default_strides();
119        Self { buffer, shape, strides, offset: 0, id: next_array_id() }
120    }
121
122    /// Create an array from a Vec<i8>.
123    pub fn from_vec_i8(data: Vec<i8>, shape: Shape) -> Self {
124        assert_eq!(data.len(), shape.size(), "Data length must match shape size");
125        let device = crate::default_device();
126        let buffer = Buffer::from_i8(data, device);
127        let strides = shape.default_strides();
128        Self { buffer, shape, strides, offset: 0, id: next_array_id() }
129    }
130
131    /// Create an array from a Vec<u8>.
132    pub fn from_vec_u8(data: Vec<u8>, shape: Shape) -> Self {
133        assert_eq!(data.len(), shape.size(), "Data length must match shape size");
134        let device = crate::default_device();
135        let buffer = Buffer::from_u8(data, device);
136        let strides = shape.default_strides();
137        Self { buffer, shape, strides, offset: 0, id: next_array_id() }
138    }
139
140    /// Create an array from a Vec<i16>.
141    pub fn from_vec_i16(data: Vec<i16>, shape: Shape) -> Self {
142        assert_eq!(data.len(), shape.size(), "Data length must match shape size");
143        let device = crate::default_device();
144        let buffer = Buffer::from_i16(data, device);
145        let strides = shape.default_strides();
146        Self { buffer, shape, strides, offset: 0, id: next_array_id() }
147    }
148
149    /// Create an array from a Vec<u16>.
150    pub fn from_vec_u16(data: Vec<u16>, shape: Shape) -> Self {
151        assert_eq!(data.len(), shape.size(), "Data length must match shape size");
152        let device = crate::default_device();
153        let buffer = Buffer::from_u16(data, device);
154        let strides = shape.default_strides();
155        Self { buffer, shape, strides, offset: 0, id: next_array_id() }
156    }
157
158    /// Create an array from a Vec<i64>.
159    pub fn from_vec_i64(data: Vec<i64>, shape: Shape) -> Self {
160        assert_eq!(data.len(), shape.size(), "Data length must match shape size");
161        let device = crate::default_device();
162        let buffer = Buffer::from_i64(data, device);
163        let strides = shape.default_strides();
164        Self { buffer, shape, strides, offset: 0, id: next_array_id() }
165    }
166
167    /// Create an array from a Vec<u32>.
168    pub fn from_vec_u32(data: Vec<u32>, shape: Shape) -> Self {
169        assert_eq!(data.len(), shape.size(), "Data length must match shape size");
170        let device = crate::default_device();
171        let buffer = Buffer::from_u32(data, device);
172        let strides = shape.default_strides();
173        Self { buffer, shape, strides, offset: 0, id: next_array_id() }
174    }
175
176    /// Create an array from a Vec<u64>.
177    pub fn from_vec_u64(data: Vec<u64>, shape: Shape) -> Self {
178        assert_eq!(data.len(), shape.size(), "Data length must match shape size");
179        let device = crate::default_device();
180        let buffer = Buffer::from_u64(data, device);
181        let strides = shape.default_strides();
182        Self { buffer, shape, strides, offset: 0, id: next_array_id() }
183    }
184
185    /// Create an array from a Vec<f64>.
186    pub fn from_vec_f64(data: Vec<f64>, shape: Shape) -> Self {
187        assert_eq!(data.len(), shape.size(), "Data length must match shape size");
188        let device = crate::default_device();
189        let buffer = Buffer::from_f64(data, device);
190        let strides = shape.default_strides();
191        Self { buffer, shape, strides, offset: 0, id: next_array_id() }
192    }
193
194    /// Create an array from a Vec<bool>.
195    pub fn from_vec_bool(data: Vec<bool>, shape: Shape) -> Self {
196        assert_eq!(data.len(), shape.size(), "Data length must match shape size");
197        let device = crate::default_device();
198        let buffer = Buffer::from_bool(data, device);
199        let strides = shape.default_strides();
200        Self { buffer, shape, strides, offset: 0, id: next_array_id() }
201    }
202
203    /// Create an array from a buffer and shape (internal use).
204    pub(crate) fn from_buffer(buffer: Buffer, shape: Shape) -> Self {
205        let strides = shape.default_strides();
206        Self { buffer, shape, strides, offset: 0, id: next_array_id() }
207    }
208
209    /// Get the shape of the array.
210    #[inline]
211    pub fn shape(&self) -> &Shape {
212        &self.shape
213    }
214
215    /// Get the data type of the array.
216    #[inline]
217    pub fn dtype(&self) -> DType {
218        self.buffer.dtype()
219    }
220
221    /// Get the device where this array lives.
222    #[inline]
223    pub fn device(&self) -> Device {
224        self.buffer.device()
225    }
226
227    /// Get reference to the underlying buffer (internal use).
228    #[inline]
229    pub(crate) fn buffer(&self) -> &Buffer {
230        &self.buffer
231    }
232
233    /// Get the number of dimensions.
234    #[inline]
235    pub fn ndim(&self) -> usize {
236        self.shape.ndim()
237    }
238
239    /// Get the total number of elements.
240    #[inline]
241    pub fn size(&self) -> usize {
242        self.shape.size()
243    }
244
245    /// Get the unique ID of this array (for tracing).
246    #[inline]
247    pub fn id(&self) -> usize {
248        self.id
249    }
250
251    /// Check if this is a scalar (0-dimensional array).
252    #[inline]
253    pub fn is_scalar(&self) -> bool {
254        self.shape.is_scalar()
255    }
256
257    /// Copy data to a Vec<f32> (synchronous).
258    ///
259    /// This materializes the array and copies all data to CPU memory.
260    /// Data is converted to f32 regardless of the array's dtype.
261    pub fn to_vec(&self) -> Vec<f32> {
262        // Fast path: contiguous array
263        if self.offset == 0 && self.strides == self.shape.default_strides() {
264            return self.buffer.to_f32_vec_converted();
265        }
266
267        // Slow path: strided/sliced array
268        // Get raw buffer data and iterate through logical indices
269        let raw_data = self.buffer.to_f32_vec_converted();
270        let size = self.size();
271        let ndim = self.ndim();
272
273        if ndim == 0 {
274            // Scalar
275            return vec![raw_data[self.offset]];
276        }
277
278        let shape = self.shape.as_slice();
279        let mut result = Vec::with_capacity(size);
280
281        // Iterate through all logical indices in row-major order
282        let mut indices = vec![0usize; ndim];
283        for _ in 0..size {
284            // Compute physical offset for current logical index
285            let physical_idx: usize = self.offset
286                + indices
287                    .iter()
288                    .zip(self.strides.iter())
289                    .map(|(&i, &s)| i * s)
290                    .sum::<usize>();
291
292            result.push(raw_data[physical_idx]);
293
294            // Increment indices (row-major order, last dimension increments first)
295            for d in (0..ndim).rev() {
296                indices[d] += 1;
297                if indices[d] < shape[d] {
298                    break;
299                }
300                indices[d] = 0;
301            }
302        }
303
304        result
305    }
306
307    /// Copy data to a Vec<bool> (for Bool dtype arrays).
308    pub fn to_bool_vec(&self) -> Vec<bool> {
309        assert_eq!(self.dtype(), DType::Bool, "to_bool_vec requires Bool dtype");
310
311        // Fast path: contiguous array
312        if self.offset == 0 && self.strides == self.shape.default_strides() {
313            return self.buffer.to_bool_vec();
314        }
315
316        // Slow path: strided/sliced array
317        let raw_data = self.buffer.to_bool_vec();
318        let size = self.size();
319        let ndim = self.ndim();
320
321        if ndim == 0 {
322            return vec![raw_data[self.offset]];
323        }
324
325        let shape = self.shape.as_slice();
326        let mut result = Vec::with_capacity(size);
327        let mut indices = vec![0usize; ndim];
328
329        for _ in 0..size {
330            let physical_idx: usize = self.offset
331                + indices
332                    .iter()
333                    .zip(self.strides.iter())
334                    .map(|(&i, &s)| i * s)
335                    .sum::<usize>();
336
337            result.push(raw_data[physical_idx]);
338
339            for d in (0..ndim).rev() {
340                indices[d] += 1;
341                if indices[d] < shape[d] {
342                    break;
343                }
344                indices[d] = 0;
345            }
346        }
347
348        result
349    }
350
351    /// Cast array to a different dtype.
352    ///
353    /// # Examples
354    ///
355    /// ```
356    /// # use jax_rs::{Array, DType, Shape};
357    /// let a = Array::from_vec(vec![1.0, 2.5, 3.9], Shape::new(vec![3]));
358    /// let b = a.astype(DType::Int32);
359    /// assert_eq!(b.dtype(), DType::Int32);
360    /// let data = b.to_vec();
361    /// assert_eq!(data, vec![1.0, 2.0, 3.0]); // truncated to integers
362    /// ```
363    pub fn astype(&self, dtype: DType) -> Self {
364        if self.dtype() == dtype {
365            return self.clone();
366        }
367
368        // Read current data as f32
369        let data = self.to_vec();
370
371        // Create new array with target dtype, casting values
372        let device = self.device();
373        let shape = self.shape.clone();
374
375        let buffer = match dtype {
376            DType::Float32 => Buffer::from_f32(data, device),
377            DType::Float64 => {
378                let casted: Vec<f64> = data.iter().map(|&x| x as f64).collect();
379                Buffer::from_f64(casted, device)
380            }
381            DType::Float16 => {
382                // Float16 is stored as f32 internally for now
383                Buffer::from_f32_as_dtype(data, DType::Float16, device)
384            }
385            DType::Int8 => {
386                let casted: Vec<i8> = data.iter().map(|&x| x as i8).collect();
387                Buffer::from_i8(casted, device)
388            }
389            DType::Int16 => {
390                let casted: Vec<i16> = data.iter().map(|&x| x as i16).collect();
391                Buffer::from_i16(casted, device)
392            }
393            DType::Int32 => {
394                let casted: Vec<i32> = data.iter().map(|&x| x as i32).collect();
395                Buffer::from_i32(casted, device)
396            }
397            DType::Int64 => {
398                let casted: Vec<i64> = data.iter().map(|&x| x as i64).collect();
399                Buffer::from_i64(casted, device)
400            }
401            DType::Uint8 => {
402                let casted: Vec<u8> = data.iter().map(|&x| x as u8).collect();
403                Buffer::from_u8(casted, device)
404            }
405            DType::Uint16 => {
406                let casted: Vec<u16> = data.iter().map(|&x| x as u16).collect();
407                Buffer::from_u16(casted, device)
408            }
409            DType::Uint32 => {
410                let casted: Vec<u32> = data.iter().map(|&x| x as u32).collect();
411                Buffer::from_u32(casted, device)
412            }
413            DType::Uint64 => {
414                let casted: Vec<u64> = data.iter().map(|&x| x as u64).collect();
415                Buffer::from_u64(casted, device)
416            }
417            DType::Bool => {
418                let casted: Vec<bool> = data.iter().map(|&x| x != 0.0).collect();
419                Buffer::from_bool(casted, device)
420            }
421        };
422
423        let strides = shape.default_strides();
424        Self { buffer, shape, strides, offset: 0, id: next_array_id() }
425    }
426
427    /// Transfer array to a different device.
428    ///
429    /// If the array is already on the target device, returns a clone.
430    /// Otherwise, transfers the data to the new device.
431    ///
432    /// # Examples
433    ///
434    /// ```rust,no_run
435    /// # use jax_rs::{Array, Device, Shape, DType};
436    /// let cpu_arr = Array::zeros(Shape::new(vec![10]), DType::Float32);
437    /// let gpu_arr = cpu_arr.to_device(Device::WebGpu);
438    /// assert_eq!(gpu_arr.device(), Device::WebGpu);
439    /// ```
440    pub fn to_device(&self, device: Device) -> Array {
441        if self.device() == device {
442            return self.clone();
443        }
444
445        match (self.device(), device) {
446            (Device::Cpu, Device::WebGpu) => {
447                // Upload to GPU
448                let data = self.to_vec();
449                let buffer = Buffer::from_f32(data, Device::WebGpu);
450                Array::from_buffer(buffer, self.shape().clone())
451            }
452            (Device::WebGpu, Device::Cpu) => {
453                // Download from GPU
454                let data = self.buffer().to_f32_vec();
455                let buffer = Buffer::from_f32(data, Device::Cpu);
456                Array::from_buffer(buffer, self.shape().clone())
457            }
458            (Device::Cpu, Device::Wasm) | (Device::Wasm, Device::Cpu) => {
459                // CPU <-> WASM transfer
460                let data = self.to_vec();
461                let buffer = Buffer::from_f32(data, device);
462                Array::from_buffer(buffer, self.shape().clone())
463            }
464            (Device::WebGpu, Device::Wasm) | (Device::Wasm, Device::WebGpu) => {
465                // Go through CPU as intermediate
466                let cpu = self.to_device(Device::Cpu);
467                cpu.to_device(device)
468            }
469            // Same device (already handled by early return, but need exhaustive match)
470            _ => self.clone()
471        }
472    }
473
474    /// Reshape the array to a new shape.
475    ///
476    /// # Panics
477    ///
478    /// Panics if the total size doesn't match.
479    pub fn reshape(&self, new_shape: Shape) -> Self {
480        assert_eq!(
481            self.shape.size(),
482            new_shape.size(),
483            "Cannot reshape array of size {} into shape of size {}",
484            self.shape.size(),
485            new_shape.size()
486        );
487        // For now, require contiguous data for reshape
488        assert_eq!(self.offset, 0);
489        assert_eq!(self.strides, self.shape.default_strides());
490
491        Self {
492            buffer: self.buffer.clone(),
493            shape: new_shape.clone(),
494            strides: new_shape.default_strides(),
495            offset: 0,
496            id: next_array_id(),
497        }
498    }
499
500    /// Remove axes of length one from the array.
501    ///
502    /// # Examples
503    ///
504    /// ```
505    /// # use jax_rs::{Array, Shape};
506    /// let a = Array::zeros(Shape::new(vec![1, 3, 1, 4]), jax_rs::DType::Float32);
507    /// let b = a.squeeze();
508    /// assert_eq!(b.shape().as_slice(), &[3, 4]);
509    /// ```
510    pub fn squeeze(&self) -> Self {
511        let new_dims: Vec<usize> = self
512            .shape
513            .as_slice()
514            .iter()
515            .filter(|&&dim| dim != 1)
516            .copied()
517            .collect();
518
519        let new_shape = if new_dims.is_empty() {
520            Shape::scalar()
521        } else {
522            Shape::new(new_dims)
523        };
524
525        self.reshape(new_shape)
526    }
527
528    /// Remove a single dimension at the specified axis.
529    ///
530    /// The dimension at the given axis must be 1.
531    pub fn squeeze_axis(&self, axis: usize) -> Self {
532        let dims = self.shape.as_slice();
533        assert!(axis < dims.len(), "Axis {} out of bounds", axis);
534        assert_eq!(dims[axis], 1, "Can only squeeze axis with size 1");
535
536        let mut new_dims = dims.to_vec();
537        new_dims.remove(axis);
538
539        let new_shape = if new_dims.is_empty() {
540            Shape::scalar()
541        } else {
542            Shape::new(new_dims)
543        };
544
545        self.reshape(new_shape)
546    }
547
548    /// Expand the shape of an array by inserting a new axis.
549    ///
550    /// # Arguments
551    ///
552    /// * `axis` - Position where new axis is placed
553    ///
554    /// # Examples
555    ///
556    /// ```
557    /// # use jax_rs::{Array, Shape};
558    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
559    /// let b = a.expand_dims(0);
560    /// assert_eq!(b.shape().as_slice(), &[1, 3]);
561    /// let c = a.expand_dims(1);
562    /// assert_eq!(c.shape().as_slice(), &[3, 1]);
563    /// ```
564    pub fn expand_dims(&self, axis: usize) -> Self {
565        let mut new_dims = self.shape.as_slice().to_vec();
566        assert!(
567            axis <= new_dims.len(),
568            "Axis {} out of bounds for array with {} dimensions",
569            axis,
570            new_dims.len()
571        );
572        new_dims.insert(axis, 1);
573        self.reshape(Shape::new(new_dims))
574    }
575}
576
577impl fmt::Display for Array {
578    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
579        write!(f, "Array:{}{}", self.dtype(), self.shape())
580    }
581}
582
583#[cfg(test)]
584mod tests {
585    use super::*;
586
587    #[test]
588    fn test_array_zeros() {
589        let a = Array::zeros(Shape::new(vec![2, 3]), DType::Float32);
590        assert_eq!(a.shape().as_slice(), &[2, 3]);
591        assert_eq!(a.dtype(), DType::Float32);
592        assert_eq!(a.size(), 6);
593        assert_eq!(a.ndim(), 2);
594        let data = a.to_vec();
595        assert_eq!(data.len(), 6);
596        assert!(data.iter().all(|&x| x == 0.0));
597    }
598
599    #[test]
600    fn test_array_ones() {
601        let a = Array::ones(Shape::new(vec![3, 2]), DType::Float32);
602        assert_eq!(a.shape().as_slice(), &[3, 2]);
603        let data = a.to_vec();
604        assert!(data.iter().all(|&x| x == 1.0));
605    }
606
607    #[test]
608    fn test_array_full() {
609        let a = Array::full(5.0, Shape::new(vec![2, 2]), DType::Float32);
610        let data = a.to_vec();
611        assert_eq!(data, vec![5.0, 5.0, 5.0, 5.0]);
612    }
613
614    #[test]
615    fn test_array_from_vec() {
616        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
617        let a = Array::from_vec(data.clone(), Shape::new(vec![2, 3]));
618        assert_eq!(a.shape().as_slice(), &[2, 3]);
619        assert_eq!(a.to_vec(), data);
620    }
621
622    #[test]
623    fn test_array_reshape() {
624        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
625        let a = Array::from_vec(data.clone(), Shape::new(vec![2, 3]));
626        let b = a.reshape(Shape::new(vec![3, 2]));
627        assert_eq!(b.shape().as_slice(), &[3, 2]);
628        assert_eq!(b.to_vec(), data);
629
630        let c = a.reshape(Shape::new(vec![6]));
631        assert_eq!(c.shape().as_slice(), &[6]);
632    }
633
634    #[test]
635    fn test_array_display() {
636        let a = Array::zeros(Shape::new(vec![2, 3]), DType::Float32);
637        let s = a.to_string();
638        assert!(s.contains("float32"));
639        assert!(s.contains("2"));
640        assert!(s.contains("3"));
641    }
642
643    #[test]
644    fn test_array_clone() {
645        let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
646        let b = a.clone();
647        assert_eq!(a.to_vec(), b.to_vec());
648        assert_eq!(a.shape(), b.shape());
649    }
650
651    #[test]
652    #[should_panic(expected = "Data length must match shape size")]
653    fn test_array_from_vec_size_mismatch() {
654        let _a = Array::from_vec(vec![1.0, 2.0], Shape::new(vec![3]));
655    }
656
657    #[test]
658    #[should_panic(expected = "Cannot reshape")]
659    fn test_array_reshape_size_mismatch() {
660        let a = Array::zeros(Shape::new(vec![2, 3]), DType::Float32);
661        let _b = a.reshape(Shape::new(vec![2, 2]));
662    }
663
664    #[test]
665    fn test_array_zeros_all_dtypes() {
666        // Test zeros with all dtypes
667        let dtypes = [
668            DType::Float32, DType::Float64, DType::Float16,
669            DType::Int8, DType::Int16, DType::Int32, DType::Int64,
670            DType::Uint8, DType::Uint16, DType::Uint32, DType::Uint64,
671            DType::Bool,
672        ];
673        for dtype in dtypes {
674            let a = Array::zeros(Shape::new(vec![2, 3]), dtype);
675            assert_eq!(a.dtype(), dtype);
676            assert_eq!(a.shape().as_slice(), &[2, 3]);
677            let data = a.to_vec();
678            assert!(data.iter().all(|&x| x == 0.0));
679        }
680    }
681
682    #[test]
683    fn test_array_ones_all_dtypes() {
684        let dtypes = [
685            DType::Float32, DType::Float64, DType::Float16,
686            DType::Int8, DType::Int16, DType::Int32, DType::Int64,
687            DType::Uint8, DType::Uint16, DType::Uint32, DType::Uint64,
688        ];
689        for dtype in dtypes {
690            let a = Array::ones(Shape::new(vec![3]), dtype);
691            assert_eq!(a.dtype(), dtype);
692            let data = a.to_vec();
693            assert!(data.iter().all(|&x| x == 1.0));
694        }
695    }
696
697    #[test]
698    fn test_array_from_vec_typed() {
699        // Test i32
700        let a = Array::from_vec_i32(vec![1, 2, 3], Shape::new(vec![3]));
701        assert_eq!(a.dtype(), DType::Int32);
702        assert_eq!(a.to_vec(), vec![1.0, 2.0, 3.0]);
703
704        // Test i8
705        let b = Array::from_vec_i8(vec![-1, 0, 127], Shape::new(vec![3]));
706        assert_eq!(b.dtype(), DType::Int8);
707        assert_eq!(b.to_vec(), vec![-1.0, 0.0, 127.0]);
708
709        // Test u8
710        let c = Array::from_vec_u8(vec![0, 128, 255], Shape::new(vec![3]));
711        assert_eq!(c.dtype(), DType::Uint8);
712        assert_eq!(c.to_vec(), vec![0.0, 128.0, 255.0]);
713
714        // Test bool
715        let d = Array::from_vec_bool(vec![true, false, true], Shape::new(vec![3]));
716        assert_eq!(d.dtype(), DType::Bool);
717        assert_eq!(d.to_vec(), vec![1.0, 0.0, 1.0]);
718    }
719
720    #[test]
721    fn test_array_astype() {
722        let a = Array::from_vec(vec![1.0, 2.5, 3.9], Shape::new(vec![3]));
723
724        // Cast to Int32 (truncates)
725        let b = a.astype(DType::Int32);
726        assert_eq!(b.dtype(), DType::Int32);
727        assert_eq!(b.to_vec(), vec![1.0, 2.0, 3.0]);
728
729        // Cast to Bool
730        let c = Array::from_vec(vec![0.0, 1.0, 5.0], Shape::new(vec![3]));
731        let d = c.astype(DType::Bool);
732        assert_eq!(d.dtype(), DType::Bool);
733        assert_eq!(d.to_vec(), vec![0.0, 1.0, 1.0]);
734
735        // Cast same dtype returns clone
736        let e = a.astype(DType::Float32);
737        assert_eq!(e.dtype(), DType::Float32);
738        assert_eq!(e.to_vec(), a.to_vec());
739    }
740
741    #[test]
742    fn test_array_to_bool_vec() {
743        let a = Array::from_vec_bool(vec![true, false, true, false], Shape::new(vec![4]));
744        let data = a.to_bool_vec();
745        assert_eq!(data, vec![true, false, true, false]);
746    }
747
748    #[test]
749    fn test_strided_to_vec_transposed() {
750        // Create a 2x3 array and simulate a transpose by using reversed strides
751        // Original: [[1, 2, 3], [4, 5, 6]] stored as [1, 2, 3, 4, 5, 6]
752        // Transposed view: [[1, 4], [2, 5], [3, 6]] with shape [3, 2] and strides [1, 3]
753        let buffer = Buffer::from_f32(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], Device::Cpu);
754        let shape = Shape::new(vec![3, 2]);
755        let strides = vec![1, 3]; // Transposed strides
756        let arr = Array {
757            buffer,
758            shape,
759            strides,
760            offset: 0,
761            id: next_array_id(),
762        };
763
764        // to_vec should return elements in row-major order of the transposed view
765        // Row 0: [1, 4], Row 1: [2, 5], Row 2: [3, 6]
766        let result = arr.to_vec();
767        assert_eq!(result, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
768    }
769
770    #[test]
771    fn test_strided_to_vec_with_offset() {
772        // Create a buffer and access a slice with offset
773        // Buffer: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
774        // View: 2x3 starting at offset 2, so [[2, 3, 4], [5, 6, 7]]
775        let buffer = Buffer::from_f32(
776            vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
777            Device::Cpu,
778        );
779        let shape = Shape::new(vec![2, 3]);
780        let strides = vec![3, 1]; // Default strides for 2x3
781        let arr = Array {
782            buffer,
783            shape,
784            strides,
785            offset: 2,
786            id: next_array_id(),
787        };
788
789        let result = arr.to_vec();
790        assert_eq!(result, vec![2.0, 3.0, 4.0, 5.0, 6.0, 7.0]);
791    }
792
793    #[test]
794    fn test_strided_to_vec_every_other() {
795        // Create a view that takes every other element
796        // Buffer: [0, 1, 2, 3, 4, 5, 6, 7]
797        // View: shape [4], stride [2] -> [0, 2, 4, 6]
798        let buffer = Buffer::from_f32(
799            vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
800            Device::Cpu,
801        );
802        let shape = Shape::new(vec![4]);
803        let strides = vec![2]; // Every other element
804        let arr = Array {
805            buffer,
806            shape,
807            strides,
808            offset: 0,
809            id: next_array_id(),
810        };
811
812        let result = arr.to_vec();
813        assert_eq!(result, vec![0.0, 2.0, 4.0, 6.0]);
814    }
815
816    #[test]
817    fn test_strided_to_vec_3d() {
818        // Create a 3D strided view
819        // Buffer: 0..24
820        // Original shape: [2, 3, 4]
821        // View as transposed [4, 3, 2] with strides [1, 4, 12]
822        let buffer = Buffer::from_f32((0..24).map(|x| x as f32).collect(), Device::Cpu);
823        let shape = Shape::new(vec![4, 3, 2]);
824        let strides = vec![1, 4, 12]; // Transposed strides
825        let arr = Array {
826            buffer,
827            shape,
828            strides,
829            offset: 0,
830            id: next_array_id(),
831        };
832
833        // First few elements should be:
834        // [0][0][0] -> 0*1 + 0*4 + 0*12 = 0
835        // [0][0][1] -> 0*1 + 0*4 + 1*12 = 12
836        // [0][1][0] -> 0*1 + 1*4 + 0*12 = 4
837        // [0][1][1] -> 0*1 + 1*4 + 1*12 = 16
838        // [0][2][0] -> 0*1 + 2*4 + 0*12 = 8
839        // [0][2][1] -> 0*1 + 2*4 + 1*12 = 20
840        // [1][0][0] -> 1*1 + 0*4 + 0*12 = 1
841        // etc.
842        let result = arr.to_vec();
843        assert_eq!(
844            result,
845            vec![
846                0.0, 12.0, 4.0, 16.0, 8.0, 20.0, // [0][*][*]
847                1.0, 13.0, 5.0, 17.0, 9.0, 21.0, // [1][*][*]
848                2.0, 14.0, 6.0, 18.0, 10.0, 22.0, // [2][*][*]
849                3.0, 15.0, 7.0, 19.0, 11.0, 23.0  // [3][*][*]
850            ]
851        );
852    }
853}