faer/col/
colref.rs

1use super::*;
2use crate::utils::bound::{Array, Dim, Partition};
3use crate::{ContiguousFwd, Idx, IdxInc};
4use core::marker::PhantomData;
5use core::ptr::NonNull;
6use equator::assert;
7use faer_traits::Real;
8use generativity::Guard;
9
10/// see [`super::ColRef`]
11pub struct Ref<'a, T, Rows = usize, RStride = isize> {
12	pub(super) imp: ColView<T, Rows, RStride>,
13	pub(super) __marker: PhantomData<&'a T>,
14}
15
16impl<T, Rows: Copy, RStride: Copy> Copy for Ref<'_, T, Rows, RStride> {}
17impl<T, Rows: Copy, RStride: Copy> Clone for Ref<'_, T, Rows, RStride> {
18	#[inline]
19	fn clone(&self) -> Self {
20		*self
21	}
22}
23
24impl<'short, T, Rows: Copy, RStride: Copy> Reborrow<'short> for Ref<'_, T, Rows, RStride> {
25	type Target = Ref<'short, T, Rows, RStride>;
26
27	#[inline]
28	fn rb(&'short self) -> Self::Target {
29		*self
30	}
31}
32impl<'short, T, Rows: Copy, RStride: Copy> ReborrowMut<'short> for Ref<'_, T, Rows, RStride> {
33	type Target = Ref<'short, T, Rows, RStride>;
34
35	#[inline]
36	fn rb_mut(&'short mut self) -> Self::Target {
37		*self
38	}
39}
40impl<'a, T, Rows: Copy, RStride: Copy> IntoConst for Ref<'a, T, Rows, RStride> {
41	type Target = Ref<'a, T, Rows, RStride>;
42
43	#[inline]
44	fn into_const(self) -> Self::Target {
45		self
46	}
47}
48
49unsafe impl<T: Sync, Rows: Sync, RStride: Sync> Sync for Ref<'_, T, Rows, RStride> {}
50unsafe impl<T: Sync, Rows: Send, RStride: Send> Send for Ref<'_, T, Rows, RStride> {}
51
52impl<'a, T> ColRef<'a, T> {
53	/// creates a column view over the given element
54	#[inline]
55	pub fn from_ref(value: &'a T) -> Self {
56		unsafe { ColRef::from_raw_parts(value as *const T, 1, 1) }
57	}
58
59	/// creates a `ColRef` from slice views over the column vector data, the result has the same
60	/// number of rows as the length of the input slice
61	#[inline]
62	pub fn from_slice(slice: &'a [T]) -> Self {
63		let len = slice.len();
64		unsafe { Self::from_raw_parts(slice.as_ptr(), len, 1) }
65	}
66}
67
68impl<'a, T, Rows: Shape, RStride: Stride> ColRef<'a, T, Rows, RStride> {
69	/// creates a `ColRef` from pointers to the column vector data, number of rows, and row stride
70	///
71	/// # safety
72	/// this function has the same safety requirements as
73	/// [`MatRef::from_raw_parts(ptr, nrows, 1, row_stride, 0)`]
74	#[inline(always)]
75	#[track_caller]
76	pub const unsafe fn from_raw_parts(ptr: *const T, nrows: Rows, row_stride: RStride) -> Self {
77		Self {
78			0: Ref {
79				imp: ColView {
80					ptr: NonNull::new_unchecked(ptr as *mut T),
81					nrows,
82					row_stride,
83				},
84				__marker: PhantomData,
85			},
86		}
87	}
88
89	/// returns a pointer to the column data
90	#[inline(always)]
91	pub fn as_ptr(&self) -> *const T {
92		self.imp.ptr.as_ptr() as *const T
93	}
94
95	/// returns the number of rows of the column
96	#[inline(always)]
97	pub fn nrows(&self) -> Rows {
98		self.imp.nrows
99	}
100
101	/// returns the number of columns of the column (always `1`)
102	#[inline(always)]
103	pub fn ncols(&self) -> usize {
104		1
105	}
106
107	/// returns the number of rows and columns of the column
108	#[inline(always)]
109	pub fn shape(&self) -> (Rows, usize) {
110		(self.nrows(), self.ncols())
111	}
112
113	/// returns the row stride of the column, specified in number of elements, not in bytes
114	#[inline(always)]
115	pub fn row_stride(&self) -> RStride {
116		self.imp.row_stride
117	}
118
119	/// returns a raw pointer to the element at the given index
120	#[inline(always)]
121	pub fn ptr_at(&self, row: IdxInc<Rows>) -> *const T {
122		let ptr = self.as_ptr();
123
124		if row >= self.nrows() {
125			ptr
126		} else {
127			ptr.wrapping_offset(row.unbound() as isize * self.row_stride().element_stride())
128		}
129	}
130
131	/// returns a raw pointer to the element at the given index, assuming the provided index
132	/// is within the column bounds
133	///
134	/// # safety
135	/// the behavior is undefined if any of the following conditions are violated:
136	/// * `row < self.nrows()`
137	#[inline(always)]
138	#[track_caller]
139	pub unsafe fn ptr_inbounds_at(&self, row: Idx<Rows>) -> *const T {
140		self.as_ptr().offset(row.unbound() as isize * self.row_stride().element_stride())
141	}
142
143	/// splits the column horizontally at the given row into two parts and returns an array of
144	/// each submatrix, in the following order:
145	/// * top
146	/// * bottom
147	///
148	/// # panics
149	/// the function panics if the following condition is violated:
150	/// * `row <= self.nrows()`
151	#[inline]
152	#[track_caller]
153	pub fn split_at_row(self, row: IdxInc<Rows>) -> (ColRef<'a, T, usize, RStride>, ColRef<'a, T, usize, RStride>) {
154		assert!(all(row <= self.nrows()));
155		let rs = self.row_stride();
156
157		let top = self.as_ptr();
158		let bot = self.ptr_at(row);
159		unsafe {
160			(
161				ColRef::from_raw_parts(top, row.unbound(), rs),
162				ColRef::from_raw_parts(bot, self.nrows().unbound() - row.unbound(), rs),
163			)
164		}
165	}
166
167	/// returns a view over the transpose of `self`
168	#[inline(always)]
169	pub fn transpose(self) -> RowRef<'a, T, Rows, RStride> {
170		RowRef {
171			0: crate::row::Ref { trans: self },
172		}
173	}
174
175	/// returns a view over the conjugate of `self`
176	#[inline(always)]
177	pub fn conjugate(self) -> ColRef<'a, T::Conj, Rows, RStride>
178	where
179		T: Conjugate,
180	{
181		unsafe { ColRef::from_raw_parts(self.as_ptr() as *const T::Conj, self.nrows(), self.row_stride()) }
182	}
183
184	/// returns an unconjugated view over `self`
185	#[inline(always)]
186	pub fn canonical(self) -> ColRef<'a, T::Canonical, Rows, RStride>
187	where
188		T: Conjugate,
189	{
190		unsafe { ColRef::from_raw_parts(self.as_ptr() as *const T::Canonical, self.nrows(), self.row_stride()) }
191	}
192
193	/// returns a view over the conjugate transpose of `self`
194	#[inline(always)]
195	pub fn adjoint(self) -> RowRef<'a, T::Conj, Rows, RStride>
196	where
197		T: Conjugate,
198	{
199		self.conjugate().transpose()
200	}
201
202	#[inline(always)]
203	#[track_caller]
204	pub(crate) fn at(self, row: Idx<Rows>) -> &'a T {
205		assert!(all(row < self.nrows()));
206		unsafe { self.at_unchecked(row) }
207	}
208
209	#[inline(always)]
210	#[track_caller]
211	pub(crate) unsafe fn at_unchecked(self, row: Idx<Rows>) -> &'a T {
212		&*self.ptr_inbounds_at(row)
213	}
214
215	/// returns a reference to the element at the given index, or a subcolumn if `row` is a range
216	///
217	/// # panics
218	/// the function panics if any of the following conditions are violated:
219	/// * `row` must be contained in `[0, self.nrows())`
220	#[track_caller]
221	#[inline(always)]
222	pub fn get<RowRange>(self, row: RowRange) -> <ColRef<'a, T, Rows, RStride> as ColIndex<RowRange>>::Target
223	where
224		ColRef<'a, T, Rows, RStride>: ColIndex<RowRange>,
225	{
226		<ColRef<'a, T, Rows, RStride> as ColIndex<RowRange>>::get(self, row)
227	}
228
229	/// returns a reference to the element at the given index, or a subcolumn if `row` is a range,
230	/// without bound checks
231	///
232	/// # safety
233	/// the behavior is undefined if any of the following conditions are violated:
234	/// * `row` must be contained in `[0, self.nrows())`
235	#[track_caller]
236	#[inline(always)]
237	pub unsafe fn get_unchecked<RowRange>(self, row: RowRange) -> <ColRef<'a, T, Rows, RStride> as ColIndex<RowRange>>::Target
238	where
239		ColRef<'a, T, Rows, RStride>: ColIndex<RowRange>,
240	{
241		unsafe { <ColRef<'a, T, Rows, RStride> as ColIndex<RowRange>>::get_unchecked(self, row) }
242	}
243
244	/// returns a view over the `self`, with the rows in reversed order
245	#[inline]
246	pub fn reverse_rows(self) -> ColRef<'a, T, Rows, RStride::Rev> {
247		let row = unsafe { IdxInc::<Rows>::new_unbound(self.nrows().unbound().saturating_sub(1)) };
248		let ptr = self.ptr_at(row);
249		unsafe { ColRef::from_raw_parts(ptr, self.nrows(), self.row_stride().rev()) }
250	}
251
252	/// returns a view over the column starting at row `row_start`, and with number of rows
253	/// `nrows`
254	///
255	/// # panics
256	/// the function panics if any of the following conditions are violated:
257	/// * `row_start <= self.nrows()`
258	/// * `nrows <= self.nrows() - row_start`
259	#[inline]
260	#[track_caller]
261	pub fn subrows<V: Shape>(self, row_start: IdxInc<Rows>, nrows: V) -> ColRef<'a, T, V, RStride> {
262		assert!(all(row_start <= self.nrows()));
263		{
264			let nrows = nrows.unbound();
265			let full_nrows = self.nrows().unbound();
266			let row_start = row_start.unbound();
267			assert!(all(nrows <= full_nrows - row_start));
268		}
269		let rs = self.row_stride();
270
271		unsafe { ColRef::from_raw_parts(self.ptr_at(row_start), nrows, rs) }
272	}
273
274	/// returns the input column with the given row shape after checking that it matches the
275	/// current row shape
276	#[inline]
277	#[track_caller]
278	pub fn as_row_shape<V: Shape>(self, nrows: V) -> ColRef<'a, T, V, RStride> {
279		assert!(all(self.nrows().unbound() == nrows.unbound()));
280		unsafe { ColRef::from_raw_parts(self.as_ptr(), nrows, self.row_stride()) }
281	}
282
283	/// returns the input column with dynamic row shape
284	#[inline]
285	pub fn as_dyn_rows(self) -> ColRef<'a, T, usize, RStride> {
286		unsafe { ColRef::from_raw_parts(self.as_ptr(), self.nrows().unbound(), self.row_stride()) }
287	}
288
289	/// returns the input column with dynamic stride
290	#[inline]
291	pub fn as_dyn_stride(self) -> ColRef<'a, T, Rows, isize> {
292		unsafe { ColRef::from_raw_parts(self.as_ptr(), self.nrows(), self.row_stride().element_stride()) }
293	}
294
295	/// returns an iterator over the elements of the column
296	#[inline]
297	pub fn iter(self) -> impl 'a + ExactSizeIterator + DoubleEndedIterator<Item = &'a T>
298	where
299		Rows: 'a,
300	{
301		Rows::indices(Rows::start(), self.nrows().end()).map(move |j| unsafe { self.at_unchecked(j) })
302	}
303
304	/// returns a parallel iterator over the elements of the column
305	#[inline]
306	#[cfg(feature = "rayon")]
307	pub fn par_iter(self) -> impl 'a + rayon::iter::IndexedParallelIterator<Item = &'a T>
308	where
309		T: Sync,
310		Rows: 'a,
311	{
312		use rayon::prelude::*;
313		(0..self.nrows().unbound())
314			.into_par_iter()
315			.map(move |j| unsafe { self.at_unchecked(Idx::<Rows>::new_unbound(j)) })
316	}
317
318	/// returns a parallel iterator that provides exactly `count` successive chunks of the elements
319	/// of this column
320	///
321	/// only available with the `rayon` feature
322	#[inline]
323	#[track_caller]
324	#[cfg(feature = "rayon")]
325	pub fn par_partition(self, count: usize) -> impl 'a + rayon::iter::IndexedParallelIterator<Item = ColRef<'a, T, usize, RStride>>
326	where
327		T: Sync,
328		Rows: 'a,
329	{
330		use rayon::prelude::*;
331
332		let this = self.as_dyn_rows();
333
334		assert!(count > 0);
335		(0..count).into_par_iter().map(move |chunk_idx| {
336			let (start, len) = crate::utils::thread::par_split_indices(this.nrows(), chunk_idx, count);
337			this.subrows(start, len)
338		})
339	}
340
341	/// returns a view over the column with a static row stride equal to `+1`, or `None` otherwise
342	#[inline]
343	pub fn try_as_col_major(self) -> Option<ColRef<'a, T, Rows, ContiguousFwd>> {
344		if self.row_stride().element_stride() == 1 {
345			Some(unsafe { ColRef::from_raw_parts(self.as_ptr(), self.nrows(), ContiguousFwd) })
346		} else {
347			None
348		}
349	}
350
351	#[inline(always)]
352	#[doc(hidden)]
353	pub unsafe fn const_cast(self) -> ColMut<'a, T, Rows, RStride> {
354		ColMut::from_raw_parts_mut(self.as_ptr() as *mut T, self.nrows(), self.row_stride())
355	}
356
357	/// returns a matrix view over `self`
358	#[inline]
359	pub fn as_mat(self) -> MatRef<'a, T, Rows, usize, RStride, isize> {
360		unsafe { MatRef::from_raw_parts(self.as_ptr(), self.nrows(), self.ncols(), self.row_stride(), 0) }
361	}
362
363	#[inline]
364	#[doc(hidden)]
365	pub fn bind_r<'N>(self, row: Guard<'N>) -> ColRef<'a, T, Dim<'N>, RStride> {
366		unsafe { ColRef::from_raw_parts(self.as_ptr(), self.nrows().bind(row), self.row_stride()) }
367	}
368
369	#[inline(always)]
370	#[track_caller]
371	pub(crate) fn read(&self, row: Idx<Rows>) -> T
372	where
373		T: Clone,
374	{
375		self.at(row).clone()
376	}
377
378	#[inline]
379	#[track_caller]
380	pub(crate) fn __at(self, i: Idx<Rows>) -> &'a T {
381		self.at(i)
382	}
383
384	/// interprets the column as a diagonal matrix
385	#[inline]
386	pub fn as_diagonal(self) -> DiagRef<'a, T, Rows, RStride> {
387		DiagRef {
388			0: crate::diag::Ref { inner: self },
389		}
390	}
391}
392
393impl<T, Rows: Shape, RStride: Stride, Inner: for<'short> Reborrow<'short, Target = Ref<'short, T, Rows, RStride>>> generic::Col<Inner> {
394	/// returns a newly allocated column holding the cloned values of `self`
395	#[inline]
396	pub fn cloned(&self) -> Col<T, Rows>
397	where
398		T: Clone,
399	{
400		fn imp<'M, T: Clone, RStride: Stride>(this: ColRef<'_, T, Dim<'M>, RStride>) -> Col<T, Dim<'M>> {
401			Col::from_fn(this.nrows(), |i| this.at(i).clone())
402		}
403
404		let this = self.rb();
405		with_dim!(M, this.nrows().unbound());
406		imp(this.as_row_shape(M)).into_row_shape(this.nrows())
407	}
408
409	/// returns a newly allocated column holding the (possibly conjugated) values of `self`
410	#[inline]
411	pub fn to_owned(&self) -> Col<T::Canonical, Rows>
412	where
413		T: Conjugate,
414	{
415		fn imp<'M, T, RStride: Stride>(this: ColRef<'_, T, Dim<'M>, RStride>) -> Col<T::Canonical, Dim<'M>>
416		where
417			T: Conjugate,
418		{
419			Col::from_fn(this.nrows(), |i| Conj::apply::<T>(this.at(i)))
420		}
421
422		let this = self.rb();
423		with_dim!(M, this.nrows().unbound());
424		imp(this.as_row_shape(M)).into_row_shape(this.nrows())
425	}
426
427	/// returns the maximum norm of `self`
428	#[inline]
429	pub fn norm_max(&self) -> Real<T>
430	where
431		T: Conjugate,
432	{
433		self.rb().as_mat().norm_max()
434	}
435
436	/// returns the l2 norm of `self`
437	#[inline]
438	pub fn norm_l2(&self) -> Real<T>
439	where
440		T: Conjugate,
441	{
442		self.rb().as_mat().norm_l2()
443	}
444
445	/// returns the squared l2 norm of `self`
446	#[inline]
447	pub fn squared_norm_l2(&self) -> Real<T>
448	where
449		T: Conjugate,
450	{
451		self.rb().as_mat().squared_norm_l2()
452	}
453
454	/// returns the l1 norm of `self`
455	#[inline]
456	pub fn norm_l1(&self) -> Real<T>
457	where
458		T: Conjugate,
459	{
460		self.rb().as_mat().norm_l1()
461	}
462
463	/// returns the sum of the elements of `self`
464	#[inline]
465	pub fn sum(&self) -> T::Canonical
466	where
467		T: Conjugate,
468	{
469		self.rb().as_mat().sum()
470	}
471
472	/// returns a view over `self`
473	#[inline]
474	pub fn as_ref(&self) -> ColRef<'_, T, Rows, RStride> {
475		self.rb()
476	}
477
478	/// see [`Mat::kron`]
479	#[inline]
480	pub fn kron(&self, rhs: impl AsMatRef<T: Conjugate<Canonical = T::Canonical>>) -> Mat<T::Canonical>
481	where
482		T: Conjugate,
483	{
484		fn imp<T: ComplexField>(lhs: MatRef<impl Conjugate<Canonical = T>>, rhs: MatRef<impl Conjugate<Canonical = T>>) -> Mat<T> {
485			let mut out = Mat::zeros(lhs.nrows() * rhs.nrows(), lhs.ncols() * rhs.ncols());
486			linalg::kron::kron(out.rb_mut(), lhs, rhs);
487			out
488		}
489
490		imp(self.rb().as_mat().as_dyn().as_dyn_stride(), rhs.as_mat_ref().as_dyn().as_dyn_stride())
491	}
492
493	/// returns `true` if all of the elements of `self` are finite.
494	/// otherwise returns `false`.
495	#[inline]
496	pub fn is_all_finite(&self) -> bool
497	where
498		T: Conjugate,
499	{
500		fn imp<T: ComplexField>(A: ColRef<'_, T>) -> bool {
501			with_dim!({
502				let M = A.nrows();
503			});
504
505			let A = A.as_row_shape(M);
506
507			for i in M.indices() {
508				if !is_finite(&A[i]) {
509					return false;
510				}
511			}
512
513			true
514		}
515
516		imp(self.rb().as_dyn_rows().as_dyn_stride().canonical())
517	}
518
519	/// returns `true` if any of the elements of `self` is `NaN`.
520	/// otherwise returns `false`.
521	#[inline]
522	pub fn has_nan(&self) -> bool
523	where
524		T: Conjugate,
525	{
526		fn imp<T: ComplexField>(A: ColRef<'_, T>) -> bool {
527			with_dim!({
528				let M = A.nrows();
529			});
530
531			let A = A.as_row_shape(M);
532
533			for i in M.indices() {
534				if is_nan(&A[i]) {
535					return true;
536				}
537			}
538
539			false
540		}
541
542		imp(self.rb().as_dyn_rows().as_dyn_stride().canonical())
543	}
544}
545
546impl<'a, T, Rows: Shape> ColRef<'a, T, Rows, ContiguousFwd> {
547	/// returns a reference over the elements as a slice
548	#[inline]
549	pub fn as_slice(self) -> &'a [T] {
550		unsafe { core::slice::from_raw_parts(self.as_ptr(), self.nrows().unbound()) }
551	}
552}
553
554impl<'a, 'ROWS, T> ColRef<'a, T, Dim<'ROWS>, ContiguousFwd> {
555	/// returns a reference over the elements as a lifetime-bound slice
556	#[inline]
557	pub fn as_array(self) -> &'a Array<'ROWS, T> {
558		unsafe { &*(self.as_slice() as *const [_] as *const Array<'ROWS, T>) }
559	}
560}
561
562impl<'ROWS, 'a, T, RStride: Stride> ColRef<'a, T, Dim<'ROWS>, RStride> {
563	#[doc(hidden)]
564	#[inline]
565	pub fn split_rows_with<'TOP, 'BOT>(
566		self,
567		row: Partition<'TOP, 'BOT, 'ROWS>,
568	) -> (ColRef<'a, T, Dim<'TOP>, RStride>, ColRef<'a, T, Dim<'BOT>, RStride>) {
569		let (a, b) = self.split_at_row(row.midpoint());
570		(a.as_row_shape(row.head), b.as_row_shape(row.tail))
571	}
572}
573
574impl<T: core::fmt::Debug, Rows: Shape, RStride: Stride> core::fmt::Debug for Ref<'_, T, Rows, RStride> {
575	fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
576		generic::Col::from_inner_ref(self).transpose().fmt(f)
577	}
578}
579
580impl<'a, T> ColRef<'a, T, usize, isize>
581where
582	T: RealField,
583{
584	pub(crate) fn internal_max(self) -> Option<T> {
585		if self.nrows().unbound() == 0 || self.ncols() == 0 {
586			return None;
587		}
588
589		let mut max_val = self.get(0);
590
591		self.iter().for_each(|val| {
592			if val > max_val {
593				max_val = val;
594			}
595		});
596
597		Some((*max_val).clone())
598	}
599
600	/// Returns the minimum element in the column, or `None` if the column is empty
601	pub(crate) fn internal_min(self) -> Option<T> {
602		if self.nrows().unbound() == 0 || self.ncols() == 0 {
603			return None;
604		}
605
606		let mut min_val = self.get(0);
607
608		self.iter().for_each(|val| {
609			if val < min_val {
610				min_val = val;
611			}
612		});
613
614		Some((*min_val).clone())
615	}
616}
617
618impl<'a, T, Rows: Shape, RStride: Stride> ColRef<'a, T, Rows, RStride>
619where
620	T: RealField,
621{
622	/// Returns the maximum element in the column, or `None` if the column is empty
623	pub fn max(&self) -> Option<T> {
624		self.as_dyn_rows().as_dyn_stride().internal_max()
625	}
626
627	/// Returns the minimum element in the column, or `None` if the column is empty
628	pub fn min(&self) -> Option<T> {
629		self.as_dyn_rows().as_dyn_stride().internal_min()
630	}
631}
632
633#[cfg(test)]
634mod tests {
635	use crate::Col;
636
637	#[test]
638	fn test_col_min() {
639		let col: Col<f64> = Col::from_fn(5, |x| (x + 1) as f64);
640		let colref = col.as_ref();
641		assert_eq!(colref.min(), Some(1.0));
642
643		let empty: Col<f64> = Col::from_fn(0, |_| 0.0);
644		let emptyref = empty.as_ref();
645		assert_eq!(emptyref.min(), None);
646	}
647
648	#[test]
649	fn test_col_max() {
650		let col: Col<f64> = Col::from_fn(5, |x| (x + 1) as f64);
651		let colref = col.as_ref();
652		assert_eq!(colref.max(), Some(5.0));
653
654		let empty: Col<f64> = Col::from_fn(0, |_| 0.0);
655		let emptyref = empty.as_ref();
656		assert_eq!(emptyref.max(), None);
657	}
658}