use crate::utils::get_mut_unchecked;
#[allow(unused_imports)]
use crate::{
IterSolverError, IterSolverResult,
ops::{Matrix, Vector},
utils::from_diagonal,
};
pub fn diagm(data: &[f64], offset: isize) -> Matrix<f64> {
if data.is_empty() {
#[cfg(feature = "ndarray")]
{
return Matrix::zeros((0, 0));
}
#[cfg(not(feature = "ndarray"))]
{
return Matrix::zeros(0, 0);
}
}
match offset {
0 => from_diagonal(data),
offset => {
let offset_usize = offset.unsigned_abs();
let n = data.len() + offset_usize;
let mut mat = {
#[cfg(feature = "ndarray")]
{
Matrix::zeros((n, n))
}
#[cfg(not(feature = "ndarray"))]
{
Matrix::zeros(n, n)
}
};
unsafe {
if offset > 0 {
data.iter().enumerate().for_each(|(idx, &val)| {
*get_mut_unchecked(&mut mat, idx, idx + offset_usize) = val
});
} else {
data.iter().enumerate().for_each(|(idx, &val)| {
*get_mut_unchecked(&mut mat, idx + offset_usize, idx) = val
});
}
}
mat
}
}
}
pub fn tridiagonal(
diagonal: &[f64],
lower: &[f64],
upper: &[f64],
) -> IterSolverResult<Matrix<f64>> {
if diagonal.len() != lower.len() + 1 || lower.len() != upper.len() {
return Err(IterSolverError::DimensionError(format!(
"For tridiagonal matrix, the length of `diagonal` {}, the length of `lower` {} and `upper` {} do not match",
diagonal.len(),
lower.len(),
upper.len()
)));
}
Ok(diagm(diagonal, 0) + diagm(lower, -1) + diagm(upper, 1))
}
pub fn symmetric_tridiagonal(
diagonal: &[f64],
sub_diagonal: &[f64],
) -> IterSolverResult<Matrix<f64>> {
tridiagonal(diagonal, sub_diagonal, sub_diagonal)
}
pub fn diags(diagonals: Vec<Vec<f64>>, offsets: Vec<isize>) -> IterSolverResult<Matrix<f64>> {
if diagonals.len() != offsets.len() {
return Err(IterSolverError::DimensionError(format!(
"The length of `diagonals` {} and `offsets` {} do not match",
diagonals.len(),
offsets.len()
)));
}
if diagonals.is_empty() {
#[cfg(feature = "ndarray")]
{
return Ok(Matrix::zeros((0, 0)));
}
#[cfg(not(feature = "ndarray"))]
{
return Ok(Matrix::zeros(0, 0));
}
}
let n = diagonals[0].len() + offsets[0].unsigned_abs();
let mut unique_offsets = std::collections::HashSet::new();
for (index, (diag, offset)) in diagonals.iter().zip(offsets.iter()).enumerate() {
if !unique_offsets.insert(*offset) {
return Err(IterSolverError::InvalidInput(
"Duplicate offsets are not allowed".to_string(),
));
}
if diag.len() + offset.unsigned_abs() != n {
return Err(IterSolverError::DimensionError(format!(
"The {}th diagonal's length {} and its offset {} do not match",
index,
diag.len(),
offset
)));
}
}
let mut res = diagm(&diagonals[0], offsets[0]);
for (diag, offset) in diagonals.iter().zip(offsets.iter()) {
res += &diagm(diag, *offset);
}
Ok(res)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_diag() {
let data = vec![1.0, 2.0, 3.0];
let mat = diagm(&data, 0);
println!("{mat:?}");
let mat = diagm(&data, 1);
println!("{mat:?}");
let mat = diagm(&data, -1);
println!("{mat:?}");
let mat = diagm(&data, 2);
println!("{mat:?}");
let mat = diagm(&data, -2);
println!("{mat:?}");
}
}