Skip to main content

Crate rlx_fft

Crate rlx_fft 

Source
Expand description

Learned FFT — butterfly network trained to match reference FFT, compiled via RLX.

§Overview

This crate learns twiddle factors in a Cooley–Tukey butterfly network so the transform matches rustfft on random signals. After training, the same graph can be compiled to CPU/GPU backends for batched inference.

§Quick start

use rlx_fft::{FftLearnConfig, FftLearnRunner, TrainConfig, train_butterfly};

let cfg = FftLearnConfig::new(256, 8)?;
let report = train_butterfly(&TrainConfig {
    model: cfg.clone(),
    steps: 200,
    ..TrainConfig::default()
})?;
println!("mse={} max_err={}", report.final_mse, report.max_error);

let runner = FftLearnRunner::with_weights(cfg, &report.weights)?;

§Welch peaks

Fast top-K spike extraction with an automatic or forced strategy picker (AutoWelchPeaks, --strategy on bench-welch-peaks). See crates/rlx-fft/README.md (Welch peaks section) in this repo.

Re-exports§

pub use ablation::AblationReport;
pub use ablation::AblationRow;
pub use ablation::ablation_row_ok;
pub use ablation::ablation_winners;
pub use ablation::limit_sweep_devices;
pub use ablation::merge_ablation_reports;
pub use ablation::print_ablation_table;
pub use ablation::run_ablation;
pub use ablation::run_limit_sweep;
pub use ablation::tier_summary;
pub use ablation::top5_variants_per_n_fft;
pub use ablation::write_ablation_json;
pub use ablation_csv::LIMITS_CSV;
pub use ablation_csv::META_CSV;
pub use ablation_csv::ROWS_CSV;
pub use ablation_csv::TOP5_CSV;
pub use ablation_csv::read_ablation_csv_dir;
pub use ablation_csv::read_ablation_rows_csv;
pub use ablation_csv::write_ablation_csv_dir;
pub use ablation_html::read_ablation_json;
pub use ablation_html::render_ablation_html;
pub use ablation_html::write_ablation_html;
pub use ablation_ternary::TernaryAblationOpts;
pub use ablation_ternary::TernaryAblationReport;
pub use ablation_ternary::TernaryAblationRow;
pub use ablation_ternary::TernaryArchVariantId;
pub use ablation_ternary::TernaryExecMode;
pub use ablation_ternary::TernaryParetoPoint;
pub use ablation_ternary::print_ternary_ablation_table;
pub use ablation_ternary::quick_ablation_opts;
pub use ablation_ternary::run_ternary_ablation;
pub use ablation_ternary::ternary_ablation_row_ok;
pub use ablation_ternary::ternary_aggregate_variants;
pub use ablation_ternary::ternary_pareto_frontier;
pub use ablation_ternary::ternary_recommendation;
pub use ablation_ternary::write_ternary_ablation_csv;
pub use ablation_ternary::write_ternary_ablation_json;
pub use ablation_ternary_html::read_ternary_ablation_json;
pub use ablation_ternary_html::render_ternary_ablation_html;
pub use ablation_ternary_html::write_ternary_ablation_html;
pub use bench::BenchReport;
pub use bench::bench_all;
pub use bench::bench_all_dir;
pub use bench::bench_reference_vs_learned;
pub use bench::bench_reference_vs_learned_dir;
pub use bench_encdec::EncDecBenchRow;
pub use bench_encdec::bench_encdec_weights;
pub use bench_encdec::bench_exact_baseline;
pub use bench_encdec::bench_phased_dir;
pub use bench_encdec::print_encdec_bench_table;
pub use bench_encdec::write_encdec_bench_json;
pub use bench_sweep::SweepReport;
pub use bench_sweep::SweepRow;
pub use bench_sweep::available_devices;
pub use bench_sweep::parse_batch_spec;
pub use bench_sweep::parse_csv_usize;
pub use bench_sweep::parse_k_spec;
pub use bench_sweep::print_sweep_chart;
pub use bench_sweep::run_sweep;
pub use bench_sweep::sweep_markdown_chart;
pub use bench_sweep::write_sweep_json;
pub use bench_sweep_html::read_sweep_json;
pub use bench_sweep_html::render_sweep_html;
pub use bench_sweep_html::write_sweep_html;
pub use bench_welch_peaks::WelchPeaksBenchOpts;
pub use bench_welch_peaks::WelchPeaksBenchReport;
pub use bench_welch_peaks::WelchPeaksBenchRow;
pub use bench_welch_peaks::print_welch_peaks_table;
pub use bench_welch_peaks::run_welch_peaks_batch_sweep;
pub use bench_welch_peaks::run_welch_peaks_bench;
pub use bench_welch_peaks::run_welch_peaks_bench_opts;
pub use bench_welch_peaks::run_welch_peaks_k_sweep;
pub use bench_welch_peaks::run_welch_peaks_sweep;
pub use bench_welch_peaks::write_welch_peaks_json;
pub use config::EncDecTrainConfig;
pub use config::FftLearnConfig;
pub use config::MultiTrainConfig;
pub use config::MultiTrainSchedule;
pub use config::PhasedTrainConfig;
pub use config::SUPPORTED_N_FFT;
pub use config::TrainConfig;
pub use config::TransformDir;
pub use config::parse_transform_dir;
pub use device::bench_device_label;
pub use device::ensure_backend_ready;
pub use device::normalize_device_alias;
pub use device::parse_bench_device_list;
pub use device::pick_auto_device;
pub use device::resolve_train_device;
pub use distill_compile::CompiledDistilledMel;
pub use distill_compile::compile_distilled_mel;
pub use distill_model::DistilledFftModel;
pub use distill_ternary_compile::CompiledDistilledTernaryMel;
pub use distill_ternary_compile::compile_distilled_ternary_mel;
pub use distill_ternary_model::DistilledTernaryFftModel;
pub use e2e_bench::E2eBackend;
pub use e2e_bench::E2eBatchTrainMeta;
pub use e2e_bench::E2eBenchMeta;
pub use e2e_bench::E2eBenchReport;
pub use e2e_bench::E2eBenchRow;
pub use e2e_bench::E2ePipeline;
pub use e2e_bench::merge_e2e_reports;
pub use e2e_bench::print_e2e_table;
pub use e2e_bench::read_e2e_json;
pub use e2e_bench::run_e2e_bench;
pub use e2e_bench::write_e2e_json;
pub use e2e_bench_html::render_e2e_html;
pub use e2e_bench_html::write_e2e_html;
pub use learned_model::FastLearnedFftModel;
pub use peak::DEFAULT_PEAK_K;
pub use peak::WelchPeakParams;
pub use peak::WelchPeaksScratch;
pub use peak::peak_band_mask;
pub use peak::peak_loss_grad_wrt_spectrum;
pub use peak::peak_match_loss;
pub use peak::peak_max_err;
pub use peak::peaks_from_psd_batch;
pub use peak::peaks_from_segment_spectrum_streaming;
pub use peak::topk_peaks_one;
pub use peak::welch_peaks_from_segment_spectrum;
pub use peak::welch_peaks_rustfft;
pub use peak::welch_peaks_rustfft_with_scratch;
pub use runner::FftLearnRunner;
pub use second_order::TwiddleOptState;
pub use second_order::TwiddleOptimizer;
pub use second_order::diag_gn_step;
pub use second_order::hvp_twiddles_finite_diff;
pub use study_html::StudyInputs;
pub use study_html::render_study_html;
pub use study_html::write_study_html;
pub use ternary_arch::CorrectorKind;
pub use ternary_arch::GateLayout;
pub use ternary_arch::SpectrumCorrection;
pub use ternary_arch::TernaryArchConfig;
pub use ternary_gates::GateMode;
pub use ternary_gates::compute_fraction;
pub use ternary_gates::gate_mode_counts;
pub use train::EncDecTrainResult;
pub use train::TrainResult;
pub use train::evaluate_encdec_weights;
pub use train::evaluate_weights;
pub use train::evaluate_weights_dir;
pub use train::random_complex_batch;
pub use train::train_butterfly;
pub use train::train_butterfly_dir;
pub use train::train_butterfly_eager;
pub use train::train_encdec;
pub use train::train_encdec_eager;
pub use train_distill::DistillTrainConfig;
pub use train_distill::DistillTrainReport;
pub use train_distill::distill_from_teacher;
pub use train_distill_ternary::DistillTernaryTrainConfig;
pub use train_distill_ternary::DistillTernaryTrainReport;
pub use train_distill_ternary::distill_ternary_from_distilled;
pub use train_distill_ternary::distill_ternary_from_teacher;
pub use train_e2e::E2eTrainConfig;
pub use train_e2e::E2eTrainReport;
pub use train_e2e::train_fast_learned_model;
pub use train_multi::MultiTrainEvalRow;
pub use train_multi::MultiTrainReport;
pub use train_multi::best_regime_per_eval;
pub use train_multi::print_multi_train_table;
pub use train_multi::run_multi_train;
pub use train_multi::write_multi_train_json;
pub use train_multi_html::read_multi_train_json;
pub use train_multi_html::render_multi_train_html;
pub use train_multi_html::write_multi_train_html;
pub use train_phased::PhaseMetrics;
pub use train_phased::PhasedTrainResult;
pub use train_phased::precision_encdec;
pub use train_phased::train_phased_encdec;
pub use twiddle::TwiddleSet;
pub use twiddle::exact_twiddles;
pub use twiddle::exact_twiddles_dir;
pub use twiddle_stability::lr_for_n_fft;
pub use twiddle_stability::max_twiddle_magnitude;
pub use twiddle_stability::project_twiddles_unit_circle;
pub use twiddle_stability::twiddle_drift_from_unit;
pub use variants::FftVariantId;
pub use variants::VariantState;
pub use weights::EncDecWeights;
pub use weights::WeightStore;
pub use weights::export_safetensors;
pub use weights::load_safetensors;
pub use welch_peaks_compile::CompiledLearnedWelchPeaks;
pub use welch_peaks_compile::CompiledRlxWelchPeaks;
pub use welch_peaks_compile::CompiledRlxWelchPeaksExec;
pub use welch_peaks_compile::CompiledRlxWelchPeaksFused;
pub use welch_peaks_compile::RlxWelchPeaksExecKind;
pub use welch_peaks_compile::compile_learned_welch_peaks;
pub use welch_peaks_compile::compile_rlx_welch_peaks;
pub use welch_peaks_compile::compile_welch_peaks_fused;
pub use welch_peaks_compile::default_welch_peaks_hard_threshold;
pub use welch_peaks_compile::rlx_welch_peaks_exec_kind;
pub use welch_peaks_cost::WelchPeaksCostEstimates;
pub use welch_peaks_cost::WelchPeaksFusionGateBreakdown;
pub use welch_peaks_cost::algorithm_bandwidth_gbps;
pub use welch_peaks_cost::ayala_io_cost_ns;
pub use welch_peaks_cost::estimate_welch_peaks_costs;
pub use welch_peaks_cost::fused_welch_peaks_auto_viable;
pub use welch_peaks_cost::rustfft_peaks_io_profile;
pub use welch_peaks_cost::useful_bytes_touched;
pub use welch_peaks_cost::welch_peaks_fusion_gate_breakdown;
pub use welch_peaks_cost::welch_peaks_fusion_target;
pub use welch_peaks_cost::welch_peaks_io_fusion_gate;
pub use welch_peaks_picker::AutoWelchPeaks;
pub use welch_peaks_picker::WelchPeaksPickBreakdown;
pub use welch_peaks_picker::WelchPeaksPickMode;
pub use welch_peaks_picker::WelchPeaksStrategy;
pub use welch_peaks_picker::all_welch_peaks_strategy_names;
pub use welch_peaks_picker::parse_welch_peaks_strategy;
pub use welch_peaks_picker::pick_welch_peaks_breakdown;
pub use welch_peaks_picker::pick_welch_peaks_strategy;
pub use welch_peaks_picker::resolve_welch_peaks_strategy;
pub use welch_peaks_picker::rlx_crossover_batch;
pub use welch_peaks_picker::ultra_fast_max_batch;

Modules§

ablation
Ablation study across FFT variants (Tiers A/B/C).
ablation_csv
Ablation results as CSV — source of truth for study HTML reports.
ablation_html
Self-contained HTML ablation report (Chart.js + Plotly heatmaps).
ablation_ternary
Ternary architecture ablation — speed vs per-pipeline error (Pareto).
ablation_ternary_html
HTML report for ternary architecture ablation.
band_correct
Banded wide-sparse correction after pruned butterfly (ternary distill).
bench
Benchmark learned butterfly FFT vs rustfft and native RLX Op::Fft.
bench_encdec
Encoder–decoder speed and precision benchmarks.
bench_sweep
Multi-dimensional benchmark sweep → JSON + ASCII chart.
bench_sweep_html
Self-contained HTML benchmark report (Chart.js + Plotly heatmaps / 3D).
bench_welch_peaks
Benchmark Welch vs fast top-K peaks (rustfft + compiled + learned).
butterfly
Learnable Cooley–Tukey butterfly network (eager CPU + optional IR graph).
cli
CLI for training, evaluation, and benchmarking learned FFT / IFFT.
compile
Compile RLX training backward graphs on CPU / GPU backends.
config
Configuration for learned FFT models.
denoise
Learned spectrum denoiser — per-bin affine correction (legacy / teacher paths).
device
Training device selection (--device auto|cpu|metal|…).
distill_compile
Compiled distilled deploy — fused Hann → FFT → correction → Op::LogMel.
distill_fused
Shared fused deploy helpers — Hann window, correction, log-mel, device routing.
distill_model
Distilled fast deploy model — Op::Fft + learned correction (+ optional mel adapter).
distill_ternary_compile
Compiled ternary-routed distilled deploy — pruned Hann → sparse butterfly → mel.
distill_ternary_model
Distilled ternary-routed FFT — sparse exact butterfly + correction.
domain
Domain-adaptive twiddle training (Tier C).
e2e_bench
End-to-end validation bench — mel, welch, q8, denoise vs references.
e2e_bench_html
Self-contained HTML report for end-to-end learned FFT validation benches.
fused
Fused FFT → spectral mask → IFFT (Tier A).
fused_train
Fused encoder–decoder training — single forward/backward pass, batched reference FFT.
learned_compile
Compiled learned spectrum + mel deploy (Tier D) — gated butterfly + mask + denoiser.
learned_model
Fast learned FFT model — pruned butterfly + optional Q8 + denoiser + freq mask (Tier D).
mel
Log-mel frontend from power spectrum (Whisper-style, Tier D validation).
peak
Welch peak extraction — top-K frequency spikes from PSD (fast path uses fewer segments).
pruned
Gated / pruned butterfly — skip butterflies via learnable gates (Tier D).
q8
Q8 quantized twiddles for inference (Tier B).
reference
Exact FFT reference via rustfft.
rlx_fft
Native RLX Op::Fft graphs — forward and inverse, shared by bench / ablation / variants.
runner
Compiled inference session for learned FFT / IFFT.
second_order
Second-order / adaptive twiddle optimizers (Adam, diagonal preconditioning, HVP).
stockham
DIF (decimation-in-frequency) Cooley–Tukey FFT — natural input, bit-reverse output.
study_collect
Collect training telemetry for comprehensive study HTML reports.
study_full_html
Full study HTML: ranked ablation configs, loss/error/memory charts, 3D loss maps, activation heatmaps.
study_html
Unified HTML study report — delegates to crate::study_full_html.
study_telemetry
Training telemetry for study reports: loss curves, param counts, activation heatmaps, loss landscape.
ternary_arch
Architecture knobs for ternary distilled FFT (ablation + deploy).
ternary_gates
Ternary butterfly routing — skip (0), forward (+1), reverse (−1).
train
Training loop for butterfly twiddle factors.
train_distill
Distill teacher learned model → fast Op::Fft + correction student.
train_distill_ternary
Distill teacher → ternary-routed student with compute / precision tradeoff.
train_e2e
Train Tier-D fast learned FFT model (pruned + mask + denoiser + optional Q8).
train_graph
RLX training graphs — forward + MSE loss + autodiff backward.
train_multi
Multi-n_fft encoder–decoder training study and train×eval matrix.
train_multi_html
HTML report for multi-n_fft training study (train regime × eval n_fft matrix).
train_phased
Three-phase encoder → decoder → joint training.
train_rlx
Training via compiled RLX backward graphs (all backends).
twiddle
Twiddle-factor initialization for butterfly stages.
twiddle_stability
Twiddle update stability — unit-circle projection, LR scaling, gradient clipping.
unitary
Learnable 2×2 complex butterfly mixing matrices (Tier C).
variants
FFT implementation variants for ablation (Tiers A/B/C + baselines).
weights
Named twiddle parameters for training and compiled inference.
welch
Welch PSD — windowed overlapping segments, FFT, averaged one-sided power.
welch_peaks_compile
Compiled Welch peaks — RLX Op::Fft or learned spectrum + streaming top-K.
welch_peaks_cost
Ayala-style latency–bandwidth cost model for Welch peaks strategy selection.
welch_peaks_picker
Automatic Welch peaks path — pick fastest strategy from batch size + device.