use crate::result::{BatchingInfo, UnmeasurableInfo};
use crate::types::Class;
use rand::seq::SliceRandom;
use super::timer::{black_box, rdtsc, Timer};
pub const MIN_TICKS_SINGLE_CALL: f64 = 5.0;
pub const TARGET_TICKS_PER_BATCH: f64 = 50.0;
pub const MAX_BATCH_SIZE: u32 = 20;
pub const PILOT_SAMPLES: usize = 100;
#[derive(Debug, Clone, Copy)]
pub struct Sample {
pub class: Class,
pub cycles: u64,
}
impl Sample {
pub fn new(class: Class, cycles: u64) -> Self {
Self { class, cycles }
}
}
#[derive(Debug)]
pub struct Collector {
timer: Timer,
warmup_iterations: usize,
max_batch_size: u32,
target_ticks_per_batch: f64,
}
impl Collector {
pub fn new(warmup_iterations: usize) -> Self {
let timer = Timer::new();
Self {
timer,
warmup_iterations,
max_batch_size: MAX_BATCH_SIZE,
target_ticks_per_batch: TARGET_TICKS_PER_BATCH,
}
}
pub fn with_timer(timer: Timer, warmup_iterations: usize) -> Self {
Self {
timer,
warmup_iterations,
max_batch_size: MAX_BATCH_SIZE,
target_ticks_per_batch: TARGET_TICKS_PER_BATCH,
}
}
pub fn with_max_batch_size(
timer: Timer,
warmup_iterations: usize,
max_batch_size: u32,
) -> Self {
Self {
timer,
warmup_iterations,
max_batch_size: max_batch_size.max(1),
target_ticks_per_batch: TARGET_TICKS_PER_BATCH,
}
}
pub fn timer(&self) -> &Timer {
&self.timer
}
fn pilot_and_warmup<F, R, T>(&self, mut fixed: F, mut random: R) -> BatchingInfo
where
F: FnMut() -> T,
R: FnMut() -> T,
{
let pilot_count = PILOT_SAMPLES.min(self.warmup_iterations);
let warmup_only = self.warmup_iterations.saturating_sub(pilot_count);
for _ in 0..warmup_only {
black_box(fixed());
black_box(random());
}
let mut pilot_cycles = Vec::with_capacity(pilot_count * 2);
for _ in 0..pilot_count {
let start = rdtsc();
black_box(fixed());
let end = rdtsc();
pilot_cycles.push(end.saturating_sub(start));
let start = rdtsc();
black_box(random());
let end = rdtsc();
pilot_cycles.push(end.saturating_sub(start));
}
pilot_cycles.sort_unstable();
let median_cycles = pilot_cycles[pilot_cycles.len() / 2];
let median_ns = self.timer.cycles_to_ns(median_cycles);
let resolution_ns = self.timer.resolution_ns();
let ticks_per_call = median_ns / resolution_ns;
let _threshold_ns = resolution_ns * TARGET_TICKS_PER_BATCH / MAX_BATCH_SIZE as f64;
let (k, enabled, unmeasurable, rationale) = if ticks_per_call >= self.target_ticks_per_batch
{
(
1,
false,
None,
format!(
"no batching needed ({:.1} ticks/call >= {:.0} target)",
ticks_per_call, self.target_ticks_per_batch
),
)
} else {
let k_raw = (self.target_ticks_per_batch / ticks_per_call).ceil() as u32;
let k_attempt = k_raw.clamp(1, self.max_batch_size);
let actual_ticks = ticks_per_call * k_attempt as f64;
if actual_ticks < self.target_ticks_per_batch {
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
let suggestion =
". On macOS, run with sudo to enable kperf cycle counting (~1ns resolution)";
#[cfg(all(target_os = "linux", target_arch = "aarch64"))]
let suggestion = ". Run with sudo and --features perf for cycle-accurate timing";
#[cfg(not(target_arch = "aarch64"))]
let suggestion = "";
let rationale = format!(
"UNMEASURABLE: {:.1} ticks/batch < {:.0} minimum even at K={} (op ~{:.2}ns, threshold ~{:.2}ns){}",
actual_ticks,
self.target_ticks_per_batch,
k_attempt,
median_ns,
resolution_ns * self.target_ticks_per_batch / k_attempt as f64,
suggestion
);
(
1, false,
Some(UnmeasurableInfo {
operation_ns: median_ns,
threshold_ns: resolution_ns * self.target_ticks_per_batch
/ k_attempt as f64,
ticks_per_call,
}),
rationale,
)
} else {
let partial = actual_ticks < self.target_ticks_per_batch;
let rationale = if partial {
format!(
"K={} ({:.1} ticks/batch < {:.0} target, capped at MAX_BATCH_SIZE={})",
k_attempt, actual_ticks, self.target_ticks_per_batch, self.max_batch_size
)
} else {
format!(
"K={} ({:.1} ticks/batch, {:.2} ticks/call, timer res {:.1}ns)",
k_attempt, actual_ticks, ticks_per_call, resolution_ns
)
};
(k_attempt, k_attempt > 1, None, rationale)
}
};
BatchingInfo {
enabled,
k,
ticks_per_batch: ticks_per_call * k as f64,
rationale,
unmeasurable,
}
}
pub fn collect_with_info<F, R, T>(
&self,
samples_per_class: usize,
mut fixed: F,
mut random: R,
) -> (Vec<Sample>, BatchingInfo)
where
F: FnMut() -> T,
R: FnMut() -> T,
{
let batching_info = self.pilot_and_warmup(&mut fixed, &mut random);
let k = batching_info.k;
let schedule = self.create_schedule(samples_per_class);
let mut samples = Vec::with_capacity(samples_per_class * 2);
if k == 1 {
for class in schedule {
let result = match class {
Class::Baseline => self.timer.measure_cycles(&mut fixed),
Class::Sample => self.timer.measure_cycles(&mut random),
};
if let Ok(cycles) = result {
samples.push(Sample::new(class, cycles));
}
}
} else {
for class in schedule {
let result = match class {
Class::Baseline => self.measure_batch_total(&mut fixed, k),
Class::Sample => self.measure_batch_total(&mut random, k),
};
if let Ok(cycles) = result {
samples.push(Sample::new(class, cycles));
}
}
}
(samples, batching_info)
}
pub fn collect<F, R, T>(&self, samples_per_class: usize, fixed: F, random: R) -> Vec<Sample>
where
F: FnMut() -> T,
R: FnMut() -> T,
{
let (samples, _) = self.collect_with_info(samples_per_class, fixed, random);
samples
}
#[inline]
fn measure_batch_total<F, T>(&self, f: &mut F, k: u32) -> super::error::MeasurementResult
where
F: FnMut() -> T,
{
let start = rdtsc();
for _ in 0..k {
black_box(f());
}
let end = rdtsc();
Ok(end.saturating_sub(start))
}
fn create_schedule(&self, samples_per_class: usize) -> Vec<Class> {
let mut rng = rand::rng();
let mut schedule: Vec<Class> = Vec::with_capacity(samples_per_class * 2);
schedule.extend(std::iter::repeat_n(Class::Baseline, samples_per_class));
schedule.extend(std::iter::repeat_n(Class::Sample, samples_per_class));
schedule.shuffle(&mut rng);
schedule
}
pub fn collect_separated<F, R, T>(
&self,
samples_per_class: usize,
fixed: F,
random: R,
) -> (Vec<u64>, Vec<u64>, BatchingInfo)
where
F: FnMut() -> T,
R: FnMut() -> T,
{
let (samples, batching_info) = self.collect_with_info(samples_per_class, fixed, random);
let mut fixed_samples = Vec::with_capacity(samples_per_class);
let mut random_samples = Vec::with_capacity(samples_per_class);
for sample in samples {
match sample.class {
Class::Baseline => fixed_samples.push(sample.cycles),
Class::Sample => random_samples.push(sample.cycles),
}
}
(fixed_samples, random_samples, batching_info)
}
}
impl Default for Collector {
fn default() -> Self {
Self::new(1000)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sample_creation() {
let sample = Sample::new(Class::Baseline, 1000);
assert_eq!(sample.class, Class::Baseline);
assert_eq!(sample.cycles, 1000);
}
#[test]
fn test_schedule_balanced() {
let collector = Collector::new(0);
let schedule = collector.create_schedule(100);
let baseline_count = schedule.iter().filter(|c| **c == Class::Baseline).count();
let sample_count = schedule.iter().filter(|c| **c == Class::Sample).count();
assert_eq!(baseline_count, 100);
assert_eq!(sample_count, 100);
}
#[test]
fn test_collector_basic() {
let collector = Collector::new(10);
let counter = std::sync::atomic::AtomicU64::new(0);
let (fixed, random, _batching) = collector.collect_separated(
100,
|| counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed),
|| counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed),
);
assert_eq!(fixed.len(), 100);
assert_eq!(random.len(), 100);
}
#[test]
fn test_unmeasurable_detection() {
let collector = Collector::new(10);
let (_, batching) = collector.collect_with_info(
100,
|| 42u8, || 42u8,
);
#[cfg(target_arch = "aarch64")]
{
if batching.ticks_per_batch < MIN_TICKS_SINGLE_CALL {
assert!(
batching.unmeasurable.is_some(),
"Expected unmeasurable for trivial op on ARM, got: {:?}",
batching
);
let info = batching.unmeasurable.as_ref().unwrap();
assert!(info.ticks_per_call < MIN_TICKS_SINGLE_CALL);
assert!(batching.rationale.contains("UNMEASURABLE"));
}
}
if let Some(ref info) = batching.unmeasurable {
assert!(info.ticks_per_call < MIN_TICKS_SINGLE_CALL);
assert!(info.threshold_ns > 0.0);
assert!(info.operation_ns >= 0.0);
assert!(!batching.enabled);
assert_eq!(batching.k, 1);
}
}
#[test]
fn test_batching_k_selection() {
let collector = Collector::new(10);
let (_, batching) = collector.collect_with_info(
100,
|| {
let mut x = 0u64;
for i in 0..1000 {
x = x.wrapping_add(black_box(i));
}
black_box(x)
},
|| {
let mut x = 0u64;
for i in 0..1000 {
x = x.wrapping_add(black_box(i));
}
black_box(x)
},
);
if batching.unmeasurable.is_none() {
assert!(batching.k >= 1, "K should be at least 1");
if batching.enabled {
assert!(batching.k > 1, "Batching enabled but K <= 1");
}
assert!(
batching.k <= MAX_BATCH_SIZE,
"K {} exceeds MAX_BATCH_SIZE {}",
batching.k,
MAX_BATCH_SIZE
);
assert!(batching.ticks_per_batch > 0.0);
}
}
#[test]
fn test_max_batch_size_cap() {
let timer = Timer::new();
let collector = Collector::with_max_batch_size(timer, 10, 5);
let (_, batching) =
collector.collect_with_info(100, || black_box(42u8), || black_box(42u8));
assert!(
batching.k <= 5,
"K {} exceeded configured max_batch_size 5",
batching.k
);
}
#[test]
fn test_batching_disabled_for_slow_ops() {
let collector = Collector::new(10);
let (_, batching) = collector.collect_with_info(
50,
|| {
let mut x = 0u64;
for i in 0..100_000 {
x = x.wrapping_add(black_box(i));
}
std::hint::black_box(x)
},
|| {
let mut x = 0u64;
for i in 0..100_000 {
x = x.wrapping_add(black_box(i));
}
std::hint::black_box(x)
},
);
if batching.unmeasurable.is_none() && batching.ticks_per_batch >= TARGET_TICKS_PER_BATCH {
assert!(
!batching.enabled || batching.k == 1,
"Slow op should not need batching: {:?}",
batching
);
}
}
#[test]
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
fn test_kperf_suggestion_on_macos_arm64() {
let collector = Collector::new(10);
let (_, batching) = collector.collect_with_info(
100,
|| 42u8, || 42u8,
);
if batching.unmeasurable.is_some() {
assert!(
batching.rationale.contains("kperf") || batching.rationale.contains("macOS"),
"macOS ARM64 unmeasurable should mention kperf, got: {}",
batching.rationale
);
}
}
#[test]
#[cfg(all(target_os = "linux", target_arch = "aarch64"))]
fn test_suggestion_on_linux_arm64() {
let collector = Collector::new(10);
let (_, batching) = collector.collect_with_info(
100,
|| 42u8, || 42u8,
);
if batching.unmeasurable.is_some() {
assert!(
batching.rationale.contains("--features perf"),
"Linux ARM64 unmeasurable should mention --features perf, got: {}",
batching.rationale
);
}
}
#[test]
fn test_batching_info_consistency() {
let collector = Collector::new(10);
let (_, batching) = collector.collect_with_info(
100,
|| {
let mut x = 0u64;
for i in 0..500 {
x = x.wrapping_add(black_box(i));
}
black_box(x)
},
|| {
let mut x = 0u64;
for i in 0..500 {
x = x.wrapping_add(black_box(i));
}
black_box(x)
},
);
assert_eq!(
batching.enabled,
batching.k > 1,
"enabled={} should match k > 1 (k={})",
batching.enabled,
batching.k
);
assert!(
!batching.rationale.is_empty(),
"rationale should not be empty"
);
if batching.unmeasurable.is_some() {
assert_eq!(batching.k, 1, "unmeasurable should have k=1");
assert!(!batching.enabled, "unmeasurable should not be enabled");
}
}
}