use crate::{error::*, layout::MatrixLayout, *};
use cauchy::*;
use num_traits::{ToPrimitive, Zero};
pub trait LuImpl: Scalar {
fn lu(l: MatrixLayout, a: &mut [Self]) -> Result<Pivot>;
}
macro_rules! impl_lu {
($scalar:ty, $getrf:path) => {
impl LuImpl for $scalar {
fn lu(l: MatrixLayout, a: &mut [Self]) -> Result<Pivot> {
let (row, col) = l.size();
assert_eq!(a.len() as i32, row * col);
if row == 0 || col == 0 {
return Ok(Vec::new());
}
let k = ::std::cmp::min(row, col);
let mut ipiv = vec_uninit(k as usize);
let mut info = 0;
unsafe {
$getrf(
&l.lda(),
&l.len(),
AsPtr::as_mut_ptr(a),
&l.lda(),
AsPtr::as_mut_ptr(&mut ipiv),
&mut info,
)
};
info.as_lapack_result()?;
let ipiv = unsafe { ipiv.assume_init() };
Ok(ipiv)
}
}
};
}
impl_lu!(c64, lapack_sys::zgetrf_);
impl_lu!(c32, lapack_sys::cgetrf_);
impl_lu!(f64, lapack_sys::dgetrf_);
impl_lu!(f32, lapack_sys::sgetrf_);
#[cfg_attr(doc, katexit::katexit)]
pub trait SolveImpl: Scalar {
fn solve(l: MatrixLayout, t: Transpose, a: &[Self], p: &Pivot, b: &mut [Self]) -> Result<()>;
}
macro_rules! impl_solve {
($scalar:ty, $getrs:path) => {
impl SolveImpl for $scalar {
fn solve(
l: MatrixLayout,
t: Transpose,
a: &[Self],
ipiv: &Pivot,
b: &mut [Self],
) -> Result<()> {
let (t, conj) = match l {
MatrixLayout::C { .. } => match t {
Transpose::No => (Transpose::Transpose, false),
Transpose::Transpose => (Transpose::No, false),
Transpose::Hermite => (Transpose::No, true),
},
MatrixLayout::F { .. } => (t, false),
};
let (n, _) = l.size();
let nrhs = 1;
let ldb = l.lda();
let mut info = 0;
if conj {
for b_elem in &mut *b {
*b_elem = b_elem.conj();
}
}
unsafe {
$getrs(
t.as_ptr(),
&n,
&nrhs,
AsPtr::as_ptr(a),
&l.lda(),
ipiv.as_ptr(),
AsPtr::as_mut_ptr(b),
&ldb,
&mut info,
)
};
if conj {
for b_elem in &mut *b {
*b_elem = b_elem.conj();
}
}
info.as_lapack_result()?;
Ok(())
}
}
};
}
impl_solve!(f64, lapack_sys::dgetrs_);
impl_solve!(f32, lapack_sys::sgetrs_);
impl_solve!(c64, lapack_sys::zgetrs_);
impl_solve!(c32, lapack_sys::cgetrs_);
pub struct InvWork<T: Scalar> {
pub layout: MatrixLayout,
pub work: Vec<MaybeUninit<T>>,
}
pub trait InvWorkImpl: Sized {
type Elem: Scalar;
fn new(layout: MatrixLayout) -> Result<Self>;
fn calc(&mut self, a: &mut [Self::Elem], p: &Pivot) -> Result<()>;
}
macro_rules! impl_inv_work {
($s:ty, $tri:path) => {
impl InvWorkImpl for InvWork<$s> {
type Elem = $s;
fn new(layout: MatrixLayout) -> Result<Self> {
let (n, _) = layout.size();
let mut info = 0;
let mut work_size = [Self::Elem::zero()];
unsafe {
$tri(
&n,
std::ptr::null_mut(),
&layout.lda(),
std::ptr::null(),
AsPtr::as_mut_ptr(&mut work_size),
&(-1),
&mut info,
)
};
info.as_lapack_result()?;
let lwork = work_size[0].to_usize().unwrap();
let work = vec_uninit(lwork);
Ok(InvWork { layout, work })
}
fn calc(&mut self, a: &mut [Self::Elem], ipiv: &Pivot) -> Result<()> {
if self.layout.len() == 0 {
return Ok(());
}
let lwork = self.work.len().to_i32().unwrap();
let mut info = 0;
unsafe {
$tri(
&self.layout.len(),
AsPtr::as_mut_ptr(a),
&self.layout.lda(),
ipiv.as_ptr(),
AsPtr::as_mut_ptr(&mut self.work),
&lwork,
&mut info,
)
};
info.as_lapack_result()?;
Ok(())
}
}
};
}
impl_inv_work!(c64, lapack_sys::zgetri_);
impl_inv_work!(c32, lapack_sys::cgetri_);
impl_inv_work!(f64, lapack_sys::dgetri_);
impl_inv_work!(f32, lapack_sys::sgetri_);