1use super::*;
2use crate::utils::bound::{Array, Dim, Partition};
3use crate::{ContiguousFwd, Idx, IdxInc};
4use equator::{assert, debug_assert};
5use faer_traits::Real;
6
7pub struct Ref<'a, T, Cols = usize, CStride = isize> {
9 pub(crate) trans: ColRef<'a, T, Cols, CStride>,
10}
11
12impl<T, Rows: Copy, CStride: Copy> Copy for Ref<'_, T, Rows, CStride> {}
13impl<T, Rows: Copy, CStride: Copy> Clone for Ref<'_, T, Rows, CStride> {
14 #[inline]
15 fn clone(&self) -> Self {
16 *self
17 }
18}
19
20impl<'short, T, Rows: Copy, CStride: Copy> Reborrow<'short> for Ref<'_, T, Rows, CStride> {
21 type Target = Ref<'short, T, Rows, CStride>;
22
23 #[inline]
24 fn rb(&'short self) -> Self::Target {
25 *self
26 }
27}
28impl<'short, T, Rows: Copy, CStride: Copy> ReborrowMut<'short> for Ref<'_, T, Rows, CStride> {
29 type Target = Ref<'short, T, Rows, CStride>;
30
31 #[inline]
32 fn rb_mut(&'short mut self) -> Self::Target {
33 *self
34 }
35}
36impl<'a, T, Rows: Copy, CStride: Copy> IntoConst for Ref<'a, T, Rows, CStride> {
37 type Target = Ref<'a, T, Rows, CStride>;
38
39 #[inline]
40 fn into_const(self) -> Self::Target {
41 self
42 }
43}
44
45unsafe impl<T: Sync, Rows: Sync, CStride: Sync> Sync for Ref<'_, T, Rows, CStride> {}
46unsafe impl<T: Sync, Rows: Send, CStride: Send> Send for Ref<'_, T, Rows, CStride> {}
47
48impl<'a, T> RowRef<'a, T> {
49 #[inline]
51 pub fn from_ref(value: &'a T) -> Self {
52 unsafe { RowRef::from_raw_parts(value as *const T, 1, 1) }
53 }
54
55 #[inline]
58 pub fn from_slice(slice: &'a [T]) -> Self {
59 let len = slice.len();
60 unsafe { Self::from_raw_parts(slice.as_ptr(), len, 1) }
61 }
62}
63
64impl<'a, T, Cols: Shape, CStride: Stride> RowRef<'a, T, Cols, CStride> {
65 #[inline(always)]
71 #[track_caller]
72 pub const unsafe fn from_raw_parts(ptr: *const T, ncols: Cols, col_stride: CStride) -> Self {
73 Self {
74 0: Ref {
75 trans: ColRef::from_raw_parts(ptr, ncols, col_stride),
76 },
77 }
78 }
79
80 #[inline(always)]
82 pub fn as_ptr(&self) -> *const T {
83 self.trans.as_ptr()
84 }
85
86 #[inline(always)]
88 pub fn nrows(&self) -> usize {
89 1
90 }
91
92 #[inline(always)]
94 pub fn ncols(&self) -> Cols {
95 self.trans.nrows()
96 }
97
98 #[inline(always)]
100 pub fn shape(&self) -> (usize, Cols) {
101 (self.nrows(), self.ncols())
102 }
103
104 #[inline(always)]
106 pub fn col_stride(&self) -> CStride {
107 self.trans.row_stride()
108 }
109
110 #[inline(always)]
112 pub fn ptr_at(&self, col: IdxInc<Cols>) -> *const T {
113 self.trans.ptr_at(col)
114 }
115
116 #[inline(always)]
123 #[track_caller]
124 pub unsafe fn ptr_inbounds_at(&self, col: Idx<Cols>) -> *const T {
125 debug_assert!(all(col < self.ncols()));
126 self.trans.ptr_inbounds_at(col)
127 }
128
129 #[inline]
138 #[track_caller]
139 pub fn split_at_col(self, col: IdxInc<Cols>) -> (RowRef<'a, T, usize, CStride>, RowRef<'a, T, usize, CStride>) {
140 assert!(all(col <= self.ncols()));
141 let rs = self.col_stride();
142
143 let top = self.as_ptr();
144 let bot = self.ptr_at(col);
145 unsafe {
146 (
147 RowRef::from_raw_parts(top, col.unbound(), rs),
148 RowRef::from_raw_parts(bot, self.ncols().unbound() - col.unbound(), rs),
149 )
150 }
151 }
152
153 #[inline(always)]
155 pub fn transpose(self) -> ColRef<'a, T, Cols, CStride> {
156 self.trans
157 }
158
159 #[inline(always)]
161 pub fn conjugate(self) -> RowRef<'a, T::Conj, Cols, CStride>
162 where
163 T: Conjugate,
164 {
165 RowRef {
166 0: Ref {
167 trans: self.trans.conjugate(),
168 },
169 }
170 }
171
172 #[inline(always)]
174 pub fn canonical(self) -> RowRef<'a, T::Canonical, Cols, CStride>
175 where
176 T: Conjugate,
177 {
178 RowRef {
179 0: Ref {
180 trans: self.trans.canonical(),
181 },
182 }
183 }
184
185 #[inline(always)]
187 pub fn adjoint(self) -> ColRef<'a, T::Conj, Cols, CStride>
188 where
189 T: Conjugate,
190 {
191 self.conjugate().transpose()
192 }
193
194 #[inline(always)]
195 #[track_caller]
196 pub(crate) fn at(self, col: Idx<Cols>) -> &'a T {
197 assert!(all(col < self.ncols()));
198 unsafe { self.at_unchecked(col) }
199 }
200
201 #[inline(always)]
202 #[track_caller]
203 pub(crate) unsafe fn at_unchecked(self, col: Idx<Cols>) -> &'a T {
204 &*self.ptr_inbounds_at(col)
205 }
206
207 #[track_caller]
214 #[inline(always)]
215 pub fn get<ColRange>(self, col: ColRange) -> <RowRef<'a, T, Cols, CStride> as RowIndex<ColRange>>::Target
216 where
217 RowRef<'a, T, Cols, CStride>: RowIndex<ColRange>,
218 {
219 <RowRef<'a, T, Cols, CStride> as RowIndex<ColRange>>::get(self, col)
220 }
221
222 #[track_caller]
229 #[inline(always)]
230 pub unsafe fn get_unchecked<ColRange>(self, col: ColRange) -> <RowRef<'a, T, Cols, CStride> as RowIndex<ColRange>>::Target
231 where
232 RowRef<'a, T, Cols, CStride>: RowIndex<ColRange>,
233 {
234 unsafe { <RowRef<'a, T, Cols, CStride> as RowIndex<ColRange>>::get_unchecked(self, col) }
235 }
236
237 #[inline]
239 pub fn reverse_cols(self) -> RowRef<'a, T, Cols, CStride::Rev> {
240 RowRef {
241 0: Ref {
242 trans: self.trans.reverse_rows(),
243 },
244 }
245 }
246
247 #[inline]
255 pub fn subcols<V: Shape>(self, col_start: IdxInc<Cols>, ncols: V) -> RowRef<'a, T, V, CStride> {
256 assert!(all(col_start <= self.ncols()));
257 {
258 let ncols = ncols.unbound();
259 let full_ncols = self.ncols().unbound();
260 let col_start = col_start.unbound();
261 assert!(all(ncols <= full_ncols - col_start));
262 }
263 let cs = self.col_stride();
264 unsafe { RowRef::from_raw_parts(self.ptr_at(col_start), ncols, cs) }
265 }
266
267 #[inline]
270 #[track_caller]
271 pub fn as_col_shape<V: Shape>(self, ncols: V) -> RowRef<'a, T, V, CStride> {
272 assert!(all(self.ncols().unbound() == ncols.unbound()));
273 unsafe { RowRef::from_raw_parts(self.as_ptr(), ncols, self.col_stride()) }
274 }
275
276 #[inline]
278 pub fn as_dyn_cols(self) -> RowRef<'a, T, usize, CStride> {
279 unsafe { RowRef::from_raw_parts(self.as_ptr(), self.ncols().unbound(), self.col_stride()) }
280 }
281
282 #[inline]
284 pub fn as_dyn_stride(self) -> RowRef<'a, T, Cols, isize> {
285 unsafe { RowRef::from_raw_parts(self.as_ptr(), self.ncols(), self.col_stride().element_stride()) }
286 }
287
288 #[inline]
290 pub fn iter(self) -> impl 'a + ExactSizeIterator + DoubleEndedIterator<Item = &'a T>
291 where
292 Cols: 'a,
293 {
294 self.trans.iter()
295 }
296
297 #[inline]
299 #[cfg(feature = "rayon")]
300 pub fn par_iter(self) -> impl 'a + rayon::iter::IndexedParallelIterator<Item = &'a T>
301 where
302 T: Sync,
303 Cols: 'a,
304 {
305 self.trans.par_iter()
306 }
307
308 #[inline]
313 #[track_caller]
314 #[cfg(feature = "rayon")]
315 pub fn par_partition(self, count: usize) -> impl 'a + rayon::iter::IndexedParallelIterator<Item = RowRef<'a, T, usize, CStride>>
316 where
317 T: Sync,
318 Cols: 'a,
319 {
320 use rayon::prelude::*;
321 self.transpose().par_partition(count).map(ColRef::transpose)
322 }
323
324 #[inline]
326 pub fn try_as_row_major(self) -> Option<RowRef<'a, T, Cols, ContiguousFwd>> {
327 if self.col_stride().element_stride() == 1 {
328 Some(unsafe { RowRef::from_raw_parts(self.as_ptr(), self.ncols(), ContiguousFwd) })
329 } else {
330 None
331 }
332 }
333
334 #[inline(always)]
335 #[doc(hidden)]
336 pub unsafe fn const_cast(self) -> RowMut<'a, T, Cols, CStride> {
337 RowMut {
338 0: Mut {
339 trans: self.trans.const_cast(),
340 },
341 }
342 }
343
344 #[inline]
346 pub fn as_mat(self) -> MatRef<'a, T, usize, Cols, isize, CStride> {
347 self.transpose().as_mat().transpose()
348 }
349
350 #[inline]
352 pub fn as_diagonal(self) -> DiagRef<'a, T, Cols, CStride> {
353 DiagRef {
354 0: crate::diag::Ref { inner: self.trans },
355 }
356 }
357
358 #[inline]
359 pub(crate) fn __at(self, i: Idx<Cols>) -> &'a T {
360 self.at(i)
361 }
362}
363
364impl<T, Cols: Shape, CStride: Stride, Inner: for<'short> Reborrow<'short, Target = Ref<'short, T, Cols, CStride>>> generic::Row<Inner> {
365 #[inline]
367 pub fn as_ref(&self) -> RowRef<'_, T, Cols, CStride> {
368 self.rb()
369 }
370
371 #[inline]
373 pub fn norm_max(&self) -> Real<T>
374 where
375 T: Conjugate,
376 {
377 self.rb().as_mat().norm_max()
378 }
379
380 #[inline]
382 pub fn norm_l2(&self) -> Real<T>
383 where
384 T: Conjugate,
385 {
386 self.rb().as_mat().norm_l2()
387 }
388
389 #[inline]
391 pub fn squared_norm_l2(&self) -> Real<T>
392 where
393 T: Conjugate,
394 {
395 self.rb().as_mat().squared_norm_l2()
396 }
397
398 #[inline]
400 pub fn norm_l1(&self) -> Real<T>
401 where
402 T: Conjugate,
403 {
404 self.rb().as_mat().norm_l1()
405 }
406
407 #[inline]
409 pub fn sum(&self) -> T::Canonical
410 where
411 T: Conjugate,
412 {
413 self.rb().as_mat().sum()
414 }
415
416 #[inline]
418 pub fn kron(&self, rhs: impl AsMatRef<T: Conjugate<Canonical = T::Canonical>>) -> Mat<T::Canonical>
419 where
420 T: Conjugate,
421 {
422 fn imp<T: ComplexField>(lhs: MatRef<impl Conjugate<Canonical = T>>, rhs: MatRef<impl Conjugate<Canonical = T>>) -> Mat<T> {
423 let mut out = Mat::zeros(lhs.nrows() * rhs.nrows(), lhs.ncols() * rhs.ncols());
424 linalg::kron::kron(out.rb_mut(), lhs, rhs);
425 out
426 }
427
428 imp(self.rb().as_mat().as_dyn().as_dyn_stride(), rhs.as_mat_ref().as_dyn().as_dyn_stride())
429 }
430
431 #[inline]
434 pub fn is_all_finite(&self) -> bool
435 where
436 T: Conjugate,
437 {
438 self.rb().transpose().is_all_finite()
439 }
440
441 #[inline]
444 pub fn has_nan(&self) -> bool
445 where
446 T: Conjugate,
447 {
448 self.rb().transpose().has_nan()
449 }
450
451 #[inline]
453 pub fn cloned(&self) -> Row<T, Cols>
454 where
455 T: Clone,
456 {
457 self.rb().transpose().cloned().into_transpose()
458 }
459
460 #[inline]
462 pub fn to_owned(&self) -> Row<T::Canonical, Cols>
463 where
464 T: Conjugate,
465 {
466 self.rb().transpose().to_owned().into_transpose()
467 }
468}
469
470impl<'a, T, Rows: Shape> RowRef<'a, T, Rows, ContiguousFwd> {
471 #[inline]
473 pub fn as_slice(self) -> &'a [T] {
474 self.transpose().as_slice()
475 }
476}
477
478impl<'a, 'ROWS, T> RowRef<'a, T, Dim<'ROWS>, ContiguousFwd> {
479 #[inline]
481 pub fn as_array(self) -> &'a Array<'ROWS, T> {
482 self.transpose().as_array()
483 }
484}
485
486impl<'COLS, 'a, T, CStride: Stride> RowRef<'a, T, Dim<'COLS>, CStride> {
487 #[doc(hidden)]
488 #[inline]
489 pub fn split_cols_with<'LEFT, 'RIGHT>(
490 self,
491 col: Partition<'LEFT, 'RIGHT, 'COLS>,
492 ) -> (RowRef<'a, T, Dim<'LEFT>, CStride>, RowRef<'a, T, Dim<'RIGHT>, CStride>) {
493 let (a, b) = self.split_at_col(col.midpoint());
494 (a.as_col_shape(col.head), b.as_col_shape(col.tail))
495 }
496}
497
498impl<T: core::fmt::Debug, Cols: Shape, CStride: Stride> core::fmt::Debug for Ref<'_, T, Cols, CStride> {
499 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
500 fn imp<T: core::fmt::Debug>(f: &mut core::fmt::Formatter<'_>, this: RowRef<'_, T, Dim<'_>>) -> core::fmt::Result {
501 f.debug_list()
502 .entries(this.ncols().indices().map(|j| crate::hacks::hijack_debug(this.at(j))))
503 .finish()
504 }
505
506 let this = generic::Row::from_inner_ref(self);
507
508 with_dim!(N, this.ncols().unbound());
509 imp(f, this.as_col_shape(N).as_dyn_stride())
510 }
511}
512
513impl<'a, T> RowRef<'a, T, usize, isize>
514where
515 T: RealField,
516{
517 pub(crate) fn internal_max(self) -> Option<T> {
519 if self.nrows().unbound() == 0 || self.ncols() == 0 {
520 return None;
521 }
522
523 let mut max_val = self.get(0);
524
525 self.iter().for_each(|val| {
526 if val > max_val {
527 max_val = val;
528 }
529 });
530
531 Some((*max_val).clone())
532 }
533
534 pub(crate) fn internal_min(self) -> Option<T> {
536 if self.nrows().unbound() == 0 || self.ncols() == 0 {
537 return None;
538 }
539
540 let mut min_val = self.get(0);
541
542 self.iter().for_each(|val| {
543 if val < min_val {
544 min_val = val;
545 }
546 });
547
548 Some((*min_val).clone())
549 }
550}
551
552impl<'a, T, Cols: Shape, CStride: Stride> RowRef<'a, T, Cols, CStride>
553where
554 T: RealField,
555{
556 pub fn max(&self) -> Option<T> {
558 self.as_dyn_cols().as_dyn_stride().internal_max()
559 }
560
561 pub fn min(&self) -> Option<T> {
563 self.as_dyn_cols().as_dyn_stride().internal_min()
564 }
565}
566
567#[cfg(test)]
568mod tests {
569 use crate::Row;
570
571 #[test]
572 fn test_row_min() {
573 let row: Row<f64> = Row::from_fn(5, |x| (x + 1) as f64);
574 let rowref = row.as_ref();
575 assert_eq!(rowref.min(), Some(1.0));
576
577 let empty: Row<f64> = Row::from_fn(0, |_| 0.0);
578 let emptyref = empty.as_ref();
579 assert_eq!(emptyref.min(), None);
580 }
581
582 #[test]
583 fn test_row_max() {
584 let row: Row<f64> = Row::from_fn(5, |x| (x + 1) as f64);
585 let rowref = row.as_ref();
586 assert_eq!(rowref.max(), Some(5.0));
587
588 let empty: Row<f64> = Row::from_fn(0, |_| 0.0);
589 let emptyref = empty.as_ref();
590 assert_eq!(emptyref.max(), None);
591 }
592}