pub use crate::new_modules::fft::*;
pub mod enhanced {
pub use crate::new_modules::fft_enhanced::*;
}
use crate::array::Array;
use crate::error::Result;
use num_traits::{Float, Zero};
use scirs2_core::parallel_ops::*;
use std::ops::{Add, Mul};
pub fn convolve<T>(a: &Array<T>, b: &Array<T>, mode: &str) -> Result<Array<T>>
where
T: Float + Clone + Zero + Send + Sync,
{
if a.ndim() != 1 || b.ndim() != 1 {
return Err(crate::error::NumRs2Error::DimensionMismatch(
"convolve requires 1D arrays".to_string(),
));
}
let a_len = a.len();
let b_len = b.len();
if a_len == 0 || b_len == 0 {
return Ok(Array::from_vec(vec![]));
}
let (output_size, start_idx) = match mode {
"full" => (a_len + b_len - 1, 0),
"valid" => {
if a_len >= b_len {
(a_len - b_len + 1, b_len - 1)
} else {
(b_len - a_len + 1, a_len - 1)
}
}
"same" => {
let size = a_len.max(b_len);
let start = if a_len >= b_len {
(b_len - 1) / 2
} else {
(a_len - 1) / 2
};
(size, start)
}
_ => {
return Err(crate::error::NumRs2Error::InvalidOperation(format!(
"Invalid mode '{}'. Use 'full', 'valid', or 'same'",
mode
)))
}
};
let a_vec = a.to_vec();
let b_vec = b.to_vec();
const PARALLEL_THRESHOLD: usize = 1000;
let full_len = a_len + b_len - 1;
let full_result: Vec<T> = if full_len >= PARALLEL_THRESHOLD {
(0..full_len)
.into_par_iter()
.map(|n| {
let mut sum = T::zero();
let m_start = n.saturating_sub(b_len - 1);
let m_end = (n + 1).min(a_len);
for m in m_start..m_end {
sum = sum + a_vec[m] * b_vec[n - m];
}
sum
})
.collect()
} else {
let mut result = Vec::with_capacity(full_len);
for n in 0..full_len {
let mut sum = T::zero();
let m_start = n.saturating_sub(b_len - 1);
let m_end = (n + 1).min(a_len);
for m in m_start..m_end {
sum = sum + a_vec[m] * b_vec[n - m];
}
result.push(sum);
}
result
};
let result = match mode {
"full" => full_result,
_ => full_result[start_idx..start_idx + output_size].to_vec(),
};
Ok(Array::from_vec(result))
}
pub fn correlate<T>(a: &Array<T>, b: &Array<T>, mode: &str) -> Result<Array<T>>
where
T: Float + Clone + Zero + Send + Sync,
{
if a.ndim() != 1 || b.ndim() != 1 {
return Err(crate::error::NumRs2Error::DimensionMismatch(
"correlate requires 1D arrays".to_string(),
));
}
let b_reversed = Array::from_vec(b.to_vec().into_iter().rev().collect());
convolve(a, &b_reversed, mode)
}
pub fn convolve2d<T>(input: &Array<T>, kernel: &Array<T>, mode: &str) -> Result<Array<T>>
where
T: Float + Clone + Zero + Add<Output = T> + Mul<Output = T> + Send + Sync,
{
if input.ndim() != 2 || kernel.ndim() != 2 {
return Err(crate::error::NumRs2Error::DimensionMismatch(
"convolve2d requires 2D arrays".to_string(),
));
}
let input_shape = input.shape();
let kernel_shape = kernel.shape();
let (m, n) = (input_shape[0], input_shape[1]);
let (p, q) = (kernel_shape[0], kernel_shape[1]);
if m == 0 || n == 0 || p == 0 || q == 0 {
return Ok(Array::zeros(&[0, 0]));
}
let (output_rows, output_cols, row_offset, col_offset) = match mode {
"full" => (m + p - 1, n + q - 1, 0, 0),
"valid" => {
if m < p || n < q {
return Ok(Array::zeros(&[0, 0]));
}
(m - p + 1, n - q + 1, p - 1, q - 1)
}
"same" => (m, n, (p - 1) / 2, (q - 1) / 2),
_ => {
return Err(crate::error::NumRs2Error::InvalidOperation(format!(
"Invalid mode '{}'. Use 'full', 'valid', or 'same'",
mode
)))
}
};
let input_data = input.to_vec();
let kernel_data = kernel.to_vec();
const PARALLEL_THRESHOLD: usize = 256;
let total_output = output_rows * output_cols;
let result = if total_output >= PARALLEL_THRESHOLD {
use scirs2_core::parallel_ops::*;
(0..total_output)
.into_par_iter()
.map(|idx| {
let out_i = idx / output_cols;
let out_j = idx % output_cols;
let mut sum = T::zero();
for k_i in 0..p {
for k_j in 0..q {
let in_i = out_i + k_i;
let in_j = out_j + k_j;
let adj_in_i = in_i as isize - row_offset as isize;
let adj_in_j = in_j as isize - col_offset as isize;
if adj_in_i >= 0
&& adj_in_i < m as isize
&& adj_in_j >= 0
&& adj_in_j < n as isize
{
let in_idx = adj_in_i as usize * n + adj_in_j as usize;
let k_idx = k_i * q + k_j;
sum = sum + input_data[in_idx] * kernel_data[k_idx];
}
}
}
sum
})
.collect()
} else {
let mut result = vec![T::zero(); total_output];
for out_i in 0..output_rows {
for out_j in 0..output_cols {
let mut sum = T::zero();
for k_i in 0..p {
for k_j in 0..q {
let in_i = out_i + k_i;
let in_j = out_j + k_j;
let adj_in_i = in_i as isize - row_offset as isize;
let adj_in_j = in_j as isize - col_offset as isize;
if adj_in_i >= 0
&& adj_in_i < m as isize
&& adj_in_j >= 0
&& adj_in_j < n as isize
{
let in_idx = adj_in_i as usize * n + adj_in_j as usize;
let k_idx = k_i * q + k_j;
sum = sum + input_data[in_idx] * kernel_data[k_idx];
}
}
}
result[out_i * output_cols + out_j] = sum;
}
}
result
};
Ok(Array::from_vec(result).reshape(&[output_rows, output_cols]))
}
pub fn correlate2d<T>(input: &Array<T>, template: &Array<T>, mode: &str) -> Result<Array<T>>
where
T: Float + Clone + Zero + Add<Output = T> + Mul<Output = T> + Send + Sync,
{
if input.ndim() != 2 || template.ndim() != 2 {
return Err(crate::error::NumRs2Error::DimensionMismatch(
"correlate2d requires 2D arrays".to_string(),
));
}
let template_shape = template.shape();
let (rows, cols) = (template_shape[0], template_shape[1]);
let template_data = template.to_vec();
let mut flipped_data = vec![T::zero(); rows * cols];
for i in 0..rows {
for j in 0..cols {
let orig_idx = i * cols + j;
let flip_idx = (rows - 1 - i) * cols + (cols - 1 - j);
flipped_data[flip_idx] = template_data[orig_idx];
}
}
let flipped_template = Array::from_vec(flipped_data).reshape(&[rows, cols]);
convolve2d(input, &flipped_template, mode)
}
pub fn unwrap<T>(
phase: &Array<T>,
discont: Option<T>,
axis: Option<usize>,
period: Option<T>,
) -> Result<Array<T>>
where
T: Float + std::fmt::Debug,
{
let pi = T::from(std::f64::consts::PI).expect("PI is representable as Float");
let two_pi = pi + pi;
let discont = discont.unwrap_or(pi);
let period = period.unwrap_or(two_pi);
let half_period = period / T::from(2.0).expect("2.0 is representable as Float");
match axis {
None => {
let data = phase.to_vec();
if data.is_empty() {
return Ok(Array::from_vec(vec![]));
}
let mut result = Vec::with_capacity(data.len());
result.push(data[0]);
let mut cumulative_correction = T::zero();
for i in 1..data.len() {
let adjusted = data[i] + cumulative_correction;
let diff = adjusted - result[i - 1];
if diff.abs() > discont {
let correction = if diff > T::zero() {
-((diff + half_period) / period).floor() * period
} else {
-((diff - half_period) / period).floor() * period
};
cumulative_correction = cumulative_correction + correction;
}
result.push(data[i] + cumulative_correction);
}
Ok(Array::from_vec(result).reshape(&phase.shape()))
}
Some(ax) => {
let shape = phase.shape();
if ax >= shape.len() {
return Err(crate::error::NumRs2Error::DimensionMismatch(format!(
"axis {} is out of bounds for array of dimension {}",
ax,
shape.len()
)));
}
let ndim = shape.len();
let axis_len = shape[ax];
if axis_len <= 1 {
return Ok(phase.clone()); }
let mut outer_shape = shape.clone();
outer_shape.remove(ax);
let n_arrays: usize = outer_shape.iter().product();
if n_arrays == 0 {
return Ok(phase.clone());
}
let phase_data = phase.to_vec();
let mut result_data = phase_data.clone();
let mut strides = vec![1usize; ndim];
for i in (0..ndim - 1).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
let mut outer_strides = vec![1usize; outer_shape.len()];
if !outer_shape.is_empty() {
for i in (0..outer_shape.len() - 1).rev() {
outer_strides[i] = outer_strides[i + 1] * outer_shape[i + 1];
}
}
for array_idx in 0..n_arrays {
let mut base_idx = 0usize;
let mut temp = array_idx;
let mut outer_dim_idx = 0;
for dim in 0..ndim {
if dim != ax {
let dim_size = if outer_dim_idx < outer_strides.len() {
(temp / outer_strides[outer_dim_idx]) % outer_shape[outer_dim_idx]
} else {
temp % outer_shape[outer_dim_idx]
};
base_idx += dim_size * strides[dim];
if outer_dim_idx < outer_strides.len() {
temp %= outer_strides[outer_dim_idx];
}
outer_dim_idx += 1;
}
}
let axis_stride = strides[ax];
let mut cumulative_correction = T::zero();
let mut prev_val = result_data[base_idx];
for i in 1..axis_len {
let idx = base_idx + i * axis_stride;
let adjusted = result_data[idx] + cumulative_correction;
let diff = adjusted - prev_val;
if diff.abs() > discont {
let correction = if diff > T::zero() {
-((diff + half_period) / period).floor() * period
} else {
-((diff - half_period) / period).floor() * period
};
cumulative_correction = cumulative_correction + correction;
}
let new_val = result_data[idx] + cumulative_correction;
result_data[idx] = new_val;
prev_val = new_val;
}
}
Ok(Array::from_vec(result_data).reshape(&shape))
}
}
}
pub fn bartlett<T>(m: usize) -> Result<Array<T>>
where
T: Float + Clone + std::fmt::Debug + num_traits::FromPrimitive,
{
if m == 0 {
return Ok(Array::from_vec(vec![]));
}
if m == 1 {
return Ok(Array::from_vec(vec![T::one()]));
}
let mut window = Vec::with_capacity(m);
for i in 0..m {
let val = if i < m / 2 {
2.0 * i as f64 / (m - 1) as f64
} else {
2.0 - 2.0 * i as f64 / (m - 1) as f64
};
window.push(T::from_f64(val).unwrap_or(T::zero()));
}
Ok(Array::from_vec(window))
}
pub fn blackman<T>(m: usize) -> Result<Array<T>>
where
T: Float + Clone + std::fmt::Debug + num_traits::FromPrimitive,
{
if m == 0 {
return Ok(Array::from_vec(vec![]));
}
if m == 1 {
return Ok(Array::from_vec(vec![T::one()]));
}
let mut window = Vec::with_capacity(m);
for i in 0..m {
let arg = 2.0 * std::f64::consts::PI * i as f64 / (m - 1) as f64;
let val = 0.42 - 0.5 * arg.cos() + 0.08 * (2.0 * arg).cos();
let clamped_val = val.clamp(0.0, 1.0);
window.push(T::from_f64(clamped_val).unwrap_or(T::zero()));
}
Ok(Array::from_vec(window))
}
pub fn hanning<T>(m: usize) -> Result<Array<T>>
where
T: Float + Clone + std::fmt::Debug + num_traits::FromPrimitive,
{
if m == 0 {
return Ok(Array::from_vec(vec![]));
}
if m == 1 {
return Ok(Array::from_vec(vec![T::one()]));
}
let mut window = Vec::with_capacity(m);
for i in 0..m {
let arg = 2.0 * std::f64::consts::PI * i as f64 / (m - 1) as f64;
let val = 0.5 * (1.0 - arg.cos());
window.push(T::from_f64(val).unwrap_or(T::zero()));
}
Ok(Array::from_vec(window))
}
pub fn hamming<T>(m: usize) -> Result<Array<T>>
where
T: Float + Clone + std::fmt::Debug + num_traits::FromPrimitive,
{
if m == 0 {
return Ok(Array::from_vec(vec![]));
}
if m == 1 {
return Ok(Array::from_vec(vec![T::one()]));
}
let mut window = Vec::with_capacity(m);
for i in 0..m {
let arg = 2.0 * std::f64::consts::PI * i as f64 / (m - 1) as f64;
let val = 0.54 - 0.46 * arg.cos();
window.push(T::from_f64(val).unwrap_or(T::zero()));
}
Ok(Array::from_vec(window))
}
pub fn kaiser<T>(m: usize, beta: Option<T>) -> Result<Array<T>>
where
T: Float + Clone + std::fmt::Debug + num_traits::FromPrimitive,
{
if m == 0 {
return Ok(Array::from_vec(vec![]));
}
if m == 1 {
return Ok(Array::from_vec(vec![T::one()]));
}
let beta_val = beta.unwrap_or(T::from_f64(8.6).unwrap_or(T::zero()));
let beta_f64 = beta_val.to_f64().unwrap_or(8.6);
let i0_beta = modified_bessel_i0(beta_f64);
let mut window = Vec::with_capacity(m);
for i in 0..m {
let x = 2.0 * i as f64 / (m - 1) as f64 - 1.0;
let arg = beta_f64 * (1.0 - x * x).sqrt();
let val = modified_bessel_i0(arg) / i0_beta;
window.push(T::from_f64(val).unwrap_or(T::zero()));
}
Ok(Array::from_vec(window))
}
fn modified_bessel_i0(x: f64) -> f64 {
let ax = x.abs();
if ax < 3.75 {
let y = x / 3.75;
let y2 = y * y;
1.0 + 3.5156229 * y2
+ 3.0899424 * y2 * y2
+ 1.2067492 * y2 * y2 * y2
+ 0.2659732 * y2 * y2 * y2 * y2
+ 0.0360768 * y2 * y2 * y2 * y2 * y2
+ 0.0045813 * y2 * y2 * y2 * y2 * y2 * y2
} else {
let y = 3.75 / ax;
let result = 0.39894228 + 0.01328592 * y + 0.00225319 * y * y - 0.00157565 * y * y * y
+ 0.00916281 * y * y * y * y
- 0.02057706 * y * y * y * y * y
+ 0.02635537 * y * y * y * y * y * y
- 0.01647633 * y * y * y * y * y * y * y
+ 0.00392377 * y * y * y * y * y * y * y * y;
(ax.exp() / ax.sqrt()) * result
}
}
#[cfg(test)]
mod window_tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_bartlett_window() {
let window = bartlett::<f64>(10).expect("bartlett(10) should succeed");
let data = window.to_vec();
assert_eq!(data.len(), 10);
assert_relative_eq!(data[0], 0.0, epsilon = 1e-10);
assert_relative_eq!(data[9], 0.0, epsilon = 1e-10);
for i in 0..5 {
assert_relative_eq!(data[i], data[9 - i], epsilon = 1e-10);
}
let window_odd = bartlett::<f64>(9).expect("bartlett(9) should succeed");
let data_odd = window_odd.to_vec();
assert_relative_eq!(data_odd[4], 1.0, epsilon = 1e-10);
}
#[test]
fn test_blackman_window() {
let window = blackman::<f64>(10).expect("blackman(10) should succeed");
let data = window.to_vec();
assert_eq!(data.len(), 10);
assert!(data[0] < 0.01);
assert!(data[9] < 0.01);
for i in 0..5 {
assert_relative_eq!(data[i], data[9 - i], epsilon = 1e-10);
}
for &val in &data {
assert!((0.0..=1.0).contains(&val));
}
}
#[test]
fn test_hanning_window() {
let window = hanning::<f64>(10).expect("hanning(10) should succeed");
let data = window.to_vec();
assert_eq!(data.len(), 10);
assert_relative_eq!(data[0], 0.0, epsilon = 1e-10);
assert_relative_eq!(data[9], 0.0, epsilon = 1e-10);
for i in 0..5 {
assert_relative_eq!(data[i], data[9 - i], epsilon = 1e-10);
}
for &val in &data {
assert!((0.0..=1.0).contains(&val));
}
}
#[test]
fn test_hamming_window() {
let window = hamming::<f64>(10).expect("hamming(10) should succeed");
let data = window.to_vec();
assert_eq!(data.len(), 10);
assert!(data[0] > 0.0);
assert!(data[9] > 0.0);
for i in 0..5 {
assert_relative_eq!(data[i], data[9 - i], epsilon = 1e-10);
}
for &val in &data {
assert!(val > 0.0 && val <= 1.0);
}
assert_relative_eq!(data[0], 0.08, epsilon = 1e-10);
assert_relative_eq!(data[9], 0.08, epsilon = 1e-10);
}
#[test]
fn test_kaiser_window() {
let window = kaiser::<f64>(10, Some(8.6)).expect("kaiser(10, 8.6) should succeed");
let data = window.to_vec();
assert_eq!(data.len(), 10);
for i in 0..5 {
assert_relative_eq!(data[i], data[9 - i], epsilon = 1e-10);
}
for &val in &data {
assert!(val > 0.0 && val <= 1.0);
}
let max_val = data.iter().fold(0.0, |a, &b| a.max(b));
assert_relative_eq!(max_val, 1.0, epsilon = 1e-1);
let window_narrow = kaiser::<f64>(10, Some(14.0)).expect("kaiser(10, 14.0) should succeed");
let window_wide = kaiser::<f64>(10, Some(2.0)).expect("kaiser(10, 2.0) should succeed");
assert_eq!(window_narrow.len(), 10);
assert_eq!(window_wide.len(), 10);
}
#[test]
fn test_window_edge_cases() {
assert_eq!(
bartlett::<f64>(0)
.expect("bartlett(0) should succeed")
.len(),
0
);
assert_eq!(
blackman::<f64>(0)
.expect("blackman(0) should succeed")
.len(),
0
);
assert_eq!(
hanning::<f64>(0).expect("hanning(0) should succeed").len(),
0
);
assert_eq!(
hamming::<f64>(0).expect("hamming(0) should succeed").len(),
0
);
assert_eq!(
kaiser::<f64>(0, Some(8.6))
.expect("kaiser(0) should succeed")
.len(),
0
);
assert_relative_eq!(
bartlett::<f64>(1)
.expect("bartlett(1) should succeed")
.get(&[0])
.expect("index 0 should be valid"),
1.0,
epsilon = 1e-10
);
assert_relative_eq!(
blackman::<f64>(1)
.expect("blackman(1) should succeed")
.get(&[0])
.expect("index 0 should be valid"),
1.0,
epsilon = 1e-10
);
assert_relative_eq!(
hanning::<f64>(1)
.expect("hanning(1) should succeed")
.get(&[0])
.expect("index 0 should be valid"),
1.0,
epsilon = 1e-10
);
assert_relative_eq!(
hamming::<f64>(1)
.expect("hamming(1) should succeed")
.get(&[0])
.expect("index 0 should be valid"),
1.0,
epsilon = 1e-10
);
assert_relative_eq!(
kaiser::<f64>(1, Some(8.6))
.expect("kaiser(1) should succeed")
.get(&[0])
.expect("index 0 should be valid"),
1.0,
epsilon = 1e-10
);
}
#[test]
fn test_modified_bessel_i0() {
assert_relative_eq!(modified_bessel_i0(0.0), 1.0, epsilon = 1e-10);
let val = 2.5;
assert_relative_eq!(
modified_bessel_i0(val),
modified_bessel_i0(-val),
epsilon = 1e-10
);
assert!(modified_bessel_i0(1.0) > 1.0);
assert!(modified_bessel_i0(5.0) > modified_bessel_i0(1.0));
}
}