use crate::api::{Direction, Flags, Plan};
use crate::kernel::{Complex, Float};
use super::bucket::BucketArray;
use super::decoder::PeelingDecoder;
use super::filter::{create_optimal_filter, AliasingFilter};
use super::hash::{generate_coprime_factors, FrequencyHash};
use super::problem::SparseProblem;
use super::result::SparseResult;
#[cfg(not(feature = "std"))]
extern crate alloc;
#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
pub struct SparsePlan<T: Float> {
n: usize,
k: usize,
num_buckets: usize,
num_stages: usize,
subsample_factors: Vec<usize>,
hash_functions: Vec<FrequencyHash>,
filter: AliasingFilter<T>,
bucket_plans: Vec<Plan<T>>,
threshold: T,
flags: Flags,
}
impl<T: Float> SparsePlan<T> {
pub fn new(n: usize, k: usize, flags: Flags) -> Option<Self> {
if n == 0 || k == 0 || k > n {
return None;
}
let problem: SparseProblem<T> = SparseProblem::new(n, k, Direction::Forward);
let num_buckets = problem.optimal_buckets();
let num_stages = problem.optimal_repetitions().max(2);
let subsample_factors = generate_coprime_factors(num_buckets, n / 2);
let num_stages = num_stages.min(subsample_factors.len());
let hash_functions: Vec<_> = (0..num_stages)
.map(|i| {
let bucket_count = if i < subsample_factors.len() {
subsample_factors[i]
} else {
num_buckets
};
FrequencyHash::new(bucket_count, n)
})
.collect();
let filter = create_optimal_filter(n, k, num_buckets);
let bucket_plans: Vec<_> = hash_functions
.iter()
.filter_map(|h| Plan::dft_1d(h.num_buckets(), Direction::Forward, flags))
.collect();
if bucket_plans.len() != num_stages {
return None; }
let threshold = T::from_f64(1e-10);
Some(Self {
n,
k,
num_buckets,
num_stages,
subsample_factors,
hash_functions,
filter,
bucket_plans,
threshold,
flags,
})
}
pub fn execute(&self, input: &[Complex<T>]) -> SparseResult<T> {
if input.len() != self.n {
return SparseResult::empty();
}
let mut bucket_stages = self.compute_bucket_stages(input);
let mut decoder = PeelingDecoder::new(self.n, self.k, self.threshold);
decoder.decode(&mut bucket_stages)
}
fn compute_bucket_stages(&self, input: &[Complex<T>]) -> Vec<BucketArray<T>> {
let mut stages = Vec::with_capacity(self.num_stages);
for stage_idx in 0..self.num_stages {
let hash = &self.hash_functions[stage_idx];
let bucket_count = hash.num_buckets();
let subsample_factor = if stage_idx < self.subsample_factors.len() {
self.subsample_factors[stage_idx]
} else {
1
};
let subsampled = self.subsample(input, subsample_factor, bucket_count);
let filtered = self.apply_filter(&subsampled);
let bucket_fft = self.bucket_fft(&filtered, stage_idx);
let mut buckets = BucketArray::new(bucket_count, subsample_factor, self.n);
buckets.fill_from_fft(&bucket_fft);
stages.push(buckets);
}
if stages.len() >= 2 {
let (first, rest) = stages.split_at_mut(1);
if !rest.is_empty() {
first[0].analyze_singletons(&rest[0], self.threshold);
}
}
stages
}
fn subsample(&self, input: &[Complex<T>], factor: usize, output_len: usize) -> Vec<Complex<T>> {
let mut output = vec![Complex::<T>::zero(); output_len];
for (i, &val) in input.iter().enumerate() {
let out_idx = i % output_len;
output[out_idx] = output[out_idx] + val;
}
let two_pi = <T as Float>::PI + <T as Float>::PI;
for (i, out) in output.iter_mut().enumerate() {
let phase = two_pi * T::from_usize(i * factor) / T::from_usize(self.n);
let (sin_p, cos_p) = Float::sin_cos(phase);
let twiddle = Complex::new(cos_p, -sin_p);
*out = *out * twiddle;
}
output
}
fn apply_filter(&self, signal: &[Complex<T>]) -> Vec<Complex<T>> {
let filter_len = self.filter.coeffs.len();
signal
.iter()
.enumerate()
.map(|(i, &val)| {
let filter_idx = i % filter_len;
val * self.filter.coeffs[filter_idx]
})
.collect()
}
fn bucket_fft(&self, input: &[Complex<T>], stage_idx: usize) -> Vec<Complex<T>> {
let bucket_count = self.hash_functions[stage_idx].num_buckets();
let input_adjusted: Vec<Complex<T>> = if input.len() >= bucket_count {
input[..bucket_count].to_vec()
} else {
let mut adjusted = input.to_vec();
adjusted.resize(bucket_count, Complex::<T>::zero());
adjusted
};
let mut output = vec![Complex::<T>::zero(); bucket_count];
if stage_idx < self.bucket_plans.len() {
self.bucket_plans[stage_idx].execute(&input_adjusted, &mut output);
}
output
}
pub fn n(&self) -> usize {
self.n
}
pub fn k(&self) -> usize {
self.k
}
pub fn num_buckets(&self) -> usize {
self.num_buckets
}
pub fn num_stages(&self) -> usize {
self.num_stages
}
pub fn flags(&self) -> Flags {
self.flags
}
pub fn set_threshold(&mut self, threshold: T) {
self.threshold = threshold;
}
pub fn threshold(&self) -> T {
self.threshold
}
pub fn estimated_ops(&self) -> usize {
let b = self.num_buckets;
let log_b = (b as f64).log2().ceil() as usize;
let bucket_fft_ops = self.num_stages * b * log_b;
let subsample_ops = self.n;
let decode_ops = self.k * self.num_stages;
bucket_fft_ops + subsample_ops + decode_ops
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sparse_plan_creation() {
let plan: Option<SparsePlan<f64>> = SparsePlan::new(1024, 10, Flags::ESTIMATE);
assert!(plan.is_some());
let plan = plan.unwrap();
assert_eq!(plan.n(), 1024);
assert_eq!(plan.k(), 10);
assert!(plan.num_stages() >= 2);
}
#[test]
fn test_sparse_plan_invalid() {
assert!(SparsePlan::<f64>::new(0, 10, Flags::ESTIMATE).is_none());
assert!(SparsePlan::<f64>::new(1024, 0, Flags::ESTIMATE).is_none());
assert!(SparsePlan::<f64>::new(10, 100, Flags::ESTIMATE).is_none());
}
#[test]
fn test_sparse_plan_execute() {
let n = 256;
let k = 5;
let plan = SparsePlan::<f64>::new(n, k, Flags::ESTIMATE).unwrap();
let mut input = vec![Complex::new(0.0_f64, 0.0); n];
let two_pi = core::f64::consts::PI * 2.0;
for i in 0..n {
let t = i as f64 / n as f64;
input[i].re += (two_pi * 10.0 * t).cos();
input[i].im += (two_pi * 10.0 * t).sin();
}
let result = plan.execute(&input);
assert!(!result.is_empty());
}
#[test]
fn test_estimated_ops() {
let plan = SparsePlan::<f64>::new(1024, 10, Flags::ESTIMATE).unwrap();
let ops = plan.estimated_ops();
assert!(ops < 5000);
}
#[test]
fn test_threshold() {
let mut plan = SparsePlan::<f64>::new(256, 5, Flags::ESTIMATE).unwrap();
plan.set_threshold(0.001);
assert_eq!(plan.threshold(), 0.001);
}
}