use crate::algorithm::sparse_linalg::{
IluDecomposition, IlukDecomposition, SparseLinAlgAlgorithms,
};
use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::ops::{BinaryOps, ReduceOps, ScalarOps, UnaryOps};
use crate::runtime::Runtime;
use crate::tensor::Tensor;
pub const BREAKDOWN_TOL: f64 = 1e-40;
pub const INVARIANT_SUBSPACE_TOL: f64 = 1e-14;
pub const REORTH_TOL: f64 = 1e-15;
pub fn vector_norm<R, C>(client: &C, v: &Tensor<R>) -> Result<f64>
where
R: Runtime<DType = DType>,
C: BinaryOps<R> + UnaryOps<R> + ReduceOps<R>,
{
let v_sq = client.mul(v, v)?;
let ndim = v_sq.ndim();
let dims: Vec<usize> = (0..ndim).collect();
let sum_sq = client.sum(&v_sq, &dims, false)?;
let norm_tensor = client.sqrt(&sum_sq)?;
match norm_tensor.dtype() {
DType::F32 => Ok(norm_tensor.item::<f32>()? as f64),
DType::F64 => Ok(norm_tensor.item::<f64>()?),
dtype => Err(Error::UnsupportedDType {
dtype,
op: "vector_norm",
}),
}
}
pub fn vector_dot<R, C>(client: &C, u: &Tensor<R>, v: &Tensor<R>) -> Result<f64>
where
R: Runtime<DType = DType>,
C: BinaryOps<R> + ReduceOps<R>,
{
let uv = client.mul(u, v)?;
let ndim = uv.ndim();
let dims: Vec<usize> = (0..ndim).collect();
let dot_tensor = client.sum(&uv, &dims, false)?;
match dot_tensor.dtype() {
DType::F32 => Ok(dot_tensor.item::<f32>()? as f64),
DType::F64 => Ok(dot_tensor.item::<f64>()?),
dtype => Err(Error::UnsupportedDType {
dtype,
op: "vector_dot",
}),
}
}
pub fn apply_ilu0_preconditioner<R, C>(
client: &C,
precond: &Option<IluDecomposition<R>>,
v: &Tensor<R>,
) -> Result<Tensor<R>>
where
R: Runtime,
C: SparseLinAlgAlgorithms<R>,
{
match precond {
None => Ok(v.clone()),
Some(ilu) => {
let y = client.sparse_solve_triangular(&ilu.l, v, true, true)?;
client.sparse_solve_triangular(&ilu.u, &y, false, false)
}
}
}
pub fn apply_iluk_preconditioner<R, C>(
client: &C,
ilu: &IlukDecomposition<R>,
v: &Tensor<R>,
) -> Result<Tensor<R>>
where
R: Runtime,
C: SparseLinAlgAlgorithms<R>,
{
let y = client.sparse_solve_triangular(&ilu.l, v, true, true)?;
client.sparse_solve_triangular(&ilu.u, &y, false, false)
}
#[inline]
pub fn givens_rotation(a: f64, b: f64) -> (f64, f64, f64) {
if b.abs() < 1e-15 {
if a >= 0.0 {
(1.0, 0.0, a)
} else {
(-1.0, 0.0, -a)
}
} else if a.abs() < 1e-15 {
(0.0, b.signum(), b.abs())
} else {
let r = a.hypot(b); let c = a / r;
let s = b / r;
(c, s, r)
}
}
pub fn solve_upper_triangular(h_matrix: &[Vec<f64>], g: &[f64]) -> Vec<f64> {
let m = g.len();
let mut y = vec![0.0; m];
for i in (0..m).rev() {
let mut sum = g[i];
for j in (i + 1)..m {
sum -= h_matrix[j][i] * y[j];
}
if h_matrix[i][i].abs() > 1e-15 {
y[i] = sum / h_matrix[i][i];
}
}
y
}
pub fn update_solution<R, C>(
client: &C,
x: &Tensor<R>,
z_basis: &[Tensor<R>],
y: &[f64],
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: BinaryOps<R> + ScalarOps<R>,
{
let m = y.len();
let mut delta = Tensor::<R>::zeros(x.shape(), x.dtype(), x.device());
for j in 0..m {
if y[j].abs() > 1e-15 {
let scaled_z = client.mul_scalar(&z_basis[j], y[j])?;
delta = client.add(&delta, &scaled_z)?;
}
}
client.add(x, &delta)
}
pub fn detect_stagnation(
residual_history: &[f64],
params: &super::types::StagnationParams,
) -> bool {
let len = residual_history.len();
if len < params.min_iterations + params.window_size {
return false;
}
let current = residual_history[len - 1];
let past = residual_history[len - 1 - params.window_size];
current > params.reduction_factor * past
}
pub fn accumulate_basis_combination<R, C>(
client: &C,
basis: &[Tensor<R>],
coefficients: &[f64],
n: usize,
dtype: DType,
device: &R::Device,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: BinaryOps<R> + ScalarOps<R>,
{
let mut result = Tensor::<R>::zeros(&[n], dtype, device);
let len = basis.len().min(coefficients.len());
for (vj, &coeff) in basis.iter().zip(coefficients.iter()).take(len) {
if coeff.abs() > REORTH_TOL {
let scaled = client.mul_scalar(vj, coeff)?;
result = client.add(&result, &scaled)?;
}
}
Ok(result)
}
pub fn extract_diagonal_inv<R, C>(client: &C, a: &crate::sparse::CsrData<R>) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: UnaryOps<R> + BinaryOps<R> + ScalarOps<R> + crate::sparse::SparseOps<R>,
{
let n = a.shape[0];
let device = a.values().device();
let dtype = a.values().dtype();
let diag = a.diagonal_with_client(client)?;
let ones = Tensor::<R>::ones(&[n], dtype, device);
client.div(&ones, &diag)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_givens_rotation() {
let (c, s, r) = givens_rotation(3.0, 4.0);
assert!((c * c + s * s - 1.0).abs() < 1e-10, "c^2 + s^2 = 1");
assert!((r - 5.0).abs() < 1e-10, "r = 5");
assert!(r >= 0.0, "r must be non-negative");
assert!((c * 3.0 + s * 4.0 - 5.0).abs() < 1e-10, "rotation works");
assert!((-s * 3.0 + c * 4.0).abs() < 1e-10, "zeroes out b");
let (c, s, r) = givens_rotation(5.0, 0.0);
assert_eq!(c, 1.0);
assert_eq!(s, 0.0);
assert_eq!(r, 5.0);
assert!(r >= 0.0);
let (c, s, r) = givens_rotation(-5.0, 0.0);
assert_eq!(c, -1.0);
assert_eq!(s, 0.0);
assert_eq!(r, 5.0);
assert!(r >= 0.0, "r must be non-negative even for negative a");
assert!((c * (-5.0) + s * 0.0 - r).abs() < 1e-10);
let (c, s, r) = givens_rotation(0.0, 3.0);
assert_eq!(c, 0.0);
assert_eq!(s, 1.0);
assert_eq!(r, 3.0);
assert!(r >= 0.0);
let (c, s, r) = givens_rotation(0.0, -3.0);
assert_eq!(c, 0.0);
assert_eq!(s, -1.0);
assert_eq!(r, 3.0);
assert!(r >= 0.0, "r must be non-negative for negative b");
let (c, s, r) = givens_rotation(-3.0, -4.0);
assert!((c * c + s * s - 1.0).abs() < 1e-10, "c^2 + s^2 = 1");
assert!((r - 5.0).abs() < 1e-10, "r = 5");
assert!(r >= 0.0, "r must be non-negative");
assert!(
(c * (-3.0) + s * (-4.0) - r).abs() < 1e-10,
"rotation gives r"
);
assert!((-s * (-3.0) + c * (-4.0)).abs() < 1e-10, "zeroes out b");
}
#[test]
fn test_solve_upper_triangular() {
let h_matrix = vec![vec![2.0, 0.0], vec![1.0, 3.0]];
let g = vec![3.0, 6.0];
let y = solve_upper_triangular(&h_matrix, &g);
assert!((y[0] - 0.5).abs() < 1e-10);
assert!((y[1] - 2.0).abs() < 1e-10);
}
#[test]
fn test_stagnation_detection() {
let params = super::super::types::StagnationParams {
reduction_factor: 0.5,
window_size: 3,
min_iterations: 2,
};
let history = vec![1.0, 0.9];
assert!(!detect_stagnation(&history, ¶ms));
let history = vec![1.0, 0.8, 0.6, 0.4, 0.2];
assert!(!detect_stagnation(&history, ¶ms));
let history = vec![1.0, 0.9, 0.85, 0.8, 0.75, 0.72];
assert!(detect_stagnation(&history, ¶ms));
}
}