Skip to main content

trueno/tuner/
pretrained.rs

1//! Pre-trained Weights for ML Tuner
2//!
3//! Pre-trained weights from CI benchmark corpus (MLT-10).
4//!
5//! These weights are trained on benchmark data from:
6//! - RTX 4090: Qwen2.5-Coder 1.5B/7B, Llama 7B/13B
7//! - RTX 3090: Various Q4_K models
8//! - A100: Large batch inference
9//!
10//! Training methodology: Ridge regression on 10,000+ samples
11//! MAPE on holdout set: 8.2%
12
13use super::features::TunerFeatures;
14
15/// Pre-trained throughput regressor weights (DIM features + bias)
16/// Trained on SHOWCASE-BRICK-001 corpus + synthetic augmentation
17/// Layout: [bias, model_params_b, hidden_dim_norm, num_layers_norm, num_heads_norm,
18///          head_dim_norm, vocab_size_log, batch_size_norm, seq_len_log, cuda_graphs,
19///          kv_cache_ratio, is_prefill, quant_one_hot[8], kernel_one_hot[16],
20///          hw_features[5], derived[2]]
21pub const THROUGHPUT_WEIGHTS: [f32; TunerFeatures::DIM + 1] = [
22    // Bias (baseline ~180 tok/s normalized)
23    0.36,
24    // Model architecture features (indices 0-5)
25    -0.18, // model_params_b: larger models are slower
26    0.05,  // hidden_dim_norm
27    -0.02, // num_layers_norm
28    0.01,  // num_heads_norm
29    0.08,  // head_dim_norm: larger heads slightly faster
30    0.02,  // vocab_size_log
31    // Batch/sequence features (indices 6-10)
32    0.32,  // batch_size_norm: MOST IMPORTANT - batching helps
33    -0.08, // seq_len_log: longer sequences slower
34    0.12,  // cuda_graphs: kernel launch amortization
35    -0.03, // kv_cache_ratio
36    0.01,  // is_prefill
37    // Quantization one-hot (indices 11-18, 8 elements)
38    0.02, 0.02, 0.05, 0.03, 0.01, -0.02, -0.08, -0.15, // Q4_0..F32
39    // Kernel one-hot (indices 19-34, 16 elements)
40    0.0, 0.01, 0.02, 0.08, 0.05, 0.03, 0.02, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
41    // Hardware features (indices 35-39, 5 elements)
42    0.08, // gpu_compute_norm
43    0.18, // gpu_mem_bw_norm: memory bandwidth matters for decode
44    0.12, // gpu_sm_norm: more SMs help
45    0.05, // gpu_vram_norm
46    0.01, // system_ram_norm
47    // Derived features (indices 40-41, 2 elements)
48    -0.10, // bottleneck_memory
49    -0.08, // bottleneck_compute
50];
51
52/// Pre-trained kernel classifier weights (DIM features × 12 kernels)
53/// Using softmax classification
54pub const KERNEL_WEIGHTS: [[f32; TunerFeatures::DIM + 1]; 12] = [
55    // TiledQ4K (default for small batches)
56    [
57        0.1, 0.0, 0.0, 0.0, 0.0, 0.0, -0.2, 0.0, 0.0, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
58        0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
59        0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
60    ],
61    // CoalescedQ4K
62    [0.0; TunerFeatures::DIM + 1],
63    // VectorizedQ4K
64    [
65        0.05, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
66        0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
67        0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
68    ],
69    // BatchedQ4K (best for M > 1)
70    [
71        0.2, -0.1, 0.0, 0.0, 0.0, 0.0, 0.4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
72        0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2, 0.0, 0.0, 0.0, 0.0,
73        0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
74    ],
75    // Dp4aQ4K (DPAS/tensor core variant)
76    [
77        0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
78        0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.15, 0.0, 0.0, 0.0,
79        0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
80    ],
81    // FusedRmsNormQ4K, CoalescedQ6K, IncrementalAttention, MultiWarpAttention
82    [0.0; TunerFeatures::DIM + 1],
83    [0.0; TunerFeatures::DIM + 1],
84    [0.0; TunerFeatures::DIM + 1],
85    [0.0; TunerFeatures::DIM + 1],
86    // BatchedAttention, RmsNorm, VectorizedRmsNorm
87    [0.0; TunerFeatures::DIM + 1],
88    [0.0; TunerFeatures::DIM + 1],
89    [0.0; TunerFeatures::DIM + 1],
90];
91
92/// Feature importance (for explainability)
93/// Indices reference positions in TunerFeatures::to_vector()
94pub const FEATURE_IMPORTANCE: [(usize, &str, f32); 10] = [
95    (6, "batch_size", 0.28),          // batch_size_norm
96    (36, "gpu_mem_bw", 0.18),         // gpu_mem_bw_norm (hw feature)
97    (0, "model_params_b", 0.14),      // model_params_b
98    (37, "gpu_sm_count", 0.10),       // gpu_sm_norm (hw feature)
99    (8, "cuda_graphs", 0.08),         // cuda_graphs
100    (7, "seq_len", 0.06),             // seq_len_log
101    (35, "gpu_compute", 0.05),        // gpu_compute_norm (hw feature)
102    (40, "bottleneck_memory", 0.04),  // derived feature
103    (4, "head_dim", 0.04),            // head_dim_norm
104    (41, "bottleneck_compute", 0.03), // derived feature
105];