Expand description
Learned FFT — butterfly network trained to match reference FFT, compiled via RLX.
§Overview
This crate learns twiddle factors in a Cooley–Tukey butterfly network so the
transform matches rustfft on random signals. After training, the same graph
can be compiled to CPU/GPU backends for batched inference.
§Quick start
use rlx_fft::{FftLearnConfig, FftLearnRunner, TrainConfig, train_butterfly};
let cfg = FftLearnConfig::new(256, 8)?;
let report = train_butterfly(&TrainConfig {
model: cfg.clone(),
steps: 200,
..TrainConfig::default()
})?;
println!("mse={} max_err={}", report.final_mse, report.max_error);
let runner = FftLearnRunner::with_weights(cfg, &report.weights)?;§Welch peaks
Fast top-K spike extraction with an automatic or forced strategy picker
(AutoWelchPeaks, --strategy on bench-welch-peaks). See crates/rlx-fft/README.md
(Welch peaks section) in this repo.
Re-exports§
pub use ablation::AblationReport;pub use ablation::AblationRow;pub use ablation::ablation_row_ok;pub use ablation::ablation_winners;pub use ablation::limit_sweep_devices;pub use ablation::merge_ablation_reports;pub use ablation::print_ablation_table;pub use ablation::run_ablation;pub use ablation::run_limit_sweep;pub use ablation::tier_summary;pub use ablation::top5_variants_per_n_fft;pub use ablation::write_ablation_json;pub use ablation_csv::LIMITS_CSV;pub use ablation_csv::META_CSV;pub use ablation_csv::ROWS_CSV;pub use ablation_csv::TOP5_CSV;pub use ablation_csv::read_ablation_csv_dir;pub use ablation_csv::read_ablation_rows_csv;pub use ablation_csv::write_ablation_csv_dir;pub use ablation_html::read_ablation_json;pub use ablation_html::render_ablation_html;pub use ablation_html::write_ablation_html;pub use ablation_ternary::TernaryAblationOpts;pub use ablation_ternary::TernaryAblationReport;pub use ablation_ternary::TernaryAblationRow;pub use ablation_ternary::TernaryArchVariantId;pub use ablation_ternary::TernaryExecMode;pub use ablation_ternary::TernaryParetoPoint;pub use ablation_ternary::print_ternary_ablation_table;pub use ablation_ternary::quick_ablation_opts;pub use ablation_ternary::run_ternary_ablation;pub use ablation_ternary::ternary_ablation_row_ok;pub use ablation_ternary::ternary_aggregate_variants;pub use ablation_ternary::ternary_pareto_frontier;pub use ablation_ternary::ternary_recommendation;pub use ablation_ternary::write_ternary_ablation_csv;pub use ablation_ternary::write_ternary_ablation_json;pub use ablation_ternary_html::read_ternary_ablation_json;pub use ablation_ternary_html::render_ternary_ablation_html;pub use ablation_ternary_html::write_ternary_ablation_html;pub use bench::BenchReport;pub use bench::bench_all;pub use bench::bench_all_dir;pub use bench::bench_reference_vs_learned;pub use bench::bench_reference_vs_learned_dir;pub use bench_encdec::EncDecBenchRow;pub use bench_encdec::bench_encdec_weights;pub use bench_encdec::bench_exact_baseline;pub use bench_encdec::bench_phased_dir;pub use bench_encdec::print_encdec_bench_table;pub use bench_encdec::write_encdec_bench_json;pub use bench_sweep::SweepReport;pub use bench_sweep::SweepRow;pub use bench_sweep::available_devices;pub use bench_sweep::parse_batch_spec;pub use bench_sweep::parse_csv_usize;pub use bench_sweep::parse_k_spec;pub use bench_sweep::print_sweep_chart;pub use bench_sweep::run_sweep;pub use bench_sweep::sweep_markdown_chart;pub use bench_sweep::write_sweep_json;pub use bench_sweep_html::read_sweep_json;pub use bench_sweep_html::render_sweep_html;pub use bench_sweep_html::write_sweep_html;pub use bench_welch_peaks::WelchPeaksBenchOpts;pub use bench_welch_peaks::WelchPeaksBenchReport;pub use bench_welch_peaks::WelchPeaksBenchRow;pub use bench_welch_peaks::print_welch_peaks_table;pub use bench_welch_peaks::run_welch_peaks_batch_sweep;pub use bench_welch_peaks::run_welch_peaks_bench;pub use bench_welch_peaks::run_welch_peaks_bench_opts;pub use bench_welch_peaks::run_welch_peaks_k_sweep;pub use bench_welch_peaks::run_welch_peaks_sweep;pub use bench_welch_peaks::write_welch_peaks_json;pub use config::EncDecTrainConfig;pub use config::FftLearnConfig;pub use config::MultiTrainConfig;pub use config::MultiTrainSchedule;pub use config::PhasedTrainConfig;pub use config::SUPPORTED_N_FFT;pub use config::TrainConfig;pub use config::TransformDir;pub use config::parse_transform_dir;pub use device::bench_device_label;pub use device::ensure_backend_ready;pub use device::normalize_device_alias;pub use device::parse_bench_device_list;pub use device::pick_auto_device;pub use device::resolve_train_device;pub use distill_compile::CompiledDistilledMel;pub use distill_compile::compile_distilled_mel;pub use distill_model::DistilledFftModel;pub use distill_ternary_compile::CompiledDistilledTernaryMel;pub use distill_ternary_compile::compile_distilled_ternary_mel;pub use distill_ternary_model::DistilledTernaryFftModel;pub use e2e_bench::E2eBackend;pub use e2e_bench::E2eBatchTrainMeta;pub use e2e_bench::E2eBenchMeta;pub use e2e_bench::E2eBenchReport;pub use e2e_bench::E2eBenchRow;pub use e2e_bench::E2ePipeline;pub use e2e_bench::merge_e2e_reports;pub use e2e_bench::print_e2e_table;pub use e2e_bench::read_e2e_json;pub use e2e_bench::run_e2e_bench;pub use e2e_bench::write_e2e_json;pub use e2e_bench_html::render_e2e_html;pub use e2e_bench_html::write_e2e_html;pub use learned_model::FastLearnedFftModel;pub use peak::DEFAULT_PEAK_K;pub use peak::WelchPeakParams;pub use peak::WelchPeaksScratch;pub use peak::peak_band_mask;pub use peak::peak_loss_grad_wrt_spectrum;pub use peak::peak_match_loss;pub use peak::peak_max_err;pub use peak::peaks_from_psd_batch;pub use peak::peaks_from_segment_spectrum_streaming;pub use peak::topk_peaks_one;pub use peak::welch_peaks_from_segment_spectrum;pub use peak::welch_peaks_rustfft;pub use peak::welch_peaks_rustfft_with_scratch;pub use runner::FftLearnRunner;pub use second_order::TwiddleOptState;pub use second_order::TwiddleOptimizer;pub use second_order::diag_gn_step;pub use second_order::hvp_twiddles_finite_diff;pub use study_html::StudyInputs;pub use study_html::render_study_html;pub use study_html::write_study_html;pub use ternary_arch::CorrectorKind;pub use ternary_arch::GateLayout;pub use ternary_arch::SpectrumCorrection;pub use ternary_arch::TernaryArchConfig;pub use ternary_gates::GateMode;pub use ternary_gates::compute_fraction;pub use ternary_gates::gate_mode_counts;pub use train::EncDecTrainResult;pub use train::TrainResult;pub use train::evaluate_encdec_weights;pub use train::evaluate_weights;pub use train::evaluate_weights_dir;pub use train::random_complex_batch;pub use train::train_butterfly;pub use train::train_butterfly_dir;pub use train::train_butterfly_eager;pub use train::train_encdec;pub use train::train_encdec_eager;pub use train_distill::DistillTrainConfig;pub use train_distill::DistillTrainReport;pub use train_distill::distill_from_teacher;pub use train_distill_ternary::DistillTernaryTrainConfig;pub use train_distill_ternary::DistillTernaryTrainReport;pub use train_distill_ternary::distill_ternary_from_distilled;pub use train_distill_ternary::distill_ternary_from_teacher;pub use train_e2e::E2eTrainConfig;pub use train_e2e::E2eTrainReport;pub use train_e2e::train_fast_learned_model;pub use train_multi::MultiTrainEvalRow;pub use train_multi::MultiTrainReport;pub use train_multi::best_regime_per_eval;pub use train_multi::print_multi_train_table;pub use train_multi::run_multi_train;pub use train_multi::write_multi_train_json;pub use train_multi_html::read_multi_train_json;pub use train_multi_html::render_multi_train_html;pub use train_multi_html::write_multi_train_html;pub use train_phased::PhaseMetrics;pub use train_phased::PhasedTrainResult;pub use train_phased::precision_encdec;pub use train_phased::train_phased_encdec;pub use twiddle::TwiddleSet;pub use twiddle::exact_twiddles;pub use twiddle::exact_twiddles_dir;pub use twiddle_stability::lr_for_n_fft;pub use twiddle_stability::max_twiddle_magnitude;pub use twiddle_stability::project_twiddles_unit_circle;pub use twiddle_stability::twiddle_drift_from_unit;pub use variants::FftVariantId;pub use variants::VariantState;pub use weights::EncDecWeights;pub use weights::WeightStore;pub use weights::export_safetensors;pub use weights::load_safetensors;pub use welch_peaks_compile::CompiledLearnedWelchPeaks;pub use welch_peaks_compile::CompiledRlxWelchPeaks;pub use welch_peaks_compile::CompiledRlxWelchPeaksExec;pub use welch_peaks_compile::CompiledRlxWelchPeaksFused;pub use welch_peaks_compile::RlxWelchPeaksExecKind;pub use welch_peaks_compile::compile_learned_welch_peaks;pub use welch_peaks_compile::compile_rlx_welch_peaks;pub use welch_peaks_compile::compile_welch_peaks_fused;pub use welch_peaks_compile::default_welch_peaks_hard_threshold;pub use welch_peaks_compile::rlx_welch_peaks_exec_kind;pub use welch_peaks_cost::WelchPeaksCostEstimates;pub use welch_peaks_cost::WelchPeaksFusionGateBreakdown;pub use welch_peaks_cost::algorithm_bandwidth_gbps;pub use welch_peaks_cost::ayala_io_cost_ns;pub use welch_peaks_cost::estimate_welch_peaks_costs;pub use welch_peaks_cost::fused_welch_peaks_auto_viable;pub use welch_peaks_cost::rustfft_peaks_io_profile;pub use welch_peaks_cost::useful_bytes_touched;pub use welch_peaks_cost::welch_peaks_fusion_gate_breakdown;pub use welch_peaks_cost::welch_peaks_fusion_target;pub use welch_peaks_cost::welch_peaks_io_fusion_gate;pub use welch_peaks_picker::AutoWelchPeaks;pub use welch_peaks_picker::WelchPeaksPickBreakdown;pub use welch_peaks_picker::WelchPeaksPickMode;pub use welch_peaks_picker::WelchPeaksStrategy;pub use welch_peaks_picker::all_welch_peaks_strategy_names;pub use welch_peaks_picker::parse_welch_peaks_strategy;pub use welch_peaks_picker::pick_welch_peaks_breakdown;pub use welch_peaks_picker::pick_welch_peaks_strategy;pub use welch_peaks_picker::resolve_welch_peaks_strategy;pub use welch_peaks_picker::rlx_crossover_batch;pub use welch_peaks_picker::ultra_fast_max_batch;
Modules§
- ablation
- Ablation study across FFT variants (Tiers A/B/C).
- ablation_
csv - Ablation results as CSV — source of truth for study HTML reports.
- ablation_
html - Self-contained HTML ablation report (Chart.js + Plotly heatmaps).
- ablation_
ternary - Ternary architecture ablation — speed vs per-pipeline error (Pareto).
- ablation_
ternary_ html - HTML report for ternary architecture ablation.
- band_
correct - Banded wide-sparse correction after pruned butterfly (ternary distill).
- bench
- Benchmark learned butterfly FFT vs
rustfftand native RLXOp::Fft. - bench_
encdec - Encoder–decoder speed and precision benchmarks.
- bench_
sweep - Multi-dimensional benchmark sweep → JSON + ASCII chart.
- bench_
sweep_ html - Self-contained HTML benchmark report (Chart.js + Plotly heatmaps / 3D).
- bench_
welch_ peaks - Benchmark Welch vs fast top-K peaks (rustfft + compiled + learned).
- butterfly
- Learnable Cooley–Tukey butterfly network (eager CPU + optional IR graph).
- cli
- CLI for training, evaluation, and benchmarking learned FFT / IFFT.
- compile
- Compile RLX training backward graphs on CPU / GPU backends.
- config
- Configuration for learned FFT models.
- denoise
- Learned spectrum denoiser — per-bin affine correction (legacy / teacher paths).
- device
- Training device selection (
--device auto|cpu|metal|…). - distill_
compile - Compiled distilled deploy — fused Hann → FFT → correction →
Op::LogMel. - distill_
fused - Shared fused deploy helpers — Hann window, correction, log-mel, device routing.
- distill_
model - Distilled fast deploy model — Op::Fft + learned correction (+ optional mel adapter).
- distill_
ternary_ compile - Compiled ternary-routed distilled deploy — pruned Hann → sparse butterfly → mel.
- distill_
ternary_ model - Distilled ternary-routed FFT — sparse exact butterfly + correction.
- domain
- Domain-adaptive twiddle training (Tier C).
- e2e_
bench - End-to-end validation bench — mel, welch, q8, denoise vs references.
- e2e_
bench_ html - Self-contained HTML report for end-to-end learned FFT validation benches.
- fused
- Fused FFT → spectral mask → IFFT (Tier A).
- fused_
train - Fused encoder–decoder training — single forward/backward pass, batched reference FFT.
- learned_
compile - Compiled learned spectrum + mel deploy (Tier D) — gated butterfly + mask + denoiser.
- learned_
model - Fast learned FFT model — pruned butterfly + optional Q8 + denoiser + freq mask (Tier D).
- mel
- Log-mel frontend from power spectrum (Whisper-style, Tier D validation).
- peak
- Welch peak extraction — top-K frequency spikes from PSD (fast path uses fewer segments).
- pruned
- Gated / pruned butterfly — skip butterflies via learnable gates (Tier D).
- q8
- Q8 quantized twiddles for inference (Tier B).
- reference
- Exact FFT reference via
rustfft. - rlx_fft
- Native RLX
Op::Fftgraphs — forward and inverse, shared by bench / ablation / variants. - runner
- Compiled inference session for learned FFT / IFFT.
- second_
order - Second-order / adaptive twiddle optimizers (Adam, diagonal preconditioning, HVP).
- stockham
- DIF (decimation-in-frequency) Cooley–Tukey FFT — natural input, bit-reverse output.
- study_
collect - Collect training telemetry for comprehensive study HTML reports.
- study_
full_ html - Full study HTML: ranked ablation configs, loss/error/memory charts, 3D loss maps, activation heatmaps.
- study_
html - Unified HTML study report — delegates to
crate::study_full_html. - study_
telemetry - Training telemetry for study reports: loss curves, param counts, activation heatmaps, loss landscape.
- ternary_
arch - Architecture knobs for ternary distilled FFT (ablation + deploy).
- ternary_
gates - Ternary butterfly routing — skip (0), forward (+1), reverse (−1).
- train
- Training loop for butterfly twiddle factors.
- train_
distill - Distill teacher learned model → fast Op::Fft + correction student.
- train_
distill_ ternary - Distill teacher → ternary-routed student with compute / precision tradeoff.
- train_
e2e - Train Tier-D fast learned FFT model (pruned + mask + denoiser + optional Q8).
- train_
graph - RLX training graphs — forward + MSE loss + autodiff backward.
- train_
multi - Multi-n_fft encoder–decoder training study and train×eval matrix.
- train_
multi_ html - HTML report for multi-n_fft training study (train regime × eval n_fft matrix).
- train_
phased - Three-phase encoder → decoder → joint training.
- train_
rlx - Training via compiled RLX backward graphs (all backends).
- twiddle
- Twiddle-factor initialization for butterfly stages.
- twiddle_
stability - Twiddle update stability — unit-circle projection, LR scaling, gradient clipping.
- unitary
- Learnable 2×2 complex butterfly mixing matrices (Tier C).
- variants
- FFT implementation variants for ablation (Tiers A/B/C + baselines).
- weights
- Named twiddle parameters for training and compiled inference.
- welch
- Welch PSD — windowed overlapping segments, FFT, averaged one-sided power.
- welch_
peaks_ compile - Compiled Welch peaks — RLX Op::Fft or learned spectrum + streaming top-K.
- welch_
peaks_ cost - Ayala-style latency–bandwidth cost model for Welch peaks strategy selection.
- welch_
peaks_ picker - Automatic Welch peaks path — pick fastest strategy from batch size + device.