use core::ops::Range;
use crate::api::{Direction, Flags, Plan};
use crate::kernel::{Complex, Float};
use crate::prelude::*;
#[derive(Debug, Clone)]
pub enum PartialStrategy {
Goertzel {
bins: Vec<usize>,
},
OutputPruned {
m: usize,
},
FullThenSlice {
range: Range<usize>,
},
}
pub struct PartialFft<T: Float> {
n: usize,
strategy: PartialStrategy,
_phantom: core::marker::PhantomData<T>,
}
impl<T: Float> PartialFft<T> {
pub fn new_sparse(n: usize, bins: &[usize]) -> Self {
let strategy = choose_strategy_sparse(n, bins);
Self {
n,
strategy,
_phantom: core::marker::PhantomData,
}
}
pub fn new_prefix(n: usize, m: usize) -> Self {
let strategy = choose_strategy_prefix(n, m);
Self {
n,
strategy,
_phantom: core::marker::PhantomData,
}
}
#[must_use]
pub fn strategy(&self) -> &PartialStrategy {
&self.strategy
}
pub fn execute(&self, input: &[Complex<T>], output: &mut [Complex<T>]) {
assert_eq!(
input.len(),
self.n,
"PartialFft::execute: input length {} != n {}",
input.len(),
self.n
);
match &self.strategy {
PartialStrategy::Goertzel { bins } => {
assert_eq!(
output.len(),
bins.len(),
"PartialFft::execute: output.len() must equal bins.len()"
);
execute_goertzel(input, bins, output);
}
PartialStrategy::OutputPruned { m } => {
assert_eq!(
output.len(),
*m,
"PartialFft::execute: output.len() must equal m"
);
execute_output_pruned(input, *m, output);
}
PartialStrategy::FullThenSlice { range } => {
let len = range.end - range.start;
assert_eq!(
output.len(),
len,
"PartialFft::execute: output.len() must equal range length"
);
execute_full_then_slice(input, range, output);
}
}
}
}
fn choose_strategy_sparse(n: usize, bins: &[usize]) -> PartialStrategy {
let k = bins.len();
let log_n = log2_ceil(n);
if k < log_n {
PartialStrategy::Goertzel {
bins: bins.to_vec(),
}
} else {
if let Some(range) = bins_as_range(bins) {
PartialStrategy::FullThenSlice { range }
} else {
PartialStrategy::Goertzel {
bins: bins.to_vec(),
}
}
}
}
fn choose_strategy_prefix(n: usize, m: usize) -> PartialStrategy {
let log_n = log2_ceil(n);
if m.is_power_of_two() && m <= n / 2 && m < log_n {
return PartialStrategy::OutputPruned { m };
}
if m < log_n {
return PartialStrategy::Goertzel {
bins: (0..m).collect(),
};
}
PartialStrategy::FullThenSlice { range: 0..m }
}
fn bins_as_range(bins: &[usize]) -> Option<Range<usize>> {
if bins.is_empty() {
return Some(0..0);
}
let start = bins[0];
for (i, &b) in bins.iter().enumerate() {
if b != start + i {
return None;
}
}
Some(start..start + bins.len())
}
fn execute_goertzel<T: Float>(input: &[Complex<T>], bins: &[usize], output: &mut [Complex<T>]) {
let results = super::goertzel_multi(input, bins);
output.copy_from_slice(&results);
}
fn execute_output_pruned<T: Float>(input: &[Complex<T>], m: usize, output: &mut [Complex<T>]) {
let bins: Vec<usize> = (0..m).collect();
let results = super::goertzel_multi(input, &bins);
output.copy_from_slice(&results);
}
fn execute_full_then_slice<T: Float>(
input: &[Complex<T>],
range: &Range<usize>,
output: &mut [Complex<T>],
) {
let n = input.len();
let plan = match Plan::dft_1d(n, Direction::Forward, Flags::ESTIMATE) {
Some(p) => p,
None => {
for o in output.iter_mut() {
*o = Complex::<T>::zero();
}
return;
}
};
let mut full_output = vec![Complex::<T>::zero(); n];
plan.execute(input, &mut full_output);
let len = range.end - range.start;
output[..len].copy_from_slice(&full_output[range.clone()]);
}
fn log2_ceil(n: usize) -> usize {
if n <= 1 {
return 0;
}
let bits = usize::BITS as usize;
bits - (n - 1).leading_zeros() as usize
}