use crate::device::{ensure_backend_ready, resolve_train_device};
use crate::learned_model::FastLearnedFftModel;
use crate::peak::{WelchPeakParams, WelchPeaksScratch, welch_peaks_rustfft_with_scratch};
use crate::pruned::DEFAULT_GATE_THRESHOLD;
use crate::welch::WelchParams;
use crate::welch_peaks_compile::{
CompiledLearnedWelchPeaks, CompiledRlxWelchPeaksFused, compile_learned_welch_peaks,
default_welch_peaks_hard_threshold,
};
use crate::welch_peaks_cost::{WelchPeaksCostEstimates, estimate_welch_peaks_costs, is_gpu_device};
use anyhow::{Result, bail, ensure};
use rlx_runtime::Device;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum WelchPeaksStrategy {
UltraFast,
FastStreaming,
RlxCompiled,
LearnedCompiled,
}
impl WelchPeaksStrategy {
pub fn label(self) -> &'static str {
match self {
Self::UltraFast => "ultra_fast_rustfft",
Self::FastStreaming => "fast_streaming_rustfft",
Self::RlxCompiled => "rlx_compiled",
Self::LearnedCompiled => "learned_compiled",
}
}
}
pub fn rlx_crossover_batch(device: Device) -> usize {
if is_gpu_device(device) {
8192
} else {
usize::MAX
}
}
pub fn ultra_fast_max_batch(device: Device) -> usize {
if is_gpu_device(device) { 128 } else { 256 }
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct WelchPeaksPickBreakdown {
pub costs: WelchPeaksCostEstimates,
pub picked: WelchPeaksStrategy,
}
fn legacy_pick_welch_peaks_strategy(
device: Device,
batch: usize,
learned_available: bool,
learned_active_gates: Option<usize>,
learned_total_gates: usize,
) -> WelchPeaksStrategy {
let sparse_learned = learned_active_gates
.map(|active| learned_total_gates > 0 && active * 4 < learned_total_gates)
.unwrap_or(false);
if learned_available
&& sparse_learned
&& batch >= rlx_crossover_batch(device)
&& is_gpu_device(device)
{
return WelchPeaksStrategy::LearnedCompiled;
}
if batch >= rlx_crossover_batch(device) && is_gpu_device(device) {
return WelchPeaksStrategy::RlxCompiled;
}
if batch <= ultra_fast_max_batch(device) {
return WelchPeaksStrategy::UltraFast;
}
WelchPeaksStrategy::FastStreaming
}
fn pick_from_costs(costs: WelchPeaksCostEstimates) -> WelchPeaksStrategy {
let mut best = WelchPeaksStrategy::UltraFast;
let mut best_ns = costs.ultra_ns;
if costs.fast_ns < best_ns {
best_ns = costs.fast_ns;
best = WelchPeaksStrategy::FastStreaming;
}
if costs.rlx_ns < best_ns {
best_ns = costs.rlx_ns;
best = WelchPeaksStrategy::RlxCompiled;
}
if costs.learned_ns < best_ns {
best = WelchPeaksStrategy::LearnedCompiled;
}
best
}
pub fn pick_welch_peaks_breakdown(
device: Device,
batch: usize,
n_fft: usize,
k: usize,
learned_available: bool,
learned_active_gates: Option<usize>,
learned_total_gates: usize,
) -> WelchPeaksPickBreakdown {
if rlx_ir::env::flag("RLX_FFT_LEGACY_PICKER") {
return WelchPeaksPickBreakdown {
costs: WelchPeaksCostEstimates {
ultra_ns: 0.0,
fast_ns: 0.0,
rlx_ns: 0.0,
learned_ns: 0.0,
},
picked: legacy_pick_welch_peaks_strategy(
device,
batch,
learned_available,
learned_active_gates,
learned_total_gates,
),
};
}
let costs = estimate_welch_peaks_costs(
device,
batch,
n_fft,
k,
learned_available,
learned_active_gates,
learned_total_gates,
);
WelchPeaksPickBreakdown {
picked: pick_from_costs(costs),
costs,
}
}
pub fn pick_welch_peaks_strategy(
device: Device,
batch: usize,
n_fft: usize,
k: usize,
learned_available: bool,
learned_active_gates: Option<usize>,
learned_total_gates: usize,
) -> WelchPeaksStrategy {
pick_welch_peaks_breakdown(
device,
batch,
n_fft,
k,
learned_available,
learned_active_gates,
learned_total_gates,
)
.picked
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum WelchPeaksPickMode {
Auto,
Force(WelchPeaksStrategy),
}
impl WelchPeaksPickMode {
pub fn is_auto(self) -> bool {
matches!(self, Self::Auto)
}
}
pub fn parse_welch_peaks_strategy(name: &str) -> Result<WelchPeaksPickMode> {
match name.trim().to_ascii_lowercase().replace('-', "_").as_str() {
"" | "auto" => Ok(WelchPeaksPickMode::Auto),
"ultra" | "ultra_fast" | "ultra_fast_rustfft" | "1seg" => {
Ok(WelchPeaksPickMode::Force(WelchPeaksStrategy::UltraFast))
}
"fast" | "streaming" | "fast_streaming" | "fast_streaming_rustfft" | "rustfft" | "2seg" => {
Ok(WelchPeaksPickMode::Force(WelchPeaksStrategy::FastStreaming))
}
"rlx" | "rlx_compiled" | "compiled" | "gpu" => {
Ok(WelchPeaksPickMode::Force(WelchPeaksStrategy::RlxCompiled))
}
"learned" | "learned_compiled" => Ok(WelchPeaksPickMode::Force(
WelchPeaksStrategy::LearnedCompiled,
)),
other => bail!("unknown welch peaks strategy {other:?} (try auto|ultra|fast|rlx|learned)"),
}
}
pub fn all_welch_peaks_strategy_names() -> &'static [&'static str] {
&["auto", "ultra", "fast", "rlx", "learned"]
}
pub fn resolve_welch_peaks_strategy(
mode: WelchPeaksPickMode,
device: Device,
batch: usize,
n_fft: usize,
k: usize,
learned_available: bool,
learned_active_gates: Option<usize>,
learned_total_gates: usize,
) -> WelchPeaksStrategy {
match mode {
WelchPeaksPickMode::Force(s) => s,
WelchPeaksPickMode::Auto => pick_welch_peaks_strategy(
device,
batch,
n_fft,
k,
learned_available,
learned_active_gates,
learned_total_gates,
),
}
}
pub struct AutoWelchPeaks {
pub strategy: WelchPeaksStrategy,
pub device: Device,
batch: usize,
full_frame: usize,
peak_params: WelchPeakParams,
scratch: WelchPeaksScratch,
rlx: Option<CompiledRlxWelchPeaksFused>,
learned: Option<CompiledLearnedWelchPeaks>,
}
impl AutoWelchPeaks {
pub fn new(batch: usize, n_fft: usize, k: usize, device: Option<&str>) -> Result<Self> {
Self::with_options(batch, n_fft, k, device, None, WelchPeaksPickMode::Auto)
}
pub fn with_learned(
batch: usize,
n_fft: usize,
k: usize,
device: Option<&str>,
model: Option<&FastLearnedFftModel>,
) -> Result<Self> {
Self::with_options(batch, n_fft, k, device, model, WelchPeaksPickMode::Auto)
}
pub fn with_strategy(
batch: usize,
n_fft: usize,
k: usize,
device: Option<&str>,
strategy: WelchPeaksStrategy,
) -> Result<Self> {
Self::with_options(
batch,
n_fft,
k,
device,
None,
WelchPeaksPickMode::Force(strategy),
)
}
pub fn with_options(
batch: usize,
n_fft: usize,
k: usize,
device: Option<&str>,
model: Option<&FastLearnedFftModel>,
mode: WelchPeaksPickMode,
) -> Result<Self> {
ensure!(batch >= 1 && k >= 1);
let device = resolve_train_device(device)?;
ensure_backend_ready(device)?;
let learned_available = model.is_some();
let (active, total) = model
.map(|m| (Some(m.active_gates(DEFAULT_GATE_THRESHOLD)), m.gates.len()))
.unwrap_or((None, 0));
let breakdown =
pick_welch_peaks_breakdown(device, batch, n_fft, k, learned_available, active, total);
let strategy = match mode {
WelchPeaksPickMode::Force(s) => s,
WelchPeaksPickMode::Auto => breakdown.picked,
};
if rlx_ir::env::flag("RLX_FFT_PICKER_TRACE") {
fn ms(ns: f64) -> f64 {
if ns.is_finite() {
ns / 1e6
} else {
f64::INFINITY
}
}
eprintln!(
"[welch-peaks] io pick batch={batch} device={device:?} \
ultra={:.2}ms fast={:.2}ms rlx={:.2}ms learned={:.2}ms -> {}",
ms(breakdown.costs.ultra_ns),
ms(breakdown.costs.fast_ns),
ms(breakdown.costs.rlx_ns),
ms(breakdown.costs.learned_ns),
strategy.label(),
);
}
if strategy == WelchPeaksStrategy::LearnedCompiled && model.is_none() {
bail!("--strategy learned requires a trained model (--train-steps > 0)");
}
let peak_params = match strategy {
WelchPeaksStrategy::UltraFast => WelchPeakParams::ultra_fast_for_n_fft(n_fft, k),
_ => WelchPeakParams::fast_for_n_fft(n_fft, k),
};
let full_frame = WelchParams::for_n_fft(n_fft).frame_len();
let scratch = WelchPeaksScratch::new(batch, peak_params.n_bins());
let mut rlx = None;
let mut learned = None;
match strategy {
WelchPeaksStrategy::RlxCompiled => {
rlx = Some(CompiledRlxWelchPeaksFused::compile(
batch,
peak_params,
device,
)?);
}
WelchPeaksStrategy::LearnedCompiled => {
let m = model.expect("learned model required for LearnedCompiled");
let mut hard = m.clone();
hard.hard_gate_threshold = Some(DEFAULT_GATE_THRESHOLD);
learned = Some(compile_learned_welch_peaks(
&hard,
batch,
peak_params,
device,
default_welch_peaks_hard_threshold(),
)?);
}
_ => {}
}
Ok(Self {
strategy,
device,
batch,
full_frame,
peak_params,
scratch,
rlx,
learned,
})
}
pub fn strategy_label(&self) -> &'static str {
self.strategy.label()
}
pub fn peak_params(&self) -> WelchPeakParams {
self.peak_params
}
pub fn welch_peaks_batch(&mut self, signal: &[f32]) -> Result<Vec<f32>> {
ensure!(signal.len() == self.batch * self.full_frame);
let fast_signal =
self.peak_params
.welch
.truncate_batch(signal, self.batch, self.full_frame)?;
match self.strategy {
WelchPeaksStrategy::UltraFast | WelchPeaksStrategy::FastStreaming => {
welch_peaks_rustfft_with_scratch(
&fast_signal,
self.batch,
self.peak_params,
Some(&mut self.scratch),
)
}
WelchPeaksStrategy::RlxCompiled => self
.rlx
.as_mut()
.expect("rlx compiled")
.welch_peaks_batch(&fast_signal),
WelchPeaksStrategy::LearnedCompiled => self
.learned
.as_mut()
.expect("learned compiled")
.welch_peaks_batch(&fast_signal, &mut self.scratch),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use rlx_runtime::Device;
#[test]
fn small_batch_picks_ultra_on_cpu() {
assert_eq!(
pick_welch_peaks_strategy(Device::Cpu, 32, 256, 16, false, None, 0),
WelchPeaksStrategy::UltraFast
);
}
#[test]
fn mid_batch_picks_ultra_on_cpu_io_model() {
assert_eq!(
pick_welch_peaks_strategy(Device::Cpu, 512, 256, 16, false, None, 0),
WelchPeaksStrategy::UltraFast
);
}
#[test]
fn large_batch_picks_rlx_on_metal() {
assert_eq!(
pick_welch_peaks_strategy(Device::Metal, 8192, 256, 16, false, None, 0),
WelchPeaksStrategy::RlxCompiled
);
}
#[test]
fn metal_mid_batch_picks_rlx() {
assert_eq!(
pick_welch_peaks_strategy(Device::Metal, 1024, 256, 16, false, None, 0),
WelchPeaksStrategy::RlxCompiled
);
}
#[test]
fn metal_small_batch_stays_rustfft() {
assert_eq!(
pick_welch_peaks_strategy(Device::Metal, 256, 256, 16, false, None, 0),
WelchPeaksStrategy::FastStreaming
);
}
#[test]
fn metal_upper_mid_batch_picks_rlx() {
assert_eq!(
pick_welch_peaks_strategy(Device::Metal, 4096, 256, 16, false, None, 0),
WelchPeaksStrategy::RlxCompiled
);
}
#[test]
#[ignore = "IO cost model picks RlxCompiled for large WGPU batches; revisit after WgpuCostModel calibration"]
fn wgpu_stays_rustfft_through_large_batch() {
for batch in [256usize, 1024, 4096, 8192] {
assert_eq!(
pick_welch_peaks_strategy(Device::Gpu, batch, 256, 16, false, None, 0),
WelchPeaksStrategy::FastStreaming,
"batch={batch}"
);
}
}
#[test]
fn legacy_picker_matches_old_thresholds() {
assert_eq!(
legacy_pick_welch_peaks_strategy(Device::Metal, 8192, false, None, 0),
WelchPeaksStrategy::RlxCompiled
);
}
#[test]
fn parse_strategy_aliases() {
assert!(parse_welch_peaks_strategy("auto").unwrap().is_auto());
assert_eq!(
parse_welch_peaks_strategy("ultra-fast").unwrap(),
WelchPeaksPickMode::Force(WelchPeaksStrategy::UltraFast)
);
assert_eq!(
parse_welch_peaks_strategy("rlx").unwrap(),
WelchPeaksPickMode::Force(WelchPeaksStrategy::RlxCompiled)
);
}
#[test]
fn forced_overrides_auto() {
assert_eq!(
resolve_welch_peaks_strategy(
WelchPeaksPickMode::Force(WelchPeaksStrategy::FastStreaming),
Device::Metal,
8192,
256,
16,
false,
None,
0,
),
WelchPeaksStrategy::FastStreaming
);
}
}