rlx-fft 0.2.7

Learned FFT via butterfly networks — train for reference precision, run compiled on RLX backends
Documentation
//! Phase-by-phase Welch peaks fusion benchmark (rlx + rlx-fft).

use crate::device::{ensure_backend_ready, resolve_train_device};
use crate::peak::WelchPeakParams;
use crate::train::random_batch;
use crate::welch_peaks_compile::{
    CompiledRlxWelchPeaksFused, build_welch_peaks_fused_graph, compile_rlx_welch_peaks,
};
use crate::welch_peaks_cost::welch_peaks_fusion_target;
use anyhow::{Context, Result, ensure};
use rand::prelude::*;
use rlx_compile::fusion_benefit::fusion_benefit;
use rlx_compile::{FusionOptions, run_fusion_pipeline, supported_for_target, supports_op};
use rlx_ir::OpKind;
use rlx_runtime::graph_io::{GraphIoProfile, profile_graph_io};
use rlx_runtime::{Device, Session};
use serde::{Deserialize, Serialize};
use std::time::Instant;

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FusionPhaseRow {
    pub phase: String,
    pub batch: usize,
    pub k: usize,
    pub device: String,
    pub ms: f64,
    pub io: GraphIoProfile,
    pub predicted_cost_ns: f64,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FusionPhaseReport {
    pub n_fft: usize,
    pub batch: usize,
    pub k: usize,
    pub device: String,
    pub fft_only_graph: GraphIoProfile,
    pub fused_graph: GraphIoProfile,
    pub fusion_launches_saved: isize,
    pub fusion_sync_saved: isize,
    pub fusion_readback_saved: i64,
    pub fusion_gate_fuse: bool,
    pub fusion_gate_score_ns: f64,
    pub fusion_gate_min_gain_ns: f64,
    pub fused_auto_viable: bool,
    pub rows: Vec<FusionPhaseRow>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FusionPhaseSweepReport {
    pub reports: Vec<FusionPhaseReport>,
}

pub fn run_fusion_phase_sweep(
    n_fft: usize,
    batch_csv: &str,
    k: usize,
    device_name: &str,
    iters: usize,
    seed: u64,
) -> Result<FusionPhaseSweepReport> {
    use crate::bench_sweep::parse_batch_spec;
    let batches = parse_batch_spec(batch_csv, "--batch")?;
    let mut reports = Vec::with_capacity(batches.len());
    for &batch in &batches {
        let mut run_iters = iters;
        if batch >= 4096 {
            run_iters = run_iters.min(15);
        }
        reports.push(run_fusion_phase_bench(
            n_fft,
            batch,
            k,
            device_name,
            run_iters,
            seed,
        )?);
    }
    Ok(FusionPhaseSweepReport { reports })
}

fn time_iters<F>(iters: usize, mut f: F) -> Result<f64>
where
    F: FnMut() -> Result<()>,
{
    for _ in 0..iters.saturating_sub(1) {
        f()?;
    }
    let t0 = Instant::now();
    f()?;
    Ok(t0.elapsed().as_secs_f64() * 1000.0)
}

pub fn run_fusion_phase_bench(
    n_fft: usize,
    batch: usize,
    k: usize,
    device_name: &str,
    iters: usize,
    seed: u64,
) -> Result<FusionPhaseReport> {
    ensure!(iters >= 1 && k >= 1);
    let device = resolve_train_device(Some(device_name))?;
    ensure_backend_ready(device)?;

    let peak_params = WelchPeakParams::fast_for_n_fft(n_fft, k);
    let frame = peak_params.frame_len();

    let mut rng = StdRng::seed_from_u64(seed);
    let signal = random_batch(&mut rng, batch, frame);
    let mut scratch = crate::peak::WelchPeaksScratch::new(batch, peak_params.n_bins());

    let fft_graph = {
        let mut g = rlx_ir::Graph::new("fft_only");
        use rlx_ir::infer::GraphExt;
        let segs = g.input(
            "segs",
            rlx_ir::Shape::new(
                &[batch * peak_params.welch.n_segments, n_fft],
                rlx_ir::DType::F32,
            ),
        );
        let zeros = g.sub(segs, segs);
        let block = g.concat_(vec![segs, zeros], 1);
        let y = g.fft(block, false);
        g.set_outputs(vec![y]);
        g
    };
    let fused_graph = build_welch_peaks_fused_graph(batch, peak_params);
    let io_fft_raw = profile_graph_io(&fft_graph);
    let io_fused_raw = profile_graph_io(&fused_graph);
    let gate_bd =
        crate::welch_peaks_cost::welch_peaks_fusion_gate_breakdown(device, batch, n_fft, k);
    let fusion_gate_fuse = gate_bd.should_fuse;
    let (io_fft, io_fused) =
        crate::welch_peaks_cost::welch_peaks_fusion_io_profiles(batch, n_fft, k, device);
    let benefit = fusion_benefit(&io_fft, &io_fused);

    let predicted_ns =
        crate::welch_peaks_cost::estimate_fused_graph_ns(batch, peak_params, device, 1.0);

    let mut rows = Vec::new();

    let mut legacy = compile_rlx_welch_peaks(batch, peak_params, device)?;
    let ms = time_iters(iters, || {
        let _ = legacy.welch_peaks_batch(&signal, &mut scratch)?;
        Ok(())
    })?;
    rows.push(FusionPhaseRow {
        phase: "baseline_interleaved_readback".into(),
        batch,
        k,
        device: device_name.into(),
        ms,
        io: io_fft_raw,
        predicted_cost_ns: predicted_ns,
    });

    let mut block_path = compile_rlx_welch_peaks(batch, peak_params, device)?;
    let ms = time_iters(iters, || {
        let _ = block_path.welch_peaks_batch_block(&signal, &mut scratch)?;
        Ok(())
    })?;
    rows.push(FusionPhaseRow {
        phase: "phase1_block_layout".into(),
        batch,
        k,
        device: device_name.into(),
        ms,
        io: io_fft_raw,
        predicted_cost_ns: predicted_ns,
    });

    let mut fused = CompiledRlxWelchPeaksFused::compile(batch, peak_params, device)?;
    let ms = time_iters(iters, || {
        let _ = fused.welch_peaks_batch(&signal)?;
        Ok(())
    })?;
    rows.push(FusionPhaseRow {
        phase: "phase2_fused_welch_peaks_op".into(),
        batch,
        k,
        device: device_name.into(),
        ms,
        io: io_fused_raw,
        predicted_cost_ns: predicted_ns,
    });

    let mut dual = build_welch_peaks_fused_graph(batch, peak_params);
    let peaks_id = dual.outputs[0];
    let spec_id = dual.node(peaks_id).inputs[0];
    dual.set_outputs(vec![spec_id, peaks_id]);
    let io_dual = profile_graph_io(&dual);
    let target = welch_peaks_fusion_target(device);
    let mut supported: Vec<OpKind> = supported_for_target(target).to_vec();
    if !supports_op(&supported, OpKind::Fft) {
        supported.push(OpKind::Fft);
    }
    if !supports_op(&supported, OpKind::WelchPeaks) {
        supported.push(OpKind::WelchPeaks);
    }
    let optimized = run_fusion_pipeline(dual, target, &supported, FusionOptions::default());
    let io_compile = profile_graph_io(&optimized);
    rows.push(FusionPhaseRow {
        phase: "phase3_compile_peaks_output_gate".into(),
        batch,
        k,
        device: device_name.into(),
        ms: 0.0,
        io: io_compile,
        predicted_cost_ns: 0.0,
    });
    if fusion_gate_fuse && io_dual.host_output_bytes > io_compile.host_output_bytes {
        let mut exec = Session::new(device).compile(optimized);
        let window = crate::welch::hann_window(peak_params.welch.n_fft);
        let ms = time_iters(iters, || {
            let segs =
                crate::welch::welch_windowed_segments(&signal, batch, peak_params.welch, &window)?;
            let _ = exec.run(&[("segs", &segs)]);
            Ok(())
        })?;
        if let Some(row) = rows
            .iter_mut()
            .find(|r| r.phase == "phase3_compile_peaks_output_gate")
        {
            row.ms = ms;
        }
    }

    let _ = Session::new(device);

    Ok(FusionPhaseReport {
        n_fft,
        batch,
        k,
        device: device_name.into(),
        fft_only_graph: io_fft_raw,
        fused_graph: io_fused_raw,
        fusion_launches_saved: benefit.launches_saved,
        fusion_sync_saved: benefit.sync_points_saved,
        fusion_readback_saved: benefit.host_readback_bytes_saved,
        fusion_gate_fuse,
        fusion_gate_score_ns: gate_bd.score_ns,
        fusion_gate_min_gain_ns: gate_bd.min_gain_ns,
        fused_auto_viable: crate::welch_peaks_cost::fused_welch_peaks_auto_viable(device),
        rows,
    })
}

pub fn print_fusion_phase_report(report: &FusionPhaseReport) {
    eprintln!(
        "\n=== Fusion phase bench n_fft={} batch={} k={} device={} ===",
        report.n_fft, report.batch, report.k, report.device
    );
    eprintln!(
        "  IO fft-only: launches={} sync={} host_out={} B device_traffic={} B",
        report.fft_only_graph.kernel_launches,
        report.fft_only_graph.sync_points,
        report.fft_only_graph.host_output_bytes,
        report.fft_only_graph.device_traffic_bytes,
    );
    eprintln!(
        "  IO fused:    launches={} sync={} host_out={} B device_traffic={} B",
        report.fused_graph.kernel_launches,
        report.fused_graph.sync_points,
        report.fused_graph.host_output_bytes,
        report.fused_graph.device_traffic_bytes,
    );
    eprintln!(
        "  fusion benefit: launches_saved={} sync_saved={} readback_saved={} B \
         gate_score={:.3}ms min_gain={:.3}ms gate_fuse={} auto_viable={}",
        report.fusion_launches_saved,
        report.fusion_sync_saved,
        report.fusion_readback_saved,
        report.fusion_gate_score_ns / 1e6,
        report.fusion_gate_min_gain_ns / 1e6,
        report.fusion_gate_fuse,
        report.fused_auto_viable,
    );
    let base = report.rows.first().map(|r| r.ms).unwrap_or(1.0).max(1e-9);
    for r in &report.rows {
        eprintln!(
            "  {:32} ms={:8.4} speedup={:.2}x host_out={} B",
            r.phase,
            r.ms,
            base / r.ms.max(1e-9),
            r.io.host_output_bytes,
        );
    }
    if let Some(p2) = report
        .rows
        .iter()
        .find(|r| r.phase == "phase2_fused_welch_peaks_op")
    {
        if p2.predicted_cost_ns > 0.0 {
            let pred_ms = p2.predicted_cost_ns / 1e6;
            let ratio = p2.ms / pred_ms.max(1e-9);
            eprintln!(
                "  phase2 IO-model: predicted={pred_ms:.3}ms measured={:.3}ms ratio={ratio:.2}x",
                p2.ms,
            );
            if let Ok(device) = resolve_train_device(Some(&report.device)) {
                let cur = crate::welch_peaks_cost::fused_io_compute_scale_for_calibration(device);
                eprintln!(
                    "  suggested fused_io_compute_scale: {:.2} (current {:.2})",
                    cur * ratio,
                    cur,
                );
            }
        }
        let picker_pred_ms = crate::welch_peaks_cost::estimate_welch_peaks_costs(
            resolve_train_device(Some(&report.device)).unwrap_or(Device::Cpu),
            report.batch,
            report.n_fft,
            report.k,
            false,
            None,
            0,
        )
        .rlx_ns
            / 1e6;
        eprintln!("  picker rlx estimate: {picker_pred_ms:.3}ms (calibrated fused scale)",);
    }
    eprintln!();
}

pub fn print_fusion_phase_sweep_summary(sweep: &FusionPhaseSweepReport) {
    if sweep.reports.len() <= 1 {
        return;
    }
    eprintln!("=== Fusion phase crossover (speedup vs baseline) ===");
    eprintln!(
        "{:>8} {:>8} {:>8} {:>10} {:>6} {:>6}",
        "batch", "phase1", "phase2", "gate_ms", "fuse", "auto"
    );
    for r in &sweep.reports {
        let base = r.rows.first().map(|x| x.ms).unwrap_or(1.0).max(1e-9);
        let p1 = r
            .rows
            .iter()
            .find(|x| x.phase == "phase1_block_layout")
            .map(|x| base / x.ms.max(1e-9))
            .unwrap_or(0.0);
        let p2 = r
            .rows
            .iter()
            .find(|x| x.phase == "phase2_fused_welch_peaks_op")
            .map(|x| base / x.ms.max(1e-9))
            .unwrap_or(0.0);
        eprintln!(
            "{:>8} {:>7.2}x {:>7.2}x {:>9.2} {:>6} {:>6}",
            r.batch,
            p1,
            p2,
            r.fusion_gate_score_ns / 1e6,
            r.fusion_gate_fuse,
            r.fused_auto_viable,
        );
    }
    eprintln!();
}

pub fn write_fusion_phase_json(path: &std::path::Path, report: &FusionPhaseReport) -> Result<()> {
    if let Some(parent) = path.parent() {
        std::fs::create_dir_all(parent)?;
    }
    std::fs::write(path, serde_json::to_vec_pretty(report)?)
        .with_context(|| format!("write {}", path.display()))
}

pub fn write_fusion_phase_sweep_json(
    path: &std::path::Path,
    sweep: &FusionPhaseSweepReport,
) -> Result<()> {
    if let Some(parent) = path.parent() {
        std::fs::create_dir_all(parent)?;
    }
    std::fs::write(path, serde_json::to_vec_pretty(sweep)?)
        .with_context(|| format!("write {}", path.display()))
}