use scirs2_core::ndarray::ScalarOperand;
use scirs2_core::ndarray::{Array1, ArrayView1};
use scirs2_core::numeric::{Float, NumAssign, One, Zero};
use std::{fmt::Debug, iter::Sum};
use super::StructuredMatrix;
use crate::error::{LinalgError, LinalgResult};
#[derive(Debug, Clone)]
pub struct HankelMatrix<A>
where
A: Float + NumAssign + Zero + Sum + One + ScalarOperand + Send + Sync + Debug,
{
first_col: Array1<A>,
last_row: Array1<A>,
nrows: usize,
ncols: usize,
}
impl<A> HankelMatrix<A>
where
A: Float + NumAssign + Zero + Sum + One + ScalarOperand + Send + Sync + Debug,
{
pub fn new(_first_col: ArrayView1<A>, lastrow: ArrayView1<A>) -> LinalgResult<Self> {
if _first_col.is_empty() || lastrow.is_empty() {
return Err(LinalgError::InvalidInputError(
"Column and row must not be empty".to_string(),
));
}
if (_first_col[_first_col.len() - 1] - lastrow[0]).abs() > A::epsilon() {
return Err(LinalgError::InvalidInputError(
"Last element of first column must be the same as first element of last row"
.to_string(),
));
}
Ok(HankelMatrix {
first_col: _first_col.to_owned(),
last_row: lastrow.to_owned(),
nrows: _first_col.len(),
ncols: lastrow.len(),
})
}
pub fn from_sequence(
sequence: ArrayView1<A>,
n_rows: usize,
n_cols: usize,
) -> LinalgResult<Self> {
if sequence.len() < n_rows + n_cols - 1 {
return Err(LinalgError::InvalidInputError(format!(
"Sequence length must be at least nrows + ncols - 1 = {}, got {}",
n_rows + n_cols - 1,
sequence.len()
)));
}
let first_col = sequence
.slice(scirs2_core::ndarray::s![0..n_rows])
.to_owned();
let last_row = sequence
.slice(scirs2_core::ndarray::s![
(n_rows - 1)..(n_rows + n_cols - 1)
])
.to_owned();
Ok(HankelMatrix {
first_col,
last_row,
nrows: n_rows,
ncols: n_cols,
})
}
}
impl<A> StructuredMatrix<A> for HankelMatrix<A>
where
A: Float + NumAssign + Zero + Sum + One + ScalarOperand + Send + Sync + Debug,
{
fn nrows(&self) -> usize {
self.nrows
}
fn ncols(&self) -> usize {
self.ncols
}
fn get(&self, i: usize, j: usize) -> LinalgResult<A> {
if i >= self.nrows || j >= self.ncols {
return Err(LinalgError::IndexError(format!(
"Index out of bounds: ({}, {}) for matrix of shape {}x{}",
i, j, self.nrows, self.ncols
)));
}
let sum_idx = i + j;
if sum_idx < self.nrows {
Ok(self.first_col[sum_idx])
} else {
let j_idx = sum_idx - self.nrows + 1;
if j_idx < self.ncols {
Ok(self.last_row[j_idx])
} else {
Err(LinalgError::IndexError(format!(
"Index out of bounds: sum index {sum_idx} exceeds matrix dimensions"
)))
}
}
}
fn matvec(&self, x: &ArrayView1<A>) -> LinalgResult<Array1<A>> {
if x.len() != self.ncols {
return Err(LinalgError::ShapeError(format!(
"Input vector has wrong length: expected {}, got {}",
self.ncols,
x.len()
)));
}
let mut result = Array1::zeros(self.nrows);
for i in 0..self.nrows {
for j in 0..self.ncols {
result[i] += self.get(i, j).expect("Operation failed") * x[j];
}
}
Ok(result)
}
fn matvec_transpose(&self, x: &ArrayView1<A>) -> LinalgResult<Array1<A>> {
if x.len() != self.nrows {
return Err(LinalgError::ShapeError(format!(
"Input vector has wrong length: expected {}, got {}",
self.nrows,
x.len()
)));
}
let mut result = Array1::zeros(self.ncols);
for j in 0..self.ncols {
for i in 0..self.nrows {
result[j] += self.get(i, j).expect("Operation failed") * x[i];
}
}
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_hankel_creation() {
let first_col = array![1.0, 2.0, 3.0];
let last_row = array![3.0, 4.0, 5.0];
let hankel =
HankelMatrix::new(first_col.view(), last_row.view()).expect("Operation failed");
assert_eq!(hankel.nrows(), 3);
assert_eq!(hankel.ncols(), 3);
assert_relative_eq!(hankel.get(0, 0).expect("Operation failed"), 1.0);
assert_relative_eq!(hankel.get(0, 1).expect("Operation failed"), 2.0);
assert_relative_eq!(hankel.get(0, 2).expect("Operation failed"), 3.0);
assert_relative_eq!(hankel.get(1, 0).expect("Operation failed"), 2.0);
assert_relative_eq!(hankel.get(1, 1).expect("Operation failed"), 3.0);
assert_relative_eq!(hankel.get(1, 2).expect("Operation failed"), 4.0);
assert_relative_eq!(hankel.get(2, 0).expect("Operation failed"), 3.0);
assert_relative_eq!(hankel.get(2, 1).expect("Operation failed"), 4.0);
assert_relative_eq!(hankel.get(2, 2).expect("Operation failed"), 5.0);
}
#[test]
fn test_hankel_from_sequence() {
let sequence = array![1.0, 2.0, 3.0, 4.0, 5.0];
let hankel = HankelMatrix::from_sequence(sequence.view(), 3, 3).expect("Operation failed");
assert_eq!(hankel.nrows(), 3);
assert_eq!(hankel.ncols(), 3);
assert_relative_eq!(hankel.get(0, 0).expect("Operation failed"), 1.0);
assert_relative_eq!(hankel.get(0, 1).expect("Operation failed"), 2.0);
assert_relative_eq!(hankel.get(0, 2).expect("Operation failed"), 3.0);
assert_relative_eq!(hankel.get(1, 0).expect("Operation failed"), 2.0);
assert_relative_eq!(hankel.get(1, 1).expect("Operation failed"), 3.0);
assert_relative_eq!(hankel.get(1, 2).expect("Operation failed"), 4.0);
assert_relative_eq!(hankel.get(2, 0).expect("Operation failed"), 3.0);
assert_relative_eq!(hankel.get(2, 1).expect("Operation failed"), 4.0);
assert_relative_eq!(hankel.get(2, 2).expect("Operation failed"), 5.0);
}
#[test]
fn test_hankel_rectangular() {
let first_col = array![1.0, 2.0, 3.0, 4.0];
let last_row = array![4.0, 5.0, 6.0];
let hankel =
HankelMatrix::new(first_col.view(), last_row.view()).expect("Operation failed");
assert_eq!(hankel.nrows(), 4);
assert_eq!(hankel.ncols(), 3);
assert_relative_eq!(hankel.get(0, 0).expect("Operation failed"), 1.0);
assert_relative_eq!(hankel.get(0, 1).expect("Operation failed"), 2.0);
assert_relative_eq!(hankel.get(0, 2).expect("Operation failed"), 3.0);
assert_relative_eq!(hankel.get(1, 0).expect("Operation failed"), 2.0);
assert_relative_eq!(hankel.get(1, 1).expect("Operation failed"), 3.0);
assert_relative_eq!(hankel.get(1, 2).expect("Operation failed"), 4.0);
assert_relative_eq!(hankel.get(2, 0).expect("Operation failed"), 3.0);
assert_relative_eq!(hankel.get(2, 1).expect("Operation failed"), 4.0);
assert_relative_eq!(hankel.get(2, 2).expect("Operation failed"), 5.0);
assert_relative_eq!(hankel.get(3, 0).expect("Operation failed"), 4.0);
assert_relative_eq!(hankel.get(3, 1).expect("Operation failed"), 5.0);
assert_relative_eq!(hankel.get(3, 2).expect("Operation failed"), 6.0);
}
#[test]
fn test_hankel_matvec() {
let first_col = array![1.0, 2.0, 3.0];
let last_row = array![3.0, 4.0, 5.0];
let hankel =
HankelMatrix::new(first_col.view(), last_row.view()).expect("Operation failed");
let x = array![1.0, 2.0, 3.0];
let y = hankel.matvec(&x.view()).expect("Operation failed");
assert_eq!(y.len(), 3);
assert_relative_eq!(y[0], 14.0);
assert_relative_eq!(y[1], 20.0);
assert_relative_eq!(y[2], 26.0);
}
#[test]
fn test_hankel_matvec_transpose() {
let first_col = array![1.0, 2.0, 3.0];
let last_row = array![3.0, 4.0, 5.0];
let hankel =
HankelMatrix::new(first_col.view(), last_row.view()).expect("Operation failed");
let x = array![1.0, 2.0, 3.0];
let y = hankel
.matvec_transpose(&x.view())
.expect("Operation failed");
assert_eq!(y.len(), 3);
assert_relative_eq!(y[0], 14.0);
assert_relative_eq!(y[1], 20.0);
assert_relative_eq!(y[2], 26.0);
}
#[test]
fn test_hankel_to_dense() {
let first_col = array![1.0, 2.0, 3.0];
let last_row = array![3.0, 4.0, 5.0];
let hankel =
HankelMatrix::new(first_col.view(), last_row.view()).expect("Operation failed");
let dense = hankel.to_dense().expect("Operation failed");
assert_eq!(dense.shape(), &[3, 3]);
assert_relative_eq!(dense[[0, 0]], 1.0);
assert_relative_eq!(dense[[0, 1]], 2.0);
assert_relative_eq!(dense[[0, 2]], 3.0);
assert_relative_eq!(dense[[1, 0]], 2.0);
assert_relative_eq!(dense[[1, 1]], 3.0);
assert_relative_eq!(dense[[1, 2]], 4.0);
assert_relative_eq!(dense[[2, 0]], 3.0);
assert_relative_eq!(dense[[2, 1]], 4.0);
assert_relative_eq!(dense[[2, 2]], 5.0);
}
#[test]
fn test_invalid_inputs() {
let first_col = array![1.0, 2.0, 4.0]; let last_row = array![3.0, 4.0, 5.0];
let result = HankelMatrix::<f64>::new(first_col.view(), last_row.view());
assert!(result.is_err());
let first_col = array![];
let last_row = array![];
let result = HankelMatrix::<f64>::new(first_col.view(), last_row.view());
assert!(result.is_err());
let sequence = array![1.0, 2.0, 3.0];
let result = HankelMatrix::from_sequence(sequence.view(), 2, 3); assert!(result.is_err());
}
}