use crate::algebra::abstr::AbsDiffEq;
use crate::algebra::abstr::RelativeEq;
use crate::algebra::linear::matrix::substitute::SubstituteBackward;
use crate::algebra::linear::matrix::substitute::SubstituteForward;
use crate::algebra::{
abstr::{Field, Scalar},
linear::{
matrix::{General, Inverse, Solve, UnitLowerTriangular, UpperTriangular},
vector::Vector,
},
};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use std::clone::Clone;
pub trait LUDecomposition<T> {
fn dec_lu(&self) -> Result<LUDec<T>, ()>;
}
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct LUDec<T> {
l: UnitLowerTriangular<T>,
u: UpperTriangular<T>,
p: General<T>,
}
impl<T> LUDec<T> {
pub(super) fn new(l: UnitLowerTriangular<T>, u: UpperTriangular<T>, p: General<T>) -> LUDec<T> {
LUDec { l, u, p }
}
pub fn l(self) -> UnitLowerTriangular<T> {
self.l
}
pub fn u(self) -> UpperTriangular<T> {
self.u
}
pub fn p(self) -> General<T> {
self.p
}
pub fn lup(self) -> (UnitLowerTriangular<T>, UpperTriangular<T>, General<T>) {
(self.l, self.u, self.p)
}
}
impl<T> Inverse<T> for LUDec<T>
where
T: Field + Scalar + AbsDiffEq,
{
type Output = General<T>;
fn inv(&self) -> Result<General<T>, ()> {
let b = General::one(self.p.nrows());
let x: General<T> = self.solve(&b)?;
Ok(x)
}
}
impl<T> Solve<Vector<T>> for LUDec<T>
where
T: Field + Scalar + AbsDiffEq<Epsilon = T> + RelativeEq,
{
fn solve(&self, rhs: &Vector<T>) -> Result<Vector<T>, ()> {
let b_hat: Vector<T> = &self.p * rhs;
let c: Vector<T> = self.l.substitute_forward(b_hat)?;
self.u.substitute_backward(c)
}
}
impl<T> Solve<General<T>> for LUDec<T>
where
T: Field + Scalar + AbsDiffEq,
{
fn solve(&self, rhs: &General<T>) -> Result<General<T>, ()> {
let b_hat: General<T> = &self.p * rhs;
let c: General<T> = self.l.substitute_forward(b_hat)?;
self.u.substitute_backward(c)
}
}