use scirs2_core::ndarray::{Array1, Array2};
use crate::error::{GraphError, Result};
use crate::spectral_graph::graph_laplacian;
fn symmetric_eigen(a: &Array2<f64>) -> Result<(Array1<f64>, Array2<f64>)> {
let n = a.nrows();
if n == 0 {
return Err(GraphError::InvalidGraph("empty matrix".into()));
}
if a.ncols() != n {
return Err(GraphError::InvalidGraph("matrix must be square".into()));
}
let mut m = a.clone();
let mut v = Array2::<f64>::eye(n);
const MAX_SWEEPS: usize = 500;
const TOL: f64 = 1e-12;
for _ in 0..MAX_SWEEPS {
let mut max_val = 0.0_f64;
let mut p = 0_usize;
let mut q = 1_usize;
for i in 0..n {
for j in (i + 1)..n {
let v_ij = m[[i, j]].abs();
if v_ij > max_val {
max_val = v_ij;
p = i;
q = j;
}
}
}
if max_val < TOL {
break;
}
let theta = if (m[[q, q]] - m[[p, p]]).abs() < TOL {
std::f64::consts::FRAC_PI_4
} else {
0.5 * ((2.0 * m[[p, q]]) / (m[[q, q]] - m[[p, p]])).atan()
};
let cos_t = theta.cos();
let sin_t = theta.sin();
let mut new_m = m.clone();
for r in 0..n {
if r != p && r != q {
new_m[[r, p]] = cos_t * m[[r, p]] - sin_t * m[[r, q]];
new_m[[p, r]] = new_m[[r, p]];
new_m[[r, q]] = sin_t * m[[r, p]] + cos_t * m[[r, q]];
new_m[[q, r]] = new_m[[r, q]];
}
}
new_m[[p, p]] = cos_t * cos_t * m[[p, p]] - 2.0 * sin_t * cos_t * m[[p, q]]
+ sin_t * sin_t * m[[q, q]];
new_m[[q, q]] = sin_t * sin_t * m[[p, p]] + 2.0 * sin_t * cos_t * m[[p, q]]
+ cos_t * cos_t * m[[q, q]];
new_m[[p, q]] = 0.0;
new_m[[q, p]] = 0.0;
m = new_m;
let v_old = v.clone();
for r in 0..n {
v[[r, p]] = cos_t * v_old[[r, p]] - sin_t * v_old[[r, q]];
v[[r, q]] = sin_t * v_old[[r, p]] + cos_t * v_old[[r, q]];
}
}
let eigenvalues = Array1::from_iter((0..n).map(|i| m[[i, i]]));
let mut idx: Vec<usize> = (0..n).collect();
idx.sort_by(|&a, &b| eigenvalues[a].partial_cmp(&eigenvalues[b]).unwrap_or(std::cmp::Ordering::Equal));
let sorted_evals = Array1::from_iter(idx.iter().map(|&i| eigenvalues[i]));
let mut sorted_evecs = Array2::<f64>::zeros((n, n));
for (new_col, &old_col) in idx.iter().enumerate() {
for row in 0..n {
sorted_evecs[[row, new_col]] = v[[row, old_col]];
}
}
Ok((sorted_evals, sorted_evecs))
}
#[derive(Debug, Clone)]
pub struct GraphFourierTransform {
pub eigenvalues: Array1<f64>,
pub eigenvectors: Array2<f64>,
}
impl GraphFourierTransform {
pub fn from_adjacency(adj: &Array2<f64>) -> Result<Self> {
let n = adj.nrows();
if n == 0 {
return Err(GraphError::InvalidGraph("empty adjacency matrix".into()));
}
let lap = graph_laplacian(adj);
let (eigenvalues, eigenvectors) = symmetric_eigen(&lap)?;
Ok(Self { eigenvalues, eigenvectors })
}
pub fn from_laplacian(laplacian: &Array2<f64>) -> Result<Self> {
let (eigenvalues, eigenvectors) = symmetric_eigen(laplacian)?;
Ok(Self { eigenvalues, eigenvectors })
}
pub fn num_nodes(&self) -> usize {
self.eigenvalues.len()
}
pub fn transform(&self, signal: &Array1<f64>) -> Result<Array1<f64>> {
let n = self.num_nodes();
if signal.len() != n {
return Err(GraphError::InvalidParameter {
param: "signal.len()".into(),
value: signal.len().to_string(),
expected: n.to_string(),
context: "GFT forward transform".into(),
});
}
let mut x_hat = Array1::<f64>::zeros(n);
for k in 0..n {
let mut acc = 0.0_f64;
for i in 0..n {
acc += self.eigenvectors[[i, k]] * signal[i];
}
x_hat[k] = acc;
}
Ok(x_hat)
}
pub fn inverse(&self, freq_signal: &Array1<f64>) -> Result<Array1<f64>> {
let n = self.num_nodes();
if freq_signal.len() != n {
return Err(GraphError::InvalidParameter {
param: "freq_signal.len()".into(),
value: freq_signal.len().to_string(),
expected: n.to_string(),
context: "GFT inverse transform".into(),
});
}
let mut x = Array1::<f64>::zeros(n);
for i in 0..n {
let mut acc = 0.0_f64;
for k in 0..n {
acc += self.eigenvectors[[i, k]] * freq_signal[k];
}
x[i] = acc;
}
Ok(x)
}
}
pub trait GraphFilter {
fn apply(&self, gft: &GraphFourierTransform, signal: &Array1<f64>) -> Result<Array1<f64>>;
fn frequency_response(&self, gft: &GraphFourierTransform) -> Array1<f64>;
}
#[derive(Debug, Clone)]
pub struct IdealLowPass {
pub k: usize,
}
impl IdealLowPass {
pub fn new(k: usize) -> Self {
Self { k }
}
}
impl GraphFilter for IdealLowPass {
fn frequency_response(&self, gft: &GraphFourierTransform) -> Array1<f64> {
let n = gft.num_nodes();
Array1::from_iter((0..n).map(|i| if i < self.k { 1.0 } else { 0.0 }))
}
fn apply(&self, gft: &GraphFourierTransform, signal: &Array1<f64>) -> Result<Array1<f64>> {
let x_hat = gft.transform(signal)?;
let h = self.frequency_response(gft);
let filtered_hat = Array1::from_iter(x_hat.iter().zip(h.iter()).map(|(a, b)| a * b));
gft.inverse(&filtered_hat)
}
}
#[derive(Debug, Clone)]
pub struct IdealHighPass {
pub k: usize,
}
impl IdealHighPass {
pub fn new(k: usize) -> Self {
Self { k }
}
}
impl GraphFilter for IdealHighPass {
fn frequency_response(&self, gft: &GraphFourierTransform) -> Array1<f64> {
let n = gft.num_nodes();
Array1::from_iter((0..n).map(|i| if i < self.k { 0.0 } else { 1.0 }))
}
fn apply(&self, gft: &GraphFourierTransform, signal: &Array1<f64>) -> Result<Array1<f64>> {
let x_hat = gft.transform(signal)?;
let h = self.frequency_response(gft);
let filtered_hat = Array1::from_iter(x_hat.iter().zip(h.iter()).map(|(a, b)| a * b));
gft.inverse(&filtered_hat)
}
}
#[derive(Debug, Clone)]
pub struct GraphBandpass {
pub low_k: usize,
pub high_k: usize,
}
impl GraphBandpass {
pub fn new(low_k: usize, high_k: usize) -> Self {
Self { low_k, high_k }
}
}
impl GraphFilter for GraphBandpass {
fn frequency_response(&self, gft: &GraphFourierTransform) -> Array1<f64> {
let n = gft.num_nodes();
Array1::from_iter((0..n).map(|i| {
if i >= self.low_k && i < self.high_k {
1.0
} else {
0.0
}
}))
}
fn apply(&self, gft: &GraphFourierTransform, signal: &Array1<f64>) -> Result<Array1<f64>> {
let x_hat = gft.transform(signal)?;
let h = self.frequency_response(gft);
let filtered_hat = Array1::from_iter(x_hat.iter().zip(h.iter()).map(|(a, b)| a * b));
gft.inverse(&filtered_hat)
}
}
#[derive(Debug, Clone)]
pub struct GraphWavelet {
pub scale: f64,
kernel: Array2<f64>,
}
impl GraphWavelet {
pub fn new(gft: &GraphFourierTransform, scale: f64) -> Result<Self> {
if scale <= 0.0 {
return Err(GraphError::InvalidParameter {
param: "scale".into(),
value: scale.to_string(),
expected: "> 0".into(),
context: "GraphWavelet construction".into(),
});
}
let n = gft.num_nodes();
let h: Vec<f64> = gft.eigenvalues.iter().map(|&lam| (-scale * lam).exp()).collect();
let mut kernel = Array2::<f64>::zeros((n, n));
for i in 0..n {
for j in 0..n {
let mut acc = 0.0_f64;
for k in 0..n {
acc += gft.eigenvectors[[i, k]] * h[k] * gft.eigenvectors[[j, k]];
}
kernel[[i, j]] = acc;
}
}
Ok(Self { scale, kernel })
}
pub fn apply(&self, signal: &Array1<f64>) -> Result<Array1<f64>> {
let n = self.kernel.nrows();
if signal.len() != n {
return Err(GraphError::InvalidParameter {
param: "signal.len()".into(),
value: signal.len().to_string(),
expected: n.to_string(),
context: "GraphWavelet apply".into(),
});
}
let mut out = Array1::<f64>::zeros(n);
for i in 0..n {
let mut acc = 0.0_f64;
for j in 0..n {
acc += self.kernel[[i, j]] * signal[j];
}
out[i] = acc;
}
Ok(out)
}
pub fn wavelet_atom(&self, s: usize) -> Result<Array1<f64>> {
let n = self.kernel.nrows();
if s >= n {
return Err(GraphError::InvalidParameter {
param: "s".into(),
value: s.to_string(),
expected: format!("< {n}"),
context: "GraphWavelet atom".into(),
});
}
Ok(self.kernel.column(s).to_owned())
}
pub fn kernel(&self) -> &Array2<f64> {
&self.kernel
}
}
#[derive(Debug, Clone)]
pub struct GraphSignalSmoother {
pub alpha: f64,
}
impl GraphSignalSmoother {
pub fn new(alpha: f64) -> Result<Self> {
if alpha <= 0.0 {
return Err(GraphError::InvalidParameter {
param: "alpha".into(),
value: alpha.to_string(),
expected: "> 0".into(),
context: "GraphSignalSmoother construction".into(),
});
}
Ok(Self { alpha })
}
pub fn smooth(&self, gft: &GraphFourierTransform, signal: &Array1<f64>) -> Result<Array1<f64>> {
let n = gft.num_nodes();
if signal.len() != n {
return Err(GraphError::InvalidParameter {
param: "signal.len()".into(),
value: signal.len().to_string(),
expected: n.to_string(),
context: "GraphSignalSmoother smooth".into(),
});
}
let y_hat = gft.transform(signal)?;
let x_hat = Array1::from_iter(
y_hat
.iter()
.zip(gft.eigenvalues.iter())
.map(|(&c, &lam)| c / (1.0 + self.alpha * lam)),
);
gft.inverse(&x_hat)
}
pub fn frequency_response(&self, gft: &GraphFourierTransform) -> Array1<f64> {
Array1::from_iter(
gft.eigenvalues
.iter()
.map(|&lam| 1.0 / (1.0 + self.alpha * lam)),
)
}
pub fn total_variation(adj: &Array2<f64>, signal: &Array1<f64>) -> Result<f64> {
let n = adj.nrows();
if signal.len() != n {
return Err(GraphError::InvalidParameter {
param: "signal.len()".into(),
value: signal.len().to_string(),
expected: n.to_string(),
context: "total_variation".into(),
});
}
let mut tv = 0.0_f64;
for i in 0..n {
for j in (i + 1)..n {
let w = adj[[i, j]];
if w != 0.0 {
let diff = signal[i] - signal[j];
tv += w * diff * diff;
}
}
}
Ok(tv)
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array1;
fn path_graph_adj(n: usize) -> Array2<f64> {
let mut adj = Array2::<f64>::zeros((n, n));
for i in 0..(n - 1) {
adj[[i, i + 1]] = 1.0;
adj[[i + 1, i]] = 1.0;
}
adj
}
#[test]
fn test_gft_reconstruction() {
let adj = path_graph_adj(5);
let gft = GraphFourierTransform::from_adjacency(&adj).unwrap();
let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 2.0, 1.0]);
let freq = gft.transform(&signal).unwrap();
let rec = gft.inverse(&freq).unwrap();
for (a, b) in signal.iter().zip(rec.iter()) {
assert!((a - b).abs() < 1e-9, "Reconstruction error: {a} vs {b}");
}
}
#[test]
fn test_low_pass_smoothing() {
let adj = path_graph_adj(6);
let gft = GraphFourierTransform::from_adjacency(&adj).unwrap();
let signal = Array1::from_vec(vec![1.0, -1.0, 1.0, -1.0, 1.0, -1.0]);
let lp = IdealLowPass::new(2);
let smoothed = lp.apply(&gft, &signal).unwrap();
let tv_orig = GraphSignalSmoother::total_variation(&adj, &signal).unwrap();
let tv_smooth = GraphSignalSmoother::total_variation(&adj, &smoothed).unwrap();
assert!(tv_smooth < tv_orig, "LP filter should reduce TV: {tv_smooth} vs {tv_orig}");
}
#[test]
fn test_high_pass_removes_dc() {
let adj = path_graph_adj(5);
let gft = GraphFourierTransform::from_adjacency(&adj).unwrap();
let dc_signal = Array1::from_vec(vec![1.0, 1.0, 1.0, 1.0, 1.0]);
let hp = IdealHighPass::new(1);
let out = hp.apply(&gft, &dc_signal).unwrap();
for v in out.iter() {
assert!(v.abs() < 1e-9, "HP filter should remove DC: got {v}");
}
}
#[test]
fn test_bandpass() {
let adj = path_graph_adj(8);
let gft = GraphFourierTransform::from_adjacency(&adj).unwrap();
let signal = Array1::from_vec(vec![1.0, 0.5, 0.0, -0.5, -1.0, -0.5, 0.0, 0.5]);
let bp = GraphBandpass::new(2, 5);
let out = bp.apply(&gft, &signal).unwrap();
assert_eq!(out.len(), 8);
}
#[test]
fn test_wavelet_kernel_symmetry() {
let adj = path_graph_adj(5);
let gft = GraphFourierTransform::from_adjacency(&adj).unwrap();
let wv = GraphWavelet::new(&gft, 1.0).unwrap();
let k = wv.kernel();
for i in 0..5 {
for j in 0..5 {
assert!((k[[i, j]] - k[[j, i]]).abs() < 1e-10);
}
}
}
#[test]
fn test_smoother_reduces_variation() {
let adj = path_graph_adj(6);
let gft = GraphFourierTransform::from_adjacency(&adj).unwrap();
let noisy = Array1::from_vec(vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0]);
let smoother = GraphSignalSmoother::new(5.0).unwrap();
let smoothed = smoother.smooth(&gft, &noisy).unwrap();
let tv_noisy = GraphSignalSmoother::total_variation(&adj, &noisy).unwrap();
let tv_smooth = GraphSignalSmoother::total_variation(&adj, &smoothed).unwrap();
assert!(tv_smooth < tv_noisy);
}
}