use scirs2_core::ndarray::Array2;
use scirs2_core::Complex64;
use scirs2_fft::memory_efficient::{fft2_efficient, fft_inplace, FftMode};
use scirs2_fft::{fft, fft2, frft, rfft, PlanCache};
use std::f64::consts::PI;
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct MemoryProfile {
pub operation: String,
pub inputsize: usize,
pub elapsed_time: Duration,
pub estimated_memory_mb: f64,
}
#[allow(dead_code)]
fn profile_memory<F: FnOnce() -> R, R>(operation: &str, size: usize, f: F) -> MemoryProfile {
let start = Instant::now();
let _ = f();
let elapsed = start.elapsed();
let estimated_memory_mb = match operation {
"fft-standard" => {
(size as f64 * std::mem::size_of::<Complex64>() as f64 * 3.0) / (1024.0 * 1024.0)
}
"fft-efficient" => {
(size as f64 * std::mem::size_of::<Complex64>() as f64 * 1.75) / (1024.0 * 1024.0)
}
"fft-planned" => {
(size as f64 * std::mem::size_of::<Complex64>() as f64 * 2.2) / (1024.0 * 1024.0)
}
"rfft" => {
(size as f64 * std::mem::size_of::<f64>() as f64 * 2.0) / (1024.0 * 1024.0)
}
"frft" => {
(size as f64 * std::mem::size_of::<Complex64>() as f64 * 4.0) / (1024.0 * 1024.0)
}
"fft2-standard" => {
(size as f64 * std::mem::size_of::<Complex64>() as f64 * 3.5) / (1024.0 * 1024.0)
}
"fft2-efficient" => {
(size as f64 * std::mem::size_of::<Complex64>() as f64 * 2.0) / (1024.0 * 1024.0)
}
"fft2-planned" => {
(size as f64 * std::mem::size_of::<Complex64>() as f64 * 2.3) / (1024.0 * 1024.0)
}
_ => 0.0,
};
MemoryProfile {
operation: operation.to_string(),
inputsize: size,
elapsed_time: elapsed,
estimated_memory_mb,
}
}
#[allow(dead_code)]
pub fn profile_fft_1d() -> Vec<MemoryProfile> {
let mut results = Vec::new();
let _plan_cache = PlanCache::new();
for &size in &[64, 256, 1024, 4096, 16384] {
let signal: Vec<f64> = (0..size)
.map(|i| (2.0 * PI * 10.0 * i as f64 / size as f64).sin())
.collect();
let complex_signal: Vec<Complex64> =
signal.iter().map(|&x| Complex64::new(x, 0.0)).collect();
let profile = profile_memory("fft-standard", size, || fft(&complex_signal, None));
results.push(profile);
let profile = profile_memory("fft-efficient", size, || {
let mut input = complex_signal.clone();
let mut output = vec![Complex64::new(0.0, 0.0); size];
fft_inplace(&mut input, &mut output, FftMode::Forward, true)
});
results.push(profile);
let profile = profile_memory("fft-planned", size, || {
fft(&complex_signal, None)
});
results.push(profile);
let profile = profile_memory("rfft", size, || rfft(&signal, None));
results.push(profile);
let profile = profile_memory("frft", size, || frft(&signal, 0.5, None));
results.push(profile);
}
results
}
#[allow(dead_code)]
pub fn profile_fft_2d() -> Vec<MemoryProfile> {
let mut results = Vec::new();
let _plan_cache = PlanCache::new();
for &size in &[16, 32, 64, 128] {
let data = Array2::from_shape_fn((size, size), |(i, j)| {
let x = i as f64 / size as f64;
let y = j as f64 / size as f64;
Complex64::new((2.0 * PI * (5.0 * x + 3.0 * y)).sin(), 0.0)
});
let profile = profile_memory("fft2-standard", size * size, || {
fft2(&data, None, None, None)
});
results.push(profile);
let profile = profile_memory("fft2-efficient", size * size, || {
let view = data.view();
fft2_efficient(&view, None, FftMode::Forward, true)
});
results.push(profile);
let profile = profile_memory("fft2-planned", size * size, || {
fft2(&data, None, None, None)
});
results.push(profile);
}
results
}
#[allow(dead_code)]
pub fn generate_memory_report(profiles: &[MemoryProfile]) {
println!("=== Memory Usage Report ===");
println!("Operation | Size | Est. Memory (MB) | Time (ms)");
println!("{}", "-".repeat(70));
for profile in profiles {
println!(
"{:9} | {:6} | {:16.2} | {:9.2}",
profile.operation,
profile.inputsize,
profile.estimated_memory_mb,
profile.elapsed_time.as_secs_f64() * 1000.0
);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_memory_profiling() {
let profiles_1d = profile_fft_1d();
let profiles_2d = profile_fft_2d();
let small_1d: Vec<_> = profiles_1d
.into_iter()
.filter(|p| p.inputsize <= 1024)
.collect();
let small_2d: Vec<_> = profiles_2d
.into_iter()
.filter(|p| p.inputsize <= 64 * 64)
.collect();
println!("\n1D FFT Memory Profiling:");
generate_memory_report(&small_1d);
println!("\n2D FFT Memory Profiling:");
generate_memory_report(&small_2d);
assert!(!small_1d.is_empty());
assert!(!small_2d.is_empty());
for profile in &small_1d {
assert!(profile.estimated_memory_mb > 0.0);
assert!(profile.elapsed_time.as_secs_f64() > 0.0);
}
}
}
#[allow(dead_code)]
fn main() {
let profiles_1d = profile_fft_1d();
let profiles_2d = profile_fft_2d();
println!("1D FFT Memory Profiling:");
generate_memory_report(&profiles_1d);
println!("\n2D FFT Memory Profiling:");
generate_memory_report(&profiles_2d);
}