1use crate::{buffer::Buffer, DType, Device, Shape};
4use std::fmt;
5use std::sync::atomic::{AtomicUsize, Ordering};
6
7static ARRAY_ID_COUNTER: AtomicUsize = AtomicUsize::new(0);
9
10fn next_array_id() -> usize {
12 ARRAY_ID_COUNTER.fetch_add(1, Ordering::Relaxed)
13}
14
15#[derive(Debug, Clone)]
37pub struct Array {
38 buffer: Buffer,
40 shape: Shape,
42 strides: Vec<usize>,
44 offset: usize,
46 id: usize,
48}
49
50impl Array {
51 pub fn zeros(shape: Shape, dtype: DType) -> Self {
62 let device = crate::default_device();
63 let size = shape.size();
64 let buffer = Buffer::zeros(size, dtype, device);
65 let strides = shape.default_strides();
66 Self { buffer, shape, strides, offset: 0, id: next_array_id() }
67 }
68
69 pub fn ones(shape: Shape, dtype: DType) -> Self {
71 let device = crate::default_device();
72 let size = shape.size();
73 let buffer = Buffer::filled(1.0, size, dtype, device);
74 let strides = shape.default_strides();
75 Self { buffer, shape, strides, offset: 0, id: next_array_id() }
76 }
77
78 pub fn full(value: f32, shape: Shape, dtype: DType) -> Self {
80 let device = crate::default_device();
81 let size = shape.size();
82 let buffer = Buffer::filled(value, size, dtype, device);
83 let strides = shape.default_strides();
84 Self { buffer, shape, strides, offset: 0, id: next_array_id() }
85 }
86
87 pub fn from_vec(data: Vec<f32>, shape: Shape) -> Self {
102 assert_eq!(
103 data.len(),
104 shape.size(),
105 "Data length must match shape size"
106 );
107 let device = crate::default_device();
108 let buffer = Buffer::from_f32(data, device);
109 let strides = shape.default_strides();
110 Self { buffer, shape, strides, offset: 0, id: next_array_id() }
111 }
112
113 pub fn from_vec_i32(data: Vec<i32>, shape: Shape) -> Self {
115 assert_eq!(data.len(), shape.size(), "Data length must match shape size");
116 let device = crate::default_device();
117 let buffer = Buffer::from_i32(data, device);
118 let strides = shape.default_strides();
119 Self { buffer, shape, strides, offset: 0, id: next_array_id() }
120 }
121
122 pub fn from_vec_i8(data: Vec<i8>, shape: Shape) -> Self {
124 assert_eq!(data.len(), shape.size(), "Data length must match shape size");
125 let device = crate::default_device();
126 let buffer = Buffer::from_i8(data, device);
127 let strides = shape.default_strides();
128 Self { buffer, shape, strides, offset: 0, id: next_array_id() }
129 }
130
131 pub fn from_vec_u8(data: Vec<u8>, shape: Shape) -> Self {
133 assert_eq!(data.len(), shape.size(), "Data length must match shape size");
134 let device = crate::default_device();
135 let buffer = Buffer::from_u8(data, device);
136 let strides = shape.default_strides();
137 Self { buffer, shape, strides, offset: 0, id: next_array_id() }
138 }
139
140 pub fn from_vec_i16(data: Vec<i16>, shape: Shape) -> Self {
142 assert_eq!(data.len(), shape.size(), "Data length must match shape size");
143 let device = crate::default_device();
144 let buffer = Buffer::from_i16(data, device);
145 let strides = shape.default_strides();
146 Self { buffer, shape, strides, offset: 0, id: next_array_id() }
147 }
148
149 pub fn from_vec_u16(data: Vec<u16>, shape: Shape) -> Self {
151 assert_eq!(data.len(), shape.size(), "Data length must match shape size");
152 let device = crate::default_device();
153 let buffer = Buffer::from_u16(data, device);
154 let strides = shape.default_strides();
155 Self { buffer, shape, strides, offset: 0, id: next_array_id() }
156 }
157
158 pub fn from_vec_i64(data: Vec<i64>, shape: Shape) -> Self {
160 assert_eq!(data.len(), shape.size(), "Data length must match shape size");
161 let device = crate::default_device();
162 let buffer = Buffer::from_i64(data, device);
163 let strides = shape.default_strides();
164 Self { buffer, shape, strides, offset: 0, id: next_array_id() }
165 }
166
167 pub fn from_vec_u32(data: Vec<u32>, shape: Shape) -> Self {
169 assert_eq!(data.len(), shape.size(), "Data length must match shape size");
170 let device = crate::default_device();
171 let buffer = Buffer::from_u32(data, device);
172 let strides = shape.default_strides();
173 Self { buffer, shape, strides, offset: 0, id: next_array_id() }
174 }
175
176 pub fn from_vec_u64(data: Vec<u64>, shape: Shape) -> Self {
178 assert_eq!(data.len(), shape.size(), "Data length must match shape size");
179 let device = crate::default_device();
180 let buffer = Buffer::from_u64(data, device);
181 let strides = shape.default_strides();
182 Self { buffer, shape, strides, offset: 0, id: next_array_id() }
183 }
184
185 pub fn from_vec_f64(data: Vec<f64>, shape: Shape) -> Self {
187 assert_eq!(data.len(), shape.size(), "Data length must match shape size");
188 let device = crate::default_device();
189 let buffer = Buffer::from_f64(data, device);
190 let strides = shape.default_strides();
191 Self { buffer, shape, strides, offset: 0, id: next_array_id() }
192 }
193
194 pub fn from_vec_bool(data: Vec<bool>, shape: Shape) -> Self {
196 assert_eq!(data.len(), shape.size(), "Data length must match shape size");
197 let device = crate::default_device();
198 let buffer = Buffer::from_bool(data, device);
199 let strides = shape.default_strides();
200 Self { buffer, shape, strides, offset: 0, id: next_array_id() }
201 }
202
203 pub(crate) fn from_buffer(buffer: Buffer, shape: Shape) -> Self {
205 let strides = shape.default_strides();
206 Self { buffer, shape, strides, offset: 0, id: next_array_id() }
207 }
208
209 #[inline]
211 pub fn shape(&self) -> &Shape {
212 &self.shape
213 }
214
215 #[inline]
217 pub fn dtype(&self) -> DType {
218 self.buffer.dtype()
219 }
220
221 #[inline]
223 pub fn device(&self) -> Device {
224 self.buffer.device()
225 }
226
227 #[inline]
229 pub(crate) fn buffer(&self) -> &Buffer {
230 &self.buffer
231 }
232
233 #[inline]
235 pub fn ndim(&self) -> usize {
236 self.shape.ndim()
237 }
238
239 #[inline]
241 pub fn size(&self) -> usize {
242 self.shape.size()
243 }
244
245 #[inline]
247 pub fn id(&self) -> usize {
248 self.id
249 }
250
251 #[inline]
253 pub fn is_scalar(&self) -> bool {
254 self.shape.is_scalar()
255 }
256
257 pub fn to_vec(&self) -> Vec<f32> {
262 if self.offset == 0 && self.strides == self.shape.default_strides() {
264 return self.buffer.to_f32_vec_converted();
265 }
266
267 let raw_data = self.buffer.to_f32_vec_converted();
270 let size = self.size();
271 let ndim = self.ndim();
272
273 if ndim == 0 {
274 return vec![raw_data[self.offset]];
276 }
277
278 let shape = self.shape.as_slice();
279 let mut result = Vec::with_capacity(size);
280
281 let mut indices = vec![0usize; ndim];
283 for _ in 0..size {
284 let physical_idx: usize = self.offset
286 + indices
287 .iter()
288 .zip(self.strides.iter())
289 .map(|(&i, &s)| i * s)
290 .sum::<usize>();
291
292 result.push(raw_data[physical_idx]);
293
294 for d in (0..ndim).rev() {
296 indices[d] += 1;
297 if indices[d] < shape[d] {
298 break;
299 }
300 indices[d] = 0;
301 }
302 }
303
304 result
305 }
306
307 pub fn to_bool_vec(&self) -> Vec<bool> {
309 assert_eq!(self.dtype(), DType::Bool, "to_bool_vec requires Bool dtype");
310
311 if self.offset == 0 && self.strides == self.shape.default_strides() {
313 return self.buffer.to_bool_vec();
314 }
315
316 let raw_data = self.buffer.to_bool_vec();
318 let size = self.size();
319 let ndim = self.ndim();
320
321 if ndim == 0 {
322 return vec![raw_data[self.offset]];
323 }
324
325 let shape = self.shape.as_slice();
326 let mut result = Vec::with_capacity(size);
327 let mut indices = vec![0usize; ndim];
328
329 for _ in 0..size {
330 let physical_idx: usize = self.offset
331 + indices
332 .iter()
333 .zip(self.strides.iter())
334 .map(|(&i, &s)| i * s)
335 .sum::<usize>();
336
337 result.push(raw_data[physical_idx]);
338
339 for d in (0..ndim).rev() {
340 indices[d] += 1;
341 if indices[d] < shape[d] {
342 break;
343 }
344 indices[d] = 0;
345 }
346 }
347
348 result
349 }
350
351 pub fn astype(&self, dtype: DType) -> Self {
364 if self.dtype() == dtype {
365 return self.clone();
366 }
367
368 let data = self.to_vec();
370
371 let device = self.device();
373 let shape = self.shape.clone();
374
375 let buffer = match dtype {
376 DType::Float32 => Buffer::from_f32(data, device),
377 DType::Float64 => {
378 let casted: Vec<f64> = data.iter().map(|&x| x as f64).collect();
379 Buffer::from_f64(casted, device)
380 }
381 DType::Float16 => {
382 Buffer::from_f32_as_dtype(data, DType::Float16, device)
384 }
385 DType::Int8 => {
386 let casted: Vec<i8> = data.iter().map(|&x| x as i8).collect();
387 Buffer::from_i8(casted, device)
388 }
389 DType::Int16 => {
390 let casted: Vec<i16> = data.iter().map(|&x| x as i16).collect();
391 Buffer::from_i16(casted, device)
392 }
393 DType::Int32 => {
394 let casted: Vec<i32> = data.iter().map(|&x| x as i32).collect();
395 Buffer::from_i32(casted, device)
396 }
397 DType::Int64 => {
398 let casted: Vec<i64> = data.iter().map(|&x| x as i64).collect();
399 Buffer::from_i64(casted, device)
400 }
401 DType::Uint8 => {
402 let casted: Vec<u8> = data.iter().map(|&x| x as u8).collect();
403 Buffer::from_u8(casted, device)
404 }
405 DType::Uint16 => {
406 let casted: Vec<u16> = data.iter().map(|&x| x as u16).collect();
407 Buffer::from_u16(casted, device)
408 }
409 DType::Uint32 => {
410 let casted: Vec<u32> = data.iter().map(|&x| x as u32).collect();
411 Buffer::from_u32(casted, device)
412 }
413 DType::Uint64 => {
414 let casted: Vec<u64> = data.iter().map(|&x| x as u64).collect();
415 Buffer::from_u64(casted, device)
416 }
417 DType::Bool => {
418 let casted: Vec<bool> = data.iter().map(|&x| x != 0.0).collect();
419 Buffer::from_bool(casted, device)
420 }
421 };
422
423 let strides = shape.default_strides();
424 Self { buffer, shape, strides, offset: 0, id: next_array_id() }
425 }
426
427 pub fn to_device(&self, device: Device) -> Array {
441 if self.device() == device {
442 return self.clone();
443 }
444
445 match (self.device(), device) {
446 (Device::Cpu, Device::WebGpu) => {
447 let data = self.to_vec();
449 let buffer = Buffer::from_f32(data, Device::WebGpu);
450 Array::from_buffer(buffer, self.shape().clone())
451 }
452 (Device::WebGpu, Device::Cpu) => {
453 let data = self.buffer().to_f32_vec();
455 let buffer = Buffer::from_f32(data, Device::Cpu);
456 Array::from_buffer(buffer, self.shape().clone())
457 }
458 (Device::Cpu, Device::Wasm) | (Device::Wasm, Device::Cpu) => {
459 let data = self.to_vec();
461 let buffer = Buffer::from_f32(data, device);
462 Array::from_buffer(buffer, self.shape().clone())
463 }
464 (Device::WebGpu, Device::Wasm) | (Device::Wasm, Device::WebGpu) => {
465 let cpu = self.to_device(Device::Cpu);
467 cpu.to_device(device)
468 }
469 _ => self.clone()
471 }
472 }
473
474 pub fn reshape(&self, new_shape: Shape) -> Self {
480 assert_eq!(
481 self.shape.size(),
482 new_shape.size(),
483 "Cannot reshape array of size {} into shape of size {}",
484 self.shape.size(),
485 new_shape.size()
486 );
487 assert_eq!(self.offset, 0);
489 assert_eq!(self.strides, self.shape.default_strides());
490
491 Self {
492 buffer: self.buffer.clone(),
493 shape: new_shape.clone(),
494 strides: new_shape.default_strides(),
495 offset: 0,
496 id: next_array_id(),
497 }
498 }
499
500 pub fn squeeze(&self) -> Self {
511 let new_dims: Vec<usize> = self
512 .shape
513 .as_slice()
514 .iter()
515 .filter(|&&dim| dim != 1)
516 .copied()
517 .collect();
518
519 let new_shape = if new_dims.is_empty() {
520 Shape::scalar()
521 } else {
522 Shape::new(new_dims)
523 };
524
525 self.reshape(new_shape)
526 }
527
528 pub fn squeeze_axis(&self, axis: usize) -> Self {
532 let dims = self.shape.as_slice();
533 assert!(axis < dims.len(), "Axis {} out of bounds", axis);
534 assert_eq!(dims[axis], 1, "Can only squeeze axis with size 1");
535
536 let mut new_dims = dims.to_vec();
537 new_dims.remove(axis);
538
539 let new_shape = if new_dims.is_empty() {
540 Shape::scalar()
541 } else {
542 Shape::new(new_dims)
543 };
544
545 self.reshape(new_shape)
546 }
547
548 pub fn expand_dims(&self, axis: usize) -> Self {
565 let mut new_dims = self.shape.as_slice().to_vec();
566 assert!(
567 axis <= new_dims.len(),
568 "Axis {} out of bounds for array with {} dimensions",
569 axis,
570 new_dims.len()
571 );
572 new_dims.insert(axis, 1);
573 self.reshape(Shape::new(new_dims))
574 }
575}
576
577impl fmt::Display for Array {
578 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
579 write!(f, "Array:{}{}", self.dtype(), self.shape())
580 }
581}
582
583#[cfg(test)]
584mod tests {
585 use super::*;
586
587 #[test]
588 fn test_array_zeros() {
589 let a = Array::zeros(Shape::new(vec![2, 3]), DType::Float32);
590 assert_eq!(a.shape().as_slice(), &[2, 3]);
591 assert_eq!(a.dtype(), DType::Float32);
592 assert_eq!(a.size(), 6);
593 assert_eq!(a.ndim(), 2);
594 let data = a.to_vec();
595 assert_eq!(data.len(), 6);
596 assert!(data.iter().all(|&x| x == 0.0));
597 }
598
599 #[test]
600 fn test_array_ones() {
601 let a = Array::ones(Shape::new(vec![3, 2]), DType::Float32);
602 assert_eq!(a.shape().as_slice(), &[3, 2]);
603 let data = a.to_vec();
604 assert!(data.iter().all(|&x| x == 1.0));
605 }
606
607 #[test]
608 fn test_array_full() {
609 let a = Array::full(5.0, Shape::new(vec![2, 2]), DType::Float32);
610 let data = a.to_vec();
611 assert_eq!(data, vec![5.0, 5.0, 5.0, 5.0]);
612 }
613
614 #[test]
615 fn test_array_from_vec() {
616 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
617 let a = Array::from_vec(data.clone(), Shape::new(vec![2, 3]));
618 assert_eq!(a.shape().as_slice(), &[2, 3]);
619 assert_eq!(a.to_vec(), data);
620 }
621
622 #[test]
623 fn test_array_reshape() {
624 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
625 let a = Array::from_vec(data.clone(), Shape::new(vec![2, 3]));
626 let b = a.reshape(Shape::new(vec![3, 2]));
627 assert_eq!(b.shape().as_slice(), &[3, 2]);
628 assert_eq!(b.to_vec(), data);
629
630 let c = a.reshape(Shape::new(vec![6]));
631 assert_eq!(c.shape().as_slice(), &[6]);
632 }
633
634 #[test]
635 fn test_array_display() {
636 let a = Array::zeros(Shape::new(vec![2, 3]), DType::Float32);
637 let s = a.to_string();
638 assert!(s.contains("float32"));
639 assert!(s.contains("2"));
640 assert!(s.contains("3"));
641 }
642
643 #[test]
644 fn test_array_clone() {
645 let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
646 let b = a.clone();
647 assert_eq!(a.to_vec(), b.to_vec());
648 assert_eq!(a.shape(), b.shape());
649 }
650
651 #[test]
652 #[should_panic(expected = "Data length must match shape size")]
653 fn test_array_from_vec_size_mismatch() {
654 let _a = Array::from_vec(vec![1.0, 2.0], Shape::new(vec![3]));
655 }
656
657 #[test]
658 #[should_panic(expected = "Cannot reshape")]
659 fn test_array_reshape_size_mismatch() {
660 let a = Array::zeros(Shape::new(vec![2, 3]), DType::Float32);
661 let _b = a.reshape(Shape::new(vec![2, 2]));
662 }
663
664 #[test]
665 fn test_array_zeros_all_dtypes() {
666 let dtypes = [
668 DType::Float32, DType::Float64, DType::Float16,
669 DType::Int8, DType::Int16, DType::Int32, DType::Int64,
670 DType::Uint8, DType::Uint16, DType::Uint32, DType::Uint64,
671 DType::Bool,
672 ];
673 for dtype in dtypes {
674 let a = Array::zeros(Shape::new(vec![2, 3]), dtype);
675 assert_eq!(a.dtype(), dtype);
676 assert_eq!(a.shape().as_slice(), &[2, 3]);
677 let data = a.to_vec();
678 assert!(data.iter().all(|&x| x == 0.0));
679 }
680 }
681
682 #[test]
683 fn test_array_ones_all_dtypes() {
684 let dtypes = [
685 DType::Float32, DType::Float64, DType::Float16,
686 DType::Int8, DType::Int16, DType::Int32, DType::Int64,
687 DType::Uint8, DType::Uint16, DType::Uint32, DType::Uint64,
688 ];
689 for dtype in dtypes {
690 let a = Array::ones(Shape::new(vec![3]), dtype);
691 assert_eq!(a.dtype(), dtype);
692 let data = a.to_vec();
693 assert!(data.iter().all(|&x| x == 1.0));
694 }
695 }
696
697 #[test]
698 fn test_array_from_vec_typed() {
699 let a = Array::from_vec_i32(vec![1, 2, 3], Shape::new(vec![3]));
701 assert_eq!(a.dtype(), DType::Int32);
702 assert_eq!(a.to_vec(), vec![1.0, 2.0, 3.0]);
703
704 let b = Array::from_vec_i8(vec![-1, 0, 127], Shape::new(vec![3]));
706 assert_eq!(b.dtype(), DType::Int8);
707 assert_eq!(b.to_vec(), vec![-1.0, 0.0, 127.0]);
708
709 let c = Array::from_vec_u8(vec![0, 128, 255], Shape::new(vec![3]));
711 assert_eq!(c.dtype(), DType::Uint8);
712 assert_eq!(c.to_vec(), vec![0.0, 128.0, 255.0]);
713
714 let d = Array::from_vec_bool(vec![true, false, true], Shape::new(vec![3]));
716 assert_eq!(d.dtype(), DType::Bool);
717 assert_eq!(d.to_vec(), vec![1.0, 0.0, 1.0]);
718 }
719
720 #[test]
721 fn test_array_astype() {
722 let a = Array::from_vec(vec![1.0, 2.5, 3.9], Shape::new(vec![3]));
723
724 let b = a.astype(DType::Int32);
726 assert_eq!(b.dtype(), DType::Int32);
727 assert_eq!(b.to_vec(), vec![1.0, 2.0, 3.0]);
728
729 let c = Array::from_vec(vec![0.0, 1.0, 5.0], Shape::new(vec![3]));
731 let d = c.astype(DType::Bool);
732 assert_eq!(d.dtype(), DType::Bool);
733 assert_eq!(d.to_vec(), vec![0.0, 1.0, 1.0]);
734
735 let e = a.astype(DType::Float32);
737 assert_eq!(e.dtype(), DType::Float32);
738 assert_eq!(e.to_vec(), a.to_vec());
739 }
740
741 #[test]
742 fn test_array_to_bool_vec() {
743 let a = Array::from_vec_bool(vec![true, false, true, false], Shape::new(vec![4]));
744 let data = a.to_bool_vec();
745 assert_eq!(data, vec![true, false, true, false]);
746 }
747
748 #[test]
749 fn test_strided_to_vec_transposed() {
750 let buffer = Buffer::from_f32(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], Device::Cpu);
754 let shape = Shape::new(vec![3, 2]);
755 let strides = vec![1, 3]; let arr = Array {
757 buffer,
758 shape,
759 strides,
760 offset: 0,
761 id: next_array_id(),
762 };
763
764 let result = arr.to_vec();
767 assert_eq!(result, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
768 }
769
770 #[test]
771 fn test_strided_to_vec_with_offset() {
772 let buffer = Buffer::from_f32(
776 vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
777 Device::Cpu,
778 );
779 let shape = Shape::new(vec![2, 3]);
780 let strides = vec![3, 1]; let arr = Array {
782 buffer,
783 shape,
784 strides,
785 offset: 2,
786 id: next_array_id(),
787 };
788
789 let result = arr.to_vec();
790 assert_eq!(result, vec![2.0, 3.0, 4.0, 5.0, 6.0, 7.0]);
791 }
792
793 #[test]
794 fn test_strided_to_vec_every_other() {
795 let buffer = Buffer::from_f32(
799 vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
800 Device::Cpu,
801 );
802 let shape = Shape::new(vec![4]);
803 let strides = vec![2]; let arr = Array {
805 buffer,
806 shape,
807 strides,
808 offset: 0,
809 id: next_array_id(),
810 };
811
812 let result = arr.to_vec();
813 assert_eq!(result, vec![0.0, 2.0, 4.0, 6.0]);
814 }
815
816 #[test]
817 fn test_strided_to_vec_3d() {
818 let buffer = Buffer::from_f32((0..24).map(|x| x as f32).collect(), Device::Cpu);
823 let shape = Shape::new(vec![4, 3, 2]);
824 let strides = vec![1, 4, 12]; let arr = Array {
826 buffer,
827 shape,
828 strides,
829 offset: 0,
830 id: next_array_id(),
831 };
832
833 let result = arr.to_vec();
843 assert_eq!(
844 result,
845 vec![
846 0.0, 12.0, 4.0, 16.0, 8.0, 20.0, 1.0, 13.0, 5.0, 17.0, 9.0, 21.0, 2.0, 14.0, 6.0, 18.0, 10.0, 22.0, 3.0, 15.0, 7.0, 19.0, 11.0, 23.0 ]
851 );
852 }
853}