1use sprs::PermView;
2use yui_core::{CloneAnd, Ring, RingOps};
3use crate::sparse::{SpMat, MatTrait, SpVec};
4
5#[derive(Clone, Debug)]
6#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
7pub struct Trans<R>
8where R: Ring, for <'x> &'x R: RingOps<R> {
9 src_dim: usize,
10 tgt_dim: usize,
11 f_mats: Vec<SpMat<R>>,
12 b_mats: Vec<SpMat<R>>,
13}
14
15impl<R> Trans<R>
16where R: Ring, for <'x> &'x R: RingOps<R> {
17 pub fn id(n: usize) -> Self {
18 Self {
19 src_dim: n,
20 tgt_dim: n,
21 f_mats: vec![],
22 b_mats: vec![]
23 }
24 }
25
26 pub fn zero() -> Self {
27 Self::id(0)
28 }
29
30 pub fn new(f: SpMat<R>, b: SpMat<R>) -> Self {
31 let mut t = Self::id(f.ncols());
32 t.append(f, b);
33 t
34 }
35
36 pub fn src_dim(&self) -> usize {
37 self.src_dim
38 }
39
40 pub fn tgt_dim(&self) -> usize {
41 self.tgt_dim
42 }
43
44 pub fn is_id(&self) -> bool {
45 self.f_mats.is_empty()
46 }
47
48 pub fn forward(&self, v: &SpVec<R>) -> SpVec<R> {
49 assert_eq!(v.dim(), self.src_dim);
50 self.f_mats.iter().fold(v.clone(), |v, f| f * v)
51 }
52
53 pub fn backward(&self, v: &SpVec<R>) -> SpVec<R> {
54 assert_eq!(v.dim(), self.tgt_dim);
55 self.b_mats.iter().rev().fold(v.clone(), |v, f| f * v)
56 }
57
58 pub fn append(&mut self, f: SpMat<R>, b: SpMat<R>) {
59 assert_eq!(f.ncols(), b.nrows());
60 assert_eq!(f.nrows(), b.ncols());
61 assert_eq!(f.ncols(), self.tgt_dim);
62
63 self.tgt_dim = f.nrows();
64 self.f_mats.push(f);
65 self.b_mats.push(b);
66 }
67
68 pub fn append_perm(&mut self, p: PermView) {
69 assert_eq!(p.dim(), self.tgt_dim);
70 let f = SpMat::from_row_perm(p.clone());
71 let b = SpMat::from_col_perm(p);
72 self.append(f, b)
73 }
74
75 pub fn merge(&mut self, mut other: Trans<R>) {
76 assert_eq!(self.tgt_dim, other.src_dim);
77
78 self.tgt_dim = other.tgt_dim;
79 self.f_mats.append(&mut other.f_mats);
80 self.b_mats.append(&mut other.b_mats);
81 }
82
83 pub fn merged(&self, other: &Trans<R>) -> Self {
84 self.clone_and(|t|
85 t.merge(other.clone())
86 )
87 }
88
89 pub fn forward_mat(&self) -> SpMat<R> {
90 if self.f_mats.len() == 1 {
92 self.f_mats[0].clone()
93 } else {
94 self.f_mats.iter().rev().fold(
95 SpMat::id(self.tgt_dim),
96 |res, f| res * f
97 )
98 }
99 }
100
101 pub fn backward_mat(&self) -> SpMat<R> {
102 if self.b_mats.len() == 1 {
104 self.b_mats[0].clone()
105 } else {
106 self.b_mats.iter().rev().fold(
107 SpMat::id(self.tgt_dim),
108 |res, b| b * res
109 )
110 }
111 }
112
113 pub fn reduce(&mut self) {
114 if self.f_mats.len() > 1 {
115 let f = self.forward_mat();
116 self.f_mats = vec![f];
117 }
118
119 if self.b_mats.len() > 1 {
120 let b = self.backward_mat();
121 self.b_mats = vec![b];
122 }
123 }
124
125 pub fn sub(&self, indices: &[usize]) -> Self {
126 let n = self.tgt_dim();
127 let p = indices.len();
128 let f = SpMat::from_entries(
129 (p, n),
130 indices.iter().enumerate().map(|(i, &j)|
131 (i, j, R::one())
132 )
133 );
134 let b = SpMat::from_entries(
135 (n, p),
136 indices.iter().enumerate().map(|(i, &j)|
137 (j, i, R::one())
138 )
139 );
140 self.clone_and(|sub|
141 sub.append(f, b)
142 )
143 }
144}
145
146impl<R> Default for Trans<R>
147where R: Ring, for <'x> &'x R: RingOps<R> {
148 fn default() -> Self {
149 Self::zero()
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 use sprs::PermOwned;
156
157 use super::*;
158 use crate::sparse::*;
159
160 #[test]
161 fn id() {
162 let t = Trans::<i32>::id(5);
163
164 let v = SpVec::from(vec![0,1,2,3,4]);
165 let w = t.forward(&v);
166 let x = t.backward(&v);
167
168 assert_eq!(w, SpVec::from(vec![0,1,2,3,4]));
169 assert_eq!(x, SpVec::from(vec![0,1,2,3,4]));
170 }
171
172 #[test]
173 fn trans() {
174 let t = Trans::<i32>::new(
175 SpMat::id(5).submat_rows(0..3),
176 SpMat::id(5).submat_cols(0..3),
177 );
178
179 let v = SpVec::from(vec![0,1,2,3,4]);
180 let w = t.forward(&v);
181 let x = t.backward(&w);
182
183 assert_eq!(w, SpVec::from(vec![0,1,2]));
184 assert_eq!(x, SpVec::from(vec![0,1,2,0,0]));
185 }
186
187 #[test]
188 fn append_perm() {
189 let mut t = Trans::<i32>::new(
190 SpMat::id(5).submat_rows(0..3),
191 SpMat::id(5).submat_cols(0..3),
192 );
193 t.append_perm(
194 PermOwned::new(vec![1,2,0]).view()
195 );
196
197 let v = SpVec::from(vec![0,1,2,3,4]);
198 let w = t.forward(&v);
199 let x = t.backward(&w);
200
201 assert_eq!(w.into_vec(), vec![2,0,1]);
202 assert_eq!(x.into_vec(), vec![0,1,2,0,0]);
203 }
204
205 #[test]
206 fn is_id() {
207 let t = Trans::<i64>::id(10);
208 assert!(t.is_id());
209 }
210}