use std::collections::HashMap;
use super::super::traits::validate_square_sparse;
use super::super::types::{IcDecomposition, IcOptions};
use super::validate_cpu_dtype;
use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::runtime::Runtime;
use crate::sparse::CsrData;
use crate::tensor::Tensor;
pub fn ic0_cpu<R: Runtime<DType = DType>>(
a: &CsrData<R>,
options: IcOptions,
) -> Result<IcDecomposition<R>> {
let n = validate_square_sparse(a.shape)?;
let dtype = a.values().dtype();
validate_cpu_dtype(dtype)?;
let row_ptrs: Vec<i64> = a.row_ptrs().to_vec();
let col_indices: Vec<i64> = a.col_indices().to_vec();
let values: Vec<f64> = match dtype {
DType::F32 => a
.values()
.to_vec::<f32>()
.iter()
.map(|&x| x as f64)
.collect(),
DType::F64 => a.values().to_vec(),
_ => return Err(Error::UnsupportedDType { dtype, op: "ic0" }),
};
let mut l_values = values.clone();
let mut col_to_idx: Vec<HashMap<usize, usize>> = vec![HashMap::new(); n];
for i in 0..n {
let start = row_ptrs[i] as usize;
let end = row_ptrs[i + 1] as usize;
for idx in start..end {
let j = col_indices[idx] as usize;
if j <= i {
col_to_idx[i].insert(j, idx);
}
}
}
for i in 0..n {
let i_start = row_ptrs[i] as usize;
let i_end = row_ptrs[i + 1] as usize;
for idx_ik in i_start..i_end {
let k = col_indices[idx_ik] as usize;
if k >= i {
break; }
let k_start = row_ptrs[k] as usize;
let k_end = row_ptrs[k + 1] as usize;
let mut sum = l_values[idx_ik];
for idx_kj in k_start..k_end {
let j = col_indices[idx_kj] as usize;
if j >= k {
break; }
if let Some(&idx_ij) = col_to_idx[i].get(&j) {
sum -= l_values[idx_ij] * l_values[idx_kj];
}
}
let diag_idx = match col_to_idx[k].get(&k) {
Some(&idx) => idx,
None => {
return Err(Error::Internal(format!(
"Zero diagonal at row {} in IC(0)",
k
)));
}
};
l_values[idx_ik] = sum / l_values[diag_idx];
}
let diag_idx = match col_to_idx[i].get(&i) {
Some(&idx) => idx,
None => {
return Err(Error::Internal(format!(
"Missing diagonal at row {} in IC(0)",
i
)));
}
};
let mut sum = l_values[diag_idx] + options.diagonal_shift;
for idx_ij in i_start..i_end {
let j = col_indices[idx_ij] as usize;
if j >= i {
break;
}
sum -= l_values[idx_ij] * l_values[idx_ij];
}
if sum <= 0.0 {
if options.diagonal_shift > 0.0 {
sum = options.diagonal_shift;
} else {
return Err(Error::Internal(format!(
"Matrix not positive definite at row {} (sum = {})",
i, sum
)));
}
}
l_values[diag_idx] = sum.sqrt();
}
if options.drop_tolerance > 0.0 {
for val in &mut l_values {
if val.abs() < options.drop_tolerance {
*val = 0.0;
}
}
}
let l = extract_lower_triangle::<R>(
n,
&row_ptrs,
&col_indices,
&l_values,
dtype,
a.values().device(),
)?;
Ok(IcDecomposition { l })
}
fn extract_lower_triangle<R: Runtime<DType = DType>>(
n: usize,
row_ptrs: &[i64],
col_indices: &[i64],
l_values: &[f64],
dtype: DType,
device: &R::Device,
) -> Result<CsrData<R>> {
let mut new_row_ptrs = vec![0i64; n + 1];
let mut new_col_indices = Vec::new();
let mut new_values = Vec::new();
for i in 0..n {
let start = row_ptrs[i] as usize;
let end = row_ptrs[i + 1] as usize;
let mut count = 0i64;
for idx in start..end {
let j = col_indices[idx] as usize;
if j > i {
continue; }
let val = l_values[idx];
if val.abs() >= 1e-15 {
new_col_indices.push(j as i64);
new_values.push(val);
count += 1;
}
}
new_row_ptrs[i + 1] = new_row_ptrs[i] + count;
}
let l_row_ptrs_tensor = Tensor::<R>::from_slice(&new_row_ptrs, &[n + 1], device);
let l_col_indices_tensor =
Tensor::<R>::from_slice(&new_col_indices, &[new_col_indices.len()], device);
let l_values_tensor = match dtype {
DType::F32 => {
let f32_vals: Vec<f32> = new_values.iter().map(|&x| x as f32).collect();
Tensor::<R>::from_slice(&f32_vals, &[f32_vals.len()], device)
}
DType::F64 => Tensor::<R>::from_slice(&new_values, &[new_values.len()], device),
_ => unreachable!(),
};
CsrData::new(
l_row_ptrs_tensor,
l_col_indices_tensor,
l_values_tensor,
[n, n],
)
}