faer 0.24.0

linear algebra library
Documentation
use crate::assert;
use crate::internal_prelude::*;
use crate::perm::{permute_rows_in_place, permute_rows_in_place_scratch};
pub fn solve_in_place_scratch<I: Index, T: ComplexField>(
	LU_dim: usize,
	rhs_ncols: usize,
	par: Par,
) -> StackReq {
	_ = par;
	permute_rows_in_place_scratch::<I, T>(LU_dim, rhs_ncols)
}
pub fn solve_transpose_in_place_scratch<I: Index, T: ComplexField>(
	LU_dim: usize,
	rhs_ncols: usize,
	par: Par,
) -> StackReq {
	_ = par;
	permute_rows_in_place_scratch::<I, T>(LU_dim, rhs_ncols)
}
#[track_caller]
pub fn solve_in_place_with_conj<I: Index, T: ComplexField>(
	L: MatRef<'_, T>,
	U: MatRef<'_, T>,
	row_perm: PermRef<'_, I>,
	conj_LU: Conj,
	rhs: MatMut<'_, T>,
	par: Par,
	stack: &mut MemStack,
) {
	let n = L.nrows();
	assert!(all(
		L.nrows() == n,
		L.ncols() == n,
		U.nrows() == n,
		U.ncols() == n,
		row_perm.len() == n,
		rhs.nrows() == n,
	));
	let mut rhs = rhs;
	permute_rows_in_place(rhs.rb_mut(), row_perm, stack);
	linalg::triangular_solve::solve_unit_lower_triangular_in_place_with_conj(
		L,
		conj_LU,
		rhs.rb_mut(),
		par,
	);
	linalg::triangular_solve::solve_upper_triangular_in_place_with_conj(
		U,
		conj_LU,
		rhs.rb_mut(),
		par,
	);
}
#[track_caller]
pub fn solve_transpose_in_place_with_conj<I: Index, T: ComplexField>(
	L: MatRef<'_, T>,
	U: MatRef<'_, T>,
	row_perm: PermRef<'_, I>,
	conj_LU: Conj,
	rhs: MatMut<'_, T>,
	par: Par,
	stack: &mut MemStack,
) {
	let n = L.nrows();
	assert!(all(
		L.nrows() == n,
		L.ncols() == n,
		U.nrows() == n,
		U.ncols() == n,
		row_perm.len() == n,
		rhs.nrows() == n,
	));
	let mut rhs = rhs;
	linalg::triangular_solve::solve_lower_triangular_in_place_with_conj(
		U.transpose(),
		conj_LU,
		rhs.rb_mut(),
		par,
	);
	linalg::triangular_solve::solve_unit_upper_triangular_in_place_with_conj(
		L.transpose(),
		conj_LU,
		rhs.rb_mut(),
		par,
	);
	permute_rows_in_place(rhs.rb_mut(), row_perm.inverse(), stack);
}
#[track_caller]
pub fn solve_in_place<
	I: Index,
	T: ComplexField,
	C: Conjugate<Canonical = T>,
>(
	L: MatRef<'_, C>,
	U: MatRef<'_, C>,
	row_perm: PermRef<'_, I>,
	rhs: MatMut<'_, T>,
	par: Par,
	stack: &mut MemStack,
) {
	solve_in_place_with_conj(
		L.canonical(),
		U.canonical(),
		row_perm,
		Conj::get::<C>(),
		rhs,
		par,
		stack,
	)
}
#[track_caller]
pub fn solve_transpose_in_place<
	I: Index,
	T: ComplexField,
	C: Conjugate<Canonical = T>,
>(
	L: MatRef<'_, C>,
	U: MatRef<'_, C>,
	row_perm: PermRef<'_, I>,
	rhs: MatMut<'_, T>,
	par: Par,
	stack: &mut MemStack,
) {
	solve_transpose_in_place_with_conj(
		L.canonical(),
		U.canonical(),
		row_perm,
		Conj::get::<C>(),
		rhs,
		par,
		stack,
	)
}
#[cfg(test)]
mod tests {
	use super::*;
	use crate::assert;
	use crate::stats::prelude::*;
	use crate::utils::approx::*;
	use dyn_stack::MemBuffer;
	use linalg::lu::partial_pivoting::*;
	#[test]
	fn test_solve() {
		let rng = &mut StdRng::seed_from_u64(0);
		let n = 50;
		let k = 3;
		let A = CwiseMatDistribution {
			nrows: n,
			ncols: n,
			dist: ComplexDistribution::new(StandardNormal, StandardNormal),
		}
		.rand::<Mat<c64>>(rng);
		let B = CwiseMatDistribution {
			nrows: n,
			ncols: k,
			dist: ComplexDistribution::new(StandardNormal, StandardNormal),
		}
		.rand::<Mat<c64>>(rng);
		let mut LU = A.to_owned();
		let row_perm_fwd = &mut *vec![0usize; n];
		let row_perm_bwd = &mut *vec![0usize; n];
		let row_perm = factor::lu_in_place(
			LU.as_mut(),
			row_perm_fwd,
			row_perm_bwd,
			Par::Seq,
			MemStack::new(&mut {
				MemBuffer::new(factor::lu_in_place_scratch::<usize, c64>(
					n,
					n,
					Par::Seq,
					default(),
				))
			}),
			default(),
		)
		.1;
		let approx_eq = CwiseMat(ApproxEq::eps() * 8.0 * (n as f64));
		{
			let mut X = B.to_owned();
			solve::solve_in_place(
				LU.as_ref(),
				LU.as_ref(),
				row_perm,
				X.as_mut(),
				Par::Seq,
				MemStack::new(&mut MemBuffer::new(
					solve::solve_in_place_scratch::<usize, c64>(n, k, Par::Seq),
				)),
			);
			assert!(& A * & X ~ B);
		}
		{
			let mut X = B.to_owned();
			solve::solve_transpose_in_place(
				LU.as_ref(),
				LU.as_ref(),
				row_perm,
				X.as_mut(),
				Par::Seq,
				MemStack::new(&mut MemBuffer::new(
					solve::solve_in_place_scratch::<usize, c64>(n, k, Par::Seq),
				)),
			);
			assert!(A.transpose() * & X ~ B);
		}
		{
			let mut X = B.to_owned();
			solve::solve_in_place(
				LU.conjugate(),
				LU.conjugate(),
				row_perm,
				X.as_mut(),
				Par::Seq,
				MemStack::new(&mut MemBuffer::new(
					solve::solve_in_place_scratch::<usize, c64>(n, k, Par::Seq),
				)),
			);
			assert!(A.conjugate() * & X ~ B);
		}
		{
			let mut X = B.to_owned();
			solve::solve_transpose_in_place(
				LU.conjugate(),
				LU.conjugate(),
				row_perm,
				X.as_mut(),
				Par::Seq,
				MemStack::new(&mut MemBuffer::new(
					solve::solve_in_place_scratch::<usize, c64>(n, k, Par::Seq),
				)),
			);
			assert!(A.adjoint() * & X ~ B);
		}
	}
}