use crate::error::IntegrateResult;
use scirs2_core::ndarray::{Array2, Array3};
use scirs2_core::numeric::Complex;
use super::fft_operations::FFTOperations;
#[derive(Debug, Clone, Copy, PartialEq, Default)]
pub enum DealiasingStrategy {
None,
#[default]
TwoThirds,
ThreeHalves,
PhaseShift,
}
pub struct DealiasingOperations;
impl DealiasingOperations {
pub fn apply_dealiasing_2d(
field: &Array2<f64>,
strategy: DealiasingStrategy,
) -> IntegrateResult<Array2<f64>> {
match strategy {
DealiasingStrategy::None => Ok(field.clone()),
DealiasingStrategy::TwoThirds => Self::apply_two_thirds_rule_2d(field),
DealiasingStrategy::ThreeHalves => Self::apply_three_halves_rule_2d(field),
DealiasingStrategy::PhaseShift => Self::apply_phase_shift_2d(field),
}
}
pub fn apply_dealiasing_3d(
field: &Array3<f64>,
strategy: DealiasingStrategy,
) -> IntegrateResult<Array3<f64>> {
match strategy {
DealiasingStrategy::None => Ok(field.clone()),
DealiasingStrategy::TwoThirds => Self::apply_two_thirds_rule_3d(field),
DealiasingStrategy::ThreeHalves => Self::apply_three_halves_rule_3d(field),
DealiasingStrategy::PhaseShift => Self::apply_phase_shift_3d(field),
}
}
fn apply_two_thirds_rule_2d(field: &Array2<f64>) -> IntegrateResult<Array2<f64>> {
let (nx, ny) = field.dim();
let field_hat = FFTOperations::fft_2d_forward(field)?;
let mut dealiased_hat = field_hat.clone();
let cutoff_x = (2 * nx) / 3;
let cutoff_y = (2 * ny) / 3;
for i in cutoff_x..nx {
for j in 0..ny {
dealiased_hat[[i, j]] = Complex::new(0.0, 0.0);
}
}
for i in 0..nx {
for j in cutoff_y..ny {
dealiased_hat[[i, j]] = Complex::new(0.0, 0.0);
}
}
FFTOperations::fft_2d_backward(&dealiased_hat)
}
fn apply_two_thirds_rule_3d(field: &Array3<f64>) -> IntegrateResult<Array3<f64>> {
let (nx, ny, nz) = field.dim();
let field_hat = FFTOperations::fft_3d_forward(field)?;
let mut dealiased_hat = field_hat.clone();
let cutoff_x = (2 * nx) / 3;
let cutoff_y = (2 * ny) / 3;
let cutoff_z = (2 * nz) / 3;
for i in cutoff_x..nx {
for j in 0..ny {
for k in 0..nz {
dealiased_hat[[i, j, k]] = Complex::new(0.0, 0.0);
}
}
}
for i in 0..nx {
for j in cutoff_y..ny {
for k in 0..nz {
dealiased_hat[[i, j, k]] = Complex::new(0.0, 0.0);
}
}
}
for i in 0..nx {
for j in 0..ny {
for k in cutoff_z..nz {
dealiased_hat[[i, j, k]] = Complex::new(0.0, 0.0);
}
}
}
FFTOperations::fft_3d_backward(&dealiased_hat)
}
fn apply_three_halves_rule_2d(field: &Array2<f64>) -> IntegrateResult<Array2<f64>> {
let (nx, ny) = field.dim();
let nx_pad = (3 * nx) / 2;
let ny_pad = (3 * ny) / 2;
let mut padded_field = Array2::zeros((nx_pad, ny_pad));
let start_x = (nx_pad - nx) / 2;
let start_y = (ny_pad - ny) / 2;
for i in 0..nx {
for j in 0..ny {
padded_field[[start_x + i, start_y + j]] = field[[i, j]];
}
}
let padded_hat = FFTOperations::fft_2d_forward(&padded_field)?;
let padded_result = FFTOperations::fft_2d_backward(&padded_hat)?;
let mut result = Array2::zeros((nx, ny));
for i in 0..nx {
for j in 0..ny {
result[[i, j]] = padded_result[[start_x + i, start_y + j]];
}
}
Ok(result)
}
fn apply_three_halves_rule_3d(field: &Array3<f64>) -> IntegrateResult<Array3<f64>> {
let (nx, ny, nz) = field.dim();
let nx_pad = (3 * nx) / 2;
let ny_pad = (3 * ny) / 2;
let nz_pad = (3 * nz) / 2;
let mut padded_field = Array3::zeros((nx_pad, ny_pad, nz_pad));
let start_x = (nx_pad - nx) / 2;
let start_y = (ny_pad - ny) / 2;
let start_z = (nz_pad - nz) / 2;
for i in 0..nx {
for j in 0..ny {
for k in 0..nz {
padded_field[[start_x + i, start_y + j, start_z + k]] = field[[i, j, k]];
}
}
}
let padded_hat = FFTOperations::fft_3d_forward(&padded_field)?;
let padded_result = FFTOperations::fft_3d_backward(&padded_hat)?;
let mut result = Array3::zeros((nx, ny, nz));
for i in 0..nx {
for j in 0..ny {
for k in 0..nz {
result[[i, j, k]] = padded_result[[start_x + i, start_y + j, start_z + k]];
}
}
}
Ok(result)
}
fn apply_phase_shift_2d(field: &Array2<f64>) -> IntegrateResult<Array2<f64>> {
let (nx, ny) = field.dim();
let field_hat = FFTOperations::fft_2d_forward(field)?;
let mut shifted_hat = field_hat.clone();
for i in 0..nx {
for j in 0..ny {
let kx = if i <= nx / 2 {
i as f64
} else {
(i as i32 - nx as i32) as f64
};
let ky = if j <= ny / 2 {
j as f64
} else {
(j as i32 - ny as i32) as f64
};
let phase_x = Complex::new(0.0, -kx * std::f64::consts::PI / (nx as f64));
let phase_y = Complex::new(0.0, -ky * std::f64::consts::PI / (ny as f64));
let phase_shift = (phase_x + phase_y).exp();
shifted_hat[[i, j]] *= phase_shift;
}
}
let shifted_field = FFTOperations::fft_2d_backward(&shifted_hat)?;
let mut result = Array2::zeros((nx, ny));
for i in 0..nx {
for j in 0..ny {
result[[i, j]] = 0.5 * (field[[i, j]] + shifted_field[[i, j]]);
}
}
Ok(result)
}
fn apply_phase_shift_3d(field: &Array3<f64>) -> IntegrateResult<Array3<f64>> {
let (nx, ny, nz) = field.dim();
let field_hat = FFTOperations::fft_3d_forward(field)?;
let mut shifted_hat = field_hat.clone();
for i in 0..nx {
for j in 0..ny {
for k in 0..nz {
let kx = if i <= nx / 2 {
i as f64
} else {
(i as i32 - nx as i32) as f64
};
let ky = if j <= ny / 2 {
j as f64
} else {
(j as i32 - ny as i32) as f64
};
let kz = if k <= nz / 2 {
k as f64
} else {
(k as i32 - nz as i32) as f64
};
let phase_x = Complex::new(0.0, -kx * std::f64::consts::PI / (nx as f64));
let phase_y = Complex::new(0.0, -ky * std::f64::consts::PI / (ny as f64));
let phase_z = Complex::new(0.0, -kz * std::f64::consts::PI / (nz as f64));
let phase_shift = (phase_x + phase_y + phase_z).exp();
shifted_hat[[i, j, k]] *= phase_shift;
}
}
}
let shifted_field = FFTOperations::fft_3d_backward(&shifted_hat)?;
let mut result = Array3::zeros((nx, ny, nz));
for i in 0..nx {
for j in 0..ny {
for k in 0..nz {
result[[i, j, k]] = 0.5 * (field[[i, j, k]] + shifted_field[[i, j, k]]);
}
}
}
Ok(result)
}
pub fn needs_dealiasing(field: &Array2<f64>, threshold: f64) -> IntegrateResult<bool> {
let field_hat = FFTOperations::fft_2d_forward(field)?;
let (nx, ny) = field_hat.dim();
let cutoff_x = (2 * nx) / 3;
let cutoff_y = (2 * ny) / 3;
let mut high_freq_energy = 0.0;
let mut total_energy = 0.0;
for i in 0..nx {
for j in 0..ny {
let energy = field_hat[[i, j]].norm_sqr();
total_energy += energy;
if i >= cutoff_x || j >= cutoff_y {
high_freq_energy += energy;
}
}
}
let high_freq_ratio = if total_energy > 1e-12 {
high_freq_energy / total_energy
} else {
0.0
};
Ok(high_freq_ratio > threshold)
}
pub fn recommend_strategy(
field_size: (usize, usize),
reynolds_number: f64,
accuracy_requirement: f64,
) -> DealiasingStrategy {
let (nx, ny) = field_size;
let min_size = nx.min(ny);
if reynolds_number < 100.0 {
return DealiasingStrategy::None;
}
if accuracy_requirement >= 0.99 && min_size >= 128 {
return DealiasingStrategy::ThreeHalves;
}
if accuracy_requirement > 0.95 {
return DealiasingStrategy::TwoThirds;
}
if min_size < 64 {
return DealiasingStrategy::PhaseShift;
}
DealiasingStrategy::TwoThirds
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
#[test]
fn test_dealiasing_strategy_default() {
let strategy = DealiasingStrategy::default();
assert_eq!(strategy, DealiasingStrategy::TwoThirds);
}
#[test]
fn test_two_thirds_rule_2d() {
let nx = 16;
let ny = 16;
let mut field = Array2::zeros((nx, ny));
for i in 0..nx {
for j in 0..ny {
let x = i as f64 * 2.0 * std::f64::consts::PI / nx as f64;
let y = j as f64 * 2.0 * std::f64::consts::PI / ny as f64;
field[[i, j]] = (8.0 * x).sin() + (8.0 * y).cos();
}
}
let dealiased =
DealiasingOperations::apply_dealiasing_2d(&field, DealiasingStrategy::TwoThirds)
.expect("Operation failed");
assert_eq!(dealiased.dim(), field.dim());
let original_max = field.iter().fold(0.0f64, |acc, &x| acc.max(x.abs()));
let dealiased_max = dealiased.iter().fold(0.0f64, |acc, &x| acc.max(x.abs()));
assert!(dealiased_max <= original_max);
}
#[test]
fn test_needs_dealiasing() {
let nx = 16;
let ny = 16;
let mut field = Array2::zeros((nx, ny));
for i in 0..nx {
for j in 0..ny {
let x = i as f64 * 2.0 * std::f64::consts::PI / nx as f64;
let y = j as f64 * 2.0 * std::f64::consts::PI / ny as f64;
field[[i, j]] = (6.0 * x).sin() + (6.0 * y).cos();
}
}
let needs_dealiasing =
DealiasingOperations::needs_dealiasing(&field, 0.5).expect("Operation failed");
assert!(needs_dealiasing);
let mut low_freq_field = Array2::zeros((nx, ny));
for i in 0..nx {
for j in 0..ny {
let x = i as f64 * 2.0 * std::f64::consts::PI / nx as f64;
let y = j as f64 * 2.0 * std::f64::consts::PI / ny as f64;
low_freq_field[[i, j]] = x.sin() + y.cos();
}
}
let needs_dealiasing_low =
DealiasingOperations::needs_dealiasing(&low_freq_field, 0.5).expect("Operation failed");
assert!(!needs_dealiasing_low);
}
#[test]
fn test_recommend_strategy() {
let strategy1 = DealiasingOperations::recommend_strategy((64, 64), 50.0, 0.95);
assert_eq!(strategy1, DealiasingStrategy::None);
let strategy2 = DealiasingOperations::recommend_strategy((256, 256), 1000.0, 0.99);
assert_eq!(strategy2, DealiasingStrategy::ThreeHalves);
let strategy3 = DealiasingOperations::recommend_strategy((128, 128), 500.0, 0.96);
assert_eq!(strategy3, DealiasingStrategy::TwoThirds);
let strategy4 = DealiasingOperations::recommend_strategy((32, 32), 200.0, 0.90);
assert_eq!(strategy4, DealiasingStrategy::PhaseShift);
}
#[test]
fn test_no_dealiasing() {
let nx = 8;
let ny = 8;
let field = Array2::ones((nx, ny));
let result = DealiasingOperations::apply_dealiasing_2d(&field, DealiasingStrategy::None)
.expect("Operation failed");
assert_eq!(result, field);
}
}