1use either::Either;
2use log::debug;
3use num_traits::Zero;
4use yui_core::{Ring, RingOps};
5
6use super::*;
7
8cfg_if::cfg_if! {
9 if #[cfg(feature = "multithread")] {
10 use std::cell::RefCell;
11 use std::sync::Arc;
12 use thread_local::ThreadLocal;
13 use rayon::prelude::*;
14 }
15}
16
17const LOG_THRESHOLD: usize = 10_000;
18
19#[derive(Clone, Copy, PartialEq, Eq)]
20pub enum TriangularType {
21 Upper, Lower
22}
23
24impl TriangularType {
25 pub fn is_upper(&self) -> bool {
26 match self {
27 Self::Upper => true,
28 Self::Lower => false
29 }
30 }
31
32 pub fn tranpose(&self) -> Self {
33 match self {
34 Self::Upper => Self::Lower,
35 Self::Lower => Self::Upper
36 }
37 }
38}
39
40pub fn inv_triangular<R>(t: TriangularType, a: &SpMat<R>) -> SpMat<R>
41where R: Ring, for<'x> &'x R: RingOps<R> {
42 let e = SpMat::id(a.nrows());
43 solve_triangular(t, a, &e)
44}
45
46pub fn solve_triangular<R>(t: TriangularType, a: &SpMat<R>, y: &SpMat<R>) -> SpMat<R>
48where R: Ring, for<'x> &'x R: RingOps<R> {
49 assert_eq!(a.nrows(), y.nrows());
50 debug_assert!(a.is_triang(t));
51
52 cfg_if::cfg_if! {
53 if #[cfg(feature = "multithread")] {
54 solve_triangular_m(t, a, y)
55 } else {
56 solve_triangular_s(t, a, y)
57 }
58 }
59}
60
61pub fn solve_triangular_left<R>(t: TriangularType, a: &SpMat<R>, y: &SpMat<R>) -> SpMat<R>
63where R: Ring, for<'x> &'x R: RingOps<R> {
64 solve_triangular(t.tranpose(), &a.transpose(), &y.transpose()).transpose()
65}
66
67pub fn solve_triangular_vec<R>(t: TriangularType, a: &SpMat<R>, b: &SpVec<R>) -> SpVec<R>
68where R: Ring, for<'x> &'x R: RingOps<R> {
69 assert_eq!(a.nrows(), b.dim());
70 debug_assert!(a.is_triang(t));
71
72 let diag = collect_diag(a);
73 let mut b = b.to_dense();
74
75 _solve_triangular(t, a, &diag, &mut b)
76}
77
78#[allow(unused)]
79fn solve_triangular_s<R>(t: TriangularType, a: &SpMat<R>, y: &SpMat<R>) -> SpMat<R>
80where R: Ring, for<'x> &'x R: RingOps<R> {
81 debug!("solve triangular, y: {:?}", y.shape());
82
83 let (n, k) = (a.nrows(), y.ncols());
84 let diag = collect_diag(a);
85 let mut b = vec![R::zero(); n];
86
87 let cols = (0..k).map(|j| {
88 copy_into(y.col_vec(j), &mut b);
89 _solve_triangular(t, a, &diag, &mut b)
90 });
91
92 SpMat::from_col_vecs(n, cols)
93}
94
95#[cfg(feature = "multithread")]
96fn solve_triangular_m<R>(t: TriangularType, a: &SpMat<R>, y: &SpMat<R>) -> SpMat<R>
97where R: Ring, for<'x> &'x R: RingOps<R> {
98 use yui_core::util::sync::SyncCounter;
99
100 debug!("solve triangular, y: {:?}", y.shape());
101
102 let (n, k) = (a.nrows(), y.ncols());
103 let diag = collect_diag(a);
104 let tl_b = Arc::new(ThreadLocal::new());
105
106 let report = should_report(y);
107 let counter = SyncCounter::new();
108
109 let cols = (0..k).into_par_iter().map(|j| {
110 let mut b = tl_b.get_or(||
111 RefCell::new(vec![R::zero(); n])
112 ).borrow_mut();
113
114 copy_into(y.col_vec(j), &mut b);
115 let col = _solve_triangular(t, a, &diag, &mut b);
116
117 if report {
118 let c = counter.incr();
119 if (c > 0 && c % LOG_THRESHOLD == 0) || c == k {
120 debug!(" solved {c}/{k}.");
121 }
122 }
123
124 col
125 }).collect::<Vec<_>>();
126
127 SpMat::from_col_vecs(n, cols)
128}
129
130#[inline(never)] fn _solve_triangular<R>(t: TriangularType, a: &SpMat<R>, diag: &[&R], b: &mut [R]) -> SpVec<R>
132where R: Ring, for<'x> &'x R: RingOps<R> {
133 let mut entries = vec![];
134
135 let itr = diag.iter().enumerate();
136 let itr = if t.is_upper() {
137 Either::Left(itr.rev())
138 } else {
139 Either::Right(itr)
140 };
141
142 for (j, u) in itr { if b[j].is_zero() { continue }
144
145 let uinv = u.inv().unwrap();
146 let x_j = &b[j] * &uinv; for (i, a_ij) in a.col_vec(j).iter() {
149 if a_ij.is_zero() { continue }
150 b[i] -= a_ij * &x_j;
151 }
152
153 entries.push((j, x_j));
154 }
155
156 debug_assert!(b.iter().all(|b_i|
157 b_i.is_zero())
158 );
159
160 if t.is_upper() {
161 entries.reverse()
162 };
163
164 SpVec::from_sorted_entries(a.ncols(), entries)
165}
166
167fn collect_diag<'a, R>(a: &'a SpMat<R>) -> Vec<&'a R>
168where R: Ring, for<'x> &'x R: RingOps<R> {
169 a.iter().filter_map(|(i, j, a)|
170 if i == j { Some(a) } else { None }
171 ).collect()
172}
173
174fn copy_into<R>(vec: SpVec<R>, x: &mut [R])
175where R: Clone + Zero {
176 vec.iter().for_each(|(i, r)| x[i] = r.clone())
177}
178
179#[allow(unused)]
180fn should_report<R>(a: &SpMat<R>) -> bool {
181 usize::min(a.nrows(), a.ncols()) > LOG_THRESHOLD && log::max_level() >= log::LevelFilter::Debug
182}
183
184#[cfg(test)]
185mod tests {
186 use super::*;
187 use super::TriangularType::{Upper, Lower};
188
189 #[test]
190 fn solve_upper() {
191 let u = SpMat::from_dense_data((5, 5), vec![
192 1, -2, 1, 3, 5,
193 0, -1, 4, 2, 1,
194 0, 0, 1, 0, 3,
195 0, 0, 0, -1, 5,
196 0, 0, 0, 0, 1
197 ]);
198 let x = SpVec::from(vec![1,2,3,4,5]);
199 let b = SpVec::from(vec![37,23,18,21,5]);
200 assert_eq!(solve_triangular_vec(Upper, &u, &b), x);
201 }
202
203 #[test]
204 fn inv_upper() {
205 let u = SpMat::from_dense_data((5, 5), [
206 1, -2, 1, 3, 5,
207 0, -1, 4, 2, 1,
208 0, 0, 1, 0, 3,
209 0, 0, 0, -1, 5,
210 0, 0, 0, 0, 1
211 ]);
212 let uinv = inv_triangular(Upper, &u);
213 let e = &u * &uinv;
214 assert!(e.is_id());
215 }
216
217 #[test]
218 fn solve_lower() {
219 let l = SpMat::from_dense_data((5, 5), [
220 1, 0, 0, 0, 0,
221 -2, -1, 0, 0, 0,
222 1, 4, 1, 0, 0,
223 3, 2, 0, -1, 0,
224 5, 1, 3, 5, 1
225 ]);
226 let x = SpVec::from(vec![1,2,3,4,5]);
227 let b = SpVec::from(vec![1,-4,12,3,41]);
228 assert_eq!(solve_triangular_vec(Lower, &l, &b), x);
229 }
230
231 #[test]
232 fn inv_lower() {
233 let l = SpMat::from_dense_data((5, 5), [
234 1, 0, 0, 0, 0,
235 -2, -1, 0, 0, 0,
236 1, 4, 1, 0, 0,
237 3, 2, 0, -1, 0,
238 5, 1, 3, 5, 1
239 ]);
240 let linv = inv_triangular(Lower, &l);
241 let e = &l * &linv;
242 assert!(e.is_id());
243 }
244}