ndarray/
arraytraits.rs

1// Copyright 2014-2016 bluss and ndarray developers.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9#[cfg(not(feature = "std"))]
10use alloc::boxed::Box;
11#[cfg(not(feature = "std"))]
12use alloc::vec::Vec;
13use std::hash;
14use std::mem;
15use std::mem::size_of;
16use std::ops::{Index, IndexMut};
17use std::{iter::FromIterator, slice};
18
19use crate::imp_prelude::*;
20use crate::Arc;
21
22use crate::{
23    dimension,
24    iter::{Iter, IterMut},
25    numeric_util,
26    FoldWhile,
27    NdIndex,
28    OwnedArcRepr,
29    Zip,
30};
31
32#[cold]
33#[inline(never)]
34pub(crate) fn array_out_of_bounds() -> !
35{
36    panic!("ndarray: index out of bounds");
37}
38
39#[inline(always)]
40pub fn debug_bounds_check<A, D, I, T>(_a: &T, _index: &I)
41where
42    D: Dimension,
43    I: NdIndex<D>,
44    T: AsRef<LayoutRef<A, D>> + ?Sized,
45{
46    let _layout_ref = _a.as_ref();
47    debug_bounds_check_ref!(_layout_ref, *_index);
48}
49
50/// Access the element at **index**.
51///
52/// **Panics** if index is out of bounds.
53impl<A, D, I> Index<I> for ArrayRef<A, D>
54where
55    D: Dimension,
56    I: NdIndex<D>,
57{
58    type Output = A;
59
60    #[inline]
61    fn index(&self, index: I) -> &Self::Output
62    {
63        debug_bounds_check_ref!(self, index);
64        unsafe {
65            &*self._ptr().as_ptr().offset(
66                index
67                    .index_checked(self._dim(), self._strides())
68                    .unwrap_or_else(|| array_out_of_bounds()),
69            )
70        }
71    }
72}
73
74/// Access the element at **index** mutably.
75///
76/// **Panics** if index is out of bounds.
77impl<A, D, I> IndexMut<I> for ArrayRef<A, D>
78where
79    D: Dimension,
80    I: NdIndex<D>,
81{
82    #[inline]
83    fn index_mut(&mut self, index: I) -> &mut A
84    {
85        debug_bounds_check_ref!(self, index);
86        unsafe {
87            &mut *self.as_mut_ptr().offset(
88                index
89                    .index_checked(self._dim(), self._strides())
90                    .unwrap_or_else(|| array_out_of_bounds()),
91            )
92        }
93    }
94}
95
96/// Access the element at **index**.
97///
98/// **Panics** if index is out of bounds.
99impl<S, D, I> Index<I> for ArrayBase<S, D>
100where
101    D: Dimension,
102    I: NdIndex<D>,
103    S: Data,
104{
105    type Output = S::Elem;
106
107    #[inline]
108    fn index(&self, index: I) -> &S::Elem
109    {
110        Index::index(&**self, index)
111    }
112}
113
114/// Access the element at **index** mutably.
115///
116/// **Panics** if index is out of bounds.
117impl<S, D, I> IndexMut<I> for ArrayBase<S, D>
118where
119    D: Dimension,
120    I: NdIndex<D>,
121    S: DataMut,
122{
123    #[inline]
124    fn index_mut(&mut self, index: I) -> &mut S::Elem
125    {
126        IndexMut::index_mut(&mut (**self), index)
127    }
128}
129
130/// Return `true` if the array shapes and all elements of `self` and
131/// `rhs` are equal. Return `false` otherwise.
132impl<A, B, D> PartialEq<ArrayRef<B, D>> for ArrayRef<A, D>
133where
134    A: PartialEq<B>,
135    D: Dimension,
136{
137    fn eq(&self, rhs: &ArrayRef<B, D>) -> bool
138    {
139        if self.shape() != rhs.shape() {
140            return false;
141        }
142        if let Some(self_s) = self.as_slice() {
143            if let Some(rhs_s) = rhs.as_slice() {
144                return numeric_util::unrolled_eq(self_s, rhs_s);
145            }
146        }
147        Zip::from(self)
148            .and(rhs)
149            .fold_while(true, |_, a, b| {
150                if a != b {
151                    FoldWhile::Done(false)
152                } else {
153                    FoldWhile::Continue(true)
154                }
155            })
156            .into_inner()
157    }
158}
159
160/// Return `true` if the array shapes and all elements of `self` and
161/// `rhs` are equal. Return `false` otherwise.
162impl<A, B, D> PartialEq<&ArrayRef<B, D>> for ArrayRef<A, D>
163where
164    A: PartialEq<B>,
165    D: Dimension,
166{
167    fn eq(&self, rhs: &&ArrayRef<B, D>) -> bool
168    {
169        *self == **rhs
170    }
171}
172
173/// Return `true` if the array shapes and all elements of `self` and
174/// `rhs` are equal. Return `false` otherwise.
175impl<A, B, D> PartialEq<ArrayRef<B, D>> for &ArrayRef<A, D>
176where
177    A: PartialEq<B>,
178    D: Dimension,
179{
180    fn eq(&self, rhs: &ArrayRef<B, D>) -> bool
181    {
182        **self == *rhs
183    }
184}
185
186impl<A, D> Eq for ArrayRef<A, D>
187where
188    D: Dimension,
189    A: Eq,
190{
191}
192
193/// Return `true` if the array shapes and all elements of `self` and
194/// `rhs` are equal. Return `false` otherwise.
195impl<A, B, S, S2, D> PartialEq<ArrayBase<S2, D>> for ArrayBase<S, D>
196where
197    A: PartialEq<B>,
198    S: Data<Elem = A>,
199    S2: Data<Elem = B>,
200    D: Dimension,
201{
202    fn eq(&self, rhs: &ArrayBase<S2, D>) -> bool
203    {
204        PartialEq::eq(&**self, &**rhs)
205    }
206}
207
208/// Return `true` if the array shapes and all elements of `self` and
209/// `rhs` are equal. Return `false` otherwise.
210impl<A, B, S, S2, D> PartialEq<&ArrayBase<S2, D>> for ArrayBase<S, D>
211where
212    A: PartialEq<B>,
213    S: Data<Elem = A>,
214    S2: Data<Elem = B>,
215    D: Dimension,
216{
217    fn eq(&self, rhs: &&ArrayBase<S2, D>) -> bool
218    {
219        *self == **rhs
220    }
221}
222
223/// Return `true` if the array shapes and all elements of `self` and
224/// `rhs` are equal. Return `false` otherwise.
225impl<A, B, S, S2, D> PartialEq<ArrayBase<S2, D>> for &ArrayBase<S, D>
226where
227    A: PartialEq<B>,
228    S: Data<Elem = A>,
229    S2: Data<Elem = B>,
230    D: Dimension,
231{
232    fn eq(&self, rhs: &ArrayBase<S2, D>) -> bool
233    {
234        **self == *rhs
235    }
236}
237impl<S, D> Eq for ArrayBase<S, D>
238where
239    D: Dimension,
240    S: Data,
241    S::Elem: Eq,
242{
243}
244
245impl<A, B, S, D> PartialEq<ArrayRef<B, D>> for ArrayBase<S, D>
246where
247    S: Data<Elem = A>,
248    A: PartialEq<B>,
249    D: Dimension,
250{
251    fn eq(&self, other: &ArrayRef<B, D>) -> bool
252    {
253        **self == other
254    }
255}
256
257impl<A, B, S, D> PartialEq<&ArrayRef<B, D>> for ArrayBase<S, D>
258where
259    S: Data<Elem = A>,
260    A: PartialEq<B>,
261    D: Dimension,
262{
263    fn eq(&self, other: &&ArrayRef<B, D>) -> bool
264    {
265        **self == *other
266    }
267}
268
269impl<A, B, S, D> PartialEq<ArrayRef<B, D>> for &ArrayBase<S, D>
270where
271    S: Data<Elem = A>,
272    A: PartialEq<B>,
273    D: Dimension,
274{
275    fn eq(&self, other: &ArrayRef<B, D>) -> bool
276    {
277        **self == other
278    }
279}
280
281impl<A, B, S, D> PartialEq<ArrayBase<S, D>> for ArrayRef<A, D>
282where
283    S: Data<Elem = B>,
284    A: PartialEq<B>,
285    D: Dimension,
286{
287    fn eq(&self, other: &ArrayBase<S, D>) -> bool
288    {
289        self == **other
290    }
291}
292
293impl<A, B, S, D> PartialEq<&ArrayBase<S, D>> for ArrayRef<A, D>
294where
295    S: Data<Elem = B>,
296    A: PartialEq<B>,
297    D: Dimension,
298{
299    fn eq(&self, other: &&ArrayBase<S, D>) -> bool
300    {
301        self == ***other
302    }
303}
304
305impl<A, B, S, D> PartialEq<ArrayBase<S, D>> for &ArrayRef<A, D>
306where
307    S: Data<Elem = B>,
308    A: PartialEq<B>,
309    D: Dimension,
310{
311    fn eq(&self, other: &ArrayBase<S, D>) -> bool
312    {
313        *self == **other
314    }
315}
316
317impl<A, S> From<Box<[A]>> for ArrayBase<S, Ix1>
318where S: DataOwned<Elem = A>
319{
320    /// Create a one-dimensional array from a boxed slice (no copying needed).
321    ///
322    /// **Panics** if the length is greater than `isize::MAX`.
323    fn from(b: Box<[A]>) -> Self
324    {
325        Self::from_vec(b.into_vec())
326    }
327}
328
329impl<A, S> From<Vec<A>> for ArrayBase<S, Ix1>
330where S: DataOwned<Elem = A>
331{
332    /// Create a one-dimensional array from a vector (no copying needed).
333    ///
334    /// **Panics** if the length is greater than `isize::MAX`.
335    ///
336    /// ```rust
337    /// use ndarray::Array;
338    ///
339    /// let array = Array::from(vec![1., 2., 3., 4.]);
340    /// ```
341    fn from(v: Vec<A>) -> Self
342    {
343        Self::from_vec(v)
344    }
345}
346
347impl<A, S> FromIterator<A> for ArrayBase<S, Ix1>
348where S: DataOwned<Elem = A>
349{
350    /// Create a one-dimensional array from an iterable.
351    ///
352    /// **Panics** if the length is greater than `isize::MAX`.
353    ///
354    /// ```rust
355    /// use ndarray::{Array, arr1};
356    ///
357    /// // Either use `from_iter` directly or use `Iterator::collect`.
358    /// let array = Array::from_iter((0..5).map(|x| x * x));
359    /// assert!(array == arr1(&[0, 1, 4, 9, 16]))
360    /// ```
361    fn from_iter<I>(iterable: I) -> ArrayBase<S, Ix1>
362    where I: IntoIterator<Item = A>
363    {
364        Self::from_iter(iterable)
365    }
366}
367
368impl<'a, A, D> IntoIterator for &'a ArrayRef<A, D>
369where D: Dimension
370{
371    type Item = &'a A;
372
373    type IntoIter = Iter<'a, A, D>;
374
375    fn into_iter(self) -> Self::IntoIter
376    {
377        self.iter()
378    }
379}
380
381impl<'a, A, D> IntoIterator for &'a mut ArrayRef<A, D>
382where D: Dimension
383{
384    type Item = &'a mut A;
385
386    type IntoIter = IterMut<'a, A, D>;
387
388    fn into_iter(self) -> Self::IntoIter
389    {
390        self.iter_mut()
391    }
392}
393
394impl<'a, S, D> IntoIterator for &'a ArrayBase<S, D>
395where
396    D: Dimension,
397    S: Data,
398{
399    type Item = &'a S::Elem;
400    type IntoIter = Iter<'a, S::Elem, D>;
401
402    fn into_iter(self) -> Self::IntoIter
403    {
404        self.iter()
405    }
406}
407
408impl<'a, S, D> IntoIterator for &'a mut ArrayBase<S, D>
409where
410    D: Dimension,
411    S: DataMut,
412{
413    type Item = &'a mut S::Elem;
414    type IntoIter = IterMut<'a, S::Elem, D>;
415
416    fn into_iter(self) -> Self::IntoIter
417    {
418        self.iter_mut()
419    }
420}
421
422impl<'a, A, D> IntoIterator for ArrayView<'a, A, D>
423where D: Dimension
424{
425    type Item = &'a A;
426    type IntoIter = Iter<'a, A, D>;
427
428    fn into_iter(self) -> Self::IntoIter
429    {
430        Iter::new(self)
431    }
432}
433
434impl<'a, A, D> IntoIterator for ArrayViewMut<'a, A, D>
435where D: Dimension
436{
437    type Item = &'a mut A;
438    type IntoIter = IterMut<'a, A, D>;
439
440    fn into_iter(self) -> Self::IntoIter
441    {
442        IterMut::new(self)
443    }
444}
445
446impl<A, D> hash::Hash for ArrayRef<A, D>
447where
448    D: Dimension,
449    A: hash::Hash,
450{
451    // Note: elements are hashed in the logical order
452    fn hash<H: hash::Hasher>(&self, state: &mut H)
453    {
454        self.shape().hash(state);
455        if let Some(self_s) = self.as_slice() {
456            hash::Hash::hash_slice(self_s, state);
457        } else {
458            for row in self.rows() {
459                if let Some(row_s) = row.as_slice() {
460                    hash::Hash::hash_slice(row_s, state);
461                } else {
462                    for elt in row {
463                        elt.hash(state)
464                    }
465                }
466            }
467        }
468    }
469}
470
471impl<S, D> hash::Hash for ArrayBase<S, D>
472where
473    D: Dimension,
474    S: Data,
475    S::Elem: hash::Hash,
476{
477    // Note: elements are hashed in the logical order
478    fn hash<H: hash::Hasher>(&self, state: &mut H)
479    {
480        (**self).hash(state)
481    }
482}
483
484// NOTE: ArrayBase keeps an internal raw pointer that always
485// points into the storage. This is Sync & Send as long as we
486// follow the usual inherited mutability rules, as we do with
487// Vec, &[] and &mut []
488
489/// `ArrayBase` is `Sync` when the storage type is.
490unsafe impl<S, D> Sync for ArrayBase<S, D>
491where
492    S: Sync + Data,
493    D: Sync,
494{
495}
496
497/// `ArrayBase` is `Send` when the storage type is.
498unsafe impl<S, D> Send for ArrayBase<S, D>
499where
500    S: Send + Data,
501    D: Send,
502{
503}
504
505unsafe impl<A, D> Sync for ArrayRef<A, D> where A: Sync {}
506
507unsafe impl<A, D> Send for ArrayRef<A, D> where A: Send {}
508
509#[cfg(feature = "serde")]
510// Use version number so we can add a packed format later.
511pub const ARRAY_FORMAT_VERSION: u8 = 1u8;
512
513// use "raw" form instead of type aliases here so that they show up in docs
514/// Implementation of `ArrayView::from(&S)` where `S` is a slice or sliceable.
515///
516/// **Panics** if the length of the slice overflows `isize`. (This can only
517/// occur if `A` is zero-sized, because slices cannot contain more than
518/// `isize::MAX` number of bytes.)
519impl<'a, A, Slice: ?Sized> From<&'a Slice> for ArrayView<'a, A, Ix1>
520where Slice: AsRef<[A]>
521{
522    /// Create a one-dimensional read-only array view of the data in `slice`.
523    ///
524    /// **Panics** if the slice length is greater than `isize::MAX`.
525    fn from(slice: &'a Slice) -> Self
526    {
527        aview1(slice.as_ref())
528    }
529}
530
531/// Implementation of ArrayView2::from(&[[A; N]; M])
532///
533/// **Panics** if the product of non-zero axis lengths overflows `isize` (This can only occur if A
534/// is zero-sized because slices cannot contain more than `isize::MAX` number of bytes).
535/// **Panics** if N == 0 and the number of rows is greater than isize::MAX.
536impl<'a, A, const M: usize, const N: usize> From<&'a [[A; N]; M]> for ArrayView<'a, A, Ix2>
537{
538    /// Create a two-dimensional read-only array view of the data in `slice`
539    fn from(xs: &'a [[A; N]; M]) -> Self
540    {
541        Self::from(&xs[..])
542    }
543}
544
545/// Implementation of ArrayView2::from(&[[A; N]])
546///
547/// **Panics** if the product of non-zero axis lengths overflows `isize`. (This
548/// can only occur if A is zero-sized or if `N` is zero, because slices cannot
549/// contain more than `isize::MAX` number of bytes.)
550impl<'a, A, const N: usize> From<&'a [[A; N]]> for ArrayView<'a, A, Ix2>
551{
552    /// Create a two-dimensional read-only array view of the data in `slice`
553    fn from(xs: &'a [[A; N]]) -> Self
554    {
555        aview2(xs)
556    }
557}
558
559/// Implementation of `ArrayView::from(&A)` where `A` is an array.
560impl<'a, A, S, D> From<&'a ArrayBase<S, D>> for ArrayView<'a, A, D>
561where
562    S: Data<Elem = A>,
563    D: Dimension,
564{
565    /// Create a read-only array view of the array.
566    fn from(array: &'a ArrayBase<S, D>) -> Self
567    {
568        array.view()
569    }
570}
571
572/// Implementation of `ArrayViewMut::from(&mut S)` where `S` is a slice or sliceable.
573impl<'a, A, Slice: ?Sized> From<&'a mut Slice> for ArrayViewMut<'a, A, Ix1>
574where Slice: AsMut<[A]>
575{
576    /// Create a one-dimensional read-write array view of the data in `slice`.
577    ///
578    /// **Panics** if the slice length is greater than `isize::MAX`.
579    fn from(slice: &'a mut Slice) -> Self
580    {
581        let xs = slice.as_mut();
582        if mem::size_of::<A>() == 0 {
583            assert!(
584                xs.len() <= isize::MAX as usize,
585                "Slice length must fit in `isize`.",
586            );
587        }
588        unsafe { Self::from_shape_ptr(xs.len(), xs.as_mut_ptr()) }
589    }
590}
591
592/// Implementation of ArrayViewMut2::from(&mut [[A; N]; M])
593///
594/// **Panics** if the product of non-zero axis lengths overflows `isize` (This can only occur if A
595/// is zero-sized because slices cannot contain more than `isize::MAX` number of bytes).
596/// **Panics** if N == 0 and the number of rows is greater than isize::MAX.
597impl<'a, A, const M: usize, const N: usize> From<&'a mut [[A; N]; M]> for ArrayViewMut<'a, A, Ix2>
598{
599    /// Create a two-dimensional read-write array view of the data in `slice`
600    fn from(xs: &'a mut [[A; N]; M]) -> Self
601    {
602        Self::from(&mut xs[..])
603    }
604}
605
606/// Implementation of ArrayViewMut2::from(&mut [[A; N]])
607///
608/// **Panics** if the product of non-zero axis lengths overflows `isize`. (This
609/// can only occur if `A` is zero-sized or if `N` is zero, because slices
610/// cannot contain more than `isize::MAX` number of bytes.)
611impl<'a, A, const N: usize> From<&'a mut [[A; N]]> for ArrayViewMut<'a, A, Ix2>
612{
613    /// Create a two-dimensional read-write array view of the data in `slice`
614    fn from(xs: &'a mut [[A; N]]) -> Self
615    {
616        let cols = N;
617        let rows = xs.len();
618        let dim = Ix2(rows, cols);
619        if size_of::<A>() == 0 {
620            dimension::size_of_shape_checked(&dim).expect("Product of non-zero axis lengths must not overflow isize.");
621        } else if N == 0 {
622            assert!(
623                xs.len() <= isize::MAX as usize,
624                "Product of non-zero axis lengths must not overflow isize.",
625            );
626        }
627
628        // `cols * rows` is guaranteed to fit in `isize` because we checked that it fits in
629        // `isize::MAX`
630        unsafe {
631            let data = slice::from_raw_parts_mut(xs.as_mut_ptr() as *mut A, cols * rows);
632            ArrayViewMut::from_shape_ptr(dim, data.as_mut_ptr())
633        }
634    }
635}
636
637/// Implementation of `ArrayViewMut::from(&mut A)` where `A` is an array.
638impl<'a, A, S, D> From<&'a mut ArrayBase<S, D>> for ArrayViewMut<'a, A, D>
639where
640    S: DataMut<Elem = A>,
641    D: Dimension,
642{
643    /// Create a read-write array view of the array.
644    fn from(array: &'a mut ArrayBase<S, D>) -> Self
645    {
646        array.view_mut()
647    }
648}
649
650impl<A, D> From<Array<A, D>> for ArcArray<A, D>
651where D: Dimension
652{
653    fn from(arr: Array<A, D>) -> ArcArray<A, D>
654    {
655        let data = OwnedArcRepr(Arc::new(arr.data));
656        // safe because: equivalent unmoved data, ptr and dims remain valid
657        unsafe { ArrayBase::from_data_ptr(data, arr.parts.ptr).with_strides_dim(arr.parts.strides, arr.parts.dim) }
658    }
659}
660
661/// Argument conversion into an array view
662///
663/// The trait is parameterized over `A`, the element type, and `D`, the
664/// dimensionality of the array. `D` defaults to one-dimensional.
665///
666/// Use `.into()` to do the conversion.
667///
668/// ```
669/// use ndarray::AsArray;
670///
671/// fn sum<'a, V: AsArray<'a, f64>>(data: V) -> f64 {
672///     let array_view = data.into();
673///     array_view.sum()
674/// }
675///
676/// assert_eq!(
677///     sum(&[1., 2., 3.]),
678///     6.
679/// );
680///
681/// ```
682pub trait AsArray<'a, A: 'a, D = Ix1>: Into<ArrayView<'a, A, D>>
683where D: Dimension
684{
685}
686impl<'a, A: 'a, D, T> AsArray<'a, A, D> for T
687where
688    T: Into<ArrayView<'a, A, D>>,
689    D: Dimension,
690{
691}
692
693/// Create an owned array with a default state.
694///
695/// The array is created with dimension `D::default()`, which results
696/// in for example dimensions `0` and `(0, 0)` with zero elements for the
697/// one-dimensional and two-dimensional cases respectively.
698///
699/// The default dimension for `IxDyn` is `IxDyn(&[0])` (array has zero
700/// elements). And the default for the dimension `()` is `()` (array has
701/// one element).
702///
703/// Since arrays cannot grow, the intention is to use the default value as
704/// placeholder.
705impl<A, S, D> Default for ArrayBase<S, D>
706where
707    S: DataOwned<Elem = A>,
708    D: Dimension,
709    A: Default,
710{
711    // NOTE: We can implement Default for non-zero dimensional array views by
712    // using an empty slice, however we need a trait for nonzero Dimension.
713    fn default() -> Self
714    {
715        ArrayBase::default(D::default())
716    }
717}
718
719#[cfg(test)]
720mod tests
721{
722    use crate::array;
723    use alloc::vec;
724
725    #[test]
726    fn test_eq_traits()
727    {
728        let a = array![1, 2, 3];
729        let a_ref = &*a;
730        let b = array![1, 2, 3];
731        let b_ref = &*b;
732
733        assert_eq!(a, b);
734        assert_eq!(a, &b);
735        assert_eq!(&a, b);
736        assert_eq!(&a, &b);
737
738        assert_eq!(a_ref, b_ref);
739        assert_eq!(&a_ref, b_ref);
740        assert_eq!(a_ref, &b_ref);
741        assert_eq!(&a_ref, &b_ref);
742
743        assert_eq!(a_ref, b);
744        assert_eq!(a_ref, &b);
745        assert_eq!(&a_ref, &b);
746
747        assert_eq!(a, b_ref);
748        assert_eq!(&a, b_ref);
749        assert_eq!(&a, &b_ref);
750    }
751}