ferrix/
row_vector_view_mut.rs

1use crate::row_vector::RowVector;
2use crate::row_vector_view::RowVectorView;
3use crate::traits::DotProduct;
4use crate::vector_view::VectorView;
5use crate::vector_view_mut::VectorViewMut;
6use num_traits::Float;
7use std::fmt;
8use std::marker::PhantomData;
9use std::ops::{Index, IndexMut};
10
11/// A row vector view of a [`Vector`](crate::vector::Vector) (transposed view) or a [`RowVector`].
12#[derive(Debug)]
13pub struct RowVectorViewMut<'a, V, T, const N: usize, const M: usize> {
14    data: &'a mut V,
15    start: usize,
16    _phantom: PhantomData<T>,
17}
18
19impl<'a, V, T, const N: usize, const M: usize> RowVectorViewMut<'a, V, T, N, M> {
20    pub(super) fn new(data: &'a mut V, start: usize) -> Self {
21        Self {
22            data,
23            start,
24            _phantom: PhantomData,
25        }
26    }
27
28    /// Returns the shape of the [`RowVectorViewMut`].
29    ///
30    /// The shape is always equal to `M`.
31    ///
32    /// # Examples
33    ///
34    /// ```
35    /// use ferrix::RowVector;
36    ///
37    /// let mut vec = RowVector::from([1, 2, 3, 4, 5]);
38    /// let view = vec.view_mut::<3>(1).unwrap();
39    /// assert_eq!(view.shape(), 3);
40    /// ```
41    #[inline]
42    pub fn shape(&self) -> usize {
43        M
44    }
45
46    /// Returns the total number of elements in the [`RowVectorViewMut`].
47    ///
48    /// The total number of elements is always equal to `M`.
49    ///
50    /// # Examples
51    ///
52    /// ```
53    /// use ferrix::RowVector;
54    ///
55    /// let mut vec = RowVector::from([1, 2, 3, 4, 5]);
56    /// let view = vec.view_mut::<3>(1).unwrap();
57    /// assert_eq!(view.capacity(), 3);
58    /// ```
59    #[inline]
60    pub fn capacity(&self) -> usize {
61        M
62    }
63
64    /// Returns the number of rows in the [`RowVectorViewMut`].
65    ///
66    /// The number of rows is always `1`.
67    ///
68    /// # Examples
69    ///
70    /// ```
71    /// use ferrix::RowVector;
72    ///
73    /// let mut vec = RowVector::from([1, 2, 3, 4, 5]);
74    /// let view = vec.view_mut::<3>(1).unwrap();
75    /// assert_eq!(view.rows(), 1);
76    /// ```
77    #[inline]
78    pub fn rows(&self) -> usize {
79        1
80    }
81
82    /// Returns the number of columns in the [`RowVectorViewMut`].
83    ///
84    /// The number of columns is always equal to `M`.
85    ///
86    /// # Examples
87    ///
88    /// ```
89    /// use ferrix::RowVector;
90    ///
91    /// let mut vec = RowVector::from([1, 2, 3, 4, 5]);
92    /// let view = vec.view_mut::<3>(1).unwrap();
93    /// assert_eq!(view.cols(), 3);
94    /// ```
95    #[inline]
96    pub fn cols(&self) -> usize {
97        M
98    }
99
100    /// Returns a transposed view of the [`RowVectorViewMut`].
101    ///
102    /// This method returns a [`VectorView`], which is a read-only view of the [`RowVectorViewMut`] as a column vector.
103    ///
104    /// # Examples
105    ///
106    /// ```
107    /// use ferrix::{Vector, RowVector};
108    ///
109    /// let mut vec = RowVector::from([1, 2, 3, 4, 5]);
110    /// let view = vec.view_mut::<3>(1).unwrap();
111    /// let col_view = view.t();
112    /// assert_eq!(col_view, Vector::from([2, 3, 4]));
113    /// ```
114    pub fn t(&'a self) -> VectorView<'a, V, T, N, M> {
115        VectorView::new(self.data, self.start)
116    }
117
118    /// Returns a mutable transposed view of the [`RowVectorViewMut`].
119    ///
120    /// This method returns a [`VectorViewMut`], which is a mutable view of the [`RowVectorViewMut`] as a column vector.
121    ///
122    /// # Examples
123    ///
124    /// ```
125    /// use ferrix::RowVector;
126    ///
127    /// let mut vec = RowVector::from([1, 2, 3, 4, 5]);
128    /// let mut view = vec.view_mut::<3>(1).unwrap();
129    /// let mut col_view = view.t_mut();
130    /// col_view[1] = 10;
131    /// assert_eq!(vec, RowVector::from([1, 2, 10, 4, 5]));
132    /// ```
133    pub fn t_mut(&'a mut self) -> VectorViewMut<'a, V, T, N, M> {
134        VectorViewMut::new(self.data, self.start)
135    }
136}
137
138impl<'a, V: Index<usize, Output = T>, T: Float, const N: usize, const M: usize>
139    RowVectorViewMut<'a, V, T, N, M>
140{
141    /// Calculates the magnitude of the [`RowVectorViewMut`].
142    ///
143    /// # Examples
144    ///
145    /// ```
146    /// use ferrix::RowVector;
147    ///
148    /// let mut vec = RowVector::from([1.0, 2.0, 3.0, 4.0, 5.0]);
149    /// let view = vec.view_mut::<2>(2).unwrap();
150    /// assert_eq!(view.magnitude(), 5.0);
151    /// ```
152    pub fn magnitude(&self) -> T {
153        self.dot(self).sqrt()
154    }
155}
156
157//////////////////////////////////////
158//  Equality Trait Implementations  //
159//////////////////////////////////////
160
161// RowVectorViewMut == RowVector
162impl<T: PartialEq, const N: usize, V: Index<usize, Output = T>, const A: usize>
163    PartialEq<RowVector<T, N>> for RowVectorViewMut<'_, V, T, A, N>
164{
165    fn eq(&self, other: &RowVector<T, N>) -> bool {
166        (0..N).all(|i| self[i] == other[i])
167    }
168}
169
170// RowVectorViewMut == RowVectorView
171impl<
172        T: PartialEq,
173        V1: Index<usize, Output = T>,
174        V2: Index<usize, Output = T>,
175        const A1: usize,
176        const A2: usize,
177        const N: usize,
178    > PartialEq<RowVectorView<'_, V2, T, A2, N>> for RowVectorViewMut<'_, V1, T, A1, N>
179{
180    fn eq(&self, other: &RowVectorView<'_, V2, T, A2, N>) -> bool {
181        (0..N).all(|i| self[i] == other[i])
182    }
183}
184
185// RowVectorViewMut == RowVectorViewMut
186impl<
187        T: PartialEq,
188        V1: Index<usize, Output = T>,
189        V2: Index<usize, Output = T>,
190        const A1: usize,
191        const A2: usize,
192        const N: usize,
193    > PartialEq<RowVectorViewMut<'_, V2, T, A2, N>> for RowVectorViewMut<'_, V1, T, A1, N>
194{
195    fn eq(&self, other: &RowVectorViewMut<'_, V2, T, A2, N>) -> bool {
196        (0..N).all(|i| self[i] == other[i])
197    }
198}
199
200impl<'a, V: Index<usize, Output = T>, T: Eq, const N: usize, const M: usize> Eq
201    for RowVectorViewMut<'a, V, T, N, M>
202{
203}
204
205/////////////////////////////////////
206//  Display Trait Implementations  //
207/////////////////////////////////////
208
209impl<'a, V: Index<usize, Output = T>, T: fmt::Display, const N: usize, const M: usize> fmt::Display
210    for RowVectorViewMut<'a, V, T, N, M>
211{
212    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
213        if f.alternate() {
214            write!(f, "RowVectorViewMut(")?;
215        }
216
217        write!(f, "[")?;
218        for i in 0..M - 1 {
219            write!(f, "{}, ", self[i])?;
220        }
221        write!(f, "{}]", self[M - 1])?;
222
223        if f.alternate() {
224            write!(f, ", dtype={})", std::any::type_name::<T>())?;
225        }
226
227        Ok(())
228    }
229}
230
231///////////////////////////////////
232//  Index Trait Implementations  //
233///////////////////////////////////
234
235impl<'a, V: Index<usize, Output = T>, T, const N: usize, const M: usize> Index<usize>
236    for RowVectorViewMut<'a, V, T, N, M>
237{
238    type Output = T;
239
240    fn index(&self, index: usize) -> &Self::Output {
241        &self.data[self.start + index]
242    }
243}
244
245impl<'a, V: IndexMut<usize, Output = T>, T, const N: usize, const M: usize> IndexMut<usize>
246    for RowVectorViewMut<'a, V, T, N, M>
247{
248    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
249        &mut self.data[self.start + index]
250    }
251}
252
253impl<V: Index<usize, Output = T>, T, const N: usize, const M: usize> Index<(usize, usize)>
254    for RowVectorViewMut<'_, V, T, N, M>
255{
256    type Output = T;
257
258    fn index(&self, index: (usize, usize)) -> &Self::Output {
259        if index.0 != 0 {
260            panic!("Index out of bounds");
261        }
262        &self.data[self.start + index.1]
263    }
264}
265
266impl<V: IndexMut<usize, Output = T>, T, const N: usize, const M: usize> IndexMut<(usize, usize)>
267    for RowVectorViewMut<'_, V, T, N, M>
268{
269    fn index_mut(&mut self, index: (usize, usize)) -> &mut Self::Output {
270        if index.0 != 0 {
271            panic!("Index out of bounds");
272        }
273        &mut self.data[self.start + index.1]
274    }
275}