numrst/core/ndarray/
mod.rs1mod construct;
2mod indexer;
3mod iter;
4mod display;
5mod shape;
6mod arith;
7mod matmul;
8mod reduce;
9mod broadcast;
10mod convert;
11mod condition;
12
13use std::sync::Arc;
14pub use indexer::{Range, IndexOp};
15use crate::{Error, Result};
16use super::{view::{AsMatrixView, AsMatrixViewMut, AsVectorView, AsVectorViewMut, MatrixView, MatrixViewMut, MatrixViewUsf, VectorView, VectorViewMut, VectorViewUsf}, DType, Dim, DimCoordinates, DimNCoordinates, Layout, NumDType, Shape, Storage, StorageArc, StorageIndices, StorageMut, StorageRef, WithDType};
17pub use iter::*;
18pub use indexer::*;
19
20#[derive(Clone)]
21pub struct NdArray<D>(pub(crate) Arc<NdArrayImpl<D>>);
22
23#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
24pub struct NdArrayId(usize);
25
26impl NdArrayId {
27 pub fn new() -> Self {
28 use std::sync::atomic;
29 static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
30 Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
31 }
32}
33
34pub struct NdArrayImpl<T> {
35 pub(crate) id: NdArrayId,
36 pub(crate) storage: StorageArc<T>,
37 pub(crate) layout: Layout,
38}
39
40impl<T: WithDType> NdArray<T> {
41 pub fn is_scalar(&self) -> bool {
42 self.shape().is_scalar()
43 }
44
45 pub fn check_scalar(&self) -> Result<()> {
46 if !self.is_scalar() {
47 Err(Error::NotScalar)
48 } else {
49 Ok(())
50 }
51 }
52
53 pub fn to_scalar(&self) -> Result<T> {
54 self.check_scalar()?;
55 let v = self.storage_ref(self.layout().start_offset()).get_unchecked(0);
56 Ok(v)
57 }
58
59 pub fn set_scalar(&self, val: T) -> Result<()> {
60 self.check_scalar()?;
61 self.storage_mut(self.layout().start_offset()).set_unchecked(0, val);
62 Ok(())
63 }
64
65 #[inline]
66 pub fn storage_ref<'a>(&'a self, start_offset: usize) -> StorageRef<'a, T> {
67 self.0.storage.get_ref(start_offset)
68 }
69
70 #[inline]
71 pub fn storage_mut<'a>(&'a self, start_offset: usize) -> StorageMut<'a, T> {
72 self.0.storage.get_mut(start_offset)
73 }
74
75 #[inline]
76 pub fn storage_ptr(&self, start_offset: usize) -> *mut T {
77 self.0.storage.get_ptr(start_offset)
78 }
79}
80
81impl<T: WithDType> NdArray<T> {
82 pub fn id(&self) -> usize {
83 self.0.id.0
84 }
85
86 pub fn shape(&self) -> &Shape {
87 self.0.layout.shape()
88 }
89
90 pub fn dtype(&self) -> DType {
91 T::DTYPE
92 }
93
94 pub fn layout(&self) -> &Layout {
95 &self.0.layout
96 }
97
98 pub fn dims(&self) -> &[usize] {
99 self.shape().dims()
100 }
101
102 pub fn dim<D: Dim>(&self, dim: D) -> Result<usize> {
103 let dim = dim.to_index(self.shape(), "dim")?;
104 Ok(self.dims()[dim])
105 }
106
107 pub fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage<T>> {
108 self.0.storage.0.read().unwrap()
109 }
110
111 pub fn element_count(&self) -> usize {
112 self.shape().element_count()
113 }
114
115 pub fn is_contiguous(&self) -> bool {
116 self.layout().is_contiguous()
117 }
118
119 pub fn rank(&self) -> usize {
120 self.shape().rank()
121 }
122
123 pub fn to_vec(&self) -> Vec<T> {
124 self.iter().collect()
125 }
126
127 pub fn storage_indices(&self) -> StorageIndices {
136 self.layout().storage_indices()
137 }
138
139 pub fn dim_coordinates(&self) -> DimCoordinates {
148 self.shape().dim_coordinates()
149 }
150
151 pub fn dims_coordinates<const N: usize>(&self) -> Result<DimNCoordinates<N>> {
152 self.shape().dims_coordinates::<N>()
153 }
154
155 pub fn dim2_coordinates(&self) -> Result<DimNCoordinates<2>> {
156 self.shape().dim2_coordinates()
157 }
158
159 pub fn dim3_coordinates(&self) -> Result<DimNCoordinates<3>> {
160 self.shape().dim3_coordinates()
161 }
162
163 pub fn dim4_coordinates(&self) -> Result<DimNCoordinates<4>> {
164 self.shape().dim4_coordinates()
165 }
166
167 pub fn dim5_coordinates(&self) -> Result<DimNCoordinates<5>> {
168 self.shape().dim5_coordinates()
169 }
170}
171
172impl<T: WithDType> NdArray<T> {
173 pub fn matrix_view_unsafe(&self) -> Result<MatrixViewUsf<'_, T>> {
174 MatrixViewUsf::from_ndarray(self)
175 }
176
177 pub fn vector_view_unsafe(&self) -> Result<VectorViewUsf<'_, T>> {
178 VectorViewUsf::from_ndarray(self)
179 }
180
181 pub fn matrix_view<'a>(&'a self) -> Result<MatrixView<'a, T>> {
182 MatrixView::from_ndarray(self)
183 }
184
185 pub fn matrix_view_mut<'a>(&'a mut self) -> Result<MatrixViewMut<'a, T>> {
186 MatrixViewMut::from_ndarray_mut(self)
187 }
188
189 pub fn vector_view<'a>(&'a self) -> Result<VectorView<'a, T>> {
190 VectorView::from_ndarray(self)
191 }
192
193 pub fn vector_view_mut<'a>(&'a mut self) -> Result<VectorViewMut<'a, T>> {
194 VectorViewMut::from_ndarray_mut(self)
195 }
196}
197
198impl<T: NumDType> NdArray<T> {
199 pub fn allclose(&self, other: &Self, rtol: f64, atol: f64) -> bool {
200 if self.shape() != other.shape() {
201 return false;
202 }
203 self.iter().zip(other.iter()).all(|(a, b)| a.close(b, rtol, atol))
204 }
205}