1use std::fmt;
11use std::time::Instant;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum CpuArch {
20 X86_64,
21 Aarch64,
22 Wasm32,
23 Other,
24}
25
26impl fmt::Display for CpuArch {
27 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28 match self {
29 Self::X86_64 => write!(f, "x86_64"),
30 Self::Aarch64 => write!(f, "aarch64"),
31 Self::Wasm32 => write!(f, "wasm32"),
32 Self::Other => write!(f, "other"),
33 }
34 }
35}
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
43pub enum SimdTier {
44 Scalar,
45 Sse42,
46 Neon,
47 Avx2,
48 Avx512,
49}
50
51impl SimdTier {
52 pub fn name(&self) -> &'static str {
54 match self {
55 Self::Scalar => "Scalar",
56 Self::Sse42 => "SSE4.2",
57 Self::Neon => "NEON",
58 Self::Avx2 => "AVX2",
59 Self::Avx512 => "AVX-512",
60 }
61 }
62
63 pub fn vector_width_bits(&self) -> usize {
65 match self {
66 Self::Scalar => 64,
67 Self::Sse42 => 128,
68 Self::Neon => 128,
69 Self::Avx2 => 256,
70 Self::Avx512 => 512,
71 }
72 }
73
74 pub fn expected_speedup_over_scalar(&self) -> f32 {
76 match self {
77 Self::Scalar => 1.0,
78 Self::Sse42 => 2.0,
79 Self::Neon => 2.5,
80 Self::Avx2 => 4.0,
81 Self::Avx512 => 7.0,
82 }
83 }
84}
85
86impl fmt::Display for SimdTier {
87 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88 write!(f, "{}", self.name())
89 }
90}
91
92#[derive(Debug, Clone, PartialEq)]
98pub struct CpuFeatures {
99 pub has_avx2: bool,
100 pub has_avx512: bool,
101 pub has_neon: bool,
102 pub has_fma: bool,
103 pub has_sse42: bool,
104 pub logical_cores: usize,
105 pub physical_cores: usize,
106 pub arch: CpuArch,
107 pub cache_line_bytes: usize,
108}
109
110impl CpuFeatures {
111 pub fn detect() -> Self {
113 let arch = detect_arch();
114
115 let has_avx2 = cfg_has_avx2();
116 let has_avx512 = cfg_has_avx512();
117 let has_neon = cfg_has_neon();
118 let has_fma = cfg_has_fma();
119 let has_sse42 = cfg_has_sse42();
120
121 let logical_cores = std::thread::available_parallelism()
122 .map(|n| n.get())
123 .unwrap_or(1);
124 let physical_cores = match arch {
126 CpuArch::X86_64 => logical_cores.div_ceil(2),
127 _ => logical_cores,
128 };
129
130 let cache_line_bytes = match arch {
131 CpuArch::X86_64 => 64,
132 CpuArch::Aarch64 => 64,
133 _ => 64,
134 };
135
136 Self {
137 has_avx2,
138 has_avx512,
139 has_neon,
140 has_fma,
141 has_sse42,
142 logical_cores,
143 physical_cores,
144 arch,
145 cache_line_bytes,
146 }
147 }
148
149 pub fn best_simd_tier(&self) -> SimdTier {
151 if self.has_avx512 {
152 SimdTier::Avx512
153 } else if self.has_avx2 {
154 SimdTier::Avx2
155 } else if self.has_neon {
156 SimdTier::Neon
157 } else if self.has_sse42 {
158 SimdTier::Sse42
159 } else {
160 SimdTier::Scalar
161 }
162 }
163
164 pub fn recommended_threads(&self) -> usize {
169 self.physical_cores.max(1)
170 }
171
172 pub fn summary(&self) -> String {
174 format!(
175 "arch={}, simd={}, logical_cores={}, physical_cores={}, cache_line={}B",
176 self.arch,
177 self.best_simd_tier(),
178 self.logical_cores,
179 self.physical_cores,
180 self.cache_line_bytes,
181 )
182 }
183}
184
185#[derive(Debug, Clone, Copy, PartialEq)]
191pub enum KvCacheType {
192 Fp32,
193 Fp16,
194 Int8,
195}
196
197impl KvCacheType {
198 pub fn bytes_per_element(&self) -> usize {
200 match self {
201 Self::Fp32 => 4,
202 Self::Fp16 => 2,
203 Self::Int8 => 1,
204 }
205 }
206
207 pub fn name(&self) -> &'static str {
209 match self {
210 Self::Fp32 => "FP32",
211 Self::Fp16 => "FP16",
212 Self::Int8 => "INT8",
213 }
214 }
215}
216
217impl fmt::Display for KvCacheType {
218 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
219 write!(f, "{}", self.name())
220 }
221}
222
223#[derive(Debug, Clone)]
229pub struct MemoryBudget {
230 pub total_system_bytes: usize,
232 pub available_bytes: usize,
234 pub model_weight_bytes: usize,
236 pub kv_cache_budget: usize,
238 pub runtime_overhead: usize,
240}
241
242impl MemoryBudget {
243 pub fn estimate(total_available_mb: usize, model_params: usize, bits_per_weight: f32) -> Self {
249 let total_system_bytes = total_available_mb * 1024 * 1024;
250
251 let model_weight_bytes = ((model_params as f64) * (bits_per_weight as f64) / 8.0) as usize;
253
254 let runtime_overhead = (total_system_bytes / 10).min(256 * 1024 * 1024);
256
257 let available_bytes = total_system_bytes
259 .saturating_sub(model_weight_bytes)
260 .saturating_sub(runtime_overhead);
261
262 let kv_cache_budget = available_bytes * 4 / 5;
264
265 Self {
266 total_system_bytes,
267 available_bytes,
268 model_weight_bytes,
269 kv_cache_budget,
270 runtime_overhead,
271 }
272 }
273
274 pub fn max_context_length(
279 &self,
280 num_layers: usize,
281 num_heads: usize,
282 head_dim: usize,
283 ) -> usize {
284 let bytes_per_token = Self::kv_bytes_per_token(num_layers, num_heads, head_dim);
285 if bytes_per_token == 0 {
286 return 0;
287 }
288 self.kv_cache_budget / bytes_per_token
289 }
290
291 pub fn fits_context(
293 &self,
294 ctx_len: usize,
295 num_layers: usize,
296 num_heads: usize,
297 head_dim: usize,
298 ) -> bool {
299 let bytes_per_token = Self::kv_bytes_per_token(num_layers, num_heads, head_dim);
300 ctx_len * bytes_per_token <= self.kv_cache_budget
301 }
302
303 pub fn summary(&self) -> String {
305 let mb = |b: usize| b as f64 / (1024.0 * 1024.0);
306 format!(
307 "total={:.0}MB, weights={:.0}MB, kv_budget={:.0}MB, overhead={:.0}MB, available={:.0}MB",
308 mb(self.total_system_bytes),
309 mb(self.model_weight_bytes),
310 mb(self.kv_cache_budget),
311 mb(self.runtime_overhead),
312 mb(self.available_bytes),
313 )
314 }
315
316 fn kv_bytes_per_token(num_layers: usize, num_heads: usize, head_dim: usize) -> usize {
318 2 * num_layers * num_heads * head_dim * 2
320 }
321}
322
323#[derive(Debug, Clone)]
329pub struct TuningRecommendation {
330 pub simd_tier: SimdTier,
332 pub thread_count: usize,
334 pub batch_size: usize,
336 pub max_context: usize,
338 pub kv_cache_type: KvCacheType,
340 pub use_flash_decode: bool,
342 pub use_prefix_cache: bool,
344 pub estimated_tokens_per_second: f32,
346}
347
348impl TuningRecommendation {
349 pub fn summary(&self) -> String {
351 format!(
352 "simd={}, threads={}, batch={}, max_ctx={}, kv={}, flash_decode={}, prefix_cache={}, est_tok/s={:.1}",
353 self.simd_tier,
354 self.thread_count,
355 self.batch_size,
356 self.max_context,
357 self.kv_cache_type,
358 self.use_flash_decode,
359 self.use_prefix_cache,
360 self.estimated_tokens_per_second,
361 )
362 }
363}
364
365#[derive(Debug, Clone)]
371pub struct KernelBenchmark {
372 pub simd_tier: SimdTier,
374 pub iterations: usize,
376 pub total_duration_ms: f64,
378 pub ops_per_second: f64,
380 pub gflops: f64,
382}
383
384impl KernelBenchmark {
385 pub fn summary(&self) -> String {
387 format!(
388 "tier={}, iters={}, time={:.2}ms, ops/s={:.0}, GFLOPS={:.2}",
389 self.simd_tier,
390 self.iterations,
391 self.total_duration_ms,
392 self.ops_per_second,
393 self.gflops,
394 )
395 }
396}
397
398pub struct AutoTuner {
404 cpu: CpuFeatures,
405 memory_mb: usize,
406}
407
408impl AutoTuner {
409 pub fn new() -> Self {
412 let cpu = CpuFeatures::detect();
413 let memory_mb = 4096;
415 Self { cpu, memory_mb }
416 }
417
418 pub fn with_memory_mb(memory_mb: usize) -> Self {
420 let cpu = CpuFeatures::detect();
421 Self { cpu, memory_mb }
422 }
423
424 pub fn recommend(
430 &self,
431 model_params: usize,
432 bits_per_weight: f32,
433 num_layers: usize,
434 num_heads: usize,
435 head_dim: usize,
436 ) -> TuningRecommendation {
437 let simd_tier = self.cpu.best_simd_tier();
438 let thread_count = self.cpu.recommended_threads();
439
440 let budget = MemoryBudget::estimate(self.memory_mb, model_params, bits_per_weight);
441 let max_context = budget.max_context_length(num_layers, num_heads, head_dim);
442
443 let batch_size = compute_batch_size(thread_count, &budget);
445
446 let kv_cache_type = if budget.kv_cache_budget > 128 * 1024 * 1024 {
448 KvCacheType::Fp16
449 } else {
450 KvCacheType::Int8
451 };
452
453 let use_flash_decode = max_context >= 2048;
455
456 let use_prefix_cache = budget.kv_cache_budget > 256 * 1024 * 1024;
458
459 let base_tps: f32 = 30.0; let speedup = simd_tier.expected_speedup_over_scalar();
462 let core_factor = (thread_count as f32).sqrt(); let estimated_tokens_per_second = base_tps * speedup * core_factor;
464
465 TuningRecommendation {
466 simd_tier,
467 thread_count,
468 batch_size,
469 max_context,
470 kv_cache_type,
471 use_flash_decode,
472 use_prefix_cache,
473 estimated_tokens_per_second,
474 }
475 }
476
477 pub fn benchmark_kernel(&self, iterations: usize) -> KernelBenchmark {
480 let simd_tier = self.cpu.best_simd_tier();
481 let n = 1024usize; let flops_per_iter = n * 2; let a: Vec<f32> = (0..n).map(|i| (i as f32) * 0.001).collect();
486 let b: Vec<f32> = (0..n).map(|i| 1.0 - (i as f32) * 0.0005).collect();
487 let mut acc = vec![0.0f32; n];
488
489 let start = Instant::now();
490 for _ in 0..iterations {
491 for j in 0..n {
492 acc[j] += a[j] * b[j];
494 }
495 std::hint::black_box(&acc);
497 }
498 let elapsed = start.elapsed();
499 let total_duration_ms = elapsed.as_secs_f64() * 1000.0;
500
501 let total_flops = (iterations * flops_per_iter) as f64;
502 let elapsed_s = elapsed.as_secs_f64().max(1e-12);
503 let ops_per_second = iterations as f64 / elapsed_s;
504 let gflops = total_flops / elapsed_s / 1e9;
505
506 KernelBenchmark {
507 simd_tier,
508 iterations,
509 total_duration_ms,
510 ops_per_second,
511 gflops,
512 }
513 }
514
515 pub fn cpu_features(&self) -> &CpuFeatures {
517 &self.cpu
518 }
519
520 pub fn report(&self) -> String {
522 let cpu_summary = self.cpu.summary();
523 let bench = self.benchmark_kernel(1000);
524 format!(
525 "OxiBonsai AutoTuner Report\n\
526 ==========================\n\
527 CPU: {}\n\
528 Benchmark: {}\n\
529 Memory budget: {} MB",
530 cpu_summary,
531 bench.summary(),
532 self.memory_mb,
533 )
534 }
535}
536
537impl Default for AutoTuner {
538 fn default() -> Self {
539 Self::new()
540 }
541}
542
543fn detect_arch() -> CpuArch {
548 #[cfg(target_arch = "x86_64")]
549 {
550 CpuArch::X86_64
551 }
552 #[cfg(target_arch = "aarch64")]
553 {
554 CpuArch::Aarch64
555 }
556 #[cfg(target_arch = "wasm32")]
557 {
558 CpuArch::Wasm32
559 }
560 #[cfg(not(any(
561 target_arch = "x86_64",
562 target_arch = "aarch64",
563 target_arch = "wasm32"
564 )))]
565 {
566 CpuArch::Other
567 }
568}
569
570fn cfg_has_avx2() -> bool {
571 #[cfg(target_arch = "x86_64")]
572 {
573 #[cfg(target_feature = "avx2")]
574 {
575 return true;
576 }
577 #[cfg(not(target_feature = "avx2"))]
578 {
579 #[cfg(target_arch = "x86_64")]
581 {
582 return std::arch::is_x86_feature_detected!("avx2");
583 }
584 #[allow(unreachable_code)]
585 false
586 }
587 }
588 #[cfg(not(target_arch = "x86_64"))]
589 {
590 false
591 }
592}
593
594fn cfg_has_avx512() -> bool {
595 #[cfg(target_arch = "x86_64")]
596 {
597 #[cfg(target_feature = "avx512f")]
598 {
599 return true;
600 }
601 #[cfg(not(target_feature = "avx512f"))]
602 {
603 #[cfg(target_arch = "x86_64")]
604 {
605 return std::arch::is_x86_feature_detected!("avx512f");
606 }
607 #[allow(unreachable_code)]
608 false
609 }
610 }
611 #[cfg(not(target_arch = "x86_64"))]
612 {
613 false
614 }
615}
616
617fn cfg_has_neon() -> bool {
618 #[cfg(target_arch = "aarch64")]
619 {
620 true
622 }
623 #[cfg(not(target_arch = "aarch64"))]
624 {
625 false
626 }
627}
628
629fn cfg_has_fma() -> bool {
630 #[cfg(target_arch = "x86_64")]
631 {
632 #[cfg(target_feature = "fma")]
633 {
634 return true;
635 }
636 #[cfg(not(target_feature = "fma"))]
637 {
638 #[cfg(target_arch = "x86_64")]
639 {
640 return std::arch::is_x86_feature_detected!("fma");
641 }
642 #[allow(unreachable_code)]
643 false
644 }
645 }
646 #[cfg(target_arch = "aarch64")]
647 {
648 true
650 }
651 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
652 {
653 false
654 }
655}
656
657fn cfg_has_sse42() -> bool {
658 #[cfg(target_arch = "x86_64")]
659 {
660 #[cfg(target_feature = "sse4.2")]
661 {
662 return true;
663 }
664 #[cfg(not(target_feature = "sse4.2"))]
665 {
666 #[cfg(target_arch = "x86_64")]
667 {
668 return std::arch::is_x86_feature_detected!("sse4.2");
669 }
670 #[allow(unreachable_code)]
671 false
672 }
673 }
674 #[cfg(not(target_arch = "x86_64"))]
675 {
676 false
677 }
678}
679
680fn compute_batch_size(thread_count: usize, budget: &MemoryBudget) -> usize {
682 let core_based = thread_count;
684
685 let activation_bytes_per_item: usize = 1024 * 1024; let memory_based = budget
688 .available_bytes
689 .checked_div(activation_bytes_per_item)
690 .unwrap_or(1);
691
692 core_based.min(memory_based).clamp(1, 128)
694}
695
696#[cfg(test)]
701mod tests {
702 use super::*;
703
704 #[test]
705 fn detect_arch_returns_valid() {
706 let arch = detect_arch();
707 let _ = format!("{arch}");
709 }
710
711 #[test]
712 fn simd_tier_display() {
713 assert_eq!(SimdTier::Scalar.name(), "Scalar");
714 assert_eq!(SimdTier::Avx512.name(), "AVX-512");
715 }
716
717 #[test]
718 fn memory_budget_zero_params() {
719 let budget = MemoryBudget::estimate(1024, 0, 1.0);
720 assert!(budget.model_weight_bytes == 0);
721 assert!(budget.available_bytes > 0);
722 }
723}