use super::Matrix;
use crate::StrError;
pub fn mat_convert_to_blas_band(band: &mut Matrix, dense: &Matrix, ml: usize, mu: usize) -> Result<(), StrError> {
let (m, n) = dense.dims();
let (mb, nb) = band.dims();
if mb != ml + mu + 1 || nb != n {
return Err("the resulting matrix must be ml + mu + 1 by n");
}
for j in 0..n {
let a = if j > mu { j - mu } else { 0 };
let b = if j + ml + 1 < m { j + ml + 1 } else { m };
for i in a..b {
band.set(i + mu - j, j, dense.get(i, j));
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::mat_convert_to_blas_band;
use crate::Matrix;
#[test]
fn mat_convert_to_blas_band_captures_errors() {
let dense = Matrix::from(&[[1.0, 2.0], [3.0, 4.0]]);
let n = dense.dims().1;
let (ml, mu) = (1, 1);
let mut band_wrong = Matrix::new(ml + mu + 0, n);
assert_eq!(
mat_convert_to_blas_band(&mut band_wrong, &dense, ml, mu).err(),
Some("the resulting matrix must be ml + mu + 1 by n")
);
let mut band_wrong = Matrix::new(ml + mu + 0, n + 1);
assert_eq!(
mat_convert_to_blas_band(&mut band_wrong, &dense, ml, mu).err(),
Some("the resulting matrix must be ml + mu + 1 by n")
);
}
#[test]
fn mat_convert_to_blas_band_works_m_gt_n() {
#[rustfmt::skip]
let dense = Matrix::from(&[
[11.0, 12.0, 13.0, 14.0, 0.0, 0.0, 0.0, 0.0],
[21.0, 22.0, 23.0, 24.0, 25.0, 0.0, 0.0, 0.0],
[31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 0.0, 0.0],
[ 0.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 0.0],
[ 0.0, 0.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0],
[ 0.0, 0.0, 0.0, 64.0, 65.0, 66.0, 67.0, 68.0],
[ 0.0, 0.0, 0.0, 0.0, 75.0, 76.0, 77.0, 78.0],
[ 0.0, 0.0, 0.0, 0.0, 0.0, 86.0, 87.0, 88.0],
[ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 97.0, 98.0],
]);
#[rustfmt::skip]
let band_correct = Matrix::from(&[
[ 0.0, 0.0, 0.0, 14.0, 25.0, 36.0, 47.0, 58.0],
[ 0.0, 0.0, 13.0, 24.0, 35.0, 46.0, 57.0, 68.0],
[ 0.0, 12.0, 23.0, 34.0, 45.0, 56.0, 67.0, 78.0],
[11.0, 22.0, 33.0, 44.0, 55.0, 66.0, 77.0, 88.0],
[21.0, 32.0, 43.0, 54.0, 65.0, 76.0, 87.0, 98.0],
[31.0, 42.0, 53.0, 64.0, 75.0, 86.0, 97.0, 0.0],
]);
let n = dense.dims().1;
let (ml, mu) = (2, 3);
let mut band = Matrix::new(ml + mu + 1, n);
mat_convert_to_blas_band(&mut band, &dense, ml, mu).unwrap();
assert_eq!(band.as_data(), band_correct.as_data());
}
#[test]
fn mat_convert_to_blas_band_works_n_gt_m() {
#[rustfmt::skip]
let dense = Matrix::from(&[
[11.0, 12.0, 13.0, 14.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[21.0, 22.0, 23.0, 24.0, 25.0, 0.0, 0.0, 0.0, 0.0],
[31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 0.0, 0.0, 0.0],
[ 0.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 0.0, 0.0],
[ 0.0, 0.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 0.0],
[ 0.0, 0.0, 0.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0],
[ 0.0, 0.0, 0.0, 0.0, 75.0, 76.0, 77.0, 78.0, 79.0],
]);
#[rustfmt::skip]
let band_correct = Matrix::from(&[
[ 0.0, 0.0, 0.0, 14.0, 25.0, 36.0, 47.0, 58.0, 69.0],
[ 0.0, 0.0, 13.0, 24.0, 35.0, 46.0, 57.0, 68.0, 79.0],
[ 0.0, 12.0, 23.0, 34.0, 45.0, 56.0, 67.0, 78.0, 0.0],
[11.0, 22.0, 33.0, 44.0, 55.0, 66.0, 77.0, 0.0, 0.0],
[21.0, 32.0, 43.0, 54.0, 65.0, 76.0, 0.0, 0.0, 0.0],
[31.0, 42.0, 53.0, 64.0, 75.0, 0.0, 0.0, 0.0, 0.0],
]);
let n = dense.dims().1;
let (ml, mu) = (2, 3);
let mut band = Matrix::new(ml + mu + 1, n);
mat_convert_to_blas_band(&mut band, &dense, ml, mu).unwrap();
assert_eq!(band.as_data(), band_correct.as_data());
}
}