use std::sync::Arc;
use num_traits::Float;
use realfft::num_complex::Complex;
use realfft::{ComplexToReal, FftNum, RealFftPlanner, RealToComplex};
use crate::Error;
use crate::config::DerivedConfig;
use crate::lpc::{ExtrapolateFallback, extrapolate_backward, extrapolate_forward};
pub(crate) struct ArdftsrcCore<T = f64>
where
T: Float + FftNum,
{
derived: DerivedConfig<T>,
forward: Arc<dyn RealToComplex<T>>,
inverse: Arc<dyn ComplexToReal<T>>,
scratch: Scratch<T>,
overlap: Vec<T>,
output_block: Vec<T>,
prev_input_window: Vec<T>,
final_input_seen: bool,
finalized: bool,
trim_remaining: usize,
flush_remaining: usize,
pre: Option<Vec<T>>,
post: Option<Vec<T>>,
input_sample_count: usize,
output_sample_count: usize,
}
struct Scratch<T>
where
T: Float + FftNum,
{
rdft_in: Vec<T>,
spectrum: Vec<Complex<T>>,
resampled_spectrum: Vec<Complex<T>>,
rdft_out: Vec<T>,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
enum TransformMode {
Normal,
Start,
End,
}
impl<T> ArdftsrcCore<T>
where
T: Float + FftNum,
{
#[inline]
fn input_chunk_len_samples(&self) -> usize {
self.derived.input_chunk_frames
}
#[inline]
fn output_chunk_len_samples(&self) -> usize {
self.derived.output_chunk_frames
}
pub fn new(derived: DerivedConfig<T>) -> Self {
let mut planner = RealFftPlanner::<T>::new();
let forward = planner.plan_fft_forward(derived.input_fft_size);
let inverse = planner.plan_fft_inverse(derived.output_fft_size);
let output_offset = derived.output_offset;
let scratch = Scratch {
rdft_in: forward.make_input_vec(),
spectrum: forward.make_output_vec(),
resampled_spectrum: inverse.make_input_vec(),
rdft_out: inverse.make_output_vec(),
};
let overlap = vec![T::zero(); derived.output_chunk_frames];
let output_block = vec![T::zero(); derived.output_chunk_frames];
let prev_input_window = vec![T::zero(); derived.input_chunk_frames * 2];
Self {
derived,
forward,
inverse,
scratch,
overlap,
output_block,
prev_input_window,
final_input_seen: false,
finalized: false,
trim_remaining: output_offset,
flush_remaining: output_offset,
pre: None,
post: None,
input_sample_count: 0,
output_sample_count: 0,
}
}
#[inline]
pub(crate) fn input_sample_processed(&self) -> usize {
self.input_sample_count
}
#[inline]
pub(crate) fn output_sample_processed(&self) -> usize {
self.output_sample_count
}
#[inline]
pub(crate) fn input_chunk_samples(&self) -> usize {
self.input_chunk_len_samples()
}
#[inline]
pub(crate) fn input_buffer_size(&self) -> usize {
self.input_chunk_samples()
}
#[inline]
pub fn pre(&mut self, pre: Vec<T>) {
self.pre = self.normalize_context(pre);
}
#[inline]
pub fn post(&mut self, post: Vec<T>) {
self.post = self.normalize_context(post);
}
#[inline]
pub fn output_sample_count_for_input(&self, input_samples: usize) -> usize {
(input_samples * self.derived.output_sample_rate).div_ceil(self.derived.input_sample_rate)
}
#[inline]
pub fn output_sample_count(&self, input_samples: usize) -> usize {
self.output_sample_count_for_input(input_samples)
}
pub fn process_all<'a>(&mut self, input: &[T]) -> Result<Vec<T>, Error> {
let expected_samples = self.output_sample_count(input.len());
let mut output = Vec::with_capacity(expected_samples);
let mut offset = 0;
let input_chunk_size = self.input_buffer_size();
while offset + input_chunk_size <= input.len() {
let chunk_output = self.process_chunk(&input[offset..offset + input_chunk_size], false)?;
output.extend_from_slice(chunk_output);
offset += input_chunk_size;
}
let final_chunk_output = self.process_chunk(&input[offset..], true)?;
output.extend_from_slice(&final_chunk_output);
let finalize_output = self.finalize()?;
output.extend_from_slice(finalize_output);
Ok(output)
}
pub fn reset(&mut self) {
let zero = Complex::new(T::zero(), T::zero());
self.scratch.rdft_in.fill(T::zero());
self.scratch.spectrum.fill(zero);
self.scratch.resampled_spectrum.fill(zero);
self.scratch.rdft_out.fill(T::zero());
self.overlap.fill(T::zero());
self.output_block.fill(T::zero());
self.prev_input_window.fill(T::zero());
self.final_input_seen = false;
self.finalized = false;
self.trim_remaining = self.derived.output_offset;
self.flush_remaining = self.derived.output_offset;
self.input_sample_count = 0;
self.output_sample_count = 0;
self.pre = None;
self.post = None;
}
pub fn finalize<'a>(&'a mut self) -> Result<&'a [T], Error> {
if self.finalized {
return Err(Error::AlreadyFlushed);
}
self.final_input_seen = true;
let written = if self.is_passthrough() || self.input_sample_count == 0 {
self.finalized = true;
0
} else {
let flush_candidate = self.flush_remaining;
let written_samples = self.cap_write_to_output_budget(flush_candidate);
self.finalized = true;
self.add_synthetic_finalize_tail_to_overlap()?;
let scale = T::from(self.output_chunk_len_samples()).unwrap_or(T::one())
/ T::from(self.input_chunk_len_samples()).unwrap_or(T::one());
for (dst, src) in self.output_block[..written_samples]
.iter_mut()
.zip(self.overlap[..written_samples].iter())
{
*dst = *src * scale;
}
written_samples
};
self.output_sample_count += written;
Ok(&self.output_block[..written])
}
#[inline]
fn expected_total_output_samples(&self) -> Option<usize> {
if !self.final_input_seen {
return None;
}
Some(self.output_sample_count_for_input(self.input_sample_count))
}
#[inline]
fn remaining_output_budget_samples(&self) -> Option<usize> {
self.expected_total_output_samples()
.map(|expected_total| expected_total.saturating_sub(self.output_sample_count))
}
#[inline]
fn cap_write_to_output_budget(&self, candidate_samples: usize) -> usize {
self.remaining_output_budget_samples()
.map_or(candidate_samples, |remaining| candidate_samples.min(remaining))
}
pub(crate) fn process_chunk<'a>(&'a mut self, input: &'a [T], is_final: bool) -> Result<&'a [T], Error> {
if self.finalized {
self.reset();
}
let input_samples = input.len();
if self.final_input_seen {
return Err(Error::StreamAlreadyFinalized);
}
if self.is_passthrough() {
if is_final {
self.final_input_seen = true;
}
self.input_sample_count += input_samples;
let written_samples = self.cap_write_to_output_budget(input_samples);
self.output_sample_count += written_samples;
return Ok(&input[..written_samples]);
}
if is_final {
self.final_input_seen = true;
if input_samples == 0 {
return Ok(&input);
}
}
self.copy_input_to_window(input, input_samples);
let is_first_input = self.input_sample_count == 0;
if is_first_input {
self.synthesize_start_context(input_samples)?;
self.copy_input_to_window(input, input_samples);
}
let is_short_final = is_final && input_samples < self.input_chunk_len_samples();
if is_short_final {
self.synthesize_final_block_missing_samples(input, input_samples);
}
self.transform_chunk(TransformMode::Normal)?;
if !is_short_final {
self.save_current_window();
}
self.input_sample_count += input_samples;
let skip_samples = self.trim_remaining.min(self.output_chunk_len_samples());
self.trim_remaining -= skip_samples;
let chunk_samples_after_trim = self.output_chunk_len_samples() - skip_samples;
let candidate_samples = chunk_samples_after_trim;
let written_samples = self.cap_write_to_output_budget(candidate_samples);
let src_start = skip_samples;
self.output_sample_count += written_samples;
return Ok(&self.output_block[src_start..src_start + written_samples]);
}
#[inline]
fn is_passthrough(&self) -> bool {
self.derived.input_sample_rate == self.derived.output_sample_rate
}
#[inline]
fn normalize_context(&self, context: Vec<T>) -> Option<Vec<T>> {
if context.is_empty() { None } else { Some(context) }
}
#[inline]
fn copy_input_to_window(&mut self, input: &[T], input_samples: usize) {
self.scratch.rdft_in.fill(T::zero());
let dst = &mut self.scratch.rdft_in[self.derived.input_offset..self.derived.input_offset + input_samples];
dst.copy_from_slice(&input[..input_samples]);
}
fn copy_pre_tail(&self, dst: &mut [T]) -> usize {
let Some(pre) = &self.pre else {
return 0;
};
let copied = pre.len().min(dst.len());
let start = pre.len() - copied;
let dst_start = dst.len() - copied;
dst[dst_start..].copy_from_slice(&pre[start..start + copied]);
copied
}
fn copy_post_head(&self, dst: &mut [T]) -> usize {
let Some(post) = &self.post else {
return 0;
};
let copied = post.len().min(dst.len());
dst[..copied].copy_from_slice(&post[..copied]);
copied
}
fn synthesize_start_context(&mut self, input_samples: usize) -> Result<(), Error> {
if input_samples == 0 {
return Ok(());
}
let input_start = self.derived.input_offset;
let input_end = input_start + input_samples;
let mut predicted = vec![T::zero(); input_start];
let copied = self.copy_pre_tail(&mut predicted);
if copied < input_start {
let fallback_len = input_start - copied;
let fallback = extrapolate_backward(
&self.scratch.rdft_in[input_start..input_end],
fallback_len,
ExtrapolateFallback::Hold,
);
predicted[..fallback_len].copy_from_slice(&fallback);
}
self.scratch.rdft_in.fill(T::zero());
let tail_start = self.input_chunk_len_samples();
self.scratch.rdft_in[tail_start..tail_start + predicted.len()].copy_from_slice(&predicted);
self.transform_chunk(TransformMode::Start)?;
Ok(())
}
fn build_tail_prediction(&self, base: &[T], needed: usize) -> Vec<T> {
let mut predicted = vec![T::zero(); needed];
let copied = self.copy_post_head(&mut predicted);
if copied < needed {
let mut seed = Vec::with_capacity(base.len() + copied);
seed.extend_from_slice(base);
seed.extend_from_slice(&predicted[..copied]);
let fallback = extrapolate_forward(&seed, needed - copied, ExtrapolateFallback::Hold);
predicted[copied..].copy_from_slice(&fallback);
}
predicted
}
fn assemble_short_final_work_window(
&self,
input: &[T],
input_samples: usize,
chunk_samples: usize,
pad_samples: usize,
) -> Vec<T> {
let mut work = vec![T::zero(); chunk_samples * 2];
work[..pad_samples].copy_from_slice(&self.prev_input_window[input_samples..input_samples + pad_samples]);
work[pad_samples..pad_samples + input_samples].copy_from_slice(&input[..input_samples]);
work
}
fn fill_short_final_predicted_tail(&self, work: &mut [T], chunk_samples: usize) -> Vec<T> {
let predicted = self.build_tail_prediction(&work[..chunk_samples], chunk_samples);
work[chunk_samples..chunk_samples * 2].copy_from_slice(&predicted);
predicted
}
fn commit_short_final_history(
&mut self,
input_samples: usize,
pad_samples: usize,
predicted: &[T],
chunk_samples: usize,
) {
if input_samples == 0 {
return;
}
self.prev_input_window[..input_samples].copy_from_slice(&predicted[pad_samples..pad_samples + input_samples]);
self.prev_input_window[input_samples..chunk_samples].fill(T::zero());
self.prev_input_window[chunk_samples..chunk_samples * 2].fill(T::zero());
}
fn stage_short_final_rdft_input_from_work(&mut self, work: &[T], pad_samples: usize, chunk_samples: usize) {
self.scratch.rdft_in.fill(T::zero());
let window_start = self.derived.input_offset;
self.scratch.rdft_in[window_start..window_start + chunk_samples]
.copy_from_slice(&work[pad_samples..pad_samples + chunk_samples]);
}
fn synthesize_final_block_missing_samples(&mut self, input: &[T], input_samples: usize) {
let chunk_samples = self.input_chunk_len_samples();
let pad_samples = chunk_samples - input_samples;
let mut work = self.assemble_short_final_work_window(input, input_samples, chunk_samples, pad_samples);
let predicted = self.fill_short_final_predicted_tail(&mut work, chunk_samples);
self.commit_short_final_history(input_samples, pad_samples, &predicted, chunk_samples);
self.stage_short_final_rdft_input_from_work(&work, pad_samples, chunk_samples);
}
fn transform_chunk(&mut self, mode: TransformMode) -> Result<(), Error> {
self.forward
.process(&mut self.scratch.rdft_in, &mut self.scratch.spectrum)
.map_err(|err| Error::Fft(err.to_string()))?;
let zero = Complex::new(T::zero(), T::zero());
let bins = self
.scratch
.resampled_spectrum
.len()
.min(self.scratch.spectrum.len())
.min(self.derived.taper.len());
for (dst, (src, taper)) in self.scratch.resampled_spectrum[..bins].iter_mut().zip(
self.scratch.spectrum[..bins]
.iter()
.zip(self.derived.taper[..bins].iter()),
) {
*dst = *src * *taper;
}
if bins < self.scratch.resampled_spectrum.len() {
self.scratch.resampled_spectrum[bins..].fill(zero);
}
if let Some(dc_bin) = self.scratch.resampled_spectrum.get_mut(0) {
dc_bin.im = T::zero();
}
if self.scratch.resampled_spectrum.len() > 1 {
let nyquist_bin = self.scratch.resampled_spectrum.len() - 1;
self.scratch.resampled_spectrum[nyquist_bin].im = T::zero();
}
self.inverse
.process(&mut self.scratch.resampled_spectrum, &mut self.scratch.rdft_out)
.map_err(|err| Error::Fft(err.to_string()))?;
let normalize = T::one() / T::from(self.derived.output_fft_size).unwrap_or(T::one());
let scale = T::from(self.output_chunk_len_samples()).unwrap_or(T::one())
/ T::from(self.input_chunk_len_samples()).unwrap_or(T::one());
let output_chunk_samples = self.output_chunk_len_samples();
if matches!(mode, TransformMode::Normal) {
for sample_idx in 0..output_chunk_samples {
self.output_block[sample_idx] =
(self.scratch.rdft_out[sample_idx] * normalize + self.overlap[sample_idx]) * scale;
}
}
if matches!(mode, TransformMode::End) {
for (overlap, rdft) in self.overlap[..output_chunk_samples]
.iter_mut()
.zip(self.scratch.rdft_out[..output_chunk_samples].iter())
{
*overlap = *overlap + *rdft * normalize;
}
}
if matches!(mode, TransformMode::Normal | TransformMode::Start) {
for (overlap, rdft) in self.overlap[..output_chunk_samples]
.iter_mut()
.zip(self.scratch.rdft_out[output_chunk_samples..output_chunk_samples * 2].iter())
{
*overlap = *rdft * normalize;
}
}
Ok(())
}
fn save_current_window(&mut self) {
let chunk_samples = self.input_chunk_len_samples();
let history_start = self.derived.input_offset;
let history_end = history_start + chunk_samples;
self.prev_input_window[..chunk_samples].copy_from_slice(&self.scratch.rdft_in[history_start..history_end]);
self.prev_input_window[chunk_samples..].fill(T::zero());
}
fn add_synthetic_finalize_tail_to_overlap(&mut self) -> Result<(), Error> {
if self.final_input_seen && !self.input_sample_count.is_multiple_of(self.input_buffer_size()) {
return Ok(());
}
self.scratch.rdft_in.fill(T::zero());
let chunk_samples = self.input_chunk_len_samples();
let input_offset = self.derived.input_offset;
let base = self.prev_input_window[..chunk_samples].to_vec();
let predicted = self.build_tail_prediction(&base, input_offset);
self.prev_input_window[chunk_samples..chunk_samples * 2].fill(T::zero());
self.prev_input_window[chunk_samples..chunk_samples + predicted.len()].copy_from_slice(&predicted);
let input_start = self.derived.input_offset;
self.scratch.rdft_in[input_start..input_start + chunk_samples]
.copy_from_slice(&self.prev_input_window[chunk_samples..chunk_samples + chunk_samples]);
self.transform_chunk(TransformMode::End)?;
Ok(())
}
}