Skip to main content

oxibonsai_runtime/
auto_tuner.rs

1//! Performance auto-tuning: detect hardware, select kernels, tune parameters.
2//!
3//! The auto-tuner:
4//! 1. Detects CPU features (AVX2, AVX-512, NEON, WASM)
5//! 2. Estimates memory budget from available system memory
6//! 3. Recommends optimal batch size, KV cache size, thread count
7//! 4. Selects the best kernel tier for the detected hardware
8//! 5. Provides runtime-adjustable tuning knobs
9
10use std::fmt;
11use std::time::Instant;
12
13// ---------------------------------------------------------------------------
14// CpuArch
15// ---------------------------------------------------------------------------
16
17/// CPU architecture family.
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum CpuArch {
20    X86_64,
21    Aarch64,
22    Wasm32,
23    Other,
24}
25
26impl fmt::Display for CpuArch {
27    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28        match self {
29            Self::X86_64 => write!(f, "x86_64"),
30            Self::Aarch64 => write!(f, "aarch64"),
31            Self::Wasm32 => write!(f, "wasm32"),
32            Self::Other => write!(f, "other"),
33        }
34    }
35}
36
37// ---------------------------------------------------------------------------
38// SimdTier
39// ---------------------------------------------------------------------------
40
41/// SIMD tier ranking (ordered from weakest to strongest).
42#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
43pub enum SimdTier {
44    Scalar,
45    Sse42,
46    Neon,
47    Avx2,
48    Avx512,
49}
50
51impl SimdTier {
52    /// Human-readable name.
53    pub fn name(&self) -> &'static str {
54        match self {
55            Self::Scalar => "Scalar",
56            Self::Sse42 => "SSE4.2",
57            Self::Neon => "NEON",
58            Self::Avx2 => "AVX2",
59            Self::Avx512 => "AVX-512",
60        }
61    }
62
63    /// Native vector width in bits.
64    pub fn vector_width_bits(&self) -> usize {
65        match self {
66            Self::Scalar => 64,
67            Self::Sse42 => 128,
68            Self::Neon => 128,
69            Self::Avx2 => 256,
70            Self::Avx512 => 512,
71        }
72    }
73
74    /// Rough expected speed-up over pure scalar for typical GEMV workloads.
75    pub fn expected_speedup_over_scalar(&self) -> f32 {
76        match self {
77            Self::Scalar => 1.0,
78            Self::Sse42 => 2.0,
79            Self::Neon => 2.5,
80            Self::Avx2 => 4.0,
81            Self::Avx512 => 7.0,
82        }
83    }
84}
85
86impl fmt::Display for SimdTier {
87    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88        write!(f, "{}", self.name())
89    }
90}
91
92// ---------------------------------------------------------------------------
93// CpuFeatures
94// ---------------------------------------------------------------------------
95
96/// Detected CPU features.
97#[derive(Debug, Clone, PartialEq)]
98pub struct CpuFeatures {
99    pub has_avx2: bool,
100    pub has_avx512: bool,
101    pub has_neon: bool,
102    pub has_fma: bool,
103    pub has_sse42: bool,
104    pub logical_cores: usize,
105    pub physical_cores: usize,
106    pub arch: CpuArch,
107    pub cache_line_bytes: usize,
108}
109
110impl CpuFeatures {
111    /// Detect features at runtime via `cfg` and `std::thread::available_parallelism`.
112    pub fn detect() -> Self {
113        let arch = detect_arch();
114
115        let has_avx2 = cfg_has_avx2();
116        let has_avx512 = cfg_has_avx512();
117        let has_neon = cfg_has_neon();
118        let has_fma = cfg_has_fma();
119        let has_sse42 = cfg_has_sse42();
120
121        let logical_cores = std::thread::available_parallelism()
122            .map(|n| n.get())
123            .unwrap_or(1);
124        // Rough heuristic: assume hyper-threading factor of 2 on x86_64.
125        let physical_cores = match arch {
126            CpuArch::X86_64 => logical_cores.div_ceil(2),
127            _ => logical_cores,
128        };
129
130        let cache_line_bytes = match arch {
131            CpuArch::X86_64 => 64,
132            CpuArch::Aarch64 => 64,
133            _ => 64,
134        };
135
136        Self {
137            has_avx2,
138            has_avx512,
139            has_neon,
140            has_fma,
141            has_sse42,
142            logical_cores,
143            physical_cores,
144            arch,
145            cache_line_bytes,
146        }
147    }
148
149    /// Best SIMD tier available on this CPU.
150    pub fn best_simd_tier(&self) -> SimdTier {
151        if self.has_avx512 {
152            SimdTier::Avx512
153        } else if self.has_avx2 {
154            SimdTier::Avx2
155        } else if self.has_neon {
156            SimdTier::Neon
157        } else if self.has_sse42 {
158            SimdTier::Sse42
159        } else {
160            SimdTier::Scalar
161        }
162    }
163
164    /// Recommended thread count for compute-bound work.
165    ///
166    /// Uses physical cores to avoid contention on hyper-threaded siblings,
167    /// but guarantees at least 1.
168    pub fn recommended_threads(&self) -> usize {
169        self.physical_cores.max(1)
170    }
171
172    /// One-line human-readable summary.
173    pub fn summary(&self) -> String {
174        format!(
175            "arch={}, simd={}, logical_cores={}, physical_cores={}, cache_line={}B",
176            self.arch,
177            self.best_simd_tier(),
178            self.logical_cores,
179            self.physical_cores,
180            self.cache_line_bytes,
181        )
182    }
183}
184
185// ---------------------------------------------------------------------------
186// KvCacheType
187// ---------------------------------------------------------------------------
188
189/// KV cache quantisation type.
190#[derive(Debug, Clone, Copy, PartialEq)]
191pub enum KvCacheType {
192    Fp32,
193    Fp16,
194    Int8,
195}
196
197impl KvCacheType {
198    /// Bytes per key or value element.
199    pub fn bytes_per_element(&self) -> usize {
200        match self {
201            Self::Fp32 => 4,
202            Self::Fp16 => 2,
203            Self::Int8 => 1,
204        }
205    }
206
207    /// Human-readable name.
208    pub fn name(&self) -> &'static str {
209        match self {
210            Self::Fp32 => "FP32",
211            Self::Fp16 => "FP16",
212            Self::Int8 => "INT8",
213        }
214    }
215}
216
217impl fmt::Display for KvCacheType {
218    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
219        write!(f, "{}", self.name())
220    }
221}
222
223// ---------------------------------------------------------------------------
224// MemoryBudget
225// ---------------------------------------------------------------------------
226
227/// Memory budget estimation for a model deployment.
228#[derive(Debug, Clone)]
229pub struct MemoryBudget {
230    /// Total system memory the user wants to allocate (bytes).
231    pub total_system_bytes: usize,
232    /// Bytes available after subtracting model weights and runtime overhead.
233    pub available_bytes: usize,
234    /// Model weight footprint (bytes).
235    pub model_weight_bytes: usize,
236    /// Budget specifically earmarked for KV cache (bytes).
237    pub kv_cache_budget: usize,
238    /// Estimated runtime overhead for buffers, activations, etc. (bytes).
239    pub runtime_overhead: usize,
240}
241
242impl MemoryBudget {
243    /// Estimate budget.
244    ///
245    /// * `total_available_mb` — how much RAM (MB) the user is willing to use.
246    /// * `model_params` — total number of parameters in the model.
247    /// * `bits_per_weight` — quantisation bits per weight (e.g. 1.125 for Q1_0).
248    pub fn estimate(total_available_mb: usize, model_params: usize, bits_per_weight: f32) -> Self {
249        let total_system_bytes = total_available_mb * 1024 * 1024;
250
251        // Weight footprint = params * bits / 8
252        let model_weight_bytes = ((model_params as f64) * (bits_per_weight as f64) / 8.0) as usize;
253
254        // Runtime overhead: ~10% of total or 256 MB, whichever is smaller
255        let runtime_overhead = (total_system_bytes / 10).min(256 * 1024 * 1024);
256
257        // Available = total - weights - overhead (saturating)
258        let available_bytes = total_system_bytes
259            .saturating_sub(model_weight_bytes)
260            .saturating_sub(runtime_overhead);
261
262        // KV cache gets 80% of remaining budget
263        let kv_cache_budget = available_bytes * 4 / 5;
264
265        Self {
266            total_system_bytes,
267            available_bytes,
268            model_weight_bytes,
269            kv_cache_budget,
270            runtime_overhead,
271        }
272    }
273
274    /// Maximum context length that fits in the KV cache budget.
275    ///
276    /// KV cache size per token = 2 (K+V) * num_layers * num_heads * head_dim * bytes_per_element.
277    /// We use FP16 (2 bytes) as the default element type for estimation.
278    pub fn max_context_length(
279        &self,
280        num_layers: usize,
281        num_heads: usize,
282        head_dim: usize,
283    ) -> usize {
284        let bytes_per_token = Self::kv_bytes_per_token(num_layers, num_heads, head_dim);
285        if bytes_per_token == 0 {
286            return 0;
287        }
288        self.kv_cache_budget / bytes_per_token
289    }
290
291    /// Whether a given context length fits in the KV cache budget.
292    pub fn fits_context(
293        &self,
294        ctx_len: usize,
295        num_layers: usize,
296        num_heads: usize,
297        head_dim: usize,
298    ) -> bool {
299        let bytes_per_token = Self::kv_bytes_per_token(num_layers, num_heads, head_dim);
300        ctx_len * bytes_per_token <= self.kv_cache_budget
301    }
302
303    /// Human-readable summary.
304    pub fn summary(&self) -> String {
305        let mb = |b: usize| b as f64 / (1024.0 * 1024.0);
306        format!(
307            "total={:.0}MB, weights={:.0}MB, kv_budget={:.0}MB, overhead={:.0}MB, available={:.0}MB",
308            mb(self.total_system_bytes),
309            mb(self.model_weight_bytes),
310            mb(self.kv_cache_budget),
311            mb(self.runtime_overhead),
312            mb(self.available_bytes),
313        )
314    }
315
316    /// Helper: KV cache bytes per token (K+V, FP16).
317    fn kv_bytes_per_token(num_layers: usize, num_heads: usize, head_dim: usize) -> usize {
318        // 2 (K+V) * layers * heads * head_dim * 2 bytes (FP16)
319        2 * num_layers * num_heads * head_dim * 2
320    }
321}
322
323// ---------------------------------------------------------------------------
324// TuningRecommendation
325// ---------------------------------------------------------------------------
326
327/// Tuning recommendations produced by the auto-tuner.
328#[derive(Debug, Clone)]
329pub struct TuningRecommendation {
330    /// Selected SIMD tier.
331    pub simd_tier: SimdTier,
332    /// Recommended worker thread count.
333    pub thread_count: usize,
334    /// Recommended batch size for prefill.
335    pub batch_size: usize,
336    /// Maximum context length that fits memory.
337    pub max_context: usize,
338    /// Recommended KV cache element type.
339    pub kv_cache_type: KvCacheType,
340    /// Whether to use flash-decode optimisation.
341    pub use_flash_decode: bool,
342    /// Whether to use prefix caching.
343    pub use_prefix_cache: bool,
344    /// Estimated tokens per second (rough).
345    pub estimated_tokens_per_second: f32,
346}
347
348impl TuningRecommendation {
349    /// Human-readable summary.
350    pub fn summary(&self) -> String {
351        format!(
352            "simd={}, threads={}, batch={}, max_ctx={}, kv={}, flash_decode={}, prefix_cache={}, est_tok/s={:.1}",
353            self.simd_tier,
354            self.thread_count,
355            self.batch_size,
356            self.max_context,
357            self.kv_cache_type,
358            self.use_flash_decode,
359            self.use_prefix_cache,
360            self.estimated_tokens_per_second,
361        )
362    }
363}
364
365// ---------------------------------------------------------------------------
366// KernelBenchmark
367// ---------------------------------------------------------------------------
368
369/// Kernel micro-benchmark result.
370#[derive(Debug, Clone)]
371pub struct KernelBenchmark {
372    /// SIMD tier that was benchmarked.
373    pub simd_tier: SimdTier,
374    /// Number of iterations run.
375    pub iterations: usize,
376    /// Total wall-clock time in milliseconds.
377    pub total_duration_ms: f64,
378    /// Operations per second.
379    pub ops_per_second: f64,
380    /// Estimated GFLOPS (based on a synthetic FMA-heavy workload).
381    pub gflops: f64,
382}
383
384impl KernelBenchmark {
385    /// Human-readable summary.
386    pub fn summary(&self) -> String {
387        format!(
388            "tier={}, iters={}, time={:.2}ms, ops/s={:.0}, GFLOPS={:.2}",
389            self.simd_tier,
390            self.iterations,
391            self.total_duration_ms,
392            self.ops_per_second,
393            self.gflops,
394        )
395    }
396}
397
398// ---------------------------------------------------------------------------
399// AutoTuner
400// ---------------------------------------------------------------------------
401
402/// The auto-tuner: detects hardware and recommends inference parameters.
403pub struct AutoTuner {
404    cpu: CpuFeatures,
405    memory_mb: usize,
406}
407
408impl AutoTuner {
409    /// Create a new auto-tuner that auto-detects CPU and uses system RSS as
410    /// a rough memory estimate (defaults to 4096 MB if detection fails).
411    pub fn new() -> Self {
412        let cpu = CpuFeatures::detect();
413        // Use a conservative default: 4 GB
414        let memory_mb = 4096;
415        Self { cpu, memory_mb }
416    }
417
418    /// Create an auto-tuner with an explicit memory budget (in MB).
419    pub fn with_memory_mb(memory_mb: usize) -> Self {
420        let cpu = CpuFeatures::detect();
421        Self { cpu, memory_mb }
422    }
423
424    /// Generate tuning recommendations for a specific model configuration.
425    ///
426    /// * `model_params` — total number of model parameters.
427    /// * `bits_per_weight` — quantisation bits (e.g. 1.125 for Q1_0_g128).
428    /// * `num_layers` / `num_heads` / `head_dim` — transformer architecture.
429    pub fn recommend(
430        &self,
431        model_params: usize,
432        bits_per_weight: f32,
433        num_layers: usize,
434        num_heads: usize,
435        head_dim: usize,
436    ) -> TuningRecommendation {
437        let simd_tier = self.cpu.best_simd_tier();
438        let thread_count = self.cpu.recommended_threads();
439
440        let budget = MemoryBudget::estimate(self.memory_mb, model_params, bits_per_weight);
441        let max_context = budget.max_context_length(num_layers, num_heads, head_dim);
442
443        // Batch size heuristic: scale with cores, cap by available memory.
444        let batch_size = compute_batch_size(thread_count, &budget);
445
446        // KV cache type: prefer FP16, fall back to INT8 if memory is tight.
447        let kv_cache_type = if budget.kv_cache_budget > 128 * 1024 * 1024 {
448            KvCacheType::Fp16
449        } else {
450            KvCacheType::Int8
451        };
452
453        // Flash decode is beneficial for long contexts (>= 2048 tokens).
454        let use_flash_decode = max_context >= 2048;
455
456        // Prefix caching useful when there is plenty of KV budget.
457        let use_prefix_cache = budget.kv_cache_budget > 256 * 1024 * 1024;
458
459        // Rough throughput estimate (tokens/s).
460        let base_tps: f32 = 30.0; // baseline for scalar on 1 core
461        let speedup = simd_tier.expected_speedup_over_scalar();
462        let core_factor = (thread_count as f32).sqrt(); // diminishing returns
463        let estimated_tokens_per_second = base_tps * speedup * core_factor;
464
465        TuningRecommendation {
466            simd_tier,
467            thread_count,
468            batch_size,
469            max_context,
470            kv_cache_type,
471            use_flash_decode,
472            use_prefix_cache,
473            estimated_tokens_per_second,
474        }
475    }
476
477    /// Quick micro-benchmark: run a synthetic FMA-heavy kernel to estimate
478    /// raw compute throughput.
479    pub fn benchmark_kernel(&self, iterations: usize) -> KernelBenchmark {
480        let simd_tier = self.cpu.best_simd_tier();
481        let n = 1024usize; // vector length
482        let flops_per_iter = n * 2; // one FMA = 2 flops per element
483
484        // Allocate work buffers
485        let a: Vec<f32> = (0..n).map(|i| (i as f32) * 0.001).collect();
486        let b: Vec<f32> = (0..n).map(|i| 1.0 - (i as f32) * 0.0005).collect();
487        let mut acc = vec![0.0f32; n];
488
489        let start = Instant::now();
490        for _ in 0..iterations {
491            for j in 0..n {
492                // FMA: acc[j] += a[j] * b[j]
493                acc[j] += a[j] * b[j];
494            }
495            // Prevent the compiler from optimising the loop away
496            std::hint::black_box(&acc);
497        }
498        let elapsed = start.elapsed();
499        let total_duration_ms = elapsed.as_secs_f64() * 1000.0;
500
501        let total_flops = (iterations * flops_per_iter) as f64;
502        let elapsed_s = elapsed.as_secs_f64().max(1e-12);
503        let ops_per_second = iterations as f64 / elapsed_s;
504        let gflops = total_flops / elapsed_s / 1e9;
505
506        KernelBenchmark {
507            simd_tier,
508            iterations,
509            total_duration_ms,
510            ops_per_second,
511            gflops,
512        }
513    }
514
515    /// Reference to the detected CPU features.
516    pub fn cpu_features(&self) -> &CpuFeatures {
517        &self.cpu
518    }
519
520    /// Full diagnostic report.
521    pub fn report(&self) -> String {
522        let cpu_summary = self.cpu.summary();
523        let bench = self.benchmark_kernel(1000);
524        format!(
525            "OxiBonsai AutoTuner Report\n\
526             ==========================\n\
527             CPU: {}\n\
528             Benchmark: {}\n\
529             Memory budget: {} MB",
530            cpu_summary,
531            bench.summary(),
532            self.memory_mb,
533        )
534    }
535}
536
537impl Default for AutoTuner {
538    fn default() -> Self {
539        Self::new()
540    }
541}
542
543// ---------------------------------------------------------------------------
544// Helpers — feature detection
545// ---------------------------------------------------------------------------
546
547fn detect_arch() -> CpuArch {
548    #[cfg(target_arch = "x86_64")]
549    {
550        CpuArch::X86_64
551    }
552    #[cfg(target_arch = "aarch64")]
553    {
554        CpuArch::Aarch64
555    }
556    #[cfg(target_arch = "wasm32")]
557    {
558        CpuArch::Wasm32
559    }
560    #[cfg(not(any(
561        target_arch = "x86_64",
562        target_arch = "aarch64",
563        target_arch = "wasm32"
564    )))]
565    {
566        CpuArch::Other
567    }
568}
569
570fn cfg_has_avx2() -> bool {
571    #[cfg(target_arch = "x86_64")]
572    {
573        #[cfg(target_feature = "avx2")]
574        {
575            return true;
576        }
577        #[cfg(not(target_feature = "avx2"))]
578        {
579            // Runtime detection via CPUID
580            #[cfg(target_arch = "x86_64")]
581            {
582                return std::arch::is_x86_feature_detected!("avx2");
583            }
584            #[allow(unreachable_code)]
585            false
586        }
587    }
588    #[cfg(not(target_arch = "x86_64"))]
589    {
590        false
591    }
592}
593
594fn cfg_has_avx512() -> bool {
595    #[cfg(target_arch = "x86_64")]
596    {
597        #[cfg(target_feature = "avx512f")]
598        {
599            return true;
600        }
601        #[cfg(not(target_feature = "avx512f"))]
602        {
603            #[cfg(target_arch = "x86_64")]
604            {
605                return std::arch::is_x86_feature_detected!("avx512f");
606            }
607            #[allow(unreachable_code)]
608            false
609        }
610    }
611    #[cfg(not(target_arch = "x86_64"))]
612    {
613        false
614    }
615}
616
617fn cfg_has_neon() -> bool {
618    #[cfg(target_arch = "aarch64")]
619    {
620        // NEON is mandatory on AArch64
621        true
622    }
623    #[cfg(not(target_arch = "aarch64"))]
624    {
625        false
626    }
627}
628
629fn cfg_has_fma() -> bool {
630    #[cfg(target_arch = "x86_64")]
631    {
632        #[cfg(target_feature = "fma")]
633        {
634            return true;
635        }
636        #[cfg(not(target_feature = "fma"))]
637        {
638            #[cfg(target_arch = "x86_64")]
639            {
640                return std::arch::is_x86_feature_detected!("fma");
641            }
642            #[allow(unreachable_code)]
643            false
644        }
645    }
646    #[cfg(target_arch = "aarch64")]
647    {
648        // AArch64 always has FMA
649        true
650    }
651    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
652    {
653        false
654    }
655}
656
657fn cfg_has_sse42() -> bool {
658    #[cfg(target_arch = "x86_64")]
659    {
660        #[cfg(target_feature = "sse4.2")]
661        {
662            return true;
663        }
664        #[cfg(not(target_feature = "sse4.2"))]
665        {
666            #[cfg(target_arch = "x86_64")]
667            {
668                return std::arch::is_x86_feature_detected!("sse4.2");
669            }
670            #[allow(unreachable_code)]
671            false
672        }
673    }
674    #[cfg(not(target_arch = "x86_64"))]
675    {
676        false
677    }
678}
679
680/// Compute a recommended batch size based on thread count and memory budget.
681fn compute_batch_size(thread_count: usize, budget: &MemoryBudget) -> usize {
682    // Start with a base of 1, scale up with cores
683    let core_based = thread_count;
684
685    // Cap by available memory: each batch element uses ~1 MB activation buffer
686    let activation_bytes_per_item: usize = 1024 * 1024; // 1 MB
687    let memory_based = budget
688        .available_bytes
689        .checked_div(activation_bytes_per_item)
690        .unwrap_or(1);
691
692    // Take the minimum, clamp to [1, 128]
693    core_based.min(memory_based).clamp(1, 128)
694}
695
696// ---------------------------------------------------------------------------
697// Tests
698// ---------------------------------------------------------------------------
699
700#[cfg(test)]
701mod tests {
702    use super::*;
703
704    #[test]
705    fn detect_arch_returns_valid() {
706        let arch = detect_arch();
707        // Just ensure it doesn't panic and returns something
708        let _ = format!("{arch}");
709    }
710
711    #[test]
712    fn simd_tier_display() {
713        assert_eq!(SimdTier::Scalar.name(), "Scalar");
714        assert_eq!(SimdTier::Avx512.name(), "AVX-512");
715    }
716
717    #[test]
718    fn memory_budget_zero_params() {
719        let budget = MemoryBudget::estimate(1024, 0, 1.0);
720        assert!(budget.model_weight_bytes == 0);
721        assert!(budget.available_bytes > 0);
722    }
723}