use std::ops::{Index, IndexMut};
use cauchy::Scalar;
use ndarray::*;
use num_traits::One;
use super::convert::*;
use super::error::*;
use super::lapack::*;
use super::layout::*;
#[derive(Clone, PartialEq)]
pub struct Tridiagonal<A: Scalar> {
pub l: MatrixLayout,
pub dl: Vec<A>,
pub d: Vec<A>,
pub du: Vec<A>,
}
pub trait TridiagIndex {
fn to_tuple(&self) -> (i32, i32);
}
impl TridiagIndex for [Ix; 2] {
fn to_tuple(&self) -> (i32, i32) {
(self[0] as i32, self[1] as i32)
}
}
impl<A, I> Index<I> for Tridiagonal<A>
where
A: Scalar,
I: TridiagIndex,
{
type Output = A;
#[inline]
fn index(&self, index: I) -> &A {
let (n, _) = self.l.size();
let (row, col) = index.to_tuple();
assert!(
std::cmp::max(row, col) < n,
"ndarray: index {:?} is out of bounds for array of shape {}",
[row, col],
n
);
match row - col {
0 => &self.d[row as usize],
1 => &self.dl[col as usize],
-1 => &self.du[row as usize],
_ => panic!(
"ndarray-linalg::tridiagonal: index {:?} is not tridiagonal element",
[row, col]
),
}
}
}
impl<A, I> IndexMut<I> for Tridiagonal<A>
where
A: Scalar,
I: TridiagIndex,
{
#[inline]
fn index_mut(&mut self, index: I) -> &mut A {
let (n, _) = self.l.size();
let (row, col) = index.to_tuple();
assert!(
std::cmp::max(row, col) < n,
"ndarray: index {:?} is out of bounds for array of shape {}",
[row, col],
n
);
match row - col {
0 => &mut self.d[row as usize],
1 => &mut self.dl[col as usize],
-1 => &mut self.du[row as usize],
_ => panic!(
"ndarray-linalg::tridiagonal: index {:?} is not tridiagonal element",
[row, col]
),
}
}
}
pub trait ExtractTridiagonal<A: Scalar> {
fn extract_tridiagonal(&self) -> Result<Tridiagonal<A>>;
}
impl<A, S> ExtractTridiagonal<A> for ArrayBase<S, Ix2>
where
A: Scalar + Lapack,
S: Data<Elem = A>,
{
fn extract_tridiagonal(&self) -> Result<Tridiagonal<A>> {
let l = self.square_layout()?;
let (n, _) = l.size();
if n < 2 {
return Err(LinalgError::NotStandardShape {
obj: "Tridiagonal",
rows: 1,
cols: 1,
});
}
let dl = self.slice(s![1..n, 0..n - 1]).diag().to_vec();
let d = self.diag().to_vec();
let du = self.slice(s![0..n - 1, 1..n]).diag().to_vec();
Ok(Tridiagonal { l, dl, d, du })
}
}
pub trait SolveTridiagonal<A: Scalar, D: Dimension> {
fn solve_tridiagonal<S: Data<Elem = A>>(&self, b: &ArrayBase<S, D>) -> Result<Array<A, D>>;
fn solve_tridiagonal_into<S: DataMut<Elem = A>>(
&self,
b: ArrayBase<S, D>,
) -> Result<ArrayBase<S, D>>;
fn solve_t_tridiagonal<S: Data<Elem = A>>(&self, b: &ArrayBase<S, D>) -> Result<Array<A, D>>;
fn solve_t_tridiagonal_into<S: DataMut<Elem = A>>(
&self,
b: ArrayBase<S, D>,
) -> Result<ArrayBase<S, D>>;
fn solve_h_tridiagonal<S: Data<Elem = A>>(&self, b: &ArrayBase<S, D>) -> Result<Array<A, D>>;
fn solve_h_tridiagonal_into<S: DataMut<Elem = A>>(
&self,
b: ArrayBase<S, D>,
) -> Result<ArrayBase<S, D>>;
}
pub trait SolveTridiagonalInplace<A: Scalar, D: Dimension> {
fn solve_tridiagonal_inplace<'a, S: DataMut<Elem = A>>(
&self,
b: &'a mut ArrayBase<S, D>,
) -> Result<&'a mut ArrayBase<S, D>>;
fn solve_t_tridiagonal_inplace<'a, S: DataMut<Elem = A>>(
&self,
b: &'a mut ArrayBase<S, D>,
) -> Result<&'a mut ArrayBase<S, D>>;
fn solve_h_tridiagonal_inplace<'a, S: DataMut<Elem = A>>(
&self,
b: &'a mut ArrayBase<S, D>,
) -> Result<&'a mut ArrayBase<S, D>>;
}
#[derive(Clone)]
pub struct LUFactorizedTridiagonal<A: Scalar> {
pub a: Tridiagonal<A>,
pub du2: Vec<A>,
pub anom: A::Real,
pub ipiv: Pivot,
}
impl<A> SolveTridiagonal<A, Ix2> for LUFactorizedTridiagonal<A>
where
A: Scalar + Lapack,
{
fn solve_tridiagonal<S: Data<Elem = A>>(&self, b: &ArrayBase<S, Ix2>) -> Result<Array<A, Ix2>> {
let mut b = replicate(b);
self.solve_tridiagonal_inplace(&mut b)?;
Ok(b)
}
fn solve_tridiagonal_into<S: DataMut<Elem = A>>(
&self,
mut b: ArrayBase<S, Ix2>,
) -> Result<ArrayBase<S, Ix2>> {
self.solve_tridiagonal_inplace(&mut b)?;
Ok(b)
}
fn solve_t_tridiagonal<S: Data<Elem = A>>(
&self,
b: &ArrayBase<S, Ix2>,
) -> Result<Array<A, Ix2>> {
let mut b = replicate(b);
self.solve_t_tridiagonal_inplace(&mut b)?;
Ok(b)
}
fn solve_t_tridiagonal_into<S: DataMut<Elem = A>>(
&self,
mut b: ArrayBase<S, Ix2>,
) -> Result<ArrayBase<S, Ix2>> {
self.solve_t_tridiagonal_inplace(&mut b)?;
Ok(b)
}
fn solve_h_tridiagonal<S: Data<Elem = A>>(
&self,
b: &ArrayBase<S, Ix2>,
) -> Result<Array<A, Ix2>> {
let mut b = replicate(b);
self.solve_h_tridiagonal_inplace(&mut b)?;
Ok(b)
}
fn solve_h_tridiagonal_into<S: DataMut<Elem = A>>(
&self,
mut b: ArrayBase<S, Ix2>,
) -> Result<ArrayBase<S, Ix2>> {
self.solve_h_tridiagonal_inplace(&mut b)?;
Ok(b)
}
}
impl<A> SolveTridiagonal<A, Ix2> for Tridiagonal<A>
where
A: Scalar + Lapack,
{
fn solve_tridiagonal<Sb: Data<Elem = A>>(
&self,
b: &ArrayBase<Sb, Ix2>,
) -> Result<Array<A, Ix2>> {
let mut b = replicate(b);
self.solve_tridiagonal_inplace(&mut b)?;
Ok(b)
}
fn solve_tridiagonal_into<Sb: DataMut<Elem = A>>(
&self,
mut b: ArrayBase<Sb, Ix2>,
) -> Result<ArrayBase<Sb, Ix2>> {
self.solve_tridiagonal_inplace(&mut b)?;
Ok(b)
}
fn solve_t_tridiagonal<Sb: Data<Elem = A>>(
&self,
b: &ArrayBase<Sb, Ix2>,
) -> Result<Array<A, Ix2>> {
let mut b = replicate(b);
self.solve_t_tridiagonal_inplace(&mut b)?;
Ok(b)
}
fn solve_t_tridiagonal_into<Sb: DataMut<Elem = A>>(
&self,
mut b: ArrayBase<Sb, Ix2>,
) -> Result<ArrayBase<Sb, Ix2>> {
self.solve_t_tridiagonal_inplace(&mut b)?;
Ok(b)
}
fn solve_h_tridiagonal<Sb: Data<Elem = A>>(
&self,
b: &ArrayBase<Sb, Ix2>,
) -> Result<Array<A, Ix2>> {
let mut b = replicate(b);
self.solve_h_tridiagonal_inplace(&mut b)?;
Ok(b)
}
fn solve_h_tridiagonal_into<Sb: DataMut<Elem = A>>(
&self,
mut b: ArrayBase<Sb, Ix2>,
) -> Result<ArrayBase<Sb, Ix2>> {
self.solve_h_tridiagonal_inplace(&mut b)?;
Ok(b)
}
}
impl<A, S> SolveTridiagonal<A, Ix2> for ArrayBase<S, Ix2>
where
A: Scalar + Lapack,
S: Data<Elem = A>,
{
fn solve_tridiagonal<Sb: Data<Elem = A>>(
&self,
b: &ArrayBase<Sb, Ix2>,
) -> Result<Array<A, Ix2>> {
let mut b = replicate(b);
self.solve_tridiagonal_inplace(&mut b)?;
Ok(b)
}
fn solve_tridiagonal_into<Sb: DataMut<Elem = A>>(
&self,
mut b: ArrayBase<Sb, Ix2>,
) -> Result<ArrayBase<Sb, Ix2>> {
self.solve_tridiagonal_inplace(&mut b)?;
Ok(b)
}
fn solve_t_tridiagonal<Sb: Data<Elem = A>>(
&self,
b: &ArrayBase<Sb, Ix2>,
) -> Result<Array<A, Ix2>> {
let mut b = replicate(b);
self.solve_t_tridiagonal_inplace(&mut b)?;
Ok(b)
}
fn solve_t_tridiagonal_into<Sb: DataMut<Elem = A>>(
&self,
mut b: ArrayBase<Sb, Ix2>,
) -> Result<ArrayBase<Sb, Ix2>> {
self.solve_t_tridiagonal_inplace(&mut b)?;
Ok(b)
}
fn solve_h_tridiagonal<Sb: Data<Elem = A>>(
&self,
b: &ArrayBase<Sb, Ix2>,
) -> Result<Array<A, Ix2>> {
let mut b = replicate(b);
self.solve_h_tridiagonal_inplace(&mut b)?;
Ok(b)
}
fn solve_h_tridiagonal_into<Sb: DataMut<Elem = A>>(
&self,
mut b: ArrayBase<Sb, Ix2>,
) -> Result<ArrayBase<Sb, Ix2>> {
self.solve_h_tridiagonal_inplace(&mut b)?;
Ok(b)
}
}
impl<A> SolveTridiagonalInplace<A, Ix2> for LUFactorizedTridiagonal<A>
where
A: Scalar + Lapack,
{
fn solve_tridiagonal_inplace<'a, Sb>(
&self,
rhs: &'a mut ArrayBase<Sb, Ix2>,
) -> Result<&'a mut ArrayBase<Sb, Ix2>>
where
Sb: DataMut<Elem = A>,
{
unsafe {
A::solve_tridiagonal(
&self,
rhs.layout()?,
Transpose::No,
rhs.as_slice_mut().unwrap(),
)?
};
Ok(rhs)
}
fn solve_t_tridiagonal_inplace<'a, Sb>(
&self,
rhs: &'a mut ArrayBase<Sb, Ix2>,
) -> Result<&'a mut ArrayBase<Sb, Ix2>>
where
Sb: DataMut<Elem = A>,
{
unsafe {
A::solve_tridiagonal(
&self,
rhs.layout()?,
Transpose::Transpose,
rhs.as_slice_mut().unwrap(),
)?
};
Ok(rhs)
}
fn solve_h_tridiagonal_inplace<'a, Sb>(
&self,
rhs: &'a mut ArrayBase<Sb, Ix2>,
) -> Result<&'a mut ArrayBase<Sb, Ix2>>
where
Sb: DataMut<Elem = A>,
{
unsafe {
A::solve_tridiagonal(
&self,
rhs.layout()?,
Transpose::Hermite,
rhs.as_slice_mut().unwrap(),
)?
};
Ok(rhs)
}
}
impl<A> SolveTridiagonalInplace<A, Ix2> for Tridiagonal<A>
where
A: Scalar + Lapack,
{
fn solve_tridiagonal_inplace<'a, Sb>(
&self,
rhs: &'a mut ArrayBase<Sb, Ix2>,
) -> Result<&'a mut ArrayBase<Sb, Ix2>>
where
Sb: DataMut<Elem = A>,
{
let f = self.factorize_tridiagonal()?;
f.solve_tridiagonal_inplace(rhs)
}
fn solve_t_tridiagonal_inplace<'a, Sb>(
&self,
rhs: &'a mut ArrayBase<Sb, Ix2>,
) -> Result<&'a mut ArrayBase<Sb, Ix2>>
where
Sb: DataMut<Elem = A>,
{
let f = self.factorize_tridiagonal()?;
f.solve_t_tridiagonal_inplace(rhs)
}
fn solve_h_tridiagonal_inplace<'a, Sb>(
&self,
rhs: &'a mut ArrayBase<Sb, Ix2>,
) -> Result<&'a mut ArrayBase<Sb, Ix2>>
where
Sb: DataMut<Elem = A>,
{
let f = self.factorize_tridiagonal()?;
f.solve_h_tridiagonal_inplace(rhs)
}
}
impl<A, S> SolveTridiagonalInplace<A, Ix2> for ArrayBase<S, Ix2>
where
A: Scalar + Lapack,
S: Data<Elem = A>,
{
fn solve_tridiagonal_inplace<'a, Sb>(
&self,
rhs: &'a mut ArrayBase<Sb, Ix2>,
) -> Result<&'a mut ArrayBase<Sb, Ix2>>
where
Sb: DataMut<Elem = A>,
{
let f = self.factorize_tridiagonal()?;
f.solve_tridiagonal_inplace(rhs)
}
fn solve_t_tridiagonal_inplace<'a, Sb>(
&self,
rhs: &'a mut ArrayBase<Sb, Ix2>,
) -> Result<&'a mut ArrayBase<Sb, Ix2>>
where
Sb: DataMut<Elem = A>,
{
let f = self.factorize_tridiagonal()?;
f.solve_t_tridiagonal_inplace(rhs)
}
fn solve_h_tridiagonal_inplace<'a, Sb>(
&self,
rhs: &'a mut ArrayBase<Sb, Ix2>,
) -> Result<&'a mut ArrayBase<Sb, Ix2>>
where
Sb: DataMut<Elem = A>,
{
let f = self.factorize_tridiagonal()?;
f.solve_h_tridiagonal_inplace(rhs)
}
}
impl<A> SolveTridiagonal<A, Ix1> for LUFactorizedTridiagonal<A>
where
A: Scalar + Lapack,
{
fn solve_tridiagonal<S: Data<Elem = A>>(&self, b: &ArrayBase<S, Ix1>) -> Result<Array<A, Ix1>> {
let b = b.to_owned();
self.solve_tridiagonal_into(b)
}
fn solve_tridiagonal_into<S: DataMut<Elem = A>>(
&self,
b: ArrayBase<S, Ix1>,
) -> Result<ArrayBase<S, Ix1>> {
let b = into_col(b);
let b = self.solve_tridiagonal_into(b)?;
Ok(flatten(b))
}
fn solve_t_tridiagonal<S: Data<Elem = A>>(
&self,
b: &ArrayBase<S, Ix1>,
) -> Result<Array<A, Ix1>> {
let b = b.to_owned();
self.solve_t_tridiagonal_into(b)
}
fn solve_t_tridiagonal_into<S: DataMut<Elem = A>>(
&self,
b: ArrayBase<S, Ix1>,
) -> Result<ArrayBase<S, Ix1>> {
let b = into_col(b);
let b = self.solve_t_tridiagonal_into(b)?;
Ok(flatten(b))
}
fn solve_h_tridiagonal<S: Data<Elem = A>>(
&self,
b: &ArrayBase<S, Ix1>,
) -> Result<Array<A, Ix1>> {
let b = b.to_owned();
self.solve_h_tridiagonal_into(b)
}
fn solve_h_tridiagonal_into<S: DataMut<Elem = A>>(
&self,
b: ArrayBase<S, Ix1>,
) -> Result<ArrayBase<S, Ix1>> {
let b = into_col(b);
let b = self.solve_h_tridiagonal_into(b)?;
Ok(flatten(b))
}
}
impl<A> SolveTridiagonal<A, Ix1> for Tridiagonal<A>
where
A: Scalar + Lapack,
{
fn solve_tridiagonal<Sb: Data<Elem = A>>(
&self,
b: &ArrayBase<Sb, Ix1>,
) -> Result<Array<A, Ix1>> {
let b = b.to_owned();
self.solve_tridiagonal_into(b)
}
fn solve_tridiagonal_into<Sb: DataMut<Elem = A>>(
&self,
b: ArrayBase<Sb, Ix1>,
) -> Result<ArrayBase<Sb, Ix1>> {
let b = into_col(b);
let f = self.factorize_tridiagonal()?;
let b = f.solve_tridiagonal_into(b)?;
Ok(flatten(b))
}
fn solve_t_tridiagonal<Sb: Data<Elem = A>>(
&self,
b: &ArrayBase<Sb, Ix1>,
) -> Result<Array<A, Ix1>> {
let b = b.to_owned();
self.solve_t_tridiagonal_into(b)
}
fn solve_t_tridiagonal_into<Sb: DataMut<Elem = A>>(
&self,
b: ArrayBase<Sb, Ix1>,
) -> Result<ArrayBase<Sb, Ix1>> {
let b = into_col(b);
let f = self.factorize_tridiagonal()?;
let b = f.solve_t_tridiagonal_into(b)?;
Ok(flatten(b))
}
fn solve_h_tridiagonal<Sb: Data<Elem = A>>(
&self,
b: &ArrayBase<Sb, Ix1>,
) -> Result<Array<A, Ix1>> {
let b = b.to_owned();
self.solve_h_tridiagonal_into(b)
}
fn solve_h_tridiagonal_into<Sb: DataMut<Elem = A>>(
&self,
b: ArrayBase<Sb, Ix1>,
) -> Result<ArrayBase<Sb, Ix1>> {
let b = into_col(b);
let f = self.factorize_tridiagonal()?;
let b = f.solve_h_tridiagonal_into(b)?;
Ok(flatten(b))
}
}
impl<A, S> SolveTridiagonal<A, Ix1> for ArrayBase<S, Ix2>
where
A: Scalar + Lapack,
S: Data<Elem = A>,
{
fn solve_tridiagonal<Sb: Data<Elem = A>>(
&self,
b: &ArrayBase<Sb, Ix1>,
) -> Result<Array<A, Ix1>> {
let b = b.to_owned();
self.solve_tridiagonal_into(b)
}
fn solve_tridiagonal_into<Sb: DataMut<Elem = A>>(
&self,
b: ArrayBase<Sb, Ix1>,
) -> Result<ArrayBase<Sb, Ix1>> {
let b = into_col(b);
let f = self.factorize_tridiagonal()?;
let b = f.solve_tridiagonal_into(b)?;
Ok(flatten(b))
}
fn solve_t_tridiagonal<Sb: Data<Elem = A>>(
&self,
b: &ArrayBase<Sb, Ix1>,
) -> Result<Array<A, Ix1>> {
let b = b.to_owned();
self.solve_t_tridiagonal_into(b)
}
fn solve_t_tridiagonal_into<Sb: DataMut<Elem = A>>(
&self,
b: ArrayBase<Sb, Ix1>,
) -> Result<ArrayBase<Sb, Ix1>> {
let b = into_col(b);
let f = self.factorize_tridiagonal()?;
let b = f.solve_t_tridiagonal_into(b)?;
Ok(flatten(b))
}
fn solve_h_tridiagonal<Sb: Data<Elem = A>>(
&self,
b: &ArrayBase<Sb, Ix1>,
) -> Result<Array<A, Ix1>> {
let b = b.to_owned();
self.solve_h_tridiagonal_into(b)
}
fn solve_h_tridiagonal_into<Sb: DataMut<Elem = A>>(
&self,
b: ArrayBase<Sb, Ix1>,
) -> Result<ArrayBase<Sb, Ix1>> {
let b = into_col(b);
let f = self.factorize_tridiagonal()?;
let b = f.solve_h_tridiagonal_into(b)?;
Ok(flatten(b))
}
}
pub trait FactorizeTridiagonal<A: Scalar> {
fn factorize_tridiagonal(&self) -> Result<LUFactorizedTridiagonal<A>>;
}
pub trait FactorizeTridiagonalInto<A: Scalar> {
fn factorize_tridiagonal_into(self) -> Result<LUFactorizedTridiagonal<A>>;
}
impl<A> FactorizeTridiagonalInto<A> for Tridiagonal<A>
where
A: Scalar + Lapack,
{
fn factorize_tridiagonal_into(mut self) -> Result<LUFactorizedTridiagonal<A>> {
let (du2, anom, ipiv) = unsafe { A::lu_tridiagonal(&mut self)? };
Ok(LUFactorizedTridiagonal {
a: self,
du2: du2,
anom: anom,
ipiv: ipiv,
})
}
}
impl<A> FactorizeTridiagonal<A> for Tridiagonal<A>
where
A: Scalar + Lapack,
{
fn factorize_tridiagonal(&self) -> Result<LUFactorizedTridiagonal<A>> {
let mut a = self.clone();
let (du2, anom, ipiv) = unsafe { A::lu_tridiagonal(&mut a)? };
Ok(LUFactorizedTridiagonal { a, du2, anom, ipiv })
}
}
impl<A, S> FactorizeTridiagonal<A> for ArrayBase<S, Ix2>
where
A: Scalar + Lapack,
S: Data<Elem = A>,
{
fn factorize_tridiagonal(&self) -> Result<LUFactorizedTridiagonal<A>> {
let mut a = self.extract_tridiagonal()?;
let (du2, anom, ipiv) = unsafe { A::lu_tridiagonal(&mut a)? };
Ok(LUFactorizedTridiagonal { a, du2, anom, ipiv })
}
}
fn rec_rel<A: Scalar>(tridiag: &Tridiagonal<A>) -> Vec<A> {
let n = tridiag.d.len();
let mut f = Vec::with_capacity(n + 1);
f.push(One::one());
f.push(tridiag.d[0]);
for i in 1..n {
f.push(tridiag.d[i] * f[i] - tridiag.dl[i - 1] * tridiag.du[i - 1] * f[i - 1]);
}
f
}
pub trait DeterminantTridiagonal<A: Scalar> {
fn det_tridiagonal(&self) -> Result<A>;
}
impl<A> DeterminantTridiagonal<A> for Tridiagonal<A>
where
A: Scalar,
{
fn det_tridiagonal(&self) -> Result<A> {
let n = self.d.len();
Ok(rec_rel(&self)[n])
}
}
impl<A, S> DeterminantTridiagonal<A> for ArrayBase<S, Ix2>
where
A: Scalar + Lapack,
S: Data<Elem = A>,
{
fn det_tridiagonal(&self) -> Result<A> {
let tridiag = self.extract_tridiagonal()?;
let n = tridiag.d.len();
Ok(rec_rel(&tridiag)[n])
}
}
pub trait ReciprocalConditionNumTridiagonal<A: Scalar> {
fn rcond_tridiagonal(&self) -> Result<A::Real>;
}
pub trait ReciprocalConditionNumTridiagonalInto<A: Scalar> {
fn rcond_tridiagonal_into(self) -> Result<A::Real>;
}
impl<A> ReciprocalConditionNumTridiagonal<A> for LUFactorizedTridiagonal<A>
where
A: Scalar + Lapack,
{
fn rcond_tridiagonal(&self) -> Result<A::Real> {
unsafe { A::rcond_tridiagonal(&self) }
}
}
impl<A> ReciprocalConditionNumTridiagonalInto<A> for LUFactorizedTridiagonal<A>
where
A: Scalar + Lapack,
{
fn rcond_tridiagonal_into(self) -> Result<A::Real> {
self.rcond_tridiagonal()
}
}
impl<A, S> ReciprocalConditionNumTridiagonal<A> for ArrayBase<S, Ix2>
where
A: Scalar + Lapack,
S: Data<Elem = A>,
{
fn rcond_tridiagonal(&self) -> Result<A::Real> {
self.factorize_tridiagonal()?.rcond_tridiagonal_into()
}
}