1#![allow(clippy::len_without_is_empty)]
3
4use crate::{
5 assert, constrained, debug_assert,
6 inner::{PermMut, PermOwn, PermRef},
7 seal::Seal,
8 temp_mat_req, temp_mat_uninit, unzipped, zipped, ComplexField, Entity, MatMut, MatRef, Matrix,
9};
10use bytemuck::Pod;
11use core::fmt::Debug;
12use dyn_stack::{PodStack, SizeOverflow, StackReq};
13use reborrow::*;
14
15impl Seal for i32 {}
16impl Seal for i64 {}
17impl Seal for i128 {}
18impl Seal for isize {}
19impl Seal for u32 {}
20impl Seal for u64 {}
21impl Seal for u128 {}
22impl Seal for usize {}
23
24pub trait Index:
28 Seal
29 + core::fmt::Debug
30 + core::ops::Not<Output = Self>
31 + core::ops::Add<Output = Self>
32 + core::ops::Sub<Output = Self>
33 + core::ops::AddAssign
34 + core::ops::SubAssign
35 + Pod
36 + Eq
37 + Ord
38 + Send
39 + Sync
40{
41 type FixedWidth: Index;
43 type Signed: SignedIndex;
45
46 #[must_use]
48 #[inline(always)]
49 fn truncate(value: usize) -> Self {
50 Self::from_signed(<Self::Signed as SignedIndex>::truncate(value))
51 }
52
53 #[must_use]
55 #[inline(always)]
56 fn zx(self) -> usize {
57 self.to_signed().zx()
58 }
59
60 #[inline(always)]
62 fn canonicalize(slice: &[Self]) -> &[Self::FixedWidth] {
63 bytemuck::cast_slice(slice)
64 }
65
66 #[inline(always)]
68 fn canonicalize_mut(slice: &mut [Self]) -> &mut [Self::FixedWidth] {
69 bytemuck::cast_slice_mut(slice)
70 }
71
72 #[inline(always)]
74 fn from_signed(value: Self::Signed) -> Self {
75 pulp::cast(value)
76 }
77
78 #[inline(always)]
80 fn to_signed(self) -> Self::Signed {
81 pulp::cast(self)
82 }
83
84 #[inline]
86 fn sum_nonnegative(slice: &[Self]) -> Option<Self> {
87 Self::Signed::sum_nonnegative(bytemuck::cast_slice(slice)).map(Self::from_signed)
88 }
89}
90
91pub trait SignedIndex:
95 Seal
96 + core::fmt::Debug
97 + core::ops::Neg<Output = Self>
98 + core::ops::Add<Output = Self>
99 + core::ops::Sub<Output = Self>
100 + core::ops::AddAssign
101 + core::ops::SubAssign
102 + Pod
103 + Eq
104 + Ord
105 + Send
106 + Sync
107{
108 const MAX: Self;
110
111 #[must_use]
113 fn truncate(value: usize) -> Self;
114
115 #[must_use]
117 fn zx(self) -> usize;
118 #[must_use]
120 fn sx(self) -> usize;
121
122 fn sum_nonnegative(slice: &[Self]) -> Option<Self> {
124 let mut acc = Self::zeroed();
125 for &i in slice {
126 if Self::MAX - i < acc {
127 return None;
128 }
129 acc += i;
130 }
131 Some(acc)
132 }
133}
134
135#[cfg(any(
136 target_pointer_width = "32",
137 target_pointer_width = "64",
138 target_pointer_width = "128",
139))]
140impl Index for u32 {
141 type FixedWidth = u32;
142 type Signed = i32;
143}
144#[cfg(any(target_pointer_width = "64", target_pointer_width = "128"))]
145impl Index for u64 {
146 type FixedWidth = u64;
147 type Signed = i64;
148}
149#[cfg(target_pointer_width = "128")]
150impl Index for u128 {
151 type FixedWidth = u128;
152 type Signed = i128;
153}
154
155impl Index for usize {
156 #[cfg(target_pointer_width = "32")]
157 type FixedWidth = u32;
158 #[cfg(target_pointer_width = "64")]
159 type FixedWidth = u64;
160 #[cfg(target_pointer_width = "128")]
161 type FixedWidth = u128;
162
163 type Signed = isize;
164}
165
166#[cfg(any(
167 target_pointer_width = "32",
168 target_pointer_width = "64",
169 target_pointer_width = "128",
170))]
171impl SignedIndex for i32 {
172 const MAX: Self = Self::MAX;
173
174 #[inline(always)]
175 fn truncate(value: usize) -> Self {
176 #[allow(clippy::assertions_on_constants)]
177 const _: () = {
178 core::assert!(i32::BITS <= usize::BITS);
179 };
180 value as isize as Self
181 }
182
183 #[inline(always)]
184 fn zx(self) -> usize {
185 self as u32 as usize
186 }
187
188 #[inline(always)]
189 fn sx(self) -> usize {
190 self as isize as usize
191 }
192}
193
194#[cfg(any(target_pointer_width = "64", target_pointer_width = "128"))]
195impl SignedIndex for i64 {
196 const MAX: Self = Self::MAX;
197
198 #[inline(always)]
199 fn truncate(value: usize) -> Self {
200 #[allow(clippy::assertions_on_constants)]
201 const _: () = {
202 core::assert!(i64::BITS <= usize::BITS);
203 };
204 value as isize as Self
205 }
206
207 #[inline(always)]
208 fn zx(self) -> usize {
209 self as u64 as usize
210 }
211
212 #[inline(always)]
213 fn sx(self) -> usize {
214 self as isize as usize
215 }
216}
217
218#[cfg(target_pointer_width = "128")]
219impl SignedIndex for i128 {
220 const MAX: Self = Self::MAX;
221
222 #[inline(always)]
223 fn truncate(value: usize) -> Self {
224 #[allow(clippy::assertions_on_constants)]
225 const _: () = {
226 core::assert!(i128::BITS <= usize::BITS);
227 };
228 value as isize as Self
229 }
230
231 #[inline(always)]
232 fn zx(self) -> usize {
233 self as u128 as usize
234 }
235
236 #[inline(always)]
237 fn sx(self) -> usize {
238 self as isize as usize
239 }
240}
241
242impl SignedIndex for isize {
243 const MAX: Self = Self::MAX;
244
245 #[inline(always)]
246 fn truncate(value: usize) -> Self {
247 value as isize
248 }
249
250 #[inline(always)]
251 fn zx(self) -> usize {
252 self as usize
253 }
254
255 #[inline(always)]
256 fn sx(self) -> usize {
257 self as usize
258 }
259}
260
261#[track_caller]
291#[inline]
292pub fn swap_cols<E: ComplexField>(mat: MatMut<'_, E>, a: usize, b: usize) {
293 assert!(all(a < mat.ncols(), b < mat.ncols()));
294
295 if a == b {
296 return;
297 }
298
299 let mat = mat.into_const();
300 let mat_a = mat.col(a);
301 let mat_b = mat.col(b);
302
303 unsafe {
304 zipped!(
305 mat_a.const_cast().as_2d_mut(),
306 mat_b.const_cast().as_2d_mut(),
307 )
308 }
309 .for_each(|unzipped!(mut a, mut b)| {
310 let (a_read, b_read) = (a.read(), b.read());
311 a.write(b_read);
312 b.write(a_read);
313 });
314}
315
316#[track_caller]
346#[inline]
347pub fn swap_rows<E: ComplexField>(mat: MatMut<'_, E>, a: usize, b: usize) {
348 swap_cols(mat.transpose_mut(), a, b)
349}
350
351pub type PermutationRef<'a, I, E> = Matrix<PermRef<'a, I, E>>;
353pub type PermutationMut<'a, I, E> = Matrix<PermMut<'a, I, E>>;
355pub type Permutation<I, E> = Matrix<PermOwn<I, E>>;
357
358impl<I, E: Entity> Permutation<I, E> {
359 #[inline]
361 pub fn as_ref(&self) -> PermutationRef<'_, I, E> {
362 PermutationRef {
363 inner: PermRef {
364 forward: &self.inner.forward,
365 inverse: &self.inner.inverse,
366 __marker: core::marker::PhantomData,
367 },
368 }
369 }
370
371 #[inline]
373 pub fn as_mut(&mut self) -> PermutationMut<'_, I, E> {
374 PermutationMut {
375 inner: PermMut {
376 forward: &mut self.inner.forward,
377 inverse: &mut self.inner.inverse,
378 __marker: core::marker::PhantomData,
379 },
380 }
381 }
382}
383
384impl<I: Index, E: Entity> Permutation<I, E> {
385 #[inline]
393 #[track_caller]
394 pub fn new_checked(forward: alloc::boxed::Box<[I]>, inverse: alloc::boxed::Box<[I]>) -> Self {
395 PermutationRef::<'_, I, E>::new_checked(&forward, &inverse);
396 Self {
397 inner: PermOwn {
398 forward,
399 inverse,
400 __marker: core::marker::PhantomData,
401 },
402 }
403 }
404
405 #[inline]
412 #[track_caller]
413 pub unsafe fn new_unchecked(
414 forward: alloc::boxed::Box<[I]>,
415 inverse: alloc::boxed::Box<[I]>,
416 ) -> Self {
417 let n = forward.len();
418 assert!(all(
419 forward.len() == inverse.len(),
420 n <= I::Signed::MAX.zx(),
421 ));
422 Self {
423 inner: PermOwn {
424 forward,
425 inverse,
426 __marker: core::marker::PhantomData,
427 },
428 }
429 }
430
431 #[inline]
433 pub fn into_arrays(self) -> (alloc::boxed::Box<[I]>, alloc::boxed::Box<[I]>) {
434 (self.inner.forward, self.inner.inverse)
435 }
436
437 #[inline]
439 pub fn len(&self) -> usize {
440 self.inner.forward.len()
441 }
442
443 #[inline]
445 pub fn inverse(self) -> Self {
446 Self {
447 inner: PermOwn {
448 forward: self.inner.inverse,
449 inverse: self.inner.forward,
450 __marker: core::marker::PhantomData,
451 },
452 }
453 }
454
455 #[inline]
457 pub fn cast<T: Entity>(self) -> Permutation<I, T> {
458 Permutation {
459 inner: PermOwn {
460 forward: self.inner.forward,
461 inverse: self.inner.inverse,
462 __marker: core::marker::PhantomData,
463 },
464 }
465 }
466}
467
468impl<'a, I: Index, E: Entity> PermutationRef<'a, I, E> {
469 #[inline]
477 #[track_caller]
478 pub fn new_checked(forward: &'a [I], inverse: &'a [I]) -> Self {
479 #[track_caller]
480 fn check<I: Index>(forward: &[I], inverse: &[I]) {
481 let n = forward.len();
482 assert!(all(
483 forward.len() == inverse.len(),
484 n <= I::Signed::MAX.zx()
485 ));
486 for (i, &p) in forward.iter().enumerate() {
487 let p = p.to_signed().zx();
488 assert!(p < n);
489 assert!(inverse[p].to_signed().zx() == i);
490 }
491 }
492
493 check(I::canonicalize(forward), I::canonicalize(inverse));
494 Self {
495 inner: PermRef {
496 forward,
497 inverse,
498 __marker: core::marker::PhantomData,
499 },
500 }
501 }
502
503 #[inline]
510 #[track_caller]
511 pub unsafe fn new_unchecked(forward: &'a [I], inverse: &'a [I]) -> Self {
512 let n = forward.len();
513 assert!(all(
514 forward.len() == inverse.len(),
515 n <= I::Signed::MAX.zx(),
516 ));
517
518 Self {
519 inner: PermRef {
520 forward,
521 inverse,
522 __marker: core::marker::PhantomData,
523 },
524 }
525 }
526
527 #[inline]
529 pub fn into_arrays(self) -> (&'a [I], &'a [I]) {
530 (self.inner.forward, self.inner.inverse)
531 }
532
533 #[inline]
535 pub fn len(&self) -> usize {
536 debug_assert!(self.inner.inverse.len() == self.inner.forward.len());
537 self.inner.forward.len()
538 }
539
540 #[inline]
542 pub fn inverse(self) -> Self {
543 Self {
544 inner: PermRef {
545 forward: self.inner.inverse,
546 inverse: self.inner.forward,
547 __marker: core::marker::PhantomData,
548 },
549 }
550 }
551
552 #[inline]
554 pub fn cast<T: Entity>(self) -> PermutationRef<'a, I, T> {
555 PermutationRef {
556 inner: PermRef {
557 forward: self.inner.forward,
558 inverse: self.inner.inverse,
559 __marker: core::marker::PhantomData,
560 },
561 }
562 }
563
564 #[inline(always)]
566 pub fn canonicalize(self) -> PermutationRef<'a, I::FixedWidth, E> {
567 PermutationRef {
568 inner: PermRef {
569 forward: I::canonicalize(self.inner.forward),
570 inverse: I::canonicalize(self.inner.inverse),
571 __marker: core::marker::PhantomData,
572 },
573 }
574 }
575
576 #[inline(always)]
578 pub fn uncanonicalize<J: Index>(self) -> PermutationRef<'a, J, E> {
579 assert!(core::mem::size_of::<J>() == core::mem::size_of::<I>());
580 PermutationRef {
581 inner: PermRef {
582 forward: bytemuck::cast_slice(self.inner.forward),
583 inverse: bytemuck::cast_slice(self.inner.inverse),
584 __marker: core::marker::PhantomData,
585 },
586 }
587 }
588}
589
590impl<'a, I: Index, E: Entity> PermutationMut<'a, I, E> {
591 #[inline]
599 #[track_caller]
600 pub fn new_checked(forward: &'a mut [I], inverse: &'a mut [I]) -> Self {
601 PermutationRef::<'_, I, E>::new_checked(forward, inverse);
602 Self {
603 inner: PermMut {
604 forward,
605 inverse,
606 __marker: core::marker::PhantomData,
607 },
608 }
609 }
610
611 #[inline]
618 #[track_caller]
619 pub unsafe fn new_unchecked(forward: &'a mut [I], inverse: &'a mut [I]) -> Self {
620 let n = forward.len();
621 assert!(all(
622 forward.len() == inverse.len(),
623 n <= I::Signed::MAX.zx(),
624 ));
625
626 Self {
627 inner: PermMut {
628 forward,
629 inverse,
630 __marker: core::marker::PhantomData,
631 },
632 }
633 }
634
635 #[inline]
642 pub unsafe fn into_arrays(self) -> (&'a mut [I], &'a mut [I]) {
643 (self.inner.forward, self.inner.inverse)
644 }
645
646 #[inline]
648 pub fn len(&self) -> usize {
649 debug_assert!(self.inner.inverse.len() == self.inner.forward.len());
650 self.inner.forward.len()
651 }
652
653 #[inline]
655 pub fn inverse(self) -> Self {
656 Self {
657 inner: PermMut {
658 forward: self.inner.inverse,
659 inverse: self.inner.forward,
660 __marker: core::marker::PhantomData,
661 },
662 }
663 }
664
665 #[inline]
667 pub fn cast<T: Entity>(self) -> PermutationMut<'a, I, T> {
668 PermutationMut {
669 inner: PermMut {
670 forward: self.inner.forward,
671 inverse: self.inner.inverse,
672 __marker: core::marker::PhantomData,
673 },
674 }
675 }
676
677 #[inline(always)]
679 pub fn canonicalize(self) -> PermutationMut<'a, I::FixedWidth, E> {
680 PermutationMut {
681 inner: PermMut {
682 forward: I::canonicalize_mut(self.inner.forward),
683 inverse: I::canonicalize_mut(self.inner.inverse),
684 __marker: core::marker::PhantomData,
685 },
686 }
687 }
688
689 #[inline(always)]
691 pub fn uncanonicalize<J: Index>(self) -> PermutationMut<'a, J, E> {
692 assert!(core::mem::size_of::<J>() == core::mem::size_of::<I>());
693 PermutationMut {
694 inner: PermMut {
695 forward: bytemuck::cast_slice_mut(self.inner.forward),
696 inverse: bytemuck::cast_slice_mut(self.inner.inverse),
697 __marker: core::marker::PhantomData,
698 },
699 }
700 }
701}
702
703impl<'short, 'a, I, E: Entity> Reborrow<'short> for PermutationRef<'a, I, E> {
704 type Target = PermutationRef<'short, I, E>;
705
706 #[inline]
707 fn rb(&'short self) -> Self::Target {
708 *self
709 }
710}
711
712impl<'short, 'a, I, E: Entity> ReborrowMut<'short> for PermutationRef<'a, I, E> {
713 type Target = PermutationRef<'short, I, E>;
714
715 #[inline]
716 fn rb_mut(&'short mut self) -> Self::Target {
717 *self
718 }
719}
720
721impl<'short, 'a, I, E: Entity> Reborrow<'short> for PermutationMut<'a, I, E> {
722 type Target = PermutationRef<'short, I, E>;
723
724 #[inline]
725 fn rb(&'short self) -> Self::Target {
726 PermutationRef {
727 inner: PermRef {
728 forward: &*self.inner.forward,
729 inverse: &*self.inner.inverse,
730 __marker: core::marker::PhantomData,
731 },
732 }
733 }
734}
735
736impl<'short, 'a, I, E: Entity> ReborrowMut<'short> for PermutationMut<'a, I, E> {
737 type Target = PermutationMut<'short, I, E>;
738
739 #[inline]
740 fn rb_mut(&'short mut self) -> Self::Target {
741 PermutationMut {
742 inner: PermMut {
743 forward: &mut *self.inner.forward,
744 inverse: &mut *self.inner.inverse,
745 __marker: core::marker::PhantomData,
746 },
747 }
748 }
749}
750
751impl<'a, I: Debug, E: Entity> Debug for PermutationRef<'a, I, E> {
752 #[inline]
753 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
754 self.inner.fmt(f)
755 }
756}
757impl<'a, I: Debug, E: Entity> Debug for PermutationMut<'a, I, E> {
758 #[inline]
759 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
760 self.rb().fmt(f)
761 }
762}
763impl<'a, I: Debug, E: Entity> Debug for Permutation<I, E> {
764 #[inline]
765 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
766 self.as_ref().fmt(f)
767 }
768}
769
770#[inline]
778#[track_caller]
779pub fn permute_cols<I: Index, E: ComplexField>(
780 dst: MatMut<'_, E>,
781 src: MatRef<'_, E>,
782 perm_indices: PermutationRef<'_, I, E>,
783) {
784 assert!(all(
785 src.nrows() == dst.nrows(),
786 src.ncols() == dst.ncols(),
787 perm_indices.into_arrays().0.len() == src.ncols(),
788 ));
789
790 permute_rows(
791 dst.transpose_mut(),
792 src.transpose(),
793 perm_indices.canonicalize(),
794 );
795}
796
797#[inline]
805#[track_caller]
806pub fn permute_rows<I: Index, E: ComplexField>(
807 dst: MatMut<'_, E>,
808 src: MatRef<'_, E>,
809 perm_indices: PermutationRef<'_, I, E>,
810) {
811 #[track_caller]
812 fn implementation<I: Index, E: ComplexField>(
813 dst: MatMut<'_, E>,
814 src: MatRef<'_, E>,
815 perm_indices: PermutationRef<'_, I, E>,
816 ) {
817 assert!(all(
818 src.nrows() == dst.nrows(),
819 src.ncols() == dst.ncols(),
820 perm_indices.into_arrays().0.len() == src.nrows(),
821 ));
822
823 constrained::Size::with2(src.nrows(), src.ncols(), |m, n| {
824 let mut dst = constrained::MatMut::new(dst, m, n);
825 let src = constrained::MatRef::new(src, m, n);
826 let perm = constrained::permutation::PermutationRef::new(perm_indices, m)
827 .into_arrays()
828 .0;
829
830 if dst.rb().into_inner().row_stride().unsigned_abs()
831 < dst.rb().into_inner().col_stride().unsigned_abs()
832 {
833 for j in n.indices() {
834 for i in m.indices() {
835 dst.rb_mut().write(i, j, src.read(perm[i].zx(), j));
836 }
837 }
838 } else {
839 for i in m.indices() {
840 let src_i = src.into_inner().row(perm[i].zx().into_inner());
841 let mut dst_i = dst.rb_mut().into_inner().row_mut(i.into_inner());
842
843 dst_i.copy_from(src_i);
844 }
845 }
846 });
847 }
848
849 implementation(dst, src, perm_indices.canonicalize())
850}
851
852pub fn permute_rows_in_place_req<I: Index, E: Entity>(
855 nrows: usize,
856 ncols: usize,
857) -> Result<StackReq, SizeOverflow> {
858 temp_mat_req::<E>(nrows, ncols)
859}
860
861pub fn permute_cols_in_place_req<I: Index, E: Entity>(
864 nrows: usize,
865 ncols: usize,
866) -> Result<StackReq, SizeOverflow> {
867 temp_mat_req::<E>(nrows, ncols)
868}
869
870#[inline]
877#[track_caller]
878pub fn permute_rows_in_place<I: Index, E: ComplexField>(
879 matrix: MatMut<'_, E>,
880 perm_indices: PermutationRef<'_, I, E>,
881 stack: PodStack<'_>,
882) {
883 #[inline]
884 #[track_caller]
885 fn implementation<E: ComplexField, I: Index>(
886 matrix: MatMut<'_, E>,
887 perm_indices: PermutationRef<'_, I, E>,
888 stack: PodStack<'_>,
889 ) {
890 let mut matrix = matrix;
891 let (mut tmp, _) = temp_mat_uninit::<E>(matrix.nrows(), matrix.ncols(), stack);
892 tmp.rb_mut().copy_from(matrix.rb());
893 permute_rows(matrix.rb_mut(), tmp.rb(), perm_indices);
894 }
895
896 implementation(matrix, perm_indices.canonicalize(), stack)
897}
898
899#[inline]
906#[track_caller]
907pub fn permute_cols_in_place<I: Index, E: ComplexField>(
908 matrix: MatMut<'_, E>,
909 perm_indices: PermutationRef<'_, I, E>,
910 stack: PodStack<'_>,
911) {
912 #[inline]
913 #[track_caller]
914 fn implementation<I: Index, E: ComplexField>(
915 matrix: MatMut<'_, E>,
916 perm_indices: PermutationRef<'_, I, E>,
917 stack: PodStack<'_>,
918 ) {
919 let mut matrix = matrix;
920 let (mut tmp, _) = temp_mat_uninit::<E>(matrix.nrows(), matrix.ncols(), stack);
921 tmp.rb_mut().copy_from(matrix.rb());
922 permute_cols(matrix.rb_mut(), tmp.rb(), perm_indices);
923 }
924
925 implementation(matrix, perm_indices.canonicalize(), stack)
926}