use scirs2_core::numeric::Complex64;
use crate::error::{FFTError, FFTResult};
use crate::fft::{fft, ifft};
use super::filter_bank::{FilterBank, FilterBankConfig};
#[derive(Debug, Clone)]
pub struct ScatteringConfig {
pub j_max: usize,
pub quality_factors: Vec<usize>,
pub max_order: usize,
pub average: bool,
pub oversampling: usize,
}
impl ScatteringConfig {
pub fn new(j_max: usize, quality_factors: Vec<usize>) -> Self {
Self {
j_max,
quality_factors,
max_order: 2,
average: true,
oversampling: 0,
}
}
#[must_use]
pub fn with_max_order(mut self, order: usize) -> Self {
self.max_order = order.min(2);
self
}
#[must_use]
pub fn with_average(mut self, average: bool) -> Self {
self.average = average;
self
}
#[must_use]
pub fn with_oversampling(mut self, oversampling: usize) -> Self {
self.oversampling = oversampling;
self
}
}
#[derive(Debug, Clone)]
pub enum ScatteringOrder {
Zeroth,
First { lambda1: usize },
Second { lambda1: usize, lambda2: usize },
}
#[derive(Debug, Clone)]
pub struct ScatteringCoefficients {
pub order: ScatteringOrder,
pub values: Vec<f64>,
}
#[derive(Debug, Clone)]
pub struct ScatteringResult {
pub coefficients: Vec<ScatteringCoefficients>,
pub num_zeroth: usize,
pub num_first: usize,
pub num_second: usize,
pub output_length: usize,
}
impl ScatteringResult {
pub fn zeroth_order(&self) -> &[ScatteringCoefficients] {
&self.coefficients[..self.num_zeroth]
}
pub fn first_order(&self) -> &[ScatteringCoefficients] {
&self.coefficients[self.num_zeroth..self.num_zeroth + self.num_first]
}
pub fn second_order(&self) -> &[ScatteringCoefficients] {
&self.coefficients[self.num_zeroth + self.num_first..]
}
pub fn flatten(&self) -> Vec<f64> {
let mut result = Vec::new();
for coeff in &self.coefficients {
result.extend_from_slice(&coeff.values);
}
result
}
pub fn total_energy(&self) -> f64 {
self.coefficients
.iter()
.flat_map(|c| c.values.iter())
.map(|v| v * v)
.sum()
}
}
#[derive(Debug, Clone)]
pub struct ScatteringTransform {
config: ScatteringConfig,
filter_bank: FilterBank,
}
impl ScatteringTransform {
pub fn new(config: ScatteringConfig, signal_length: usize) -> FFTResult<Self> {
if signal_length == 0 {
return Err(FFTError::ValueError(
"signal_length must be positive".to_string(),
));
}
let fb_config =
FilterBankConfig::new(config.j_max, config.quality_factors.clone(), signal_length);
let filter_bank = FilterBank::new(fb_config)?;
Ok(Self {
config,
filter_bank,
})
}
pub fn filter_bank(&self) -> &FilterBank {
&self.filter_bank
}
pub fn transform(&self, signal: &[f64]) -> FFTResult<ScatteringResult> {
if signal.is_empty() {
return Err(FFTError::ValueError(
"Input signal must not be empty".to_string(),
));
}
let fft_size = self.filter_bank.fft_size;
let mut padded = vec![0.0_f64; fft_size];
let copy_len = signal.len().min(fft_size);
padded[..copy_len].copy_from_slice(&signal[..copy_len]);
let x_hat = fft(&padded, Some(fft_size))?;
let subsample = if self.config.average {
let base = 2_usize.pow(self.config.j_max as u32);
base >> self.config.oversampling.min(self.config.j_max)
} else {
1
};
let output_length = fft_size.div_ceil(subsample);
let mut coefficients = Vec::new();
let mut num_first = 0;
let mut num_second = 0;
let s0 = convolve_and_subsample(&x_hat, &self.filter_bank.phi, fft_size, subsample)?;
coefficients.push(ScatteringCoefficients {
order: ScatteringOrder::Zeroth,
values: s0,
});
let num_zeroth = 1;
if self.config.max_order == 0 {
return Ok(ScatteringResult {
coefficients,
num_zeroth,
num_first,
num_second,
output_length,
});
}
let first_order_wavelets = self
.filter_bank
.wavelets
.first()
.ok_or_else(|| FFTError::ComputationError("No first-order wavelets".to_string()))?;
let mut u1_hats: Vec<Vec<Complex64>> = Vec::new();
for (lambda1, wavelet) in first_order_wavelets.iter().enumerate() {
let convolved: Vec<Complex64> = x_hat
.iter()
.zip(wavelet.freq_response.iter())
.map(|(x, w)| x * w)
.collect();
let u1_time = ifft(&convolved, None)?;
let u1_mod: Vec<f64> = u1_time.iter().map(|c| c.norm()).collect();
if self.config.max_order >= 2 {
let u1_mod_hat = fft(&u1_mod, Some(fft_size))?;
u1_hats.push(u1_mod_hat);
}
let u1_mod_hat_for_avg = if self.config.max_order >= 2 {
u1_hats.last().ok_or_else(|| {
FFTError::ComputationError("u1_hats should not be empty".to_string())
})?
} else {
&fft(&u1_mod, Some(fft_size))?
};
let s1 = convolve_and_subsample(
u1_mod_hat_for_avg,
&self.filter_bank.phi,
fft_size,
subsample,
)?;
coefficients.push(ScatteringCoefficients {
order: ScatteringOrder::First { lambda1 },
values: s1,
});
num_first += 1;
}
if self.config.max_order < 2 {
return Ok(ScatteringResult {
coefficients,
num_zeroth,
num_first,
num_second,
output_length,
});
}
let second_order_wavelets = if self.filter_bank.wavelets.len() > 1 {
&self.filter_bank.wavelets[1]
} else {
&self.filter_bank.wavelets[0]
};
for (lambda1, u1_hat) in u1_hats.iter().enumerate() {
for (lambda2, wavelet2) in second_order_wavelets.iter().enumerate() {
let first_scale = if !first_order_wavelets.is_empty() {
first_order_wavelets[lambda1].j
} else {
0
};
let second_scale = wavelet2.j;
if second_scale <= first_scale {
continue;
}
let convolved2: Vec<Complex64> = u1_hat
.iter()
.zip(wavelet2.freq_response.iter())
.map(|(u, w)| u * w)
.collect();
let u2_time = ifft(&convolved2, None)?;
let u2_mod: Vec<f64> = u2_time.iter().map(|c| c.norm()).collect();
let u2_mod_hat = fft(&u2_mod, Some(fft_size))?;
let s2 = convolve_and_subsample(
&u2_mod_hat,
&self.filter_bank.phi,
fft_size,
subsample,
)?;
coefficients.push(ScatteringCoefficients {
order: ScatteringOrder::Second { lambda1, lambda2 },
values: s2,
});
num_second += 1;
}
}
Ok(ScatteringResult {
coefficients,
num_zeroth,
num_first,
num_second,
output_length,
})
}
pub fn features(&self, signal: &[f64]) -> FFTResult<Vec<f64>> {
let result = self.transform(signal)?;
Ok(result.flatten())
}
}
fn convolve_and_subsample(
x_hat: &[Complex64],
filter_hat: &[Complex64],
fft_size: usize,
subsample: usize,
) -> FFTResult<Vec<f64>> {
let product: Vec<Complex64> = x_hat
.iter()
.zip(filter_hat.iter())
.map(|(x, f)| x * f)
.collect();
let time_domain = ifft(&product, None)?;
let output_len = fft_size.div_ceil(subsample);
let mut result = Vec::with_capacity(output_len);
for i in 0..output_len {
let idx = i * subsample;
if idx < time_domain.len() {
result.push(time_domain[idx].re);
} else {
result.push(0.0);
}
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use std::f64::consts::PI;
#[test]
fn test_scattering_basic() {
let config = ScatteringConfig::new(3, vec![2, 1]);
let st = ScatteringTransform::new(config, 256)
.expect("scattering transform creation should succeed");
let signal: Vec<f64> = (0..256)
.map(|i| (2.0 * PI * 10.0 * i as f64 / 256.0).sin())
.collect();
let result = st.transform(&signal).expect("transform should succeed");
assert_eq!(result.num_zeroth, 1);
assert!(result.num_first > 0);
}
#[test]
fn test_translation_invariance() {
let config = ScatteringConfig::new(3, vec![4, 1]).with_max_order(1);
let n = 512;
let st = ScatteringTransform::new(config, n)
.expect("scattering transform creation should succeed");
let mut signal1 = vec![0.0; n];
for i in 0..n {
let t = (i as f64 - 128.0) / 20.0;
signal1[i] = (-0.5 * t * t).exp();
}
let shift = 64;
let mut signal2 = vec![0.0; n];
for i in 0..n {
let src = (i + n - shift) % n;
signal2[i] = signal1[src];
}
let r1 = st.transform(&signal1).expect("transform should succeed");
let r2 = st.transform(&signal2).expect("transform should succeed");
let s1_energies_1: Vec<f64> = r1
.first_order()
.iter()
.map(|c| c.values.iter().map(|v| v * v).sum::<f64>())
.collect();
let s1_energies_2: Vec<f64> = r2
.first_order()
.iter()
.map(|c| c.values.iter().map(|v| v * v).sum::<f64>())
.collect();
let total_e1: f64 = s1_energies_1.iter().sum();
let total_e2: f64 = s1_energies_2.iter().sum();
if total_e1 > 1e-15 {
let rel_error = ((total_e1 - total_e2) / total_e1).abs();
assert!(
rel_error < 0.3,
"First-order total energy should be approximately translation invariant, \
rel_error={:.4} (e1={:.4}, e2={:.4})",
rel_error,
total_e1,
total_e2
);
}
}
#[test]
fn test_output_dimensions() {
let j = 3;
let q1 = 4;
let q2 = 1;
let config = ScatteringConfig::new(j, vec![q1, q2]);
let n = 256;
let st = ScatteringTransform::new(config, n)
.expect("scattering transform creation should succeed");
let signal: Vec<f64> = (0..n)
.map(|i| (2.0 * PI * 5.0 * i as f64 / n as f64).sin())
.collect();
let result = st.transform(&signal).expect("transform should succeed");
assert_eq!(result.num_first, j * q1);
let _ = result.num_second;
let expected_len = result.output_length;
for coeff in &result.coefficients {
assert_eq!(
coeff.values.len(),
expected_len,
"coefficient output length mismatch"
);
}
}
#[test]
fn test_energy_approximate_preservation() {
let config = ScatteringConfig::new(3, vec![4, 1]);
let n = 256;
let st = ScatteringTransform::new(config, n)
.expect("scattering transform creation should succeed");
let signal: Vec<f64> = (0..n)
.map(|i| {
let t = i as f64 / n as f64;
(2.0 * PI * 8.0 * t).sin() + 0.5 * (2.0 * PI * 32.0 * t).cos()
})
.collect();
let input_energy: f64 = signal.iter().map(|v| v * v).sum();
let result = st.transform(&signal).expect("transform should succeed");
let scatter_energy = result.total_energy();
assert!(scatter_energy > 0.0, "scattering energy should be positive");
}
#[test]
fn test_sine_wave_first_order() {
let config = ScatteringConfig::new(4, vec![8]).with_max_order(1);
let n = 1024;
let st = ScatteringTransform::new(config, n)
.expect("scattering transform creation should succeed");
let freq = 20.0; let signal: Vec<f64> = (0..n)
.map(|i| (2.0 * PI * freq * i as f64 / n as f64).sin())
.collect();
let result = st.transform(&signal).expect("transform should succeed");
let first = result.first_order();
assert!(!first.is_empty(), "should have first-order coefficients");
let max_path = first
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| {
let ea: f64 = a.values.iter().map(|v| v * v).sum();
let eb: f64 = b.values.iter().map(|v| v * v).sum();
ea.partial_cmp(&eb).unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(idx, _)| idx);
assert!(max_path.is_some(), "should find a path with maximum energy");
}
#[test]
fn test_zeroth_order_only() {
let config = ScatteringConfig::new(3, vec![4]).with_max_order(0);
let n = 128;
let st = ScatteringTransform::new(config, n)
.expect("scattering transform creation should succeed");
let signal: Vec<f64> = (0..n).map(|i| i as f64 / n as f64).collect();
let result = st.transform(&signal).expect("transform should succeed");
assert_eq!(result.num_zeroth, 1);
assert_eq!(result.num_first, 0);
assert_eq!(result.num_second, 0);
}
#[test]
fn test_empty_signal_error() {
let config = ScatteringConfig::new(3, vec![4]);
let st = ScatteringTransform::new(config, 128)
.expect("scattering transform creation should succeed");
let result = st.transform(&[]);
assert!(result.is_err());
}
}