Skip to main content

rlx_fft/
lib.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Learned FFT — butterfly network trained to match reference FFT, compiled via RLX.
17//!
18//! # Overview
19//!
20//! This crate learns twiddle factors in a Cooley–Tukey butterfly network so the
21//! transform matches [`rustfft`] on random signals. After training, the same graph
22//! can be compiled to CPU/GPU backends for batched inference.
23//!
24//! # Quick start
25//!
26//! ```no_run
27//! use rlx_fft::{FftLearnConfig, FftLearnRunner, TrainConfig, train_butterfly};
28//!
29//! # fn main() -> anyhow::Result<()> {
30//! let cfg = FftLearnConfig::new(256, 8)?;
31//! let report = train_butterfly(&TrainConfig {
32//!     model: cfg.clone(),
33//!     steps: 200,
34//!     ..TrainConfig::default()
35//! })?;
36//! println!("mse={} max_err={}", report.final_mse, report.max_error);
37//!
38//! let runner = FftLearnRunner::with_weights(cfg, &report.weights)?;
39//! # Ok(())
40//! # }
41//! ```
42//!
43//! # Welch peaks
44//!
45//! Fast top-K spike extraction with an automatic or forced strategy picker
46//! ([`AutoWelchPeaks`], `--strategy` on `bench-welch-peaks`). See `crates/rlx-fft/README.md`
47//! (Welch peaks section) in this repo.
48
49pub mod ablation;
50pub mod ablation_csv;
51pub mod ablation_html;
52pub mod ablation_ternary;
53pub mod ablation_ternary_html;
54pub mod band_correct;
55pub mod bench;
56pub mod bench_encdec;
57#[cfg(feature = "dev")]
58pub mod bench_fusion_phases;
59pub mod bench_sweep;
60pub mod bench_sweep_html;
61pub mod bench_welch_peaks;
62pub mod butterfly;
63pub mod compile;
64pub mod config;
65pub mod denoise;
66pub mod device;
67pub mod distill_compile;
68pub mod distill_fused;
69pub mod distill_model;
70pub mod distill_ternary_compile;
71pub mod distill_ternary_model;
72pub mod domain;
73pub mod e2e_bench;
74pub mod e2e_bench_html;
75pub mod fused;
76pub mod fused_train;
77pub mod learned_compile;
78pub mod learned_model;
79pub mod mel;
80pub mod peak;
81pub mod pruned;
82pub mod q8;
83pub mod reference;
84pub mod rlx_fft;
85pub mod runner;
86pub mod second_order;
87pub mod stockham;
88pub mod study_collect;
89pub mod study_full_html;
90pub mod study_html;
91pub mod study_telemetry;
92pub mod ternary_arch;
93pub mod ternary_gates;
94pub mod train;
95pub mod train_distill;
96pub mod train_distill_ternary;
97pub mod train_e2e;
98pub mod train_graph;
99pub mod train_multi;
100pub mod train_multi_html;
101pub mod train_phased;
102pub mod train_rlx;
103pub mod twiddle;
104pub mod twiddle_stability;
105pub mod unitary;
106pub mod variants;
107pub mod weights;
108pub mod welch;
109pub mod welch_peaks_compile;
110pub mod welch_peaks_cost;
111pub mod welch_peaks_picker;
112
113pub mod cli;
114
115pub use ablation::{
116    AblationReport, AblationRow, ablation_row_ok, ablation_winners, limit_sweep_devices,
117    merge_ablation_reports, print_ablation_table, run_ablation, run_limit_sweep, tier_summary,
118    top5_variants_per_n_fft, write_ablation_json,
119};
120pub use ablation_csv::{
121    LIMITS_CSV, META_CSV, ROWS_CSV, TOP5_CSV, read_ablation_csv_dir, read_ablation_rows_csv,
122    write_ablation_csv_dir,
123};
124pub use ablation_html::{read_ablation_json, render_ablation_html, write_ablation_html};
125pub use ablation_ternary::{
126    TernaryAblationOpts, TernaryAblationReport, TernaryAblationRow, TernaryArchVariantId,
127    TernaryExecMode, TernaryParetoPoint, print_ternary_ablation_table, quick_ablation_opts,
128    run_ternary_ablation, ternary_ablation_row_ok, ternary_aggregate_variants,
129    ternary_pareto_frontier, ternary_recommendation, write_ternary_ablation_csv,
130    write_ternary_ablation_json,
131};
132pub use ablation_ternary_html::{
133    read_ternary_ablation_json, render_ternary_ablation_html, write_ternary_ablation_html,
134};
135pub use bench::{
136    BenchReport, bench_all, bench_all_dir, bench_reference_vs_learned,
137    bench_reference_vs_learned_dir,
138};
139pub use bench_encdec::{
140    EncDecBenchRow, bench_encdec_weights, bench_exact_baseline, bench_phased_dir,
141    print_encdec_bench_table, write_encdec_bench_json,
142};
143pub use bench_sweep::{
144    SweepReport, SweepRow, available_devices, parse_batch_spec, parse_csv_usize, parse_k_spec,
145    print_sweep_chart, run_sweep, sweep_markdown_chart, write_sweep_json,
146};
147pub use bench_sweep_html::{read_sweep_json, render_sweep_html, write_sweep_html};
148pub use bench_welch_peaks::{
149    WelchPeaksBenchOpts, WelchPeaksBenchReport, WelchPeaksBenchRow, print_welch_peaks_table,
150    run_welch_peaks_batch_sweep, run_welch_peaks_bench, run_welch_peaks_bench_opts,
151    run_welch_peaks_k_sweep, run_welch_peaks_sweep, write_welch_peaks_json,
152};
153pub use config::{
154    EncDecTrainConfig, FftLearnConfig, MultiTrainConfig, MultiTrainSchedule, PhasedTrainConfig,
155    SUPPORTED_N_FFT, TrainConfig, TransformDir, parse_transform_dir,
156};
157pub use device::{
158    bench_device_label, ensure_backend_ready, normalize_device_alias, parse_bench_device_list,
159    pick_auto_device, resolve_train_device,
160};
161pub use distill_compile::{CompiledDistilledMel, compile_distilled_mel};
162pub use distill_model::DistilledFftModel;
163pub use distill_ternary_compile::{CompiledDistilledTernaryMel, compile_distilled_ternary_mel};
164pub use distill_ternary_model::DistilledTernaryFftModel;
165pub use e2e_bench::{
166    E2eBackend, E2eBatchTrainMeta, E2eBenchMeta, E2eBenchReport, E2eBenchRow, E2ePipeline,
167    merge_e2e_reports, print_e2e_table, read_e2e_json, run_e2e_bench, write_e2e_json,
168};
169pub use e2e_bench_html::{render_e2e_html, write_e2e_html};
170pub use learned_model::FastLearnedFftModel;
171pub use peak::{
172    DEFAULT_PEAK_K, WelchPeakParams, WelchPeaksScratch, peak_band_mask,
173    peak_loss_grad_wrt_spectrum, peak_match_loss, peak_max_err, peaks_from_psd_batch,
174    peaks_from_segment_spectrum_streaming, topk_peaks_one, welch_peaks_from_segment_spectrum,
175    welch_peaks_rustfft, welch_peaks_rustfft_with_scratch,
176};
177pub use runner::FftLearnRunner;
178pub use second_order::{TwiddleOptState, TwiddleOptimizer, diag_gn_step, hvp_twiddles_finite_diff};
179pub use study_html::{StudyInputs, render_study_html, write_study_html};
180pub use ternary_arch::{CorrectorKind, GateLayout, SpectrumCorrection, TernaryArchConfig};
181pub use ternary_gates::{GateMode, compute_fraction, gate_mode_counts};
182pub use train::{
183    EncDecTrainResult, TrainResult, evaluate_encdec_weights, evaluate_weights,
184    evaluate_weights_dir, random_complex_batch, train_butterfly, train_butterfly_dir,
185    train_butterfly_eager, train_encdec, train_encdec_eager,
186};
187pub use train_distill::{DistillTrainConfig, DistillTrainReport, distill_from_teacher};
188pub use train_distill_ternary::{
189    DistillTernaryTrainConfig, DistillTernaryTrainReport, distill_ternary_from_distilled,
190    distill_ternary_from_teacher,
191};
192pub use train_e2e::{E2eTrainConfig, E2eTrainReport, train_fast_learned_model};
193pub use train_multi::{
194    MultiTrainEvalRow, MultiTrainReport, best_regime_per_eval, print_multi_train_table,
195    run_multi_train, write_multi_train_json,
196};
197pub use train_multi_html::{
198    read_multi_train_json, render_multi_train_html, write_multi_train_html,
199};
200pub use train_phased::{PhaseMetrics, PhasedTrainResult, precision_encdec, train_phased_encdec};
201pub use twiddle::{TwiddleSet, exact_twiddles, exact_twiddles_dir};
202pub use twiddle_stability::{
203    lr_for_n_fft, max_twiddle_magnitude, project_twiddles_unit_circle, twiddle_drift_from_unit,
204};
205pub use variants::{FftVariantId, VariantState};
206pub use weights::{EncDecWeights, WeightStore, export_safetensors, load_safetensors};
207pub use welch_peaks_compile::{
208    CompiledLearnedWelchPeaks, CompiledRlxWelchPeaks, CompiledRlxWelchPeaksExec,
209    CompiledRlxWelchPeaksFused, RlxWelchPeaksExecKind, compile_learned_welch_peaks,
210    compile_rlx_welch_peaks, compile_welch_peaks_fused, default_welch_peaks_hard_threshold,
211    rlx_welch_peaks_exec_kind,
212};
213pub use welch_peaks_cost::{
214    WelchPeaksCostEstimates, WelchPeaksFusionGateBreakdown, algorithm_bandwidth_gbps,
215    ayala_io_cost_ns, estimate_welch_peaks_costs, fused_welch_peaks_auto_viable,
216    rustfft_peaks_io_profile, useful_bytes_touched, welch_peaks_fusion_gate_breakdown,
217    welch_peaks_fusion_target, welch_peaks_io_fusion_gate,
218};
219pub use welch_peaks_picker::{
220    AutoWelchPeaks, WelchPeaksPickBreakdown, WelchPeaksPickMode, WelchPeaksStrategy,
221    all_welch_peaks_strategy_names, parse_welch_peaks_strategy, pick_welch_peaks_breakdown,
222    pick_welch_peaks_strategy, resolve_welch_peaks_strategy, rlx_crossover_batch,
223    ultra_fast_max_batch,
224};