use crate::common::IntegrateFloat;
use crate::dae::utils::linear_solvers::solve_linear_system;
use crate::error::{IntegrateError, IntegrateResult};
use crate::ode::types::{MassMatrix, MassMatrixType};
use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
#[allow(dead_code)]
pub fn solve_mass_system<F>(
mass: &MassMatrix<F>,
t: F,
y: ArrayView1<F>,
b: ArrayView1<F>,
) -> IntegrateResult<Array1<F>>
where
F: IntegrateFloat,
{
match mass.matrix_type {
MassMatrixType::Identity => {
Ok(b.to_owned())
}
_ => {
let matrix = mass.evaluate(t, y).ok_or_else(|| {
IntegrateError::ComputationError("Failed to evaluate mass matrix".to_string())
})?;
solve_matrix_system(matrix.view(), b)
}
}
}
#[allow(dead_code)]
fn solve_matrix_system<F>(matrix: ArrayView2<F>, b: ArrayView1<F>) -> IntegrateResult<Array1<F>>
where
F: IntegrateFloat,
{
solve_linear_system(&matrix, &b).map_err(|err| {
IntegrateError::ComputationError(format!("Failed to solve mass _matrix system: {err}"))
})
}
#[allow(dead_code)]
pub fn apply_mass<F>(
mass: &MassMatrix<F>,
t: F,
y: ArrayView1<F>,
v: ArrayView1<F>,
) -> IntegrateResult<Array1<F>>
where
F: IntegrateFloat,
{
match mass.matrix_type {
MassMatrixType::Identity => {
Ok(v.to_owned())
}
_ => {
let matrix = mass.evaluate(t, y).ok_or_else(|| {
IntegrateError::ComputationError("Failed to evaluate mass matrix".to_string())
})?;
let result = matrix.dot(&v);
Ok(result)
}
}
}
#[allow(dead_code)]
struct LUDecomposition<F: IntegrateFloat> {
lu: Array2<F>,
pivots: Vec<usize>,
}
#[allow(dead_code)]
impl<F: IntegrateFloat> LUDecomposition<F> {
fn new(matrix: ArrayView2<F>) -> IntegrateResult<Self> {
let (n, m) = matrix.dim();
if n != m {
return Err(IntegrateError::ValueError(
"Matrix must be square for LU decomposition".to_string(),
));
}
let mut lu = matrix.to_owned();
let mut pivots = (0..n).collect::<Vec<_>>();
for k in 0..n {
let mut max_row = k;
let mut max_val = lu[[k, k]].abs();
for i in (k + 1)..n {
let val = lu[[i, k]].abs();
if val > max_val {
max_val = val;
max_row = i;
}
}
if max_val < F::from_f64(1e-14).expect("Operation failed") {
return Err(IntegrateError::ComputationError(
"Matrix is singular or nearly singular".to_string(),
));
}
if max_row != k {
pivots.swap(k, max_row);
for j in 0..n {
let temp = lu[[k, j]];
lu[[k, j]] = lu[[max_row, j]];
lu[[max_row, j]] = temp;
}
}
for i in (k + 1)..n {
let factor = lu[[i, k]] / lu[[k, k]];
lu[[i, k]] = factor;
for j in (k + 1)..n {
let temp = lu[[k, j]];
lu[[i, j]] -= factor * temp;
}
}
}
Ok(LUDecomposition { lu, pivots })
}
fn solve(&self, b: ArrayView1<F>) -> IntegrateResult<Array1<F>> {
solve_linear_system(&self.lu.view(), &b).map_err(|err| {
IntegrateError::ComputationError(format!("Failed to solve with matrix: {err}"))
})
}
}
#[allow(dead_code)]
pub fn check_mass_compatibility<F>(
mass: &MassMatrix<F>,
t: F,
y: ArrayView1<F>,
) -> IntegrateResult<()>
where
F: IntegrateFloat,
{
let n = y.len();
match mass.matrix_type {
MassMatrixType::Identity => {
Ok(())
}
_ => {
let matrix = mass.evaluate(t, y).ok_or_else(|| {
IntegrateError::ComputationError("Failed to evaluate mass matrix".to_string())
})?;
let (rows, cols) = matrix.dim();
if rows != n || cols != n {
return Err(IntegrateError::ValueError(format!(
"Mass matrix dimensions ({rows},{cols}) do not match state vector length ({n})"
)));
}
Ok(())
}
}
}
#[allow(dead_code)]
pub fn transform_to_standard_form<F, Func>(
f: Func,
mass: &MassMatrix<F>,
) -> impl Fn(F, ArrayView1<F>) -> IntegrateResult<Array1<F>> + Clone
where
F: IntegrateFloat,
Func: Fn(F, ArrayView1<F>) -> Array1<F> + Clone,
{
let mass_cloned = mass.clone();
move |t: F, y: ArrayView1<F>| {
let rhs = f(t, y);
solve_mass_system(&mass_cloned, t, y, rhs.view())
}
}
#[allow(dead_code)]
pub fn is_singular<F>(matrix: ArrayView2<F>, threshold: Option<F>) -> bool
where
F: IntegrateFloat,
{
let thresh = threshold.unwrap_or_else(|| F::from_f64(1e14).expect("Operation failed"));
let (n, m) = matrix.dim();
if n != m {
return true; }
if n <= 3 {
let det = compute_determinant(&matrix);
return det.abs() < F::from_f64(1e-14).expect("Operation failed");
}
let cond_number = estimate_condition_number(&matrix);
cond_number > thresh
}
#[allow(dead_code)]
fn compute_determinant<F: IntegrateFloat>(matrix: &ArrayView2<F>) -> F {
let (n, _) = matrix.dim();
match n {
1 => matrix[[0, 0]],
2 => matrix[[0, 0]] * matrix[[1, 1]] - matrix[[0, 1]] * matrix[[1, 0]],
3 => {
matrix[[0, 0]] * (matrix[[1, 1]] * matrix[[2, 2]] - matrix[[1, 2]] * matrix[[2, 1]])
- matrix[[0, 1]]
* (matrix[[1, 0]] * matrix[[2, 2]] - matrix[[1, 2]] * matrix[[2, 0]])
+ matrix[[0, 2]]
* (matrix[[1, 0]] * matrix[[2, 1]] - matrix[[1, 1]] * matrix[[2, 0]])
}
_ => F::zero(), }
}
#[allow(dead_code)]
fn estimate_condition_number<F: IntegrateFloat>(matrix: &ArrayView2<F>) -> F {
let _n = matrix.nrows();
let max_singular_val_sq = estimate_largest_eigenvalue_ata(matrix);
let max_singular_val = max_singular_val_sq.sqrt();
let min_singular_val_sq = estimate_smallest_eigenvalue_ata(matrix);
let min_singular_val = min_singular_val_sq.sqrt();
if min_singular_val < F::from_f64(1e-14).expect("Operation failed") {
F::from_f64(1e16).expect("Operation failed") } else {
max_singular_val / min_singular_val
}
}
#[allow(dead_code)]
fn estimate_largest_eigenvalue_ata<F: IntegrateFloat>(matrix: &ArrayView2<F>) -> F {
let n = matrix.nrows();
let max_iterations = 10;
let mut v = Array1::<F>::from_elem(n, F::one());
let mut norm = (v.dot(&v)).sqrt();
if norm > F::from_f64(1e-14).expect("Operation failed") {
v = &v / norm;
}
let mut eigenvalue = F::zero();
for _ in 0..max_iterations {
let mut av = Array1::<F>::zeros(n);
for i in 0..n {
for j in 0..n {
av[i] += matrix[[i, j]] * v[j];
}
}
let mut atav = Array1::<F>::zeros(n);
for i in 0..n {
for j in 0..n {
atav[i] += matrix[[j, i]] * av[j];
}
}
let new_eigenvalue = v.dot(&atav);
norm = (atav.dot(&atav)).sqrt();
if norm > F::from_f64(1e-14).expect("Operation failed") {
v = &atav / norm;
}
eigenvalue = new_eigenvalue;
}
eigenvalue.abs()
}
#[allow(dead_code)]
fn estimate_smallest_eigenvalue_ata<F: IntegrateFloat>(matrix: &ArrayView2<F>) -> F {
let n = matrix.nrows();
let mut min_diag = F::from_f64(f64::INFINITY).expect("Operation failed");
for i in 0..n {
let mut diag_elem = F::zero();
for k in 0..n {
diag_elem += matrix[[k, i]] * matrix[[k, i]];
}
if diag_elem < min_diag {
min_diag = diag_elem;
}
}
min_diag.max(F::from_f64(1e-16).expect("Operation failed"))
}