use crate::numeric::Numeric;
#[cfg(feature = "rayon")]
use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator};
#[cfg(any(feature = "openblas", feature = "intel-mkl"))]
use ndarray::{Array1, Array2};
#[cfg(any(feature = "openblas", feature = "intel-mkl"))]
use ndarray_linalg::Solve;
pub fn lu_linear_solver<T>(mat: &[Vec<f64>], rhs: &[T]) -> Result<Vec<T>, String>
where
T: Numeric,
{
let mat_rows = mat.len();
if mat_rows != rhs.len() {
return Err(String::from(
"Incompatible design matrix and right-hand side sizes!",
));
}
let (lu, p) = match lu_decomposition(mat) {
LuDecompResult::Success { lu, p } => (lu, p),
LuDecompResult::Failure(err) => return Err(err),
};
let mut y = T::zeros(mat_rows, &rhs[p[0]]);
y[0] = rhs[p[0]].subtract(&T::zero(&rhs[p[0]]));
for i in 1..mat_rows {
let sum = T::sum((0..i).map(|j| y[j].multiply_scalar(lu[i][j])));
y[i] = rhs[p[i]].subtract(&sum);
}
let mut x = T::zeros(mat_rows, &y[mat_rows - 1]);
x[mat_rows - 1] = y[mat_rows - 1].divide_scalar(lu[mat_rows - 1][mat_rows - 1]);
for i in (0..mat_rows - 1).rev() {
let sum = T::sum((i + 1..mat_rows).map(|j| x[j].multiply_scalar(lu[i][j])));
x[i] = y[i].subtract(&sum).divide_scalar(lu[i][i]);
if x[i].is_instance_nan() {
return Err(format!(
"NaN detected in backward substitution at index {}",
i
));
}
}
Ok(x)
}
pub enum LuDecompResult {
Success { lu: Vec<Vec<f64>>, p: Vec<usize> },
Failure(String),
}
pub fn lu_decomposition(mat: &[Vec<f64>]) -> LuDecompResult {
let n = mat.len();
let mut lu = mat.to_vec();
let mut p = (0..n).collect::<Vec<usize>>();
#[cfg(feature = "rayon")]
let parallel_threshold = 100;
for k in 0..n - 1 {
let mut max_row = k;
for i in k + 1..n {
if lu[i][k].abs() > lu[max_row][k].abs() {
max_row = i;
}
}
if lu[max_row][k] == 0.0 {
return LuDecompResult::Failure(format!(
"Exact zero pivot encountered at column {}",
k
));
}
if max_row != k {
lu.swap(k, max_row);
p.swap(k, max_row);
}
let pivot = lu[k][k];
let (head, tail) = lu.split_at_mut(k + 1);
let pivot_row = &head[k];
#[cfg(feature = "rayon")]
{
if n - k - 1 > parallel_threshold {
tail.par_iter_mut().for_each(|row| {
update_row(row, pivot_row, k, pivot);
});
continue;
}
}
for row in tail.iter_mut() {
update_row(row, pivot_row, k, pivot);
}
}
LuDecompResult::Success { lu, p }
}
fn update_row(row: &mut [f64], upper_row_k: &[f64], k: usize, pivot: f64) {
let factor = row[k] / pivot;
row[k] = factor;
let row_tail = &mut row[(k + 1)..];
let lu_k_row_tail = &upper_row_k[(k + 1)..];
for (j, &lu_kj) in lu_k_row_tail.iter().enumerate() {
row_tail[j] -= factor * lu_kj;
}
}
pub fn matrix_frobenius_norm(mat: &[Vec<f64>]) -> f64 {
mat.iter()
.flat_map(|row| row.iter())
.map(|&v| v * v)
.sum::<f64>()
.sqrt()
}
pub fn build_design_matrix<X>(
x0: &[X],
x1: &[X],
kernel: &fn(f64, f64) -> f64,
epsilon: f64,
) -> Vec<Vec<f64>>
where
X: Numeric,
{
(0..x0.len())
.map(|i| {
(0..x1.len())
.map(|j| {
let dist = x0[i].squared_distance(&x1[j]).max(f64::EPSILON);
kernel(dist, epsilon)
})
.collect::<Vec<f64>>()
})
.collect::<Vec<Vec<f64>>>()
}
#[cfg(any(feature = "openblas", feature = "intel-mkl"))]
pub fn ndarray_linear_solver<T: Numeric>(
design_matrix: &[Vec<f64>],
rhs: &[T],
) -> Result<Vec<T>, String> {
let n = design_matrix.len();
if rhs.len() != n {
return Err(format!(
"Design matrix and rhs have different lengths: {} vs {}",
n,
rhs.len()
));
}
let design_array =
Array2::from_shape_vec((n, n), design_matrix.iter().flatten().copied().collect())
.map_err(|e| format!("Design matrix conversion failed: {}", e))?;
let dim = rhs[0].to_flattened().len();
for (i, val) in rhs.iter().enumerate() {
if val.to_flattened().len() != dim {
return Err(format!(
"Inconsistent dimension at index {}: expected {}, got {}",
i,
dim,
val.to_flattened().len()
));
}
}
let mut weights_matrix = Array2::zeros((n, dim));
for d in 0..dim {
let rhs_column: Array1<f64> = Array1::from_iter(rhs.iter().map(|x| x.to_flattened()[d]));
let w = design_array
.solve(&rhs_column)
.map_err(|e| format!("Linear solve failed for dimension {}: {}", d, e))?;
weights_matrix.column_mut(d).assign(&w);
}
weights_matrix
.outer_iter()
.map(|row| T::from_flattened(row.to_vec()))
.collect()
}