ferrix/
vector_view.rs

1use crate::row_vector_view::RowVectorView;
2use crate::traits::DotProduct;
3use crate::vector::Vector;
4use crate::vector_view_mut::VectorViewMut;
5use num_traits::Float;
6use std::fmt;
7use std::marker::PhantomData;
8use std::ops::Index;
9
10/// A column vector view of a [`Vector`] or a [`RowVector`](crate::row_vector::RowVector) (transposed view).
11#[derive(Debug, Clone)]
12pub struct VectorView<'a, V, T, const N: usize, const M: usize> {
13    data: &'a V,
14    start: usize,
15    _phantom: PhantomData<T>,
16}
17
18impl<'a, V, T, const N: usize, const M: usize> VectorView<'a, V, T, N, M> {
19    pub(super) fn new(data: &'a V, start: usize) -> Self {
20        Self {
21            data,
22            start,
23            _phantom: PhantomData,
24        }
25    }
26
27    /// Returns the shape of the [`VectorView`].
28    ///
29    /// The shape is always equal to `M`.
30    ///
31    /// # Examples
32    ///
33    /// ```
34    /// use ferrix::Vector;
35    ///
36    /// let vec = Vector::from([1, 2, 3, 4, 5]);
37    /// let view = vec.view::<3>(1).unwrap();
38    /// assert_eq!(view.shape(), 3);
39    /// ```
40    #[inline]
41    pub fn shape(&self) -> usize {
42        M
43    }
44
45    /// Returns the total number of elements in the [`VectorView`].
46    ///
47    /// The total number of elements is always equal to `M`.
48    ///
49    /// # Examples
50    ///
51    /// ```
52    /// use ferrix::Vector;
53    ///
54    /// let vec = Vector::from([1, 2, 3, 4, 5]);
55    /// let view = vec.view::<3>(1).unwrap();
56    /// assert_eq!(view.capacity(), 3);
57    /// ```
58    #[inline]
59    pub fn capacity(&self) -> usize {
60        M
61    }
62
63    /// Returns the number of rows in the [`VectorView`].
64    ///
65    /// The number of rows is always equal to `M`.
66    ///
67    /// # Examples
68    ///
69    /// ```
70    /// use ferrix::Vector;
71    ///
72    /// let vec = Vector::from([1, 2, 3, 4, 5]);
73    /// let view = vec.view::<3>(1).unwrap();
74    /// assert_eq!(view.rows(), 3);
75    /// ```
76    #[inline]
77    pub fn rows(&self) -> usize {
78        M
79    }
80
81    /// Returns the number of columns in the [`VectorView`].
82    ///
83    /// The number of columns is always `1`.
84    ///
85    /// # Examples
86    ///
87    /// ```
88    /// use ferrix::Vector;
89    ///
90    /// let vec = Vector::from([1, 2, 3, 4, 5]);
91    /// let view = vec.view::<3>(1).unwrap();
92    /// assert_eq!(view.cols(), 1);
93    /// ```
94    #[inline]
95    pub fn cols(&self) -> usize {
96        1
97    }
98
99    /// Returns a transposed view of the [`VectorView`].
100    ///
101    /// This method returns a [`RowVectorView`], which is a read-only view of the [`VectorView`] as a row vector.
102    ///
103    /// # Examples
104    ///
105    /// ```
106    /// use ferrix::{Vector, RowVector};
107    ///
108    /// let vec = Vector::from([1, 2, 3, 4, 5]);
109    /// let view = vec.view::<3>(1).unwrap();
110    /// let row_view = view.t();
111    /// assert_eq!(row_view, RowVector::from([2, 3, 4]));
112    /// ```
113    pub fn t(&'a self) -> RowVectorView<'a, V, T, N, M> {
114        RowVectorView::new(self.data, self.start)
115    }
116}
117
118impl<'a, V: Index<usize, Output = T>, T: Float, const N: usize, const M: usize>
119    VectorView<'a, V, T, N, M>
120{
121    /// Calculates the magnitude (Euclidean norm) of the [`VectorView`].
122    ///
123    /// # Examples
124    ///
125    /// ```
126    /// use ferrix::Vector;
127    ///
128    /// let vec = Vector::from([1.0, 2.0, 3.0, 4.0, 5.0]);
129    /// let view = vec.view::<2>(2).unwrap();
130    /// assert_eq!(view.magnitude(), 5.0);
131    /// ```
132    pub fn magnitude(&self) -> T {
133        self.dot(self).sqrt()
134    }
135}
136
137//////////////////////////////////////
138//  Equality Trait Implementations  //
139//////////////////////////////////////
140
141// VectorView == Vector
142impl<T: PartialEq, const N: usize, V: Index<usize, Output = T>, const A: usize>
143    PartialEq<Vector<T, N>> for VectorView<'_, V, T, A, N>
144{
145    fn eq(&self, other: &Vector<T, N>) -> bool {
146        (0..N).all(|i| self[i] == other[i])
147    }
148}
149
150// VectorView == VectorView
151impl<
152        T: PartialEq,
153        V1: Index<usize, Output = T>,
154        V2: Index<usize, Output = T>,
155        const A1: usize,
156        const A2: usize,
157        const N: usize,
158    > PartialEq<VectorView<'_, V2, T, A2, N>> for VectorView<'_, V1, T, A1, N>
159{
160    fn eq(&self, other: &VectorView<'_, V2, T, A2, N>) -> bool {
161        (0..N).all(|i| self[i] == other[i])
162    }
163}
164
165// VectorView == VectorViewMut
166impl<
167        T: PartialEq,
168        V1: Index<usize, Output = T>,
169        V2: Index<usize, Output = T>,
170        const A1: usize,
171        const A2: usize,
172        const N: usize,
173    > PartialEq<VectorViewMut<'_, V2, T, A2, N>> for VectorView<'_, V1, T, A1, N>
174{
175    fn eq(&self, other: &VectorViewMut<'_, V2, T, A2, N>) -> bool {
176        (0..N).all(|i| self[i] == other[i])
177    }
178}
179
180impl<'a, V: Index<usize, Output = T>, T: Eq, const N: usize, const M: usize> Eq
181    for VectorView<'a, V, T, N, M>
182{
183}
184
185/////////////////////////////////////
186//  Display Trait Implementations  //
187/////////////////////////////////////
188
189impl<'a, V: Index<usize, Output = T>, T: fmt::Display, const N: usize, const M: usize> fmt::Display
190    for VectorView<'a, V, T, N, M>
191{
192    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
193        if f.alternate() {
194            write!(f, "VectorView(")?;
195        }
196
197        write!(f, "[")?;
198        for i in 0..M - 1 {
199            if i > 0 {
200                write!(f, " ")?;
201                if f.alternate() {
202                    write!(f, "           ")?;
203                }
204            }
205            writeln!(f, "{}", self[i])?;
206        }
207        if f.alternate() {
208            write!(f, "           ")?;
209        }
210
211        write!(f, " {}]", self[M - 1])?;
212
213        if f.alternate() {
214            write!(f, ", dtype={})", std::any::type_name::<T>())?;
215        }
216
217        Ok(())
218    }
219}
220
221///////////////////////////////////
222//  Index Trait Implementations  //
223///////////////////////////////////
224
225impl<'a, V: Index<usize, Output = T>, T, const N: usize, const M: usize> Index<usize>
226    for VectorView<'a, V, T, N, M>
227{
228    type Output = T;
229
230    fn index(&self, index: usize) -> &Self::Output {
231        &self.data[self.start + index]
232    }
233}
234
235impl<V: Index<usize, Output = T>, T, const N: usize, const M: usize> Index<(usize, usize)>
236    for VectorView<'_, V, T, N, M>
237{
238    type Output = T;
239
240    fn index(&self, index: (usize, usize)) -> &Self::Output {
241        if index.1 != 0 {
242            panic!("Index out of bounds");
243        }
244        &self.data[self.start + index.0]
245    }
246}