use crate::config::RuntimeConfig;
#[derive(Debug, Clone, Copy)]
pub struct Cost(pub f64);
impl Cost {
pub fn ns(self) -> f64 {
self.0
}
}
pub struct HwModel {
pub neon_flops: f64,
pub blas_flops: f64,
pub blas_overhead_ns: f64,
pub par_for_overhead_ns: f64,
pub l1_bytes: usize,
pub l2_bytes: usize,
pub mem_bw: f64,
pub num_threads: usize,
}
impl HwModel {
pub fn from_config(cfg: &RuntimeConfig) -> Self {
#[cfg(all(target_arch = "aarch64", target_os = "macos"))]
let model = HwModel {
neon_flops: 72e9, blas_flops: 1000e9, blas_overhead_ns: 500.0, par_for_overhead_ns: 5000.0, l1_bytes: 65536, l2_bytes: 4 * 1024 * 1024, mem_bw: 50.0, num_threads: cfg.pool_workers + 1,
};
#[cfg(not(all(target_arch = "aarch64", target_os = "macos")))]
let model = HwModel {
neon_flops: 32e9,
blas_flops: 200e9,
blas_overhead_ns: 300.0,
par_for_overhead_ns: 3000.0,
l1_bytes: 32768,
l2_bytes: 1024 * 1024,
mem_bw: 30.0,
num_threads: cfg.pool_workers + 1,
};
model
}
pub fn prefer_neon_sgemm(&self, m: usize, k: usize, n: usize) -> bool {
let flops = 2.0 * m as f64 * k as f64 * n as f64;
let blas_time = flops / self.blas_flops + self.blas_overhead_ns * 1e-9;
let neon_time = flops / self.neon_flops;
neon_time < blas_time
}
pub fn prefer_parallel(&self, total_elements: usize, cost_per_element_ns: f64) -> bool {
let seq_time = total_elements as f64 * cost_per_element_ns;
let par_time = seq_time / self.num_threads as f64 + self.par_for_overhead_ns;
par_time < seq_time
}
pub fn prefer_blas_sdpa(
&self,
batch: usize,
seq: usize,
num_heads: usize,
head_dim: usize,
) -> bool {
let total_heads = batch * num_heads;
let per_head_flops = 2.0 * seq as f64 * seq as f64 * head_dim as f64 * 2.0;
let blas_per_head = per_head_flops / self.blas_flops + 2.0 * self.blas_overhead_ns * 1e-9;
let neon_per_head = per_head_flops / self.neon_flops;
let blas_total = blas_per_head * total_heads as f64 / self.num_threads as f64
+ self.par_for_overhead_ns * 1e-9;
let neon_total = neon_per_head * total_heads as f64;
blas_total < neon_total
}
pub fn prefer_fused_layer(
&self,
batch: usize,
seq: usize,
hidden: usize,
intermediate: usize,
) -> bool {
let m = batch * seq;
let qkv_bytes = m * 3 * hidden * 4;
let attn_bytes = m * hidden * 4;
let ffn_bytes = m * intermediate * 4;
let total_bytes = qkv_bytes + 2 * attn_bytes + ffn_bytes;
total_bytes <= self.l2_bytes / 2
}
}
pub fn hw_model() -> &'static HwModel {
use std::sync::OnceLock;
static MODEL: OnceLock<HwModel> = OnceLock::new();
MODEL.get_or_init(|| HwModel::from_config(RuntimeConfig::global()))
}