use crate::csr_array::CsrArray;
use crate::error::{SparseError, SparseResult};
use crate::sparray::SparseArray;
use scirs2_core::ndarray::{Array1, ArrayView1};
use scirs2_core::numeric::Float;
use scirs2_core::SparseElement;
use std::collections::HashMap;
use std::fmt::Debug;
#[derive(Debug, Clone)]
pub struct AMGOptions {
pub max_levels: usize,
pub theta: f64,
pub max_coarse_size: usize,
pub interpolation: InterpolationType,
pub smoother: SmootherType,
pub pre_smooth_steps: usize,
pub post_smooth_steps: usize,
pub cycle_type: CycleType,
}
impl Default for AMGOptions {
fn default() -> Self {
Self {
max_levels: 10,
theta: 0.25,
max_coarse_size: 50,
interpolation: InterpolationType::Classical,
smoother: SmootherType::GaussSeidel,
pre_smooth_steps: 1,
post_smooth_steps: 1,
cycle_type: CycleType::V,
}
}
}
#[derive(Debug, Clone, Copy)]
pub enum InterpolationType {
Classical,
Direct,
Standard,
}
#[derive(Debug, Clone, Copy)]
pub enum SmootherType {
GaussSeidel,
Jacobi,
SOR,
}
#[derive(Debug, Clone, Copy)]
pub enum CycleType {
V,
W,
F,
}
#[derive(Debug)]
pub struct AMGPreconditioner<T>
where
T: Float + SparseElement + Debug + Copy + 'static,
{
operators: Vec<CsrArray<T>>,
prolongations: Vec<CsrArray<T>>,
restrictions: Vec<CsrArray<T>>,
options: AMGOptions,
num_levels: usize,
}
impl<T> AMGPreconditioner<T>
where
T: Float + SparseElement + Debug + Copy + 'static,
{
pub fn new(matrix: &CsrArray<T>, options: AMGOptions) -> SparseResult<Self> {
let mut amg = AMGPreconditioner {
operators: vec![matrix.clone()],
prolongations: Vec::new(),
restrictions: Vec::new(),
options,
num_levels: 1,
};
amg.build_hierarchy()?;
Ok(amg)
}
fn build_hierarchy(&mut self) -> SparseResult<()> {
let mut level = 0;
while level < self.options.max_levels - 1 {
let currentmatrix = &self.operators[level];
let (rows, _) = currentmatrix.shape();
if rows <= self.options.max_coarse_size {
break;
}
let (coarsematrix, prolongation, restriction) = self.coarsen_level(currentmatrix)?;
let (coarse_rows, _) = coarsematrix.shape();
if coarse_rows >= rows {
break;
}
self.operators.push(coarsematrix);
self.prolongations.push(prolongation);
self.restrictions.push(restriction);
self.num_levels += 1;
level += 1;
}
Ok(())
}
fn coarsen_level(
&self,
matrix: &CsrArray<T>,
) -> SparseResult<(CsrArray<T>, CsrArray<T>, CsrArray<T>)> {
let (n, _) = matrix.shape();
let strong_connections = self.detect_strong_connections(matrix)?;
let (c_points, f_points) = self.classical_cf_splitting(matrix, &strong_connections)?;
let mut fine_to_coarse = HashMap::new();
for (coarse_idx, &fine_idx) in c_points.iter().enumerate() {
fine_to_coarse.insert(fine_idx, coarse_idx);
}
let coarse_size = c_points.len();
let prolongation = self.build_prolongation(matrix, &fine_to_coarse, coarse_size)?;
let restriction_box = prolongation.transpose()?;
let restriction = restriction_box
.as_any()
.downcast_ref::<CsrArray<T>>()
.ok_or_else(|| {
SparseError::ValueError("Failed to downcast restriction to CsrArray".to_string())
})?
.clone();
let temp_box = restriction.dot(matrix)?;
let temp = temp_box
.as_any()
.downcast_ref::<CsrArray<T>>()
.ok_or_else(|| {
SparseError::ValueError("Failed to downcast temp to CsrArray".to_string())
})?;
let coarsematrix_box = temp.dot(&prolongation)?;
let coarsematrix = coarsematrix_box
.as_any()
.downcast_ref::<CsrArray<T>>()
.ok_or_else(|| {
SparseError::ValueError("Failed to downcast coarsematrix to CsrArray".to_string())
})?
.clone();
Ok((coarsematrix, prolongation, restriction))
}
fn detect_strong_connections(&self, matrix: &CsrArray<T>) -> SparseResult<Vec<Vec<usize>>> {
let (n, _) = matrix.shape();
let mut strong_connections = vec![Vec::new(); n];
#[allow(clippy::needless_range_loop)]
for i in 0..n {
let row_start = matrix.get_indptr()[i];
let row_end = matrix.get_indptr()[i + 1];
let mut max_off_diag = T::sparse_zero();
for j in row_start..row_end {
let col = matrix.get_indices()[j];
if col != i {
let val = matrix.get_data()[j].abs();
if val > max_off_diag {
max_off_diag = val;
}
}
}
let threshold = T::from(self.options.theta).expect("Operation failed") * max_off_diag;
for j in row_start..row_end {
let col = matrix.get_indices()[j];
if col != i {
let val = matrix.get_data()[j].abs();
if val >= threshold {
strong_connections[i].push(col);
}
}
}
}
Ok(strong_connections)
}
fn classical_cf_splitting(
&self,
matrix: &CsrArray<T>,
strong_connections: &[Vec<usize>],
) -> SparseResult<(Vec<usize>, Vec<usize>)> {
let (n, _) = matrix.shape();
let mut influence = vec![0; n];
for i in 0..n {
influence[i] = strong_connections[i].len();
}
let mut point_type = vec![0; n];
let mut c_points = Vec::new();
let mut f_points = Vec::new();
let mut sorted_points: Vec<usize> = (0..n).collect();
sorted_points.sort_by(|&a, &b| influence[b].cmp(&influence[a]));
for &i in &sorted_points {
if point_type[i] != 0 {
continue; }
let mut needs_coarse = false;
for &j in &strong_connections[i] {
if point_type[j] == 2 {
let mut has_coarse_interp = false;
for &k in &strong_connections[j] {
if point_type[k] == 1 {
has_coarse_interp = true;
break;
}
}
if !has_coarse_interp {
needs_coarse = true;
break;
}
}
}
if needs_coarse || influence[i] > 2 {
point_type[i] = 1;
c_points.push(i);
for &j in &strong_connections[i] {
if point_type[j] == 0 {
point_type[j] = 2;
f_points.push(j);
}
}
}
}
#[allow(clippy::needless_range_loop)]
for i in 0..n {
if point_type[i] == 0 {
point_type[i] = 2;
f_points.push(i);
}
}
Ok((c_points, f_points))
}
fn build_prolongation(
&self,
matrix: &CsrArray<T>,
fine_to_coarse: &HashMap<usize, usize>,
coarse_size: usize,
) -> SparseResult<CsrArray<T>> {
let (n, _) = matrix.shape();
let mut prolongation_data = Vec::new();
let mut prolongation_indices = Vec::new();
let mut prolongation_indptr = vec![0];
let strong_connections = self.detect_strong_connections(matrix)?;
#[allow(clippy::needless_range_loop)]
for i in 0..n {
if let Some(&coarse_idx) = fine_to_coarse.get(&i) {
prolongation_data.push(T::sparse_one());
prolongation_indices.push(coarse_idx);
} else {
let interp_weights = self.compute_interpolation_weights(
i,
matrix,
&strong_connections[i],
fine_to_coarse,
)?;
if interp_weights.is_empty() {
prolongation_data.push(T::sparse_one());
prolongation_indices.push(0);
} else {
for (coarse_idx, weight) in interp_weights {
prolongation_data.push(weight);
prolongation_indices.push(coarse_idx);
}
}
}
prolongation_indptr.push(prolongation_data.len());
}
CsrArray::new(
prolongation_data.into(),
prolongation_indptr.into(),
prolongation_indices.into(),
(n, coarse_size),
)
}
fn compute_interpolation_weights(
&self,
fine_point: usize,
matrix: &CsrArray<T>,
strong_neighbors: &[usize],
fine_to_coarse: &HashMap<usize, usize>,
) -> SparseResult<Vec<(usize, T)>> {
let mut weights = Vec::new();
let mut coarse_neighbors = Vec::new();
let mut coarse_weights = Vec::new();
for &neighbor in strong_neighbors {
if let Some(&coarse_idx) = fine_to_coarse.get(&neighbor) {
coarse_neighbors.push(neighbor);
coarse_weights.push(coarse_idx);
}
}
if coarse_neighbors.is_empty() {
return Ok(weights);
}
let mut a_ii = T::sparse_zero();
let row_start = matrix.get_indptr()[fine_point];
let row_end = matrix.get_indptr()[fine_point + 1];
for j in row_start..row_end {
let col = matrix.get_indices()[j];
if col == fine_point {
a_ii = matrix.get_data()[j];
break;
}
}
if SparseElement::is_zero(&a_ii) {
return Ok(weights);
}
let mut total_weight = T::sparse_zero();
let mut temp_weights = Vec::new();
for &coarse_neighbor in &coarse_neighbors {
let mut a_ij = T::sparse_zero();
for j in row_start..row_end {
let col = matrix.get_indices()[j];
if col == coarse_neighbor {
a_ij = matrix.get_data()[j];
break;
}
}
if !SparseElement::is_zero(&a_ij) {
let weight = -a_ij / a_ii;
temp_weights.push(weight);
total_weight = total_weight + weight;
} else {
temp_weights.push(T::sparse_zero());
}
}
if !SparseElement::is_zero(&total_weight) {
for (i, &coarse_idx) in coarse_weights.iter().enumerate() {
let normalized_weight = temp_weights[i] / total_weight;
if !SparseElement::is_zero(&normalized_weight) {
weights.push((coarse_idx, normalized_weight));
}
}
}
Ok(weights)
}
pub fn apply(&self, b: &ArrayView1<T>) -> SparseResult<Array1<T>> {
let (n, _) = self.operators[0].shape();
if b.len() != n {
return Err(SparseError::DimensionMismatch {
expected: n,
found: b.len(),
});
}
let mut x = Array1::zeros(n);
self.mg_cycle(&mut x, b, 0)?;
Ok(x)
}
fn mg_cycle(&self, x: &mut Array1<T>, b: &ArrayView1<T>, level: usize) -> SparseResult<()> {
if level == self.num_levels - 1 {
self.coarse_solve(x, b, level)?;
return Ok(());
}
let matrix = &self.operators[level];
for _ in 0..self.options.pre_smooth_steps {
self.smooth(x, b, matrix)?;
}
let ax = matrix_vector_multiply(matrix, &x.view())?;
let residual = b - &ax;
let restriction = &self.restrictions[level];
let coarse_residual = matrix_vector_multiply(restriction, &residual.view())?;
let coarse_size = coarse_residual.len();
let mut coarse_correction = Array1::zeros(coarse_size);
match self.options.cycle_type {
CycleType::V => {
self.mg_cycle(&mut coarse_correction, &coarse_residual.view(), level + 1)?;
}
CycleType::W => {
self.mg_cycle(&mut coarse_correction, &coarse_residual.view(), level + 1)?;
self.mg_cycle(&mut coarse_correction, &coarse_residual.view(), level + 1)?;
}
CycleType::F => {
self.mg_cycle(&mut coarse_correction, &coarse_residual.view(), level + 1)?;
}
}
let prolongation = &self.prolongations[level];
let fine_correction = matrix_vector_multiply(prolongation, &coarse_correction.view())?;
for i in 0..x.len() {
x[i] = x[i] + fine_correction[i];
}
for _ in 0..self.options.post_smooth_steps {
self.smooth(x, b, matrix)?;
}
Ok(())
}
fn smooth(
&self,
x: &mut Array1<T>,
b: &ArrayView1<T>,
matrix: &CsrArray<T>,
) -> SparseResult<()> {
match self.options.smoother {
SmootherType::GaussSeidel => self.gauss_seidel_smooth(x, b, matrix),
SmootherType::Jacobi => self.jacobi_smooth(x, b, matrix),
SmootherType::SOR => {
self.sor_smooth(x, b, matrix, T::from(1.2).expect("Operation failed"))
}
}
}
fn gauss_seidel_smooth(
&self,
x: &mut Array1<T>,
b: &ArrayView1<T>,
matrix: &CsrArray<T>,
) -> SparseResult<()> {
let n = x.len();
for i in 0..n {
let row_start = matrix.get_indptr()[i];
let row_end = matrix.get_indptr()[i + 1];
let mut sum = T::sparse_zero();
let mut diag_val = T::sparse_zero();
for j in row_start..row_end {
let col = matrix.get_indices()[j];
let val = matrix.get_data()[j];
if col == i {
diag_val = val;
} else {
sum = sum + val * x[col];
}
}
if !SparseElement::is_zero(&diag_val) {
x[i] = (b[i] - sum) / diag_val;
}
}
Ok(())
}
fn jacobi_smooth(
&self,
x: &mut Array1<T>,
b: &ArrayView1<T>,
matrix: &CsrArray<T>,
) -> SparseResult<()> {
let n = x.len();
let mut x_new = x.clone();
for i in 0..n {
let row_start = matrix.get_indptr()[i];
let row_end = matrix.get_indptr()[i + 1];
let mut sum = T::sparse_zero();
let mut diag_val = T::sparse_zero();
for j in row_start..row_end {
let col = matrix.get_indices()[j];
let val = matrix.get_data()[j];
if col == i {
diag_val = val;
} else {
sum = sum + val * x[col];
}
}
if !SparseElement::is_zero(&diag_val) {
x_new[i] = (b[i] - sum) / diag_val;
}
}
*x = x_new;
Ok(())
}
fn sor_smooth(
&self,
x: &mut Array1<T>,
b: &ArrayView1<T>,
matrix: &CsrArray<T>,
omega: T,
) -> SparseResult<()> {
let n = x.len();
for i in 0..n {
let row_start = matrix.get_indptr()[i];
let row_end = matrix.get_indptr()[i + 1];
let mut sum = T::sparse_zero();
let mut diag_val = T::sparse_zero();
for j in row_start..row_end {
let col = matrix.get_indices()[j];
let val = matrix.get_data()[j];
if col == i {
diag_val = val;
} else {
sum = sum + val * x[col];
}
}
if !SparseElement::is_zero(&diag_val) {
let x_gs = (b[i] - sum) / diag_val;
x[i] = (T::sparse_one() - omega) * x[i] + omega * x_gs;
}
}
Ok(())
}
fn coarse_solve(&self, x: &mut Array1<T>, b: &ArrayView1<T>, level: usize) -> SparseResult<()> {
let matrix = &self.operators[level];
for _ in 0..10 {
self.gauss_seidel_smooth(x, b, matrix)?;
}
Ok(())
}
pub fn num_levels(&self) -> usize {
self.num_levels
}
pub fn level_size(&self, level: usize) -> Option<(usize, usize)> {
if level < self.num_levels {
Some(self.operators[level].shape())
} else {
None
}
}
}
#[allow(dead_code)]
fn matrix_vector_multiply<T>(matrix: &CsrArray<T>, x: &ArrayView1<T>) -> SparseResult<Array1<T>>
where
T: Float + SparseElement + Debug + Copy + 'static,
{
let (rows, cols) = matrix.shape();
if x.len() != cols {
return Err(SparseError::DimensionMismatch {
expected: cols,
found: x.len(),
});
}
let mut result = Array1::zeros(rows);
for i in 0..rows {
for j in matrix.get_indptr()[i]..matrix.get_indptr()[i + 1] {
let col = matrix.get_indices()[j];
let val = matrix.get_data()[j];
result[i] = result[i] + val * x[col];
}
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::csr_array::CsrArray;
#[test]
fn test_amg_preconditioner_creation() {
let rows = vec![0, 0, 1, 1, 2, 2];
let cols = vec![0, 1, 0, 1, 1, 2];
let data = vec![2.0, -1.0, -1.0, 2.0, -1.0, 2.0];
let matrix =
CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).expect("Operation failed");
let amg = AMGPreconditioner::new(&matrix, AMGOptions::default()).expect("Operation failed");
assert!(amg.num_levels() >= 1);
assert_eq!(amg.level_size(0), Some((3, 3)));
}
#[test]
fn test_amg_apply() {
let rows = vec![0, 1, 2];
let cols = vec![0, 1, 2];
let data = vec![2.0, 3.0, 4.0];
let matrix =
CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).expect("Operation failed");
let amg = AMGPreconditioner::new(&matrix, AMGOptions::default()).expect("Operation failed");
let b = Array1::from_vec(vec![2.0, 3.0, 4.0]);
let x = amg.apply(&b.view()).expect("Operation failed");
assert!(x[0] > 0.5 && x[0] < 1.5);
assert!(x[1] > 0.5 && x[1] < 1.5);
assert!(x[2] > 0.5 && x[2] < 1.5);
}
#[test]
fn test_amg_options() {
let options = AMGOptions {
max_levels: 5,
theta: 0.5,
smoother: SmootherType::Jacobi,
cycle_type: CycleType::W,
..Default::default()
};
assert_eq!(options.max_levels, 5);
assert_eq!(options.theta, 0.5);
assert!(matches!(options.smoother, SmootherType::Jacobi));
assert!(matches!(options.cycle_type, CycleType::W));
}
#[test]
fn test_gauss_seidel_smoother() {
let rows = vec![0, 0, 1, 1, 2, 2];
let cols = vec![0, 1, 0, 1, 1, 2];
let data = vec![2.0, -1.0, -1.0, 2.0, -1.0, 2.0];
let matrix =
CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).expect("Operation failed");
let amg = AMGPreconditioner::new(&matrix, AMGOptions::default()).expect("Operation failed");
let mut x = Array1::from_vec(vec![0.0, 0.0, 0.0]);
let b = Array1::from_vec(vec![1.0, 1.0, 1.0]);
amg.gauss_seidel_smooth(&mut x, &b.view(), &matrix)
.expect("Operation failed");
assert!(x.iter().any(|&val| val.abs() > 1e-10));
}
#[test]
fn test_enhanced_amg_coarsening() {
let rows = vec![0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 4];
let cols = vec![0, 1, 0, 1, 2, 1, 2, 3, 2, 3, 3, 4, 0];
let data = vec![
4.0, -1.0, -1.0, 4.0, -1.0, -1.0, 4.0, -1.0, -1.0, 4.0, -1.0, 4.0, -1.0,
];
let matrix =
CsrArray::from_triplets(&rows, &cols, &data, (5, 5), false).expect("Operation failed");
let options = AMGOptions {
theta: 0.25, ..Default::default()
};
let amg = AMGPreconditioner::new(&matrix, options).expect("Operation failed");
assert!(amg.num_levels() >= 1);
let b = Array1::from_vec(vec![1.0, 2.0, 3.0, 2.0, 1.0]);
let x = amg.apply(&b.view()).expect("Operation failed");
assert_eq!(x.len(), 5);
assert!(x.iter().any(|&val| val.abs() > 1e-10));
}
#[test]
fn test_strong_connection_detection() {
let rows = vec![0, 0, 1, 1, 2, 2];
let cols = vec![0, 1, 0, 1, 1, 2];
let data = vec![4.0, -2.0, -2.0, 4.0, -2.0, 4.0];
let matrix =
CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).expect("Operation failed");
let options = AMGOptions {
theta: 0.25,
..Default::default()
};
let amg = AMGPreconditioner::new(&matrix, options).expect("Operation failed");
let strong_connections = amg
.detect_strong_connections(&matrix)
.expect("Operation failed");
assert!(!strong_connections[0].is_empty());
assert!(!strong_connections[1].is_empty());
if strong_connections[0].contains(&1) {
assert!(strong_connections[1].contains(&0));
}
}
}