#[cfg(not(feature = "std"))]
extern crate alloc;
#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
use crate::api::{Direction, Flags, Plan};
use crate::kernel::{Complex, Float};
use super::window::WindowFunction;
use super::RingBuffer;
pub struct StreamingFft<T: Float> {
fft_size: usize,
hop_size: usize,
window: Vec<T>,
forward_plan: Option<Plan<T>>,
inverse_plan: Option<Plan<T>>,
input_buffer: RingBuffer<T>,
output_buffer: Vec<T>,
output_pos: usize,
pending_frames: Vec<Vec<Complex<T>>>,
}
impl<T: Float> StreamingFft<T> {
pub fn new(fft_size: usize, hop_size: usize, window: WindowFunction) -> Self {
let window_coeffs = window.generate(fft_size);
let forward_plan = Plan::dft_1d(fft_size, Direction::Forward, Flags::ESTIMATE);
let inverse_plan = Plan::dft_1d(fft_size, Direction::Backward, Flags::ESTIMATE);
Self {
fft_size,
hop_size,
window: window_coeffs,
forward_plan,
inverse_plan,
input_buffer: RingBuffer::new(fft_size),
output_buffer: vec![T::ZERO; fft_size * 2],
output_pos: 0,
pending_frames: Vec::new(),
}
}
pub fn feed(&mut self, samples: &[T]) -> usize {
self.input_buffer.push_slice(samples);
let mut frames_processed = 0;
while self.input_buffer.len() >= self.fft_size {
if let Some(frame) = self.process_frame() {
self.pending_frames.push(frame);
frames_processed += 1;
}
self.input_buffer.advance(self.hop_size);
}
frames_processed
}
pub fn pop_frame(&mut self) -> Option<Vec<Complex<T>>> {
if self.pending_frames.is_empty() {
None
} else {
Some(self.pending_frames.remove(0))
}
}
fn process_frame(&self) -> Option<Vec<Complex<T>>> {
let plan = self.forward_plan.as_ref()?;
let mut frame = vec![T::ZERO; self.fft_size];
self.input_buffer.read_last(&mut frame);
let input: Vec<Complex<T>> = frame
.iter()
.zip(self.window.iter())
.map(|(&s, &w)| Complex::new(s * w, T::ZERO))
.collect();
let mut output = vec![Complex::<T>::zero(); self.fft_size];
plan.execute(&input, &mut output);
Some(output)
}
pub fn analyze_frame(&self, frame: &[T]) -> Vec<Complex<T>> {
if frame.len() != self.fft_size {
return vec![Complex::<T>::zero(); self.fft_size];
}
let plan = match &self.forward_plan {
Some(p) => p,
None => return vec![Complex::<T>::zero(); self.fft_size],
};
let input: Vec<Complex<T>> = frame
.iter()
.zip(self.window.iter())
.map(|(&s, &w)| Complex::new(s * w, T::ZERO))
.collect();
let mut output = vec![Complex::<T>::zero(); self.fft_size];
plan.execute(&input, &mut output);
output
}
pub fn synthesize_frame(&self, spectrum: &[Complex<T>]) -> Vec<T> {
if spectrum.len() != self.fft_size {
return vec![T::ZERO; self.fft_size];
}
let plan = match &self.inverse_plan {
Some(p) => p,
None => return vec![T::ZERO; self.fft_size],
};
let mut output = vec![Complex::<T>::zero(); self.fft_size];
plan.execute(spectrum, &mut output);
let scale = T::ONE / T::from_usize(self.fft_size);
output
.iter()
.zip(self.window.iter())
.map(|(c, &w)| c.re * scale * w)
.collect()
}
pub fn fft_size(&self) -> usize {
self.fft_size
}
pub fn hop_size(&self) -> usize {
self.hop_size
}
pub fn window(&self) -> &[T] {
&self.window
}
pub fn clear(&mut self) {
self.input_buffer.clear();
self.pending_frames.clear();
for v in &mut self.output_buffer {
*v = T::ZERO;
}
self.output_pos = 0;
}
}
pub fn stft<T: Float>(
signal: &[T],
fft_size: usize,
hop_size: usize,
window: WindowFunction,
) -> Vec<Vec<Complex<T>>> {
if signal.len() < fft_size || fft_size == 0 || hop_size == 0 {
return Vec::new();
}
let window_coeffs: Vec<T> = window.generate(fft_size);
let plan = match Plan::dft_1d(fft_size, Direction::Forward, Flags::ESTIMATE) {
Some(p) => p,
None => return Vec::new(),
};
let num_frames = (signal.len() - fft_size) / hop_size + 1;
let mut spectrogram = Vec::with_capacity(num_frames);
for frame_idx in 0..num_frames {
let start = frame_idx * hop_size;
let end = start + fft_size;
let input: Vec<Complex<T>> = signal[start..end]
.iter()
.zip(window_coeffs.iter())
.map(|(&s, &w)| Complex::new(s * w, T::ZERO))
.collect();
let mut output = vec![Complex::<T>::zero(); fft_size];
plan.execute(&input, &mut output);
spectrogram.push(output);
}
spectrogram
}
pub fn istft<T: Float>(
spectrogram: &[Vec<Complex<T>>],
hop_size: usize,
window: WindowFunction,
) -> Vec<T> {
if spectrogram.is_empty() || hop_size == 0 {
return Vec::new();
}
let fft_size = spectrogram[0].len();
if fft_size == 0 {
return Vec::new();
}
let window_coeffs: Vec<T> = window.generate(fft_size);
let plan = match Plan::dft_1d(fft_size, Direction::Backward, Flags::ESTIMATE) {
Some(p) => p,
None => return Vec::new(),
};
let num_frames = spectrogram.len();
let output_len = fft_size + (num_frames - 1) * hop_size;
let mut output = vec![T::ZERO; output_len];
let mut window_sum = vec![T::ZERO; output_len];
let scale = T::ONE / T::from_usize(fft_size);
for (frame_idx, spectrum) in spectrogram.iter().enumerate() {
if spectrum.len() != fft_size {
continue;
}
let mut frame = vec![Complex::<T>::zero(); fft_size];
plan.execute(spectrum, &mut frame);
let start = frame_idx * hop_size;
for i in 0..fft_size {
let w = window_coeffs[i];
output[start + i] = output[start + i] + frame[i].re * scale * w;
window_sum[start + i] = window_sum[start + i] + w * w;
}
}
let threshold = T::from_f64(1e-10);
for i in 0..output_len {
if window_sum[i] > threshold {
output[i] = output[i] / window_sum[i];
}
}
output
}
pub fn stft_overlap_save<T: Float>(
signal: &[T],
fft_size: usize,
hop_size: usize,
window: WindowFunction,
) -> Vec<Vec<Complex<T>>> {
if fft_size == 0 || hop_size == 0 || hop_size >= fft_size || signal.is_empty() {
return Vec::new();
}
let window_coeffs: Vec<T> = window.generate(fft_size);
let plan = match Plan::dft_1d(fft_size, Direction::Forward, Flags::ESTIMATE) {
Some(p) => p,
None => return Vec::new(),
};
let n_overlap = fft_size - hop_size;
let padded_len = n_overlap + signal.len();
let num_frames = padded_len / hop_size;
let mut spectrogram = Vec::with_capacity(num_frames);
let mut frame_buf = vec![Complex::<T>::zero(); fft_size];
let mut output = vec![Complex::<T>::zero(); fft_size];
for frame_idx in 0..num_frames {
let buf_start = frame_idx * hop_size;
for i in 0..fft_size {
let padded_pos = buf_start + i;
let sample = if padded_pos < n_overlap {
T::ZERO
} else {
let sig_pos = padded_pos - n_overlap;
if sig_pos < signal.len() {
signal[sig_pos]
} else {
T::ZERO
}
};
frame_buf[i] = Complex::new(sample * window_coeffs[i], T::ZERO);
}
plan.execute(&frame_buf, &mut output);
spectrogram.push(output.clone());
}
spectrogram
}
pub fn istft_overlap_save<T: Float>(
spectra: &[Vec<Complex<T>>],
fft_size: usize,
hop_size: usize,
window: WindowFunction,
) -> Vec<T> {
if spectra.is_empty() || fft_size == 0 || hop_size == 0 || hop_size >= fft_size {
return Vec::new();
}
let window_coeffs: Vec<T> = window.generate(fft_size);
let plan = match Plan::dft_1d(fft_size, Direction::Backward, Flags::ESTIMATE) {
Some(p) => p,
None => return Vec::new(),
};
let n_overlap = fft_size - hop_size;
let scale = T::ONE / T::from_usize(fft_size);
let total_len = spectra.len() * hop_size;
let mut output = Vec::with_capacity(total_len);
let mut ifft_buf = vec![Complex::<T>::zero(); fft_size];
for frame_spectrum in spectra {
if frame_spectrum.len() != fft_size {
for _ in 0..hop_size {
output.push(T::ZERO);
}
continue;
}
plan.execute(frame_spectrum, &mut ifft_buf);
let threshold = T::from_f64(1e-12);
for i in n_overlap..fft_size {
let w = window_coeffs[i];
let sample = if w > threshold || w < (T::ZERO - threshold) {
ifft_buf[i].re * scale / w
} else {
ifft_buf[i].re * scale
};
output.push(sample);
}
}
output
}
pub fn magnitude_spectrogram<T: Float>(spectrogram: &[Vec<Complex<T>>]) -> Vec<Vec<T>> {
spectrogram
.iter()
.map(|frame| frame.iter().map(|c| c.norm()).collect())
.collect()
}
pub fn power_spectrogram<T: Float>(spectrogram: &[Vec<Complex<T>>]) -> Vec<Vec<T>> {
spectrogram
.iter()
.map(|frame| frame.iter().map(|c| c.norm_sqr()).collect())
.collect()
}
pub fn phase_spectrogram<T: Float>(spectrogram: &[Vec<Complex<T>>]) -> Vec<Vec<T>> {
spectrogram
.iter()
.map(|frame| frame.iter().map(|c| c.im.atan2(c.re)).collect())
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_streaming_fft_basic() {
let mut processor: StreamingFft<f64> = StreamingFft::new(8, 4, WindowFunction::Hann);
let samples = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let frames = processor.feed(&samples);
assert!(frames > 0);
assert!(processor.pop_frame().is_some());
}
#[test]
fn test_streaming_fft_analyze_synthesize() {
let processor: StreamingFft<f64> = StreamingFft::new(8, 4, WindowFunction::Hann);
let frame = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let spectrum = processor.analyze_frame(&frame);
let reconstructed = processor.synthesize_frame(&spectrum);
assert_eq!(spectrum.len(), 8);
assert_eq!(reconstructed.len(), 8);
}
#[test]
fn test_stft_basic() {
let signal: Vec<f64> = vec![0.0; 128];
let spectrogram = stft(&signal, 32, 16, WindowFunction::Hann);
assert_eq!(spectrogram.len(), 7);
assert_eq!(spectrogram[0].len(), 32);
}
#[test]
fn test_stft_istft_roundtrip() {
let n = 256;
let signal: Vec<f64> = (0..n).map(|i| (f64::from(i) / 10.0).sin()).collect();
let fft_size = 64;
let hop_size = 16;
let window = WindowFunction::Hann;
let spectrogram = stft(&signal, fft_size, hop_size, window.clone());
let reconstructed = istft(&spectrogram, hop_size, window);
let start = fft_size;
let end = reconstructed.len().saturating_sub(fft_size);
if end > start {
for i in start..end.min(signal.len()) {
let diff = (reconstructed[i] - signal[i]).abs();
assert!(
diff < 0.1,
"Mismatch at {}: {} vs {}",
i,
reconstructed[i],
signal[i]
);
}
}
}
#[test]
fn test_magnitude_spectrogram() {
let signal: Vec<f64> = vec![1.0; 64];
let spectrogram = stft(&signal, 16, 8, WindowFunction::Rectangular);
let magnitudes = magnitude_spectrogram(&spectrogram);
assert!(!magnitudes.is_empty());
assert!(magnitudes[0][0] > 0.0);
}
#[test]
fn test_power_spectrogram() {
let signal: Vec<f64> = vec![1.0; 64];
let spectrogram = stft(&signal, 16, 8, WindowFunction::Rectangular);
let powers = power_spectrogram(&spectrogram);
assert!(!powers.is_empty());
assert!(powers[0][0] >= 0.0);
}
#[test]
fn test_phase_spectrogram() {
let signal: Vec<f64> = (0..64).map(|i| f64::from(i).sin()).collect();
let spectrogram = stft(&signal, 16, 8, WindowFunction::Hann);
let phases = phase_spectrogram(&spectrogram);
assert!(!phases.is_empty());
for frame in &phases {
for &phase in frame {
assert!(phase >= -core::f64::consts::PI - 0.01);
assert!(phase <= core::f64::consts::PI + 0.01);
}
}
}
#[test]
fn test_stft_overlap_save_basic_shape() {
let signal: Vec<f64> = vec![0.0; 256];
let fft_size = 64;
let hop_size = 16;
let spectra = stft_overlap_save(&signal, fft_size, hop_size, WindowFunction::Hann);
assert!(!spectra.is_empty(), "Should produce at least one frame");
for frame in &spectra {
assert_eq!(frame.len(), fft_size);
}
}
#[test]
fn test_stft_overlap_save_same_frame_count_as_stft() {
let signal: Vec<f64> = (0..512)
.map(|i| (f64::from(i) / 8.0 * core::f64::consts::TAU).sin())
.collect();
let fft_size = 64;
let hop_size = 16;
let oa_spectra = stft(&signal, fft_size, hop_size, WindowFunction::Hann);
let os_spectra = stft_overlap_save(&signal, fft_size, hop_size, WindowFunction::Hann);
assert!(!oa_spectra.is_empty());
assert!(!os_spectra.is_empty());
for frame in &os_spectra {
assert_eq!(frame.len(), fft_size);
}
}
#[test]
fn test_stft_overlap_save_roundtrip_rectangular() {
let n = 256;
let fft_size = 64;
let hop_size = 32;
let window = WindowFunction::Rectangular;
let signal: Vec<f64> = (0..n).map(|i| (i as f64 * 0.0731_f64).sin()).collect();
let spectra = stft_overlap_save(&signal, fft_size, hop_size, window.clone());
let recovered = istft_overlap_save(&spectra, fft_size, hop_size, window);
let check_len = recovered.len().min(n);
let mut max_err = 0.0f64;
for i in 0..check_len {
let err = (recovered[i] - signal[i]).abs();
if err > max_err {
max_err = err;
}
}
assert!(
max_err < 1e-9,
"Max roundtrip error {max_err} exceeds threshold with rectangular window"
);
}
#[test]
fn test_stft_overlap_save_magnitude_similar_to_overlap_add() {
let n = 512;
let fft_size = 64;
let hop_size = 16;
let window = WindowFunction::Hann;
let freq_bin = 4usize;
let signal: Vec<f64> = (0..n)
.map(|i| {
(2.0 * core::f64::consts::PI * freq_bin as f64 * i as f64 / fft_size as f64).sin()
})
.collect();
let os_spectra = stft_overlap_save(&signal, fft_size, hop_size, window.clone());
let oa_spectra = stft(&signal, fft_size, hop_size, window);
let mid = os_spectra.len() / 2;
let os_frame = &os_spectra[mid];
let oa_frame = &oa_spectra[(oa_spectra.len() / 2).min(oa_spectra.len() - 1)];
let os_peak = os_frame
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| {
a.norm()
.partial_cmp(&b.norm())
.unwrap_or(core::cmp::Ordering::Equal)
})
.map(|(k, _)| k)
.unwrap_or(0);
let oa_peak = oa_frame
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| {
a.norm()
.partial_cmp(&b.norm())
.unwrap_or(core::cmp::Ordering::Equal)
})
.map(|(k, _)| k)
.unwrap_or(0);
let mirror = fft_size - freq_bin;
assert!(
os_peak == freq_bin || os_peak == mirror,
"Overlap-save peak at bin {os_peak}, expected {freq_bin} or {mirror}"
);
assert!(
oa_peak == freq_bin || oa_peak == mirror,
"Overlap-add peak at bin {oa_peak}, expected {freq_bin} or {mirror}"
);
}
#[test]
fn test_stft_overlap_save_empty_signal() {
let spectra = stft_overlap_save::<f64>(&[], 64, 16, WindowFunction::Hann);
assert!(spectra.is_empty());
}
#[test]
fn test_stft_overlap_save_invalid_params() {
let signal = vec![0.0f64; 128];
let spectra = stft_overlap_save(&signal, 32, 32, WindowFunction::Hann);
assert!(spectra.is_empty());
let spectra = stft_overlap_save(&signal, 0, 16, WindowFunction::Hann);
assert!(spectra.is_empty());
}
}