#![cfg(feature = "backend-faer")]
use crate::error::KError;
use crate::utils::permutation::{Permutation, cuthill_mckee_from_adj};
use faer::Mat;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ReorderingMethod {
None,
Rcm,
CuthillMckee,
Colamd,
Amd,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ScalingMethod {
None,
Diagonal,
Symmetric,
}
#[derive(Debug, Clone)]
pub struct MatrixPreprocessing {
pub permutation: Permutation,
pub left_scaling: Option<Vec<f64>>,
pub right_scaling: Option<Vec<f64>>,
pub is_identity: bool,
}
impl MatrixPreprocessing {
pub fn identity(n: usize) -> Self {
Self {
permutation: Permutation::identity(n),
left_scaling: None,
right_scaling: None,
is_identity: true,
}
}
pub fn apply_to_matrix(&self, a: &Mat<f64>) -> Result<Mat<f64>, KError> {
let n = a.nrows();
if self.is_identity {
return Ok(a.clone());
}
let mut result = Mat::zeros(n, n);
for i in 0..n {
for j in 0..n {
let orig_i = self.permutation.p[i];
let orig_j = self.permutation.p[j];
result[(i, j)] = a[(orig_i, orig_j)];
}
}
if let (Some(left), Some(right)) = (&self.left_scaling, &self.right_scaling) {
for i in 0..n {
for j in 0..n {
result[(i, j)] = left[i] * result[(i, j)] * right[j];
}
}
} else if let Some(diag) = &self.left_scaling {
for i in 0..n {
for j in 0..n {
result[(i, j)] = diag[i] * result[(i, j)] * diag[j];
}
}
}
Ok(result)
}
pub fn transform_vector(&self, x: &[f64]) -> Vec<f64> {
let n = x.len();
let mut result = vec![0.0; n];
for i in 0..n {
result[i] = x[self.permutation.p[i]];
}
if let Some(left) = &self.left_scaling {
for i in 0..n {
result[i] *= left[i];
}
}
result
}
pub fn untransform_vector(&self, y: &[f64]) -> Vec<f64> {
let n = y.len();
let mut temp = y.to_vec();
if let Some(left) = &self.left_scaling {
for i in 0..n {
temp[i] /= left[i];
}
}
let mut result = vec![0.0; n];
for i in 0..n {
result[self.permutation.pinv[i]] = temp[i];
}
result
}
}
pub fn preprocess_matrix(
a: &Mat<f64>,
reorder_method: ReorderingMethod,
scaling_method: ScalingMethod,
) -> Result<(Mat<f64>, MatrixPreprocessing), KError> {
let n = a.nrows();
if n != a.ncols() {
return Err(KError::SolveError(
"Matrix must be square for preprocessing".to_string(),
));
}
let perm_vec = match reorder_method {
ReorderingMethod::None => (0..n).collect(),
ReorderingMethod::Rcm => reverse_cuthill_mckee(a)?,
ReorderingMethod::CuthillMckee => cuthill_mckee(a)?,
ReorderingMethod::Colamd | ReorderingMethod::Amd => {
return Err(KError::NotImplemented(format!(
"{reorder_method:?} reordering not yet implemented"
)));
}
};
let mut permuted = Mat::zeros(n, n);
for i in 0..n {
for j in 0..n {
let orig_i = perm_vec[i];
let orig_j = perm_vec[j];
permuted[(i, j)] = a[(orig_i, orig_j)];
}
}
let (left_scaling, right_scaling) = match scaling_method {
ScalingMethod::None => (None, None),
ScalingMethod::Diagonal => {
let scaling: Vec<f64> = (0..n)
.map(|i| {
let diag_val = permuted[(i, i)].abs();
if diag_val > 1e-15 {
1.0 / diag_val.sqrt()
} else {
1.0 }
})
.collect();
(Some(scaling.clone()), Some(scaling))
}
ScalingMethod::Symmetric => {
return Err(KError::NotImplemented(
"Symmetric scaling not yet implemented".to_string(),
));
}
};
let mut result = permuted;
if let (Some(left), Some(right)) = (&left_scaling, &right_scaling) {
for i in 0..n {
for j in 0..n {
result[(i, j)] = left[i] * result[(i, j)] * right[j];
}
}
}
let mut pinv = vec![0; n];
for (new_idx, &old_idx) in perm_vec.iter().enumerate() {
pinv[old_idx] = new_idx;
}
let preprocessing = MatrixPreprocessing {
permutation: Permutation { p: perm_vec, pinv },
left_scaling,
right_scaling,
is_identity: reorder_method == ReorderingMethod::None
&& scaling_method == ScalingMethod::None,
};
Ok((result, preprocessing))
}
fn reverse_cuthill_mckee(a: &Mat<f64>) -> Result<Vec<usize>, KError> {
let perm = cuthill_mckee(a)?;
Ok(perm.into_iter().rev().collect())
}
fn cuthill_mckee(a: &Mat<f64>) -> Result<Vec<usize>, KError> {
let n = a.nrows();
let tol = 1e-15;
let mut adj = vec![Vec::new(); n];
for i in 0..n {
for j in 0..n {
if i != j && (a[(i, j)].abs() > tol || a[(j, i)].abs() > tol) {
adj[i].push(j);
}
}
}
Ok(cuthill_mckee_from_adj(&mut adj))
}
#[cfg(test)]
mod tests {
use super::*;
use faer::Mat;
fn create_test_matrix() -> Mat<f64> {
Mat::from_fn(4, 4, |i, j| match (i, j) {
(0, 0) => 2.0,
(0, 1) => 1.0,
(0, 3) => 1.0,
(1, 0) => 1.0,
(1, 1) => 3.0,
(1, 2) => 1.0,
(2, 1) => 1.0,
(2, 2) => 4.0,
(2, 3) => 1.0,
(3, 0) => 1.0,
(3, 2) => 1.0,
(3, 3) => 5.0,
_ => 0.0,
})
}
#[test]
fn test_identity_preprocessing() {
let a = create_test_matrix();
let (result, info) =
preprocess_matrix(&a, ReorderingMethod::None, ScalingMethod::None).unwrap();
assert!(info.is_identity);
assert_eq!(info.permutation.p, vec![0, 1, 2, 3]);
assert_eq!(info.permutation.pinv, vec![0, 1, 2, 3]);
assert!(info.left_scaling.is_none());
assert!(info.right_scaling.is_none());
for i in 0..4 {
for j in 0..4 {
assert!((result[(i, j)] - a[(i, j)]).abs() < 1e-15);
}
}
}
#[test]
fn test_cuthill_mckee_reordering() {
let a = create_test_matrix();
let (_result, info) =
preprocess_matrix(&a, ReorderingMethod::CuthillMckee, ScalingMethod::None).unwrap();
assert!(!info.is_identity);
assert_eq!(info.permutation.len(), 4);
let mut check = vec![false; 4];
for &p in &info.permutation.p {
assert!(p < 4);
assert!(!check[p]); check[p] = true;
}
assert!(check.iter().all(|&x| x));
for (new_idx, &old_idx) in info.permutation.p.iter().enumerate() {
assert_eq!(info.permutation.pinv[old_idx], new_idx);
}
}
#[test]
fn test_diagonal_scaling() {
let a = create_test_matrix();
let (_result, info) =
preprocess_matrix(&a, ReorderingMethod::None, ScalingMethod::Diagonal).unwrap();
assert!(!info.is_identity);
assert!(info.left_scaling.is_some());
assert!(info.right_scaling.is_some());
let scaling = info.left_scaling.as_ref().unwrap();
for i in 0..4 {
let scaled_diag = scaling[i] * a[(i, i)] * scaling[i];
assert!(
(scaled_diag - 1.0).abs() < 1e-12,
"Diagonal entry {} not scaled to 1: {}",
i,
scaled_diag
);
}
}
#[test]
fn test_vector_transformation() {
let a = create_test_matrix();
let x = vec![1.0, 2.0, 3.0, 4.0];
let (_, info) =
preprocess_matrix(&a, ReorderingMethod::CuthillMckee, ScalingMethod::Diagonal).unwrap();
let transformed = info.transform_vector(&x);
let recovered = info.untransform_vector(&transformed);
for i in 0..4 {
assert!(
(recovered[i] - x[i]).abs() < 1e-12,
"Vector transformation not invertible"
);
}
}
}