1use crate::get_global_parallelism;
2use crate::internal_prelude_sp::*;
3use crate::linalg::solvers::{ShapeCore, SolveCore, SolveLstsqCore};
4use linalg_sp::{LltError, LuError};
5
6#[derive(Debug, Clone)]
8pub struct SymbolicLlt<I> {
9 inner: alloc::sync::Arc<linalg_sp::cholesky::SymbolicCholesky<I>>,
10}
11
12#[derive(Debug, Clone)]
14pub struct Llt<I, T> {
15 symbolic: SymbolicLlt<I>,
16 numeric: alloc::vec::Vec<T>,
17}
18
19#[derive(Debug, Clone)]
21pub struct SymbolicQr<I> {
22 inner: alloc::sync::Arc<linalg_sp::qr::SymbolicQr<I>>,
23}
24
25#[derive(Debug, Clone)]
27pub struct Qr<I, T> {
28 symbolic: SymbolicQr<I>,
29 indices: alloc::vec::Vec<I>,
30 numeric: alloc::vec::Vec<T>,
31}
32
33#[derive(Debug, Clone)]
35pub struct SymbolicLu<I> {
36 inner: alloc::sync::Arc<linalg_sp::lu::SymbolicLu<I>>,
37}
38
39#[derive(Debug, Clone)]
41pub struct Lu<I, T> {
42 symbolic: SymbolicLu<I>,
43 numeric: linalg_sp::lu::NumericLu<I, T>,
44}
45
46impl<I: Index> SymbolicLlt<I> {
47 #[track_caller]
51 pub fn try_new(mat: SymbolicSparseColMatRef<'_, I>, side: Side) -> Result<Self, FaerError> {
52 Ok(Self {
53 inner: alloc::sync::Arc::new(linalg_sp::cholesky::factorize_symbolic_cholesky(
54 mat,
55 side,
56 Default::default(),
57 Default::default(),
58 )?),
59 })
60 }
61}
62
63impl<I: Index> SymbolicQr<I> {
64 #[track_caller]
66 pub fn try_new(mat: SymbolicSparseColMatRef<'_, I>) -> Result<Self, FaerError> {
67 Ok(Self {
68 inner: alloc::sync::Arc::new(linalg_sp::qr::factorize_symbolic_qr(mat, Default::default())?),
69 })
70 }
71}
72
73impl<I: Index> SymbolicLu<I> {
74 #[track_caller]
76 pub fn try_new(mat: SymbolicSparseColMatRef<'_, I>) -> Result<Self, FaerError> {
77 Ok(Self {
78 inner: alloc::sync::Arc::new(linalg_sp::lu::factorize_symbolic_lu(mat, Default::default())?),
79 })
80 }
81}
82
83impl<I: Index, T: ComplexField> Llt<I, T> {
84 #[track_caller]
89 pub fn try_new_with_symbolic(symbolic: SymbolicLlt<I>, mat: SparseColMatRef<'_, I, T>, side: Side) -> Result<Self, LltError> {
90 let len_val = symbolic.inner.len_val();
91 let mut numeric = alloc::vec::Vec::new();
92 numeric.try_reserve_exact(len_val).map_err(|_| FaerError::OutOfMemory)?;
93 numeric.resize(len_val, zero::<T>());
94 let par = get_global_parallelism();
95 symbolic.inner.factorize_numeric_llt::<T>(
96 &mut numeric,
97 mat,
98 side,
99 Default::default(),
100 par,
101 MemStack::new(&mut MemBuffer::try_new(
102 symbolic.inner.factorize_numeric_llt_scratch::<T>(par, Default::default()),
103 )?),
104 Default::default(),
105 )?;
106 Ok(Self { symbolic, numeric })
107 }
108}
109
110impl<I: Index, T: ComplexField> Lu<I, T> {
111 #[track_caller]
114 pub fn try_new_with_symbolic(symbolic: SymbolicLu<I>, mat: SparseColMatRef<'_, I, T>) -> Result<Self, LuError> {
115 let mut numeric = linalg_sp::lu::NumericLu::new();
116 let par = get_global_parallelism();
117 symbolic.inner.factorize_numeric_lu::<T>(
118 &mut numeric,
119 mat,
120 par,
121 MemStack::new(&mut MemBuffer::try_new(
122 symbolic.inner.factorize_numeric_lu_scratch::<T>(par, Default::default()),
123 )?),
124 Default::default(),
125 )?;
126 Ok(Self { symbolic, numeric })
127 }
128}
129
130impl<I: Index, T: ComplexField> Qr<I, T> {
131 #[track_caller]
134 pub fn try_new_with_symbolic(symbolic: SymbolicQr<I>, mat: SparseColMatRef<'_, I, T>) -> Result<Self, FaerError> {
135 let len_val = symbolic.inner.len_val();
136 let len_idx = symbolic.inner.len_idx();
137
138 let mut indices = alloc::vec::Vec::new();
139 let mut numeric = alloc::vec::Vec::new();
140 numeric.try_reserve_exact(len_val).map_err(|_| FaerError::OutOfMemory)?;
141 numeric.resize(len_val, zero::<T>());
142
143 indices.try_reserve_exact(len_idx).map_err(|_| FaerError::OutOfMemory)?;
144 indices.resize(len_idx, I::truncate(0));
145 let par = get_global_parallelism();
146
147 symbolic.inner.factorize_numeric_qr::<T>(
148 &mut indices,
149 &mut numeric,
150 mat,
151 par,
152 MemStack::new(&mut MemBuffer::try_new(
153 symbolic.inner.factorize_numeric_qr_scratch::<T>(par, Default::default()),
154 )?),
155 Default::default(),
156 );
157 Ok(Self { symbolic, indices, numeric })
158 }
159}
160
161impl<I: Index, T: ComplexField> ShapeCore for Llt<I, T> {
162 #[track_caller]
163 fn nrows(&self) -> usize {
164 self.symbolic.inner.nrows()
165 }
166
167 #[track_caller]
168 fn ncols(&self) -> usize {
169 self.symbolic.inner.ncols()
170 }
171}
172
173impl<I: Index, T: ComplexField> ShapeCore for Qr<I, T> {
174 #[track_caller]
175 fn nrows(&self) -> usize {
176 self.symbolic.inner.nrows()
177 }
178
179 #[track_caller]
180 fn ncols(&self) -> usize {
181 self.symbolic.inner.ncols()
182 }
183}
184
185impl<I: Index, T: ComplexField> ShapeCore for Lu<I, T> {
186 #[track_caller]
187 fn nrows(&self) -> usize {
188 self.symbolic.inner.nrows()
189 }
190
191 #[track_caller]
192 fn ncols(&self) -> usize {
193 self.symbolic.inner.ncols()
194 }
195}
196
197impl<I: Index, T: ComplexField> SolveCore<T> for Llt<I, T> {
198 #[track_caller]
199 fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
200 let par = get_global_parallelism();
201 let rhs_ncols = rhs.ncols();
202 linalg_sp::cholesky::LltRef::<'_, I, T>::new(&self.symbolic.inner, &self.numeric).solve_in_place_with_conj(
203 conj,
204 rhs,
205 par,
206 MemStack::new(&mut MemBuffer::new(self.symbolic.inner.solve_in_place_scratch::<T>(rhs_ncols, par))),
207 );
208 }
209
210 #[track_caller]
211 fn solve_transpose_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
212 let par = get_global_parallelism();
213 let rhs_ncols = rhs.ncols();
214 linalg_sp::cholesky::LltRef::<'_, I, T>::new(&self.symbolic.inner, &self.numeric).solve_in_place_with_conj(
215 conj.compose(Conj::Yes),
216 rhs,
217 par,
218 MemStack::new(&mut MemBuffer::new(self.symbolic.inner.solve_in_place_scratch::<T>(rhs_ncols, par))),
219 );
220 }
221}
222
223impl<I: Index, T: ComplexField> SolveCore<T> for Qr<I, T> {
224 #[track_caller]
225 fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
226 let par = get_global_parallelism();
227 let rhs_ncols = rhs.ncols();
228 unsafe { linalg_sp::qr::QrRef::<'_, I, T>::new_unchecked(&self.symbolic.inner, &self.indices, &self.numeric) }.solve_in_place_with_conj(
229 conj,
230 rhs,
231 par,
232 MemStack::new(&mut MemBuffer::new(self.symbolic.inner.solve_in_place_scratch::<T>(rhs_ncols, par))),
233 );
234 }
235
236 #[track_caller]
237 fn solve_transpose_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
238 _ = conj;
239 _ = rhs;
240 panic!("the sparse QR decomposition doesn't support solve_transpose.\nconsider using the sparse LU or Cholesky instead");
241 }
242}
243
244impl<I: Index, T: ComplexField> SolveLstsqCore<T> for Qr<I, T> {
245 #[track_caller]
246 fn solve_lstsq_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
247 let par = get_global_parallelism();
248 let rhs_ncols = rhs.ncols();
249 unsafe { linalg_sp::qr::QrRef::<'_, I, T>::new_unchecked(&self.symbolic.inner, &self.indices, &self.numeric) }.solve_in_place_with_conj(
250 conj,
251 rhs,
252 par,
253 MemStack::new(&mut MemBuffer::new(self.symbolic.inner.solve_in_place_scratch::<T>(rhs_ncols, par))),
254 );
255 }
256}
257
258impl<I: Index, T: ComplexField> SolveCore<T> for Lu<I, T> {
259 #[track_caller]
260 fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
261 let par = get_global_parallelism();
262 let rhs_ncols = rhs.ncols();
263 unsafe { linalg_sp::lu::LuRef::<'_, I, T>::new_unchecked(&self.symbolic.inner, &self.numeric) }.solve_in_place_with_conj(
264 conj,
265 rhs,
266 par,
267 MemStack::new(&mut MemBuffer::new(self.symbolic.inner.solve_in_place_scratch::<T>(rhs_ncols, par))),
268 );
269 }
270
271 #[track_caller]
272 fn solve_transpose_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
273 let par = get_global_parallelism();
274 let rhs_ncols = rhs.ncols();
275 unsafe { linalg_sp::lu::LuRef::<'_, I, T>::new_unchecked(&self.symbolic.inner, &self.numeric) }.solve_transpose_in_place_with_conj(
276 conj,
277 rhs,
278 par,
279 MemStack::new(&mut MemBuffer::new(
280 self.symbolic.inner.solve_transpose_in_place_scratch::<T>(rhs_ncols, par),
281 )),
282 );
283 }
284}
285
286impl<I: Index, T: ComplexField, Inner: for<'short> Reborrow<'short, Target = csc_numeric::Ref<'short, I, T>>>
287 csc_numeric::generic::SparseColMat<Inner>
288{
289 #[track_caller]
296 pub fn sp_solve_lower_triangular_in_place(&self, mut rhs: impl AsMatMut<T = T, Rows = usize>) {
297 linalg_sp::triangular_solve::solve_lower_triangular_in_place(
298 self.rb(),
299 Conj::No,
300 rhs.as_mat_mut().as_dyn_cols_mut(),
301 get_global_parallelism(),
302 );
303 }
304
305 #[track_caller]
312 pub fn sp_solve_upper_triangular_in_place(&self, mut rhs: impl AsMatMut<T = T, Rows = usize>) {
313 linalg_sp::triangular_solve::solve_upper_triangular_in_place(
314 self.rb(),
315 Conj::No,
316 rhs.as_mat_mut().as_dyn_cols_mut(),
317 get_global_parallelism(),
318 );
319 }
320
321 #[track_caller]
328 pub fn sp_solve_unit_lower_triangular_in_place(&self, mut rhs: impl AsMatMut<T = T, Rows = usize>) {
329 linalg_sp::triangular_solve::solve_unit_lower_triangular_in_place(
330 self.rb(),
331 Conj::No,
332 rhs.as_mat_mut().as_dyn_cols_mut(),
333 get_global_parallelism(),
334 );
335 }
336
337 #[track_caller]
344 pub fn sp_solve_unit_upper_triangular_in_place(&self, mut rhs: impl AsMatMut<T = T, Rows = usize>) {
345 linalg_sp::triangular_solve::solve_unit_upper_triangular_in_place(
346 self.rb(),
347 Conj::No,
348 rhs.as_mat_mut().as_dyn_cols_mut(),
349 get_global_parallelism(),
350 );
351 }
352
353 #[track_caller]
355 #[doc(alias = "sp_llt")]
356 pub fn sp_cholesky(&self, side: Side) -> Result<Llt<I, T>, LltError> {
357 let this = self.rb();
358 Llt::try_new_with_symbolic(SymbolicLlt::try_new(this.symbolic(), side)?, this, side)
359 }
360
361 #[track_caller]
363 pub fn sp_lu(&self) -> Result<Lu<I, T>, LuError> {
364 let this = self.rb();
365 Lu::try_new_with_symbolic(SymbolicLu::try_new(this.symbolic())?, this)
366 }
367
368 #[track_caller]
370 pub fn sp_qr(&self) -> Result<Qr<I, T>, FaerError> {
371 let this = self.rb();
372 Qr::try_new_with_symbolic(SymbolicQr::try_new(this.symbolic())?, this)
373 }
374}
375
376impl<I: Index, T: ComplexField, Inner: for<'short> Reborrow<'short, Target = csr_numeric::Ref<'short, I, T>>>
377 csr_numeric::generic::SparseRowMat<Inner>
378{
379 #[track_caller]
386 pub fn sp_solve_lower_triangular_in_place(&self, mut rhs: impl AsMatMut<T = T, Rows = usize>) {
387 linalg_sp::triangular_solve::solve_upper_triangular_transpose_in_place(
388 self.rb().transpose(),
389 Conj::No,
390 rhs.as_mat_mut().as_dyn_cols_mut(),
391 get_global_parallelism(),
392 );
393 }
394
395 #[track_caller]
402 pub fn sp_solve_upper_triangular_in_place(&self, mut rhs: impl AsMatMut<T = T, Rows = usize>) {
403 linalg_sp::triangular_solve::solve_lower_triangular_transpose_in_place(
404 self.rb().transpose(),
405 Conj::No,
406 rhs.as_mat_mut().as_dyn_cols_mut(),
407 get_global_parallelism(),
408 );
409 }
410
411 #[track_caller]
418 pub fn sp_solve_unit_lower_triangular_in_place(&self, mut rhs: impl AsMatMut<T = T, Rows = usize>) {
419 linalg_sp::triangular_solve::solve_unit_upper_triangular_transpose_in_place(
420 self.rb().transpose(),
421 Conj::No,
422 rhs.as_mat_mut().as_dyn_cols_mut(),
423 get_global_parallelism(),
424 );
425 }
426
427 #[track_caller]
434 pub fn sp_solve_unit_upper_triangular_in_place(&self, mut rhs: impl AsMatMut<T = T, Rows = usize>) {
435 linalg_sp::triangular_solve::solve_unit_lower_triangular_transpose_in_place(
436 self.rb().transpose(),
437 Conj::No,
438 rhs.as_mat_mut().as_dyn_cols_mut(),
439 get_global_parallelism(),
440 );
441 }
442
443 #[track_caller]
445 #[doc(alias = "sp_llt")]
446 pub fn sp_cholesky(&self, side: Side) -> Result<Llt<I, T>, LltError> {
447 let this = self.rb().to_col_major()?;
448 let this = this.rb();
449 Llt::try_new_with_symbolic(SymbolicLlt::try_new(this.symbolic(), side)?, this, side)
450 }
451
452 #[track_caller]
454 pub fn sp_lu(&self) -> Result<Lu<I, T>, LuError> {
455 let this = self.rb().to_col_major()?;
456 let this = this.rb();
457 Lu::try_new_with_symbolic(SymbolicLu::try_new(this.symbolic())?, this)
458 }
459
460 #[track_caller]
462 pub fn sp_qr(&self) -> Result<Qr<I, T>, FaerError> {
463 let this = self.rb().to_col_major()?;
464 let this = this.rb();
465 Qr::try_new_with_symbolic(SymbolicQr::try_new(this.symbolic())?, this)
466 }
467}