#[cfg(feature = "std")]
use std::time::Instant;
use crate::api::wisdom::WisdomCache;
use crate::api::{Direction, Flags};
use crate::kernel::WisdomEntry;
use crate::kernel::{Complex, Float};
use super::types::Plan;
#[derive(Debug, Clone)]
pub struct TuneResult {
pub n: usize,
pub direction: Direction,
pub algorithm_name: String,
pub elapsed_ns: u64,
}
fn make_test_input<T: Float>(n: usize) -> Vec<Complex<T>> {
let mut state: u64 = 0xdead_beef_cafe_1234;
(0..n)
.map(|_| {
state = state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
let re = T::from_f64((state >> 32) as f64 / u32::MAX as f64 * 2.0 - 1.0);
state = state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
let im = T::from_f64((state >> 32) as f64 / u32::MAX as f64 * 2.0 - 1.0);
Complex::new(re, im)
})
.collect()
}
fn median_ns(timings: &[u64]) -> u64 {
let len = timings.len();
if len == 0 {
return 0;
}
if len % 2 == 0 {
(timings[len / 2 - 1] / 2) + (timings[len / 2] / 2)
} else {
timings[len / 2]
}
}
#[cfg(feature = "std")]
fn time_plan<T: Float>(plan: &Plan<T>, input: &[Complex<T>], max_iters: usize) -> u64 {
let n = plan.size();
let mut output = vec![Complex::<T>::new(T::ZERO, T::ZERO); n];
const WARMUP: usize = 4;
for _ in 0..WARMUP {
plan.execute(input, &mut output);
}
let mut timings: Vec<u64> = Vec::with_capacity(max_iters);
for _ in 0..max_iters {
let t0 = Instant::now();
plan.execute(input, &mut output);
timings.push(t0.elapsed().as_nanos() as u64);
}
timings.sort_unstable();
median_ns(&timings)
}
#[cfg(feature = "std")]
pub fn tune_size<T: Float>(n: usize, direction: Direction, max_iters: usize) -> Option<TuneResult> {
let iters = max_iters.clamp(4, 10_000);
let plan = Plan::<T>::dft_1d(n, direction, Flags::ESTIMATE)?;
let input = make_test_input::<T>(n);
let elapsed_ns = time_plan(&plan, &input, iters);
Some(TuneResult {
n,
direction,
algorithm_name: plan.wisdom_solver_name(),
elapsed_ns,
})
}
#[cfg(feature = "std")]
pub fn tune_range<T: Float>(
min_n: usize,
max_n: usize,
reps_per_size: usize,
mut on_progress: impl FnMut(usize),
) -> WisdomCache {
let mut cache = WisdomCache::new();
for n in min_n..=max_n {
if let Some(result) = tune_size::<T>(n, Direction::Forward, reps_per_size) {
cache.store(WisdomEntry {
problem_hash: n as u64,
solver_name: result.algorithm_name.clone(),
cost: result.elapsed_ns as f64,
});
}
on_progress(n);
}
cache
}
#[cfg(all(test, feature = "std"))]
mod tests {
use super::*;
#[test]
fn tune_size_returns_result() {
let result = tune_size::<f64>(64, Direction::Forward, 16);
assert!(result.is_some(), "tune_size should succeed for n=64");
let r = result.expect("already checked is_some");
assert_eq!(r.n, 64);
assert!(r.elapsed_ns > 0, "elapsed_ns must be non-zero");
assert!(
!r.algorithm_name.is_empty(),
"algorithm_name must be non-empty"
);
}
#[test]
fn tune_size_various_sizes() {
let r = tune_size::<f64>(8, Direction::Forward, 8);
assert!(r.is_some());
let r = tune_size::<f64>(17, Direction::Forward, 8);
assert!(r.is_some());
let r = tune_size::<f64>(12, Direction::Forward, 8);
assert!(r.is_some());
}
#[test]
fn tune_range_covers_all_sizes() {
let cache = tune_range::<f64>(2, 32, 8, |_| {});
assert!(
cache.entry_count() >= 1,
"at least one entry expected, got 0"
);
assert_eq!(
cache.entry_count(),
31,
"expected 31 entries for range 2..=32, got {}",
cache.entry_count()
);
}
#[test]
fn binary_wisdom_round_trip() {
let cache = tune_range::<f64>(2, 8, 4, |_| {});
let bytes = cache.to_binary();
assert!(!bytes.is_empty(), "binary output must not be empty");
let restored =
WisdomCache::from_binary(&bytes).expect("from_binary should succeed on valid data");
assert_eq!(
restored.entry_count(),
cache.entry_count(),
"entry count must be preserved in round-trip"
);
}
#[test]
fn binary_wisdom_round_trip_solver_name_content() {
let result = tune_size::<f64>(6, Direction::Forward, 4);
let r = result.expect("tune_size for n=6 (MixedRadix) must succeed");
assert!(
r.algorithm_name.starts_with("mixed-radix-"),
"n=6 must map to mixed-radix, got: {}",
r.algorithm_name
);
let mut cache = WisdomCache::new();
cache.store(crate::kernel::WisdomEntry {
problem_hash: 6,
solver_name: r.algorithm_name.clone(),
cost: r.elapsed_ns as f64,
});
let bytes = cache.to_binary();
let restored = WisdomCache::from_binary(&bytes)
.expect("from_binary must succeed for mixed-radix entry");
let entry = restored
.lookup(6)
.expect("entry for n=6 must survive round-trip");
let restored_name = entry.solver_name.clone();
assert!(
restored_name.starts_with("mixed-radix-"),
"round-tripped solver name must still be mixed-radix, got: {restored_name}"
);
assert_eq!(
restored_name, r.algorithm_name,
"solver name must be identical after binary round-trip"
);
}
#[test]
fn estimate_does_not_tune() {
let start = Instant::now();
let _plan = Plan::<f64>::dft_1d(64, Direction::Forward, Flags::ESTIMATE);
let elapsed = start.elapsed();
assert!(
elapsed.as_millis() < 100,
"ESTIMATE should be fast, took {elapsed:?}"
);
}
#[test]
fn make_test_input_length() {
let v = make_test_input::<f64>(64);
assert_eq!(v.len(), 64);
}
#[test]
fn median_ns_basic() {
let mut v = vec![3u64, 1, 4, 1, 5];
v.sort_unstable();
assert_eq!(median_ns(&v), 3);
let mut even = vec![2u64, 4];
even.sort_unstable();
assert_eq!(median_ns(&even), 3);
assert_eq!(median_ns(&[]), 0);
}
}