1use crate::row_vector::RowVector;
2use crate::row_vector_view::RowVectorView;
3use crate::traits::DotProduct;
4use crate::vector_view::VectorView;
5use crate::vector_view_mut::VectorViewMut;
6use num_traits::Float;
7use std::fmt;
8use std::marker::PhantomData;
9use std::ops::{Index, IndexMut};
10
11#[derive(Debug)]
13pub struct RowVectorViewMut<'a, V, T, const N: usize, const M: usize> {
14 data: &'a mut V,
15 start: usize,
16 _phantom: PhantomData<T>,
17}
18
19impl<'a, V, T, const N: usize, const M: usize> RowVectorViewMut<'a, V, T, N, M> {
20 pub(super) fn new(data: &'a mut V, start: usize) -> Self {
21 Self {
22 data,
23 start,
24 _phantom: PhantomData,
25 }
26 }
27
28 #[inline]
42 pub fn shape(&self) -> usize {
43 M
44 }
45
46 #[inline]
60 pub fn capacity(&self) -> usize {
61 M
62 }
63
64 #[inline]
78 pub fn rows(&self) -> usize {
79 1
80 }
81
82 #[inline]
96 pub fn cols(&self) -> usize {
97 M
98 }
99
100 pub fn t(&'a self) -> VectorView<'a, V, T, N, M> {
115 VectorView::new(self.data, self.start)
116 }
117
118 pub fn t_mut(&'a mut self) -> VectorViewMut<'a, V, T, N, M> {
134 VectorViewMut::new(self.data, self.start)
135 }
136}
137
138impl<'a, V: Index<usize, Output = T>, T: Float, const N: usize, const M: usize>
139 RowVectorViewMut<'a, V, T, N, M>
140{
141 pub fn magnitude(&self) -> T {
153 self.dot(self).sqrt()
154 }
155}
156
157impl<T: PartialEq, const N: usize, V: Index<usize, Output = T>, const A: usize>
163 PartialEq<RowVector<T, N>> for RowVectorViewMut<'_, V, T, A, N>
164{
165 fn eq(&self, other: &RowVector<T, N>) -> bool {
166 (0..N).all(|i| self[i] == other[i])
167 }
168}
169
170impl<
172 T: PartialEq,
173 V1: Index<usize, Output = T>,
174 V2: Index<usize, Output = T>,
175 const A1: usize,
176 const A2: usize,
177 const N: usize,
178 > PartialEq<RowVectorView<'_, V2, T, A2, N>> for RowVectorViewMut<'_, V1, T, A1, N>
179{
180 fn eq(&self, other: &RowVectorView<'_, V2, T, A2, N>) -> bool {
181 (0..N).all(|i| self[i] == other[i])
182 }
183}
184
185impl<
187 T: PartialEq,
188 V1: Index<usize, Output = T>,
189 V2: Index<usize, Output = T>,
190 const A1: usize,
191 const A2: usize,
192 const N: usize,
193 > PartialEq<RowVectorViewMut<'_, V2, T, A2, N>> for RowVectorViewMut<'_, V1, T, A1, N>
194{
195 fn eq(&self, other: &RowVectorViewMut<'_, V2, T, A2, N>) -> bool {
196 (0..N).all(|i| self[i] == other[i])
197 }
198}
199
200impl<'a, V: Index<usize, Output = T>, T: Eq, const N: usize, const M: usize> Eq
201 for RowVectorViewMut<'a, V, T, N, M>
202{
203}
204
205impl<'a, V: Index<usize, Output = T>, T: fmt::Display, const N: usize, const M: usize> fmt::Display
210 for RowVectorViewMut<'a, V, T, N, M>
211{
212 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
213 if f.alternate() {
214 write!(f, "RowVectorViewMut(")?;
215 }
216
217 write!(f, "[")?;
218 for i in 0..M - 1 {
219 write!(f, "{}, ", self[i])?;
220 }
221 write!(f, "{}]", self[M - 1])?;
222
223 if f.alternate() {
224 write!(f, ", dtype={})", std::any::type_name::<T>())?;
225 }
226
227 Ok(())
228 }
229}
230
231impl<'a, V: Index<usize, Output = T>, T, const N: usize, const M: usize> Index<usize>
236 for RowVectorViewMut<'a, V, T, N, M>
237{
238 type Output = T;
239
240 fn index(&self, index: usize) -> &Self::Output {
241 &self.data[self.start + index]
242 }
243}
244
245impl<'a, V: IndexMut<usize, Output = T>, T, const N: usize, const M: usize> IndexMut<usize>
246 for RowVectorViewMut<'a, V, T, N, M>
247{
248 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
249 &mut self.data[self.start + index]
250 }
251}
252
253impl<V: Index<usize, Output = T>, T, const N: usize, const M: usize> Index<(usize, usize)>
254 for RowVectorViewMut<'_, V, T, N, M>
255{
256 type Output = T;
257
258 fn index(&self, index: (usize, usize)) -> &Self::Output {
259 if index.0 != 0 {
260 panic!("Index out of bounds");
261 }
262 &self.data[self.start + index.1]
263 }
264}
265
266impl<V: IndexMut<usize, Output = T>, T, const N: usize, const M: usize> IndexMut<(usize, usize)>
267 for RowVectorViewMut<'_, V, T, N, M>
268{
269 fn index_mut(&mut self, index: (usize, usize)) -> &mut Self::Output {
270 if index.0 != 0 {
271 panic!("Index out of bounds");
272 }
273 &mut self.data[self.start + index.1]
274 }
275}