use crate::algorithm::sparse_linalg::SparseLinAlgAlgorithms;
use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::ops::{BinaryOps, ReduceOps, ScalarOps, UnaryOps};
use crate::runtime::Runtime;
use crate::sparse::{CsrData, SparseOps};
use crate::tensor::Tensor;
use super::super::helpers::vector_norm;
use super::super::types::{SorOptions, SorResult};
pub fn sor_impl<R, C>(
client: &C,
a: &CsrData<R>,
b: &Tensor<R>,
x0: Option<&Tensor<R>>,
options: SorOptions,
) -> Result<SorResult<R>>
where
R: Runtime<DType = DType>,
R::Client: SparseOps<R>,
C: SparseLinAlgAlgorithms<R>
+ SparseOps<R>
+ BinaryOps<R>
+ UnaryOps<R>
+ ReduceOps<R>
+ ScalarOps<R>,
{
let n = super::super::traits::validate_iterative_inputs(a.shape, b, x0)?;
let dtype = b.dtype();
let device = b.device();
if !matches!(dtype, DType::F32 | DType::F64) {
return Err(Error::UnsupportedDType { dtype, op: "sor" });
}
let b_norm = vector_norm(client, b)?;
let mut x = match x0 {
Some(x0) => x0.clone(),
None => Tensor::<R>::zeros(&[n], dtype, device),
};
if b_norm < options.atol {
return Ok(SorResult {
solution: x,
iterations: 0,
residual_norm: b_norm,
converged: true,
});
}
let omega = options.omega;
let lower_tri = build_sor_lower_triangular::<R>(a, omega, device)?;
for iter in 0..options.max_iter {
let ax = a.spmv(&x)?;
let r = client.sub(b, &ax)?;
let rhs = client.mul_scalar(&r, omega)?;
let delta = client.sparse_solve_triangular(&lower_tri, &rhs, true, false)?;
x = client.add(&x, &delta)?;
let ax = a.spmv(&x)?;
let r_check = client.sub(b, &ax)?;
let res_norm = vector_norm(client, &r_check)?;
if res_norm < options.atol || res_norm / b_norm < options.rtol {
return Ok(SorResult {
solution: x,
iterations: iter + 1,
residual_norm: res_norm,
converged: true,
});
}
}
let ax = a.spmv(&x)?;
let r_final = client.sub(b, &ax)?;
let final_residual = vector_norm(client, &r_final)?;
Ok(SorResult {
solution: x,
iterations: options.max_iter,
residual_norm: final_residual,
converged: false,
})
}
fn build_sor_lower_triangular<R: Runtime<DType = DType>>(
a: &CsrData<R>,
omega: f64,
device: &R::Device,
) -> Result<CsrData<R>> {
let n = a.shape[0];
let dtype = a.values().dtype();
let row_ptrs: Vec<i64> = a.row_ptrs().to_vec();
let col_indices: Vec<i64> = a.col_indices().to_vec();
let values: Vec<f64> = match dtype {
DType::F32 => a
.values()
.to_vec::<f32>()
.iter()
.map(|&v| v as f64)
.collect(),
DType::F64 => a.values().to_vec::<f64>(),
_ => unreachable!(),
};
let mut lt_rp = Vec::with_capacity(n + 1);
let mut lt_ci = Vec::new();
let mut lt_vv = Vec::new();
lt_rp.push(0i64);
for i in 0..n {
let start = row_ptrs[i] as usize;
let end = row_ptrs[i + 1] as usize;
let mut row_entries: Vec<(i64, f64)> = Vec::new();
for idx in start..end {
let j = col_indices[idx] as usize;
let v = values[idx];
if j < i {
row_entries.push((j as i64, omega * v));
} else if j == i {
row_entries.push((j as i64, v));
}
}
row_entries.sort_by_key(|&(c, _)| c);
for (c, v) in row_entries {
lt_ci.push(c);
lt_vv.push(v);
}
lt_rp.push(lt_ci.len() as i64);
}
let rp_t = Tensor::<R>::from_slice(<_rp, &[lt_rp.len()], device);
let ci_t = Tensor::<R>::from_slice(<_ci, &[lt_ci.len()], device);
let vv_t = Tensor::<R>::from_slice(<_vv, &[lt_vv.len()], device);
CsrData::new(rp_t, ci_t, vv_t, [n, n])
}