use scirs2_core::ndarray::ScalarOperand;
use scirs2_core::ndarray::{Array1, ArrayView1};
use scirs2_core::numeric::{Float, NumAssign, One, Zero};
use std::{fmt::Debug, iter::Sum};
use super::StructuredMatrix;
use crate::error::{LinalgError, LinalgResult};
#[derive(Debug, Clone)]
pub struct CirculantMatrix<A>
where
A: Float + NumAssign + Zero + Sum + One + ScalarOperand + Send + Sync + Debug,
{
first_row: Array1<A>,
n: usize,
}
impl<A> CirculantMatrix<A>
where
A: Float + NumAssign + Zero + Sum + One + ScalarOperand + Send + Sync + Debug,
{
pub fn new(_firstrow: ArrayView1<A>) -> LinalgResult<Self> {
if _firstrow.is_empty() {
return Err(LinalgError::InvalidInputError(
"First row must not be empty".to_string(),
));
}
Ok(CirculantMatrix {
first_row: _firstrow.to_owned(),
n: _firstrow.len(),
})
}
pub fn from_kernel(kernel: ArrayView1<A>) -> LinalgResult<Self> {
if kernel.is_empty() {
return Err(LinalgError::InvalidInputError(
"Kernel must not be empty".to_string(),
));
}
let n = kernel.len();
let mut first_row = Array1::zeros(n);
first_row[0] = kernel[0];
for i in 1..n {
first_row[i] = kernel[n - i];
}
Ok(CirculantMatrix { first_row, n })
}
pub fn first_row(&self) -> scirs2_core::ndarray::ArrayView1<A> {
self.first_row.view()
}
pub fn size(&self) -> usize {
self.n
}
}
impl<A> StructuredMatrix<A> for CirculantMatrix<A>
where
A: Float + NumAssign + Zero + Sum + One + ScalarOperand + Send + Sync + Debug,
{
fn nrows(&self) -> usize {
self.n
}
fn ncols(&self) -> usize {
self.n
}
fn get(&self, i: usize, j: usize) -> LinalgResult<A> {
if i >= self.n || j >= self.n {
return Err(LinalgError::IndexError(format!(
"Index out of bounds: ({}, {}) for matrix of shape {}x{}",
i, j, self.n, self.n
)));
}
let idx = (j as isize - i as isize).rem_euclid(self.n as isize) as usize;
Ok(self.first_row[idx])
}
fn matvec(&self, x: &ArrayView1<A>) -> LinalgResult<Array1<A>> {
if x.len() != self.n {
return Err(LinalgError::ShapeError(format!(
"Input vector has wrong length: expected {}, got {}",
self.n,
x.len()
)));
}
let mut result = Array1::zeros(self.n);
for i in 0..self.n {
for j in 0..self.n {
let index = (j + self.n - i) % self.n;
result[i] += self.first_row[index] * x[j];
}
}
Ok(result)
}
fn matvec_transpose(&self, x: &ArrayView1<A>) -> LinalgResult<Array1<A>> {
if x.len() != self.n {
return Err(LinalgError::ShapeError(format!(
"Input vector has wrong length: expected {}, got {}",
self.n,
x.len()
)));
}
let mut result = Array1::zeros(self.n);
for j in 0..self.n {
for i in 0..self.n {
let index = (i + self.n - j) % self.n;
result[j] += self.first_row[index] * x[i];
}
}
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_circulant_creation() {
let first_row = array![1.0, 2.0, 3.0, 4.0];
let circulant = CirculantMatrix::new(first_row.view()).expect("Operation failed");
assert_eq!(circulant.nrows(), 4);
assert_eq!(circulant.ncols(), 4);
assert_relative_eq!(circulant.get(0, 0).expect("Operation failed"), 1.0);
assert_relative_eq!(circulant.get(0, 1).expect("Operation failed"), 2.0);
assert_relative_eq!(circulant.get(0, 2).expect("Operation failed"), 3.0);
assert_relative_eq!(circulant.get(0, 3).expect("Operation failed"), 4.0);
assert_relative_eq!(circulant.get(1, 0).expect("Operation failed"), 4.0);
assert_relative_eq!(circulant.get(1, 1).expect("Operation failed"), 1.0);
assert_relative_eq!(circulant.get(1, 2).expect("Operation failed"), 2.0);
assert_relative_eq!(circulant.get(1, 3).expect("Operation failed"), 3.0);
assert_relative_eq!(circulant.get(2, 0).expect("Operation failed"), 3.0);
assert_relative_eq!(circulant.get(2, 1).expect("Operation failed"), 4.0);
assert_relative_eq!(circulant.get(2, 2).expect("Operation failed"), 1.0);
assert_relative_eq!(circulant.get(2, 3).expect("Operation failed"), 2.0);
assert_relative_eq!(circulant.get(3, 0).expect("Operation failed"), 2.0);
assert_relative_eq!(circulant.get(3, 1).expect("Operation failed"), 3.0);
assert_relative_eq!(circulant.get(3, 2).expect("Operation failed"), 4.0);
assert_relative_eq!(circulant.get(3, 3).expect("Operation failed"), 1.0);
}
#[test]
fn test_circulant_kernel() {
let kernel = array![5.0, 1.0, 2.0, 3.0];
let circulant = CirculantMatrix::from_kernel(kernel.view()).expect("Operation failed");
assert_relative_eq!(circulant.first_row[0], 5.0);
assert_relative_eq!(circulant.first_row[1], 3.0);
assert_relative_eq!(circulant.first_row[2], 2.0);
assert_relative_eq!(circulant.first_row[3], 1.0);
}
#[test]
fn test_circulant_matvec() {
let first_row = array![1.0, 2.0, 3.0, 4.0];
let circulant = CirculantMatrix::new(first_row.view()).expect("Operation failed");
let x = array![1.0, 2.0, 3.0, 4.0];
let y = circulant.matvec(&x.view()).expect("Operation failed");
assert_eq!(y.len(), 4);
assert_relative_eq!(y[0], 30.0);
assert_relative_eq!(y[1], 24.0);
assert_relative_eq!(y[2], 22.0);
assert_relative_eq!(y[3], 24.0);
}
#[test]
fn test_circulant_matvec_transpose() {
let first_row = array![1.0, 2.0, 3.0, 4.0];
let circulant = CirculantMatrix::new(first_row.view()).expect("Operation failed");
let x = array![1.0, 2.0, 3.0, 4.0];
let y = circulant
.matvec_transpose(&x.view())
.expect("Operation failed");
assert_eq!(y.len(), 4);
assert_relative_eq!(y[0], 30.0);
assert_relative_eq!(y[1], 24.0);
assert_relative_eq!(y[2], 22.0);
assert_relative_eq!(y[3], 24.0);
}
#[test]
fn test_circulant_to_dense() {
let first_row = array![1.0, 2.0, 3.0, 4.0];
let circulant = CirculantMatrix::new(first_row.view()).expect("Operation failed");
let dense = circulant.to_dense().expect("Operation failed");
assert_eq!(dense.shape(), &[4, 4]);
assert_relative_eq!(dense[[0, 0]], 1.0);
assert_relative_eq!(dense[[0, 1]], 2.0);
assert_relative_eq!(dense[[0, 2]], 3.0);
assert_relative_eq!(dense[[0, 3]], 4.0);
assert_relative_eq!(dense[[1, 0]], 4.0);
assert_relative_eq!(dense[[1, 1]], 1.0);
assert_relative_eq!(dense[[1, 2]], 2.0);
assert_relative_eq!(dense[[1, 3]], 3.0);
assert_relative_eq!(dense[[2, 0]], 3.0);
assert_relative_eq!(dense[[2, 1]], 4.0);
assert_relative_eq!(dense[[2, 2]], 1.0);
assert_relative_eq!(dense[[2, 3]], 2.0);
assert_relative_eq!(dense[[3, 0]], 2.0);
assert_relative_eq!(dense[[3, 1]], 3.0);
assert_relative_eq!(dense[[3, 2]], 4.0);
assert_relative_eq!(dense[[3, 3]], 1.0);
}
#[test]
fn test_invalid_inputs() {
let first_row = array![];
let result = CirculantMatrix::<f64>::new(first_row.view());
assert!(result.is_err());
}
}