numrst/core/ndarray/
mod.rs

1mod construct;
2mod indexer;
3mod iter;
4mod display;
5mod shape;
6mod arith;
7mod matmul;
8mod reduce;
9mod broadcast;
10mod convert;
11mod condition;
12
13use std::sync::Arc;
14pub use indexer::{Range, IndexOp};
15use crate::{Error, Result};
16use super::{view::{AsMatrixView, AsMatrixViewMut, AsVectorView, AsVectorViewMut, MatrixView, MatrixViewMut, MatrixViewUsf, VectorView, VectorViewMut, VectorViewUsf}, DType, Dim, DimCoordinates, DimNCoordinates, Layout, NumDType, Shape, Storage, StorageArc, StorageIndices, StorageMut, StorageRef, WithDType};
17pub use iter::*;
18pub use indexer::*;
19
20#[derive(Clone)]
21pub struct NdArray<D>(pub(crate) Arc<NdArrayImpl<D>>);
22
23#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
24pub struct NdArrayId(usize);
25
26impl NdArrayId {
27    pub fn new() -> Self {
28        use std::sync::atomic;
29        static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
30        Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
31    }
32}
33
34pub struct NdArrayImpl<T> {
35    pub(crate) id: NdArrayId,
36    pub(crate) storage: StorageArc<T>,
37    pub(crate) layout: Layout,
38}
39
40impl<T: WithDType> NdArray<T> {
41    pub fn is_scalar(&self) -> bool {
42        self.shape().is_scalar()
43    }
44
45    pub fn check_scalar(&self) -> Result<()> {
46        if !self.is_scalar() {
47            Err(Error::NotScalar)
48        } else {
49            Ok(())
50        }
51    }
52
53    pub fn to_scalar(&self) -> Result<T> {
54        self.check_scalar()?;
55        let v = self.storage_ref(self.layout().start_offset()).get_unchecked(0);
56        Ok(v)
57    }
58
59    pub fn set_scalar(&self, val: T) -> Result<()> {
60        self.check_scalar()?;
61        self.storage_mut(self.layout().start_offset()).set_unchecked(0, val);
62        Ok(())
63    }
64
65    #[inline]
66    pub fn storage_ref<'a>(&'a self, start_offset: usize) -> StorageRef<'a, T> {
67        self.0.storage.get_ref(start_offset)
68    }
69
70    #[inline]
71    pub fn storage_mut<'a>(&'a self, start_offset: usize) -> StorageMut<'a, T> {
72        self.0.storage.get_mut(start_offset)
73    }
74
75    #[inline]
76    pub fn storage_ptr(&self, start_offset: usize) -> *mut T {
77        self.0.storage.get_ptr(start_offset)
78    }
79}
80
81impl<T: WithDType> NdArray<T> {
82    pub fn id(&self) -> usize {
83        self.0.id.0
84    }
85
86    pub fn shape(&self) -> &Shape {
87        self.0.layout.shape()
88    }
89
90    pub fn dtype(&self) -> DType {
91        T::DTYPE
92    }
93
94    pub fn layout(&self) -> &Layout {
95        &self.0.layout
96    }
97
98    pub fn dims(&self) -> &[usize] {
99        self.shape().dims()
100    }
101
102    pub fn dim<D: Dim>(&self, dim: D) -> Result<usize> {
103        let dim = dim.to_index(self.shape(), "dim")?;
104        Ok(self.dims()[dim])
105    }
106
107    pub fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage<T>> {
108        self.0.storage.0.read().unwrap()
109    }
110
111    pub fn element_count(&self) -> usize {
112        self.shape().element_count()
113    }
114
115    pub fn is_contiguous(&self) -> bool {
116        self.layout().is_contiguous()
117    }
118
119    pub fn rank(&self) -> usize {
120        self.shape().rank()
121    }
122
123    pub fn to_vec(&self) -> Vec<T> {
124        self.iter().collect()
125    }
126
127    /// Returns an iterator over **storage indices**.
128    ///
129    /// This iterator yields the linear (flat) indices as they are laid out
130    /// in the underlying storage buffer. The order depends on the memory
131    /// layout (e.g., row-major / column-major / with strides).
132    ///
133    /// Example for shape = (2, 2) in row-major layout:
134    /// yields: `0, 1, 2, 3`
135    pub fn storage_indices(&self) -> StorageIndices {
136        self.layout().storage_indices()
137    }
138
139    /// Returns an iterator over **dimension coordinates**.
140    ///
141    /// This iterator yields the multi-dimensional coordinates
142    /// (e.g., `[i, j, k, ...]`) of each element in the array, independent
143    /// of the physical storage layout.
144    ///
145    /// Example for shape = (2, 2):
146    /// yields: `[0, 0], [0, 1], [1, 0], [1, 1]`
147    pub fn dim_coordinates(&self) -> DimCoordinates {
148        self.shape().dim_coordinates()
149    }
150
151    pub fn dims_coordinates<const N: usize>(&self) -> Result<DimNCoordinates<N>> {
152        self.shape().dims_coordinates::<N>()
153    }
154
155    pub fn dim2_coordinates(&self) -> Result<DimNCoordinates<2>> {
156        self.shape().dim2_coordinates()
157    }
158
159    pub fn dim3_coordinates(&self) -> Result<DimNCoordinates<3>> {
160        self.shape().dim3_coordinates()
161    }
162
163    pub fn dim4_coordinates(&self) -> Result<DimNCoordinates<4>> {
164        self.shape().dim4_coordinates()
165    }
166
167    pub fn dim5_coordinates(&self) -> Result<DimNCoordinates<5>> {
168        self.shape().dim5_coordinates()
169    }
170}
171
172impl<T: WithDType> NdArray<T> {
173    pub fn matrix_view_unsafe(&self) -> Result<MatrixViewUsf<'_, T>> {
174        MatrixViewUsf::from_ndarray(self)
175    }
176
177    pub fn vector_view_unsafe(&self) -> Result<VectorViewUsf<'_, T>> {
178        VectorViewUsf::from_ndarray(self)
179    }
180
181    pub fn matrix_view<'a>(&'a self) -> Result<MatrixView<'a, T>> {
182        MatrixView::from_ndarray(self)
183    }
184
185    pub fn matrix_view_mut<'a>(&'a mut self) -> Result<MatrixViewMut<'a, T>> {
186        MatrixViewMut::from_ndarray_mut(self)
187    }
188
189    pub fn vector_view<'a>(&'a self) -> Result<VectorView<'a, T>> {
190        VectorView::from_ndarray(self)
191    }
192
193    pub fn vector_view_mut<'a>(&'a mut self) -> Result<VectorViewMut<'a, T>> {
194        VectorViewMut::from_ndarray_mut(self)
195    }
196}
197
198impl<T: NumDType> NdArray<T> {
199    pub fn allclose(&self, other: &Self, rtol: f64, atol: f64) -> bool {
200        if self.shape() != other.shape() {
201            return false;
202        }
203        self.iter().zip(other.iter()).all(|(a, b)| a.close(b, rtol, atol))
204    }
205}