faer/row/
rowref.rs

1use super::*;
2use crate::utils::bound::{Array, Dim, Partition};
3use crate::{ContiguousFwd, Idx, IdxInc};
4use equator::{assert, debug_assert};
5use faer_traits::Real;
6
7/// see [`super::RowRef`]
8pub struct Ref<'a, T, Cols = usize, CStride = isize> {
9	pub(crate) trans: ColRef<'a, T, Cols, CStride>,
10}
11
12impl<T, Rows: Copy, CStride: Copy> Copy for Ref<'_, T, Rows, CStride> {}
13impl<T, Rows: Copy, CStride: Copy> Clone for Ref<'_, T, Rows, CStride> {
14	#[inline]
15	fn clone(&self) -> Self {
16		*self
17	}
18}
19
20impl<'short, T, Rows: Copy, CStride: Copy> Reborrow<'short> for Ref<'_, T, Rows, CStride> {
21	type Target = Ref<'short, T, Rows, CStride>;
22
23	#[inline]
24	fn rb(&'short self) -> Self::Target {
25		*self
26	}
27}
28impl<'short, T, Rows: Copy, CStride: Copy> ReborrowMut<'short> for Ref<'_, T, Rows, CStride> {
29	type Target = Ref<'short, T, Rows, CStride>;
30
31	#[inline]
32	fn rb_mut(&'short mut self) -> Self::Target {
33		*self
34	}
35}
36impl<'a, T, Rows: Copy, CStride: Copy> IntoConst for Ref<'a, T, Rows, CStride> {
37	type Target = Ref<'a, T, Rows, CStride>;
38
39	#[inline]
40	fn into_const(self) -> Self::Target {
41		self
42	}
43}
44
45unsafe impl<T: Sync, Rows: Sync, CStride: Sync> Sync for Ref<'_, T, Rows, CStride> {}
46unsafe impl<T: Sync, Rows: Send, CStride: Send> Send for Ref<'_, T, Rows, CStride> {}
47
48impl<'a, T> RowRef<'a, T> {
49	/// creates a row view over the given element
50	#[inline]
51	pub fn from_ref(value: &'a T) -> Self {
52		unsafe { RowRef::from_raw_parts(value as *const T, 1, 1) }
53	}
54
55	/// creates a `RowRef` from slice views over the row vector data, the result has the same
56	/// number of columns as the length of the input slice
57	#[inline]
58	pub fn from_slice(slice: &'a [T]) -> Self {
59		let len = slice.len();
60		unsafe { Self::from_raw_parts(slice.as_ptr(), len, 1) }
61	}
62}
63
64impl<'a, T, Cols: Shape, CStride: Stride> RowRef<'a, T, Cols, CStride> {
65	/// creates a `RowRef` from pointers to the column vector data, number of rows, and row stride
66	///
67	/// # safety
68	/// this function has the same safety requirements as
69	/// [`MatRef::from_raw_parts(ptr, 1, ncols, 0, col_stride)`]
70	#[inline(always)]
71	#[track_caller]
72	pub const unsafe fn from_raw_parts(ptr: *const T, ncols: Cols, col_stride: CStride) -> Self {
73		Self {
74			0: Ref {
75				trans: ColRef::from_raw_parts(ptr, ncols, col_stride),
76			},
77		}
78	}
79
80	/// returns a pointer to the row data
81	#[inline(always)]
82	pub fn as_ptr(&self) -> *const T {
83		self.trans.as_ptr()
84	}
85
86	/// returns the number of rows of the row (always 1)
87	#[inline(always)]
88	pub fn nrows(&self) -> usize {
89		1
90	}
91
92	/// returns the number of columns of the row
93	#[inline(always)]
94	pub fn ncols(&self) -> Cols {
95		self.trans.nrows()
96	}
97
98	/// returns the number of rows and columns of the row
99	#[inline(always)]
100	pub fn shape(&self) -> (usize, Cols) {
101		(self.nrows(), self.ncols())
102	}
103
104	/// returns the column stride of the row
105	#[inline(always)]
106	pub fn col_stride(&self) -> CStride {
107		self.trans.row_stride()
108	}
109
110	/// returns a raw pointer to the element at the given index
111	#[inline(always)]
112	pub fn ptr_at(&self, col: IdxInc<Cols>) -> *const T {
113		self.trans.ptr_at(col)
114	}
115
116	/// returns a raw pointer to the element at the given index, assuming the provided index
117	/// is within the row bounds
118	///
119	/// # safety
120	/// the behavior is undefined if any of the following conditions are violated:
121	/// * `col < self.ncols()`
122	#[inline(always)]
123	#[track_caller]
124	pub unsafe fn ptr_inbounds_at(&self, col: Idx<Cols>) -> *const T {
125		debug_assert!(all(col < self.ncols()));
126		self.trans.ptr_inbounds_at(col)
127	}
128
129	/// splits the row vertically at the given column into two parts and returns an array of
130	/// each subrow, in the following order:
131	/// * left
132	/// * right
133	///
134	/// # panics
135	/// the function panics if the following condition is violated:
136	/// * `col <= self.ncols()`
137	#[inline]
138	#[track_caller]
139	pub fn split_at_col(self, col: IdxInc<Cols>) -> (RowRef<'a, T, usize, CStride>, RowRef<'a, T, usize, CStride>) {
140		assert!(all(col <= self.ncols()));
141		let rs = self.col_stride();
142
143		let top = self.as_ptr();
144		let bot = self.ptr_at(col);
145		unsafe {
146			(
147				RowRef::from_raw_parts(top, col.unbound(), rs),
148				RowRef::from_raw_parts(bot, self.ncols().unbound() - col.unbound(), rs),
149			)
150		}
151	}
152
153	/// returns a view over the transpose of `self`
154	#[inline(always)]
155	pub fn transpose(self) -> ColRef<'a, T, Cols, CStride> {
156		self.trans
157	}
158
159	/// returns a view over the conjugate of `self`
160	#[inline(always)]
161	pub fn conjugate(self) -> RowRef<'a, T::Conj, Cols, CStride>
162	where
163		T: Conjugate,
164	{
165		RowRef {
166			0: Ref {
167				trans: self.trans.conjugate(),
168			},
169		}
170	}
171
172	/// returns an unconjugated view over `self`
173	#[inline(always)]
174	pub fn canonical(self) -> RowRef<'a, T::Canonical, Cols, CStride>
175	where
176		T: Conjugate,
177	{
178		RowRef {
179			0: Ref {
180				trans: self.trans.canonical(),
181			},
182		}
183	}
184
185	/// returns a view over the conjugate transpose of `self`
186	#[inline(always)]
187	pub fn adjoint(self) -> ColRef<'a, T::Conj, Cols, CStride>
188	where
189		T: Conjugate,
190	{
191		self.conjugate().transpose()
192	}
193
194	#[inline(always)]
195	#[track_caller]
196	pub(crate) fn at(self, col: Idx<Cols>) -> &'a T {
197		assert!(all(col < self.ncols()));
198		unsafe { self.at_unchecked(col) }
199	}
200
201	#[inline(always)]
202	#[track_caller]
203	pub(crate) unsafe fn at_unchecked(self, col: Idx<Cols>) -> &'a T {
204		&*self.ptr_inbounds_at(col)
205	}
206
207	/// returns a reference to the element at the given index, or a subrow if
208	/// `col` is a range, with bound checks
209	///
210	/// # panics
211	/// the function panics if any of the following conditions are violated:
212	/// * `col` must be contained in `[0, self.ncols())`
213	#[track_caller]
214	#[inline(always)]
215	pub fn get<ColRange>(self, col: ColRange) -> <RowRef<'a, T, Cols, CStride> as RowIndex<ColRange>>::Target
216	where
217		RowRef<'a, T, Cols, CStride>: RowIndex<ColRange>,
218	{
219		<RowRef<'a, T, Cols, CStride> as RowIndex<ColRange>>::get(self, col)
220	}
221
222	/// returns a reference to the element at the given index, or a subrow if
223	/// `col` is a range, without bound checks
224	///
225	/// # panics
226	/// the behavior is undefined if any of the following conditions are violated:
227	/// * `col` must be contained in `[0, self.ncols())`
228	#[track_caller]
229	#[inline(always)]
230	pub unsafe fn get_unchecked<ColRange>(self, col: ColRange) -> <RowRef<'a, T, Cols, CStride> as RowIndex<ColRange>>::Target
231	where
232		RowRef<'a, T, Cols, CStride>: RowIndex<ColRange>,
233	{
234		unsafe { <RowRef<'a, T, Cols, CStride> as RowIndex<ColRange>>::get_unchecked(self, col) }
235	}
236
237	/// returns a view over the `self`, with the columns in reversed order
238	#[inline]
239	pub fn reverse_cols(self) -> RowRef<'a, T, Cols, CStride::Rev> {
240		RowRef {
241			0: Ref {
242				trans: self.trans.reverse_rows(),
243			},
244		}
245	}
246
247	/// returns a view over the subrow starting at column `col_start`, and with number of
248	/// columns `ncols`
249	///
250	/// # panics
251	/// the function panics if any of the following conditions are violated:
252	/// * `col_start <= self.ncols()`
253	/// * `ncols <= self.ncols() - col_start`
254	#[inline]
255	pub fn subcols<V: Shape>(self, col_start: IdxInc<Cols>, ncols: V) -> RowRef<'a, T, V, CStride> {
256		assert!(all(col_start <= self.ncols()));
257		{
258			let ncols = ncols.unbound();
259			let full_ncols = self.ncols().unbound();
260			let col_start = col_start.unbound();
261			assert!(all(ncols <= full_ncols - col_start));
262		}
263		let cs = self.col_stride();
264		unsafe { RowRef::from_raw_parts(self.ptr_at(col_start), ncols, cs) }
265	}
266
267	/// returns the input row with the given column shape after checking that it matches the
268	/// current column shape
269	#[inline]
270	#[track_caller]
271	pub fn as_col_shape<V: Shape>(self, ncols: V) -> RowRef<'a, T, V, CStride> {
272		assert!(all(self.ncols().unbound() == ncols.unbound()));
273		unsafe { RowRef::from_raw_parts(self.as_ptr(), ncols, self.col_stride()) }
274	}
275
276	/// returns the input row with dynamic column shape
277	#[inline]
278	pub fn as_dyn_cols(self) -> RowRef<'a, T, usize, CStride> {
279		unsafe { RowRef::from_raw_parts(self.as_ptr(), self.ncols().unbound(), self.col_stride()) }
280	}
281
282	/// returns the input row with dynamic stride
283	#[inline]
284	pub fn as_dyn_stride(self) -> RowRef<'a, T, Cols, isize> {
285		unsafe { RowRef::from_raw_parts(self.as_ptr(), self.ncols(), self.col_stride().element_stride()) }
286	}
287
288	/// returns an iterator over the elements of the row
289	#[inline]
290	pub fn iter(self) -> impl 'a + ExactSizeIterator + DoubleEndedIterator<Item = &'a T>
291	where
292		Cols: 'a,
293	{
294		self.trans.iter()
295	}
296
297	/// returns a parallel iterator over the elements of the row
298	#[inline]
299	#[cfg(feature = "rayon")]
300	pub fn par_iter(self) -> impl 'a + rayon::iter::IndexedParallelIterator<Item = &'a T>
301	where
302		T: Sync,
303		Cols: 'a,
304	{
305		self.trans.par_iter()
306	}
307
308	/// returns a parallel iterator that provides exactly `count` successive chunks of the elements
309	/// of this row
310	///
311	/// only available with the `rayon` feature
312	#[inline]
313	#[track_caller]
314	#[cfg(feature = "rayon")]
315	pub fn par_partition(self, count: usize) -> impl 'a + rayon::iter::IndexedParallelIterator<Item = RowRef<'a, T, usize, CStride>>
316	where
317		T: Sync,
318		Cols: 'a,
319	{
320		use rayon::prelude::*;
321		self.transpose().par_partition(count).map(ColRef::transpose)
322	}
323
324	/// returns a view over the row with a static column stride equal to `+1`, or `None` otherwise
325	#[inline]
326	pub fn try_as_row_major(self) -> Option<RowRef<'a, T, Cols, ContiguousFwd>> {
327		if self.col_stride().element_stride() == 1 {
328			Some(unsafe { RowRef::from_raw_parts(self.as_ptr(), self.ncols(), ContiguousFwd) })
329		} else {
330			None
331		}
332	}
333
334	#[inline(always)]
335	#[doc(hidden)]
336	pub unsafe fn const_cast(self) -> RowMut<'a, T, Cols, CStride> {
337		RowMut {
338			0: Mut {
339				trans: self.trans.const_cast(),
340			},
341		}
342	}
343
344	/// returns a matrix view over `self`
345	#[inline]
346	pub fn as_mat(self) -> MatRef<'a, T, usize, Cols, isize, CStride> {
347		self.transpose().as_mat().transpose()
348	}
349
350	/// interprets the row as a diagonal matrix
351	#[inline]
352	pub fn as_diagonal(self) -> DiagRef<'a, T, Cols, CStride> {
353		DiagRef {
354			0: crate::diag::Ref { inner: self.trans },
355		}
356	}
357
358	#[inline]
359	pub(crate) fn __at(self, i: Idx<Cols>) -> &'a T {
360		self.at(i)
361	}
362}
363
364impl<T, Cols: Shape, CStride: Stride, Inner: for<'short> Reborrow<'short, Target = Ref<'short, T, Cols, CStride>>> generic::Row<Inner> {
365	/// returns a view over `self`
366	#[inline]
367	pub fn as_ref(&self) -> RowRef<'_, T, Cols, CStride> {
368		self.rb()
369	}
370
371	/// returns the maximum norm of `self`
372	#[inline]
373	pub fn norm_max(&self) -> Real<T>
374	where
375		T: Conjugate,
376	{
377		self.rb().as_mat().norm_max()
378	}
379
380	/// returns the l2 norm of `self`
381	#[inline]
382	pub fn norm_l2(&self) -> Real<T>
383	where
384		T: Conjugate,
385	{
386		self.rb().as_mat().norm_l2()
387	}
388
389	/// returns the squared l2 norm of `self`
390	#[inline]
391	pub fn squared_norm_l2(&self) -> Real<T>
392	where
393		T: Conjugate,
394	{
395		self.rb().as_mat().squared_norm_l2()
396	}
397
398	/// returns the l1 norm of `self`
399	#[inline]
400	pub fn norm_l1(&self) -> Real<T>
401	where
402		T: Conjugate,
403	{
404		self.rb().as_mat().norm_l1()
405	}
406
407	/// returns the sum of the elements of `self`
408	#[inline]
409	pub fn sum(&self) -> T::Canonical
410	where
411		T: Conjugate,
412	{
413		self.rb().as_mat().sum()
414	}
415
416	/// see [`Mat::kron`]
417	#[inline]
418	pub fn kron(&self, rhs: impl AsMatRef<T: Conjugate<Canonical = T::Canonical>>) -> Mat<T::Canonical>
419	where
420		T: Conjugate,
421	{
422		fn imp<T: ComplexField>(lhs: MatRef<impl Conjugate<Canonical = T>>, rhs: MatRef<impl Conjugate<Canonical = T>>) -> Mat<T> {
423			let mut out = Mat::zeros(lhs.nrows() * rhs.nrows(), lhs.ncols() * rhs.ncols());
424			linalg::kron::kron(out.rb_mut(), lhs, rhs);
425			out
426		}
427
428		imp(self.rb().as_mat().as_dyn().as_dyn_stride(), rhs.as_mat_ref().as_dyn().as_dyn_stride())
429	}
430
431	/// returns `true` if all of the elements of `self` are finite.
432	/// otherwise returns `false`.
433	#[inline]
434	pub fn is_all_finite(&self) -> bool
435	where
436		T: Conjugate,
437	{
438		self.rb().transpose().is_all_finite()
439	}
440
441	/// returns `true` if any of the elements of `self` is `NaN`.
442	/// otherwise returns `false`.
443	#[inline]
444	pub fn has_nan(&self) -> bool
445	where
446		T: Conjugate,
447	{
448		self.rb().transpose().has_nan()
449	}
450
451	/// returns a newly allocated row holding the cloned values of `self`
452	#[inline]
453	pub fn cloned(&self) -> Row<T, Cols>
454	where
455		T: Clone,
456	{
457		self.rb().transpose().cloned().into_transpose()
458	}
459
460	/// returns a newly allocated row holding the (possibly conjugated) values of `self`
461	#[inline]
462	pub fn to_owned(&self) -> Row<T::Canonical, Cols>
463	where
464		T: Conjugate,
465	{
466		self.rb().transpose().to_owned().into_transpose()
467	}
468}
469
470impl<'a, T, Rows: Shape> RowRef<'a, T, Rows, ContiguousFwd> {
471	/// returns a reference over the elements as a slice
472	#[inline]
473	pub fn as_slice(self) -> &'a [T] {
474		self.transpose().as_slice()
475	}
476}
477
478impl<'a, 'ROWS, T> RowRef<'a, T, Dim<'ROWS>, ContiguousFwd> {
479	/// returns a reference over the elements as a lifetime-bound slice
480	#[inline]
481	pub fn as_array(self) -> &'a Array<'ROWS, T> {
482		self.transpose().as_array()
483	}
484}
485
486impl<'COLS, 'a, T, CStride: Stride> RowRef<'a, T, Dim<'COLS>, CStride> {
487	#[doc(hidden)]
488	#[inline]
489	pub fn split_cols_with<'LEFT, 'RIGHT>(
490		self,
491		col: Partition<'LEFT, 'RIGHT, 'COLS>,
492	) -> (RowRef<'a, T, Dim<'LEFT>, CStride>, RowRef<'a, T, Dim<'RIGHT>, CStride>) {
493		let (a, b) = self.split_at_col(col.midpoint());
494		(a.as_col_shape(col.head), b.as_col_shape(col.tail))
495	}
496}
497
498impl<T: core::fmt::Debug, Cols: Shape, CStride: Stride> core::fmt::Debug for Ref<'_, T, Cols, CStride> {
499	fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
500		fn imp<T: core::fmt::Debug>(f: &mut core::fmt::Formatter<'_>, this: RowRef<'_, T, Dim<'_>>) -> core::fmt::Result {
501			f.debug_list()
502				.entries(this.ncols().indices().map(|j| crate::hacks::hijack_debug(this.at(j))))
503				.finish()
504		}
505
506		let this = generic::Row::from_inner_ref(self);
507
508		with_dim!(N, this.ncols().unbound());
509		imp(f, this.as_col_shape(N).as_dyn_stride())
510	}
511}
512
513impl<'a, T> RowRef<'a, T, usize, isize>
514where
515	T: RealField,
516{
517	/// Returns the maximum element in the row, or `None` if the row is empty
518	pub(crate) fn internal_max(self) -> Option<T> {
519		if self.nrows().unbound() == 0 || self.ncols() == 0 {
520			return None;
521		}
522
523		let mut max_val = self.get(0);
524
525		self.iter().for_each(|val| {
526			if val > max_val {
527				max_val = val;
528			}
529		});
530
531		Some((*max_val).clone())
532	}
533
534	/// Returns the minimum element in the row, or `None` if the row is empty
535	pub(crate) fn internal_min(self) -> Option<T> {
536		if self.nrows().unbound() == 0 || self.ncols() == 0 {
537			return None;
538		}
539
540		let mut min_val = self.get(0);
541
542		self.iter().for_each(|val| {
543			if val < min_val {
544				min_val = val;
545			}
546		});
547
548		Some((*min_val).clone())
549	}
550}
551
552impl<'a, T, Cols: Shape, CStride: Stride> RowRef<'a, T, Cols, CStride>
553where
554	T: RealField,
555{
556	/// Returns the maximum element in the row, or `None` if the row is empty
557	pub fn max(&self) -> Option<T> {
558		self.as_dyn_cols().as_dyn_stride().internal_max()
559	}
560
561	/// Returns the minimum element in the row, or `None` if the row is empty
562	pub fn min(&self) -> Option<T> {
563		self.as_dyn_cols().as_dyn_stride().internal_min()
564	}
565}
566
567#[cfg(test)]
568mod tests {
569	use crate::Row;
570
571	#[test]
572	fn test_row_min() {
573		let row: Row<f64> = Row::from_fn(5, |x| (x + 1) as f64);
574		let rowref = row.as_ref();
575		assert_eq!(rowref.min(), Some(1.0));
576
577		let empty: Row<f64> = Row::from_fn(0, |_| 0.0);
578		let emptyref = empty.as_ref();
579		assert_eq!(emptyref.min(), None);
580	}
581
582	#[test]
583	fn test_row_max() {
584		let row: Row<f64> = Row::from_fn(5, |x| (x + 1) as f64);
585		let rowref = row.as_ref();
586		assert_eq!(rowref.max(), Some(5.0));
587
588		let empty: Row<f64> = Row::from_fn(0, |_| 0.0);
589		let emptyref = empty.as_ref();
590		assert_eq!(emptyref.max(), None);
591	}
592}