1use super::*;
2use crate::utils::bound::{Array, Dim, Partition};
3use crate::{ContiguousFwd, Idx, IdxInc};
4use core::marker::PhantomData;
5use core::ptr::NonNull;
6use equator::assert;
7use faer_traits::Real;
8use generativity::Guard;
9
10pub struct Ref<'a, T, Rows = usize, RStride = isize> {
12 pub(super) imp: ColView<T, Rows, RStride>,
13 pub(super) __marker: PhantomData<&'a T>,
14}
15
16impl<T, Rows: Copy, RStride: Copy> Copy for Ref<'_, T, Rows, RStride> {}
17impl<T, Rows: Copy, RStride: Copy> Clone for Ref<'_, T, Rows, RStride> {
18 #[inline]
19 fn clone(&self) -> Self {
20 *self
21 }
22}
23
24impl<'short, T, Rows: Copy, RStride: Copy> Reborrow<'short> for Ref<'_, T, Rows, RStride> {
25 type Target = Ref<'short, T, Rows, RStride>;
26
27 #[inline]
28 fn rb(&'short self) -> Self::Target {
29 *self
30 }
31}
32impl<'short, T, Rows: Copy, RStride: Copy> ReborrowMut<'short> for Ref<'_, T, Rows, RStride> {
33 type Target = Ref<'short, T, Rows, RStride>;
34
35 #[inline]
36 fn rb_mut(&'short mut self) -> Self::Target {
37 *self
38 }
39}
40impl<'a, T, Rows: Copy, RStride: Copy> IntoConst for Ref<'a, T, Rows, RStride> {
41 type Target = Ref<'a, T, Rows, RStride>;
42
43 #[inline]
44 fn into_const(self) -> Self::Target {
45 self
46 }
47}
48
49unsafe impl<T: Sync, Rows: Sync, RStride: Sync> Sync for Ref<'_, T, Rows, RStride> {}
50unsafe impl<T: Sync, Rows: Send, RStride: Send> Send for Ref<'_, T, Rows, RStride> {}
51
52impl<'a, T> ColRef<'a, T> {
53 #[inline]
55 pub fn from_ref(value: &'a T) -> Self {
56 unsafe { ColRef::from_raw_parts(value as *const T, 1, 1) }
57 }
58
59 #[inline]
62 pub fn from_slice(slice: &'a [T]) -> Self {
63 let len = slice.len();
64 unsafe { Self::from_raw_parts(slice.as_ptr(), len, 1) }
65 }
66}
67
68impl<'a, T, Rows: Shape, RStride: Stride> ColRef<'a, T, Rows, RStride> {
69 #[inline(always)]
75 #[track_caller]
76 pub const unsafe fn from_raw_parts(ptr: *const T, nrows: Rows, row_stride: RStride) -> Self {
77 Self {
78 0: Ref {
79 imp: ColView {
80 ptr: NonNull::new_unchecked(ptr as *mut T),
81 nrows,
82 row_stride,
83 },
84 __marker: PhantomData,
85 },
86 }
87 }
88
89 #[inline(always)]
91 pub fn as_ptr(&self) -> *const T {
92 self.imp.ptr.as_ptr() as *const T
93 }
94
95 #[inline(always)]
97 pub fn nrows(&self) -> Rows {
98 self.imp.nrows
99 }
100
101 #[inline(always)]
103 pub fn ncols(&self) -> usize {
104 1
105 }
106
107 #[inline(always)]
109 pub fn shape(&self) -> (Rows, usize) {
110 (self.nrows(), self.ncols())
111 }
112
113 #[inline(always)]
115 pub fn row_stride(&self) -> RStride {
116 self.imp.row_stride
117 }
118
119 #[inline(always)]
121 pub fn ptr_at(&self, row: IdxInc<Rows>) -> *const T {
122 let ptr = self.as_ptr();
123
124 if row >= self.nrows() {
125 ptr
126 } else {
127 ptr.wrapping_offset(row.unbound() as isize * self.row_stride().element_stride())
128 }
129 }
130
131 #[inline(always)]
138 #[track_caller]
139 pub unsafe fn ptr_inbounds_at(&self, row: Idx<Rows>) -> *const T {
140 self.as_ptr().offset(row.unbound() as isize * self.row_stride().element_stride())
141 }
142
143 #[inline]
152 #[track_caller]
153 pub fn split_at_row(self, row: IdxInc<Rows>) -> (ColRef<'a, T, usize, RStride>, ColRef<'a, T, usize, RStride>) {
154 assert!(all(row <= self.nrows()));
155 let rs = self.row_stride();
156
157 let top = self.as_ptr();
158 let bot = self.ptr_at(row);
159 unsafe {
160 (
161 ColRef::from_raw_parts(top, row.unbound(), rs),
162 ColRef::from_raw_parts(bot, self.nrows().unbound() - row.unbound(), rs),
163 )
164 }
165 }
166
167 #[inline(always)]
169 pub fn transpose(self) -> RowRef<'a, T, Rows, RStride> {
170 RowRef {
171 0: crate::row::Ref { trans: self },
172 }
173 }
174
175 #[inline(always)]
177 pub fn conjugate(self) -> ColRef<'a, T::Conj, Rows, RStride>
178 where
179 T: Conjugate,
180 {
181 unsafe { ColRef::from_raw_parts(self.as_ptr() as *const T::Conj, self.nrows(), self.row_stride()) }
182 }
183
184 #[inline(always)]
186 pub fn canonical(self) -> ColRef<'a, T::Canonical, Rows, RStride>
187 where
188 T: Conjugate,
189 {
190 unsafe { ColRef::from_raw_parts(self.as_ptr() as *const T::Canonical, self.nrows(), self.row_stride()) }
191 }
192
193 #[inline(always)]
195 pub fn adjoint(self) -> RowRef<'a, T::Conj, Rows, RStride>
196 where
197 T: Conjugate,
198 {
199 self.conjugate().transpose()
200 }
201
202 #[inline(always)]
203 #[track_caller]
204 pub(crate) fn at(self, row: Idx<Rows>) -> &'a T {
205 assert!(all(row < self.nrows()));
206 unsafe { self.at_unchecked(row) }
207 }
208
209 #[inline(always)]
210 #[track_caller]
211 pub(crate) unsafe fn at_unchecked(self, row: Idx<Rows>) -> &'a T {
212 &*self.ptr_inbounds_at(row)
213 }
214
215 #[track_caller]
221 #[inline(always)]
222 pub fn get<RowRange>(self, row: RowRange) -> <ColRef<'a, T, Rows, RStride> as ColIndex<RowRange>>::Target
223 where
224 ColRef<'a, T, Rows, RStride>: ColIndex<RowRange>,
225 {
226 <ColRef<'a, T, Rows, RStride> as ColIndex<RowRange>>::get(self, row)
227 }
228
229 #[track_caller]
236 #[inline(always)]
237 pub unsafe fn get_unchecked<RowRange>(self, row: RowRange) -> <ColRef<'a, T, Rows, RStride> as ColIndex<RowRange>>::Target
238 where
239 ColRef<'a, T, Rows, RStride>: ColIndex<RowRange>,
240 {
241 unsafe { <ColRef<'a, T, Rows, RStride> as ColIndex<RowRange>>::get_unchecked(self, row) }
242 }
243
244 #[inline]
246 pub fn reverse_rows(self) -> ColRef<'a, T, Rows, RStride::Rev> {
247 let row = unsafe { IdxInc::<Rows>::new_unbound(self.nrows().unbound().saturating_sub(1)) };
248 let ptr = self.ptr_at(row);
249 unsafe { ColRef::from_raw_parts(ptr, self.nrows(), self.row_stride().rev()) }
250 }
251
252 #[inline]
260 #[track_caller]
261 pub fn subrows<V: Shape>(self, row_start: IdxInc<Rows>, nrows: V) -> ColRef<'a, T, V, RStride> {
262 assert!(all(row_start <= self.nrows()));
263 {
264 let nrows = nrows.unbound();
265 let full_nrows = self.nrows().unbound();
266 let row_start = row_start.unbound();
267 assert!(all(nrows <= full_nrows - row_start));
268 }
269 let rs = self.row_stride();
270
271 unsafe { ColRef::from_raw_parts(self.ptr_at(row_start), nrows, rs) }
272 }
273
274 #[inline]
277 #[track_caller]
278 pub fn as_row_shape<V: Shape>(self, nrows: V) -> ColRef<'a, T, V, RStride> {
279 assert!(all(self.nrows().unbound() == nrows.unbound()));
280 unsafe { ColRef::from_raw_parts(self.as_ptr(), nrows, self.row_stride()) }
281 }
282
283 #[inline]
285 pub fn as_dyn_rows(self) -> ColRef<'a, T, usize, RStride> {
286 unsafe { ColRef::from_raw_parts(self.as_ptr(), self.nrows().unbound(), self.row_stride()) }
287 }
288
289 #[inline]
291 pub fn as_dyn_stride(self) -> ColRef<'a, T, Rows, isize> {
292 unsafe { ColRef::from_raw_parts(self.as_ptr(), self.nrows(), self.row_stride().element_stride()) }
293 }
294
295 #[inline]
297 pub fn iter(self) -> impl 'a + ExactSizeIterator + DoubleEndedIterator<Item = &'a T>
298 where
299 Rows: 'a,
300 {
301 Rows::indices(Rows::start(), self.nrows().end()).map(move |j| unsafe { self.at_unchecked(j) })
302 }
303
304 #[inline]
306 #[cfg(feature = "rayon")]
307 pub fn par_iter(self) -> impl 'a + rayon::iter::IndexedParallelIterator<Item = &'a T>
308 where
309 T: Sync,
310 Rows: 'a,
311 {
312 use rayon::prelude::*;
313 (0..self.nrows().unbound())
314 .into_par_iter()
315 .map(move |j| unsafe { self.at_unchecked(Idx::<Rows>::new_unbound(j)) })
316 }
317
318 #[inline]
323 #[track_caller]
324 #[cfg(feature = "rayon")]
325 pub fn par_partition(self, count: usize) -> impl 'a + rayon::iter::IndexedParallelIterator<Item = ColRef<'a, T, usize, RStride>>
326 where
327 T: Sync,
328 Rows: 'a,
329 {
330 use rayon::prelude::*;
331
332 let this = self.as_dyn_rows();
333
334 assert!(count > 0);
335 (0..count).into_par_iter().map(move |chunk_idx| {
336 let (start, len) = crate::utils::thread::par_split_indices(this.nrows(), chunk_idx, count);
337 this.subrows(start, len)
338 })
339 }
340
341 #[inline]
343 pub fn try_as_col_major(self) -> Option<ColRef<'a, T, Rows, ContiguousFwd>> {
344 if self.row_stride().element_stride() == 1 {
345 Some(unsafe { ColRef::from_raw_parts(self.as_ptr(), self.nrows(), ContiguousFwd) })
346 } else {
347 None
348 }
349 }
350
351 #[inline(always)]
352 #[doc(hidden)]
353 pub unsafe fn const_cast(self) -> ColMut<'a, T, Rows, RStride> {
354 ColMut::from_raw_parts_mut(self.as_ptr() as *mut T, self.nrows(), self.row_stride())
355 }
356
357 #[inline]
359 pub fn as_mat(self) -> MatRef<'a, T, Rows, usize, RStride, isize> {
360 unsafe { MatRef::from_raw_parts(self.as_ptr(), self.nrows(), self.ncols(), self.row_stride(), 0) }
361 }
362
363 #[inline]
364 #[doc(hidden)]
365 pub fn bind_r<'N>(self, row: Guard<'N>) -> ColRef<'a, T, Dim<'N>, RStride> {
366 unsafe { ColRef::from_raw_parts(self.as_ptr(), self.nrows().bind(row), self.row_stride()) }
367 }
368
369 #[inline(always)]
370 #[track_caller]
371 pub(crate) fn read(&self, row: Idx<Rows>) -> T
372 where
373 T: Clone,
374 {
375 self.at(row).clone()
376 }
377
378 #[inline]
379 #[track_caller]
380 pub(crate) fn __at(self, i: Idx<Rows>) -> &'a T {
381 self.at(i)
382 }
383
384 #[inline]
386 pub fn as_diagonal(self) -> DiagRef<'a, T, Rows, RStride> {
387 DiagRef {
388 0: crate::diag::Ref { inner: self },
389 }
390 }
391}
392
393impl<T, Rows: Shape, RStride: Stride, Inner: for<'short> Reborrow<'short, Target = Ref<'short, T, Rows, RStride>>> generic::Col<Inner> {
394 #[inline]
396 pub fn cloned(&self) -> Col<T, Rows>
397 where
398 T: Clone,
399 {
400 fn imp<'M, T: Clone, RStride: Stride>(this: ColRef<'_, T, Dim<'M>, RStride>) -> Col<T, Dim<'M>> {
401 Col::from_fn(this.nrows(), |i| this.at(i).clone())
402 }
403
404 let this = self.rb();
405 with_dim!(M, this.nrows().unbound());
406 imp(this.as_row_shape(M)).into_row_shape(this.nrows())
407 }
408
409 #[inline]
411 pub fn to_owned(&self) -> Col<T::Canonical, Rows>
412 where
413 T: Conjugate,
414 {
415 fn imp<'M, T, RStride: Stride>(this: ColRef<'_, T, Dim<'M>, RStride>) -> Col<T::Canonical, Dim<'M>>
416 where
417 T: Conjugate,
418 {
419 Col::from_fn(this.nrows(), |i| Conj::apply::<T>(this.at(i)))
420 }
421
422 let this = self.rb();
423 with_dim!(M, this.nrows().unbound());
424 imp(this.as_row_shape(M)).into_row_shape(this.nrows())
425 }
426
427 #[inline]
429 pub fn norm_max(&self) -> Real<T>
430 where
431 T: Conjugate,
432 {
433 self.rb().as_mat().norm_max()
434 }
435
436 #[inline]
438 pub fn norm_l2(&self) -> Real<T>
439 where
440 T: Conjugate,
441 {
442 self.rb().as_mat().norm_l2()
443 }
444
445 #[inline]
447 pub fn squared_norm_l2(&self) -> Real<T>
448 where
449 T: Conjugate,
450 {
451 self.rb().as_mat().squared_norm_l2()
452 }
453
454 #[inline]
456 pub fn norm_l1(&self) -> Real<T>
457 where
458 T: Conjugate,
459 {
460 self.rb().as_mat().norm_l1()
461 }
462
463 #[inline]
465 pub fn sum(&self) -> T::Canonical
466 where
467 T: Conjugate,
468 {
469 self.rb().as_mat().sum()
470 }
471
472 #[inline]
474 pub fn as_ref(&self) -> ColRef<'_, T, Rows, RStride> {
475 self.rb()
476 }
477
478 #[inline]
480 pub fn kron(&self, rhs: impl AsMatRef<T: Conjugate<Canonical = T::Canonical>>) -> Mat<T::Canonical>
481 where
482 T: Conjugate,
483 {
484 fn imp<T: ComplexField>(lhs: MatRef<impl Conjugate<Canonical = T>>, rhs: MatRef<impl Conjugate<Canonical = T>>) -> Mat<T> {
485 let mut out = Mat::zeros(lhs.nrows() * rhs.nrows(), lhs.ncols() * rhs.ncols());
486 linalg::kron::kron(out.rb_mut(), lhs, rhs);
487 out
488 }
489
490 imp(self.rb().as_mat().as_dyn().as_dyn_stride(), rhs.as_mat_ref().as_dyn().as_dyn_stride())
491 }
492
493 #[inline]
496 pub fn is_all_finite(&self) -> bool
497 where
498 T: Conjugate,
499 {
500 fn imp<T: ComplexField>(A: ColRef<'_, T>) -> bool {
501 with_dim!({
502 let M = A.nrows();
503 });
504
505 let A = A.as_row_shape(M);
506
507 for i in M.indices() {
508 if !is_finite(&A[i]) {
509 return false;
510 }
511 }
512
513 true
514 }
515
516 imp(self.rb().as_dyn_rows().as_dyn_stride().canonical())
517 }
518
519 #[inline]
522 pub fn has_nan(&self) -> bool
523 where
524 T: Conjugate,
525 {
526 fn imp<T: ComplexField>(A: ColRef<'_, T>) -> bool {
527 with_dim!({
528 let M = A.nrows();
529 });
530
531 let A = A.as_row_shape(M);
532
533 for i in M.indices() {
534 if is_nan(&A[i]) {
535 return true;
536 }
537 }
538
539 false
540 }
541
542 imp(self.rb().as_dyn_rows().as_dyn_stride().canonical())
543 }
544}
545
546impl<'a, T, Rows: Shape> ColRef<'a, T, Rows, ContiguousFwd> {
547 #[inline]
549 pub fn as_slice(self) -> &'a [T] {
550 unsafe { core::slice::from_raw_parts(self.as_ptr(), self.nrows().unbound()) }
551 }
552}
553
554impl<'a, 'ROWS, T> ColRef<'a, T, Dim<'ROWS>, ContiguousFwd> {
555 #[inline]
557 pub fn as_array(self) -> &'a Array<'ROWS, T> {
558 unsafe { &*(self.as_slice() as *const [_] as *const Array<'ROWS, T>) }
559 }
560}
561
562impl<'ROWS, 'a, T, RStride: Stride> ColRef<'a, T, Dim<'ROWS>, RStride> {
563 #[doc(hidden)]
564 #[inline]
565 pub fn split_rows_with<'TOP, 'BOT>(
566 self,
567 row: Partition<'TOP, 'BOT, 'ROWS>,
568 ) -> (ColRef<'a, T, Dim<'TOP>, RStride>, ColRef<'a, T, Dim<'BOT>, RStride>) {
569 let (a, b) = self.split_at_row(row.midpoint());
570 (a.as_row_shape(row.head), b.as_row_shape(row.tail))
571 }
572}
573
574impl<T: core::fmt::Debug, Rows: Shape, RStride: Stride> core::fmt::Debug for Ref<'_, T, Rows, RStride> {
575 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
576 generic::Col::from_inner_ref(self).transpose().fmt(f)
577 }
578}
579
580impl<'a, T> ColRef<'a, T, usize, isize>
581where
582 T: RealField,
583{
584 pub(crate) fn internal_max(self) -> Option<T> {
585 if self.nrows().unbound() == 0 || self.ncols() == 0 {
586 return None;
587 }
588
589 let mut max_val = self.get(0);
590
591 self.iter().for_each(|val| {
592 if val > max_val {
593 max_val = val;
594 }
595 });
596
597 Some((*max_val).clone())
598 }
599
600 pub(crate) fn internal_min(self) -> Option<T> {
602 if self.nrows().unbound() == 0 || self.ncols() == 0 {
603 return None;
604 }
605
606 let mut min_val = self.get(0);
607
608 self.iter().for_each(|val| {
609 if val < min_val {
610 min_val = val;
611 }
612 });
613
614 Some((*min_val).clone())
615 }
616}
617
618impl<'a, T, Rows: Shape, RStride: Stride> ColRef<'a, T, Rows, RStride>
619where
620 T: RealField,
621{
622 pub fn max(&self) -> Option<T> {
624 self.as_dyn_rows().as_dyn_stride().internal_max()
625 }
626
627 pub fn min(&self) -> Option<T> {
629 self.as_dyn_rows().as_dyn_stride().internal_min()
630 }
631}
632
633#[cfg(test)]
634mod tests {
635 use crate::Col;
636
637 #[test]
638 fn test_col_min() {
639 let col: Col<f64> = Col::from_fn(5, |x| (x + 1) as f64);
640 let colref = col.as_ref();
641 assert_eq!(colref.min(), Some(1.0));
642
643 let empty: Col<f64> = Col::from_fn(0, |_| 0.0);
644 let emptyref = empty.as_ref();
645 assert_eq!(emptyref.min(), None);
646 }
647
648 #[test]
649 fn test_col_max() {
650 let col: Col<f64> = Col::from_fn(5, |x| (x + 1) as f64);
651 let colref = col.as_ref();
652 assert_eq!(colref.max(), Some(5.0));
653
654 let empty: Col<f64> = Col::from_fn(0, |_| 0.0);
655 let emptyref = empty.as_ref();
656 assert_eq!(emptyref.max(), None);
657 }
658}