1use std::ops::{Deref, Mul};
5
6use crate::dense_vector::{DenseVector, DenseVectorMut};
7use crate::indexing::SpIndex;
8use crate::sparse::{CompressedStorage, CsMatI, CsMatViewI};
9
10#[derive(Debug, Clone)]
11enum PermStorage<I, IndStorage>
12where
13 IndStorage: Deref<Target = [I]>,
14{
15 Identity,
16 FinitePerm {
17 perm: IndStorage,
18 perm_inv: IndStorage,
19 },
20}
21
22use self::PermStorage::{FinitePerm, Identity};
23
24#[derive(Debug, Clone)]
25pub struct Permutation<I, IndStorage>
26where
27 IndStorage: Deref<Target = [I]>,
28{
29 dim: usize,
30 storage: PermStorage<I, IndStorage>,
31}
32
33pub type PermOwned = Permutation<usize, Vec<usize>>;
34pub type PermOwnedI<I> = Permutation<I, Vec<I>>;
35
36pub type PermView<'a> = Permutation<usize, &'a [usize]>;
37pub type PermViewI<'a, I> = Permutation<I, &'a [I]>;
38
39pub fn perm_is_valid<I: SpIndex>(perm: &[I]) -> bool {
40 let n = perm.len();
41 let mut seen = vec![false; n];
42 for i in perm {
43 if *i < I::zero() || *i >= I::from_usize(n) || seen[i.index()] {
44 return false;
45 }
46 seen[i.index()] = true;
47 }
48 true
49}
50
51impl<I: SpIndex> PermOwnedI<I> {
52 pub fn new(perm: Vec<I>) -> Self {
53 assert!(perm_is_valid(&perm));
54 Self::new_trusted(perm)
55 }
56
57 pub(crate) fn new_trusted(perm: Vec<I>) -> Self {
58 let mut perm_inv = perm.clone();
59 for (ind, val) in perm.iter().enumerate() {
60 perm_inv[val.index()] = I::from_usize(ind);
61 }
62 Self {
63 dim: perm.len(),
64 storage: FinitePerm { perm, perm_inv },
65 }
66 }
67}
68
69impl<I: SpIndex> PermViewI<'_, I> {
70 pub fn reborrow(&self) -> Self {
71 match self.storage {
72 Identity => Self {
73 dim: self.dim,
74 storage: Identity,
75 },
76 FinitePerm {
77 perm: p,
78 perm_inv: p_,
79 } => Self {
80 dim: self.dim,
81 storage: FinitePerm {
82 perm: &p[..],
83 perm_inv: &p_[..],
84 },
85 },
86 }
87 }
88
89 pub fn reborrow_inv(&self) -> Self {
90 match self.storage {
91 Identity => Self {
92 dim: self.dim,
93 storage: Identity,
94 },
95 FinitePerm {
96 perm: p,
97 perm_inv: p_,
98 } => Self {
99 dim: self.dim,
100 storage: FinitePerm {
101 perm: &p_[..],
102 perm_inv: &p[..],
103 },
104 },
105 }
106 }
107}
108
109impl<I: SpIndex, IndStorage> Permutation<I, IndStorage>
110where
111 IndStorage: Deref<Target = [I]>,
112{
113 pub fn identity(dim: usize) -> Self {
114 Self {
115 dim,
116 storage: Identity,
117 }
118 }
119
120 pub fn inv(&self) -> PermViewI<I> {
121 match self.storage {
122 Identity => PermViewI {
123 dim: self.dim,
124 storage: Identity,
125 },
126 FinitePerm {
127 perm: ref p,
128 perm_inv: ref p_,
129 } => PermViewI {
130 dim: self.dim,
131 storage: FinitePerm {
132 perm: &p_[..],
133 perm_inv: &p[..],
134 },
135 },
136 }
137 }
138
139 pub fn dim(&self) -> usize {
140 self.dim
141 }
142
143 pub fn is_identity(&self) -> bool {
145 match self.storage {
146 Identity => true,
147 FinitePerm {
148 perm: ref p,
149 perm_inv: ref _p_,
150 } => p.iter().enumerate().all(|(ind, x)| ind == x.index()),
151 }
152 }
153
154 pub fn view(&self) -> PermViewI<I> {
155 match self.storage {
156 Identity => PermViewI {
157 dim: self.dim,
158 storage: Identity,
159 },
160 FinitePerm {
161 perm: ref p,
162 perm_inv: ref p_,
163 } => PermViewI {
164 dim: self.dim,
165 storage: FinitePerm {
166 perm: &p[..],
167 perm_inv: &p_[..],
168 },
169 },
170 }
171 }
172
173 pub fn owned_clone(&self) -> PermOwnedI<I> {
174 match self.storage {
175 Identity => PermOwnedI {
176 dim: self.dim,
177 storage: Identity,
178 },
179 FinitePerm {
180 perm: ref p,
181 perm_inv: ref p_,
182 } => PermOwnedI {
183 dim: self.dim,
184 storage: FinitePerm {
185 perm: p.iter().copied().collect(),
186 perm_inv: p_.iter().copied().collect(),
187 },
188 },
189 }
190 }
191
192 pub fn at(&self, index: usize) -> usize {
193 assert!(index < self.dim);
194 match self.storage {
195 Identity => index,
196 FinitePerm { perm: ref p, .. } => p[index].index_unchecked(),
197 }
198 }
199
200 pub fn at_inv(&self, index: usize) -> usize {
201 assert!(index < self.dim);
202 match self.storage {
203 Identity => index,
204 FinitePerm {
205 perm_inv: ref p_, ..
206 } => p_[index].index_unchecked(),
207 }
208 }
209
210 pub fn vec(&self) -> Vec<I> {
212 match self.storage {
213 Identity => (0..self.dim).map(I::from_usize).collect(),
214 FinitePerm { perm: ref p, .. } => p.to_vec(),
215 }
216 }
217
218 pub fn inv_vec(&self) -> Vec<I> {
220 match self.storage {
221 Identity => (0..self.dim).map(I::from_usize).collect(),
222 FinitePerm {
223 perm_inv: ref p_, ..
224 } => p_.to_vec(),
225 }
226 }
227
228 pub fn to_other_idx_type<I2>(&self) -> PermOwnedI<I2>
229 where
230 I2: SpIndex,
231 {
232 match self.storage {
233 Identity => PermOwnedI::identity(self.dim),
234 FinitePerm {
235 perm: ref p,
236 perm_inv: ref p_,
237 } => {
238 let perm = p
239 .iter()
240 .map(|i| I2::from_usize(i.index_unchecked()))
241 .collect();
242 let perm_inv = p_
243 .iter()
244 .map(|i| I2::from_usize(i.index_unchecked()))
245 .collect();
246 PermOwnedI {
247 dim: self.dim,
248 storage: FinitePerm { perm, perm_inv },
249 }
250 }
251 }
252 }
253}
254
255impl<'b, V, I, IndStorage> Mul<V> for &'b Permutation<I, IndStorage>
256where
257 IndStorage: 'b + Deref<Target = [I]>,
258 V: DenseVector,
259 <V as DenseVector>::Owned:
260 DenseVectorMut + DenseVector<Scalar = <V as DenseVector>::Scalar>,
261 <V as DenseVector>::Scalar: Clone,
262 I: SpIndex,
263{
264 type Output = V::Owned;
265 fn mul(self, rhs: V) -> Self::Output {
266 assert_eq!(self.dim, rhs.dim());
267 let mut res = rhs.to_owned();
268 match self.storage {
269 Identity => res,
270 FinitePerm { perm: ref p, .. } => {
271 for (i, pi) in p.iter().enumerate() {
272 *res.index_mut(i) = rhs.index(pi.index_unchecked()).clone();
273 }
274 res
275 }
276 }
277 }
278}
279
280impl<V, I, IndStorage> Mul<V> for Permutation<I, IndStorage>
281where
282 IndStorage: Deref<Target = [I]>,
283 V: DenseVector,
284 <V as DenseVector>::Owned:
285 DenseVectorMut + DenseVector<Scalar = <V as DenseVector>::Scalar>,
286 <V as DenseVector>::Scalar: Clone,
287 I: SpIndex,
288{
289 type Output = V::Owned;
290 fn mul(self, rhs: V) -> Self::Output {
291 &self * rhs
292 }
293}
294
295fn permute_outer<N, I, Iptr>(
297 mat: CsMatViewI<N, I, Iptr>,
298 perm: PermViewI<I>,
299) -> CsMatI<N, I, Iptr>
300where
301 N: Clone + ::std::fmt::Debug,
302 I: SpIndex,
303 Iptr: SpIndex,
304{
305 assert!(mat.outer_dims() == perm.dim());
306 if mat.rows() == 0 || mat.cols() == 0 {
307 return mat.to_owned();
308 }
309
310 let mut indptr = Vec::with_capacity(mat.indptr().len());
311 let mut indices = Vec::with_capacity(mat.indices().len());
312 let mut data = Vec::with_capacity(mat.data().len());
313
314 let p = match perm.storage {
315 Identity => unreachable!(),
316 FinitePerm {
317 perm: p,
318 perm_inv: _,
319 } => p,
320 };
321
322 let mut nnz = Iptr::zero();
323 indptr.push(nnz);
324 let mut tmp = Vec::with_capacity(mat.max_outer_nnz());
325 for in_outer in p {
326 nnz += mat.indptr().nnz_in_outer(in_outer.index());
327 indptr.push(nnz);
328 tmp.clear();
329
330 let outer = mat.outer_view(in_outer.index()).unwrap();
331 for (ind, val) in outer.indices().iter().zip(outer.data()) {
332 tmp.push((*ind, val.clone()));
333 }
334 tmp.sort_by_key(|(ind, _)| *ind);
335 for (ind, val) in &tmp {
336 indices.push(*ind);
337 data.push(val.clone());
338 }
339 }
340
341 match mat.storage() {
342 CompressedStorage::CSR => {
343 CsMatI::new(mat.shape(), indptr, indices, data)
344 }
345 CompressedStorage::CSC => {
346 CsMatI::new_csc(mat.shape(), indptr, indices, data)
347 }
348 }
349}
350
351fn permute_inner<N, I, Iptr>(
353 mat: CsMatViewI<N, I, Iptr>,
354 perm: PermViewI<I>,
355) -> CsMatI<N, I, Iptr>
356where
357 N: Clone + ::std::fmt::Debug,
358 I: SpIndex,
359 Iptr: SpIndex,
360{
361 assert!(mat.inner_dims() == perm.dim());
362 if mat.rows() == 0 || mat.cols() == 0 {
363 return mat.to_owned();
364 }
365
366 let mut indptr = Vec::with_capacity(mat.indptr().len());
367 let mut indices = Vec::with_capacity(mat.indices().len());
368 let mut data = Vec::with_capacity(mat.data().len());
369 let p_ = match perm.storage {
370 Identity => unreachable!(),
371 FinitePerm {
372 perm: _,
373 perm_inv: p_,
374 } => p_,
375 };
376
377 let mut nnz = Iptr::zero();
378 indptr.push(nnz);
379 let mut tmp = Vec::with_capacity(mat.max_outer_nnz());
380 for in_outer in 0..mat.outer_dims() {
381 nnz += mat.indptr().nnz_in_outer(in_outer.index());
382 indptr.push(nnz);
383 tmp.clear();
384
385 let outer = mat.outer_view(in_outer.index()).unwrap();
386 for (ind, val) in outer.indices().iter().zip(outer.data()) {
387 tmp.push((p_[ind.index()], val.clone()));
388 }
389 tmp.sort_by_key(|(ind, _)| *ind);
390 for (ind, val) in &tmp {
391 indices.push(*ind);
392 data.push(val.clone());
393 }
394 }
395
396 match mat.storage() {
397 CompressedStorage::CSR => {
398 CsMatI::new(mat.shape(), indptr, indices, data)
399 }
400 CompressedStorage::CSC => {
401 CsMatI::new_csc(mat.shape(), indptr, indices, data)
402 }
403 }
404}
405
406pub fn permute_rows<N, I, Iptr>(
408 mat: CsMatViewI<N, I, Iptr>,
409 perm: PermViewI<I>,
410) -> CsMatI<N, I, Iptr>
411where
412 N: Clone + ::std::fmt::Debug,
413 I: SpIndex,
414 Iptr: SpIndex,
415{
416 match mat.storage {
417 CompressedStorage::CSC => permute_inner(mat, perm),
418 CompressedStorage::CSR => permute_outer(mat, perm),
419 }
420}
421
422pub fn permute_cols<N, I, Iptr>(
424 mat: CsMatViewI<N, I, Iptr>,
425 perm: PermViewI<I>,
426) -> CsMatI<N, I, Iptr>
427where
428 N: Clone + ::std::fmt::Debug,
429 I: SpIndex,
430 Iptr: SpIndex,
431{
432 match mat.storage {
433 CompressedStorage::CSC => permute_outer(mat, perm),
434 CompressedStorage::CSR => permute_inner(mat, perm),
435 }
436}
437
438pub fn transform_mat_papt<N, I, Iptr>(
440 mat: CsMatViewI<N, I, Iptr>,
441 perm: PermViewI<I>,
442) -> CsMatI<N, I, Iptr>
443where
444 N: Clone + ::std::fmt::Debug,
445 I: SpIndex,
446 Iptr: SpIndex,
447{
448 assert!(mat.rows() == mat.cols());
449 assert!(mat.rows() == perm.dim());
450 if perm.is_identity() || mat.rows() == 0 {
451 return mat.to_owned();
452 }
453 let mut indptr = Vec::with_capacity(mat.indptr().len());
456 let mut indices = Vec::with_capacity(mat.indices().len());
457 let mut data = Vec::with_capacity(mat.data().len());
458 let (p, p_) = match perm.storage {
459 Identity => unreachable!(),
460 FinitePerm {
461 perm: p,
462 perm_inv: p_,
463 } => (p, p_),
464 };
465 let mut nnz = Iptr::zero();
466 indptr.push(nnz);
467 let mut tmp = Vec::with_capacity(mat.max_outer_nnz());
468 for in_outer in p {
469 nnz += mat.indptr().nnz_in_outer(in_outer.index());
470 indptr.push(nnz);
471 tmp.clear();
472 let outer = mat.outer_view(in_outer.index()).unwrap();
473 for (ind, val) in outer.indices().iter().zip(outer.data()) {
474 tmp.push((p_[ind.index()], val.clone()));
475 }
476 tmp.sort_by_key(|(ind, _)| *ind);
477 for (ind, val) in &tmp {
478 indices.push(*ind);
479 data.push(val.clone());
480 }
481 }
482
483 match mat.storage() {
484 CompressedStorage::CSR => {
485 CsMatI::new(mat.shape(), indptr, indices, data)
486 }
487 CompressedStorage::CSC => {
488 CsMatI::new_csc(mat.shape(), indptr, indices, data)
489 }
490 }
491}
492
493pub fn transform_mat_paq<N, I, Iptr>(
497 mat: CsMatViewI<N, I, Iptr>,
498 row_perm: PermViewI<I>,
499 col_perm: PermViewI<I>,
500) -> CsMatI<N, I, Iptr>
501where
502 N: Clone + ::std::fmt::Debug,
503 I: SpIndex,
504 Iptr: SpIndex,
505{
506 assert!(mat.rows() == row_perm.dim());
507 assert!(mat.cols() == col_perm.dim());
508
509 if (row_perm.is_identity() && col_perm.is_identity())
510 || mat.rows() == 0
511 || mat.cols() == 0
512 {
513 return mat.to_owned();
514 }
515
516 let (p, p_) = match row_perm.storage {
522 Identity => {
523 return permute_cols(mat, col_perm);
525 }
526 FinitePerm {
527 perm: p,
528 perm_inv: p_,
529 } => (p, p_),
530 };
531
532 let (q, q_) = match col_perm.storage {
533 Identity => {
534 return permute_rows(mat, row_perm);
536 }
537 FinitePerm {
538 perm: q,
539 perm_inv: q_,
540 } => (q, q_),
541 };
542
543 let (outer_perm, inner_perm) = match mat.storage() {
545 CompressedStorage::CSR => (p, q_),
546 CompressedStorage::CSC => (q, p_),
547 };
548
549 let mut indptr = Vec::with_capacity(mat.indptr().len());
553 let mut indices = Vec::with_capacity(mat.indices().len());
554 let mut data = Vec::with_capacity(mat.data().len());
555 let mut nnz = Iptr::zero();
556 indptr.push(nnz);
557 let mut tmp = Vec::with_capacity(mat.max_outer_nnz());
558 for in_outer in outer_perm {
559 nnz += mat.indptr().nnz_in_outer(in_outer.index());
560 indptr.push(nnz);
561 tmp.clear();
562 let outer = mat.outer_view(in_outer.index()).unwrap();
563 for (ind, val) in outer.indices().iter().zip(outer.data()) {
564 tmp.push((inner_perm[ind.index()], val.clone()));
565 }
566 tmp.sort_by_key(|(ind, _)| *ind);
567 for (ind, val) in &tmp {
568 indices.push(*ind);
569 data.push(val.clone());
570 }
571 }
572
573 match mat.storage() {
574 CompressedStorage::CSR => {
575 CsMatI::new(mat.shape(), indptr, indices, data)
576 }
577 CompressedStorage::CSC => {
578 CsMatI::new_csc(mat.shape(), indptr, indices, data)
579 }
580 }
581}
582
583#[cfg(test)]
584mod test {
585 use crate::sparse::CsMat;
586
587 #[test]
588 fn perm_mul() {
589 let x = vec![5, 1, 2, 3, 4];
595 let p = super::PermOwned::new(vec![2, 1, 3, 0, 4]);
596
597 let y = &p * &x;
598 assert_eq!(&y, &[2, 1, 3, 5, 4]);
599
600 let x = ndarray::arr1(&[5, 1, 2, 3, 4]);
601 let y = p.view() * x.view();
602 assert_eq!(y, ndarray::arr1(&[2, 1, 3, 5, 4]));
603 }
604
605 #[test]
606 fn transform_mat_papt() {
607 let mat = CsMat::new_csc(
613 (5, 5),
614 vec![0, 3, 4, 5, 8, 10],
615 vec![0, 3, 4, 1, 3, 0, 2, 3, 0, 4],
616 vec![1, 3, 1, 2, 1, 3, 1, 1, 1, 1],
617 );
618
619 let perm = super::PermOwned::new(vec![2, 1, 3, 0, 4]);
620 let expected_papt = CsMat::new_csc(
633 (5, 5),
634 vec![0, 1, 2, 5, 8, 10],
635 vec![2, 1, 0, 2, 3, 2, 3, 4, 3, 4],
636 vec![1, 2, 1, 1, 3, 3, 1, 1, 1, 1],
637 );
638 let papt = super::transform_mat_papt(mat.view(), perm.view());
639 assert_eq!(expected_papt, papt);
640 }
641
642 #[test]
643 fn transform_mat_paq() {
644 let mat = CsMat::new_csc(
650 (5, 5),
651 vec![0, 3, 4, 5, 8, 10],
652 vec![0, 3, 4, 1, 3, 0, 2, 3, 0, 4],
653 vec![1, 3, 1, 2, 1, 3, 1, 1, 1, 1],
654 );
655
656 let row_perm = super::PermOwned::new(vec![2, 1, 3, 0, 4]);
657 let col_perm = super::PermOwned::new(vec![1, 2, 3, 0, 4]);
658 let expected_paq = CsMat::new_csc(
665 (5, 5),
666 vec![0, 1, 2, 5, 8, 10],
667 vec![1, 2, 0, 2, 3, 2, 3, 4, 3, 4],
668 vec![2, 1, 1, 1, 3, 3, 1, 1, 1, 1],
669 );
670
671 let paq = super::transform_mat_paq(
673 mat.view(),
674 row_perm.view(),
675 col_perm.view(),
676 );
677 assert_eq!(expected_paq.to_dense(), paq.to_dense());
678
679 let paq = super::transform_mat_paq(
681 mat.to_other_storage().view(),
682 row_perm.view(),
683 col_perm.view(),
684 );
685 assert_eq!(expected_paq.to_dense(), paq.to_dense());
686
687 let aq = super::permute_cols(mat.view(), col_perm.view());
689 let paq_separate = super::permute_rows(aq.view(), row_perm.view());
690 assert_eq!(expected_paq.to_dense(), paq_separate.to_dense());
691 }
692
693 #[test]
694 fn permute_rows() {
695 let mat = CsMat::new_csc(
703 (5, 4),
704 vec![0, 3, 4, 5, 8],
705 vec![0, 3, 4, 1, 3, 0, 2, 3],
706 vec![1, 3, 1, 2, 1, 3, 1, 1],
707 );
708
709 let pa_expected = CsMat::new_csc(
715 (5, 4),
716 vec![0, 3, 4, 5, 8],
717 vec![1, 3, 4, 0, 3, 1, 2, 3],
718 vec![1, 3, 1, 2, 1, 3, 1, 1],
719 );
720
721 let perm = super::PermOwned::new(vec![1, 0, 2, 3, 4]);
723 let pa = super::permute_rows(mat.view(), perm.view());
724 assert_eq!(pa_expected.to_dense(), pa.to_dense());
725
726 let perm = super::PermOwned::new(vec![1, 0, 2, 3, 4]);
728 let pa =
729 super::permute_rows(mat.to_other_storage().view(), perm.view());
730 assert_eq!(pa_expected.to_dense(), pa.to_dense());
731 }
732
733 #[test]
734 fn permute_cols() {
735 let mat = CsMat::new_csc(
743 (5, 4),
744 vec![0, 3, 4, 5, 8],
745 vec![0, 3, 4, 1, 3, 0, 2, 3],
746 vec![1, 3, 1, 2, 1, 3, 1, 1],
747 );
748
749 let atqt_expected = CsMat::new_csc(
755 (5, 4),
756 vec![0, 3, 4, 5, 8],
757 vec![1, 3, 4, 0, 3, 1, 2, 3],
758 vec![1, 3, 1, 2, 1, 3, 1, 1],
759 );
760
761 let perm = super::PermOwned::new(vec![1, 0, 2, 3, 4]);
763 let atq = super::permute_cols(mat.transpose_view(), perm.view());
764 assert_eq!(atqt_expected.to_dense(), atq.transpose_view().to_dense());
765
766 let perm = super::PermOwned::new(vec![1, 0, 2, 3, 4]);
768 let atq = super::permute_cols(
769 mat.to_other_storage().transpose_view(),
770 perm.view(),
771 );
772 assert_eq!(atqt_expected.to_dense(), atq.transpose_view().to_dense());
773 }
774
775 #[test]
776 fn perm_validity() {
777 use super::perm_is_valid;
778 assert!(perm_is_valid(&[0, 1, 2, 3, 4]));
779 assert!(perm_is_valid(&[1, 0, 3, 4, 2]));
780 assert!(!perm_is_valid(&[0, 1, 2, 3, 5]));
781 assert!(!perm_is_valid(&[0, 1, 2, 3, 3]));
782 }
783}