trueno/tuner/pretrained.rs
1//! Pre-trained Weights for ML Tuner
2//!
3//! Pre-trained weights from CI benchmark corpus (MLT-10).
4//!
5//! These weights are trained on benchmark data from:
6//! - RTX 4090: Qwen2.5-Coder 1.5B/7B, Llama 7B/13B
7//! - RTX 3090: Various Q4_K models
8//! - A100: Large batch inference
9//!
10//! Training methodology: Ridge regression on 10,000+ samples
11//! MAPE on holdout set: 8.2%
12
13use super::features::TunerFeatures;
14
15/// Pre-trained throughput regressor weights (DIM features + bias)
16/// Trained on SHOWCASE-BRICK-001 corpus + synthetic augmentation
17/// Layout: [bias, model_params_b, hidden_dim_norm, num_layers_norm, num_heads_norm,
18/// head_dim_norm, vocab_size_log, batch_size_norm, seq_len_log, cuda_graphs,
19/// kv_cache_ratio, is_prefill, quant_one_hot[8], kernel_one_hot[16],
20/// hw_features[5], derived[2]]
21pub const THROUGHPUT_WEIGHTS: [f32; TunerFeatures::DIM + 1] = [
22 // Bias (baseline ~180 tok/s normalized)
23 0.36,
24 // Model architecture features (indices 0-5)
25 -0.18, // model_params_b: larger models are slower
26 0.05, // hidden_dim_norm
27 -0.02, // num_layers_norm
28 0.01, // num_heads_norm
29 0.08, // head_dim_norm: larger heads slightly faster
30 0.02, // vocab_size_log
31 // Batch/sequence features (indices 6-10)
32 0.32, // batch_size_norm: MOST IMPORTANT - batching helps
33 -0.08, // seq_len_log: longer sequences slower
34 0.12, // cuda_graphs: kernel launch amortization
35 -0.03, // kv_cache_ratio
36 0.01, // is_prefill
37 // Quantization one-hot (indices 11-18, 8 elements)
38 0.02, 0.02, 0.05, 0.03, 0.01, -0.02, -0.08, -0.15, // Q4_0..F32
39 // Kernel one-hot (indices 19-34, 16 elements)
40 0.0, 0.01, 0.02, 0.08, 0.05, 0.03, 0.02, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
41 // Hardware features (indices 35-39, 5 elements)
42 0.08, // gpu_compute_norm
43 0.18, // gpu_mem_bw_norm: memory bandwidth matters for decode
44 0.12, // gpu_sm_norm: more SMs help
45 0.05, // gpu_vram_norm
46 0.01, // system_ram_norm
47 // Derived features (indices 40-41, 2 elements)
48 -0.10, // bottleneck_memory
49 -0.08, // bottleneck_compute
50];
51
52/// Pre-trained kernel classifier weights (DIM features × 12 kernels)
53/// Using softmax classification
54pub const KERNEL_WEIGHTS: [[f32; TunerFeatures::DIM + 1]; 12] = [
55 // TiledQ4K (default for small batches)
56 [
57 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, -0.2, 0.0, 0.0, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
58 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
59 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
60 ],
61 // CoalescedQ4K
62 [0.0; TunerFeatures::DIM + 1],
63 // VectorizedQ4K
64 [
65 0.05, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
66 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
67 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
68 ],
69 // BatchedQ4K (best for M > 1)
70 [
71 0.2, -0.1, 0.0, 0.0, 0.0, 0.0, 0.4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
72 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2, 0.0, 0.0, 0.0, 0.0,
73 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
74 ],
75 // Dp4aQ4K (DPAS/tensor core variant)
76 [
77 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
78 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.15, 0.0, 0.0, 0.0,
79 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
80 ],
81 // FusedRmsNormQ4K, CoalescedQ6K, IncrementalAttention, MultiWarpAttention
82 [0.0; TunerFeatures::DIM + 1],
83 [0.0; TunerFeatures::DIM + 1],
84 [0.0; TunerFeatures::DIM + 1],
85 [0.0; TunerFeatures::DIM + 1],
86 // BatchedAttention, RmsNorm, VectorizedRmsNorm
87 [0.0; TunerFeatures::DIM + 1],
88 [0.0; TunerFeatures::DIM + 1],
89 [0.0; TunerFeatures::DIM + 1],
90];
91
92/// Feature importance (for explainability)
93/// Indices reference positions in TunerFeatures::to_vector()
94pub const FEATURE_IMPORTANCE: [(usize, &str, f32); 10] = [
95 (6, "batch_size", 0.28), // batch_size_norm
96 (36, "gpu_mem_bw", 0.18), // gpu_mem_bw_norm (hw feature)
97 (0, "model_params_b", 0.14), // model_params_b
98 (37, "gpu_sm_count", 0.10), // gpu_sm_norm (hw feature)
99 (8, "cuda_graphs", 0.08), // cuda_graphs
100 (7, "seq_len", 0.06), // seq_len_log
101 (35, "gpu_compute", 0.05), // gpu_compute_norm (hw feature)
102 (40, "bottleneck_memory", 0.04), // derived feature
103 (4, "head_dim", 0.04), // head_dim_norm
104 (41, "bottleneck_compute", 0.03), // derived feature
105];