1use crate::error::{Error, Result};
4use crate::numeric::Float;
5use crate::view2::validate_view;
6
7#[derive(Clone, Debug, PartialEq)]
9pub struct ArrayN<T> {
10 data: Vec<T>,
11 shape: Vec<usize>,
12 strides: Vec<isize>,
13}
14
15#[derive(Clone, Debug)]
17pub struct ArrayViewN<'a, T> {
18 data: &'a [T],
19 shape: Vec<usize>,
20 strides: Vec<isize>,
21 offset: isize,
22}
23
24#[derive(Debug)]
26pub struct ArrayViewMutN<'a, T> {
27 data: &'a mut [T],
28 shape: Vec<usize>,
29 strides: Vec<isize>,
30 offset: isize,
31}
32
33impl<T> ArrayN<T> {
34 pub fn from_vec(shape: Vec<usize>, data: Vec<T>) -> Result<Self> {
36 let expected = checked_len(&shape)?;
37 if data.len() != expected {
38 return Err(Error::shape(vec![expected], vec![data.len()]));
39 }
40 let strides = row_major_strides(&shape);
41 Ok(Self {
42 data,
43 shape,
44 strides,
45 })
46 }
47
48 pub fn shape(&self) -> &[usize] {
50 &self.shape
51 }
52
53 pub fn strides(&self) -> &[isize] {
55 &self.strides
56 }
57
58 pub fn ndim(&self) -> usize {
60 self.shape.len()
61 }
62
63 pub fn len(&self) -> usize {
65 self.data.len()
66 }
67
68 pub fn is_empty(&self) -> bool {
70 self.data.is_empty()
71 }
72
73 pub fn as_slice(&self) -> &[T] {
75 &self.data
76 }
77
78 pub fn as_mut_slice(&mut self) -> &mut [T] {
80 &mut self.data
81 }
82
83 pub fn view(&self) -> ArrayViewN<'_, T> {
85 ArrayViewN {
86 data: &self.data,
87 shape: self.shape.clone(),
88 strides: self.strides.clone(),
89 offset: 0,
90 }
91 }
92
93 pub fn view_mut(&mut self) -> ArrayViewMutN<'_, T> {
95 ArrayViewMutN {
96 data: &mut self.data,
97 shape: self.shape.clone(),
98 strides: self.strides.clone(),
99 offset: 0,
100 }
101 }
102
103 pub fn get(&self, index: &[usize]) -> Option<&T> {
105 self.linear_index(index).map(|idx| &self.data[idx])
106 }
107
108 pub fn slice_axis(&self, axis: usize, index: usize) -> Result<ArrayViewN<'_, T>> {
110 self.view().slice_axis(axis, index)
111 }
112
113 pub fn slice_axis_mut(&mut self, axis: usize, index: usize) -> Result<ArrayViewMutN<'_, T>> {
115 if axis >= self.ndim() {
116 return Err(Error::AxisOutOfBounds {
117 axis,
118 ndim: self.ndim(),
119 });
120 }
121 if index >= self.shape[axis] {
122 return Err(Error::IndexOutOfBounds);
123 }
124 let mut shape = self.shape.clone();
125 let mut strides = self.strides.clone();
126 let offset = index as isize * strides[axis];
127 shape.remove(axis);
128 strides.remove(axis);
129 Ok(ArrayViewMutN {
130 data: &mut self.data,
131 shape,
132 strides,
133 offset,
134 })
135 }
136
137 pub fn get_mut(&mut self, index: &[usize]) -> Option<&mut T> {
139 self.linear_index(index).map(|idx| &mut self.data[idx])
140 }
141
142 fn linear_index(&self, index: &[usize]) -> Option<usize> {
143 if index.len() != self.ndim() {
144 return None;
145 }
146 let mut linear = 0usize;
147 for ((&idx, &dim), &stride) in index.iter().zip(&self.shape).zip(&self.strides) {
148 if idx >= dim {
149 return None;
150 }
151 linear += idx * stride as usize;
152 }
153 Some(linear)
154 }
155}
156
157impl<T: Clone> ArrayN<T> {
158 pub fn filled(shape: Vec<usize>, value: T) -> Self {
160 let len = shape.iter().product();
161 let strides = row_major_strides(&shape);
162 Self {
163 data: vec![value; len],
164 shape,
165 strides,
166 }
167 }
168
169 pub fn try_filled(shape: Vec<usize>, value: T) -> Result<Self> {
171 let len = checked_len(&shape)?;
172 let strides = row_major_strides(&shape);
173 let mut data = Vec::new();
174 data.try_reserve_exact(len)
175 .map_err(|_| Error::AllocationFailed)?;
176 data.resize(len, value);
177 Ok(Self {
178 data,
179 shape,
180 strides,
181 })
182 }
183}
184
185impl<T: Float> ArrayN<T> {
186 pub fn zeros(shape: Vec<usize>) -> Self {
188 Self::filled(shape, T::zero())
189 }
190
191 pub fn try_zeros(shape: Vec<usize>) -> Result<Self> {
193 Self::try_filled(shape, T::zero())
194 }
195
196 pub fn ones(shape: Vec<usize>) -> Self {
198 Self::filled(shape, T::one())
199 }
200
201 pub fn try_ones(shape: Vec<usize>) -> Result<Self> {
203 Self::try_filled(shape, T::one())
204 }
205}
206
207impl<'a, T> ArrayViewN<'a, T> {
208 pub fn new(
210 data: &'a [T],
211 shape: &'a [usize],
212 strides: &'a [isize],
213 offset: isize,
214 ) -> Result<Self> {
215 validate_view(data.len(), shape, strides, offset)?;
216 Ok(Self {
217 data,
218 shape: shape.to_vec(),
219 strides: strides.to_vec(),
220 offset,
221 })
222 }
223
224 pub fn shape(&self) -> &[usize] {
226 &self.shape
227 }
228
229 pub fn strides(&self) -> &[isize] {
231 &self.strides
232 }
233
234 pub fn ndim(&self) -> usize {
236 self.shape.len()
237 }
238
239 pub fn len(&self) -> usize {
241 self.shape.iter().product()
242 }
243
244 pub fn is_empty(&self) -> bool {
246 self.len() == 0
247 }
248
249 pub fn is_contiguous(&self) -> bool {
251 is_compact_row_major(&self.shape, &self.strides)
252 }
253
254 pub fn as_slice(&self) -> Option<&'a [T]> {
256 if !self.is_contiguous() {
257 return None;
258 }
259 let start = self.offset as usize;
260 let end = start + self.len();
261 Some(&self.data[start..end])
262 }
263
264 pub fn get(&self, index: &[usize]) -> Option<&'a T> {
266 self.linear_index(index).map(|idx| &self.data[idx])
267 }
268
269 pub fn slice_axis(&self, axis: usize, index: usize) -> Result<Self> {
271 if axis >= self.ndim() {
272 return Err(Error::AxisOutOfBounds {
273 axis,
274 ndim: self.ndim(),
275 });
276 }
277 if index >= self.shape[axis] {
278 return Err(Error::IndexOutOfBounds);
279 }
280 let mut shape = self.shape.clone();
281 let mut strides = self.strides.clone();
282 let offset = self.offset + index as isize * strides[axis];
283 shape.remove(axis);
284 strides.remove(axis);
285 Ok(Self {
286 data: self.data,
287 shape,
288 strides,
289 offset,
290 })
291 }
292
293 fn linear_index(&self, index: &[usize]) -> Option<usize> {
294 if index.len() != self.ndim() {
295 return None;
296 }
297 let mut linear = self.offset;
298 for ((&idx, &dim), &stride) in index.iter().zip(&self.shape).zip(&self.strides) {
299 if idx >= dim {
300 return None;
301 }
302 linear += idx as isize * stride;
303 }
304 (linear >= 0).then_some(linear as usize)
305 }
306}
307
308impl<'a, T> ArrayViewMutN<'a, T> {
309 pub fn new(
311 data: &'a mut [T],
312 shape: Vec<usize>,
313 strides: Vec<isize>,
314 offset: isize,
315 ) -> Result<Self> {
316 validate_view(data.len(), &shape, &strides, offset)?;
317 Ok(Self {
318 data,
319 shape,
320 strides,
321 offset,
322 })
323 }
324
325 pub fn shape(&self) -> &[usize] {
327 &self.shape
328 }
329
330 pub fn strides(&self) -> &[isize] {
332 &self.strides
333 }
334
335 pub fn ndim(&self) -> usize {
337 self.shape.len()
338 }
339
340 pub fn len(&self) -> usize {
342 self.shape.iter().product()
343 }
344
345 pub fn is_empty(&self) -> bool {
347 self.len() == 0
348 }
349
350 pub fn is_contiguous(&self) -> bool {
352 is_compact_row_major(&self.shape, &self.strides)
353 }
354
355 pub fn as_mut_slice(&mut self) -> Option<&mut [T]> {
357 if !self.is_contiguous() {
358 return None;
359 }
360 let start = self.offset as usize;
361 let end = start + self.len();
362 Some(&mut self.data[start..end])
363 }
364
365 pub fn as_view(&self) -> ArrayViewN<'_, T> {
367 ArrayViewN {
368 data: self.data,
369 shape: self.shape.clone(),
370 strides: self.strides.clone(),
371 offset: self.offset,
372 }
373 }
374
375 pub fn get(&self, index: &[usize]) -> Option<&T> {
377 self.linear_index(index).map(|idx| &self.data[idx])
378 }
379
380 pub fn get_mut(&mut self, index: &[usize]) -> Option<&mut T> {
382 let linear = self.linear_index(index)?;
383 Some(&mut self.data[linear])
384 }
385
386 pub fn slice_axis_mut(&mut self, axis: usize, index: usize) -> Result<ArrayViewMutN<'_, T>> {
388 if axis >= self.ndim() {
389 return Err(Error::AxisOutOfBounds {
390 axis,
391 ndim: self.ndim(),
392 });
393 }
394 if index >= self.shape[axis] {
395 return Err(Error::IndexOutOfBounds);
396 }
397 let mut shape = self.shape.clone();
398 let mut strides = self.strides.clone();
399 let offset = self.offset + index as isize * strides[axis];
400 shape.remove(axis);
401 strides.remove(axis);
402 Ok(ArrayViewMutN {
403 data: &mut *self.data,
404 shape,
405 strides,
406 offset,
407 })
408 }
409
410 fn linear_index(&self, index: &[usize]) -> Option<usize> {
411 if index.len() != self.ndim() {
412 return None;
413 }
414 let mut linear = self.offset;
415 for ((&idx, &dim), &stride) in index.iter().zip(&self.shape).zip(&self.strides) {
416 if idx >= dim {
417 return None;
418 }
419 linear += idx as isize * stride;
420 }
421 (linear >= 0).then_some(linear as usize)
422 }
423}
424
425fn checked_len(shape: &[usize]) -> Result<usize> {
426 shape
427 .iter()
428 .try_fold(1usize, |acc, &dim| acc.checked_mul(dim))
429 .ok_or(Error::DimensionTooLarge)
430}
431
432fn row_major_strides(shape: &[usize]) -> Vec<isize> {
433 let mut strides = vec![1isize; shape.len()];
434 let mut acc = 1isize;
435 for axis in (0..shape.len()).rev() {
436 strides[axis] = acc;
437 acc *= shape[axis] as isize;
438 }
439 strides
440}
441
442fn is_compact_row_major(shape: &[usize], strides: &[isize]) -> bool {
443 if shape.contains(&0) {
444 return true;
445 }
446 let mut expected = 1isize;
447 for (&dim, &stride) in shape.iter().zip(strides).rev() {
448 if dim > 1 && stride != expected {
449 return false;
450 }
451 expected *= dim as isize;
452 }
453 true
454}