use scirs2_core::ndarray::ScalarOperand;
use scirs2_core::ndarray::{Array1, ArrayView1};
use scirs2_core::numeric::{Float, NumAssign, One, Zero};
use std::{fmt::Debug, iter::Sum};
use super::StructuredMatrix;
use crate::error::{LinalgError, LinalgResult};
#[derive(Debug, Clone)]
pub struct ToeplitzMatrix<A>
where
A: Float + NumAssign + Zero + Sum + One + ScalarOperand + Send + Sync + Debug,
{
first_row: Array1<A>,
first_col: Array1<A>,
nrows: usize,
ncols: usize,
}
impl<A> ToeplitzMatrix<A>
where
A: Float + NumAssign + Zero + Sum + One + ScalarOperand + Send + Sync + Debug,
{
pub fn new(_first_row: ArrayView1<A>, firstcol: ArrayView1<A>) -> LinalgResult<Self> {
if _first_row.is_empty() || firstcol.is_empty() {
return Err(LinalgError::InvalidInputError(
"Row and column must not be empty".to_string(),
));
}
if (_first_row[0] - firstcol[0]).abs() > A::epsilon() {
return Err(LinalgError::InvalidInputError(
"First element of row and column must be the same".to_string(),
));
}
Ok(ToeplitzMatrix {
first_row: _first_row.to_owned(),
first_col: firstcol.to_owned(),
nrows: firstcol.len(),
ncols: _first_row.len(),
})
}
pub fn new_symmetric(_firstrow: ArrayView1<A>) -> LinalgResult<Self> {
let n = _firstrow.len();
let mut first_col = Array1::zeros(n);
for i in 0..n {
first_col[i] = _firstrow[i];
}
Ok(ToeplitzMatrix {
first_row: _firstrow.to_owned(),
first_col,
nrows: n,
ncols: n,
})
}
pub fn from_parts(c: A, r: ArrayView1<A>, l: ArrayView1<A>) -> LinalgResult<Self> {
let ncols = r.len() + 1;
let nrows = l.len() + 1;
let mut first_row = Array1::zeros(ncols);
let mut first_col = Array1::zeros(nrows);
first_row[0] = c;
first_col[0] = c;
for (i, &val) in r.iter().enumerate() {
first_row[i + 1] = val;
}
for (i, &val) in l.iter().enumerate() {
first_col[i + 1] = val;
}
Ok(ToeplitzMatrix {
first_row,
first_col,
nrows,
ncols,
})
}
}
impl<A> StructuredMatrix<A> for ToeplitzMatrix<A>
where
A: Float + NumAssign + Zero + Sum + One + ScalarOperand + Send + Sync + Debug,
{
fn nrows(&self) -> usize {
self.nrows
}
fn ncols(&self) -> usize {
self.ncols
}
fn get(&self, i: usize, j: usize) -> LinalgResult<A> {
if i >= self.nrows || j >= self.ncols {
return Err(LinalgError::IndexError(format!(
"Index out of bounds: ({}, {}) for matrix of shape {}x{}",
i, j, self.nrows, self.ncols
)));
}
if i <= j {
Ok(self.first_row[j - i])
} else {
Ok(self.first_col[i - j])
}
}
fn matvec(&self, x: &ArrayView1<A>) -> LinalgResult<Array1<A>> {
if x.len() != self.ncols {
return Err(LinalgError::ShapeError(format!(
"Input vector has wrong length: expected {}, got {}",
self.ncols,
x.len()
)));
}
let mut result = Array1::zeros(self.nrows);
for i in 0..self.nrows {
for j in 0..self.ncols {
let matrix_value = match i.cmp(&j) {
std::cmp::Ordering::Equal => {
self.first_row[0]
}
std::cmp::Ordering::Less => {
self.first_row[j - i]
}
std::cmp::Ordering::Greater => {
self.first_col[i - j]
}
};
result[i] += matrix_value * x[j];
}
}
Ok(result)
}
fn matvec_transpose(&self, x: &ArrayView1<A>) -> LinalgResult<Array1<A>> {
if x.len() != self.nrows {
return Err(LinalgError::ShapeError(format!(
"Input vector has wrong length: expected {}, got {}",
self.nrows,
x.len()
)));
}
let mut result = Array1::zeros(self.ncols);
for j in 0..self.ncols {
for i in 0..self.nrows {
let matrix_value = match i.cmp(&j) {
std::cmp::Ordering::Equal => {
self.first_row[0]
}
std::cmp::Ordering::Greater => {
self.first_col[i - j]
}
std::cmp::Ordering::Less => {
self.first_row[j - i]
}
};
result[j] += matrix_value * x[i];
}
}
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_toeplitz_creation() {
let first_row = array![1.0, 2.0, 3.0];
let first_col = array![1.0, 4.0, 5.0];
let toeplitz =
ToeplitzMatrix::new(first_row.view(), first_col.view()).expect("Operation failed");
assert_eq!(toeplitz.nrows(), 3);
assert_eq!(toeplitz.ncols(), 3);
assert_relative_eq!(toeplitz.get(0, 0).expect("Operation failed"), 1.0);
assert_relative_eq!(toeplitz.get(0, 1).expect("Operation failed"), 2.0);
assert_relative_eq!(toeplitz.get(0, 2).expect("Operation failed"), 3.0);
assert_relative_eq!(toeplitz.get(1, 0).expect("Operation failed"), 4.0);
assert_relative_eq!(toeplitz.get(1, 1).expect("Operation failed"), 1.0);
assert_relative_eq!(toeplitz.get(1, 2).expect("Operation failed"), 2.0);
assert_relative_eq!(toeplitz.get(2, 0).expect("Operation failed"), 5.0);
assert_relative_eq!(toeplitz.get(2, 1).expect("Operation failed"), 4.0);
assert_relative_eq!(toeplitz.get(2, 2).expect("Operation failed"), 1.0);
}
#[test]
fn test_toeplitz_symmetric() {
let first_row = array![1.0, 2.0, 3.0];
let toeplitz = ToeplitzMatrix::new_symmetric(first_row.view()).expect("Operation failed");
assert_eq!(toeplitz.nrows(), 3);
assert_eq!(toeplitz.ncols(), 3);
assert_relative_eq!(toeplitz.get(0, 0).expect("Operation failed"), 1.0);
assert_relative_eq!(toeplitz.get(0, 1).expect("Operation failed"), 2.0);
assert_relative_eq!(toeplitz.get(0, 2).expect("Operation failed"), 3.0);
assert_relative_eq!(toeplitz.get(1, 0).expect("Operation failed"), 2.0);
assert_relative_eq!(toeplitz.get(1, 1).expect("Operation failed"), 1.0);
assert_relative_eq!(toeplitz.get(1, 2).expect("Operation failed"), 2.0);
assert_relative_eq!(toeplitz.get(2, 0).expect("Operation failed"), 3.0);
assert_relative_eq!(toeplitz.get(2, 1).expect("Operation failed"), 2.0);
assert_relative_eq!(toeplitz.get(2, 2).expect("Operation failed"), 1.0);
}
#[test]
fn test_toeplitz_from_parts() {
let r = array![2.0, 3.0];
let l = array![4.0, 5.0];
let c = 1.0;
let toeplitz = ToeplitzMatrix::from_parts(c, r.view(), l.view()).expect("Operation failed");
assert_eq!(toeplitz.nrows(), 3);
assert_eq!(toeplitz.ncols(), 3);
assert_relative_eq!(toeplitz.get(0, 0).expect("Operation failed"), 1.0);
assert_relative_eq!(toeplitz.get(0, 1).expect("Operation failed"), 2.0);
assert_relative_eq!(toeplitz.get(0, 2).expect("Operation failed"), 3.0);
assert_relative_eq!(toeplitz.get(1, 0).expect("Operation failed"), 4.0);
assert_relative_eq!(toeplitz.get(1, 1).expect("Operation failed"), 1.0);
assert_relative_eq!(toeplitz.get(1, 2).expect("Operation failed"), 2.0);
assert_relative_eq!(toeplitz.get(2, 0).expect("Operation failed"), 5.0);
assert_relative_eq!(toeplitz.get(2, 1).expect("Operation failed"), 4.0);
assert_relative_eq!(toeplitz.get(2, 2).expect("Operation failed"), 1.0);
}
#[test]
fn test_toeplitz_matvec() {
let first_row = array![1.0, 2.0, 3.0];
let first_col = array![1.0, 4.0, 5.0];
let toeplitz =
ToeplitzMatrix::new(first_row.view(), first_col.view()).expect("Operation failed");
let x = array![1.0, 2.0, 3.0];
let y = toeplitz.matvec(&x.view()).expect("Operation failed");
assert_eq!(y.len(), 3);
assert_relative_eq!(y[0], 14.0);
assert_relative_eq!(y[1], 12.0);
assert_relative_eq!(y[2], 16.0);
}
#[test]
fn test_toeplitz_matvec_transpose() {
let first_row = array![1.0, 2.0, 3.0];
let first_col = array![1.0, 4.0, 5.0];
let toeplitz =
ToeplitzMatrix::new(first_row.view(), first_col.view()).expect("Operation failed");
let x = array![1.0, 2.0, 3.0];
let y = toeplitz
.matvec_transpose(&x.view())
.expect("Operation failed");
assert_eq!(y.len(), 3);
assert_relative_eq!(y[0], 24.0);
assert_relative_eq!(y[1], 16.0);
assert_relative_eq!(y[2], 10.0);
}
#[test]
fn test_toeplitz_to_dense() {
let first_row = array![1.0, 2.0, 3.0];
let first_col = array![1.0, 4.0, 5.0];
let toeplitz =
ToeplitzMatrix::new(first_row.view(), first_col.view()).expect("Operation failed");
let dense = toeplitz.to_dense().expect("Operation failed");
assert_eq!(dense.shape(), &[3, 3]);
assert_relative_eq!(dense[[0, 0]], 1.0);
assert_relative_eq!(dense[[0, 1]], 2.0);
assert_relative_eq!(dense[[0, 2]], 3.0);
assert_relative_eq!(dense[[1, 0]], 4.0);
assert_relative_eq!(dense[[1, 1]], 1.0);
assert_relative_eq!(dense[[1, 2]], 2.0);
assert_relative_eq!(dense[[2, 0]], 5.0);
assert_relative_eq!(dense[[2, 1]], 4.0);
assert_relative_eq!(dense[[2, 2]], 1.0);
}
#[test]
fn test_invalid_inputs() {
let first_row = array![1.0, 2.0, 3.0];
let first_col = array![2.0, 4.0, 5.0];
let result = ToeplitzMatrix::<f64>::new(first_row.view(), first_col.view());
assert!(result.is_err());
let first_row = array![];
let first_col = array![];
let result = ToeplitzMatrix::<f64>::new(first_row.view(), first_col.view());
assert!(result.is_err());
}
}