use std::collections::HashMap;
use super::super::traits::validate_triangular_solve_dims;
use super::validate_cpu_dtype;
use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::runtime::Runtime;
use crate::sparse::CsrData;
use crate::tensor::Tensor;
pub fn sparse_solve_triangular_cpu<R: Runtime<DType = DType>>(
l_or_u: &CsrData<R>,
b: &Tensor<R>,
lower: bool,
unit_diagonal: bool,
) -> Result<Tensor<R>> {
let (n, nrhs) = validate_triangular_solve_dims(l_or_u.shape, b.shape())?;
let dtype = l_or_u.values().dtype();
validate_cpu_dtype(dtype)?;
if b.dtype() != dtype {
return Err(Error::DTypeMismatch {
lhs: dtype,
rhs: b.dtype(),
});
}
let row_ptrs: Vec<i64> = l_or_u.row_ptrs().to_vec();
let col_indices: Vec<i64> = l_or_u.col_indices().to_vec();
let mat_values: Vec<f64> = match dtype {
DType::F32 => l_or_u
.values()
.to_vec::<f32>()
.iter()
.map(|&x| x as f64)
.collect(),
DType::F64 => l_or_u.values().to_vec(),
_ => {
return Err(Error::UnsupportedDType {
dtype,
op: "sparse_solve_triangular",
});
}
};
let b_flat: Vec<f64> = match dtype {
DType::F32 => b.to_vec::<f32>().iter().map(|&x| x as f64).collect(),
DType::F64 => b.to_vec(),
_ => unreachable!(),
};
let b_matrix: Vec<Vec<f64>> = if nrhs == 1 {
b_flat.into_iter().map(|x| vec![x]).collect()
} else {
(0..n)
.map(|i| (0..nrhs).map(|j| b_flat[i * nrhs + j]).collect())
.collect()
};
let col_to_idx = build_col_index_map(n, &row_ptrs, &col_indices);
let x_matrix = if lower {
solve_lower(
n,
nrhs,
&row_ptrs,
&col_indices,
&mat_values,
&b_matrix,
&col_to_idx,
unit_diagonal,
)?
} else {
solve_upper(
n,
nrhs,
&row_ptrs,
&col_indices,
&mat_values,
&b_matrix,
&col_to_idx,
)?
};
let x_flat: Vec<f64> = if nrhs == 1 {
x_matrix.into_iter().map(|v| v[0]).collect()
} else {
x_matrix.into_iter().flatten().collect()
};
let device = b.device();
let shape = if nrhs == 1 { vec![n] } else { vec![n, nrhs] };
let result = match dtype {
DType::F32 => {
let f32_vals: Vec<f32> = x_flat.iter().map(|&x| x as f32).collect();
Tensor::<R>::from_slice(&f32_vals, &shape, device)
}
DType::F64 => Tensor::<R>::from_slice(&x_flat, &shape, device),
_ => unreachable!(),
};
Ok(result)
}
fn build_col_index_map(
n: usize,
row_ptrs: &[i64],
col_indices: &[i64],
) -> Vec<HashMap<usize, usize>> {
let mut col_to_idx: Vec<HashMap<usize, usize>> = vec![HashMap::new(); n];
for i in 0..n {
let start = row_ptrs[i] as usize;
let end = row_ptrs[i + 1] as usize;
for idx in start..end {
let j = col_indices[idx] as usize;
col_to_idx[i].insert(j, idx);
}
}
col_to_idx
}
fn solve_lower(
n: usize,
nrhs: usize,
row_ptrs: &[i64],
col_indices: &[i64],
mat_values: &[f64],
b_matrix: &[Vec<f64>],
col_to_idx: &[HashMap<usize, usize>],
unit_diagonal: bool,
) -> Result<Vec<Vec<f64>>> {
let mut x_matrix: Vec<Vec<f64>> = vec![vec![0.0; nrhs]; n];
for i in 0..n {
for k in 0..nrhs {
let mut sum = b_matrix[i][k];
let start = row_ptrs[i] as usize;
let end = row_ptrs[i + 1] as usize;
for idx in start..end {
let j = col_indices[idx] as usize;
if j < i {
sum -= mat_values[idx] * x_matrix[j][k];
}
}
if !unit_diagonal {
let diag_idx = col_to_idx[i].get(&i);
match diag_idx {
Some(&idx) => {
let diag = mat_values[idx];
if diag.abs() < 1e-15 {
return Err(Error::Internal(format!(
"Zero diagonal at row {} in triangular solve",
i
)));
}
sum /= diag;
}
None => {
return Err(Error::Internal(format!(
"Missing diagonal at row {} in triangular solve",
i
)));
}
}
}
x_matrix[i][k] = sum;
}
}
Ok(x_matrix)
}
fn solve_upper(
n: usize,
nrhs: usize,
row_ptrs: &[i64],
col_indices: &[i64],
mat_values: &[f64],
b_matrix: &[Vec<f64>],
col_to_idx: &[HashMap<usize, usize>],
) -> Result<Vec<Vec<f64>>> {
let mut x_matrix: Vec<Vec<f64>> = vec![vec![0.0; nrhs]; n];
for i in (0..n).rev() {
for k in 0..nrhs {
let mut sum = b_matrix[i][k];
let start = row_ptrs[i] as usize;
let end = row_ptrs[i + 1] as usize;
for idx in start..end {
let j = col_indices[idx] as usize;
if j > i {
sum -= mat_values[idx] * x_matrix[j][k];
}
}
let diag_idx = col_to_idx[i].get(&i);
match diag_idx {
Some(&idx) => {
let diag = mat_values[idx];
if diag.abs() < 1e-15 {
return Err(Error::Internal(format!(
"Zero diagonal at row {} in triangular solve",
i
)));
}
sum /= diag;
}
None => {
return Err(Error::Internal(format!(
"Missing diagonal at row {} in triangular solve",
i
)));
}
}
x_matrix[i][k] = sum;
}
}
Ok(x_matrix)
}