use std::time::Instant;
use num_complex::Complex;
use crate::FftExecutor;
pub enum ValidationOutcome {
Pass,
Fail {
max_error: f32,
},
}
pub struct BenchmarkResult {
pub rival_name: String,
pub n: usize,
pub batch_size: usize,
pub msamples_per_sec: f64,
pub gflops: f64,
pub validation: ValidationOutcome,
}
pub struct GpuOnlyResult {
pub rival_name: String,
pub n: usize,
pub batch_size: usize,
pub gpu_duration_sec: f64,
pub gpu_msamples_per_sec: f64,
pub gpu_gflops: f64,
}
pub const WARMUP_ITERS: usize = 1;
pub const BENCH_ITERS: usize = 10;
pub const GPU_WARMUP_ITERS: usize = 3;
pub const GPU_BENCH_ITERS: usize = 50;
pub const VALIDATION_TOLERANCE: f32 = 1e-3;
pub const MAX_TOTAL_SAMPLES: usize = 16 * 1024 * 1024;
pub fn benchmark_gpu_only(
gpu_fft: &(impl crate::GpuFftTrait + crate::FftExecutor + ?Sized),
n: usize,
batch_size: usize,
) -> Result<GpuOnlyResult, Box<dyn std::error::Error>> {
let log_n = n.trailing_zeros();
let sc = gpu_fft.get_or_build_size_cache(n, log_n);
let inputs: Vec<Vec<Complex<f32>>> = (0..batch_size)
.map(|_| {
let n_f = n as f32;
(0..n)
.map(|i| {
let t = i as f32 / n_f;
Complex::new(t * 0.001, (t * std::f32::consts::TAU).sin() * 0.001)
})
.collect()
})
.collect();
let mut all_raw_data = Vec::with_capacity((n * 2 * batch_size as usize) as usize);
for input in &inputs {
let raw = gpu_fft.prepare_input_data(input, false);
all_raw_data.extend_from_slice(&raw);
}
gpu_fft
.queue()
.write_buffer(&sc.buf_a, 0, bytemuck::cast_slice(&all_raw_data));
let gpu_duration_sec =
gpu_fft.benchmark_gpu_only(&sc, batch_size as u32, n, GPU_WARMUP_ITERS, GPU_BENCH_ITERS)?;
let total_samples = (n * batch_size) as f64;
let gpu_msamples_per_sec = total_samples / gpu_duration_sec / 1_000_000.0;
let gpu_gflops = 5.0 * total_samples * (n as f64).log2() / gpu_duration_sec / 1_000_000_000.0;
Ok(GpuOnlyResult {
rival_name: gpu_fft.name().to_string(),
n,
batch_size,
gpu_duration_sec,
gpu_msamples_per_sec,
gpu_gflops,
})
}
pub fn benchmark_gpu_pipeline(
rival: &dyn FftExecutor,
n: usize,
batch_size: usize,
) -> Result<GpuOnlyResult, Box<dyn std::error::Error>> {
let inputs: Vec<Vec<Complex<f32>>> = (0..batch_size)
.map(|_| {
let n_f = n as f32;
(0..n)
.map(|i| {
let t = i as f32 / n_f;
Complex::new(t * 0.001, (t * std::f32::consts::TAU).sin() * 0.001)
})
.collect()
})
.collect();
for _ in 0..WARMUP_ITERS {
let _ = rival.fft(&inputs);
}
let start = Instant::now();
for _ in 0..BENCH_ITERS {
let _ = rival.fft(&inputs);
}
let duration = start.elapsed() / BENCH_ITERS as u32;
let total_samples = (n * batch_size) as f64;
let gpu_msamples_per_sec = total_samples / duration.as_secs_f64() / 1_000_000.0;
let gpu_gflops =
5.0 * total_samples * (n as f64).log2() / duration.as_secs_f64() / 1_000_000_000.0;
Ok(GpuOnlyResult {
rival_name: rival.name().to_string(),
n,
batch_size,
gpu_duration_sec: duration.as_secs_f64(),
gpu_msamples_per_sec,
gpu_gflops,
})
}
pub fn validate_rival(
rival: &dyn FftExecutor,
reference: &dyn FftExecutor,
n: usize,
batch_size: usize,
) -> ValidationOutcome {
let inputs: Vec<Vec<Complex<f32>>> = (0..batch_size)
.map(|_| {
let n_f = n as f32;
(0..n)
.map(|i| {
let t = i as f32 / n_f;
Complex::new(t * 0.001, (t * std::f32::consts::TAU).sin() * 0.001)
})
.collect()
})
.collect();
let rival_out = match rival.fft(&inputs) {
Ok(out) => out,
Err(_) => {
return ValidationOutcome::Fail {
max_error: f32::INFINITY,
}
}
};
let ref_out = match reference.fft(&inputs) {
Ok(out) => out,
Err(_) => {
return ValidationOutcome::Fail {
max_error: f32::INFINITY,
}
}
};
validate(&rival_out, &ref_out)
}
pub fn benchmark_rival(
rival: &dyn FftExecutor,
reference: &dyn FftExecutor,
n: usize,
batch_size: usize,
) -> BenchmarkResult {
let inputs: Vec<Vec<Complex<f32>>> = (0..batch_size)
.map(|_| {
let n_f = n as f32;
(0..n)
.map(|i| {
let t = i as f32 / n_f;
Complex::new(t * 0.001, (t * std::f32::consts::TAU).sin() * 0.001)
})
.collect()
})
.collect();
for _ in 0..WARMUP_ITERS {
let _ = rival.fft(&inputs).unwrap();
}
let start = Instant::now();
for _ in 0..BENCH_ITERS {
let _ = rival.fft(&inputs).unwrap();
}
let duration = start.elapsed() / BENCH_ITERS as u32;
let total_samples = (n * batch_size) as f64;
let msamples_per_sec = total_samples / duration.as_secs_f64() / 1_000_000.0;
let gflops = 5.0 * total_samples * (n as f64).log2() / duration.as_secs_f64() / 1_000_000_000.0;
let rival_out = rival.fft(&inputs).unwrap();
let ref_out = reference.fft(&inputs).unwrap();
let validation = validate(&rival_out, &ref_out);
BenchmarkResult {
rival_name: rival.name().to_string(),
n,
batch_size,
msamples_per_sec,
gflops,
validation,
}
}
pub fn sweep_rival(
rival: &dyn FftExecutor,
reference: &dyn FftExecutor,
n: usize,
batch_sizes: &[usize],
) -> BenchmarkResult {
let mut best: Option<BenchmarkResult> = None;
for &batch_size in batch_sizes {
if n * batch_size > MAX_TOTAL_SAMPLES {
continue;
}
let result = benchmark_rival(rival, reference, n, batch_size);
if best
.as_ref()
.map_or(true, |b| result.msamples_per_sec > b.msamples_per_sec)
{
best = Some(result);
}
}
best.unwrap_or_else(|| benchmark_rival(rival, reference, n, 1))
}
fn validate(result: &[Vec<Complex<f32>>], reference: &[Vec<Complex<f32>>]) -> ValidationOutcome {
if result.len() != reference.len() {
return ValidationOutcome::Fail {
max_error: f32::INFINITY,
};
}
let mut max_err = 0.0f32;
for (r_vec, ref_vec) in result.iter().zip(reference.iter()) {
if r_vec.len() != ref_vec.len() {
return ValidationOutcome::Fail {
max_error: f32::INFINITY,
};
}
for (r, ref_val) in r_vec.iter().zip(ref_vec.iter()) {
let err = (r - ref_val).norm();
if err > max_err {
max_err = err;
}
}
}
if max_err <= VALIDATION_TOLERANCE {
ValidationOutcome::Pass
} else {
ValidationOutcome::Fail { max_error: max_err }
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::GpuFft;
#[test]
fn test_gpu_only_benchmark_basic() {
let gpu_fft = GpuFft::new().expect("Failed to create GpuFft");
let n = 256;
let batch_size = 4;
let result = benchmark_gpu_only(&gpu_fft, n, batch_size);
assert!(result.is_ok(), "GPU-only benchmark should succeed");
let result = result.unwrap();
assert_eq!(result.n, n);
assert_eq!(result.batch_size, batch_size);
assert!(
result.gpu_duration_sec > 0.0,
"GPU duration should be positive"
);
assert!(
result.gpu_msamples_per_sec > 0.0,
"Throughput should be positive"
);
assert!(result.gpu_gflops > 0.0, "GFLOPS should be positive");
assert!(
result.gpu_duration_sec < 1.0,
"GPU duration should be reasonable"
);
}
}