faer/sparse/
solvers.rs

1use crate::get_global_parallelism;
2use crate::internal_prelude_sp::*;
3use crate::linalg::solvers::{ShapeCore, SolveCore, SolveLstsqCore};
4use linalg_sp::{LltError, LuError};
5
6/// reference-counted sparse symbolic $LL^\top$ factorization
7#[derive(Debug, Clone)]
8pub struct SymbolicLlt<I> {
9	inner: alloc::sync::Arc<linalg_sp::cholesky::SymbolicCholesky<I>>,
10}
11
12/// sparse $LL^\top$ factorization
13#[derive(Debug, Clone)]
14pub struct Llt<I, T> {
15	symbolic: SymbolicLlt<I>,
16	numeric: alloc::vec::Vec<T>,
17}
18
19/// reference-counted sparse symbolic $QR$ factorization
20#[derive(Debug, Clone)]
21pub struct SymbolicQr<I> {
22	inner: alloc::sync::Arc<linalg_sp::qr::SymbolicQr<I>>,
23}
24
25/// sparse $QR$ factorization
26#[derive(Debug, Clone)]
27pub struct Qr<I, T> {
28	symbolic: SymbolicQr<I>,
29	indices: alloc::vec::Vec<I>,
30	numeric: alloc::vec::Vec<T>,
31}
32
33/// reference-counted sparse symbolic $LU$ factorization
34#[derive(Debug, Clone)]
35pub struct SymbolicLu<I> {
36	inner: alloc::sync::Arc<linalg_sp::lu::SymbolicLu<I>>,
37}
38
39/// sparse $QR$ factorization
40#[derive(Debug, Clone)]
41pub struct Lu<I, T> {
42	symbolic: SymbolicLu<I>,
43	numeric: linalg_sp::lu::NumericLu<I, T>,
44}
45
46impl<I: Index> SymbolicLlt<I> {
47	/// returns the symbolic $LL^\top$ factorization of the input matrix
48	///
49	/// only the provided side is accessed
50	#[track_caller]
51	pub fn try_new(mat: SymbolicSparseColMatRef<'_, I>, side: Side) -> Result<Self, FaerError> {
52		Ok(Self {
53			inner: alloc::sync::Arc::new(linalg_sp::cholesky::factorize_symbolic_cholesky(
54				mat,
55				side,
56				Default::default(),
57				Default::default(),
58			)?),
59		})
60	}
61}
62
63impl<I: Index> SymbolicQr<I> {
64	/// returns the symbolic $QR$ factorization of the input matrix
65	#[track_caller]
66	pub fn try_new(mat: SymbolicSparseColMatRef<'_, I>) -> Result<Self, FaerError> {
67		Ok(Self {
68			inner: alloc::sync::Arc::new(linalg_sp::qr::factorize_symbolic_qr(mat, Default::default())?),
69		})
70	}
71}
72
73impl<I: Index> SymbolicLu<I> {
74	/// returns the symbolic $LU$ factorization of the input matrix
75	#[track_caller]
76	pub fn try_new(mat: SymbolicSparseColMatRef<'_, I>) -> Result<Self, FaerError> {
77		Ok(Self {
78			inner: alloc::sync::Arc::new(linalg_sp::lu::factorize_symbolic_lu(mat, Default::default())?),
79		})
80	}
81}
82
83impl<I: Index, T: ComplexField> Llt<I, T> {
84	/// returns the $LL^\top$ factorization of the input matrix with the same sparsity pattern as
85	/// the original one used to construct the symbolic factorization
86	///
87	/// only the provided side is accessed
88	#[track_caller]
89	pub fn try_new_with_symbolic(symbolic: SymbolicLlt<I>, mat: SparseColMatRef<'_, I, T>, side: Side) -> Result<Self, LltError> {
90		let len_val = symbolic.inner.len_val();
91		let mut numeric = alloc::vec::Vec::new();
92		numeric.try_reserve_exact(len_val).map_err(|_| FaerError::OutOfMemory)?;
93		numeric.resize(len_val, zero::<T>());
94		let par = get_global_parallelism();
95		symbolic.inner.factorize_numeric_llt::<T>(
96			&mut numeric,
97			mat,
98			side,
99			Default::default(),
100			par,
101			MemStack::new(&mut MemBuffer::try_new(
102				symbolic.inner.factorize_numeric_llt_scratch::<T>(par, Default::default()),
103			)?),
104			Default::default(),
105		)?;
106		Ok(Self { symbolic, numeric })
107	}
108}
109
110impl<I: Index, T: ComplexField> Lu<I, T> {
111	/// returns the $LU$ factorization of the input matrix with the same sparsity pattern as the
112	/// original one used to construct the symbolic factorization
113	#[track_caller]
114	pub fn try_new_with_symbolic(symbolic: SymbolicLu<I>, mat: SparseColMatRef<'_, I, T>) -> Result<Self, LuError> {
115		let mut numeric = linalg_sp::lu::NumericLu::new();
116		let par = get_global_parallelism();
117		symbolic.inner.factorize_numeric_lu::<T>(
118			&mut numeric,
119			mat,
120			par,
121			MemStack::new(&mut MemBuffer::try_new(
122				symbolic.inner.factorize_numeric_lu_scratch::<T>(par, Default::default()),
123			)?),
124			Default::default(),
125		)?;
126		Ok(Self { symbolic, numeric })
127	}
128}
129
130impl<I: Index, T: ComplexField> Qr<I, T> {
131	/// returns the $QR$ factorization of the input matrix with the same sparsity pattern as the
132	/// original one used to construct the symbolic factorization
133	#[track_caller]
134	pub fn try_new_with_symbolic(symbolic: SymbolicQr<I>, mat: SparseColMatRef<'_, I, T>) -> Result<Self, FaerError> {
135		let len_val = symbolic.inner.len_val();
136		let len_idx = symbolic.inner.len_idx();
137
138		let mut indices = alloc::vec::Vec::new();
139		let mut numeric = alloc::vec::Vec::new();
140		numeric.try_reserve_exact(len_val).map_err(|_| FaerError::OutOfMemory)?;
141		numeric.resize(len_val, zero::<T>());
142
143		indices.try_reserve_exact(len_idx).map_err(|_| FaerError::OutOfMemory)?;
144		indices.resize(len_idx, I::truncate(0));
145		let par = get_global_parallelism();
146
147		symbolic.inner.factorize_numeric_qr::<T>(
148			&mut indices,
149			&mut numeric,
150			mat,
151			par,
152			MemStack::new(&mut MemBuffer::try_new(
153				symbolic.inner.factorize_numeric_qr_scratch::<T>(par, Default::default()),
154			)?),
155			Default::default(),
156		);
157		Ok(Self { symbolic, indices, numeric })
158	}
159}
160
161impl<I: Index, T: ComplexField> ShapeCore for Llt<I, T> {
162	#[track_caller]
163	fn nrows(&self) -> usize {
164		self.symbolic.inner.nrows()
165	}
166
167	#[track_caller]
168	fn ncols(&self) -> usize {
169		self.symbolic.inner.ncols()
170	}
171}
172
173impl<I: Index, T: ComplexField> ShapeCore for Qr<I, T> {
174	#[track_caller]
175	fn nrows(&self) -> usize {
176		self.symbolic.inner.nrows()
177	}
178
179	#[track_caller]
180	fn ncols(&self) -> usize {
181		self.symbolic.inner.ncols()
182	}
183}
184
185impl<I: Index, T: ComplexField> ShapeCore for Lu<I, T> {
186	#[track_caller]
187	fn nrows(&self) -> usize {
188		self.symbolic.inner.nrows()
189	}
190
191	#[track_caller]
192	fn ncols(&self) -> usize {
193		self.symbolic.inner.ncols()
194	}
195}
196
197impl<I: Index, T: ComplexField> SolveCore<T> for Llt<I, T> {
198	#[track_caller]
199	fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
200		let par = get_global_parallelism();
201		let rhs_ncols = rhs.ncols();
202		linalg_sp::cholesky::LltRef::<'_, I, T>::new(&self.symbolic.inner, &self.numeric).solve_in_place_with_conj(
203			conj,
204			rhs,
205			par,
206			MemStack::new(&mut MemBuffer::new(self.symbolic.inner.solve_in_place_scratch::<T>(rhs_ncols, par))),
207		);
208	}
209
210	#[track_caller]
211	fn solve_transpose_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
212		let par = get_global_parallelism();
213		let rhs_ncols = rhs.ncols();
214		linalg_sp::cholesky::LltRef::<'_, I, T>::new(&self.symbolic.inner, &self.numeric).solve_in_place_with_conj(
215			conj.compose(Conj::Yes),
216			rhs,
217			par,
218			MemStack::new(&mut MemBuffer::new(self.symbolic.inner.solve_in_place_scratch::<T>(rhs_ncols, par))),
219		);
220	}
221}
222
223impl<I: Index, T: ComplexField> SolveCore<T> for Qr<I, T> {
224	#[track_caller]
225	fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
226		let par = get_global_parallelism();
227		let rhs_ncols = rhs.ncols();
228		unsafe { linalg_sp::qr::QrRef::<'_, I, T>::new_unchecked(&self.symbolic.inner, &self.indices, &self.numeric) }.solve_in_place_with_conj(
229			conj,
230			rhs,
231			par,
232			MemStack::new(&mut MemBuffer::new(self.symbolic.inner.solve_in_place_scratch::<T>(rhs_ncols, par))),
233		);
234	}
235
236	#[track_caller]
237	fn solve_transpose_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
238		_ = conj;
239		_ = rhs;
240		panic!("the sparse QR decomposition doesn't support solve_transpose.\nconsider using the sparse LU or Cholesky instead");
241	}
242}
243
244impl<I: Index, T: ComplexField> SolveLstsqCore<T> for Qr<I, T> {
245	#[track_caller]
246	fn solve_lstsq_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
247		let par = get_global_parallelism();
248		let rhs_ncols = rhs.ncols();
249		unsafe { linalg_sp::qr::QrRef::<'_, I, T>::new_unchecked(&self.symbolic.inner, &self.indices, &self.numeric) }.solve_in_place_with_conj(
250			conj,
251			rhs,
252			par,
253			MemStack::new(&mut MemBuffer::new(self.symbolic.inner.solve_in_place_scratch::<T>(rhs_ncols, par))),
254		);
255	}
256}
257
258impl<I: Index, T: ComplexField> SolveCore<T> for Lu<I, T> {
259	#[track_caller]
260	fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
261		let par = get_global_parallelism();
262		let rhs_ncols = rhs.ncols();
263		unsafe { linalg_sp::lu::LuRef::<'_, I, T>::new_unchecked(&self.symbolic.inner, &self.numeric) }.solve_in_place_with_conj(
264			conj,
265			rhs,
266			par,
267			MemStack::new(&mut MemBuffer::new(self.symbolic.inner.solve_in_place_scratch::<T>(rhs_ncols, par))),
268		);
269	}
270
271	#[track_caller]
272	fn solve_transpose_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
273		let par = get_global_parallelism();
274		let rhs_ncols = rhs.ncols();
275		unsafe { linalg_sp::lu::LuRef::<'_, I, T>::new_unchecked(&self.symbolic.inner, &self.numeric) }.solve_transpose_in_place_with_conj(
276			conj,
277			rhs,
278			par,
279			MemStack::new(&mut MemBuffer::new(
280				self.symbolic.inner.solve_transpose_in_place_scratch::<T>(rhs_ncols, par),
281			)),
282		);
283	}
284}
285
286impl<I: Index, T: ComplexField, Inner: for<'short> Reborrow<'short, Target = csc_numeric::Ref<'short, I, T>>>
287	csc_numeric::generic::SparseColMat<Inner>
288{
289	/// assuming `self` is a lower triangular matrix, solves the equation `self * x = rhs`, and
290	/// stores the result in `rhs`
291	///
292	/// # note
293	/// the matrix indices need not be sorted, but
294	/// the diagonal element is assumed to be the first stored element in each column
295	#[track_caller]
296	pub fn sp_solve_lower_triangular_in_place(&self, mut rhs: impl AsMatMut<T = T, Rows = usize>) {
297		linalg_sp::triangular_solve::solve_lower_triangular_in_place(
298			self.rb(),
299			Conj::No,
300			rhs.as_mat_mut().as_dyn_cols_mut(),
301			get_global_parallelism(),
302		);
303	}
304
305	/// assuming `self` is an upper triangular matrix, solves the equation `self * x = rhs`, and
306	/// stores the result in `rhs`
307	///
308	/// # note
309	/// the matrix indices need not be sorted, but
310	/// the diagonal element is assumed to be the last stored element in each column
311	#[track_caller]
312	pub fn sp_solve_upper_triangular_in_place(&self, mut rhs: impl AsMatMut<T = T, Rows = usize>) {
313		linalg_sp::triangular_solve::solve_upper_triangular_in_place(
314			self.rb(),
315			Conj::No,
316			rhs.as_mat_mut().as_dyn_cols_mut(),
317			get_global_parallelism(),
318		);
319	}
320
321	/// assuming `self` is a unit lower triangular matrix, solves the equation `self * x = rhs`,
322	/// and stores the result in `rhs`
323	///
324	/// # note
325	/// the matrix indices need not be sorted, but
326	/// the diagonal element is assumed to be the first stored element in each column
327	#[track_caller]
328	pub fn sp_solve_unit_lower_triangular_in_place(&self, mut rhs: impl AsMatMut<T = T, Rows = usize>) {
329		linalg_sp::triangular_solve::solve_unit_lower_triangular_in_place(
330			self.rb(),
331			Conj::No,
332			rhs.as_mat_mut().as_dyn_cols_mut(),
333			get_global_parallelism(),
334		);
335	}
336
337	/// assuming `self` is a unit upper triangular matrix, solves the equation `self * x = rhs`,
338	/// and stores the result in `rhs`
339	///
340	/// # note
341	/// the matrix indices need not be sorted, but
342	/// the diagonal element is assumed to be the last stored element in each column
343	#[track_caller]
344	pub fn sp_solve_unit_upper_triangular_in_place(&self, mut rhs: impl AsMatMut<T = T, Rows = usize>) {
345		linalg_sp::triangular_solve::solve_unit_upper_triangular_in_place(
346			self.rb(),
347			Conj::No,
348			rhs.as_mat_mut().as_dyn_cols_mut(),
349			get_global_parallelism(),
350		);
351	}
352
353	/// returns the $LL^\top$ decomposition of `self`. only the provided side is accessed
354	#[track_caller]
355	#[doc(alias = "sp_llt")]
356	pub fn sp_cholesky(&self, side: Side) -> Result<Llt<I, T>, LltError> {
357		let this = self.rb();
358		Llt::try_new_with_symbolic(SymbolicLlt::try_new(this.symbolic(), side)?, this, side)
359	}
360
361	/// returns the $LU$ decomposition of `self` with partial (row) pivoting
362	#[track_caller]
363	pub fn sp_lu(&self) -> Result<Lu<I, T>, LuError> {
364		let this = self.rb();
365		Lu::try_new_with_symbolic(SymbolicLu::try_new(this.symbolic())?, this)
366	}
367
368	/// returns the $QR$ decomposition of `self`
369	#[track_caller]
370	pub fn sp_qr(&self) -> Result<Qr<I, T>, FaerError> {
371		let this = self.rb();
372		Qr::try_new_with_symbolic(SymbolicQr::try_new(this.symbolic())?, this)
373	}
374}
375
376impl<I: Index, T: ComplexField, Inner: for<'short> Reborrow<'short, Target = csr_numeric::Ref<'short, I, T>>>
377	csr_numeric::generic::SparseRowMat<Inner>
378{
379	/// assuming `self` is an upper triangular matrix, solves the equation `self * x = rhs`, and
380	/// stores the result in `rhs`
381	///
382	/// # note
383	/// the matrix indices need not be sorted, but
384	/// the diagonal element is assumed to be the last stored element in each row
385	#[track_caller]
386	pub fn sp_solve_lower_triangular_in_place(&self, mut rhs: impl AsMatMut<T = T, Rows = usize>) {
387		linalg_sp::triangular_solve::solve_upper_triangular_transpose_in_place(
388			self.rb().transpose(),
389			Conj::No,
390			rhs.as_mat_mut().as_dyn_cols_mut(),
391			get_global_parallelism(),
392		);
393	}
394
395	/// assuming `self` is an upper triangular matrix, solves the equation `self * x = rhs`, and
396	/// stores the result in `rhs`
397	///
398	/// # note
399	/// the matrix indices need not be sorted, but
400	/// the diagonal element is assumed to be the first stored element in each row
401	#[track_caller]
402	pub fn sp_solve_upper_triangular_in_place(&self, mut rhs: impl AsMatMut<T = T, Rows = usize>) {
403		linalg_sp::triangular_solve::solve_lower_triangular_transpose_in_place(
404			self.rb().transpose(),
405			Conj::No,
406			rhs.as_mat_mut().as_dyn_cols_mut(),
407			get_global_parallelism(),
408		);
409	}
410
411	/// assuming `self` is a unit lower triangular matrix, solves the equation `self * x = rhs`,
412	/// and stores the result in `rhs`
413	///
414	/// # note
415	/// the matrix indices need not be sorted, but
416	/// the diagonal element is assumed to be the last stored element in each row
417	#[track_caller]
418	pub fn sp_solve_unit_lower_triangular_in_place(&self, mut rhs: impl AsMatMut<T = T, Rows = usize>) {
419		linalg_sp::triangular_solve::solve_unit_upper_triangular_transpose_in_place(
420			self.rb().transpose(),
421			Conj::No,
422			rhs.as_mat_mut().as_dyn_cols_mut(),
423			get_global_parallelism(),
424		);
425	}
426
427	/// assuming `self` is a unit upper triangular matrix, solves the equation `self * x = rhs`,
428	/// and stores the result in `rhs`
429	///
430	/// # note
431	/// the matrix indices need not be sorted, but
432	/// the diagonal element is assumed to be the first stored element in each row
433	#[track_caller]
434	pub fn sp_solve_unit_upper_triangular_in_place(&self, mut rhs: impl AsMatMut<T = T, Rows = usize>) {
435		linalg_sp::triangular_solve::solve_unit_lower_triangular_transpose_in_place(
436			self.rb().transpose(),
437			Conj::No,
438			rhs.as_mat_mut().as_dyn_cols_mut(),
439			get_global_parallelism(),
440		);
441	}
442
443	/// returns the $LL^\top$ decomposition of `self`. only the provided side is accessed
444	#[track_caller]
445	#[doc(alias = "sp_llt")]
446	pub fn sp_cholesky(&self, side: Side) -> Result<Llt<I, T>, LltError> {
447		let this = self.rb().to_col_major()?;
448		let this = this.rb();
449		Llt::try_new_with_symbolic(SymbolicLlt::try_new(this.symbolic(), side)?, this, side)
450	}
451
452	/// returns the $LU$ decomposition of `self` with partial (row) pivoting
453	#[track_caller]
454	pub fn sp_lu(&self) -> Result<Lu<I, T>, LuError> {
455		let this = self.rb().to_col_major()?;
456		let this = this.rb();
457		Lu::try_new_with_symbolic(SymbolicLu::try_new(this.symbolic())?, this)
458	}
459
460	/// returns the $QR$ decomposition of `self`
461	#[track_caller]
462	pub fn sp_qr(&self) -> Result<Qr<I, T>, FaerError> {
463		let this = self.rb().to_col_major()?;
464		let this = this.rb();
465		Qr::try_new_with_symbolic(SymbolicQr::try_new(this.symbolic())?, this)
466	}
467}